【小白学PyTorch】20 TF2的eager模式与求导
【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.
参考目录:
之前讲解了如何构建数据集,如何创建TFREC文件,如何构建模型,如何存储模型。这一篇文章主要讲解,TF2中提出的一个eager模式,这个模式大大简化了TF的复杂程度。
1 什么是eager模式
Eager模式(积极模式),我认为是TensorFlow2.0最大的更新,没有之一。
Tensorflow1.0的时候还是静态计算图,在《小白学PyTorch》系列的第一篇内容,就讲解了Tensorflow的静态特征图和PyTorch的动态特征图的区别。Tensorflow2.0提出了eager模式,在这个模式下,也支持了动态特征图的构建
不得不说,改的和PyTorch越来越像了,但是人类的工具总是向着简单易用的方向发展,这肯定是无可厚非的。
2 TF1.0 vs TF2.0
TF1.0中加入要计算梯度,是只能构建静态计算图的。
- 是先构建计算流程;
- 然后开始起一个会话对象;
- 把数据放到这个静态的数据图中。
整个流程非常的繁琐。
# 这个是tensorflow1.0的代码
import tensorflow as tf
a = tf.constant(3.0)
b = tf.placeholder(dtype = tf.float32)
c = tf.add(a,b)
sess = tf.Session() #创建会话对象
init = tf.global_variables_ini tializer()
sess.run(init) #初始化会话对象
feed = {
b: 2.0
} #对变量b赋值
c_res = sess.run(c, feed) #通过会话驱动计算图获取计算结果
print(c_res)
代码中,我们需要用palceholder先开辟一个内存空间,然后构建好静态计算图后,在把数据赋值到这个被开辟的内存中,然后再运行整个计算流程。
下面我们来看在eager模式下运行上面的代码
import tensorflow as tf
a = tf.Variable(2)
b = tf.Variable(20)
c = a + b
没错,这样的话,就已经完成一个动态计算图的构建,TF2是默认开启eager模式的,所以不需要要额外的设置了。这样的构建方法,和PyTorch是非常类似的。
3 获取导数/梯度
假如我们使用的是PyTorch,那么我们如何得到\(w\times x + b\)的导数呢?
import torch
# Create tensors.
x = torch.tensor(10., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
# Build a computational graph.
y = w * x + b # y = 2 * x + 3
# Compute gradients.
y.backward()
# Print out the gradients.
print(x.grad) # tensor(2.)
print(w.grad) # tensor(10.)
print(b.grad) # tensor(1.)
都没问题吧,下面用Tensorflow2.0来重写一下上面的内容:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape() as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
print(dz_dw)
>>> tf.Tensor(10.0, shape=(), dtype=float32)
我们需要注意这几点:
- 首先结果来看,没问题,w的梯度就是10;
- 对于参与计算梯度、也就是参与梯度下降的变量,是需要用
tf.Varaible来定义的; - 不管是变量还是输入数据,都要求是浮点数float,如果是整数的话会报错,并且梯度计算输出None;
- tensorflow提供tf.GradientTape来实现自动求导,所以在tf.GradientTape内进行的操作,都会记录在tape当中,这个就是tape的概念。一个摄影带,把计算的过程录下来,然后进行求导操作
现在我们不仅要输出w的梯度,还要输出b的梯度,我们把上面的代码改成:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape() as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
dz_db = tape.gradient(z,b)
print(dz_dw)
print(dz_db)
运行结果为:
这个错误翻译过来就是一个non-persistent的录像带,只能被要求计算一次梯度。 我们用tape计算了w的梯度,然后这个tape清空了数据,所有我们不能再计算b的梯度。
解决方法也很简单,我们只要设置这个tape是persistent就行了:
import tensorflow as tf
x = tf.convert_to_tensor(10.)
w = tf.Variable(2.)
b = tf.Variable(3.)
with tf.GradientTape(persistent=True) as tape:
z = w * x + b
dz_dw = tape.gradient(z,w)
dz_db = tape.gradient(z,b)
print(dz_dw)
print(dz_db)
运行结果为:
4 获取高阶导数
import tensorflow as tf
x = tf.Variable(1.0)
with tf.GradientTape() as t1:
with tf.GradientTape() as t2:
y = x * x * x
dy_dx = t2.gradient(y, x)
print(dy_dx)
d2y_d2x = t1.gradient(dy_dx, x)
print(d2y_d2x)
>>> tf.Tensor(3.0, shape=(), dtype=float32)
>>> tf.Tensor(6.0, shape=(), dtype=float32)
想要得到二阶导数,就要使用两个tape,然后对一阶导数再求导就行了。
【小白学PyTorch】20 TF2的eager模式与求导的更多相关文章
- 小白学PyTorch 动态图与静态图的浅显理解
文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...
- 【小白学PyTorch】1 搭建一个超简单的网络
文章目录: 目录 1 任务 2 实现思路 3 实现过程 3.1 引入必要库 3.2 创建训练集 3.3 搭建网络 3.4 设置优化器 3.5 训练网络 3.6 测试 1 任务 首先说下我们要搭建的网络 ...
- PyTorch官方中文文档:自动求导机制
自动求导机制 本说明将概述Autograd如何工作并记录操作.了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序,并可帮助您进行调试. 从后向中排除子图 每个变量都有 ...
- 【小白学PyTorch】15 TF2实现一个简单的服装分类任务
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】18 TF2构建自定义模型
[机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...
- 【小白学PyTorch】16 TF2读取图片的方法
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...
- 【小白学PyTorch】19 TF2模型的存储与载入
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
- 【小白学PyTorch】8 实战之MNIST小试牛刀
文章来自微信公众号[机器学习炼丹术].有什么问题都可以咨询作者WX:cyx645016617.想交个朋友占一个好友位也是可以的~好友位快满了不过. 参考目录: 目录 1 探索性数据分析 1.1 数据集 ...
随机推荐
- Java开发之javaEE(java2EE)的介绍,java软件工程师初步阶段知识
1. 为什么需要JavaEE 我们编写的JSP代码中,由于大量的显示代码和业务逻辑混淆在一起,彼此嵌套,不利于程序的维护和扩展.当业务需求发生变化的时候,对于程序员和美工都是一个很重的负担. 为了程序 ...
- RabbitMQ 3.8.7 在 centos7 上安装
1.安装 erlang 因为 RabbitMQ 是 erlang 语言开发,所以需要依赖 erlang 环境,所以在安装 RabbitMQ 前需要先安装 erlang wget https://pac ...
- Ubuntu 20.04 手动安装 sublime_text 并建立搜索栏图标(解决 Ubuntu 20.04 桌面图标无法双击打开问题)
下载sublime_text_3离线程序包 wget https://download.sublimetext.com/sublime_text_3_build_3211_x64.tar.bz2 #x ...
- Spring security OAuth2.0认证授权学习第二天(基础概念-RBAC)
RBAC 基于角色的访问控制 基于角色的访问控制用代码实现一下其实就是一个if的问题if(如果有角色1){ } 如果某个角色可以访问某个功能,当某一天其他的另一个角色也可以访问了,那么代码就需要变化, ...
- JVM学习第三天(JVM的执行子系统)之类加载机制补充
昨晚没看完,今天继续 系统的类加载器 对于任意一个类,都需要由加载它的类加载器和这个类本身一同确立其在Java虚拟机中的唯一性,每一个类加载器,都拥有一个独立的类名称空间.这句话可以表达得更通俗一些: ...
- Solon详解(二)- Solon的核心
Solon详解系列文章: Solon详解(一)- 快速入门 Solon详解(二)- Solon的核心 Solon详解(三)- Solon的web开发 Solon详解(四)- Solon的事务传播机制 ...
- rabbitmq集成和实战
与 Spring 集成 pom 文件 使用 Maven,这里使用的 4.3.11,所以这里引入的是 rabbit 是 2.0.0,如果兼容性的话请自行去 Spring 的官网上去查 这里补充一下,sp ...
- 洛谷P1712 [NOI2016]区间 尺取法+线段树+离散化
洛谷P1712 [NOI2016]区间 noi2016第一题(大概是签到题吧,可我还是不会) 链接在这里 题面可以看链接: 先看题意 这么大的l,r,先来个离散化 很容易,我们可以想到一个结论 假设一 ...
- Java的foreach用法
foreach其实就是for的加强版,其语法如下: for(元素类型type 元素变量value : 遍历对象obj) { 引用x的java语句; } 举个例子,比如定义一个数组,使用foreach以 ...
- 别再眼高手低了! 这些Linq方法都清楚地掌握了吗?
不要再眼高手低了,这些Enumerable之常见Linq扩展方法都清楚掌握了吗?其实这是对我自己来说的! 例如:一个人这个技术掌握了一点那个技术也懂一点,其他的好像也了解一些,感觉自己啥都会一点,又觉 ...