backward函数

官方定义:

torch.autograd.backward(tensorsgrad_tensors=Noneretain_graph=Nonecreate_graph=Falsegrad_variables=None)

Computes the sum of gradients of given tensors w.r.t. graph leaves.The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors). This function accumulates gradients in the leaves - you might need to zero them before calling it.

翻译和解释:参数tensors如果是标量,函数backward计算参数tensors对于给定图叶子节点的梯度( graph leaves,即为设置requires_grad=True的变量)。

参数tensors如果不是标量,需要另外指定参数grad_tensors参数grad_tensors必须和参数tensors的长度相同。在这一种情况下,backward实际上实现的是代价函数(loss = torch.sum(tensors*grad_tensors); 注:torch中向量*向量实际上是点积,因此tensors和grad_tensors的维度必须一致 )关于叶子节点的梯度计算,而不是参数tensors对于给定图叶子节点的梯度。如果指定参数grad_tensors=torch.ones((size(tensors))),显而易见,代价函数关于叶子节点的梯度,也就等于参数tensors对于给定图叶子节点的梯度。

每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。

下面给出具体的例子:

 import torch
x=torch.randn((3),dtype=torch.float32,requires_grad=True)
y = torch.randn((3),dtype=torch.float32,requires_grad=True)
z = torch.randn((3),dtype=torch.float32,requires_grad=True)
t = x + y
loss = t.dot(z) #求向量的内积

在调用 backward 之前,可以先手动求一下导数,应该是:

用代码实现求导:

 loss.backward(retain_graph=True)
print(z,x.grad,y.grad) #预期打印出的结果都一样
print(t,z.grad) #预期打印出的结果都一样
print(t.grad) #在这个例子中,x,y,z就是叶子节点,而t不是,t的导数在backward的过程中求出来回传之后就会被释放,因而预期结果是None

结果和预期一致:

tensor([-2.6752,  0.2306, -0.8356], requires_grad=True) tensor([-2.6752,  0.2306, -0.8356]) tensor([-2.6752,  0.2306, -0.8356])
tensor([-1.1916, -0.0156, 0.8952], grad_fn=<AddBackward0>) tensor([-1.1916, -0.0156, 0.8952])
None

敲重点:注意到前面函数的解释中,在参数tensors不是标量的情况下,tensor.backward(grad_tensors)实现的是代价函数torch.sum(tensors*grad_tensors))关于叶子节点的导数。在上面例子中,loss = t.dot(z),因此用t.backward(z),实现的就是loss对于所有叶子结点的求导,实际运算结果和预期吻合。

 t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果如下:

tensor([-0.7830,  1.4468,  1.2440], requires_grad=True) tensor([-0.7830,  1.4468,  1.2440]) tensor([-0.7830,  1.4468,  1.2440])
tensor([-0.7145, -0.7598, 2.0756], grad_fn=<AddBackward0>) None

上面的结果中,出现了一个问题,虽然loss关于x和y的导数正确,但是z不再是叶子节点了。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

另外一个尝试,loss = t.dot(z)=z.dot(t),但是如果用z.backward(t)替换t.backward(z,retain_graph=True),结果却不同。

 z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果:

tensor([-1.0716, -1.3643, -0.0016], requires_grad=True) None None
tensor([-0.7324, 0.9763, -0.4036], grad_fn=<AddBackward0>) tensor([-0.7324, 0.9763, -0.4036])

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。

上述仿真出现的两个问题,我还不能解释,希望和大家交流,我的邮箱yangyuwen_yang@126.com,欢迎来信讨论。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。



另外强调一下,每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。
简单的代码可以看出:
 #测试1,:对比上两次单独执行backward,此处连续执行两次backward
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)
# 结果x.grad,y.grad本应该是None,因为保留了第一次backward的结果而打印出上一次梯度的结果

tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) None
tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) tensor([-1.7914, 0.8761, -0.3462])

 #测试2,:连续执行两次backward,并且清零,可以验证第二次backward没有计算x和y的梯度
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
x.grad.data.zero_()
y.grad.data.zero_()
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([ 0.8671, 0.6503, -1.6643]) tensor([ 0.8671, 0.6503, -1.6643])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) None
tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([0., 0., 0.]) tensor([0., 0., 0.])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) tensor([1.6231e+00, 1.3842e+00, 4.6492e-06])

