本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/computational_graph.py

计算图

深度学习就是对张量进行一系列的操作,随着操作种类和数量的增多,会出现各种值得思考的问题。比如多个操作之间是否可以并行,如何协同底层的不同设备,如何避免冗余的操作,以实现最高效的计算效率,同时避免一些 bug。因此产生了计算图 (Computational Graph)。

计算图是用来描述运算的有向无环图,有两个主要元素:节点 (Node) 和边 (Edge)。节点表示数据,如向量、矩阵、张量。边表示运算,如加减乘除卷积等。

用计算图表示:$y=(x+w)*(w+1)$,如下所示:

可以看作, $y=a \times b$ ,其中 $a=x+w$,$b=w+1$。

计算图与梯度求导

这里求 $y$ 对 $w$ 的导数。根复合函数的求导法则,可以得到如下过程。

$\begin{aligned} \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \ &=b * 1+a * 1 \ &=b+a \ &=(w+1)+(x+w) \ &=2 * w+x+1 \ &=2 * 1+2+1=5\end{aligned}$

体现到计算图中,就是根节点 $y$ 到叶子节点 $w$ 有两条路径 y -> a -> wy ->b -> w。根节点依次对每条路径的孩子节点求导,一直到叶子节点w,最后把每条路径的导数相加即可。

代码如下:

import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
# y=(x+w)*(w+1)
a = torch.add(w, x) # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)
# y 求导
y.backward()
# 打印 w 的梯度,就是 y 对 w 的导数
print(w.grad)

结果为tensor([5.])

我们回顾前面说过的 Tensor 中有一个属性is_leaf标记是否为叶子节点。

在上面的例子中,$x$ 和 $w$ 是叶子节点,其他所有节点都依赖于叶子节点。叶子节点的概念主要是为了节省内存,在计算图中的一轮反向传播结束之后,非叶子节点的梯度是会被释放的。

代码示例:

# 查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf) # 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

结果为:

is_leaf:
True True False False False
gradient:
tensor([5.]) tensor([2.]) None None None

非叶子节点的梯度为空,如果在反向传播结束之后仍然需要保留非叶子节点的梯度,可以对节点使用retain_grad()方法。

而 Tensor 中的 grad_fn 属性记录的是创建该张量时所用的方法 (函数)。而在反向传播求导梯度时需要用到该属性。

示例代码:

# 查看梯度
print("w.grad_fn = ", w.grad_fn)
print("x.grad_fn = ", x.grad_fn)
print("a.grad_fn = ", a.grad_fn)
print("b.grad_fn = ", b.grad_fn)
print("y.grad_fn = ", y.grad_fn)

结果为

w.grad_fn =  None
x.grad_fn = None
a.grad_fn = <AddBackward0 object at 0x000001D8DDD20588>
b.grad_fn = <AddBackward0 object at 0x000001D8DDD20588>
y.grad_fn = <MulBackward0 object at 0x000001D8DDD20588>

PyTorch 的动态图机制

PyTorch 采用的是动态图机制 (Dynamic Computational Graph),而 Tensorflow 采用的是静态图机制 (Static Computational Graph)。

动态图是运算和搭建同时进行,也就是可以先计算前面的节点的值,再根据这些值搭建后面的计算图。优点是灵活,易调节,易调试。PyTorch 里的很多写法跟其他 Python 库的代码的使用方法是完全一致的,没有任何额外的学习成本。

静态图是先搭建图,然后再输入数据进行运算。优点是高效,因为静态计算是通过先定义后运行的方式,之后再次运行的时候就不再需要重新构建计算图,所以速度会比动态图更快。但是不灵活。TensorFlow 每次运行的时候图都是一样的,是不能够改变的,所以不能直接使用 Python 的 while 循环语句,需要使用辅助函数 tf.while_loop 写成 TensorFlow 内部的形式。

参考资料

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学

