Gradient Centralization: 简单的梯度中心化,一行代码加速训练并提升泛化能力 | ECCV 2020 Oral
梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理
来源:晓飞的算法工程笔记 公众号
论文: Gradient Centralization: A New Optimization Technique for Deep Neural Networks

Introduction
优化器(Optimizer)对于深度神经网络在大型数据集上的训练是十分重要的,如SGD和SGDM,优化器的目标有两个:加速训练过程和提高模型的泛化能力。目前,很多工作研究如何提高如SGD等优化器的性能,如克服训练中的梯度消失和梯度爆炸问题,有效的trick有权值初始化、激活函数、梯度裁剪以及自适应学习率等。而一些工作则从统计的角度对权值和特征值进行标准化来让训练更稳定,比如特征图标准化方法BN以及权值标准化方法WN。。

与在权值和特征值进行标准化方法不同,论文提出作用于权值梯度的高性能网络优化算法梯度中心化(GC, gradient centralization),能够加速网络训练,提高泛化能力以及兼容模型fine-tune。如图a所示,GC的思想很简单,零均值化梯度向量,能够轻松地嵌入各种优化器中。论文主要贡献如下:
- 提出新的通用网络优化方法,梯度中心化(GC),不仅能平滑和加速训练过程,还能提高模型的泛化能力。
- 分析了GC的理论属性,表明GC能够约束损失函数,标准化权值空间和特征值空间,提升模型的泛化能力。另外,约束的损失函数有更好的Lipschitzness(抗扰动能力,函数斜率恒定小于一个Lipschitze常数),让训练更稳定、更高效。
Gradient Centralization
Motivation
BN和WS使用Z-score标准化分别操作于特征值和权重,实际是间接地对权值的梯度进行约束,从而提高优化时损失函数的Lipschitz属性。受此启发,论文直接对梯度操作,首先尝试了Z-score标准化,但实验发现并没有提升训练的稳定性。之后,尝试计算梯度向量的均值,对梯度向量进行零均值化,实验发现能够有效地提高损失函数的Lipschitz属性,使网络训练更稳定、更具泛化能力,得到梯度中心化(GC)算法。
Notations
定义一些基础符号,使用$W \in \mathbb{R}^{M \times N}$统一表示全连接层的权值矩阵$W_{fc} \in \mathbb{R}^{C_{in}\times C_{out}}$和卷积层的权值张量$W_{conv} \in \mathbb{R}^{(C_{in} k_1 k_2)\times C_{out}}$,$w_i \in \mathbb{R}M$为权值矩阵$W$的第$i$列,$\mathcal{L}$为目标函数,$\nabla_{W}\mathcal{L}$和$\nabla_{w_i}\mathcal{L}$为$\mathcal{L}$对$W$和$w_i$的梯度,$W$与$\nabla_{W}\mathcal{L}$的大小一样。定义$X$为输入特征图,则$WT X$为输出特征图,$e=\frac{1}{\sqrt{M}}1$为$M$位单位向量(unit vector),$I\in\mathbb{R}^{M\times M}$为单位矩阵(identity matrix)。
Formulation of GC

对于卷积层或全连接层的权值向量$w_i$,通过反向传播得到其梯度$\nabla_{w_i}\mathcal{L}$,然后如图b所示计算其均值$\mu\nabla_{w_i}\mathcal{L}=\frac{1}{M}{\sum}^M_{j=1} \nabla_{w_{i,j}\mathcal{L}}$,GC操作$\Phi$定义如下:

也可以将公式1转换为矩阵形式:

$P$由单位矩阵以及单位向量形成矩阵构成,分别负责保留原值以及求均值。
Embedding of GC to SGDM/Adam
GC能够简单地嵌入当前的主流网络优化算法中,如SGDM和Adam,直接使用零均值化的梯度$\Phi_{GC}(\nabla_w \mathcal{L})$进行权值的更新。

算法1和算法2分别展示了将GC嵌入到SGDM和Adam中,基本上不需要对原优化器算法进行修改,仅需加入一行梯度零均值化计算即可,大约仅需0.6sec。
Properties of GC
下面从理论的角度分析GC为何能提高模型的泛化能力以及加速训练。
Improving Generalization Performance
GC有一个很重要的优点是提高模型的泛化能力,主要得益于权值空间正则化和特征值空间正则化。
Weight space regularization
首先介绍$P$的物理意义,经过推算可以得到:

即$P$可以看作映射矩阵,将$\nabla_W \mathcal{L}$映射到空间向量中法向量为$e$的超平面,$P\nabla_W \mathcal{L}$为映射梯度。

以SGD优化为例,权值梯度的映射能够将权值空间约束在一个超平面或黎曼流形(Riemannian manifold)中,如图2所示,梯度首先映射到$eT(w-wt)=0$的超平面中,然后跟随映射梯度$-P\nabla_{wt}\mathcal{L}$的方向进行更新。从$eT(w-wt)=0$可以得到$eTw{t+1}=eTwt=\cdots=eTw^0$,目标函数实际变为:

这是一个权值空间$w$的约束优化问题,正则化$w$的解空间,降低了过拟合的可能性(过拟合通常是学习了复杂的权值来适应训练数据),能够提升网络的泛化能力,特别是当训练样本较少的情况下。
WS对权值进行$eTw=0$的约束,当初始权值不满足约束时,会直接修改权值来满足约束条件。假设进行fine-tune训练,WS则会完全丢弃预训练模型的优势,而GC可以适应任何初始权值$eT(w0-w0)=0$。
Output feature space regularization
以SGD优化方法为例,权值更新$w{t+1}=wt-\alphatP\nabla_{w_t}\mathcal{L}$,可以推导得到$wt=w0-P{\sum}{t-1}_{i=0}\alpha{(i)}\nabla_{w{(i)}}\mathcal{L}$。对于任何输入特征向量$x$,有以下定理:

相关证明可以看原文附录,定理4.1表明输入特征的常量变化会造成输出的变化,而输出的变化量仅与标量$\gamma$和$1Tw0$相关,与当前权值$wt$无关。$\gamma1Tw0$为初始化权值向量缩放后的均值,假设$w0$接近0,则输入特征值的常量变化将几乎不会改变输出特征值,意味着输出特征空间对训练样本的变化更鲁棒。

对ResNet50的不同初始权值进行可视化,可以看到权值都非常小(小于$e^{-7}$),这说明如果使用GC来训练,输出特征不会对输入特征的变化过于敏感。这个属性正则化输出特征空间,并且提升网络训练的泛化能力。
Accelerating Training Process
Optimization landscape smoothing
前面提到BN和WS都间接地对权值梯度进行约束,使损失函数满足Lipschitz属性,$||\nabla_w\mathcal{L}||_2$和$||\nabla^2_w\mathcal{L}||_2$($w$的Hessian矩阵)都有上界。GC直接对梯度进行约束,也有类似于BN和WS的属性,对比原损失函数满足以下定理:

相关证明可以看原文附录,定理4.2表明GC比原函数有更好的Lipschitzness,更好的Lipschitzness意味着梯度更加稳定,优化过程也更加平滑,能够类似于BN和WS那样加速训练过程。
Gradient explosion suppression
GC的另一个优点是防止梯度爆炸,使得训练更加稳定,作用原理类似于梯度裁剪。过大的梯度会导致损失严重震荡,难以收敛,而梯度裁剪能够抑制大梯度,使得训练更稳定、更快。

对梯度的$L_2$ norm和最大值进行了可视化,可以看到使用GC后的值均比原函数要小,这也与定理4.2一致,GC能够让训练过程更平滑、更快。
Experiment

与BN和WS结合的性能对比。

Mini-ImageNet上的对比实验。

CIFAR100上的对比实验。

ImageNet上的对比实验。

细粒度数据集上的性能对比。

检测与分割任务上的性能对比。
Conclustion
梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理。
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