附参考学习的链接如下,并对作者表示感谢:
PyTorch 的 backward 为什么有一个 grad_variables 参数?该话题在我的知乎专栏里同时转载,欢迎更多讨论交流,参见链接

Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析的更多相关文章

  1. PyTorch 中 torch.matmul() 函数的文档详解

    官方文档 torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定. 关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的 ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. 模式识别 - libsvm该函数的调用方法 详细说明

    libsvm该函数的调用方法 详细说明 本文地址: http://blog.csdn.net/caroline_wendy/article/details/26261173 须要载入(load)SVM ...

  4. Pytorch中torch.load()中出现AttributeError: Can't get attribute

    原因:保存下来的模型和参数不能在没有类定义时直接使用. Pytorch使用Pickle来处理保存/加载模型,这个问题实际上是Pickle的问题,而不是Pytorch. 解决方法也非常简单,只需显式地导 ...

  5. 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍

    squeeze用来减少维度, unsqueeze用来增加维度 具体可见下方博客. pytorch中squeeze和unsqueeze

  6. pytorch 中的Variable一般常用的使用方法

    Variable一般的初始化方法,默认是不求梯度的 import torch from torch.autograd import Variable x_tensor = torch.randn(2, ...

  7. javascript中的闭包、函数的toString方法

    闭包: 闭包可以理解为定义在一个函数内部的函数, 函数A内部定义了函数B, 函数B有访问函数A内部变量的权力: 闭包是函数和子函数之间的桥梁: 举个例子: let func = function() ...

  8. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  9. pytorch中torch.narrow()函数

    torch.narrow(input, dim, start, length) → Tensor Returns a new tensor that is a narrowed version of  ...

随机推荐

  1. genymotion的安装及运行

    一.下载工具 安装genymontion一共需要下载三个东西,分别是genymotion.虚拟机virtualbox和ova 笔者提供百度云下载:mac版虚拟机 mac上genymotion.wind ...

  2. 数据结构的javascript实现

    栈 栈(stack)又名堆栈,是一种遵循后进先出(LIFO)原则的有序集合.新添加或待删除的元素都保存在栈的末尾,称作栈顶,另一端称作栈底.在栈里,新元素都靠近栈顶,旧元素都接近栈底. functio ...

  3. Ubuntu 16.04下安装64位谷歌Chrome浏览器

    1.进入 Ubuntu 16.04 桌面,按下 Ctrl + Alt + t 键盘组合键,启动终端. 也可以按下 Win 键(或叫 Super 键),在 Dash 的搜索框中输入 terminal 或 ...

  4. Instrumentation(1)

    Instrumentation介绍: JavaInstrumentation指的是可以用独立于应用程序之外的代理(agent)程序来监测和协助运行在JVM上的应用程序.这种监测和协助包括但不限于获取J ...

  5. Solr的中英文分词实现

    对于Solr应该不需要过多介绍了,强大的功能也是都体验过了,但是solr一个较大的问题就是分词问题,特别是中英文的混合分词,处理起来非常棘手. 虽然solr自带了支持中文分词的cjk,但是其效果实在不 ...

  6. STM32基于固件库新建MDK工程模板(精简版)

    上个博文理论讲解的东西太多,太复杂,这里把所有步骤全部贴出 1.新建一个工程文件夹LED 2.LED文件夹下建立如下文件夹 3.Project –>New Uvision Project 到US ...

  7. 基于Java实现红黑树的基本操作

    首先,在阅读文章之前,我希望读者对二叉树有一定的了解,因为红黑树的本质就是一颗二叉树.所以本篇博客中不在将二叉树的增删查的基本操作了,需要了解的同学可以到我之前写的一篇关于二叉树基本操作的博客:htt ...

  8. 一文读懂Asp.net core 依赖注入(Dependency injection)

    一.什么是依赖注入 首先在Asp.net core中是支持依赖注入软件设计模式,或者说依赖注入是asp.net core的核心: 依赖注入(DI)和控制反转(IOC)基本是一个意思,因为说起来谁都离不 ...

  9. 微信小程序之表单验证

    表单验证 何为表单验证呢? 百度百科给出的回答是这样的: 表单验证是javascript中的高级选项之一.JavaScript 可用来在数据被送往服务器前对 HTML 表单中的这些输入数据进行验证 [ ...

  10. 重学前端 --- Promise里的代码为什么比setTimeout先执行?

    首先通过一段代码进入讨论的主题 var r = new Promise(function(resolve, reject){ console.log("a"); resolve() ...