自动求导机制

本说明将概述Autograd如何工作并记录操作。了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序,并可帮助您进行调试。

从后向中排除子图

每个变量都有两个标志:requires_gradvolatile。它们都允许从梯度计算中精细地排除子图,并可以提高效率。

艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

requires_grad

如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。

>>> x = Variable(torch.randn(5, 5))
>>> y = Variable(torch.randn(5, 5))
>>> z = Variable(torch.randn(5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

volatile

纯粹的inference模式下推荐使用volatile,当你确定你甚至不会调用.backward()时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile也决定了require_grad is False

volatile不同于require_grad的传递。如果一个操作甚至只有有一个volatile的输入,它的输出也将是volatileVolatility比“不需要梯度”更容易传递——只需要一个volatile的输入即可得到一个volatile的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile的输入就够了,这将保证不会保存中间状态。

>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad
True
>>> model(volatile_input).requires_grad
False
>>> model(volatile_input).volatile
True
>>> model(volatile_input).creator is None
True

自动求导如何编码历史信息

每个变量都有一个.creator属性,它指向把它作为输出的函数。这是一个由Function对象作为节点组成的有向无环图(DAG)的入口点,它们之间的引用就是图的边。每次执行一个操作时,一个表示它的新Function就被实例化,它的forward()方法被调用,并且它输出的Variable的创建者被设置为这个Function。然后,通过跟踪从任何变量到叶节点的路径,可以重建创建数据的操作序列,并自动计算梯度。

需要注意的一点是,整个图在每次迭代时都是从头开始重新创建的,这就允许使用任意的Python控制流语句,这样可以在每次迭代时改变图的整体形状和大小。在启动训练之前不必对所有可能的路径进行编码—— what you run is what you differentiate.

Variable上的In-place操作

在自动求导中支持in-place操作是一件很困难的事情,我们在大多数情况下都不鼓励使用它们。Autograd的缓冲区释放和重用非常高效,并且很少场合下in-place操作能实际上明显降低内存的使用量。除非您在内存压力很大的情况下,否则您可能永远不需要使用它们。

限制in-place操作适用性主要有两个原因:

1.覆盖梯度计算所需的值。这就是为什么变量不支持log_。它的梯度公式需要原始输入,而虽然通过计算反向操作可以重新创建它,但在数值上是不稳定的,并且需要额外的工作,这往往会与使用这些功能的目的相悖。

2.每个in-place操作实际上需要实现重写计算图。不合适的版本只需分配新对象并保留对旧图的引用,而in-place操作则需要将所有输入的creator更改为表示此操作的Function。这就比较棘手,特别是如果有许多变量引用相同的存储(例如通过索引或转置创建的),并且如果被修改输入的存储被任何其他Variable引用,则in-place函数实际上会抛出错误。

In-place正确性检查

每个变量保留有version counter,它每次都会递增,当在任何操作中被使用时。当Function保存任何用于后向的tensor时,还会保存其包含变量的version counter。一旦访问self.saved_tensors,它将被检查,如果它大于保存的值,则会引起错误。

PyTorch官方中文文档:自动求导机制的更多相关文章

  1. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  2. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  3. PyTorch官方中文文档:PyTorch中文文档

    PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...

  4. Pytorch学习(一)—— 自动求导机制

    现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API进行学 ...

  5. PyTorch官方中文文档:torch

    torch 包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作.另外,它也提供了多种工具,其中一些可以更有效地对张量和任意类型进行序列化. 它有CUDA 的对应实现,可以在NVIDIA ...

  6. PyTorch官方中文文档:torch.optim

    torch.optim torch.optim是一个实现了各种优化算法的库.大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法. 如何使用optimizer 为了使用t ...

  7. PyTorch官方中文文档:torch.Tensor

    torch.Tensor torch.Tensor是一种包含单一数据类型元素的多维矩阵. Torch定义了七种CPU tensor类型和八种GPU tensor类型: Data tyoe CPU te ...

  8. Pytorch Autograd (自动求导机制)

    Pytorch Autograd (自动求导机制) Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logisti ...

  9. ReactNative官方中文文档0.21

    整理了一份ReactNative0.21中文文档,提供给需要的reactnative爱好者.ReactNative0.21中文文档.chm  百度盘下载:ReactNative0.21中文文档 来源: ...

随机推荐

  1. ajax请求 readyState为0 可能原因之一

    问题:同样的代码逻辑,PC端和iOS都能正常访问,但是Android系统请求都是报错: 上网查阅,关于ajax请求失败且状态码都是0的情况有很多,最后排查的原因是:域名证书问题:

  2. ajax常用操作

    load的方法的使用(现在已不常用) <!doctype html><html lang="en"><head> <meta charse ...

  3. redis —主从&&集群(CLUSTER)

    REDIS主从配置 为了节省资源,本实验在一台机器进行.即,在一台机器上启动两个端口,模拟两台机器. 机器准备: [root@adailinux ~]# cp /etc/redis.conf /etc ...

  4. [HNOI2008] GT考试

    [HNOI2008] GT考试 标签 : DP 矩阵乘法 题目链接 题意 n位数中不出现一个子串的方案数. 题解 \(设dp[i][j]\)为前i位匹配到j时的合法方案数.(所谓合法,就是不能有别的匹 ...

  5. laravel框架基础知识点

    一.数据库:DB    1.db查    DB::table('msg')->where('id','>',$id)->get()       查询单行    DB::table(' ...

  6. MySQL数据库基础(一)(启动/停止、登录/退出、语法规范及最基础操作)

    1.启动/停止MySQL服务 启动:net start mysql    停止:net stop mysql 2.MySQL登录/退出 登录:mysql 参数:如果连接的是本地服务器,一般用命令:my ...

  7. 基于JDK1.8的ConcurrentHashMap分析

    之前看过ConcurrentHashMap的分析,感觉也了解的七七八八了.但昨晚接到了面试,让我把所知道的ConcurrentHashMap全部说出来. 然后我结结巴巴,然后应该毫无意外的话就G了,今 ...

  8. C语言_初步了解一下指针

    指针的基本概念 在计算机中,所有的数据都是存放在存储器中的. 一般把存储器中的一个字节称为一个内存单元, 不同的数据类型所占用的内存单元数不等,如整型量占2个单元,字符量占1个单元等.为了正确地访问这 ...

  9. 《android开发进阶从小工到专家》读书笔记--网络框架的设计与实现

    第一步: 第一层:Request--请求类型,JSON,字符串,文件 第二层:消息队列--维护了提交给网络框架的请求列表,并且根据响应的规则进行排序.默认情况下按照优先级和进入队列的顺序来执行,该队列 ...

  10. lower_bound()返回值

    lower_bound()函数实现功能就是二分查找,函数lower_bound()在first和last中的前闭后开区间进行二分查找,返回大于或等于val的第一个元素位置.如果所有元素都小于val,则 ...