tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均。

当训练模型的时候,保持训练参数的滑动平均是非常有益的。评估时使用取平均后的参数有时会产生比使用最终训练好的参数值好很多的效果。方法apply()会添加被训练变量的影子副本和在影子副本中维持被训练变量的滑动平均的若干操作。该方法在创建训练模型时使用。那些保持维持滑动平均的操作(ops)一般会在每个训练步骤之后被执行。average()和average_name()方法分别提供了对影子变量和影子变量名字访问的途径。它们在建立评估模型或者从checkpoint文件恢复模型时能够用到,主要是帮助使用滑动平均代替最终训练结果进行评估。

滑动平均计算时使用指数衰减。当创建ExponentialMovingAverage对象时,衰减率应该被指定。影子变量和被训练参数的初始值相同。当执行更新滑动平均的操作时,每个影子变量会按照下面的公式进行更新:

shadow_variable -= (1 - decay) * (shadow_variable - variable)

上面的公式与下面的公式相同:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

decay决定了模型更新的速度,越大越趋于稳定。decay的合理取值接近1.0,所以 decay的取值一般包含多个9,如0.999、0.9999等。

创建训练模型时的用法示例:

# Create variables.
var0 = tf.Variable(...)
var1 = tf.Variable(...)
# ... use the variables to build a training model...
...
# Create an op that applies the optimizer. This is what we usually
# would use as a training op.
opt_op = opt.minimize(my_loss, [var0, var1]) # Create an ExponentialMovingAverage object
ema = tf.train.ExponentialMovingAverage(decay=0.9999) with tf.control_dependencies([opt_op]):
# Create the shadow variables, and add ops to maintain moving averages
# of var0 and var1. This also creates an op that will update the moving
# averages after each training step. This is what we will use in place
# of the usual training op.
training_op = ema.apply([var0, var1]) ...train the model by running training_op...

有两种使用滑动平均进行评估的方法:

  • 建立一个使用影子变量(shadow variables)而非变量(variables)的模型。为此,需要使用返回给定变量的影子变量的average()方法
  • 创建一个正常的模型,但是使用影子变量名加载checkpoint文件进行评估。为此,需要使用average_name()方法

恢复影子变量值的示例:

# Create a Saver that loads variables from their saved shadow values.
shadow_var0_name = ema.average_name(var0)
shadow_var1_name = ema.average_name(var1)
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
saver.restore(...checkpoint filename...)
# var0 and var1 now hold the moving average values

部分方法:

__init__(decay,
num_updates=None,
zero_debias=False,
name='ExponentialMovingAverage')
# 创建一个ExponentialMovingAverage对象
# 为了创建影子变量和添加维持滑动平均的操作,apply()方法必须被调用
        # 可选参数num_updates允许对衰减率进行动态微调。典型的方式是通过记录训练次数,在每次训练开始之前降低衰减率。这样做可以使模型在训练的初始阶段更新
        # 得更快
        # zero_debias: 如果为True,Tensor objects会被初始化为无偏滑动平均
        # 衰减率更新公式为:
actual_decay = min(decay, (1 + num_updates) / (10 + num_updates))
        可选参数name是被添加到apply()方法中的操作名称的前缀。
apply(var_list=None)
# 维持变量的滑动平均,即对shadow variables进行计算
# var_list必须是Variable或者Tensor objects构成的列表。该方法会为列表中的所有元素创建影子变量,且变量对象的影子变量初始值和变量相同。影子变量
           也会被添加到GraphKeys.MOVING_AVERAGE_VARIABLES集合中。对于Tensor objects,影子变量会被初始化为0,同时被设置为无偏。
# 影子变量被设置trainable=False,并且被添加到GraphKeys.MOVING_AVERAGE_VARIABLES集合中,它们会在调用tf.global_variables()时被返回。
# 该方法返回一个按照要求更新所有影子变量的操作。同时需要注意的是,apply()可以在不同的var_list下被多次调用。
average(var)
# 返回变量的影子变量值,即读取影子变量shadow variables

average_name(var)
# 返回变量的影子变量名,即读取影子变量名
# 在模型训练期间计算变量的滑动平均和在评估时从计算得到的滑动平均恢复变量是ExponentialMovingAverage的典型应用。
# 为了恢复变量,必须知道影子变量名。然后影子变量名和对应的变量被传给Saver()对象来从计算得到滑动平均值恢复变量。
# Saver=tf.train.Saver({ema.average_name(var):var})
# 不管apply()方法有没有被调用,average_name()都可以被调用
variables_to_restore(moving_avg_variables=None)
# 返回要恢复的变量的名称映射
# moving_avg_variables : 需要使用滑动平均名进行恢复的变量构成的list;如果为None,会默认为variables.moving_average_variables() + va
                                   riables.trainable_variables()
# 如果变量有滑动平均,那么使用滑动平均变量名作为恢复时使用的名称;否则,使用变量名。
# 例如:
# variables_to_restore = ema.variables_to_restore()
# saver = tf.train.Saver(variables_to_restore)
# 以下是返回的一个映射的示例:
# conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma,
# conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params,
# global_step: global_step

