干货 | 这可能全网最好的BatchNorm详解
文章来自:公众号【机器学习炼丹术】。求关注~
其实关于BN层,我在之前的文章“梯度爆炸”那一篇中已经涉及到了,但是鉴于面试经历中多次问道这个,这里再做一个更加全面的讲解。
Internal Covariate Shift(ICS)
Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。
这里做一个简单的数学定义,对于全链接网络而言,第i层的数学表达可以体现为:
\(Z^i=W^i\times input^i+b^i\)
\(input^{i+1}=g^i(Z^i)\)
- 第一个公式就是一个简单的线性变换;
- 第二个公式是表示一个激活函数的过程。
【怎么理解ICS问题】
我们知道,随着梯度下降的进行,每一层的参数\(W^i,b^i\)都会不断地更新,这意味着\(Z^i\)的分布也不断地改变,从而\(input^{i+1}\)的分布发生了改变。这意味着,除了第一层的输入数据不改变,之后所有层的输入数据的分布都会随着模型参数的更新发生改变,而每一层就要不停的去适应这种数据分布的变化,这个过程就是Internal Covariate Shift。
BN解决的问题
【ICS带来的收敛速度慢】
因为每一层的参数不断发生变化,从而每一层的计算结果的分布发生变化,后层网络不断地适应这种分布变化,这个时候会让整个网络的学习速度过慢。
【梯度饱和问题】
因为神经网络中经常会采用sigmoid,tanh这样的饱和激活函数(saturated actication function),因此模型训练有陷入梯度饱和区的风险。解决这样的梯度饱和问题有两个思路:第一种就是更为非饱和性激活函数,例如线性整流函数ReLU可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也就是Normalization的思路。
Batch Normalization
batchNormalization就像是名字一样,对一个batch的数据进行normalization。
现在假设一个batch有3个数据,每个数据有两个特征:(1,2),(2,3),(0,1)
如果做一个简单的normalization,那么就是计算均值和方差,把数据减去均值除以标准差,变成0均值1方差的标准形式。
对于第一个特征来说:
\(\mu=\frac{1}{3}(1+2+0)=1\)
\(\sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67\)
【通用公式】
\(\mu=\frac{1}{m}\sum_{i=1}^m{Z}\)
\(\sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu)\)
\(\hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}\)
- 其中m表示一个batch的数量。
- \(\epsilon\)是一个极小数,防止分母为0。
目前为止,我们做到了让每个特征的分布均值为0,方差为1。这样分布都一样,一定不会有ICS问题
如同上面提到的,Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。每一层的分布都相同,所有任务的数据分布都相同,模型学啥呢
【0均值1方差数据的弊端】
- 数据表达能力的缺失;
- 通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。(线性区域和饱和区域都不理想,最好是非线性区域)
为了解决这个问题,BN层引入了两个可学习的参数\(\gamma\)和\(\beta\),这样,经过BN层normalization的数据其实是服从\(\beta\)均值,\(\gamma^2\)方差的数据。
所以对于某一层的网络来说,我们现在变成这样的流程:
- \(Z=W\times input^i+b\)
- \(\hat{Z}=\gamma \times \frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta\)
- \(input^{i+1}=g(\hat{Z})\)
(上面公式中,省略了\(i\),总的来说是表示第i层的网络层产生第i+1层输入数据的过程)
测试阶段的BN
我们知道BN在每一层计算的\(\mu\)与\(\sigma^2\) 都是基于当前batch中的训练数据,但是这就带来了一个问题:我们在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,这样的\(\sigma^2\)和\(\mu\)要怎么计算呢?
利用训练集训练好模型之后,其实每一层的BN层都保留下了每一个batch算出来的\(\mu\)和\(\sigma^2\).然后呢利用整体的训练集来估计测试集的\(\mu_{test}\)和\(\sigma_{test}^2\)
\(\mu_{test}=E(\mu_{train})\)
\(\sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2)\)
然后再对测试机进行BN层:

