torch.autograd 是PyTorch的自动微分引擎,用以推动神经网络训练。在本节,你将会对autograd如何帮助神经网络训练的概念有所理解。

背景

神经网络(NNs)是在输入数据上执行的嵌套函数的集合。这些函数由参数(权重、偏置)定义,并在PyTorch中保存于tensors中。

训练NN需要两个步骤:

  • 前向传播:在前向传播中(forward prop),神经网络作出关于正确输出的最佳预测。它使输入数据经过每一个函数来作出预测。
  • 反向传播:在反向传播中(backprop),神经网络根据其预测中的误差来调整其参数,它通过从输出向后遍历,收集关于函数参数的误差的导数(梯度),并使用梯度下降优化参数。有关更多关于反向传播的细节,参见video from 3Blue1Brownvideo from 3Blue1Brown。

在PyTorch中的使用

让我们来看一下单个训练步骤。对于这个例子,我们从 torchvision 加载了一个预训练的resnet18模型。我们创建了一个随机数据tensor,用以表示一个3通道图片,其高和宽均为64,而其对应的 label 初始化为某一随机值。

import torch, torchvision
model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 64, 64)
labels = torch.rand(1, 1000)

接下来,我们将数据输入模型,经过模型的每一层最后作出预测。这是前向过程

prediction = model(data) # forward pass

我们使用模型的预测及其对应的标签计算误差(loss)。下一步是通过网络反向传播误差。当在误差tensor上调用.backward()时,反向传播开始。然后,Autograd计算针对每一个模型参数的梯度,并将其保存在参数的 .grad 属性中。

loss = (prediction - labels).sum()
loss.backward() # backward pass

接下来,我们加载一个优化器,在此案例中是SGD,学习率是0.01,动量参数(momentum)是0.9。我们在优化器中注册所有的模型参数。

optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

最后,我们调用 .step()启动梯度下降。优化器会通过保存在 .grad 的参数梯度调整所有参数。

optim.step() # gradient descent

此时,你已拥有训练神经网络所需的一切。以下部分详细介绍了autograd的工作原理 - 可随意跳过。

Autograd中的微分

让我们来看一下 autograd是如何收集梯度的。创建两个tensor ab,并且 requires_grad=True。这向 autograd 发出信号,跟踪在它们上执行的每一个操作。

import torch
a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)

ab 创建tensor Q

\[Q = 3a^2 - b^2
\]
Q = 3*a**2 - b**2

假设 ab 是一个神经网络的参数,Q 是误差。在NN训练中,求解关于参数的梯度,即:

\[\frac{\partial Q}{\partial a} = 9a^2
\]
\[\frac{\partial Q}{\partial b} = -2b
\]

当我们在 Q 上调用 .backward(),autograd计算以上梯度并保存在对应tensor的 .grad 属性中。

Q.backward() 是一个向量,因此我们需要在 Q.backward() 中显示地传递一个 gradient 参数。gradient 是一个和 Q相同形状的tensor,它表示Q关于其本身的梯度,即:

\[\frac{\partial Q}{\partial Q} = 1
\]

等效地,我们还可以将Q聚合为一个标量,并隐式的向后调用,如 Q.sum().backward()

external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)

梯度现在杯保存在 a.gradb.grad

## 检查收集的梯度是否正确
print(9*a**2 == a.grad)
print(-2*b == b.grad)

输出:

tensor([True, True])
tensor([Ture, True])

选读 - 使用 autograd 进行矢量微分

计算图

从概念上来说,autograd在一个由Function对象组成的有向无环图(DAG)中记录了数据(tensors)和所有执行的操作(连同由此产生的新tensors)。在DAG中,叶节点是输入tensors,根节点是输出tensors。通过从根节点到叶节点跟踪此图,你可以使用链式法则自动计算梯度。

在前向过程中,autograd同时进行两件事:

  • 执行请求的操作计算结果tensor,
  • 在DAG中保留操作的 gradient function

在DAG根节点处调用 .backward() 时启动反向过程。然后autograd

  • 由每个 .grad_fn计算梯度,
  • 将梯度累积在其对应tensor的 .grad 属性中,
  • 使用链式法则,将梯度一直传播到叶节点。

下图是以上例子中DAG的可视化表示。在该图中,箭头表示前向过程的方向。节点表示在前向过程中每一个操作的backward functions。蓝色叶节点表示我们的tensor ab

注意:DAGs在PyTorch中是动态的。需要重点注意的是:DAG是从头开始重新创建的,在每次 .backward调用时,autograd开始填充一个新图。这正是在模型中允许你使用控制流语句的原因。如果需要,你可以在每次迭代中更改形状、大小和操作。

从DAG中排除

torch.autograd 跟踪所有 requires_grad=True 的tensor上的操作。对于不要求计算梯度的tensor,requires_grad=False,并将其从梯度计算DAG中排除。

当一个操作就算只有一个输入tensor有 requires_grad=True,其输出的tensor仍然要计算梯度。

x = torch.rand(5, 5)
y = torch.rand(5, 5)
z = torch.rand((5, 5), requires_grad=True) a = x + y
print(f"Does 'a' require gradients? : {a.requires_grad}")
b = x + z
print(f"Does 'b' require gradients? : {b.requires_grad}")

输出:

Does `a` require gradients? : False
Does `b` require gradients?: True

在神经网络中,不计算梯度的参数通常成为冻结参数。如果你事先知道不需要这些参数的梯度,那冻结模型的一部分很有用(这通过减少autograd计算量提供了一些性能优势)。

