5 自动微分

求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

5.1 一个简单的例子

作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。

import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])

在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。

x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None

现在计算y。

y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)

x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。

y.backward()
x.grad
tensor([ 0.,  4.,  8., 12.])

函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。

x.grad == 4 * x
tensor([True, True, True, True])

现在计算x的另一个函数。

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])

5.2 非标量变量的反向传播

当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])

5.3 分离计算

有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))

由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])

5.4 Python控制流的梯度计算

使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c

计算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。

a.grad == d / a
tensor(True)

【pytorch学习】之自动微分的更多相关文章

  1. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  2. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  3. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  4. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  5. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  6. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  7. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  8. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  9. PyTorch 学习

    PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...

  10. MindSpore:自动微分

    MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...

随机推荐

  1. Win10使用Dism++离线安装.Net3.5

    .Net3.5的安装包在Win10已经不能使用了,在线安装.Net3.5会很卡(跟网络无关),最好是使用Dism++提取Win10系统镜像文件离线安装. 打开Dism++软件,按照如下步骤操作: 选择 ...

  2. 在Linux平台使用wps卡顿现象解决方法

    是因为软件占据的内存过多,需要关掉目前不使用的软件,以释放系统内存.

  3. max30100心率血氧健康传感器调试总结备忘

    前记  在健康监测领域,心率血氧传感器是一个非常重要的前端采集设备.了解,研究并使用它,是一个方案商的基本素质.鉴于此,笔者花了一些时间在不同的硬件平台来使用它.中间遇到了一些问题值得总结和反思一下. ...

  4. Android 8.0 Only fullscreen activities can request orientation解决方法

    原文:Android 8.0 Only fullscreen activities can request orientation解决方法 | Stars-One的杂货小窝 公司的项目坑太多,现在适配 ...

  5. ubuntu重启网卡

    1.关闭接口:sudo ifconfig eth0 down 2.然后打开:sudo ifconfig eth0 up

  6. WINDOWS.H already included. MFC apps must not #include Windows.h

    做C++.C#和C++/CLI的混合编程有一段时间了,填了不少的坑. 今天又遇到一错误,想着挺容易就解决,估计是大脑疲惫,折腾许久才找到原因. 错误: 错误 C1189 #error: WINDOWS ...

  7. AOSP编译成功后关闭终端emulator命令找不到

    当我们编译好AOSP系统源码后,可以通过emulator命令打开模拟器,但是当我们关闭终端后,在次打开终端输入emulator命令,提示未找到命令: 此时我们需要重新执行下面语句 source bui ...

  8. 访问Webapp目录下面的html文件变为代码

    一.问题由来 一位朋友在学习使用Servlet做练习的时候,突然出现一个问题,他去访问自己创建的html文件时,发现返回的数据是html代码,而不是解析后的页面. 很是疑惑,自己尝试着解决这个问题,很 ...

  9. 记录--你还在使用websocket实现实时消息推送吗?

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前言 在日常的开发中,我们经常能碰见服务端需要主动推送给客户端数据的业务场景,比如数据大屏的实时数据,比如消息中心的未读消息,比如聊天功能 ...

  10. 记录-vue项目中使用PWA

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前言: 梳理了一下项目中的PWA的相关用法,下面我会正对vue2和vue3的用法进行一些教程示例,引入离线缓存机制,即使你断网,也能访问页 ...