backward函数

官方定义:

torch.autograd.backward(tensorsgrad_tensors=Noneretain_graph=Nonecreate_graph=Falsegrad_variables=None)

Computes the sum of gradients of given tensors w.r.t. graph leaves.The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors). This function accumulates gradients in the leaves - you might need to zero them before calling it.

翻译和解释:参数tensors如果是标量,函数backward计算参数tensors对于给定图叶子节点的梯度( graph leaves,即为设置requires_grad=True的变量)。

参数tensors如果不是标量,需要另外指定参数grad_tensors参数grad_tensors必须和参数tensors的长度相同。在这一种情况下,backward实际上实现的是代价函数(loss = torch.sum(tensors*grad_tensors); 注:torch中向量*向量实际上是点积,因此tensors和grad_tensors的维度必须一致 )关于叶子节点的梯度计算,而不是参数tensors对于给定图叶子节点的梯度。如果指定参数grad_tensors=torch.ones((size(tensors))),显而易见,代价函数关于叶子节点的梯度,也就等于参数tensors对于给定图叶子节点的梯度。

每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。

下面给出具体的例子:

 import torch
x=torch.randn((3),dtype=torch.float32,requires_grad=True)
y = torch.randn((3),dtype=torch.float32,requires_grad=True)
z = torch.randn((3),dtype=torch.float32,requires_grad=True)
t = x + y
loss = t.dot(z) #求向量的内积

在调用 backward 之前,可以先手动求一下导数,应该是:

用代码实现求导:

 loss.backward(retain_graph=True)
print(z,x.grad,y.grad) #预期打印出的结果都一样
print(t,z.grad) #预期打印出的结果都一样
print(t.grad) #在这个例子中,x,y,z就是叶子节点,而t不是,t的导数在backward的过程中求出来回传之后就会被释放,因而预期结果是None

结果和预期一致:

tensor([-2.6752,  0.2306, -0.8356], requires_grad=True) tensor([-2.6752,  0.2306, -0.8356]) tensor([-2.6752,  0.2306, -0.8356])
tensor([-1.1916, -0.0156, 0.8952], grad_fn=<AddBackward0>) tensor([-1.1916, -0.0156, 0.8952])
None

敲重点:注意到前面函数的解释中,在参数tensors不是标量的情况下,tensor.backward(grad_tensors)实现的是代价函数torch.sum(tensors*grad_tensors))关于叶子节点的导数。在上面例子中,loss = t.dot(z),因此用t.backward(z),实现的就是loss对于所有叶子结点的求导,实际运算结果和预期吻合。

 t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果如下:

tensor([-0.7830,  1.4468,  1.2440], requires_grad=True) tensor([-0.7830,  1.4468,  1.2440]) tensor([-0.7830,  1.4468,  1.2440])
tensor([-0.7145, -0.7598, 2.0756], grad_fn=<AddBackward0>) None

上面的结果中,出现了一个问题,虽然loss关于x和y的导数正确,但是z不再是叶子节点了。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

另外一个尝试,loss = t.dot(z)=z.dot(t),但是如果用z.backward(t)替换t.backward(z,retain_graph=True),结果却不同。

 z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果:

tensor([-1.0716, -1.3643, -0.0016], requires_grad=True) None None
tensor([-0.7324, 0.9763, -0.4036], grad_fn=<AddBackward0>) tensor([-0.7324, 0.9763, -0.4036])

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。

上述仿真出现的两个问题,我还不能解释,希望和大家交流,我的邮箱yangyuwen_yang@126.com,欢迎来信讨论。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。



另外强调一下,每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。
简单的代码可以看出:
 #测试1,:对比上两次单独执行backward,此处连续执行两次backward
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)
# 结果x.grad,y.grad本应该是None,因为保留了第一次backward的结果而打印出上一次梯度的结果

tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) None
tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) tensor([-1.7914, 0.8761, -0.3462])

 #测试2,:连续执行两次backward,并且清零,可以验证第二次backward没有计算x和y的梯度
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
x.grad.data.zero_()
y.grad.data.zero_()
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([ 0.8671, 0.6503, -1.6643]) tensor([ 0.8671, 0.6503, -1.6643])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) None
tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([0., 0., 0.]) tensor([0., 0., 0.])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) tensor([1.6231e+00, 1.3842e+00, 4.6492e-06])

