5 自动微分

求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

5.1 一个简单的例子

作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。

import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])

在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。

x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None

现在计算y。

y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)

x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。

y.backward()
x.grad
tensor([ 0.,  4.,  8., 12.])

函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。

x.grad == 4 * x
tensor([True, True, True, True])

现在计算x的另一个函数。

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])

5.2 非标量变量的反向传播

当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])

5.3 分离计算

有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))

由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])

5.4 Python控制流的梯度计算

使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c

计算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。

a.grad == d / a
tensor(True)

【pytorch学习】之自动微分的更多相关文章

  1. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  2. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  3. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  4. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  5. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  6. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  7. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  8. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  9. PyTorch 学习

    PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...

  10. MindSpore:自动微分

    MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...

随机推荐

  1. archlinux 使用ventoyU盘启动器(ISO)

    ventoy详细介绍https://www.ventoy.net/cn/doc_start.html Linux系统安装 Ventoy -- 命令行界面 下载安装包,例如 ventoy-1.0.00- ...

  2. 基于ADS1292芯片的解决方案之芯片简析

    基本资料: ADS1292芯片是多通道同步采样 24 位 Δ-Σ 模数转换器 (ADC),它们具有内置的可编程增益放大器 (PGA).内部基准和板载振荡器. ADS1292 包含 便携式 低功耗医疗心 ...

  3. springMVC+JDBC:分页示例

    文章来源:http://liuzidong.iteye.com/blog/1067492 一 环境:XP3+Oracle10g+MyEclipse6+(Tomcat)+JDK1.5 二 工程相关图片: ...

  4. eclipse错误之Errors occurred during the build. Errors running builder 'JavaScript Validator' on project

    把JavaScript Validator去掉.去掉的方法是:选择一个项目--右键Properties--Builders(排第二)--点一下右侧会有四项--取消第一项"JavaScript ...

  5. Android设备上运行live555的推流程序

    在live555使用NDK21编译出arm64-v8a和armeabi-v7a中我们编译出了v8a和v7a的可执行文件 我们可以使用testH264VideoStreamer程序进行推流 我们将tes ...

  6. C++ 调用 Python 总结(一)

    PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明   本文作为本人csdn blog的主站的备份.(Bl ...

  7. .NET周刊【3月第2期 2024-03-17】

    国内文章 开源.NET8.0小项目伪微服务框架(分布式.EFCore.Redis.RabbitMQ.Mysql等) https://www.cnblogs.com/aehyok/p/18058032 ...

  8. Python 合并Excel文件(Excel文件多sheet)

    一.Python合并Excel文件多sheet<方法1> import os import pandas as pd # 指定包含Excel文件的文件夹路径 folder_path = ' ...

  9. 【K8S】Docker向私有仓库拉取/推送镜像报错(http: server gave HTTP response to HTTPS client)

    这里,我们搭建的Harbor仓库的地址为 http://192.168.175.101:1180. 报错信息如下所示. [root@binghe101 ~]# docker login 192.168 ...

  10. KingbaseES 中的xmin,xmax等系统字段说明

    在KingbaseES中,当我们创建一个数据表时,数据库会隐式增加几个系统字段.这些字段由系统进行维护,用户一般不会感知它们的存在. 例如,以下语句创建了一个简单的表: create table te ...