梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理



来源:晓飞的算法工程笔记 公众号

论文: Gradient Centralization: A New Optimization Technique for Deep Neural Networks

Introduction


  优化器(Optimizer)对于深度神经网络在大型数据集上的训练是十分重要的,如SGD和SGDM,优化器的目标有两个:加速训练过程和提高模型的泛化能力。目前,很多工作研究如何提高如SGD等优化器的性能,如克服训练中的梯度消失和梯度爆炸问题,有效的trick有权值初始化、激活函数、梯度裁剪以及自适应学习率等。而一些工作则从统计的角度对权值和特征值进行标准化来让训练更稳定,比如特征图标准化方法BN以及权值标准化方法WN。。

  与在权值和特征值进行标准化方法不同,论文提出作用于权值梯度的高性能网络优化算法梯度中心化(GC, gradient centralization),能够加速网络训练,提高泛化能力以及兼容模型fine-tune。如图a所示,GC的思想很简单,零均值化梯度向量,能够轻松地嵌入各种优化器中。论文主要贡献如下:

  • 提出新的通用网络优化方法,梯度中心化(GC),不仅能平滑和加速训练过程,还能提高模型的泛化能力。
  • 分析了GC的理论属性,表明GC能够约束损失函数,标准化权值空间和特征值空间,提升模型的泛化能力。另外,约束的损失函数有更好的Lipschitzness(抗扰动能力,函数斜率恒定小于一个Lipschitze常数),让训练更稳定、更高效。

Gradient Centralization


Motivation

  BN和WS使用Z-score标准化分别操作于特征值和权重,实际是间接地对权值的梯度进行约束,从而提高优化时损失函数的Lipschitz属性。受此启发,论文直接对梯度操作,首先尝试了Z-score标准化,但实验发现并没有提升训练的稳定性。之后,尝试计算梯度向量的均值,对梯度向量进行零均值化,实验发现能够有效地提高损失函数的Lipschitz属性,使网络训练更稳定、更具泛化能力,得到梯度中心化(GC)算法。

Notations

  定义一些基础符号,使用$W \in \mathbb{R}^{M \times N}$统一表示全连接层的权值矩阵$W_{fc} \in \mathbb{R}^{C_{in}\times C_{out}}$和卷积层的权值张量$W_{conv} \in \mathbb{R}^{(C_{in} k_1 k_2)\times C_{out}}$,$w_i \in \mathbb{R}M$为权值矩阵$W$的第$i$列,$\mathcal{L}$为目标函数,$\nabla_{W}\mathcal{L}$和$\nabla_{w_i}\mathcal{L}$为$\mathcal{L}$对$W$和$w_i$的梯度,$W$与$\nabla_{W}\mathcal{L}$的大小一样。定义$X$为输入特征图,则$WT X$为输出特征图,$e=\frac{1}{\sqrt{M}}1$为$M$位单位向量(unit vector),$I\in\mathbb{R}^{M\times M}$为单位矩阵(identity matrix)。

Formulation of GC

  对于卷积层或全连接层的权值向量$w_i$,通过反向传播得到其梯度$\nabla_{w_i}\mathcal{L}$,然后如图b所示计算其均值$\mu\nabla_{w_i}\mathcal{L}=\frac{1}{M}{\sum}^M_{j=1} \nabla_{w_{i,j}\mathcal{L}}$,GC操作$\Phi$定义如下:

  也可以将公式1转换为矩阵形式:

  $P$由单位矩阵以及单位向量形成矩阵构成,分别负责保留原值以及求均值。

Embedding of GC to SGDM/Adam

  GC能够简单地嵌入当前的主流网络优化算法中,如SGDM和Adam,直接使用零均值化的梯度$\Phi_{GC}(\nabla_w \mathcal{L})$进行权值的更新。

  算法1和算法2分别展示了将GC嵌入到SGDM和Adam中,基本上不需要对原优化器算法进行修改,仅需加入一行梯度零均值化计算即可,大约仅需0.6sec。

Properties of GC


  下面从理论的角度分析GC为何能提高模型的泛化能力以及加速训练。

