BN 详解和使用Tensorflow实现(参数理解)
Tensorflow BN具体实现(多种方式):
理论知识(参照大佬):https://blog.csdn.net/hjimce/article/details/50866313
补充知识:
① tf.nn.moments 这个函数的输出就是BN需要的mean和variance。
方式1:
tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None
):原始接口封装使用
x
·mean moments方法的输出之一
·variance moments方法的输出之一
·offset BN需要学习的参数
·scale BN需要学习的参数
·variance_epsilon 归一化时防止分母为0加的一个常量
实现代码:
import tensorflow as tf # 实现Batch Normalization
def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
# 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
shape = x.shape
assert len(shape) in [2,4] param_shape = shape[-1]
with tf.variable_scope(name):
# 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
gamma = tf.get_variable('gamma',param_shape,initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape,initializer=tf.constant_initializer(0)) # 计算当前整个batch的均值与方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(x,axes,name='moments') # 采用滑动平均更新均值与方差
ema = tf.train.ExponentialMovingAverage(moving_decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var) # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
mean, var = tf.cond(tf.equal(is_training,True),mean_var_with_update,
lambda:(ema.average(batch_mean),ema.average(batch_var))) # 最后执行batch normalization
return tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)
方式2:
tf.contrib.layers.batch_norm:封装好的批处理类
实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装
参数:
1 inputs: 输入
2 decay :衰减系数。合适的衰减系数值接近1.0,特别是含多个9的值:0.999,0.99,0.9。如果训练集表现很好而验证/测试集表现得不好,选择
小的系数(推荐使用0.9)。如果想要提高稳定性,zero_debias_moving_mean设为True
3 center:如果为True,有beta偏移量;如果为False,无beta偏移量
4 scale:如果为True,则乘以gamma。如果为False,gamma则不使用。当下一层是线性的时(例如nn.relu),由于缩放可以由下一层完成,
所以可以禁用该层。
5 epsilon:避免被零除
6 activation_fn:用于激活,默认为线性激活函数
7 param_initializers : beta, gamma, moving mean and moving variance的优化初始化
8 param_regularizers : beta and gamma正则化优化
9 updates_collections :Collections来收集计算的更新操作。updates_ops需要使用train_op来执行。如果为None,则会添加控件依赖项以
确保更新已计算到位。
10 is_training:图层是否处于训练模式。在训练模式下,它将积累转入的统计量moving_mean并 moving_variance使用给定的指数移动平均值 decay。当它不是在训练模式,那么它将使用的数值moving_mean和moving_variance。
11 scope:可选范围variable_scope
注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作被放入
tf.GraphKeys.UPDATE_OPS,所以需要添加它们作为依赖项train_op
。例如:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)
可以将updates_collections = None设置为强制更新,但可能会导致速度损失,尤其是在分布式设置中。
实现代码:
import tensorflow as tf def batch_norm(x,epsilon=1e-5, momentum=0.9,train=True, name="batch_norm"):
with tf.variable_scope(name):
epsilon = epsilon
momentum = momentum
name = name
return tf.contrib.layers.batch_norm(x, decay=momentum, updates_collections=None, epsilon=epsilon,
scale=True, is_training=train,scope=name)
BN一般放哪一层?
BN层的设定一般是按照conv->bn->scale->relu的顺序来形成一个block
训练和测试时 BN的区别???
bn层训练的时候,基于当前batch的mean和std调整分布;当测试的时候,也就是测试的时候,基于全部训练样本的mean和std调整分布
所以,训练的时候需要让BN层工作,并且保存BN层学习到的参数。测试的时候加载训练得到的参数来重构测试集。
BN 详解和使用Tensorflow实现(参数理解)的更多相关文章
- Java虚拟机详解(五)------JVM参数(持续更新)
JVM参数有很多,其实我们直接使用默认的JVM参数,不去修改都可以满足大多数情况.但是如果你想在有限的硬件资源下,部署的系统达到最大的运行效率,那么进行相关的JVM参数设置是必不可少的.下面我们就来对 ...
- 一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)
雷锋网按:本文作者陆池,原文载于作者个人博客,雷锋网已获授权. 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实例,这个星期就用 ...
- Java虚拟机详解03----常用JVM配置参数
[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/4 ...
- Apache:详解QSA,PT,L,E参数的作用
[QSA] 当被替换的URI包含有query string的时候,apache的默认行为是,丢弃原有的query string 并直接使用新产生的query string,如果加上了[QSA]选项,那 ...
- C#调用存储过程详解(带返回值、参数输入输出等)
CREATE PROCEDURE [dbo].[GetNameById] @studentid varchar(8), @studentname nvarchar(50) OUTPUT AS BEGI ...
- object detection api调参详解(兼SSD算法参数详解)
一.引言 使用谷歌提供的object detection api图像识别框架,我们可以很方便地重新训练一个预训练模型,用于自己的具体业务.以我所使用的ssd_mobilenet_v1预训练模型为例,训 ...
- Python趣味入门9:函数是你走过的套路,详解函数、调用、参数及返回值
1.概念 琼恩·雪诺当上守夜人的司令后,为训练士兵对付僵尸兵团,把成功斩杀僵尸的一系列动作编排成了"葵花宝典剑法",这就是函数.相似,在计算机世界,一系列前后连续的计算机语句组合在 ...
- Java 虚拟机系列二:垃圾收集机制详解,动图帮你理解
前言 上篇文章已经给大家介绍了 JVM 的架构和运行时数据区 (内存区域),本篇文章将给大家介绍 JVM 的重点内容--垃圾收集.众所周知,相比 C / C++ 等语言,Java 可以省去手动管理内存 ...
- Eureka(一)术语详解(用具体的事物理解抽象的概念)
最近工作较闲,所以自己研究了下eureka的原理,实现,和集群搭建等.(注:我没实操过eureka集群项目,都是自己做的demo产生的结论,如果有错误欢迎指出) 首先说一下我对eureka的一些术语的 ...
随机推荐
- day03-执行python方式、变量及数据类型简介
目录 执行Python程序的两种方式 1. 第一种:交互式 2. 第二种:命令式 3. Python执行程序的三个阶段 变量 变量 什么是变量 Python中的变量 变量名的命名规范 内存管理 定义变 ...
- 浏览器 <html>相关
若页面需默认用极速核,增加标签:<meta name="renderer" content="webkit"> https://blog.csdn ...
- Secret of Chocolate Poles (Aizu1378——dp)
Select Of Chocolate Poles 题意:有一个竖直放置的高度为l cm的盒子,现在有三种方块分别为1cm的白块,1cm的黑块,k cm的黑块,要求第一块放进去的必须是黑色的,盒子最上 ...
- [luogu2594 ZJOI2009]染色游戏(博弈论)
传送门 Solution 对于硬币问题,结论是:当前局面的SG值等于所有背面朝上的单个硬币SG值的异或和 对于求单个背面朝上的硬币SG值...打表找规律吧 Code //By Menteur_Hxy ...
- 【[Offer收割]编程练习赛13 D】骑士游历(矩阵模板,乘法,加法,乘方)
[题目链接]:http://hihocoder.com/problemset/problem/1504 [题意] [题解] 可以把二维的坐标转成成一维的; 即(x,y)->(x-1)*8+y 然 ...
- Win8.1 Hyper-V 共享本机IP上网
公司的Win8.1自带了Hyper v,可是死活连接不到网络. 原因是公司只给每人分配一个局域网IP,而默认情况下Hyper-V的虚拟机会动态分配了一个没有经过MIS人员许可的IP…… 百度了N久解决 ...
- 清北学堂模拟赛d3t4 a
分析:很水的一道题,就是用栈来看看是不是匹配就好了,只是最后没有判断栈是否为空而WA了一个点,以后做题要注意了. #include <bits/stdc++.h> using namesp ...
- 使用git bash向github远程仓库提交代码
1.登录github,创建仓库. 2.切换到要提交的文件目录下. 3.打开git bash 3.1.初始化仓库 git init 3.2.将本地仓库与远程仓库关联 git remote add ori ...
- Servlet请求参数编码处理(POST & GET)
小巧,但在中文语境下,还是要注意的. 以下是关键语句,注意转码的先后顺序,这源于GET是HTTP服务器处理,而POST是WEB容器处理: String name = request.getParame ...
- P1294 高手去散步 洛谷
https://www.luogu.org/problem/show?pid=1294#sub 题目背景 高手最近谈恋爱了.不过是单相思.“即使是单相思,也是完整的爱情”,高手从未放弃对它的追求.今天 ...