Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。
书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。
那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。
所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) # 设置保存模型
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。
最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。
感谢几位大佬的博客解惑:
https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0
http://www.cnblogs.com/hrlnw/p/7227447.html
Tensorflow训练和预测中的BN层的坑的更多相关文章
- tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图
tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...
- TensorFlow使用记录 (七): BN 层及 Dropout 层的使用
参考:tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究 1. Batch Normalization ...
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...
- 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...
- tensorflow 训练最后预测结果为一个定值,可能的原因
训练一个分类网络,没想到预测结果为一个定值. 找了很久发现,是因为tensor的维度的原因. 注意:我说的是我的label数据的维度. 我的输入是: y_= tf.placeholder(tf.in ...
- 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层
原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...
- tensorflow在文本处理中的使用——Word2Vec预测
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- BN层
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
- 【卷积神经网络】对BN层的解释
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
随机推荐
- 在win7上跑基于任少卿作者代码修改的RPN+BF实验
1.前言 之前在win10上成功的跑起来faster-rcnn的实验,并且跑了一下CaltechPedestrian的数据集,但是效果一直不理想,折腾了好久也没弄清楚到底原因出在哪里,直到读了Is F ...
- Jvisualvm 添加插件
1.访问地址:https://visualvm.github.io/pluginscenters.html,找到自己JDK版本对应的插件下载地址(我的JDK版本为1.7.0_67): 2.点击该链接进 ...
- Javascript循环删除数组中元素的几种方法示例
发现问题 大家在码代码的过程中,经常会遇到在循环中移除指定元素的需求.按照常规的思路,直接一个for循环,然后在循环里面来个if判断,在判断中删除掉指定元素即可.但是实际情况往往不会像预想的那样顺利运 ...
- Object:所有类的超类
Java中每个类都是由Object类扩展而来 1.equals方法 在Object类中,这个方法用于判断两个对象是否具有相同的引用,然而对于大多数类来说,经常需要检测两个对象状态的相等性. publi ...
- 在vue中使用echarts图表
在vue中使用echarts图表 转载请注明出处:https://www.cnblogs.com/wenjunwei/p/9815290.html 安装vue依赖 使用npm npm instal ...
- [转载来之雨松:NGUI研究院之为什么打开界面太慢(十三)]
本文固定链接: http://www.xuanyusong.com/archives/2799
- Shrinking images on Linux
When creating images from existing ISOs you often need to allocate a number of MB for the image to a ...
- 【数据结构】算法 LinkList (Merge Two Sorted Lists)
合并2个有序链表 list A, list B, Solution: 对A,B 表按序读取数据,比较大小后插入新链表C. 由于两个输入链表的长度可能不同,所以最终会有一个链表先完成插入所有元素,则直接 ...
- 在Ubuntu16.04中python环境下实现tab键补全
1.编写tab.py的代码: 1 #!/usr/bin/env python 2 # python startup file 3 import sys 4 import readline 5 impo ...
- django的url反向解析
目的:防止页面中url地址改变,其他与这个URL地址有关联的都要改,减少耦合度 使用:主要分为在html中和视图函数中的使用 HTML中的使用: 如果我们在项目的url文件中通过include导入了应 ...