PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD
训练神经网络时,最常用的算法就是反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。
为了计算这些梯度,PyTorch内置了名为 torch.autograd
的微分引擎。它支持任意计算图的自动梯度计算。
一个最简单的单层神经网络,输入 x
,参数 w
和 b
,某个损失函数。它可以用PyTorch这样定义:
import torch
x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b # 矩阵乘法
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
Tensors、Functions and Computational graph
上述代码定义了下面的computational graph:
在该网络中,w
和 b
是parameters,是我们需要优化的。因此,我们需要能够计算损失函数关于这些变量的梯度。因此,我们设置了这些tensor的 requires_grad
属性。
注意:在创建tensor时可以设置 requires_grad
的值,或者创建之后使用 x.requires_grad_(True)
方法。
我们应用到tensor上构成计算图的function实际上是 Function
类的对象。该对象知道如何计算前向的函数,还有怎么计算反向传播步骤中函数的导数。反向传播函数存储在tensor的 grad_fn
属性中。You can find more information of Function
in the documentation。
print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)
输出:
Gradient function for z = <AddBackward0 object at 0x7faea5ef7e10>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7faea5ef7e10>
计算梯度
为了优化神经网络的参数权重,我们需要计算损失函数关于参数的导数,即,我们需要利用一些固定的 x
和 y
计算\(\frac{\partial loss}{\partial w}\)和\(\frac{\partial loss}{\partial b}\)。为计算这些导数,可以调用 loss.backward()
,然后从 w.grad
和 b.grad
:
loss.backward()
print(w.grad)
print(b.grad)
输出:
tensor([[0.0043, 0.2572, 0.3275],
[0.0043, 0.2572, 0.3275],
[0.0043, 0.2572, 0.3275],
[0.0043, 0.2572, 0.3275],
[0.0043, 0.2572, 0.3275]])
tensor([0.0043, 0.2572, 0.3275])
注意:
- 我们只能在计算图中
requires_grad=True
的叶节点获得grad
属性。对于其它节点,梯度是无效的。 - 出于性能原因,我们只能对给定的graph使用
backward
执行梯度计算。如果需要在同一graph调用若干次backward
,在调用时,需要传入retain_graph=True
。
禁用梯度跟踪
默认情况下,所有 requires_grad=True
的tensor都会跟踪它们的计算历史,并支持梯度计算。但是在一些情况下并不需要,例如,当我们已经训练了一个模型,并将其用在一些输入数据上,即,仅仅经过网络做前向运算。那么可以在我们的计算代码外包围 torch.no_grad()
块停止跟踪计算。
z = torch.matmul(x, w) + b
print(z.requires_grad())
with torch.no_grad():
z = torch.matmul(x, w) + b
print(z.requires_grad)
输出:
True
False
在tensor上使用 detach()
也能达到同样的效果
z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)
输出:
False
禁止梯度跟踪的几个原因:
- 将神经网络的一些参数标记为frozen parameters。这在finetuning a pretrained network中是非常常见的脚本。
- 当你只做前向过程,用于speed up computations,因为tensor计算而不跟踪梯度将会更有效。
More on Coputational Graphs
概念上,autograd在一个由Function对象组成的有向无环图(DAG)中保留了数据(tensors)记录,还有所有执行的操作(以及由此产生的新的tensors)。在DAG中,叶节点是输入tensor,根节点是输出tensors。通过从根到叶跟踪该图,可以使用链式法则自动地计算梯度。
在前向过程中,autograd同时进行两件事:
- 运行请求的操作计算结果tensor
- 在DAG中保存操作的梯度函数
当在DAG根部调用 .backward()
时,后向过程就会开始。autograd
会:
- 由每一个
.grad_fn
计算梯度。 - 在对应tensor的 '.grad' 属性累积梯度
- 使用链式法则,一直传播到叶tensor
注意:DAGs在PyTorch是动态的,需要注意的一点是,graph是从头开始创建的;在每次调用 .backward()
之后,autograd开始生成一个新的graph。这允许你在模型中使用控制流语句;如果需要,你可以在每次迭代中改变shape,size,and operations。
选读:Tensor梯度和Jacobian Products
延伸阅读
PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD的更多相关文章
- DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TORCH.AUTOGRAD
torch.autograd 是PyTorch的自动微分引擎,用以推动神经网络训练.在本节,你将会对autograd如何帮助神经网络训练的概念有所理解. 背景 神经网络(NNs)是在输入数据上执行的嵌 ...
- pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分
参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...
- Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析
backward函数 官方定义: torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph ...
- (转)自动微分(Automatic Differentiation)简介——tensorflow核心原理
现代深度学习系统中(比如MXNet, TensorFlow等)都用到了一种技术——自动微分.在此之前,机器学习社区中很少发挥这个利器,一般都是用Backpropagation进行梯度求解,然后进行SG ...
- [深度学习] Pytorch学习(一)—— torch tensor
[深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...
- PyTorch Tutorials 2 AUTOGRAD: AUTOMATIC DIFFERENTIATION
%matplotlib inline Autograd: 自动求导机制 PyTorch 中所有神经网络的核心是 autograd 包. 我们先简单介绍一下这个包,然后训练第一个简单的神经网络. aut ...
- PyTorch源码解读之torch.utils.data.DataLoader(转)
原文链接 https://blog.csdn.net/u014380165/article/details/79058479 写得特别好!最近正好在学习pytorch,学习一下! PyTorch中数据 ...
- PyTorch 介绍 | DATSETS & DATALOADERS
用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
随机推荐
- leetcode日记本
写在前面: 2019.6开始经过一年的学习,我依然没有学会算法,依然停留在最基本的阶段,面对题目依然一头雾水 但是难不是放弃的理由,根据毛主席的论持久战原理,我决定一天看一点循序渐进,相信总有一天可以 ...
- Ubuntu mininet+Ryu环境安装
我们使用下载Ryu源代码进行那个安装 Ryu官方文档:http://ryu.readthedocs.io/en/latest/ Ryu电子书:http://osrg.github.io/ryu/res ...
- 网站迁移纪实:从Web Form 到 Asp.Net Core (Abp vNext 自定义开发)
问题和需求 从2004年上线,ZLDNN.COM运行已经超过16年了,一直使用DotNetNuke平台(现在叫DNN Platform),从最初的DotNetNuke 2.1到现在使用的7.4.先是在 ...
- 离线版centos8环境部署迁移监控操作笔记
嗨咯,前两天总结记录了离线版centos8下docker的部署笔记,今天正好是2021年的最后一天,今天正好坐在本次出差回家的列车上,车上没有上面事做,索性不如把本次离线版centos8环境安装的其他 ...
- MySQL数据库基础(1)数据库基础
目录 一.数据库简介 二.mysql数据库 三.客户端连接mysql服务 四.Navicat for mysql 一.数据库简介 1.概念 (1)数据:如文字.图形.图像.声音以及学生的档案记录等,这 ...
- MySQL数据操作与查询笔记 • 【第3章 DDL 和 DML】
全部章节 >>>> 本章目录 3.1 使用 DDL 定义数据库表结构 3.1.1 SQL 简介 3.1.2 维护数据库和创建数据表 3.2 使用 DDL 维护数据库表结构 ...
- js 关于 data.xuNum = xuNum++; 赋值写法 的探讨
1 .源码 let xuNum = 0; let data = []; data.xuNum = xuNum++; console.log(data.xuNum) 2.打印结果 // 0 3.原因 ...
- websocket 使用 spring 的service层 ,进而调用里面的 dao层 来操作数据库 ,包括redis、mysql等通用
1.前言 描述一下今天用websocket踩得坑 --->空指针异常! 我想在websocket里面使用service 层的接口,从中获取数据库的一些信息 , 使用 @Autowired 注 ...
- GORM学习指南
orm是一个使用Go语言编写的ORM框架.它文档齐全,对开发者友好,支持主流数据库. 一.初识Gorm Github GORM 中文官方网站内含十分齐全的中文文档,有了它你甚至不需要再继续向下阅读本文 ...
- SpringBoot学习笔记五之管理员后台维护
注:图片如果损坏,点击文章链接:https://www.toutiao.com/i6803544440112677379/ 首先完成分页 引入PageHelper(之前已经添加过了) 在spring- ...