『教程』Batch Normalization 层介绍

基础知识

下面有莫凡的对于批处理的解释:

fc_mean,fc_var = tf.nn.moments(
Wx_plus_b,
axes=[0],
# 想要 normalize 的维度, [0] 代表 batch 维度
# 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
# 上面那一步, 在做如下事情:
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
# Wx_plus_b = Wx_plus_b * scale + shift

tf.contrib.layers.batch_norm:封装好的批处理类

class batch_norm():
'''batch normalization层''' def __init__(self, epsilon=1e-5,
momentum=0.9, name='batch_norm'):
'''
初始化
:param epsilon: 防零极小值
:param momentum: 滑动平均参数
:param name: 节点名称
'''
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name def __call__(self, x, train=True):
# 一个封装了的会在内部调用batch_normalization进行正则化的高级接口
return tf.contrib.layers.batch_norm(x,
decay=self.momentum, # 滑动平均参数
updates_collections=None,
epsilon=self.epsilon,
scale=True,
is_training=train, # 影响滑动平均
scope=self.name)

1.

Note: when training, the moving_mean and moving_variance need to be updated.
    By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
    need to be added as a dependency to the `train_op`. For example:
    
    ```python
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)
    ```
    
    One can set updates_collections=None to force the updates in place, but that
    can have a speed penalty, especially in distributed settings.

2.

is_training: Whether or not the layer is in training mode. In training mode
        it would accumulate the statistics of the moments into `moving_mean` and
        `moving_variance` using an exponential moving average with the given
        `decay`. When it is not in training mode then it would use the values of
        the `moving_mean` and the `moving_variance`.

tf.nn.batch_normalization:原始接口封装使用

实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装,这个类又进行了一次封装(主要是制订了一部分默认参数),实际操作时可以仅仅使用tf.contrib.layers.batch_norm函数,它已经足够方便了。

添加了滑动平均处理之后,也就是不使用封装,直接使用tf.nn.moments和tf.nn.batch_normalization实现的batch_norm函数:

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed

另一种将滑动平均展开了的方式,

def batch_norm(x, size, training, decay=0.999):
beta = tf.Variable(tf.zeros([size]), name='beta')
scale = tf.Variable(tf.ones([size]), name='scale')
pop_mean = tf.Variable(tf.zeros([size]))
pop_var = tf.Variable(tf.ones([size]))
epsilon = 1e-3 batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) def batch_statistics():
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm') def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm') return tf.cond(training, batch_statistics, population_statistics)

注, tf.cond:流程控制,参数一True,则执行参数二的函数,否则执行参数三函数。

『TensorFlow』批处理类的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  3. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  4. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  5. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  6. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  7. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  8. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  9. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

随机推荐

  1. C++11 std::ref使用场景

    C++本身有引用(&),为什么C++11又引入了std::ref(或者std::cref)? 主要是考虑函数式编程(如std::bind)在使用时,是对参数直接拷贝,而不是引用.如下例子: # ...

  2. intellij idea 的全局搜索快捷键方法

    1.Ctrl+N按名字搜索类 相当于eclipse的ctrl+shift+R,输入类名可以定位到这个类文件,就像idea在其它的搜索部分的表现一样,搜索类名也能对你所要搜索的内容多个部分进行匹配,而且 ...

  3. Cartographer源码阅读(7):轨迹推算和位姿推算的原理

    其实也就是包括两个方面的内容:类似于运动模型的位姿估计和扫描匹配,因为需要计算速度,所以时间就有必要了! 1. PoseExtrapolator解决了IMU数据.里程计和位姿信息进行融合的问题. 该类 ...

  4. select,poll,epoll

    1. Epoll 是何方神圣? Epoll 可是当前在 Linux 下开发大规模并发网络程序的热门人选, Epoll 在 Linux2.6 内核中正式引入,和 select 相似,其实都 I/O 多路 ...

  5. Vue:Promise概要

    1.Promise中then是异步的 2.Promise 的then里面两个回调,默认第一个resolve,第二个reject:不会进入catch:如果只有一个回调则进入catch var p1=ne ...

  6. node (02 CommonJs 和 Nodejs 中自定义模块)顺便讲讲module.exports和exports的区别 dependencies 与 devDependencies 之间的区别

    CommonJS 规范的提出,主要是为了弥补当前 JavaScript 没有标准的缺陷.它的终极目标就是:提供一个类似 Python,Ruby 和 Java 语言的标准库,而不只是停留在小脚本程序的阶 ...

  7. Unity之显示fps功能

    如下: using UnityEngine; using System.Collections; public class ShowFpsOnGUI : MonoBehaviour { public ...

  8. 在java中,将String类型字符串s赋值为null后,将字符串与其他字符串拼接后得到结果出现了null字符串与其他字符连接的样式

    String s = null; s  += "hello"; System.out.println(s); 结果为:nullhello 原因: 先应用String.valueOf ...

  9. Porsche Piwis Tester II V12.100 Version Released

    Piwis Tester II v12.100 Version released today! In this new version we can find the latest type Pors ...

  10. Docker Kubernetes 创建管理 Deployment

    Docker Kubernetes YAML文件创建容器 通过创建Deployment来管理pods从而创建容器.它会同时创建容器.pod.以及Deployment ! 环境: 系统:Centos 7 ...