查看非叶节点梯度的两种方法

在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

  • 使用autograd.grad函数
  • 使用hook

autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

求z对y的导数

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() # hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):
print("hook梯度输出:\r\n",grad) hook_handle = y.register_hook(variable_hook) # 注册hook
z.backward(retain_graph=True) # 内置输出上面的hook
hook_handle.remove() # 释放 print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出:
Variable containing:
1
1
1
[torch.FloatTensor of size 3] autograd.grad输出:
(Variable containing:
1
1
1
[torch.FloatTensor of size 3]
,)

多次反向传播试验

实际就是使用retain_graph参数,

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
1
1
1
[torch.FloatTensor of size 3] Variable containing:
2
2
2
[torch.FloatTensor of size 3]

如果不使用retain_graph参数,

实际上效果是一样的,AccumulateGrad object仍然会积累梯度

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() z.backward()
print(w.grad)
y = w.mul(x) # <-----
z = y.sum() # <-----
z.backward()
print(w.grad)
Variable containing:
1
1
1
[torch.FloatTensor of size 3] Variable containing:
2
2
2
[torch.FloatTensor of size 3]

分析:

这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的"动态"过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。

实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。

总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。

『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究的更多相关文章

  1. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  2. 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数

    一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...

  3. 『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较_&_广播原理简介

    一.简单数学操作 1.逐元素操作 t.clamp(a,min=2,max=4)近似于tf.clip_by_value(A, min, max),修剪值域. a = t.arange(0,6).view ...

  4. 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor

    Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...

  5. 『PyTorch』第五弹_深入理解Tensor对象_中上:索引

    一.普通索引 示例 a = t.Tensor(4,5) print(a) print(a[0:1,:2]) print(a[0,:2]) # 注意和前一种索引出来的值相同,shape不同 print( ...

  6. 『PyTorch』第五弹_深入理解Tensor对象_上:初始化以及尺寸调整

    一.创建Tensor 特殊方法: t.arange(1,6,2)t.linspace(1,10,3)t.randn(2,3) # 标准分布,*size t.randperm(5) # 随机排序,从0到 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  9. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

随机推荐

  1. Linux基础命令---cal

    cal cal指令可以显示一个日历信息,如果没有指定选项和参数,那么就会显示当前的月份. 此命令的适用范围:RedHat.RHEL.Ubuntu.CentOS.SUSE.openSUSE.Fedora ...

  2. bzoj3524 [Poi2014]Couriers/2223 [Coci 2009]PATULJCI

    题目链接1 题目链接2 主席树模板题 两题有细节不同 #include<algorithm> #include<iostream> #include<cstdlib> ...

  3. Linux学习笔记之Linux Centos关闭防火墙

    # Centos6.x /etc/init.d/iptables stop chkconfig iptables off sed -i 's/SELINUX=enforcing/SELINUX=dis ...

  4. UVA302 John's trip(欧拉回路)

    UVA302 John's trip 欧拉回路 attention: 如果有多组解,按字典序输出. 起点为每组数据所给的第一条边的编号较小的路口 每次输出完额外换一行 保证连通性 每次输入数据结束后, ...

  5. c++继承、多态以及与java的行为差异之处

    对于面向对象而言,多态是最有用的基本特性之一,相对于函数指针,易用得多.下面看下c++继承和多态行为的基本特性,最后说明下和java的基本差别. 首先定义父类和子类. base.h #pragma o ...

  6. 08:Python数据分析之pandas学习

    1.1 数据结构介绍 参考博客:http://www.cnblogs.com/nxld/p/6058591.html 1.pandas介绍 1. 在pandas中有两类非常重要的数据结构,即序列Ser ...

  7. cron表达式增加一段时间变为新的表达式

    cron表达式是使用任务调度经常使用的表达式了.对于通常的简单任务,我们只需要一条cron表达式就能满足.但是有的时候任务也可以很复杂. 最近我遇到了一个问题,一条任务在开始的时候要触发A方法,在结束 ...

  8. 关于定时器、波特率、TH和TL值的计算

    假设晶振位6MHZ,定时10ms 单片机系统晶振频率为6mhz,系统时钟频率 (也是计时脉冲频率)为500KHZ,一个脉冲周期2us ,10ms是5000个脉冲,因此TMOD=0X01;TH0=(65 ...

  9. IE6里样式表不起作用解决方法

    写的html页面引用外部css文件的时候在IE7,IE8和FF中能正常作用,即能正常显示,可在IE6中却完全没有作用到,即css文件里的样式根本未被解析到我们的html页面,这是什么原因? 开 始把c ...

  10. JavaScript 获取地址栏参数

    1. function a() { console.log(this); } a.call(null); window 如果第一个参数传入的对象调用者是null或者undefined的话,call方法 ...