平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考。以下笔记基于Pytorch1.0

Tensor

Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor。如果我们需要计算某个Tensor的导数,那么我们需要设置其.requires_grad属性为True。为方便说明,在本文中对于这种我们自己定义的变量,我们称之为叶子节点(leaf nodes),而基于叶子节点得到的中间或最终变量则可称之为结果节点。例如下面例子中的x则是叶子节点,y则是结果节点。

x = torch.rand(3, requires_grad=True)
y = x**2
z = x + x

另外一个Tensor中通常会记录如下图中所示的属性:

  • data: 即存储的数据信息
  • requires_grad: 设置为True则表示该Tensor需要求导
  • grad: 该Tensor的梯度值,每次在计算backward时都需要将前一时刻的梯度归零,否则梯度值会一直累加,这个会在后面讲到。
  • grad_fn: 叶子节点通常为None,只有结果节点的grad_fn才有效,用于指示梯度函数是哪种类型。例如上面示例代码中的y.grad_fn=<PowBackward0 at 0x213550af048>, z.grad_fn=<AddBackward0 at 0x2135df11be0>
  • is_leaf: 用来指示该Tensor是否是叶子节点。

torch.autograd.backward

有如下代码:

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
z.backward()
print(z, x.grad, y.grad) >>> tensor(3., grad_fn=<AddBackward0>) tensor(2.) tensor(1.)

可以z是一个标量,当调用它的backward方法后会根据链式法则自动计算出叶子节点的梯度值。

但是如果遇到z是一个向量或者是一个矩阵的情况,这个时候又该怎么计算梯度呢?这种情况我们需要定义grad_tensor来计算矩阵的梯度。在介绍为什么使用之前我们先看一下源代码中backward的接口是如何定义的:

torch.autograd.backward(
tensors,
grad_tensors=None,
retain_graph=None,
create_graph=False,
grad_variables=None)
  • tensor: 用于计算梯度的tensor。也就是说这两种方式是等价的:torch.autograd.backward(z) == z.backward()
  • grad_tensors: 在计算矩阵的梯度时会用到。他其实也是一个tensor,shape一般需要和前面的tensor保持一致。
  • retain_graph: 通常在调用一次backward后,pytorch会自动把计算图销毁,所以要想对某个变量重复调用backward,则需要将该参数设置为True
  • create_graph: 当设置为True的时候可以用来计算更高阶的梯度
  • grad_variables: 这个官方说法是grad_variables' is deprecated. Use 'grad_tensors' instead.也就是说这个参数后面版本中应该会丢弃,直接使用grad_tensors就好了。

好了,参数大致作用都介绍了,下面我们看看pytorch为什么设计了grad_tensors这么一个参数,以及它有什么用呢?

还是用代码做个示例

x = torch.ones(2,requires_grad=True)
z = x + 2
z.backward() >>> ...
RuntimeError: grad can be implicitly created only for scalar outputs

当我们运行上面的代码的话会报错,报错信息为RuntimeError: grad can be implicitly created only for scalar outputs

上面的报错信息意思是只有对标量输出它才会计算梯度,而求一个矩阵对另一矩阵的导数束手无策。

\[X = \left[\begin{array}{cc}
x_0 & x_1 \\
\end{array}\right] \,\,\,\,\,\,\,\,\,\
Z=X+2=\left[\begin{array}{cc}
x_0+2 & x_1+2 \\
\end{array}\right]
\Rightarrow \frac{\partial{Z}}{\partial{X}}=?
\]

那么我们只要想办法把矩阵转变成一个标量不就好了?比如我们可以对z求和,然后用求和得到的标量在对x求导,这样不会对结果有影响,例如:

\[\begin{align}
&Z_{sum}=\sum{z_i}=x_0+x_1+8 \notag \\
&\text{then} \,\,\,\,\, \frac{\partial{Z_{sum}}}{\partial{x_0}}=\frac{\partial{Z_{sum}}}{\partial{x_1}}=1 \notag
\end{align}
\]

我们可以看到对z求和后再计算梯度没有报错,结果也与预期一样:

x = torch.ones(2,requires_grad=True)
z = x + 2
z.sum().backward()
print(x.grad) >>> tensor([1., 1.])

我们再仔细想想,对z求和不就是等价于z点乘一个一样维度的全为1的矩阵吗?即\(sum(Z)=dot(Z,I)\),而这个I也就是我们需要传入的grad_tensors参数。(点乘只是相对于一维向量而言的,对于矩阵或更高为的张量,可以看做是对每一个维度做点乘)

代码如下:

x = torch.ones(2,requires_grad=True)
z = x + 2
z.backward(torch.ones_like(z)) # grad_tensors需要与输入tensor大小一致
print(x.grad) >>> tensor([1., 1.])

弄个再复杂一点的:

