TensorFlow学习笔记之--[tf.clip_by_global_norm,tf.clip_by_value,tf.clip_by_norm等的区别]
以下这些函数可以用于解决梯度消失或梯度爆炸问题上。
1. tf.clip_by_value
tf.clip_by_value(
t,
clip_value_min,
clip_value_max,
name=None
)
Returns:A clipped Tensor.
输入一个张量t,把t中的每一个元素的值都压缩在clip_value_min和clip_value_max之间。小于min的让它等于min,大于max的元素的值等于max。
例子:
import tensorflow as tf;
import numpy as np;
A = np.array([[1,1,2,4], [3,4,8,5]])
with tf.Session() as sess:
print sess.run(tf.clip_by_value(A, 2, 5))
>>>
[[2 2 2 4]
[3 4 5 5]]
2. tf.clip_by_norm
tf.clip_by_norm(
t,
clip_norm,
axes=None,
name=None
)
Returns:A clipped Tensor.
指对梯度进行裁剪,通过控制梯度的最大范式,防止梯度爆炸的问题,是一种比较常用的梯度规约的方式。
- t: 输入tensor,也可以是list
- clip_norm: 一个具体的数,如果\(l_2 \, norm(t)≤clip\_norm\), 则t不变化;否则\(t=\frac{t*clip\_norm}{l_2norm(t)}\)
注意上面的t可以是list,所以最后做比较的时候是将t的二范式和clip_norm作比较。看下面的例子:
a = np.array([2.,5.])
b = tf.clip_by_norm(a, 5)
with tf.Session() as sess:
print(sess.run(tf.norm(a)))
print(sess.run(b))
>>>
5.3851647
[1.8569534 4.6423836]
3. tf.clip_by_average_norm
tf.clip_by_average_norm(
t,
clip_norm,
name=None
)
Returns:A clipped Tensor.
其实和tf.clip_by_norm类似,只不过把\(l_2\,norm(t)\)改成了\(l_2\,norm_avg(t)=\frac{1}{n} \, l_2\,norm(t)\),\(n\)表示t的元素个数。
例子
a = np.array([3, 4]).astype('float32')
e = tf.clip_by_average_norm(a, 1)
with tf.Session() as sess:
print(sess.run(e))
>>>
[1.2 1.6]
验证一下:\(\frac{3*1}{\frac{1}{2}\sqrt{3^2+4^2}}=\frac{3}{2.5}=1.2\)。
4. tf.clip_by_global_norm
tf.clip_by_global_norm(
t_list,
clip_norm,
use_norm=None,
name=None
)
Returns:
list_clipped: A list of Tensors of the same type as list_t.global_norm: A 0-D (scalar) Tensor representing the global norm.
注意这里的t_list是a tuple or list of tensors。
global_norm计算公式如下:
\]
如果global_norm>clip_norm,则t_list中所有元素若如下计算:
\]
TensorFlow学习笔记之--[tf.clip_by_global_norm,tf.clip_by_value,tf.clip_by_norm等的区别]的更多相关文章
- TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]
I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- tensorflow学习笔记(4)-学习率
tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...
- tensorflow学习笔记(3)前置数学知识
tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个 b为4* ...
- tensorflow学习笔记(2)-反向传播
tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...
随机推荐
- DK location not found. Define location with sdk.dir in the local.properties file or with an ANDROID_HOME
根据提示,我们可以新建一个项目或者以前自己使用过没问题的工程,从中把local.properties文件copy到我们从github中想要导入的工程中,我自己就是这样的,然后这个问题就解决了. ndk ...
- Qt ------ QDir 路径
QDir::currentPath() 获取当前路径 QDir::toNativeSeparators 把windows下的路径转换为Qt可以识别的路径,如"C:/Users/Adminis ...
- js 各种事件 如:点击事件、失去焦点、键盘事件等
事件驱动: 我们点击按钮 按钮去掉用相应的方法. demo: <input type="button" v ...
- 2017-12-15python全栈9期第二天第七节之x or y ,x 为 非 0时,则返回x
#!/user/bin/python# -*- coding:utf-8 -*-# x or y ,x 为 非 0时,则返回xprint(1 or 2)print(3 or 2)print(0 or ...
- CM记录-部署cdh5.3.3集群
1.安装操作系统,保证联网环境,本文以CentOS 6.8为操作系统(略) 2.wget下载安装包(以5.3.3为例) #mkdir /usr/cdh ---新建cm安装目录 #cd /usr/cdh ...
- 在Ajax返回多个值
<html> <head> <title>AjaxTest</title> <script type="text/javascript& ...
- Ubuntu 下使用 putty并通过 ch340 usb 串口线进行调试
安装putty sudo apt-get install putty -y 插入usb转串口线 由于linux下没有Windos类似的设备管理器,所以我们可以通过其他方法获取对应的串口号 可以在插拔之 ...
- jQuery图片灯箱和视频灯箱
在一些前端页面中经常需要文件上传,为了美观,我们经常做一个灯箱来显示我们选择的文件, 而不是简单的input标签. html 代码:这个是多图片上传 <div class="layui ...
- ACM-ICPC 2018 焦作赛区网络预赛 B Mathematical Curse(DP)
https://nanti.jisuanke.com/t/31711 题意 m个符号必须按顺序全用,n个房间需顺序选择,有个初始值,问最后得到的值最大是多少. 分析 如果要求出最大解,维护最大值是不能 ...
- 解决浏览器跨域限制方案之CORS
一.什么是CORS CORS是解决浏览器跨域限制的W3C标准,详见:https://www.w3.org/TR/cors/. 根据CORS标准的定义,在浏览器中访问跨域资源时,需要做如下实现: 服务端 ...