神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情。

而深度学习框架可以帮助我们自动地完成这种求梯度运算。

Tensorflow一般使用梯度磁带tf.GradientTape来记录正向运算过程,然后反播磁带自动得到梯度值。

这种利用tf.GradientTape求微分的方法叫做Tensorflow的自动微分机制。

一,利用梯度磁带求导数

import tensorflow as tf
import numpy as np # f(x) = a*x**2 + b*x + c的导数 x = tf.Variable(0.0,name = "x",dtype = tf.float32)
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0) with tf.GradientTape() as tape:
y = a*tf.pow(x,2) + b*x + c dy_dx = tape.gradient(y,x)
print(dy_dx)

tf.Tensor(-2.0, shape=(), dtype=float32)

# 对常量张量也可以求导,需要增加watch

with tf.GradientTape() as tape:
tape.watch([a,b,c])
y = a*tf.pow(x,2) + b*x + c dy_dx,dy_da,dy_db,dy_dc = tape.gradient(y,[x,a,b,c])
print(dy_da)
print(dy_dc)

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)

# 可以求二阶导数
with tf.GradientTape() as tape2:
with tf.GradientTape() as tape1:
y = a*tf.pow(x,2) + b*x + c
dy_dx = tape1.gradient(y,x)
dy2_dx2 = tape2.gradient(dy_dx,x)
tf.Tensor(2.0, shape=(), dtype=float32)
# 可以在autograph中使用

@tf.function
def f(x):
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0) # 自变量转换成tf.float32
x = tf.cast(x,tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = a*tf.pow(x,2)+b*x+c
dy_dx = tape.gradient(y,x) return((dy_dx,y)) tf.print(f(tf.constant(0.0)))
tf.print(f(tf.constant(1.0)))
(-2, 1)
(0, 0)

二,利用梯度磁带和优化器求最小值

# 求f(x) = a*x**2 + b*x + c的最小值
# 使用optimizer.apply_gradients x = tf.Variable(0.0,name = "x",dtype = tf.float32)
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
for _ in range(1000):
with tf.GradientTape() as tape:
y = a*tf.pow(x,2) + b*x + c
dy_dx = tape.gradient(y,x)
optimizer.apply_gradients(grads_and_vars=[(dy_dx,x)]) tf.print("y =",y,"; x =",x)

y = 0 ; x = 0.999998569

# 求f(x) = a*x**2 + b*x + c的最小值
# 使用optimizer.minimize
# optimizer.minimize相当于先用tape求gradient,再apply_gradient x = tf.Variable(0.0,name = "x",dtype = tf.float32) # 注意f()无参数
def f():
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0)
y = a*tf.pow(x,2)+b*x+c
return(y) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
for _ in range(1000):
optimizer.minimize(f,[x]) tf.print("y =",f(),"; x =",x)
y = 0 ; x = 0.999998569
# 在autograph中完成最小值求解
# 使用optimizer.apply_gradients x = tf.Variable(0.0,name = "x",dtype = tf.float32)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function
def minimizef():
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0) for _ in tf.range(1000): #注意autograph时使用tf.range(1000)而不是range(1000)
with tf.GradientTape() as tape:
y = a*tf.pow(x,2) + b*x + c
dy_dx = tape.gradient(y,x)
optimizer.apply_gradients(grads_and_vars=[(dy_dx,x)]) y = a*tf.pow(x,2) + b*x + c
return y tf.print(minimizef())
tf.print(x)
0
0.999998569
# 在autograph中完成最小值求解
# 使用optimizer.minimize x = tf.Variable(0.0,name = "x",dtype = tf.float32)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function
def f():
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0)
y = a*tf.pow(x,2)+b*x+c
return(y) @tf.function
def train(epoch):
for _ in tf.range(epoch):
optimizer.minimize(f,[x])
return(f()) tf.print(train(1000))
tf.print(x)
0
0.999998569

参考:

开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days

