问题

训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题,例如在训练过程中每一层输入数据分布发生了改变了,那么我们就需要使用更小的learning rate去训练,这一现象被称为internal covariate shiftBatch Normalization能够很好的解决这一问题。目前该算法已经被广泛应用在深度学习模型中,该算法的强大至于在于:

  • 可以选择一个较大的学习率,能够达到快速收敛的效果。
  • 能够起到Regularizer的效果,在一些情况下可以不使用Dropout,因为BN提高了模型的泛化能力

介绍

我们在将数据输入到神经网络中往往需要对数据进行归一化,原因在于模型的目的就是为了学习模型的数据的分布,如果训练集的数据分布和测试集的不一样那么模型的泛化能力就会很差,另一方面如果模型的每一 batch的数据分布都不一样,那么模型就需要去学习不同的分布,这样模型的训练速度会大大降低。

BN是一个独立的步骤,被应用在激活函数之前,它简单地对输入进行零中心(zero-center)和归一化(normalize),然后使用两个新参数来缩放和移动结果(一个用于缩放,另一个用于缩放转移)。 换句话说,BN让模型学习最佳的尺度和 每层的输入的平均值。

为了零中心和归一化数据的分布,BN需要去估算输入的mean和standard deviation,算法的计算过程如下:



其中:

  • \(u_B\)是mini-btach \(B\)的均值,\(\sigma\)是mini-btach的标准差
  • \(m_B\)是mini-batch中的样本
  • \(\hat{x}^{(i)}\) 是zero-center和normalize后的输入
  • 公式4是一个线性变换,是对数据分布的重构,\(z^{(i)}\)是算法对数据重构的output,\(\gamma\)和\(\beta\)分别代表的是对数据的scaleshift,是我们需要学习的参数

应用

接下来我们就使用TensorFlow来实现带有BN的神经网络,步骤和前面讲到的很多一样,只是在输入激活函数之前多处理了一部而已,在TF中我们使用的实现是tf.layers.batch_normalization

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./") #自动下载数据到这个目录
tf.reset_default_graph()
n_inputs = 28 * 28
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
training = tf.placeholder_with_default(False, shape=(), name='training') hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1")
bn1 = tf.layers.batch_normalization(hidden1, training=training, momentum=0.9)
bn1_act = tf.nn.elu(bn1) hidden2 = tf.layers.dense(bn1_act, n_hidden2, name="hidden2")
bn2 = tf.layers.batch_normalization(hidden2, training=training, momentum=0.9)
bn2_act = tf.nn.elu(bn2) logits_before_bn = tf.layers.dense(bn2_act, n_outputs, name="outputs")
logits = tf.layers.batch_normalization(logits_before_bn, training=training,
momentum=0.9) with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)#labels允许的数据类型有int32, int64
loss = tf.reduce_mean(xentropy,name="loss")
learning_rate = 0.01
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits,y,1) #取值最高的一位
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #结果boolean转为0,1
init = tf.global_variables_initializer()
saver = tf.train.Saver() extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) n_epochs = 20
batch_size = 200
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run([training_op, extra_update_ops],
feed_dict={training: True, X: X_batch, y: y_batch})
accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
y: mnist.test.labels})
print(epoch, "Test accuracy:", accuracy_val)

在上面代码中有一句需要解释一下

extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

这是因为在计算BN中需要计算moving_meanmoving_variance并且更新,所以在执行run的时候需要将其添加到执行列表中。我们还可以这样写

with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
training_op = optimizer.minimize(loss)

在训练的时候就只需要更新一个参数

sess.run(training_op, feed_dict={training: True, X: X_batch, y: y_batch})

此外,我们会发现在编写神经网络代码中,很多代码都是重复的可以将其模块化,例如将构建每一层神经网络的代码封装成一个function,不过这都是后话,看个人喜好吧。

使用TensorFlow中的Batch Normalization的更多相关文章

  1. 在tensorflow中使用batch normalization

    问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题, ...

  2. tensorflow中使用Batch Normalization

    在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2.dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络 ...

  3. Pytorch中的Batch Normalization操作

    之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接 这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的f ...

  4. PyTorch中的Batch Normalization

    Pytorch中的BatchNorm的API主要有: 1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 ...

  5. 神经网络中使用Batch Normalization 解决梯度问题

    BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...

  6. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  7. Batch Normalization原理及其TensorFlow实现——为了减少深度神经网络中的internal covariate shift,论文中提出了Batch Normalization算法,首先是对”每一层“的输入做一个Batch Normalization 变换

    批标准化(Bactch Normalization,BN)是为了克服神经网络加深导致难以训练而诞生的,随着神经网络深度加深,训练起来就会越来越困难,收敛速度回很慢,常常会导致梯度弥散问题(Vanish ...

  8. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  9. BN(Batch Normalization)

    Batch Nornalization Question? 1.是什么? 2.有什么用? 3.怎么用? paper:<Batch Normalization: Accelerating Deep ...

随机推荐

  1. sql基础题目测试及正确答案

    在网上做了一套基本的sql题目,以下是我的写的答案,适合基础人员练练 --创建测试数据 use test create table Student(S# varchar(10),Sname nvarc ...

  2. 遍历map的几种方式

    1,平时开发中对map的使用很多,然后发现了很多map可能存在的各种问题:如HashMap 需要放置 1024 个元素,由于没有设置容量初始大小,随着元素不断增加,容量 7 次被迫扩大,resize ...

  3. wait/notify 实现多线程交叉备份

    一.任务 创建20个线程,其中10个线程是将数据备份到 A 数据库中,另外10 个线程将数据备份到 B 数据库中,并且备份 A 数据库和 备份 B 数据库的是交叉运行的. 二.实现 1.实现备份 A ...

  4. arcgis api for js热力图优化篇-不依赖地图服务

    前面我写过一篇文章,介绍如何实现arcgis api的热力图效果,但是依赖arcgis server发布的地图服务来获取热力图的数据源.实际应用中,很多业务数据来源数据库,并不一定是从地图服务来获取的 ...

  5. 记vue API 知识点

    1. v-cloak指令:这个指令保持在元素上直到关联实例结束编译.和 CSS 规则如 [v-cloak] { display: none } 一起用时,这个指令可以隐藏未编译的 Mustache 标 ...

  6. Docker(一):Docker安装

    简介:Docker是一个开源的引擎,可以轻松的为任何应用创建一个轻量级的.可移植的.自给自足的容器.开发者在笔记本上编译测试通过的容器可以批量地在生产环境中部署,包括VMs(虚拟机).bare met ...

  7. npm lodash

    在数据操作时,Lodash 就是我的弹药库,不管遇到多复杂的数据结构都能用一些函数轻松拆解. ES6 中也新增了诸多新的对象函数,一些简单的项目中 ES6 就足够使用了,但还是会有例外的情况引用了少数 ...

  8. springBoot系列教程04:mybatis及druid数据源的集成及查询缓存的使用

    首先说下查询缓存:查询缓存就是相同的数据库查询请求在设定的时间间隔内仅查询一次数据库并保存到redis中,后续的请求只要在时间间隔内都直接从redis中获取,不再查询数据库,提高查询效率,降低服务器负 ...

  9. C# DataGridView 列的显示顺序

    this.dataGridView1.Columns["列名"].DisplayIndex=Convert.ToInt32("你要放置的位置")

  10. popupwindow那些坑

    1. new PopupWindow(vw, ViewGroup.LayoutParams.MATCH_PARENT, ViewGroup.LayoutParams.MATCH_PARENT); 如果 ...