PyTorch官方中文文档:自动求导机制
自动求导机制
本说明将概述Autograd如何工作并记录操作。了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序,并可帮助您进行调试。
从后向中排除子图
每个变量都有两个标志:requires_grad
和volatile
。它们都允许从梯度计算中精细地排除子图,并可以提高效率。
艾伯特(http://www.aibbt.com/)国内第一家人工智能门户
requires_grad
如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。
>>> x = Variable(torch.randn(5, 5))
>>> y = Variable(torch.randn(5, 5))
>>> z = Variable(torch.randn(5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True
这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad
标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
volatile
纯粹的inference模式下推荐使用volatile
,当你确定你甚至不会调用.backward()
时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile
也决定了require_grad is False
。
volatile
不同于require_grad
的传递。如果一个操作甚至只有有一个volatile
的输入,它的输出也将是volatile
。Volatility
比“不需要梯度”更容易传递——只需要一个volatile
的输入即可得到一个volatile
的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile
的输入就够了,这将保证不会保存中间状态。
>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad
True
>>> model(volatile_input).requires_grad
False
>>> model(volatile_input).volatile
True
>>> model(volatile_input).creator is None
True
自动求导如何编码历史信息
每个变量都有一个.creator
属性,它指向把它作为输出的函数。这是一个由Function
对象作为节点组成的有向无环图(DAG)的入口点,它们之间的引用就是图的边。每次执行一个操作时,一个表示它的新Function
就被实例化,它的forward()
方法被调用,并且它输出的Variable
的创建者被设置为这个Function
。然后,通过跟踪从任何变量到叶节点的路径,可以重建创建数据的操作序列,并自动计算梯度。
需要注意的一点是,整个图在每次迭代时都是从头开始重新创建的,这就允许使用任意的Python控制流语句,这样可以在每次迭代时改变图的整体形状和大小。在启动训练之前不必对所有可能的路径进行编码—— what you run is what you differentiate.
Variable上的In-place操作
在自动求导中支持in-place操作是一件很困难的事情,我们在大多数情况下都不鼓励使用它们。Autograd的缓冲区释放和重用非常高效,并且很少场合下in-place操作能实际上明显降低内存的使用量。除非您在内存压力很大的情况下,否则您可能永远不需要使用它们。
限制in-place操作适用性主要有两个原因:
1.覆盖梯度计算所需的值。这就是为什么变量不支持log_
。它的梯度公式需要原始输入,而虽然通过计算反向操作可以重新创建它,但在数值上是不稳定的,并且需要额外的工作,这往往会与使用这些功能的目的相悖。
2.每个in-place操作实际上需要实现重写计算图。不合适的版本只需分配新对象并保留对旧图的引用,而in-place操作则需要将所有输入的creator
更改为表示此操作的Function
。这就比较棘手,特别是如果有许多变量引用相同的存储(例如通过索引或转置创建的),并且如果被修改输入的存储被任何其他Variable
引用,则in-place函数实际上会抛出错误。
In-place正确性检查
每个变量保留有version counter,它每次都会递增,当在任何操作中被使用时。当Function
保存任何用于后向的tensor时,还会保存其包含变量的version counter。一旦访问self.saved_tensors
,它将被检查,如果它大于保存的值,则会引起错误。
PyTorch官方中文文档:自动求导机制的更多相关文章
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- PyTorch官方中文文档:torch.optim 优化器参数
内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...
- PyTorch官方中文文档:PyTorch中文文档
PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...
- Pytorch学习(一)—— 自动求导机制
现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API进行学 ...
- PyTorch官方中文文档:torch
torch 包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作.另外,它也提供了多种工具,其中一些可以更有效地对张量和任意类型进行序列化. 它有CUDA 的对应实现,可以在NVIDIA ...
- PyTorch官方中文文档:torch.optim
torch.optim torch.optim是一个实现了各种优化算法的库.大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法. 如何使用optimizer 为了使用t ...
- PyTorch官方中文文档:torch.Tensor
torch.Tensor torch.Tensor是一种包含单一数据类型元素的多维矩阵. Torch定义了七种CPU tensor类型和八种GPU tensor类型: Data tyoe CPU te ...
- Pytorch Autograd (自动求导机制)
Pytorch Autograd (自动求导机制) Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logisti ...
- ReactNative官方中文文档0.21
整理了一份ReactNative0.21中文文档,提供给需要的reactnative爱好者.ReactNative0.21中文文档.chm 百度盘下载:ReactNative0.21中文文档 来源: ...
随机推荐
- 2018/1/19 Netty学习笔记(一)
这段时间学了好多好多东西,不过更多是细节和思想上的,比如分布式事物,二次提交,改善代码质量,还有一些看了一些源码什么的; 记录一下真正的技术学习,关于Netty的学习过程; 首先说Netty之前先说一 ...
- 软件开发:网站&视频&书籍推荐(不断更新)
利用书籍进行系统学习,凭借博客/新闻等资料开阔眼界,辅之以代码及项目实战,并勤加以总结,方可进步. 常用网站: Leetcode刷题:https://leetcode.com/ ,练习数据结构和算法必 ...
- Halcon一日一练:CAD类型的相关操作
大很多场合,需要在视觉程序中导入CAD文档,比如,在3C行业,需要对手机外壳进行CNC加工,或者点胶操作,此时,需要获取产品的各个点的数据.如果将CAD直接导入,就会大的减少编程工作量,同时也能达到很 ...
- linux使用tcpdump抓包工具抓取网络数据包,多示例演示
tcpdump是linux命令行下常用的的一个抓包工具,记录一下平时常用的方式,测试机器系统是ubuntu 12.04. tcpdump的命令格式 tcpdump的参数众多,通过man tcpdump ...
- 分布式高性能消息处理中心HPMessageCenter
# HPMessageCenter 高性能消息分发中心.用户只需写好restful接口,在portal里面配置消息的处理地址,消息消费者就会自动访问相关接口,完成消息任务. ### 部署说明 **创建 ...
- 3.2 while 循环
Python 编程中 while 语句用于循环执行程序,即在条件满足的情况下,循环执行某段代码.所以就需要在循环的代码块中设计一种使代码块循环执行一定次数后是while语句的条件不满足,从而中止whi ...
- qt 字符数组如何转换字符串?
char 字符数组如何转换成 QString? char source{1024} = {0}; QString des = QString::fromLocal8Bit(source);
- NJU 1010 Air
思路:把那张图打表(吐血...),然后就按照规则输出就行. AC代码 #include <cstdio> #include <cmath> #include <cctyp ...
- uva 116 单向TSP
这题的状态很明显. 转移方程就是 d(i,j)=min(d(i+1,j+1),d(i,j+1),d(i-1,j+1)) //注意边界 我用了一个next数组方便打印结果,但是一直编译错误,原来是不能用 ...
- Docker系统五:Docker仓库
创建Docker Hub账户 登录和上传镜像到Hub.docker.com docker login //登陆hub.docker.com docker tag ubutun1404-baseimag ...