Gradient Centralization: 简单的梯度中心化,一行代码加速训练并提升泛化能力 | ECCV 2020 Oral的更多相关文章
- 简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020
论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果 来源:晓飞的算 ...
- 一行代码调用实现带字段选取+条件判断+排序+分页功能的增强ORM框架
问题:3行代码 PDF.NET是一个开源的数据开发框架,它的特点是简单.轻量.快速,易上手,而且是一个注释完善的国产开发框架,受到不少朋友的欢迎,也在我们公司的项目中多次使用.但是,PDF.NET比起 ...
- 学习笔记57—归一化 (Normalization)、标准化 (Standardization)和中心化/零均值化 (Zero-centered)
1 概念 归一化:1)把数据变成(0,1)或者(1,1)之间的小数.主要是为了数据处理方便提出来的,把数据映射到0-1范围之内处理,更加便捷快速.2)把有量纲表达式变成无量纲表达式,便于不同单位或 ...
- 一个轻client,多语言支持,去中心化,自己主动负载,可扩展的实时数据写服务的实现方案讨论
背景 背景是设计一个实时数据接入的模块,负责接收client的实时数据写入(如日志流,点击流),数据支持直接下沉到HBase上(兴许提供HBase上的查询),或先持久化到Kafka里.方便兴许进行一些 ...
- [数据预处理]-中心化 缩放 KNN(一)
据预处理是总称,涵盖了数据分析师使用它将数据转处理成想要的数据的一系列操作.例如,对某个网站进行分析的时候,可能会去掉 html 标签,空格,缩进以及提取相关关键字.分析空间数据的时候,一般会把带单位 ...
- 在dotnet core下去中心化访问HTTP服务集群
一般应用服务都会部署到多台服务器之上,一.可以通过硬件得到更多的并发处理能力:二.可以避免单点太故障的出现,从而确保服务7X24有效运作.当访问这些HTTP服务的情况一般都是经过反向代理服务进行统一处 ...
- 理解去中心化 稳定币 DAI
本文转载于深入浅出区块链, 原文链接 随着摩根大通推出JPM Coin 稳定币,可以预见稳定币将成为区块链落地的一大助推器. 坦白来讲,对于一个程序员的我来讲(不懂一点专业经济和金融),理解DAI的机 ...
- 去中心化存储项目终极指南 | Filecoin, Storj 和 PPIO 项目异同
Filecoin,Storj 以及 PPIO 这三个存储公链的设计思路是不一样的,没有优劣之分,写这篇文章也并不是为了争论各项目的好坏对错.去中心化存储是一个长期商业赛道,不同团队在同一个赛道上往不同 ...
- 为什么比特币和以太坊未必真得比EOS更去中心化?
在区块链行业里,有两派人一直在争论:一个是以比特币和以太坊为首的社群,另一个是以EOS为首的社群.这两群人一直在争论谁才是真正的未来,双方都认为自己这边更有未来.其中EOS抗争的重点就是100万TPS ...
随机推荐
- GAN网络从入门教程(三)之DCGAN原理
目录 DCGAN简介 DCGAN的特点 几个重要概念 下采样(subsampled) 上采样(upsampling) 反卷积(Deconvolution) 批标准化(Batch Normalizati ...
- day60 django入门
目录 一.静态文件配置 1 引子 2 如何配置 1 在settins.py中的具体配置 2 静态文件的动态解析(html页面中) 二.request对象方法初识 三.pycharm链接数据库(mysq ...
- Vue 的响应式原理中 Object.defineProperty 有什么缺陷?
Object.defineProperty只能劫持对象的属性,从而需要对每个对象,每个属性进行遍历,如果,属性值是对象,还需要深度遍历.Proxy可以劫持整个对象,并返回一个新的对象. Proxy不仅 ...
- Alink漫谈(十) :线性回归实现 之 数据预处理
Alink漫谈(十) :线性回归实现 之 数据预处理 目录 Alink漫谈(十) :线性回归实现 之 数据预处理 0x00 摘要 0x01 概念 1.1 线性回归 1.2 优化模型 1.3 损失函数& ...
- Sql Or NoSql,看完这一篇你就懂了(转五月的仓颉)
前言 你是否在为系统的数据库来一波大流量就几乎打满CPU,日常CPU居高不下烦恼?你是否在各种NoSql间纠结不定,到底该选用那种最好?今天的你就是昨天的我,这也是写这篇文章的初衷. 这篇文章是我好几 ...
- 前端04 /css样式
前端04 /css样式 目录 前端04 /css样式 昨日内容回顾 css引入 选择器 基础选择器 组合选择器 属性选择器 伪类选择器 伪元素选择器 优先级(权重) 通用选择器 css样式 1高度宽度 ...
- MySql-Binlog协议
MySQL主备复制原理 MySQL master 将数据变更写入二进制日志( binary log, 其中记录叫做二进制日志事件binary log events,可以通过 show binlog e ...
- Ethical Hacking - NETWORK PENETRATION TESTING(15)
ARP Poisoning - arpspoof Arpspoof is a tool part of a suit called dsniff, which contains a number of ...
- C#中的类与对象
类:说白了就是类型,是对具体事物的一种抽象总结. 对象:一个具体的事物. 类与对象的关系,类实例化就会得到一个对象,同样一个对象也应该属于某一个类.例如张三这个人,他是一个对象,同时他属于人类,在程序 ...
- vue :没有全局变量的计数器
created: created () { let num = null this.mFun(num) }, methods: methods:{ mFun(m){ if (m === null) { ...