『TensorFlow』批处理类
基础知识

下面有莫凡的对于批处理的解释:
fc_mean,fc_var = tf.nn.moments(
Wx_plus_b,
axes=[0],
# 想要 normalize 的维度, [0] 代表 batch 维度
# 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
# 上面那一步, 在做如下事情:
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
# Wx_plus_b = Wx_plus_b * scale + shift
tf.contrib.layers.batch_norm:封装好的批处理类
class batch_norm():
'''batch normalization层''' def __init__(self, epsilon=1e-5,
momentum=0.9, name='batch_norm'):
'''
初始化
:param epsilon: 防零极小值
:param momentum: 滑动平均参数
:param name: 节点名称
'''
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name def __call__(self, x, train=True):
# 一个封装了的会在内部调用batch_normalization进行正则化的高级接口
return tf.contrib.layers.batch_norm(x,
decay=self.momentum, # 滑动平均参数
updates_collections=None,
epsilon=self.epsilon,
scale=True,
is_training=train, # 影响滑动平均
scope=self.name)
1.
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
need to be added as a dependency to the `train_op`. For example:
```python
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
```
One can set updates_collections=None to force the updates in place, but that
can have a speed penalty, especially in distributed settings.2.
is_training: Whether or not the layer is in training mode. In training mode
it would accumulate the statistics of the moments into `moving_mean` and
`moving_variance` using an exponential moving average with the given
`decay`. When it is not in training mode then it would use the values of
the `moving_mean` and the `moving_variance`.
tf.nn.batch_normalization:原始接口封装使用
实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装,这个类又进行了一次封装(主要是制订了一部分默认参数),实际操作时可以仅仅使用tf.contrib.layers.batch_norm函数,它已经足够方便了。
添加了滑动平均处理之后,也就是不使用封装,直接使用tf.nn.moments和tf.nn.batch_normalization实现的batch_norm函数:
def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed
另一种将滑动平均展开了的方式,
def batch_norm(x, size, training, decay=0.999):
beta = tf.Variable(tf.zeros([size]), name='beta')
scale = tf.Variable(tf.ones([size]), name='scale')
pop_mean = tf.Variable(tf.zeros([size]))
pop_var = tf.Variable(tf.ones([size]))
epsilon = 1e-3 batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) def batch_statistics():
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm') def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm') return tf.cond(training, batch_statistics, population_statistics)
注, tf.cond:流程控制,参数一True,则执行参数二的函数,否则执行参数三函数。
『TensorFlow』批处理类的更多相关文章
- 『TensorFlow』专题汇总
		
TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...
 - 『TensorFlow』流程控制
		
『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...
 - 『TensorFlow』读书笔记_降噪自编码器
		
『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...
 - 『TensorFlow』梯度优化相关
		
tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...
 - 『TensorFlow』模型保存和载入方法汇总
		
『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...
 - 『TensorFlow』命令行参数解析
		
argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...
 - 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
		
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
 - 『TensorFlow』DCGAN生成动漫人物头像_下
		
『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...
 - 『TensorFlow』滑动平均
		
滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...
 
随机推荐
- C#中的double类型数据向SQL sqerver 存储与读取问题
			
1.存储 由于double类型在SQLsever中并没有对应数据,试过对应float.real类型,发现小数位都存在四舍五入的现象,目前我使用的是decimal类型,用此类型时个人觉得小数位数应该比自 ...
 - 4、 LwIP协议栈规范翻译——流程模型
			
4.流程模型 协议实现的流程模型描述了系统被划分为不同的流程的方式.用于实现通信协议的一个流程模型是让每个协议作为一个独立的进程运行.有了这个模型,严格的协议分层被强制执行,并且协议之间的通信点必须严 ...
 - Hibernate 补充 ManyToOne、OneToMany、OneToOne的使用例
			
1.前言 Hibernate 为程序员提供一种级联操作,在编写程序时,通过 Hibernate 的级联功能可以很方便的操作数据库的主从表的数据, 我们最常用的级联是级联保存和级联删除. ...
 - int bool 字符串 列表 字典 集合
			
1.int和bool 输出i的最大二进制位数inti = 1000 print(i.bit_length()) 2. str int bool list set dict tuple 相互转换 pr ...
 - 用git如何把单个文件回退到某一版本
			
暂定此文件为a.jsp 1.首先到a.jsp所在目录: 通过 git log a.jsp 查看a.jsp的更改记录 2.找到想要回退的版本号:例如 fcd2093 通过 git reset fcd ...
 - ECMAScript6 入门教程 初学记录let命令 块级作用域
			
一.基本语法-let命令 (1)ES6新增了let命令,用来声明变量.所声明的变量,只在let命令所在的代码块内有效. 循环的计数器,就很合适使用let命令.计数器i只在for循环体内有效,在循环体外 ...
 - Spring整合MyBatis(简单登录Demo)
			
SpringMvc简单整合(登录模块) 1.1 整合思路 1.SqlSessionFactory对象应该放到spring容器中作为单例存在. 2.传统dao的开发方式中,应该从spring容器中获得s ...
 - “==”和equals
			
== 比较的是变量(栈)内存中存放的对象的(堆)内存地址,用来判断两个对象的地址是否相同,即是否是指相同一个对象.比较的是真正意义上的指针操作. 1.比较的是操作符两端的操作数是否是同一个对象.2.两 ...
 - Jmeter接口测试+压力测试+环境配置+证书导出
			
jmeter是apache公司基于java开发的一款开源压力测试工具,体积小,功能全,使用方便,是一个比较轻量级的测试工具,使用起来非常简单.因为jmeter是java开发的,所以运行的时候必须先要安 ...
 - linux bash tutorial
			
bash read-special-keys-in-bash xdotool linux 登录启动顺序