5 自动微分

求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

5.1 一个简单的例子

作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。

import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])

在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。

x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None

现在计算y。

y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)

x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。

y.backward()
x.grad
tensor([ 0.,  4.,  8., 12.])

函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。

x.grad == 4 * x
tensor([True, True, True, True])

现在计算x的另一个函数。

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])

5.2 非标量变量的反向传播

当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])

5.3 分离计算

有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))

由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])

5.4 Python控制流的梯度计算

使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c

计算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。

a.grad == d / a
tensor(True)

【pytorch学习】之自动微分的更多相关文章

  1. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  2. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  3. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  4. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  5. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  6. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  7. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  8. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  9. PyTorch 学习

    PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...

  10. MindSpore:自动微分

    MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...

随机推荐

  1. 个人呕心沥血编写的全网最详细的kettle教程书籍

    笔者呕心沥血编写的kettle教程,涉及到kettle的每个控件的讲解和详细的实战示例 可以说是全网最详细的kettle教程,三天学完你就可以成为优秀的ETL专家!!! 现在免费分享出来!视频教程也已 ...

  2. Android Progressbar进度条样式调整为圆角矩形,且改变颜色

    原文地址: Android Progressbar进度条样式调整为圆角矩形,且改变颜色 美工设计的进度条是圆角矩形的,与Android默认的样式有所区别,可以通过样式progressDrawable属 ...

  3. Grails的数据库相关开发

    1.开发domain和service 在出来的输入框里输入domain的名字,可以包括包名. 这里我输入test.domain.House,点finish 创建了两个groovy文件,一个当然是tes ...

  4. 云VR的未来发展方向

    随着元宇宙元年的到来,VR正呈现出蓬勃的发展势头.然而,更好的用户体验大多依赖于高性能PC或主机进行本地渲染,这使得用户的VR消费成本更高,在一定程度上影响了产业发展,成为业界亟待解决的问题. 的确, ...

  5. Python简单程序设计(计算程序设计(存款利息)篇)

    如题: 解题方式如下:

  6. 记录--手写$forceUpdate,vm.$destroy方法

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 vm.$forceUpdate (1)作用 迫使Vue.js实例重新渲染.注意它仅仅影响实例本身以及插入插槽内容的子组件,而不是所有子组件 ...

  7. 《Go程序设计语言》学习笔记之map

    <Go程序设计语言>学习笔记之map 一. 环境 Centos8.5, go1.17.5 linux/amd64 二. 概念 1) 散列表,是一个拥有键值对元素的无序集合.在这个集合中,键 ...

  8. 超详细的彻底卸载VMware虚拟机方法

    一.在卸载VMware虚拟机之前,要先把与VMware相关的服务和进程终止 1.在windows中按下[Windows键],搜索[服务]设置,然后打开: 2.找到以VM打头命名的服务,然后右键停止这些 ...

  9. 【K8S】Kubernetes中暴露外部IP地址来访问集群中的应用

    本文是Kubernetes.io官方文档中介绍如何创建暴露外部IP地址的Kubernetes Service 对象. 学习目标 运行Hello World应用程序的五个实例. 创建一个暴露外部IP地址 ...

  10. ChatGPT写作提示词指令大全

    1 .用ChatGPT写影评 指令:你是一个自媒体人,同时也是一个专业的影评人.最近熬夜看完了韩剧黑暗荣耀第一季和第二季,忍不住想在公众号分享给粉丝们,请写一篇1000字左右的自媒体文章,并且加上一个 ...