使用TensorFlow中的Batch Normalization
问题
训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题,例如在训练过程中每一层输入数据分布发生了改变了,那么我们就需要使用更小的learning rate去训练,这一现象被称为internal covariate shift,Batch Normalization能够很好的解决这一问题。目前该算法已经被广泛应用在深度学习模型中,该算法的强大至于在于:
- 可以选择一个较大的学习率,能够达到快速收敛的效果。
- 能够起到Regularizer的效果,在一些情况下可以不使用Dropout,因为BN提高了模型的泛化能力
介绍
我们在将数据输入到神经网络中往往需要对数据进行归一化,原因在于模型的目的就是为了学习模型的数据的分布,如果训练集的数据分布和测试集的不一样那么模型的泛化能力就会很差,另一方面如果模型的每一 batch的数据分布都不一样,那么模型就需要去学习不同的分布,这样模型的训练速度会大大降低。
BN是一个独立的步骤,被应用在激活函数之前,它简单地对输入进行零中心(zero-center)和归一化(normalize),然后使用两个新参数来缩放和移动结果(一个用于缩放,另一个用于缩放转移)。 换句话说,BN让模型学习最佳的尺度和 每层的输入的平均值。
为了零中心和归一化数据的分布,BN需要去估算输入的mean和standard deviation,算法的计算过程如下:

其中:
- \(u_B\)是mini-btach \(B\)的均值,\(\sigma\)是mini-btach的标准差
- \(m_B\)是mini-batch中的样本
- \(\hat{x}^{(i)}\) 是zero-center和normalize后的输入
- 公式4是一个线性变换,是对数据分布的重构,\(z^{(i)}\)是算法对数据重构的output,\(\gamma\)和\(\beta\)分别代表的是对数据的
scale和shift,是我们需要学习的参数
应用
接下来我们就使用TensorFlow来实现带有BN的神经网络,步骤和前面讲到的很多一样,只是在输入激活函数之前多处理了一部而已,在TF中我们使用的实现是tf.layers.batch_normalization。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./") #自动下载数据到这个目录
tf.reset_default_graph()
n_inputs = 28 * 28
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
training = tf.placeholder_with_default(False, shape=(), name='training')
hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1")
bn1 = tf.layers.batch_normalization(hidden1, training=training, momentum=0.9)
bn1_act = tf.nn.elu(bn1)
hidden2 = tf.layers.dense(bn1_act, n_hidden2, name="hidden2")
bn2 = tf.layers.batch_normalization(hidden2, training=training, momentum=0.9)
bn2_act = tf.nn.elu(bn2)
logits_before_bn = tf.layers.dense(bn2_act, n_outputs, name="outputs")
logits = tf.layers.batch_normalization(logits_before_bn, training=training,
momentum=0.9)
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)#labels允许的数据类型有int32, int64
loss = tf.reduce_mean(xentropy,name="loss")
learning_rate = 0.01
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits,y,1) #取值最高的一位
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #结果boolean转为0,1
init = tf.global_variables_initializer()
saver = tf.train.Saver()
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
n_epochs = 20
batch_size = 200
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run([training_op, extra_update_ops],
feed_dict={training: True, X: X_batch, y: y_batch})
accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
y: mnist.test.labels})
print(epoch, "Test accuracy:", accuracy_val)
在上面代码中有一句需要解释一下
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
这是因为在计算BN中需要计算moving_mean和moving_variance并且更新,所以在执行run的时候需要将其添加到执行列表中。我们还可以这样写
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
training_op = optimizer.minimize(loss)
在训练的时候就只需要更新一个参数
sess.run(training_op, feed_dict={training: True, X: X_batch, y: y_batch})
此外,我们会发现在编写神经网络代码中,很多代码都是重复的可以将其模块化,例如将构建每一层神经网络的代码封装成一个function,不过这都是后话,看个人喜好吧。
使用TensorFlow中的Batch Normalization的更多相关文章
- 在tensorflow中使用batch normalization
问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题, ...
- tensorflow中使用Batch Normalization
在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2.dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络 ...
- Pytorch中的Batch Normalization操作
之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接 这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的f ...
- PyTorch中的Batch Normalization
Pytorch中的BatchNorm的API主要有: 1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 ...
- 神经网络中使用Batch Normalization 解决梯度问题
BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...
- tensorflow中batch normalization的用法
网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...
- Batch Normalization原理及其TensorFlow实现——为了减少深度神经网络中的internal covariate shift,论文中提出了Batch Normalization算法,首先是对”每一层“的输入做一个Batch Normalization 变换
批标准化(Bactch Normalization,BN)是为了克服神经网络加深导致难以训练而诞生的,随着神经网络深度加深,训练起来就会越来越困难,收敛速度回很慢,常常会导致梯度弥散问题(Vanish ...
- tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)
tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...
- BN(Batch Normalization)
Batch Nornalization Question? 1.是什么? 2.有什么用? 3.怎么用? paper:<Batch Normalization: Accelerating Deep ...
随机推荐
- springMVC(3)---利用pdf模板下载
springMVC(3)---利用pdf模板下载 在实际开发中,很多时候需要通过把数据库中的数据添加到pdf模板中,然后供客户下载,那我们该如何中呢? 本文主要内容是:用java在pdf模板中加入数据 ...
- Linux网络配置。Win10能ping虚拟机但虚拟机ping不通Win10,关闭Win10防火墙就好。
仅主机模式:配置VMnet1与虚拟机在同一个网段 ifconfig查看网卡名: ifconfig ens33 192.168.1.2: ifconfig查看是否配置成功: 在Windows物理机上pi ...
- 【python】入门:打印字符串、简单计算
- MongoDB集群搭建-副本集
MongoDB集群搭建-副本集 概念性的知识,可以参考本人博客地址: 一.Master-Slave方案: 主从: 二.Replica Set方案: 副本集: 步骤:(只要按步骤操作,100%成功) 1 ...
- [数据结构]C语言链表实现
我学数据结构的时候也是感觉很困难,当我学完后我发现了之所以困难时因为我没有系统的进行学习,而且很多教授都只是注重数据结构思想,而忽略了代码方面,为此我写了这些博文给那些试图自学数据结构的朋友,希望你们 ...
- Hibernate缓存和状态
缓存是介于应用程序和物理数据源之间,其作用是为了降低应用程序对物理数据源访问的频次,从而提高了应用的运行性能. 缓存的介质一般是内存,所以读写速度很快.但如果缓存中存放的数据量非常大时,也会用硬盘 ...
- COM_第四讲_保存GUID_优化使用代码
优化以前的代码,让使用者更方便 一丶 优化思路 1.我们可以将我们写的GUID(类工厂的ID)保存到注册表中,并且保存一下DLL的文件路径,遍历注册表去DLL路径即可. 2.每个类工厂我们就要使用一个 ...
- js功能代码大全
1.日期格式化 //化为2017-08-14 function formatDate (date) { var y = date.getFullYear(); var m = date.getMont ...
- scrapy框架第一章
操作环境:python2.7+scrapy 安装比较简单,网上教程也超多,就不在此赘述. 示例网站:https://www.cnblogs.com/cate/python/ (爬去关于博客园所有pyt ...
- Spring 自动装配及自动注册的相关配置
Spring支持好几种自动装配(Autowiring)的方式,以及自动扫描并注册Bean的配置(在beans.xml中配置). 下文我们进行一个小结. 1. <context: annotati ...