『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究
查看非叶节点梯度的两种方法
在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:
- 使用autograd.grad函数
- 使用hook
autograd.grad和hook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。
求z对y的导数
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() # hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):
print("hook梯度输出:\r\n",grad) hook_handle = y.register_hook(variable_hook) # 注册hook
z.backward(retain_graph=True) # 内置输出上面的hook
hook_handle.remove() # 释放 print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出:
Variable containing:
1
1
1
[torch.FloatTensor of size 3] autograd.grad输出:
(Variable containing:
1
1
1
[torch.FloatTensor of size 3]
,)
多次反向传播试验
实际就是使用retain_graph参数,
# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
1
1
1
[torch.FloatTensor of size 3] Variable containing:
2
2
2
[torch.FloatTensor of size 3]
如果不使用retain_graph参数,
实际上效果是一样的,AccumulateGrad object仍然会积累梯度
# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum() z.backward()
print(w.grad)
y = w.mul(x) # <-----
z = y.sum() # <-----
z.backward()
print(w.grad)
Variable containing:
1
1
1
[torch.FloatTensor of size 3] Variable containing:
2
2
2
[torch.FloatTensor of size 3]
分析:
这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的"动态"过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。
实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。
总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。
『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究的更多相关文章
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数
一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...
- 『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较_&_广播原理简介
一.简单数学操作 1.逐元素操作 t.clamp(a,min=2,max=4)近似于tf.clip_by_value(A, min, max),修剪值域. a = t.arange(0,6).view ...
- 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor
Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...
- 『PyTorch』第五弹_深入理解Tensor对象_中上:索引
一.普通索引 示例 a = t.Tensor(4,5) print(a) print(a[0:1,:2]) print(a[0,:2]) # 注意和前一种索引出来的值相同,shape不同 print( ...
- 『PyTorch』第五弹_深入理解Tensor对象_上:初始化以及尺寸调整
一.创建Tensor 特殊方法: t.arange(1,6,2)t.linspace(1,10,3)t.randn(2,3) # 标准分布,*size t.randperm(5) # 随机排序,从0到 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
随机推荐
- linux常用命令:route 命令
Linux系统的route 命令用于显示和操作IP路由表(show / manipulate the IP routing table).要实现两个不同的子网之间的通信,需 要一台连接两个网络的路由器 ...
- windows下nodejs express安装及入门网站,视频资料,开源项目介绍
windows下nodejs express安装及入门网站,视频资料,开源项目介绍,pm2,supervisor,npm,Pomelo,Grunt安装使用注意事项等总结 第一步:下载安装文件下载地址: ...
- Socket和ServletSocket的交互
ServerSocket(int port) 是服务端绑定port端口,调accept()监听等待客户端连接,它返回一个连接队列中的一个socket. Socket(InetAddress addre ...
- php在Nginx环境下进行刷新缓存立即输出,实现常驻进程轮询。
以下面这段代码并不会逐个输出,而是当浏览器筹够一定字节数进行统一输出,结果显而易见,10秒后一次性输出所有内容 for($i=0;$i<10;$i++){ echo $i.'</br> ...
- 对于phy芯片的认识
一,关于phy芯片 以RTL8211E(G)为例 PHY是IEEE802.3中定义的一个标准模块,STA(station management entity,管理实体,一般为MAC或CPU) 通过SM ...
- POJ3241 Object Clustering(最小生成树)题解
题意:求最小生成树第K大的边权值 思路: 如果暴力加边再用Kruskal,边太多会超时.这里用一个算法来减少有效边的加入. 边权值为点间曼哈顿距离,那么每个点的有效加边选择应该是和他最近的4个象限方向 ...
- 51nod 1070 Bash游戏 V4
这种博弈题 都是打表找规律 可我连怎么打表都不会 这个是凑任务的吧....以后等脑子好些了 再琢磨吧 就是斐波那契数列中的数 是必败态 #include<bits/stdc++.h> u ...
- git连接华为开发云devcloud
华为开发运在代码托管方面的个github很类似,引入了代码仓库的概念,同时需要本地安装git客户端,且只能与git进行连接,从这个角度上讲,华为开发云的代码管理部分就是github的功能,下面对git ...
- 纯CSS实现一个微信logo,需要几个标签?
博客已迁移至http://lwzhang.github.io. 纯CSS实现一个微信logo并不难,难的是怎样用最少的html标签实现.我一直在想怎样用一个标签就能实现,最后还是没想出来,就只好用两个 ...
- JavaScript页面跳转的一些实现方法
第一种 <script language=”javascript” type=”text/javascript”> window.location.href=”login.jsp?back ...