当然,计算训练集的\(\mu\)和\(\simga\)的方法除了上面的求均值之外。吴恩达老师在其课程中也提出了,可以使用指数加权平均的方法。不过都是同样的道理,根据整个训练集来估计测试机的均值方差。
BN层的好处有哪些
BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度。
BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题
通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习\(\gamma\)与 \(\beta\) 又让数据保留更多的原始信息。BN具有一定的正则化效果
在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音
BN与其他normalizaiton的比较
【weight normalization】
Weight Normalization是对网络权值进行normalization,也就是L2 norm。
相对于BN有下面的优势:
- WN通过重写神经网络的权重的方式来加速网络参数的收敛,不依赖于mini-batch。BN因为以来minibatch所以BN不能用于RNN网路,而WN可以。而且BN要保存每一个batch的均值方差,所以WN节省内存;
- BN的优点中有正则化效果,但是添加噪音不适合对噪声敏感的强化学习、GAN等网络。WN可以引入更小的噪音。
但是WN要特别注意参数初始化的选择。
【Layer normalization】
更常见的比较是BN与LN的比较。
BN层有两个缺点:
- 无法进行在线学习,因为在线学习的mini-batch为1;LN可以
- 之前提到的BN不能用在RNN中;LN可以
- 消耗一定的内存来记录均值和方差;LN不用
但是,在CNN中LN并没有取得比BN更好的效果。
参考链接:
- https://zhuanlan.zhihu.com/p/34879333
- https://www.zhihu.com/question/59728870
- https://zhuanlan.zhihu.com/p/113233908
- https://www.zhihu.com/question/55890057/answer/267872896
干货 | 这可能全网最好的BatchNorm详解的更多相关文章
- [转帖]HTTPS系列干货(一):HTTPS 原理详解
HTTPS系列干货(一):HTTPS 原理详解 https://tech.upyun.com/article/192/HTTPS%E7%B3%BB%E5%88%97%E5%B9%B2%E8%B4%A7 ...
- 【转】HTTPS系列干货(一):HTTPS 原理详解
HTTPS系列干货(一):HTTPS 原理详解 前言 HTTPS(全称:HyperText Transfer Protocol over Secure Socket Layer),其实 HTTPS 并 ...
- Mybatis系列全解(五):全网最全!详解Mybatis的Mapper映射文件
封面:洛小汐 作者:潘潘 若不是生活所迫,谁愿意背负一身才华. 前言 上节我们介绍了 < Mybatis系列全解(四):全网最全!Mybatis配置文件 XML 全貌详解 >,内容很详细( ...
- HTTPS系列干货(一):HTTPS 原理详解
HTTPS(全称:HyperText Transfer Protocol over Secure Socket Layer),其实 HTTPS 并不是一个新鲜协议,Google 很早就开始启用了,初衷 ...
- 【腾讯Bugly干货分享】iOS10 SiriKit QQ适配详解
本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/57ece0331288fb4d31137da6 1. 概述 苹果在iOS10开放 ...
- 干货分享:Academic Essay写作套路详解
你想过如何中立的表达自己吗?大概只有10%不到的同学,会真正重视这个细节.但很多留学生能顺利写完作文已经不容易,还要注意什么中立不中立的.我知道这个标准,对许多同学有些过分,但很残酷的告诉你,这的确是 ...
- 干货分享:Research Essay写作规范详解
同学们在刚到国外时觉得一切都很新鲜,感觉到处都在吸引着他们,但是大部分留学生在刚碰到Research Essay便是一头包.其实Research Essay也没有想象中的那么难,只是留学生们初次接触, ...
- 全程干货,requests模块与selenium框架详解
requests模块 前言: 通常我们利用Python写一些WEB程序.webAPI部署在服务端,让客户端request,我们作为服务器端response数据: 但也可以反主为客利用Python的re ...
- Mybatis系列全解(四):全网最全!Mybatis配置文件XML全貌详解
封面:洛小汐 作者:潘潘 做大事和做小事的难度是一样的.两者都会消耗你的时间和精力,所以如果决心做事,就要做大事,要确保你的梦想值得追求,未来的收获可以配得上你的努力. 前言 上一篇文章 <My ...
随机推荐
- socket 建立网络连接,client && server
client代码: package socket; import java.io.IOException; import java.net.Socket; /** * 客户端_聊天室 * * @aut ...
- 与跨域相关的 jsonp 劫持与 CORS 配置错误
参考文章: CORS(跨域资源共享)错误配置漏洞的高级利用 JSONP劫持CORS跨源资源共享漏洞 JSONP绕过CSRF防护token 读取型CSRF-需要交互的内容劫持 跨域资源共享 CORS 详 ...
- 记一道CTF隐写题解答过程
0x00 前言 由于我是这几天才开始接触隐写这种东西,所以作为新手我想记录一下刚刚所学.这道CTF所需的知识点包括了图片的内容隐藏,mp3隐写,base64解密,当铺解密,可能用到的工具包括bin ...
- 深入理解JVM(③)学习Java的内存模型
前言 Java内存模型(Java Memory Model)用来屏蔽各种硬件和操作系统的内存访问差异,这使得Java能够变得非常灵活而不用考虑各系统间的兼容性等问题.定义Java内存模型并非一件容易的 ...
- day64 django模型层
目录 一.单表操作(增删改) 二.必知必会13个方法 三.查看内部的sql语句的方法 四.神奇的双下划线查询 五.一对多外键的增删改查 六.多对多外键的增删改查 七.正反向查询概念 八.多表查询 1 ...
- 【XCTF】ics-05
信息: 题目来源:XCTF 4th-CyberEarth 标签:PHP.伪协议 题目描述:其他破坏者会利用工控云管理系统设备维护中心的后门入侵系统 解题过程 题目给了一个工控管理系统,并提示存在后门, ...
- 攻防世界-Web-ics-05
根据题目提示直接进入设备维护中心 点击云平台设备维护中心发现page=index LFI漏洞的黑盒判断方法: 单纯的从URL判断的话,URL中path.dir.file.pag.page.archiv ...
- scrapy (三) : 请求传参
scrapy 请求传参 1.定义数据结构item.py文件 ''' field: item.py ''' # -*- coding: utf-8 -*- # Define here the model ...
- Python之 爬虫(十二)关于深度优先和广度优先
网站的树结构 深度优先算法和实现 广度优先算法和实现 网站的树结构 通过伯乐在线网站为例子: 并且我们通过访问伯乐在线也是可以发现,我们从任何一个子页面其实都是可以返回到首页,所以当我们爬取页面的数据 ...
- java大数据最全课程学习笔记(1)--Hadoop简介和安装及伪分布式
Hadoop简介和安装及伪分布式 大数据概念 大数据概论 大数据(Big Data): 指无法在一定时间范围内用常规软件工具进行捕捉,管理和处理的数据集合,是需要新处理模式才能具有更强的决策力,洞察发 ...