一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。

tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:

1.构建训练模型时,添加如下代码

 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
train_op = tf.group(train_op, variables_averages_op)

第1行创建了一个指数移动平均类 variable_averages

第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符

第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型

第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均

2.开始训练前,创建saver时,使用如下代码

 save_vars = tf.trainable_variables() + ave_vars
saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。

第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。

另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客

3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型

 variable_averages = tf.train.ExponentialMovingAverage(0.999)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, model_path)

这几行很简单,就不做解释了。

实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。

tensorflow中moving average的用法的更多相关文章

  1. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  2. [LeetCode] Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  3. [Swift]LeetCode346. 从数据流中移动平均值 $ Moving Average from Data Stream

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  4. [转载]Tensorflow中reduction_indices 的用法

    Tensorflow中reduction_indices 的用法 默认时None 压缩成一维

  5. LeetCode 346. Moving Average from Data Stream (数据流动中的移动平均值)$

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  6. [LeetCode] 346. Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  7. TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同

    tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...

  8. 第十八节,TensorFlow中使用批量归一化(BN)

    在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...

  9. 理解滑动平均(exponential moving average)

    1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...

随机推荐

  1. 第三届“百越杯”福建省高校网络空间安全大赛writeup--Do you know upload?

    一打开网址,可以看出应该是文件上传漏洞,查看源码,也有可能是文件包含 上传个图片,成功,然后上传一句话木马 通过bp进行上传绕过 , 开始菜刀连接http://e00b6eca3c9c4e14a31c ...

  2. Chrome_调试js出现Uncaught SyntaxError: Unexpected identifier

    转载自:http://blog.csdn.net/yiluoak_47/article/details/7663952 chrome下运行编写的JavaScript代码时,在工具javascript控 ...

  3. 【容斥】Four-tuples @山东省第九届省赛 F

    时间限制: 10 Sec 内存限制: 128 MB 题目描述 Given l1,r1,l2,r2,l3,r3,l4,r4, please count the number of four-tuples ...

  4. JS自学笔记03

    JS自学笔记03 1.函数练习: 如果函数所需参数为数组,在声明和定义时按照普通变量名书写参数列表,在编写函数体内容时体现其为一个数组即可,再传参时可以直接将具体的数组传进去 即 var max=ge ...

  5. 虚拟串口VSPD破解版 亲测win10 64可用

    虚拟串口VSPD破解版 亲测win10 64可用 点击下载

  6. 用DirectX实现多视图渲染

    什么是多视图 一般的3D程序都只有一个视图,对应整个窗口的客户区.多视图就是在一个窗口中放置多个视图,以便从不同的角度观察模型或者场景.很多图形软件都有这个功能,比如大家熟知的3DMax就有四个视图, ...

  7. nginx的http_sub_module模块使用之替换字符串

    Nginx可以实现很多功能,提供了许多插件,其中一个比较冷门的http_sub_module,是用来替换指定字符串的,它的原理是Nginx解析到文件后,执行这个插件进行拦截后返回. 昨天碰到一个场景, ...

  8. Springboot 之 多配置文件

    六.Springboot 之 多配置文件   说明:在程序开发过程中可能会有这样的需求:开发和部署的配置信息可能会不同,以传统的方式就是在配置文件里面写好配置,在部署的时候再去修改这些配置,这样肯定会 ...

  9. Apache Kafka 快速入门

    概述 Apache Kafka是一个分布式发布-订阅消息系统和强大的队列,可以处理大量的数据,将消息从一个端点传递到另一个端点.Kafka适合离线和在线消息消费,Kafka消息保存在磁盘上,并在集群内 ...

  10. iOS:苹果内购实践

    iOS 苹果的内购 一.介绍 苹果规定,凡是虚拟的物品(例如:QQ音乐的乐币)进行交易时,都必须走苹果的内购通道,苹果要收取大约30%的抽成,所以不允许接入第三方的支付方式(微信.支付宝等),当然开发 ...