示例:参考链接

import os
import tensorflow as tf os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 创建待训练参数
variable1 = tf.Variable(initial_value=0, trainable=True, dtype=tf.float32)
# 训练次数,不可训练
step_var = tf.Variable(initial_value=0, trainable=False)
# 创建滑动平均对象
ema = tf.train.ExponentialMovingAverage(decay=0.999, num_updates=step_var)
# 计算变量variable1的滑动平均操作
maintain_average_op = ema.apply([variable1]) # 初始化操作
init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# 初始值输出
# 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新变量
sess.run(tf.assign(variable1, 5))
# 更新影子变量
# decay = min(decay, (1+step_var) / (10+step_var))
# shadow_variable = decay * shadow_variable + (1 - decay) * variable
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新step_var
sess.run(tf.assign(step_var, 10000))
# 更新变量
sess.run(tf.assign(variable1, 10))
# 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新影子变量 # 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)]))

输出如下:

[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.5054998]
[10.0, 4.5109944]

tensorflow入门笔记(二) 滑动平均模型的更多相关文章

  1. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  2. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  3. tensorflow随机梯度下降算法使用滑动平均模型

    在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现.在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模 ...

  4. Tensorflow中的滑动平均模型

    原文链接 在Tensorflow的教程里面,使用梯度下降算法训练神经网络时,都会提到一个使模型更加健壮的策略,即滑动平均模型. 基本思想 在使用梯度下降算法训练模型时,每次更新权重时,为每个权重维护一 ...

  5. tensorflow学习笔记二:入门基础 好教程 可用

    http://www.cnblogs.com/denny402/p/5852083.html tensorflow学习笔记二:入门基础   TensorFlow用张量这种数据结构来表示所有的数据.用一 ...

  6. 1 TensorFlow入门笔记之基础架构

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  7. 78、tensorflow滑动平均模型,用来更新迭代的衰减系数

    ''' Created on 2017年4月21日 @author: weizhen ''' #4.滑动平均模型 import tensorflow as tf #定义一个变量用于计算滑动平均,这个变 ...

  8. tensorflow入门笔记(三) tf.GraphKeys

    tf.GraphKeys类存放了图集用到的标准名称. 该标准库使用各种已知的名称收集和检索图中相关的值.例如,tf.Optimizer子类在没有明确指定待优化变量的情况下默认优化被收集到tf.Grap ...

  9. 吴裕雄 PYTHON 神经网络——TENSORFLOW 滑动平均模型

    import tensorflow as tf v1 = tf.Variable(0, dtype=tf.float32) step = tf.Variable(0, trainable=False) ...

随机推荐

  1. sql操作总结

    SQL 语句的多表查询方式例如:按照 department_id 查询 employees(员工表)和 departments(部门表)的信息.方式一(通用型):SELECT ... FROM ... ...

  2. [原创]MSP430FR4133练习(一):GPIO输入电平状态判断

    硬件环境:MSP430FR4133 LANCHPAD开发板 软件环境:IARV7.10 For 430 源代码: #include "driverlib.h" void main( ...

  3. MapReduce处理HBase出错:XXX.jar is not a valid DFS filename

    原因:Hadoop文件系统没有检查路径时没有区分是本地windows系统还是Hadoop集群文件系统 解决:  只需将Map和Reduce的init方法最后一个参数(boolean addDepend ...

  4. libuv示例代码

    https://github.com/nikhilm/uvbook/tree/master/code

  5. mvc4安装、新建、模版简介

    第一安装 mvc4 1.Visual Studio 2012本身就包含MVC4另外无需安装. 2.Vs2010 需要安装vs2010 sp1补丁,后再安装mvc4安装包(官网下载即可) 第二 创建mv ...

  6. day_5.5 单例

    2018-5-5 15:00:25 单例 : 就是对象只有一个 ''' class main(object): __instance = None def __new__(cls,): if cls. ...

  7. Intersection Observer API 可以让你知道被观察元素何时进入或退出浏览器的视口

    google 文档 https://developers.google.cn/web/updates/2016/04/intersectionobserver MDN 文档 https://devel ...

  8. MyBatis limit分页设置

    错误的写法: <select id="queryMyApplicationRecord" parameterType="MyApplicationRequest&q ...

  9. Google、微软、Linkedln、Uber、亚马逊等15+海外技术专家聚首2018TOP100Summit

    11月30日-12月3日,由msup主办的第七届全球软件案例研究峰会(以下简称为TOP100Summit)将在北京国家会议中心举办.本届峰会以“释放AI生产力,让组织向智能化演进”作为开幕式主题, 4 ...

  10. 专访姚冬:All-in-One,智能时代下企业需要更快速的变革

    2017年,msup将咨询服务列入公司发展战略目标,并邀请前IBM大中华区技术总监姚冬成为咨询合伙人.近一年来,msup在咨询服务方面持续发力,与包括百度.平安科技.用友等在内的大型公司形成企业合作联 ...