滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf  

if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

|  Example usage when creating a training model:
 |  
 |  ```python
 |  # Create variables.
 |  var0 = tf.Variable(...)
 |  var1 = tf.Variable(...)
 |  # ... use the variables to build a training model...
 |  ...
 |  # Create an op that applies the optimizer.  This is what we usually
 |  # would use as a training op.
 |  opt_op = opt.minimize(my_loss, [var0, var1])
 |  
 |  # Create an ExponentialMovingAverage object
 |  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 |  
 |  with tf.control_dependencies([opt_op]):
 |      # Create the shadow variables, and add ops to maintain moving averages
 |      # of var0 and var1. This also creates an op that will update the moving
 |      # averages after each training step.  This is what we will use in place
 |      # of the usual training op.
 |      training_op = ema.apply([var0, var1])
 |  
 |  ...train the model by running training_op...
 |  ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed

『TensorFlow』滑动平均的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  3. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  4. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  5. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  6. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  9. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

随机推荐

  1. python 容器类型数据 (str list tuple set dict)

    # ###容器类型数据(str list tuple set dict) var1 = "今天心情非常美丽" var2 = [1,2,3,4] var3 = ("黄将用& ...

  2. 266B

    #include <stdio.h> #define MAXSIZE 1024 char que[MAXSIZE]; int main() { int n, t; scanf(" ...

  3. iPhoneX快速适配,简单到你想哭。

    研究了5个小时的iPhoneX适配. 从catalog,storyboard,safearea等一系列文章中发现.如果我们想完全撑满全屏.那直接建一个storyboard就好了.但撑满全屏后,流海就是 ...

  4. torch随机数 manual_seed

    import torch seed = 2018 torch.manual_seed(seed) torch.cuda.manual_seed(seed) a=torch.rand([1,5]) # ...

  5. 检查文件是否被修改或者被破坏工具 md5

    检查文件和对应的md5值是否一致.

  6. Nginx技术研究系列7-Azure环境中Nginx高可用性和部署架构设计

    前几篇文章介绍了Nginx的应用.动态路由.配置.在实际生产环境部署时,我们需要同时考虑Nginx的高可用性和部署架构. Nginx自身不支持集群以保证自身的高可用性,商业版本的Nginx+推荐: T ...

  7. sqlserver with(nolock)

    所有Select加 With (NoLock)解决阻塞死锁 在查询语句中使用 NOLOCK 和 READPAST 处理一个数据库死锁的异常时候,其中一个建议就是使用 NOLOCK 或者 READPAS ...

  8. How use Nmon and "Java Nmon Analyzer" for Monitor Linux Performance

    Nmon is a  resource monitoring tools which can monitor CPU, Memory, Disks, Network and even Filesyst ...

  9. rman实验——测试备份压缩

    oracle rman自带的备份压缩机制,可以有效的压缩备份的大小,降低磁盘的占用率.但是也会因为压缩而消耗更多的系统性能,和增加备份时间.现在就通过实验来看压缩和不压缩的区别. 进行不压缩全备 RM ...

  10. 2018-2019-1 20189206 《Linux内核原理与分析》第八周作业

    #linux内核分析学习笔记 --第七章 可执行程序工作原理 学习目标:了解一个可执行程序是如何作为一个进程工作的. ELF文件 目标文件:是指由汇编产生的(*.o)文件和可执行文件. 即 可执行或可 ...