Improving Generalization Performance

  GC有一个很重要的优点是提高模型的泛化能力,主要得益于权值空间正则化和特征值空间正则化。

  • Weight space regularization

  首先介绍$P$的物理意义,经过推算可以得到:

  即$P$可以看作映射矩阵,将$\nabla_W \mathcal{L}$映射到空间向量中法向量为$e$的超平面,$P\nabla_W \mathcal{L}$为映射梯度。

  以SGD优化为例,权值梯度的映射能够将权值空间约束在一个超平面或黎曼流形(Riemannian manifold)中,如图2所示,梯度首先映射到$eT(w-wt)=0$的超平面中,然后跟随映射梯度$-P\nabla_{wt}\mathcal{L}$的方向进行更新。从$eT(w-wt)=0$可以得到$eTw{t+1}=eTwt=\cdots=eTw^0$,目标函数实际变为:

  这是一个权值空间$w$的约束优化问题,正则化$w$的解空间,降低了过拟合的可能性(过拟合通常是学习了复杂的权值来适应训练数据),能够提升网络的泛化能力,特别是当训练样本较少的情况下。

  WS对权值进行$eTw=0$的约束,当初始权值不满足约束时,会直接修改权值来满足约束条件。假设进行fine-tune训练,WS则会完全丢弃预训练模型的优势,而GC可以适应任何初始权值$eT(w0-w0)=0$。

  • Output feature space regularization

  以SGD优化方法为例,权值更新$w{t+1}=wt-\alphatP\nabla_{w_t}\mathcal{L}$,可以推导得到$wt=w0-P{\sum}{t-1}_{i=0}\alpha{(i)}\nabla_{w{(i)}}\mathcal{L}$。对于任何输入特征向量$x$,有以下定理:

  相关证明可以看原文附录,定理4.1表明输入特征的常量变化会造成输出的变化,而输出的变化量仅与标量$\gamma$和$1Tw0$相关,与当前权值$wt$无关。$\gamma1Tw0$为初始化权值向量缩放后的均值,假设$w0$接近0,则输入特征值的常量变化将几乎不会改变输出特征值,意味着输出特征空间对训练样本的变化更鲁棒。

  对ResNet50的不同初始权值进行可视化,可以看到权值都非常小(小于$e^{-7}$),这说明如果使用GC来训练,输出特征不会对输入特征的变化过于敏感。这个属性正则化输出特征空间,并且提升网络训练的泛化能力。

Accelerating Training Process

  • Optimization landscape smoothing

  前面提到BN和WS都间接地对权值梯度进行约束,使损失函数满足Lipschitz属性,$||\nabla_w\mathcal{L}||_2$和$||\nabla^2_w\mathcal{L}||_2$($w$的Hessian矩阵)都有上界。GC直接对梯度进行约束,也有类似于BN和WS的属性,对比原损失函数满足以下定理:

  相关证明可以看原文附录,定理4.2表明GC比原函数有更好的Lipschitzness,更好的Lipschitzness意味着梯度更加稳定,优化过程也更加平滑,能够类似于BN和WS那样加速训练过程。

  • Gradient explosion suppression

  GC的另一个优点是防止梯度爆炸,使得训练更加稳定,作用原理类似于梯度裁剪。过大的梯度会导致损失严重震荡,难以收敛,而梯度裁剪能够抑制大梯度,使得训练更稳定、更快。

  对梯度的$L_2$ norm和最大值进行了可视化,可以看到使用GC后的值均比原函数要小,这也与定理4.2一致,GC能够让训练过程更平滑、更快。

Experiment


  与BN和WS结合的性能对比。

  Mini-ImageNet上的对比实验。

  CIFAR100上的对比实验。

  ImageNet上的对比实验。

  细粒度数据集上的性能对比。

  检测与分割任务上的性能对比。

Conclustion


  梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理。





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