从DAG中排除的另一个重要的常见用法是finetuning a pretrained network

在finetune中,我们冻结模型的大部分参数,并且通常只修改分类层以对新的标签作出预测。让我们通过一个小例子来演示这一点。像之前一样,我们加载一个预训练resnet18模型,并且冻结所有参数。

from torch import nn, optim

model = torchvision.models.resnet18(pretrained=True)

# 冻结网络中的所有参数
for param in model.parameters():
param.requires_grad = False

假设我们要在一个10标签数据集上微调模型。在resnet中,分类层是最后的线性层 model.fc。我们可以简单地用一个新的线性层(默认情况下未冻结)替换它作为我们的分类器。

model.fc = nn.Linear(512, 10)

模型中除了 model.fc 的所有参数均被冻结。需要计算梯度的参数仅仅是 model.fc 的权重和偏置

# 仅优化分类层
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

注意,尽管我们在优化器中注册了所有参数,但是计算梯度(在梯度下降中更新)的参数仅是分类层的权重和偏置。

The same exclusionary functionality is available as a context manager in torch.no_grad().

DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TORCH.AUTOGRAD的更多相关文章

  1. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TENSORS

    Tensor是一种特殊的数据结构,非常类似于数组和矩阵.在PyTorch中,我们使用tensor编码模型的输入和输出,以及模型的参数. Tensor类似于Numpy的数组,除了tensor可以在GPU ...

  2. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | NEURAL NETWORKS

    神经网络可以使用 torch.nn包构建. 现在你已经对autograd有所了解,nn依赖 autograd 定义模型并对其求微分.nn.Module 包括层,和一个返回 output 的方法 - f ...

  3. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TRAINING A CLASSIFIER

    你已经知道怎样定义神经网络,计算损失和更新网络权重.现在你可能会想, 那么,数据呢? 通常,当你需要解决有关图像.文本或音频数据的问题,你可以使用python标准库加载数据并转换为numpy arra ...

  4. Deep learning with PyTorch: A 60 minute blitz _note(1) Tensors

    Tensors 1. construst matrix 2. addition 3. slice from __future__ import print_function import torch ...

  5. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  6. Neural Network Programming - Deep Learning with PyTorch with deeplizard.

    PyTorch Prerequisites - Syllabus for Neural Network Programming Series PyTorch先决条件 - 神经网络编程系列教学大纲 每个 ...

  7. PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...

  8. Neural Network Programming - Deep Learning with PyTorch - YouTube

    百度云链接: 链接:https://pan.baidu.com/s/1xU-CxXGCvV6o5Sksryj3fA 提取码:gawn

  9. (zhuan) Where can I start with Deep Learning?

    Where can I start with Deep Learning? By Rotek Song, Deep Reinforcement Learning/Robotics/Computer V ...

随机推荐

  1. java 图形化工具Swing 监听键盘输入字符触发动作getInputMap();getActionMap();

    双缓冲技术的介绍: 所有的Swing组件默认启用双缓冲绘图技术.使用双缓冲技术能改进频繁重绘GUI组件的显示效果(避免闪烁现象)JComponent组件默认启用双缓冲,无须自己实现双缓冲.如果想关闭双 ...

  2. Photoshop学习笔记(一)

    1.Alt+delete,用前景色填充选区 2.按住shift键可以新加选区 3.按住alt键可以减去选区 4.第一次选择选区时按住shift键制作出正方形或者圆形 5.第一次选择选区时按住alt键将 ...

  3. 【LeetCode】913. Cat and Mouse 解题报告(Python)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 参考资料 日期 题目地址:https://leetc ...

  4. 使用Java对接永中格式转换

    永中格式转换服务基于永中DCS的文档转换能力,支持不同格式文件之间的高质量互转,可实现PDF文档与Word.Excel.PPT.图片的高质量互转,PDF文档转换完美保留原文档的版式,格式等,转换效果出 ...

  5. Sentry 开发者贡献指南 - SDK 开发(事件负载)

    内容整理自官方开发文档 系列 Docker Compose 部署与故障排除详解 1 分钟快速使用 Docker 上手最新版 Sentry-CLI - 创建版本 快速使用 Docker 上手 Sentr ...

  6. StringBoot

    1.首先我们需要依赖SpringBoot父工程,这是每个项目中必须要有的. <!--引入SpringBoot父依赖--><parent>        <groupId& ...

  7. Chapter 2 Randomized Experiments

    目录 概 2.1 Randomization 2.2 Conditional randomization 2.3 Standardization 2.4 Inverse probability wei ...

  8. v75.01 鸿蒙内核源码分析(远程登录篇) | 内核如何接待远方的客人 | 百篇博客分析OpenHarmony源码

    子曰:"不学礼,无以立 ; 不学诗,无以言 " <论语>:季氏篇 百篇博客分析.本篇为: (远程登录篇) | 内核如何接待远方的客人 设备驱动相关篇为: v67.03 ...

  9. 每天学一点——while循环(2)、for循环

    while循环(2) while+continue 打印数字的话相信朋友们在python中不会一个个的print吧 eg: 或者是打印列表里的元素 eg 这种方法只适用于你知道里面有多少个元素, 不然 ...

  10. Linux下设置普通用户使用sudo命令

    1.登录root用户 2.增加root用户对文件sudoers的写权限 chmod u+w /etc/sudoers 3.编辑sudoers,把用户mysql添加进去 vi /etc/sudoers ...