滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf  

if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

|  Example usage when creating a training model:
 |  
 |  ```python
 |  # Create variables.
 |  var0 = tf.Variable(...)
 |  var1 = tf.Variable(...)
 |  # ... use the variables to build a training model...
 |  ...
 |  # Create an op that applies the optimizer.  This is what we usually
 |  # would use as a training op.
 |  opt_op = opt.minimize(my_loss, [var0, var1])
 |  
 |  # Create an ExponentialMovingAverage object
 |  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 |  
 |  with tf.control_dependencies([opt_op]):
 |      # Create the shadow variables, and add ops to maintain moving averages
 |      # of var0 and var1. This also creates an op that will update the moving
 |      # averages after each training step.  This is what we will use in place
 |      # of the usual training op.
 |      training_op = ema.apply([var0, var1])
 |  
 |  ...train the model by running training_op...
 |  ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed

『TensorFlow』滑动平均的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  3. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  4. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  5. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  6. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  9. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

随机推荐

  1. 11.0-uC/OS-III就绪列表(优先级)

    准备运行的任务被放置于就绪列表中.就绪列表包括2个部分:位映像组包含了优先级信息,一个表包含了所有指向就绪任务的指针. 1.优先级 图6-1到6-3显示了优先级的位映像组.它的宽度取决于CPU_DAT ...

  2. Linux下安装jieba

    Jieba代码对 Python 2/3 均兼容 * 全自动安装:`easy_install jieba` 或者 `pip install jieba` / `pip3 install jieba` * ...

  3. LCA Tarjan方法

    LCA Tarjan方法 不得不说,高中生好厉害,OI大佬,感觉上个大学好憋屈啊! 说多了都是眼泪 链接拿去:http://www.cnblogs.com/JVxie/p/4854719.html

  4. 4.Python3运算符

    4.1算数运算符(以下假设变量a为10,变量b为21) 实例操作: print(3 + 5) #数字3与5相加 print(3 - 5) #数字3与5相减 print(3 * 5) #数字3与5相乘 ...

  5. sitecore开发入门Sitecore的CRUD操作 - 第一部分

    在本文中,讨论如何使用Sitecore.Data.Items.Item并对这些项执行CRUD(创建,读取,更新和删除)操作.我还将介绍如何使用Glass和Fortis类库进行相同的操作,这些操作都是对 ...

  6. python IO 多路复用

    一.epoll epoll 参考链接: https://www.cnblogs.com/Alanpy/articles/5125986.html epoll  参考链接: https://www.cn ...

  7. 安装OpenResty开发环境

    OpenResty是一个基于 Nginx 与 Lua 的高性能 Web 平台,其内部集成了大量精良的 Lua 库.第三方模块以及大多数的依赖项.用于方便地搭建能够处理超高并发.扩展性极高的动态 Web ...

  8. detours express版本的使用

    原文最早发表于百度空间2012-03-21 一.编译lib 1)拷贝它的src文件夹和system.mak文件到VS的VCVARS32.BAT所在的目录下 2)在命令提示符中运行VCVARS32.BA ...

  9. event.target.dataset

    dataset并不是典型意义上的JavaScript对象,而是个DOMStringMap对象,DOMStringMap是HTML5一种新的含有多个名-值对的交互变量. 1.event.target.d ...

  10. Learning-Python【13】:迭代器和生成器

    一.什么是迭代器 迭代指的是一个重复的过程,每一次重复都是基于上一次的结果而来的 # 这里的循环也是一个迭代,每次基于上一次的结果而取值 li = 'hello' i = 0 while i < ...