tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage
在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness)。
tensorflow 下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,
由上述公式可知, decay" role="presentation">decaydecay 控制着模型更新的速度,越大越趋于稳定。实际运用中,decay" role="presentation">decaydecay 一般会设置为十分接近 1 的常数(0.99或0.999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:
import tensorflow as tf
v1 =tf.Variable(dtype=tf.float32, initial_value=0.)
decay = .99
num_updates = tf.Variable(0, trainable=False)
ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates)
update_var_list = [v1] # 定义更新变量列表
ema_apply = ema.apply(update_var_list)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([v1, ema.average(v1)]))
# [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),shadow_variable = variable = 0.
sess.run(tf.assign(v1, 5))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# 此时,num_updates = 0 ⇒ decay =.1, v1 = 5;
# shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable
sess.run(tf.assign(num_updates, 10000))
sess.run(tf.assign(v1, 10))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99,
# shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99
# shadow_variable = .99*4.555 + .01*10 = 4.609
tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage的更多相关文章
- Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析
觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...
- tf.train.ExponentialMovingAverage
这个函数可以参考吴恩达deeplearning.ai中的指数加权平均. 和指数加权平均不一样的是,tensorflow中提供的这个函数,能够让decay_rate随着step的变化而变化.(在训练初期 ...
- TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model
TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...
- TensorFlow 实战(二)—— tf.train(优化算法)
Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...
- tensorflow:实战Google深度学习框架第四章02神经网络优化(学习率,避免过拟合,滑动平均模型)
1.学习率的设置既不能太小,又不能太大,解决方法:使用指数衰减法 例如: 假设我们要最小化函数 y=x2y=x2, 选择初始点 x0=5x0=5 1. 学习率为1的时候,x在5和-5之间震荡. im ...
- (转)深入解析TensorFlow中滑动平均模型与代码实现
本文链接:https://blog.csdn.net/m0_38106113/article/details/81542863 指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动 ...
- day-18 滑动平均模型测试样例
为了使训练模型在测试数据上有更好的效果,可以引入一种新的方法:滑动平均模型.通过维护一个影子变量,来代替最终训练参数,进行训练模型的验证. 在tensorflow中提供了ExponentialMovi ...
- deep_learning_Function_tf.train.ExponentialMovingAverage()滑动平均
近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录. tf.train.ExponentialMovingA ...
- TensorFlow函数(四)tf.trainable_variable() 和 tf.all_variable()
tf.trainable_variable() 此函数返回的是需要训练的变量列表 tf.all_variable() 此函数返回的是所有变量列表 v = tf.Variable(tf.constant ...
随机推荐
- janos.io
http://janos.io/ https://developer.mozilla.org/en-US/Firefox_OS/Developing_Firefox_OS/Porting
- mycat 不得不说的缘分(转)
,尾声,左兄与任正非.leader-us与马云 新成立的公司里面,有个左兄,很传奇,大一在大学入伍,然后复员专业,来上海学IT,年纪轻轻,睡在地铁站,苦心专研数据库.系统.中间件,现在已经成为了业界大 ...
- Css3 过渡(Transition)特效回调函数
Css3 出来之后,能够说是替代了Flash,通过使用Html5和Css3的完美结合.就能够做出不论什么你想得到的特效,这里不再阐述... 近期在做一个喝水签到的功能.在想签到成功之后,签到框能够模拟 ...
- 雅黑PHP探针 For PHP7
雅黑PHP探针 For PHP7资料来源: https://kn007.net/topics/yahei-php-probe-for-php7/在v0.4.7版本的基础上,修正了废弃函数及错误语法.使 ...
- theme-不同主题资源更改
1.找到了影响桌面小部件的布局文件packages/apps/Mms$ vim res/layout/widget.xml修改里面的背景颜色属性,可以实现预期效果,至于里面的 <LinearLa ...
- [NowCoder]牛客OI周赛3
A.地斗主 题意:\(4\times N\) 的地板,在上面铺 \(1\times 2\) 和 \(2\times 1\) 的地砖,求铺满方案数, \(N\le 10^9\) 原题..先把一列的状态压 ...
- 6.Windows 二进制文件 (.exe)安装--终端安装
转自:http://www.runoob.com/nodejs/nodejs-tutorial.html 32 位安装包下载地址 : http://nodejs.org/dist/v0.10.26/n ...
- 备库报 ORA-00313、ORA-00312、ORA-27037
备库alert日志如下:Errors in file /data/app/oracle/diag/rdbms/standby/orcl/trace/orcl_m000_31006.trc:ORA-00 ...
- Non-resolvable parent POM for **: Could not find artifact **
注意查看这句: 原因是本地仓库缺少了easybuy-parent:pom:0.0.1-SNAPSHOT, 原来是忘记了将父工程打包到本地仓库 ,运行聚合工程前记得先将依赖的工程都先Maven inst ...
- SSO单点登录学习总结(3)—— 基于CAS实现单点登录实例
第一: 本demo在一个机器上实现(三个虚拟主机),来看SSO单点登录实例(我们可以布到多个机器上使用都是同一个道理的),一个服务器主机,和两个客户端虚拟主机 [html] view plaincop ...