一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。

tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:

1.构建训练模型时,添加如下代码

 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
train_op = tf.group(train_op, variables_averages_op)

第1行创建了一个指数移动平均类 variable_averages

第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符

第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型

第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均

2.开始训练前,创建saver时,使用如下代码

 save_vars = tf.trainable_variables() + ave_vars
saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。

第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。

另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客

3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型

 variable_averages = tf.train.ExponentialMovingAverage(0.999)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, model_path)

这几行很简单,就不做解释了。

实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。

tensorflow中moving average的用法的更多相关文章

  1. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  2. [LeetCode] Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  3. [Swift]LeetCode346. 从数据流中移动平均值 $ Moving Average from Data Stream

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  4. [转载]Tensorflow中reduction_indices 的用法

    Tensorflow中reduction_indices 的用法 默认时None 压缩成一维

  5. LeetCode 346. Moving Average from Data Stream (数据流动中的移动平均值)$

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  6. [LeetCode] 346. Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  7. TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同

    tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...

  8. 第十八节,TensorFlow中使用批量归一化(BN)

    在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...

  9. 理解滑动平均(exponential moving average)

    1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...

随机推荐

  1. Python中join 和 split详解(推荐)

    http://www.jb51.net/article/87700.htm python join 和 split方法简单的说是:join用来连接字符串,split恰好相反,拆分字符串的. .join ...

  2. shell脚本使用技巧3--函数调用

    定义函数 function fname() { statements; } 或者 fname() { statements; } 传递参数给函数: fname arg1 arg2; ex: 函数参数定 ...

  3. HttpClient异步调用引发的程序挂起问题排查及解决

    在搭建搭建分布式系统时,基础组件与框架的重要性不言而喻.但是如果组件出现bug,真的很要命.虽然我们通过各种单元测试,拼命找bug,但是总有一些问题被盲目自信蒙蔽了双眼,很多时候我们认为这段代码100 ...

  4. PAT Basic 1011

    1011 A+B 和 C (15 分) 给定区间 [−2​31​​,2​31​​] 内的 3 个整数 A.B 和 C,请判断 A+B 是否大于 C. 输入格式: 输入第 1 行给出正整数 T (≤10 ...

  5. 初次接触Jenkins遇到的几个问题

    1,Jenkins一直显示pending-Waiting for next available executor 网上已经提到的原因 1>,磁盘满了 2>,节点管理 刷新状态 我遇到的情况 ...

  6. Python基础-字符串、集合类型、判断、深拷贝与浅拷贝、文件读写

    字符串 1.定义三个变量: 2.交换两个变量值 1)引入第三个变量: 2)Python引入第三方变量: 3)不引入第三方变量: 3. isalpha 是否是汉字或字母 4.Isalnum  是否是汉字 ...

  7. HTML常用标签1

    1 html html:超文本标记(标签)语言 通过标签语言来标记要显示的网页中的各个部分.一套规则,浏览器认识的规则 浏览器渲染的顺序:从上到下,从左到右 对于不同的浏览器,对同一标签可能会有不完全 ...

  8. SpringMVC知识点

    一.SpringMVC 1.HelloWorld案例 ①步骤: 加jar包 在web.xml文件中配置DispatcherServlet 加入SpringMVC的配置文件 编写处理请求的处理器,并标识 ...

  9. 使用idea+springboot+Mybatis搭建web项目

    使用idea+springboot+Mybatis搭建web项目 springboot的优势之一就是快速搭建项目,省去了自己导入jar包和配置xml的时间,使用非常方便. 1.创建项目project, ...

  10. UML建模——用例图(Use Case Diagram)

    用例图主要用来描述角色以及角色与用例之间的连接关系.说明的是谁要使用系统,以及他们使用该系统可以做些什么.一个用例图包含了多个模型元素,如系统.参与者和用例,并且显示这些元素之间的各种关系,如泛化.关 ...