滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf  

if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

|  Example usage when creating a training model:
 |  
 |  ```python
 |  # Create variables.
 |  var0 = tf.Variable(...)
 |  var1 = tf.Variable(...)
 |  # ... use the variables to build a training model...
 |  ...
 |  # Create an op that applies the optimizer.  This is what we usually
 |  # would use as a training op.
 |  opt_op = opt.minimize(my_loss, [var0, var1])
 |  
 |  # Create an ExponentialMovingAverage object
 |  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 |  
 |  with tf.control_dependencies([opt_op]):
 |      # Create the shadow variables, and add ops to maintain moving averages
 |      # of var0 and var1. This also creates an op that will update the moving
 |      # averages after each training step.  This is what we will use in place
 |      # of the usual training op.
 |      training_op = ema.apply([var0, var1])
 |  
 |  ...train the model by running training_op...
 |  ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed

『TensorFlow』滑动平均的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  3. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  4. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  5. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  6. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  9. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

随机推荐

  1. PyQt5简介及demo

    PyQt5说明 pyqt5是一套Python绑定Digia QT5应用的框架.它可用于Python 2和3.本教程使用Python 3.Qt库是最强大的GUI库之一.pyqt5的官方网站http:// ...

  2. django时区设置 media配置 日期截断函数 上传图片管理设计方案

    1.django时区 修改一下app里的设置 TIME_ZONE = 'Asia/Shanghai' USE_I18N = True USE_L10N = True # 不用UTC时间 USE_TZ ...

  3. openshift 容器云从入门到崩溃之四《配置用户验证》

    1.配置本地用户 之前安装的时候选择了htpasswd验证方式 先创建用户 # htpasswd -c /etc/origin/master/htpasswd admin 授权为集群管理员 # oc ...

  4. PowerDesigner导出pdm设计为Word文档

    点击Report->Reports 点击New Report 选择Standard Physical Report,语言选择简体中文,如下图 此时目录下就会多一个Report 右窗口: 根据自己 ...

  5. 数据库SQL的分组函数

    分组函数:(五个) 1···max(expr):求expr的最大值 }\ 2···min(expr):求expr的最小值 }-- 数据类型是有规定的 3···sum(expr):求expr的总和 }- ...

  6. SQL 增加列、修改列、删除列

    SQL语句增加列.修改列.删除列 1.增加列: alter table tableName add columnName varchar(30) 2.1. 修改列类型: alter table tab ...

  7. JavaScript 声明提前机制

    声明提前机制 在JavaScript存在着这样一种预处理机制,即浏览器在解析JS代码时会将var声明的变量和function声明的函数提升到当前作用域的顶部.但是解析JS代码时对var和functio ...

  8. Html 标签初知

    Html 标签初知 什么是Html 标签 超文本标记语言(外国语简称:HTML)标记标签通常被称为HTML标签,HTML标签是HTML语言中最基本的单位,HTML标签是HTML(标准通用标记语言下的一 ...

  9. Python assert断言

    assert断言:指定某个对象判断类型,不成立则报错. 使用环境  :接下来程序的执行,如果依赖前面的类型,不能报错的情况下使用. assert type(obj) is str print(&quo ...

  10. kruscal重构树略解

    我们先看一道题:Luogu P4197 Peaks 这道题珂以用启发式合并+主席树来做 那么强制在线呢?(bzoj 3551 [ONTAK2010]Peaks加强版) 离线做法就不行了 我们就要用一个 ...