tensorflow中batch normalization的用法
网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下:
1.原理
公式如下:
y=γ(x-μ)/σ+β
其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。
一般来讲,这些参数都是基于channel来做的,比如输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。
2.tensorflow中使用
tensorflow中batch normalization的实现主要有下面三个:
tf.nn.batch_normalization
tf.layers.batch_normalization
tf.contrib.layers.batch_norm
封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。
3.训练
训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
4.测试
测试时需要注意一点,输入参数training=False,其他就没了
5.预测
预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:
var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也没法获取bn层中的μ和σ(也可能是我用法不对),不过好在所有的参数都在tf.global_variables()中,因此可以这么写:
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
按照上述写法,即可把μ和σ保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。
注意上面有个不严谨的地方,因为我的网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。
2018.4.22更新
提供一个基于mnist的示例,供大家参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。
tensorflow中batch normalization的用法的更多相关文章
- tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)
tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...
- 深度学习中 Batch Normalization
深度学习中 Batch Normalization为什么效果好?(知乎) https://www.zhihu.com/question/38102762
- 深度学习中batch normalization
目录 1 Batch Normalization笔记 1.1 引包 1.2 构建模型: 1.3 构建训练函数 1.4 结论 Batch Normalization笔记 我们将会用MNIST数 ...
- 深度学习中 Batch Normalization为什么效果好
看mnist数据集上其他人的CNN模型时了解到了Batch Normalization 这种操作.效果还不错,至少对于训练速度提升了很多. batch normalization的做法是把数据转换为0 ...
- tensorflow中moving average的用法
一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些. tensorflow-model ...
- Batch Normalization原理
Batch Normalization导读 博客转载自:https://blog.csdn.net/malefactor/article/details/51476961 作者: 张俊林 为什么深度神 ...
- Batch Normalization
一.BN 的作用 1.具有快速训练收敛的特性:采用初始很大的学习率,然后学习率的衰减速度也很大 2.具有提高网络泛化能力的特性:不用去理会过拟合中drop out.L2正则项参数的选择问题 3.不需要 ...
- Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015
这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...
- zz详解深度学习中的Normalization,BN/LN/WN
详解深度学习中的Normalization,BN/LN/WN 讲得是相当之透彻清晰了 深度神经网络模型训练之难众所周知,其中一个重要的现象就是 Internal Covariate Shift. Ba ...
随机推荐
- thinkphp 判断有没有登录 没有登录就自动给他打到登录去
html 代码 <a href="javascript:;" class="user_uid"> <i class="iconfon ...
- python学习之旅(五)
Python基础知识(4):基础数据类型之字符串(Ⅰ) 字符串是 Python 中最常用的数据类型.可以使用引号“ ”来创建字符串,只要为变量分配一个值即可.例如: name=“Alice” 注:字符 ...
- python学习之旅(二)
Python基础知识(1) 一.变量 变量名可以由字母.数字.下划线任意组合而成. 注意:1.变量名不能以数字开头: 2.变量名不能为关键字: 3.变量名尽量起有意义的,能够通过变量名知道代表的是什么 ...
- 关闭shift中英文切换 英文代码/中文注释随意切换着写。
x 背景 写代码的时候总是意外的就切成中文了,特别是代码中大小写切换的这种情况... 例如:"public static TimeZone CurrentTime..."publi ...
- 终于等到你!WebOptimizer - A bundler and minifier for ASP.NET Core
迷人的 ASP.NET Core 有一个美中不足之处,自从一开始接触它到现在,我就一直不喜欢,一直想找到替代品,甚至想过自己实现一个,它就是 BundlerMinifier . 昨天面对 bundle ...
- linux命令: 两个查找工具 locate,find
linux 中有很多查找工具,今天主要讲解locate,find两个工具. 1.locate (1)查询系统上预建的文件索引数据库 /var/lib/mlocate/mlocate.db 注意:如果这 ...
- Web开发——JavaScript基础(JSON教程)
参考: JSON:JavaScript 对象表示法(JavaScript Object Notation). JSON 是存储和交换文本信息的语法.类似 XML. JSON 比 XML 更小.更快,更 ...
- iOS系统NSNotificationCenter中的常用通知名称
//音频 AVF_EXPORT NSString *const AVAudioSessionInterruptionNotification //音频中断出现 AVF_EXPORT NSString ...
- Win10部署IIS 10.0
win10自带IIS10.0 控制面板 >> 程序 >>启用或关闭Windows功能 勾选完之后会安装IIS,安装完成后 计算机管理 >> 服务和应用程序 > ...
- springboot中通过cors协议解决跨域问题
1.对于前后端分离的项目来说,如果前端项目与后端项目部署在两个不同的域下,那么势必会引起跨域问题的出现. 针对跨域问题,我们可能第一个想到的解决方案就是jsonp,并且以前处理跨域问题我基本也是这么处 ...