批量归一化

在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。

在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。

了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。

批量归一化的定义

在TensorFlow中有自带的BN函数定义:

tf.nn.batch_normalization(x,
maen,
variance,
offset,
scale,
variance_epsilon)

各个参数的含义如下:

x:代表输入

mean:代表样本的均值

variance:代表方差

offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。

scale:代表缩放,即乘以一个转化值,同理,一般是1

variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。

要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。

为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用

tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。

批量归一化的简单用法

下面介绍具体的用法,在使用的时候需要引入头文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函数的定义如下:

batch_norm(inputs,
decay,
center,
scale,
epsilon,
activation_fn,
param_initializers=None,
param_regularizers=None,
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=False,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99)

各参数的具体含义如下:

inputs:输入

decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。

scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。

epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。

is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。

updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。

reuse:支持变量共享。

具体的代码如下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
train = tf.Variable(tf.constant(False)) x_images = tf.reshape(x, [-1, 32, 32, 3]) def batch_norm_layer(value, train=False, name='batch_norm'):
if train is not False:
return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)
else:
return batch_norm(value, decay=0.9, updates_collections=None, is_training=False) w_conv1 = init_cnn.weight_variable([3, 3, 3, 64]) # [-1, 32, 32, 3]
b_conv1 = init_cnn.bias_variable([64])
h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train))
h_pool1 = init_cnn.max_pool_2x2(h_conv1) w_conv2 = init_cnn.weight_variable([3, 3, 64, 64]) # [-1, 16, 16, 64]
b_conv2 = init_cnn.bias_variable([64])
h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train))
h_pool2 = init_cnn.max_pool_2x2(h_conv2) w_conv3 = init_cnn.weight_variable([3, 3, 64, 32]) # [-1, 18, 8, 32]
b_conv3 = init_cnn.bias_variable([32])
h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train))
h_pool3 = init_cnn.max_pool_2x2(h_conv3) w_conv4 = init_cnn.weight_variable([3, 3, 32, 16]) # [-1, 18, 8, 32]
b_conv4 = init_cnn.bias_variable([16])
h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train))
h_pool4 = init_cnn.max_pool_2x2(h_conv4) w_conv5 = init_cnn.weight_variable([3, 3, 16, 10]) # [-1, 4, 4, 16]
b_conv5 = init_cnn.bias_variable([10])
h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train))
h_pool5 = init_cnn.avg_pool_4x4(h_conv5) # [-1, 4, 4, 10] y_pool = tf.reshape(h_pool5, shape=[-1, 10]) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool)) optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。

TensorFlow——批量归一化操作的更多相关文章

  1. Batch Normalization批量归一化

    BN的深度理解:https://www.cnblogs.com/guoyaohua/p/8724433.html BN: BN的意义:在激活函数之前将输入归一化到高斯分布,控制到激活函数的敏感区域,避 ...

  2. 第十八节,TensorFlow中使用批量归一化(BN)

    在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...

  3. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  4. 批量归一化batch_normalization

    为了解决在深度神经网络训练初期降低梯度消失/爆炸问题,Sergey loffe和Christian Szegedy提出了使用批量归一化的技术的方案,该技术包括在每一层激活函数之前在模型里加一个操作,简 ...

  5. 深度学习面试题21:批量归一化(Batch Normalization,BN)

    目录 BN的由来 BN的作用 BN的操作阶段 BN的操作流程 BN可以防止梯度消失吗 为什么归一化后还要放缩和平移 BN在GoogLeNet中的应用 参考资料 BN的由来 BN是由Google于201 ...

  6. 对抗生成网络-图像卷积-mnist数据生成(代码) 1.tf.layers.conv2d(卷积操作) 2.tf.layers.conv2d_transpose(反卷积操作) 3.tf.layers.batch_normalize(归一化操作) 4.tf.maximum(用于lrelu) 5.tf.train_variable(训练中所有参数) 6.np.random.uniform(生成正态数据

    1. tf.layers.conv2d(input, filter, kernel_size, stride, padding) # 进行卷积操作 参数说明:input输入数据, filter特征图的 ...

  7. 从头学pytorch(十九):批量归一化batch normalization

    批量归一化 论文地址:https://arxiv.org/abs/1502.03167 批量归一化基本上是现在模型的标配了. 说实在的,到今天我也没搞明白batch normalize能够使得模型训练 ...

  8. 【转】批量复制操作(SqlBulkCopy)的出错处理:事务提交、回滚

    原文地址:http://blog.csdn.net/westsource/article/details/6658109 默认情况下,批量复制操作作为独立的操作执行. 批量复制操作以非事务性方式发生, ...

  9. 使用Ajax实现的批量删除操作(C#)

    今天做了一个简单的批量删除操作,虽然简单,但是很多问题出现,终究还是技术不够熟练. 现在在这里跟大家分享一下.仅供学习... 1.在前台获取用户点击的信息id,把这里id封装到一个数组里面:(rows ...

随机推荐

  1. C#的选择语句练习(二)

    1.输入a,b,c三个数,计算一元二次方程ax²+bx+c的根:若a=0,则不是一元二次方程:△=b²-4ac,根的计算公式为-b±√b²-4ac/2a:若△=b²-4ac>0,则方程有两个不一 ...

  2. 教你如何成为Spark大数据高手?

    教你如何成为Spark大数据高手? Spark目前被越来越多的企业使用,和Hadoop一样,Spark也是以作业的形式向集群提交任务,那么如何成为Spark大数据高手?下面就来个深度教程. Spark ...

  3. Yarn install 报错 Resolving packages... [2/4] Fetching packages... info There appears to be trouble with your network connection. Retrying

    1.设置淘宝代理 yarn config set registry 'https://registry.npm.taobao.org' 2.如果网址本地可以打开,说明你本地有代理设置 所以需要按本地的 ...

  4. Native memory allocation (mmap) failed to map xxx bytes for committing reserved memory

    遇到问题 在服务器上运行 nexus 出现Native memory allocation (mmap) failed to map 838860800 bytes for committing re ...

  5. H3C PPP协议的组成

  6. linux 如何编译安装软件

  7. centos linux mysql 10060远程错误代码

    Navicat for MySQL远程连接数据错误代码10060 1.登陆远程linux服务器命令界面 vim /etc/sysconfig/iptables  进入防火墙配置修改 增加以下两条防火墙 ...

  8. el-tree文本内容过多显示不完全问题(解决)

    布局: <span class="custom-tree-node" slot-scope="{ node, data }"> 外层span 树节点 ...

  9. dot net double 数组转 float 数组

    本文告诉大家如果遇到 double 数组转 float 数组千万不要使用 Cast ,一般都使用 select 强转. 最近在开发Avalonia ,有大神告诉我,下面的代码可以这样写 dashes ...

  10. H3C 路由器单跳操作