pytorch-backword函数的理解

函数:\(tensor.backward(params)\)

这个params的维度一定要和tensor的一致,因为tensor如果是一个向量y = [y1,y2,y3],那么传入的params=[a1,a2,a3],这三个值是系数,那么是什么的系数呢?
假定对x =[ x1,x2]求导,那么我们知道,
\(dy/dx\) 为:
第一列: \(dy1/dx1,dy2/dx1,dy3/dx1\)
第二列:\(dy1/dx2, dy2/dx2,dy3/dx2\)
从而 \(dy/dx\)是一个3行2列的矩阵,每一列对应了对x1的导数,每一列也就是\(x1\)的梯度向量
而反向计算的时候,并不是返回这个矩阵,而是返回这个矩阵每列的和作为梯度,也就是:\(dy1/dx1+dy2/dx1+dy3/dx1\) 是y对x1的梯度
这就好理解了,系数为\(params=[a1,a2,a3]\)就对应了这加和的三项!也就是,对\(x1\)的梯度实际上是\(a1*dy1/dx1+a2*dy2/dx1+a3*dy3/dx1\)
而输出y是标量的时候,就不需要了,默认的就是\(1.\)

自己重写backward函数时,要写上一个grad_output参数,这个参数就是上面提到的params

这个grad_output参数究竟是什么呢?下面作出解释:
是这样的,假如网络有两层, h = h(x),y = y(h)
你可以计算\(dy/dx\),这样,y.backward(),因为\(dy/dy=1\),那么,backward的参数就可以省略
如果计算h.backward(),因为你想求的是\(dy/dx\),(这才是输出对于输入的梯度),那么,计算图中的y = y(h)就没有考虑到
因为\(dy/dx = dy/dh * dh/dx\),h.backward()求得是\(dh/dx\),那么你必须传入之前的梯度\(dy/dh\)才行,也就是说,h.backward(params=dy/dh)这里面的参数就是\(dy/dh\)

这就好理解了,如果我们自己实现了一层,继承自Function,自己实现静态方法forwardbackward时,backward必须有个grad_output参数,这个参数就是计算图中输出对该自定义层的梯度,这样才能求出对输入的梯度。

另外,假设定义的层计算出的是y,调用的就是y.backward(grad_output),这个里面的参数的维度必须和y是相同的。这也就是为什么前面提到对于输出是多维的,会有个“系数”的原因,这个系数就是后向传播时,该层之前的梯度的累积,这样与本层再累积,才实现了完整的链式法则,最终求出outinput的梯度。

另外,自定义实现forwardbackward时,两函数的输入输出是有要求的,即forward的输入必须和~的return相对应,如forwardinput有个w参数,那么backwardreturn就必须在对应的位置返回grad_w,因为只有这样,才能够对相应的输入参数梯度下降。

【pytorch】pytorch-backward()的理解的更多相关文章

  1. ARTS-S pytorch中backward函数的gradient参数作用

    导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...

  2. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  3. Pytorch 之 backward

    首先看这个自动求导的参数: grad_variables:形状与variable一致,对于y.backward(),grad_variables相当于链式法则dz/dx=dz/dy × dy/dx 中 ...

  4. [pytorch] Pytorch入门

    Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...

  5. pytorch lstm crf 代码理解 重点

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  6. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  7. Pytorch的LSTM的理解

    class torch.nn.LSTM(*args, **kwargs) 参数列表 input_size:x的特征维度 hidden_size:隐藏层的特征维度 num_layers:lstm隐层的层 ...

  8. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  9. pytorch的backward

    在学习的过程中遇见了一个问题,就是当使用backward()反向传播时传入参数的问题: net.zero_grad() #所有参数的梯度清零 output.backward(Variable(t.on ...

随机推荐

  1. Java并发专题(三)深入理解volatile关键字

    前言 上一章节简单介绍了线程安全以及最基础的保证线程安全的方法,建议大家手敲代码去体会.这一章会提到volatile关键字,虽然看起来很简单,但是想彻底搞清楚需要具备JMM.CPU缓存模型的知识.不要 ...

  2. Springboot 系列(十三)使用邮件服务

    在我们这个时代,邮件服务不管是对于工作上的交流,还是平时的各种邮件通知,都是一个十分重要的存在.Java 从很早时候就可以通过 Java mail 支持邮件服务.Spring 更是对 Java mai ...

  3. centos7+rsyslog+loganalyzer+mysql 搭建rsyslog日志服务器

    一.简介 在centos7系统中,默认的日志系统是rsyslog,它是一类unix系统上使用的开源工具,用于在ip网络中转发日志信息,rsyslog采用模块化设计,是syslog的替代品. 1.rsy ...

  4. 查询拼接SQL语句,多条件模糊查询

    多条件查询,使用StringBuilder拼接SQL语句,效果如下: 当点击按钮时代码如下: private void button1_Click(object sender, EventArgs e ...

  5. Php7.3 could not find driver

    今天phpstudy升级php7.3,发现框架报错:could not find driver,后来发现默认php.ini的配置有几个是注释掉的,配置php.ini,修改如下 extension=my ...

  6. 观察者模式与.Net Framework中的委托与事件

    本文文字内容均选自<大话设计模式>一书. 解释:观察者模式定义了一种一对多的依赖关系,让多个观察者对象同时监听某一个主题对象.这个主题对象在状态发生变化时,会通知所有观察者对象,使它们能够 ...

  7. BestSync多终端文件资料同步利器

    分享一款多终端文件同步的强力软件,windows下使用. 我这里的多终端意思是,多台电脑.移动存储.云端. 就我个人而言,实用性在于移动硬盘和电脑上都有的文件,比如保存项目资料,电脑上需要编辑,有时外 ...

  8. 仿微信未读RecyclerView平滑滚动定位效果

    效果图有红点的地方表示有未读消息,依次双击首页图标定位,然后定位到某个未读在手动下滑一点距离在次点击定位效果 用过 RecyclerView 的人都知道,自带有几个滚动到item下标的方法,但是不靠谱 ...

  9. linux添加crontab定时任务

    1.crontab -e命令进入linux定时任务编辑界面,举个简单的例子,比如我要定时往txt文件写入 */ * * * * .txt */1就是每隔一分钟像文件写入,其他一些详细的操作大家可以去网 ...

  10. python-对requests请求简单的封装

    # coding:utf-8 import requests class send_request: def __init__(self,url,method,data=None): self.res ...