在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现。在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型。在初始化ExponentialMovingAverage时,需要提供一个衰减率(decay)。这个衰减率将用于控制模型更新的速度。ExponentialMovingAverage对每一个变量会维护一个影子变量(shadowvariable),这个影子变量的初始值就是相应变量的初始值,而每次运行变量更新时,影子变量的值会更新为:

shadow_variable=decay x shadow_variable+(1-decay) x variable

其中shadow_variable 为影子变量,variable为待更新的变量,decay为衰减率。decay决定了模型更新的速度,decay越大模型越趋于稳定。在实际应用中,decay一般会设成非常接近1的数(比如0.999或0.9999)。为了使得模型在训练前期可以更新得更快,ExponentialMovingAverage还提供了num_updates参数来动态设置decay的大小.

下面是ExponentailMovingAverage使用示例

# -*- coding:UTF- -*-
import tensorflow as tf
# 定义一个初始为0的变量来计算滑动平均 v1=tf.Variable(,dtype=tf.float32) #这里的step变量模拟神经网络中迭代的轮数,可以用于动态控制衰减率
step=tf.Variable(,trainable=False) #定义一个滑动平均的类,初始化时给定了衰减率(0.99)和控制衰减率的变量step
ema=tf.train.ExponentialMovingAverage(0.99,step) # 定义一个更新变量滑动平均的操作,这里给定一个列表,每次执行这个操作时,这个列表中的变量的值都会更新 maintain_averages_op=ema.apply([v1])
with tf.Session() as sess:
# 初始化所有变量
init_op=tf.global_variables_initializer()
sess.run(init_op) # 通过ema.average(v1)获取滑动平均之后变量的取值。在初始化之后变量v1的值和v1的滑动平均都为0 print sess.run([v1,ema.average(v1)])
# 更新变量v1的值到5
sess.run(tf.assign(v1,))
# 更新v1的滑动平均值,衰减率为min{0.99,(+step)/(+step)=0.1}=0.1
# 所以v1的滑动平均会被更新为0.*+0.9*=4.5 sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)]) # 更新 step的值为10000
sess.run(tf.assign(step,))
# 更新 v1的值为10。
sess.run(tf.assign(v1,))
# 更新v1 的滑动平均值。衰减率为min(0.99,(+step)/(+step)≈0.999}=0.99
# 所以v1的滑动平均会被更新为0.*4.5+0.01*=4.555 sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)]) #再次更新滑动平均值,得到的新滑动平均值为0.*4.555+0.01*=4.60945 sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)])

结果如下

[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.555]
[10.0, 4.60945]

tensorflow随机梯度下降算法使用滑动平均模型的更多相关文章

  1. Tensorflow中的滑动平均模型

    原文链接 在Tensorflow的教程里面,使用梯度下降算法训练神经网络时,都会提到一个使模型更加健壮的策略,即滑动平均模型. 基本思想 在使用梯度下降算法训练模型时,每次更新权重时,为每个权重维护一 ...

  2. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  3. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  4. 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  5. 监督学习——随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  6. tensorflow入门笔记(二) 滑动平均模型

    tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均. 当训练模型的时候,保持训练参数的滑动平均是非常有益的.评估时使用取平均 ...

  7. 78、tensorflow滑动平均模型,用来更新迭代的衰减系数

    ''' Created on 2017年4月21日 @author: weizhen ''' #4.滑动平均模型 import tensorflow as tf #定义一个变量用于计算滑动平均,这个变 ...

  8. 吴裕雄 PYTHON 神经网络——TENSORFLOW 滑动平均模型

    import tensorflow as tf v1 = tf.Variable(0, dtype=tf.float32) step = tf.Variable(0, trainable=False) ...

  9. 随机梯度下降算法求解SVM

    测试代码(matlab)如下: clear; load E:\dataset\USPS\USPS.mat; % data format: % Xtr n1*dim % Xte n2*dim % Ytr ...

随机推荐

  1. day01-h1字体大小和文本居中

    <!doctype html><html > <head> <meta charset="utf-8"> <link rel= ...

  2. struts2中的session、request 、和action往页面中传值的方法

    ActionContext.getContext().put("list", list); ActionContext.getContext().getValueStack().p ...

  3. AngularJS中获取数据源的几种方式

    在AngularJS中,可以从$rootScope中获取数据源,也可以把获取数据的逻辑封装在service中,然后注入到app.run函数中,或者注入到controller中.本篇就来整理获取数据的几 ...

  4. 你真的会用Gson吗?Gson使用指南(1)

    JSON (官网) 是一种文本形式的数据交换格式,它比XML更轻量.比二进制容易阅读和编写,调式也更加方便.其重要性不言而喻.解析和生成的方式很多,Java中最常用的类库有:JSON-Java.Gso ...

  5. No module named 'pandas._libs.tslib'

    用pip命令安装: pip install pandas python3的: pip3 install pandas

  6. JavaScript 空间分析库——JSTS和Turf【转】

    https://blog.csdn.net/neimeng0/article/details/80363468 前言 项目中有管线的空间拓扑关系查询需求,在npm中检索到JSTS和Turf两个Java ...

  7. IDEA下使用Maven的test命令乱码

    IDEA下使用Maven的test命令乱码的时候,加上 -Dfile.encoding=GBK 就可以解决啦   如下图所示:   或者在Maven的pom.xml文件中增加: <propert ...

  8. Hbase 命令小结

    1.创建test,如果存在先删除 hbase(main)::> disable 'test' row(s) in 1.4250 seconds hbase(main)::> drop 't ...

  9. ORA-16447 Redo apply was not active at the target standby database

    Cause ALTER SYSTEM FLUSH REDO TO STANDBY failed because redo apply is not active at the target datab ...

  10. Jmeter笔记:响应断言详解

    转自:http://www.51testing.com/html/80/n-2430180.html 平时我们使用jmeter进行性能测试时,经常会用到断言.jmeter提供了很多种断言,本来想全都写 ...