x = torch.tensor([2., 1.], requires_grad=True).view(1, 2)
y = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) z = torch.mm(x, y)
print(f"z:{z}")
z.backward(torch.Tensor([[1., 0]]), retain_graph=True)
print(f"x.grad: {x.grad}")
print(f"y.grad: {y.grad}") >>> z:tensor([[5., 8.]], grad_fn=<MmBackward>)
x.grad: tensor([[1., 3.]])
y.grad: tensor([[2., 0.],
[1., 0.]])

结果解释如下:

总结:

说了这么多,grad_tensors的作用其实可以简单地理解成在求梯度时的权重,因为可能不同值的梯度对结果影响程度不同,所以pytorch弄了个这种接口,而没有固定为全是1。引用自知乎上的一个评论:如果从最后一个节点(总loss)来backward,这种实现(torch.sum(y*w))的意义就具体化为 multiple loss term with difference weights 这种需求了吧。

torch.autograd.grad

torch.autograd.grad(
outputs,
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False)

看了前面的内容后在看这个函数就很好理解了,各参数作用如下:

  • outputs: 结果节点,即被求导数
  • inputs: 叶子节点
  • grad_outputs: 类似于backward方法中的grad_tensors
  • retain_graph: 同上
  • create_graph: 同上
  • only_inputs: 默认为True, 如果为True, 则只会返回指定input的梯度值。 若为False,则会计算所有叶子节点的梯度,并且将计算得到的梯度累加到各自的.grad属性上去。
  • allow_unused: 默认为False, 即必须要指定input,如果没有指定的话则报错。

参考

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com




2020-01-23 17:45:35



2019-9-18

Pytorch autograd,backward详解的更多相关文章

  1. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  2. Pytorch数据读取详解

    原文:http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler ...

  3. 【小白学PyTorch】10 pytorch常见运算详解

    参考目录: 目录 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 近似值运算 7 剪裁运算 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些 ...

  4. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

  5. Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析

    backward函数 官方定义: torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph ...

  6. pytorch之nn.Conv1d详解

    转自:https://blog.csdn.net/sunny_xsc1994/article/details/82969867,感谢分享 pytorch之nn.Conv1d详解

  7. [转载]Pytorch详解NLLLoss和CrossEntropyLoss

    [转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...

  8. pytorch nn.LSTM()参数详解

    输入数据格式:input(seq_len, batch, input_size)h0(num_layers * num_directions, batch, hidden_size)c0(num_la ...

  9. 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...

随机推荐

  1. leetcode 13. Integer to Roman

    使用eval,特别处理6个case var romanToInt = function (s) { const map = { 'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C ...

  2. 【操作系统之六】Linux常用命令之less

    一.概念less 工具也是对文件或其它输出进行分页显示的工具,是linux正统查看文件内容的工具,功能极其强大.less 的用法比起 more .tail更加的有弹性.在 more 的时候,我们并没有 ...

  3. Alpha冲刺——测试篇

    课程信息 课程 软件工程1916|W(福州大学) 团队名称 修!咻咻! 作业要求 项目Alpha冲刺 团队目标 切实可行的计算机协会维修预约平台 团队信息 队员学号 队员姓名 个人博客地址 备注 22 ...

  4. Django CBV和FBV

    Django CBV和FBV Django内部CBV内部接收方法操作: 1.通过客户端返回的请求头RequestMethod与RequesrtURL,会以字符串形式发送到服务器端. 2.取到值后通过d ...

  5. trie、FSA、FST(转)

    add by zhj: 在学习Lucene的存储结构时,看到其使用了FST,这篇文章写的不错. trie,FSA,FST都是用来解决有限状态机的存储,trie是树,它进一步演化为FSA和FST,这两者 ...

  6. 开源规则引擎 Drools 学习笔记 之 -- 1 cannot be cast to org.drools.compiler.kie.builder.impl.InternalKieModule

    直接进入正题 我们在使用开源规则引擎 Drools 的时候, 启动的时候可能会抛出如下异常: Caused by: java.lang.ClassCastException: cn.com.cheng ...

  7. Python开发【第十五篇】模块的导入

    的导入语句 import 语句 语法: import 模块名1 [as 模块别名] 作用: 将某模块整体导入到当前模块 示例: import math import sys,os 用法: 模块名.属性 ...

  8. c# mvc webapi的put报405错误

    程序在本机调试可正常修改,本机是iis11 放到服务器上,报错了:405.服务器iis7.0 返回的错误页面: <!DOCTYPE html PUBLIC "-//W3C//DTD X ...

  9. Jwt身份验证

    转载自博友(TerryTon)  1.因为json是通用的,所以jwt可以在绝大部分平台可以通用,如java,python,php,.net等  2.基于jwt是无状态的,jwt可以用于分布式等现在比 ...

  10. web技术栈开发原生应用-多端共用一套代码

    weex: vuejs开发原生应用 nativescript: vuejs开发原生应用 ReactNative = reactjs开发原生应用 ionic = angularjs 开发原生应用