以下这些函数可以用于解决梯度消失或梯度爆炸问题上。

1. tf.clip_by_value

tf.clip_by_value(
t,
clip_value_min,
clip_value_max,
name=None
)

Returns:A clipped Tensor.

输入一个张量t,把t中的每一个元素的值都压缩在clip_value_min和clip_value_max之间。小于min的让它等于min,大于max的元素的值等于max。

例子:

import tensorflow as tf;
import numpy as np; A = np.array([[1,1,2,4], [3,4,8,5]]) with tf.Session() as sess:
print sess.run(tf.clip_by_value(A, 2, 5)) >>>
[[2 2 2 4]
[3 4 5 5]]

2. tf.clip_by_norm

tf.clip_by_norm(
t,
clip_norm,
axes=None,
name=None
)

Returns:A clipped Tensor.

指对梯度进行裁剪,通过控制梯度的最大范式,防止梯度爆炸的问题,是一种比较常用的梯度规约的方式。

  • t: 输入tensor,也可以是list
  • clip_norm: 一个具体的数,如果\(l_2 \, norm(t)≤clip\_norm\), 则t不变化;否则\(t=\frac{t*clip\_norm}{l_2norm(t)}\)

注意上面的t可以是list,所以最后做比较的时候是将t的二范式和clip_norm作比较。看下面的例子:

a = np.array([2.,5.])
b = tf.clip_by_norm(a, 5)
with tf.Session() as sess:
print(sess.run(tf.norm(a)))
print(sess.run(b)) >>>
5.3851647
[1.8569534 4.6423836]

3. tf.clip_by_average_norm

tf.clip_by_average_norm(
t,
clip_norm,
name=None
)

Returns:A clipped Tensor.

其实和tf.clip_by_norm类似,只不过把\(l_2\,norm(t)\)改成了\(l_2\,norm_avg(t)=\frac{1}{n} \, l_2\,norm(t)\),\(n\)表示t的元素个数。

例子

a = np.array([3, 4]).astype('float32')
e = tf.clip_by_average_norm(a, 1)
with tf.Session() as sess:
print(sess.run(e)) >>>
[1.2 1.6]

验证一下:\(\frac{3*1}{\frac{1}{2}\sqrt{3^2+4^2}}=\frac{3}{2.5}=1.2\)。

4. tf.clip_by_global_norm

tf.clip_by_global_norm(
t_list,
clip_norm,
use_norm=None,
name=None
)

Returns:

  • list_clipped: A list of Tensors of the same type as list_t.
  • global_norm: A 0-D (scalar) Tensor representing the global norm.

注意这里的t_list是a tuple or list of tensors。

global_norm计算公式如下:

\[global\_norm=\sqrt{\sum_i^n{l_2\,norm(t[i])^2}}
\]

如果global_norm>clip_norm,则t_list中所有元素若如下计算:

\[t\_list[i]=\frac{t\_list[i]*clip\_norm}{max(global\_norm,clip\_norm)}
\]

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com


2018-12-2

TensorFlow学习笔记之--[tf.clip_by_global_norm,tf.clip_by_value,tf.clip_by_norm等的区别]的更多相关文章

  1. TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]

    I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  4. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

  5. Tensorflow学习笔记2019.01.03

    tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...

  6. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  7. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  8. tensorflow学习笔记(4)-学习率

    tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...

  9. tensorflow学习笔记(3)前置数学知识

    tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个   b为4* ...

  10. tensorflow学习笔记(2)-反向传播

    tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...

随机推荐

  1. (线性DP LIS)POJ2533 Longest Ordered Subsequence

    Longest Ordered Subsequence Time Limit: 2000MS   Memory Limit: 65536K Total Submissions: 66763   Acc ...

  2. Python网络编程之socket编程

    什么是Socket? Socket是应用层与TCP/IP协议族通信的中间软件抽象层,它是一组接口.在设计模式中,Socket其实就是一个门面模式,它把复杂的TCP/IP协议族隐藏在Socket接口后面 ...

  3. Python模块初识

    目录 一 模块初识 二 模块分类 三 导入模块 四 Python文件的两种用途 五 模板查找顺序 六 软件开发目录规范 一.模块初识 模块是自我包含并且有组织的代码片段,是一系列功能的集合体,一个py ...

  4. MySQL利用binlog恢复误操作数据(python脚本)

    在人工手动进行一些数据库写操作的时候(比方说数据订正),尤其是一些不可控的批量更新或删除,通常都建议备份后操作.不过不怕万一,就怕一万,有备无患总是好的.在线上或者测试环境误操作导致数据被删除或者更新 ...

  5. 有了这8款Mac安全杀毒和流氓防护软件,让你的mac清理优化,更加安全

    其实Mac系统相对Windows来说更加安全,主要原因是针对Mac系统的病毒和流氓软件并不多,而且Mac系统的安全机制也更加完善,不过为了更加安全的使用Mac,使用以下8款Mac 杀毒安全.安全防护和 ...

  6. jmeter源码导入eclipse步骤

    1.新建标准java项目2.右击项目选import filesystem 将apache-jmeter-4.0整个目录勾选allow output folders for source folders ...

  7. 6.Hystrix-超时设置

    由于客户端请求服务端方法时,服务端方法响应超过1秒将会触发降级,所以我们可以配置Hystrix默认的超时配置 如果我们没有配置默认的超时时间,Hystrix将取default_executionTim ...

  8. ACM-ICPC 2018 焦作赛区网络预赛 J Participate in E-sports(大数开方)

    https://nanti.jisuanke.com/t/31719 题意 让你分别判断n或(n-1)*n/2是否是完全平方数 分析 二分高精度开根裸题呀.经典题:bzoj1213 用java套个板子 ...

  9. 用过企业微信APP 后,微信接收不到消息,解决方案

    用过企业微信APP 后,微信接收不到消息的,怎么办? 请打开企业微信,找到:我----设置----新消息通知----仅在企业微信中接收消息

  10. Vuex笔记

    Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式 Vuex - 状态管理器,可以管理你的数据状态(类似于 React的 Redux) 一个 Vuex 应用的核心是 store(仓库,一个 ...