附参考学习的链接如下,并对作者表示感谢:
PyTorch 的 backward 为什么有一个 grad_variables 参数?该话题在我的知乎专栏里同时转载,欢迎更多讨论交流,参见链接

Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析的更多相关文章

  1. PyTorch 中 torch.matmul() 函数的文档详解

    官方文档 torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定. 关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的 ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. 模式识别 - libsvm该函数的调用方法 详细说明

    libsvm该函数的调用方法 详细说明 本文地址: http://blog.csdn.net/caroline_wendy/article/details/26261173 须要载入(load)SVM ...

  4. Pytorch中torch.load()中出现AttributeError: Can't get attribute

    原因:保存下来的模型和参数不能在没有类定义时直接使用. Pytorch使用Pickle来处理保存/加载模型,这个问题实际上是Pickle的问题,而不是Pytorch. 解决方法也非常简单,只需显式地导 ...

  5. 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍

    squeeze用来减少维度, unsqueeze用来增加维度 具体可见下方博客. pytorch中squeeze和unsqueeze

  6. pytorch 中的Variable一般常用的使用方法

    Variable一般的初始化方法,默认是不求梯度的 import torch from torch.autograd import Variable x_tensor = torch.randn(2, ...

  7. javascript中的闭包、函数的toString方法

    闭包: 闭包可以理解为定义在一个函数内部的函数, 函数A内部定义了函数B, 函数B有访问函数A内部变量的权力: 闭包是函数和子函数之间的桥梁: 举个例子: let func = function() ...

  8. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  9. pytorch中torch.narrow()函数

    torch.narrow(input, dim, start, length) → Tensor Returns a new tensor that is a narrowed version of  ...

随机推荐

  1. 7-25 :active :after :before :disabled

    1:<list,<datalist>,required,<select>,<option>,title,draggable,hidden 2:data-*和命 ...

  2. 如何将网页保存为PDF文件

    怎样将网页保存为PDF文件... 问题: 很多时候我们需要将网页上的内容,在排版不变的情况下完整的保存下来,那么用pdf格式是最好的效果了,还图文并茂,效果与真实的网页很相似,如果另存为网页的话,会下 ...

  3. Python模拟登陆万能法-微博|知乎

    Python模拟登陆让不少人伤透脑筋,今天奉上一种万能登陆方法.你无须精通HTML,甚至也无须精通Python,但却能让你成功的进行模拟登陆.本文讲的是登陆所有网站的一种方法,并不局限于微博与知乎,仅 ...

  4. 关于Kafka __consumer_offests的讨论

    众所周知,__consumer__offsets是一个内部topic,对用户而言是透明的,除了它的数据文件以及偶尔在日志中出现这两点之外,用户一般是感觉不到这个topic的.不过我们的确知道它保存的是 ...

  5. java基础学习周计划之3--每周一练

    每周一练第一周 一. 关键代码:1.斐波那契数列指的是这样一个数列 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, ...数列第一项和第二项是1, 从第三项开始, ...

  6. Git详解及github与gitlab使用

    第一章 关于版本控制 第二章 GIT简介 第三章 GIT安装 第四章 初次运行GIT前配置 第五章 初始化仓库 第六章 GIT命令操作 第七章 GIT分支结构

  7. Hibernate-ORM:02.Hibernate增删改入门案例

    ------------吾亦无他,唯手熟尔,谦卑若愚,好学若饥------------- 本笔者使用的是Idea+mysql+maven做Hibernate的博客,本篇及其以后都是如此! 首先写好思路 ...

  8. Charpter3 名字 作用域 约束

    一个对象拥有其语义价值的区域<其作用域 当一个变量将不再被使用,那它应该被理想的回收机制回收.但现实是我们仅当一个变量离开了其作用域,或变成不可访问,才考虑回收. 然而,作用域规则有其优点:1. ...

  9. Itest(爱测试),最懂测试人的开源测试管理软件隆重发布

    测试人自己开发,汇聚10年沉淀,独创流程驱动测试.度量展现测试人价值的测试协同软件,开源免费   官网刚上线,近期发布源码:http://www.itest.work 在线体验 http://www. ...

  10. SpringCloud-分布式链路跟踪配置详解

    SpringCloud-分布式链路跟踪 作者 : Stanley 罗昊 [转载请注明出处和署名,谢谢!] 注:作者使用IDEA + Gradle 注:需要有一定的java SpringBoot and ...