[PyTorch 学习笔记] 1.4 计算图与动态图机制的更多相关文章

  1. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

  2. [PyTorch 学习笔记] 1.1 PyTorch 简介与安装

    PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...

  3. WPF学习笔记-用Expression Design制作矢量图然后导出为XAML

    WPF学习笔记-用Expression Design制作矢量图然后导出为XAML 第一次用Windows live writer写东西,感觉不错,哈哈~~ 1.在白纸上完全凭感觉,想象来画图难度很大, ...

  4. Introduction to 3D Game Programming with DirectX 12 学习笔记之 --- 第九章:贴图

    原文:Introduction to 3D Game Programming with DirectX 12 学习笔记之 --- 第九章:贴图 代码工程地址: https://github.com/j ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. [PyTorch 学习笔记] 5.2 Hook 函数与 CAM 算法

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py https://gi ...

  7. Nodejs学习笔记(三)——一张图看懂Nodejs建站

    前言:一条线,竖着放,如果做不到精进至深,那就旋转90°,至少也图个幅度宽广. 通俗解释上面的胡言乱语:还没学会爬,就学起走了?! 继上篇<Nodejs学习笔记(二)——Eclipse中运行调试 ...

  8. Spring学习(二)—— java的动态代理机制

    在学习Spring的时候,我们知道Spring主要有两大思想,一个是IoC,另一个就是AOP,对于IoC,依赖注入就不用多说了,而对于Spring的核心AOP来说,我们不但要知道怎么通过AOP来满足的 ...

  9. PyTorch学习笔记之计算图

    1. **args, **kwargs的区别 def build_vocab(self, *args, **kwargs): counter = Counter() sources = [] for ...

随机推荐

  1. .NET Core + K8S + Loki 玩转日志聚合

    1. Intro 最近在了解日志聚合系统,正好前几天看到一篇文章<用了日志系统新贵Loki,ELK突然不香了!>,所以就决定动手体验一下.本文就带大家快速了解下Loki,并简单介绍.NET ...

  2. 动态页面技术(JSP)

    JSP技术 jsp脚本和注释 jsp脚本: 1)<%java代码%> ----- 内部的java代码翻译到service方法的内部 2)<%=java变量或表达式> ----- ...

  3. TCP 服务器端

    """ 建立tcp服务器 绑定本地服务器信息(ip地址,端口号) 进行监听 获取监听数据(监听到的客户端和地址) 使用监听到的客户端client_socket获取数据 输 ...

  4. mit-6.828 Lab Tools

    Lab Tools 目录 Lab Tools 写在前面 GDB GNU GPL (通用公共许可证) QEMU ELF 可执行文件的格式 Verbose mode Makefile 写在前面 操作系统小 ...

  5. win10下visual studio code安装及mingw C/C++编译器配置,launch.json和task.json文件的配置

    快一年了,我竟然还有脸回来..... 过去一年,由于毕设.找工作的原因,发生太多变故,所以一直没更(最主要的原因还是毅力不够...),至于发生了什么事,以后想说的时候再更吧..依然是小白,下面说正事. ...

  6. 一文学会MySQL的explain工具

    开篇说明 (1) 本文将细致介绍MySQL的explain工具,是下一篇<一文读懂MySQL的索引机制及查询优化>的准备篇. (2) 本文主要基于MySQL5.7版本(https://de ...

  7. 【AHOI2009】同类分布 题解(数位DP)

    题目大意:求$[l,r]$中各位数之和能被该数整除的数的个数.$0\leq l\leq r\leq 10^{18}$. ------------------------ 显然数位DP. 搜索时记录$p ...

  8. TF签名是什么?比企业签名好在哪里?

      现在苹果企业签名的服务大致分为三类,苹果企业签名.超级签名和TF签名,而TF签名TF签名又称 TestFlight 签名,是目前最稳定的签名方式. ​   「优势」   关键词:零风险;限制少;安 ...

  9. 分布式任务调度平台 → XXL-JOB 实战

    开心一刻 老师:谁知道鞭炮用英语怎么说? 甲:老师!老师!我知道,鞭炮的英文是pilipala. 老师:那闪电呢? 乙:kucha kucha 老师:那舞狮呢? 丙:dong dong qiang 老 ...

  10. Linux系统之《消息队列》入手应用

    目录 简述 代码 编译 运行 简述 消息队列是Linux进程间通信方式之一,消息队列一般是用于简单的通信,数据量不大,通信不频繁的情况.如果交互频繁或者数据量大就不适合了. 代码 下面直接上代码,发送 ...