Gradient Centralization: 简单的梯度中心化,一行代码加速训练并提升泛化能力 | ECCV 2020 Oral的更多相关文章

  1. 简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020

    论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果   来源:晓飞的算 ...

  2. 一行代码调用实现带字段选取+条件判断+排序+分页功能的增强ORM框架

    问题:3行代码 PDF.NET是一个开源的数据开发框架,它的特点是简单.轻量.快速,易上手,而且是一个注释完善的国产开发框架,受到不少朋友的欢迎,也在我们公司的项目中多次使用.但是,PDF.NET比起 ...

  3. 学习笔记57—归一化 (Normalization)、标准化 (Standardization)和中心化/零均值化 (Zero-centered)

    1 概念   归一化:1)把数据变成(0,1)或者(1,1)之间的小数.主要是为了数据处理方便提出来的,把数据映射到0-1范围之内处理,更加便捷快速.2)把有量纲表达式变成无量纲表达式,便于不同单位或 ...

  4. 一个轻client,多语言支持,去中心化,自己主动负载,可扩展的实时数据写服务的实现方案讨论

    背景 背景是设计一个实时数据接入的模块,负责接收client的实时数据写入(如日志流,点击流),数据支持直接下沉到HBase上(兴许提供HBase上的查询),或先持久化到Kafka里.方便兴许进行一些 ...

  5. [数据预处理]-中心化 缩放 KNN(一)

    据预处理是总称,涵盖了数据分析师使用它将数据转处理成想要的数据的一系列操作.例如,对某个网站进行分析的时候,可能会去掉 html 标签,空格,缩进以及提取相关关键字.分析空间数据的时候,一般会把带单位 ...

  6. 在dotnet core下去中心化访问HTTP服务集群

    一般应用服务都会部署到多台服务器之上,一.可以通过硬件得到更多的并发处理能力:二.可以避免单点太故障的出现,从而确保服务7X24有效运作.当访问这些HTTP服务的情况一般都是经过反向代理服务进行统一处 ...

  7. 理解去中心化 稳定币 DAI

    本文转载于深入浅出区块链, 原文链接 随着摩根大通推出JPM Coin 稳定币,可以预见稳定币将成为区块链落地的一大助推器. 坦白来讲,对于一个程序员的我来讲(不懂一点专业经济和金融),理解DAI的机 ...

  8. 去中心化存储项目终极指南 | Filecoin, Storj 和 PPIO 项目异同

    Filecoin,Storj 以及 PPIO 这三个存储公链的设计思路是不一样的,没有优劣之分,写这篇文章也并不是为了争论各项目的好坏对错.去中心化存储是一个长期商业赛道,不同团队在同一个赛道上往不同 ...

  9. 为什么比特币和以太坊未必真得比EOS更去中心化?

    在区块链行业里,有两派人一直在争论:一个是以比特币和以太坊为首的社群,另一个是以EOS为首的社群.这两群人一直在争论谁才是真正的未来,双方都认为自己这边更有未来.其中EOS抗争的重点就是100万TPS ...

随机推荐

  1. Scala 基础(四):Scala变量 (一) 基础

    1.概念 变量相当于内存中一个数据存储空间的表示,你可以把变量看做是一个房间的门 牌号,通过门牌号我们可以找到房间,而通过变量名可以访问到变量(值). 2 变量使用的基本步骤 1) 声明/定义变量 ( ...

  2. JVM 专题九:运行时数据区(四)本地方法栈

    1. 本地方法栈 2. 什么是本地方法栈? Java虚拟机栈用于管理Java方法的调用,而本地方法栈用于管理本地方法的调用   本地方法栈,也是线程私有的. 允许被实现成固定或者是可动态拓展的内存大小 ...

  3. python之爬虫(十一) 实例爬取上海高级人民法院网开庭公告数据

    通过前面的文章已经学习了基本的爬虫知识,通过这个例子进行一下练习,毕竟前面文章的知识点只是一个 一个单独的散知识点,需要通过实际的例子进行融合 分析网站 其实爬虫最重要的是前面的分析网站,只有对要爬取 ...

  4. python之爬虫(十) Selenium库的使用

    一.什么是Selenium selenium 是一套完整的web应用程序测试系统,包含了测试的录制(selenium IDE),编写及运行(Selenium Remote Control)和测试的并行 ...

  5. 安装 VsCode 插件安装以及配置

    安装vscode 官方网站 https://code.visualstudio.com/ 下载后 1.双击vscode.exe 2.选择 我接受  3.一路下一步,遇到方框就选4.点击  安装按钮 v ...

  6. CMMI规范目录结构

  7. Spring配置类深度剖析-总结篇(手绘流程图,可白嫖)

    生命太短暂,不要去做一些根本没有人想要的东西.本文已被 https://www.yourbatman.cn 收录,里面一并有Spring技术栈.MyBatis.JVM.中间件等小而美的专栏供以免费学习 ...

  8. 使用Java带你打造一款简单的英语学习系统

    [一.项目背景] 随着移动互联网的发展,英语学习系统能结构化的组织海量资料.针对用户个性需求,有的放矢地呈现给用户,从而为英语学习者提供便利,提升他们的学习效率. [二.项目目标] 1. 实现美观的界 ...

  9. (2)简单理解和使用webpack-dev-server

    webpack-dev-server能做什么? 每次打包都得像之前一样使用webapck 入口文件 -o 出口文件,每次修改都得打包一次过于麻烦,可以使用webpack-dev-server实现自动打 ...

  10. CUDA Programming Guide 学习笔记

    CUDA学习笔记 GPU架构 GPU围绕流式多处理器(SM)的可扩展阵列搭建,每个GPU有多个SM,每个SM支持数百个线程并发执行.目前Nvidia推出了6种GPU架构(按时间顺序,详见下图):Fer ...