在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness)

tensorflow 下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

shadow_variable=decay×shadow_variable+(1−decay)×variable" role="presentation">shadow_variable=decay×shadow_variable+(1−decay)×variableshadow_variable=decay×shadow_variable+(1−decay)×variable

由上述公式可知, decay" role="presentation">decaydecay 控制着模型更新的速度,越大越趋于稳定。实际运用中,decay" role="presentation">decaydecay 一般会设置为十分接近 1 的常数(0.99或0.999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

decay=min{decay,1+num_updates10+num_updates}" role="presentation">decay=min{decay,1+num_updates10+num_updates}decay=min{decay,1+num_updates10+num_updates}
import tensorflow as tf

v1 =tf.Variable(dtype=tf.float32, initial_value=0.)
decay = .99
num_updates = tf.Variable(0, trainable=False)
ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates) update_var_list = [v1] # 定义更新变量列表
ema_apply = ema.apply(update_var_list) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([v1, ema.average(v1)]))
# [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),shadow_variable = variable = 0. sess.run(tf.assign(v1, 5))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# 此时,num_updates = 0 ⇒ decay =.1, v1 = 5;
# shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable
sess.run(tf.assign(num_updates, 10000))
sess.run(tf.assign(v1, 10))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99,
# shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99
# shadow_variable = .99*4.555 + .01*10 = 4.609

tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage的更多相关文章

  1. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  2. tf.train.ExponentialMovingAverage

    这个函数可以参考吴恩达deeplearning.ai中的指数加权平均. 和指数加权平均不一样的是,tensorflow中提供的这个函数,能够让decay_rate随着step的变化而变化.(在训练初期 ...

  3. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  4. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  5. tensorflow:实战Google深度学习框架第四章02神经网络优化(学习率,避免过拟合,滑动平均模型)

    1.学习率的设置既不能太小,又不能太大,解决方法:使用指数衰减法 例如: 假设我们要最小化函数 y=x2y=x2, 选择初始点 x0=5x0=5  1. 学习率为1的时候,x在5和-5之间震荡. im ...

  6. (转)深入解析TensorFlow中滑动平均模型与代码实现

    本文链接:https://blog.csdn.net/m0_38106113/article/details/81542863 指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动 ...

  7. day-18 滑动平均模型测试样例

    为了使训练模型在测试数据上有更好的效果,可以引入一种新的方法:滑动平均模型.通过维护一个影子变量,来代替最终训练参数,进行训练模型的验证. 在tensorflow中提供了ExponentialMovi ...

  8. deep_learning_Function_tf.train.ExponentialMovingAverage()滑动平均

    近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录. tf.train.ExponentialMovingA ...

  9. TensorFlow函数(四)tf.trainable_variable() 和 tf.all_variable()

    tf.trainable_variable() 此函数返回的是需要训练的变量列表 tf.all_variable() 此函数返回的是所有变量列表 v = tf.Variable(tf.constant ...

随机推荐

  1. 小米开源文件管理器MiCodeFileExplorer-源码研究(1)-2个模型Model

    上篇说到,把小米的Java代码整理成了5个包,其中1个是net.micode.fileexplorer.model.这个包就2个模型类,最基本了,FileInfo和FavoriteItem. pack ...

  2. Python - 字典(dict)删除元素

    字典(dict)删除元素, 能够选择两种方式, dict.pop(key)和del dict[key]. 代码 # -*- coding: utf-8 -*- def remove_key(d, ke ...

  3. Filebeat的下载(图文讲解)

    第一步:进入Elasticsearch的官网 https://www.elastic.co/ 第二步:点击downloads https://www.elastic.co/downloads 第三步: ...

  4. jsp+tomcat+ 创建project 配置project

    *如今我们已经下载到了 tomcat 7.0+ eclipse for java ee 直接解压,打开eclipse. 接下来是步骤: eclipse 打开的界面.空空如也 ! ..* 点击 file ...

  5. [NPM] Publish npm packages using npm publish

    In this lesson we will publish our package. We will first add a prepublish script that runs our buil ...

  6. 24岁程序员, 一个人撑起App开发项目

    "疲惫吾心,怎样躲藏! 四处荒芜,怎话忧伤?"临近中秋,看到艾瑞斯的QQ签名,无尽的伤感.这个年仅24的青年.连续3年没有回家了,近期一个月总是失眠,没有家人的陪伴,就连女朋友都没 ...

  7. J2SE核心开发实战(一)——认识J2SE

    认识J2SE 一.课程简单介绍 在本章学习開始前,你应该具备一些Java的基础知识. 我们将在本章来认识J2SE,并复习一下前面学过的面向对象的相关知识. 注:全部的蓝色文字都是带超链接的,这些链接是 ...

  8. bootstrap课程12 滚动监听如何实现(bootstrap方式和自定义方式)

    bootstrap课程12 滚动监听如何实现(bootstrap方式和自定义方式) 一.总结 一句话总结:通过监听滚动的高,判断滚动的高是否大于元素距离顶端的距离 1.如何知道屏幕滚动的高? st=$ ...

  9. ORACLE11g R2【RAC+ASM→单实例FS】

    ORACLE11g R2[RAC+ASM→单实例FS] 11g R2 RAC+ASMà单实例FS的DG,建议禁用OMF. 本演示案例所用环境:   primary standby OS Hostnam ...

  10. 原生js大总结九

    81.ES6的Symbol的作用是什么?   ES6引入了一种新的原始数据类型Symbol,表示独一无二的值   82.ES6中字符串和数组新增了那些方法   字符串       1.字符串模板    ...