pytorch-backword函数的理解

函数:\(tensor.backward(params)\)

这个params的维度一定要和tensor的一致,因为tensor如果是一个向量y = [y1,y2,y3],那么传入的params=[a1,a2,a3],这三个值是系数,那么是什么的系数呢?
假定对x =[ x1,x2]求导,那么我们知道,
\(dy/dx\) 为:
第一列: \(dy1/dx1,dy2/dx1,dy3/dx1\)
第二列:\(dy1/dx2, dy2/dx2,dy3/dx2\)
从而 \(dy/dx\)是一个3行2列的矩阵,每一列对应了对x1的导数,每一列也就是\(x1\)的梯度向量
而反向计算的时候,并不是返回这个矩阵,而是返回这个矩阵每列的和作为梯度,也就是:\(dy1/dx1+dy2/dx1+dy3/dx1\) 是y对x1的梯度
这就好理解了,系数为\(params=[a1,a2,a3]\)就对应了这加和的三项!也就是,对\(x1\)的梯度实际上是\(a1*dy1/dx1+a2*dy2/dx1+a3*dy3/dx1\)
而输出y是标量的时候,就不需要了,默认的就是\(1.\)

自己重写backward函数时,要写上一个grad_output参数,这个参数就是上面提到的params

这个grad_output参数究竟是什么呢?下面作出解释:
是这样的,假如网络有两层, h = h(x),y = y(h)
你可以计算\(dy/dx\),这样,y.backward(),因为\(dy/dy=1\),那么,backward的参数就可以省略
如果计算h.backward(),因为你想求的是\(dy/dx\),(这才是输出对于输入的梯度),那么,计算图中的y = y(h)就没有考虑到
因为\(dy/dx = dy/dh * dh/dx\),h.backward()求得是\(dh/dx\),那么你必须传入之前的梯度\(dy/dh\)才行,也就是说,h.backward(params=dy/dh)这里面的参数就是\(dy/dh\)

这就好理解了,如果我们自己实现了一层,继承自Function,自己实现静态方法forwardbackward时,backward必须有个grad_output参数,这个参数就是计算图中输出对该自定义层的梯度,这样才能求出对输入的梯度。

另外,假设定义的层计算出的是y,调用的就是y.backward(grad_output),这个里面的参数的维度必须和y是相同的。这也就是为什么前面提到对于输出是多维的,会有个“系数”的原因,这个系数就是后向传播时,该层之前的梯度的累积,这样与本层再累积,才实现了完整的链式法则,最终求出outinput的梯度。

另外,自定义实现forwardbackward时,两函数的输入输出是有要求的,即forward的输入必须和~的return相对应,如forwardinput有个w参数,那么backwardreturn就必须在对应的位置返回grad_w,因为只有这样,才能够对相应的输入参数梯度下降。

【pytorch】pytorch-backward()的理解的更多相关文章

  1. ARTS-S pytorch中backward函数的gradient参数作用

    导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...

  2. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  3. Pytorch 之 backward

    首先看这个自动求导的参数: grad_variables:形状与variable一致,对于y.backward(),grad_variables相当于链式法则dz/dx=dz/dy × dy/dx 中 ...

  4. [pytorch] Pytorch入门

    Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...

  5. pytorch lstm crf 代码理解 重点

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  6. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  7. Pytorch的LSTM的理解

    class torch.nn.LSTM(*args, **kwargs) 参数列表 input_size:x的特征维度 hidden_size:隐藏层的特征维度 num_layers:lstm隐层的层 ...

  8. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  9. pytorch的backward

    在学习的过程中遇见了一个问题,就是当使用backward()反向传播时传入参数的问题: net.zero_grad() #所有参数的梯度清零 output.backward(Variable(t.on ...

随机推荐

  1. SqlServer注意事项总结,高级程序员必背!

    本篇文章主要介绍SqlServer使用时的注意事项. 想成为一个高级程序员,数据库的使用是必须要会的.而数据库的使用纯熟程度,也侧面反映了一个开发的水平. 下面介绍SqlServer在使用和设计的过程 ...

  2. Python编程从入门到实践笔记——if语句

    Python编程从入门到实践笔记——if语句 #coding=utf-8 cars=['bwm','audi','toyota','subaru','maserati'] bicycles = [&q ...

  3. 使用Atlas进行元数据管理之Glossary(术语)

    背景:笔者和团队的小伙伴近期在进行数据治理/元数据管理方向的探索, 在接下来的系列文章中, 会陆续与读者们进行分享在此过程中踩过的坑和收获. 元数据管理系列文章: [0] - 使用Atlas进行元数据 ...

  4. 如何使用git和ssh部署本地代码到服务器

    一.首先设置好自己本地的Git用户名和密码: git config --global user.name "your name" git config --global user. ...

  5. SQL Server读写分离之发布订阅

    一.发布 上面有多种发布方式,这里我选择事物发布,具体区别请自行百度. 点击下一步.然后继续选择需要发布的对象.  如果需要筛选发布的数据点击添加. 根据自己的计划选择发布的时间. 点击安全设置,设置 ...

  6. java实现 批量转换文件编码格式

    一.场景说明 不知道大家有没有遇到过之前项目是GBK,现在需要全部换成UTF-8的情况.反正我是遇到了. eclipse可以改变项目的编码格式,但是文件如果直接转换的话里面的中文就会全部乱码,需要先复 ...

  7. 查看apk签名 和 keystore 的信息

    原文出处:https://www.jianshu.com/p/90b698002215 1.keytool -printcert -file ***(把apk文件下的META- INF文件夹解压出来, ...

  8. Spark RPC框架源码分析(一)简述

    Spark RPC系列: Spark RPC框架源码分析(一)运行时序 Spark RPC框架源码分析(二)运行时序 Spark RPC框架源码分析(三)运行时序 一. Spark rpc框架概述 S ...

  9. MySQL, XE7使用FireDAC连接MySQL数据库

    发现使用DBExpress进行MySQL连接老是有莫名其妙的问题,直接改为FireDAC 在上一篇的DataSnap服务框架程序中,将连接的数据库由MSSQL改为本文的MySQL 使用的MySQL数据 ...

  10. nginx常用场景

    1.浏览器缓存 server { listen 8083; server_name 127.0.0.1; sendfile on; access_log /var/log/nginx/static_s ...