【tensorflow2.0】自动微分机制的更多相关文章

  1. 理解PyTorch的自动微分机制

    参考Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works 非常好的文章,讲解的非 ...

  2. Unity3.0基于约定的自动注册机制

    前文<Unity2.0容器自动注册机制>中,介绍了如何在 Unity 2.0 版本中使用 Auto Registration 自动注册机制.在 Unity 3.0 版本中(2013年),新 ...

  3. Unity2.0容器自动注册机制

    现如今可能每个人都会在项目中使用着某种 IoC 容器,并且我们的意识中已经形成一些固定的使用模式,有时会很难想象如果没有 IoC 容器工作该怎么进展. IoC 容器通过某种特定设计的配置,用于在运行时 ...

  4. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  5. 推荐模型DeepCrossing: 原理介绍与TensorFlow2.0实现

    DeepCrossing是在AutoRec之后,微软完整的将深度学习应用在推荐系统的模型.其应用场景是搜索推荐广告中,解决了特征工程,稀疏向量稠密化,多层神经网路的优化拟合等问题.所使用的特征在论文中 ...

  6. Senparc.Weixin.MP SDK 微信公众平台开发教程(十六):AccessToken自动管理机制

    在<Senparc.Weixin.MP SDK 微信公众平台开发教程(八):通用接口说明>中,我介绍了获取AccessToken(通用接口)的方法. 在实际的开发过程中,所有的高级接口都需 ...

  7. 关于thinkphp 中的字段自动检查机制

    在thinkphp中有很好用的自动检查机制$_validate() 但是必须与create接收配合使用 可以很方便的帮助我们去判断 namespace Home\Model;use Think\Mod ...

  8. 微软IOC容器Unity简单代码示例3-基于约定的自动注册机制

    @(编程) [TOC] Unity在3.0之后,支持基于约定的自动注册机制Registration By Convention,本文简单介绍如何配置. 1. 通过Nuget下载Unity 版本号如下: ...

  9. ArrayList源码解析(二)自动扩容机制与add操作

    本篇主要分析ArrayList的自动扩容机制,add和remove的相关方法. 作为一个list,add和remove操作自然是必须的. 前面说过,ArrayList底层是使用Object数组实现的. ...

随机推荐

  1. Windows10 JDK1.8安装及环境变量配置

    一.下载JDK1.8: 下载地址:https://www.oracle.com/java/technologies/javase-jdk8-downloads.html  二.安装步骤: 我们通常选择 ...

  2. Python - requests发送请求报错:UnicodeEncodeError: 'latin-1' codec can't encode characters in position 13-14: 小明 is not valid Latin-1. Use body.encode('utf-8') if you want to send it encoded in UTF-8.

    背景 在做接口自动化的时候,Excel作为数据驱动,里面存了中文,通过第三方库读取中文当请求参数传入 requests.post() 里面,就会报错 UnicodeEncodeError: 'lati ...

  3. JavaScript(js)函数声明与函数表达式的区别

    在JavaScript中,函数是经常用到的,在实际开发的时候,我想很多人都没有太在意函数的声明与函数表达式的区别,但是呢,这种细节的东西对于学好js是非常重要的. 函数声明与函数表达式用代码写出来是这 ...

  4. 编译putty 源码去掉 Are you sure you want to close this session? 提示

    0, 为什么要编译 putty ?在关闭窗口的时候,会弹出一个 Are you sure you want to close this session?要把这个去掉.当然也可以用 OD 之类的来修改. ...

  5. Markdown中插入复杂的合并表格方法

    由于Markdown自身的语法限制,不能直接插入有合并单元格的复杂表格. 姓名 学号 专业 张三 2018123456 计算机 赵四 2018222356 自动化 李六 2018666666 信息工程 ...

  6. 前端要了解的seo

    一.搜索引擎工作原理 当我们在输入框中输入关键词,点击搜索或查询时,然后得到结果.深究其背后的故事,搜索引擎做了很多事情. 在搜索引擎网站,比如百度,在其后台有一个非常庞大的数据库,里面存储了海量的关 ...

  7. chrome DevTools 里面 css样式里面 勾上 :hover 会将鼠标移上的效果一直保持,技巧:要在鼠标上的 div上 勾 :hover

    chrome DevTools 里面 css样式里面 勾上 :hover 会将鼠标移上的效果一直保持,技巧:要在鼠标上的 div上 勾 :hover

  8. File 关键词

    getParent() 获取父路径 getAbsoluteFile 获取绝对路径 length()  获得文件的字节数 getName() 获取路径中最后部分的名字 getPath() 获取整体路径. ...

  9. centOS6.5桌面版用不了中文输入法解决方案

    1:centos6.5中   系统->首选项->输入法中选择“使用iBus(推荐)”,点击首选输入法n遍,没有任何效果. 2.我也弄了很多种方式包括用 yum install " ...

  10. Vue中使用echarts,ajax请求的远程数据赋值给图表不刷新的问题和解决办法

    问题: vue-cli搭建的项目,在mounted钩子函数里面创建echarts图表,本地模拟数据可以正常显示,但是当将ajax请求的远程数据赋值给图表时,图表并不会刷新. 解决办法: 刚开始以为是v ...