以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) # 设置保存模型
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

Tensorflow训练和预测中的BN层的坑的更多相关文章

  1. tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图

    tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...

  2. TensorFlow使用记录 (七): BN 层及 Dropout 层的使用

    参考:tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究 1. Batch Normalization ...

  3. 【转载】 Pytorch(1) pytorch中的BN层的注意事项

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...

  4. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  5. tensorflow 训练最后预测结果为一个定值,可能的原因

    训练一个分类网络,没想到预测结果为一个定值. 找了很久发现,是因为tensor的维度的原因.  注意:我说的是我的label数据的维度. 我的输入是: y_= tf.placeholder(tf.in ...

  6. 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层

    原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...

  7. tensorflow在文本处理中的使用——Word2Vec预测

    代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...

  8. BN层

    论文名字:Batch Normalization: Accelerating Deep Network Training by  Reducing Internal Covariate Shift 论 ...

  9. 【卷积神经网络】对BN层的解释

    前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...

随机推荐

  1. readlink 获取进程的绝对路径

    readlink可以获取exe所在的路径(直接和进程关联);无法获得so的路径,so路径可以用dladdr,参考另一篇文章linux系统中有个符号链接:/proc/self/exe 它代表当前程序,所 ...

  2. less is more,so 只记 less

    less + 文件名 1.Enter键 :向下翻一行 2.空格键 :向下翻一屏 3.j键 :想下翻一行 4.k键 :向上翻一行 5.f键 :向下翻一屏 6.b键 : 向上翻一屏 7.d键 :向下翻半屏 ...

  3. 运维面试题之linux编程

    吐槽: linux下的编程基本上都很简单包括shell 三剑客和vim的使用,也可能写ansible的playbook,有基础都是一两天可以学会的,正则表达式都是试出来的不知道有些面试官让我们在纸上写 ...

  4. composer学习之路01

    以前对composer还是的理解很模糊,直到最近看一些资料,稍微有了一些浅显的了解. /* composer依赖包管理工具,如果一个项目是windows操作系统,那么composer就是360,他可以 ...

  5. table添加正确的样式

    以前在做表格的时候,会在表格<table>标签中添加一些属性,来改变表格的样式,经常用到的有这几个 width 表格的宽度border 表格边框的宽度cellpadding  单元边沿与其 ...

  6. java mysql连接时出现的问题

    当出现Caused by: java.sql.SQLException: Unknown system variable ‘tx_isolation’ 一般是mysql-connector-java的 ...

  7. libpng warning: iCCP: known incorrect sRGB profile

    参考  http://www.cocos2d-x.org/forums/6/topics/49093 解决 I got the following warnings in console when r ...

  8. Signalr实时通讯

    我们直接来干货~~~~~~觉得好推荐一下哈  研究不易 参考--https://www.jb51.net/article/133202.htm  这是基本教程 下面是重点: 如果你想允许跨域 具体代码 ...

  9. Asp.net(C#)年月日时分秒毫秒

    年月日时分秒毫秒格式:yyyyMMddHHmmssfff

  10. Android -- 带你从源码角度领悟Dagger2入门到放弃(三)

    1, 前面两篇文章我们知道了怎么使用常用的四种标签,现在我们结合我们自己的项目中去简单的使用 在我们搭建项目的时候,一般会创建自己的Application,在里面进行一些初始化如一些第三方的Green ...