梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理



来源:晓飞的算法工程笔记 公众号

论文: Gradient Centralization: A New Optimization Technique for Deep Neural Networks

Introduction


  优化器(Optimizer)对于深度神经网络在大型数据集上的训练是十分重要的,如SGD和SGDM,优化器的目标有两个:加速训练过程和提高模型的泛化能力。目前,很多工作研究如何提高如SGD等优化器的性能,如克服训练中的梯度消失和梯度爆炸问题,有效的trick有权值初始化、激活函数、梯度裁剪以及自适应学习率等。而一些工作则从统计的角度对权值和特征值进行标准化来让训练更稳定,比如特征图标准化方法BN以及权值标准化方法WN。。

  与在权值和特征值进行标准化方法不同,论文提出作用于权值梯度的高性能网络优化算法梯度中心化(GC, gradient centralization),能够加速网络训练,提高泛化能力以及兼容模型fine-tune。如图a所示,GC的思想很简单,零均值化梯度向量,能够轻松地嵌入各种优化器中。论文主要贡献如下:

  • 提出新的通用网络优化方法,梯度中心化(GC),不仅能平滑和加速训练过程,还能提高模型的泛化能力。
  • 分析了GC的理论属性,表明GC能够约束损失函数,标准化权值空间和特征值空间,提升模型的泛化能力。另外,约束的损失函数有更好的Lipschitzness(抗扰动能力,函数斜率恒定小于一个Lipschitze常数),让训练更稳定、更高效。

Gradient Centralization


Motivation

  BN和WS使用Z-score标准化分别操作于特征值和权重,实际是间接地对权值的梯度进行约束,从而提高优化时损失函数的Lipschitz属性。受此启发,论文直接对梯度操作,首先尝试了Z-score标准化,但实验发现并没有提升训练的稳定性。之后,尝试计算梯度向量的均值,对梯度向量进行零均值化,实验发现能够有效地提高损失函数的Lipschitz属性,使网络训练更稳定、更具泛化能力,得到梯度中心化(GC)算法。

Notations

  定义一些基础符号,使用$W \in \mathbb{R}^{M \times N}$统一表示全连接层的权值矩阵$W_{fc} \in \mathbb{R}^{C_{in}\times C_{out}}$和卷积层的权值张量$W_{conv} \in \mathbb{R}^{(C_{in} k_1 k_2)\times C_{out}}$,$w_i \in \mathbb{R}M$为权值矩阵$W$的第$i$列,$\mathcal{L}$为目标函数,$\nabla_{W}\mathcal{L}$和$\nabla_{w_i}\mathcal{L}$为$\mathcal{L}$对$W$和$w_i$的梯度,$W$与$\nabla_{W}\mathcal{L}$的大小一样。定义$X$为输入特征图,则$WT X$为输出特征图,$e=\frac{1}{\sqrt{M}}1$为$M$位单位向量(unit vector),$I\in\mathbb{R}^{M\times M}$为单位矩阵(identity matrix)。

Formulation of GC

  对于卷积层或全连接层的权值向量$w_i$,通过反向传播得到其梯度$\nabla_{w_i}\mathcal{L}$,然后如图b所示计算其均值$\mu\nabla_{w_i}\mathcal{L}=\frac{1}{M}{\sum}^M_{j=1} \nabla_{w_{i,j}\mathcal{L}}$,GC操作$\Phi$定义如下:

  也可以将公式1转换为矩阵形式:

  $P$由单位矩阵以及单位向量形成矩阵构成,分别负责保留原值以及求均值。

Embedding of GC to SGDM/Adam

  GC能够简单地嵌入当前的主流网络优化算法中,如SGDM和Adam,直接使用零均值化的梯度$\Phi_{GC}(\nabla_w \mathcal{L})$进行权值的更新。

  算法1和算法2分别展示了将GC嵌入到SGDM和Adam中,基本上不需要对原优化器算法进行修改,仅需加入一行梯度零均值化计算即可,大约仅需0.6sec。

Properties of GC


  下面从理论的角度分析GC为何能提高模型的泛化能力以及加速训练。

Improving Generalization Performance

  GC有一个很重要的优点是提高模型的泛化能力,主要得益于权值空间正则化和特征值空间正则化。

  • Weight space regularization

  首先介绍$P$的物理意义,经过推算可以得到:

  即$P$可以看作映射矩阵,将$\nabla_W \mathcal{L}$映射到空间向量中法向量为$e$的超平面,$P\nabla_W \mathcal{L}$为映射梯度。

  以SGD优化为例,权值梯度的映射能够将权值空间约束在一个超平面或黎曼流形(Riemannian manifold)中,如图2所示,梯度首先映射到$eT(w-wt)=0$的超平面中,然后跟随映射梯度$-P\nabla_{wt}\mathcal{L}$的方向进行更新。从$eT(w-wt)=0$可以得到$eTw{t+1}=eTwt=\cdots=eTw^0$,目标函数实际变为:

  这是一个权值空间$w$的约束优化问题,正则化$w$的解空间,降低了过拟合的可能性(过拟合通常是学习了复杂的权值来适应训练数据),能够提升网络的泛化能力,特别是当训练样本较少的情况下。

  WS对权值进行$eTw=0$的约束,当初始权值不满足约束时,会直接修改权值来满足约束条件。假设进行fine-tune训练,WS则会完全丢弃预训练模型的优势,而GC可以适应任何初始权值$eT(w0-w0)=0$。

  • Output feature space regularization

  以SGD优化方法为例,权值更新$w{t+1}=wt-\alphatP\nabla_{w_t}\mathcal{L}$,可以推导得到$wt=w0-P{\sum}{t-1}_{i=0}\alpha{(i)}\nabla_{w{(i)}}\mathcal{L}$。对于任何输入特征向量$x$,有以下定理:

  相关证明可以看原文附录,定理4.1表明输入特征的常量变化会造成输出的变化,而输出的变化量仅与标量$\gamma$和$1Tw0$相关,与当前权值$wt$无关。$\gamma1Tw0$为初始化权值向量缩放后的均值,假设$w0$接近0,则输入特征值的常量变化将几乎不会改变输出特征值,意味着输出特征空间对训练样本的变化更鲁棒。

  对ResNet50的不同初始权值进行可视化,可以看到权值都非常小(小于$e^{-7}$),这说明如果使用GC来训练,输出特征不会对输入特征的变化过于敏感。这个属性正则化输出特征空间,并且提升网络训练的泛化能力。

Accelerating Training Process

  • Optimization landscape smoothing

  前面提到BN和WS都间接地对权值梯度进行约束,使损失函数满足Lipschitz属性,$||\nabla_w\mathcal{L}||_2$和$||\nabla^2_w\mathcal{L}||_2$($w$的Hessian矩阵)都有上界。GC直接对梯度进行约束,也有类似于BN和WS的属性,对比原损失函数满足以下定理:

  相关证明可以看原文附录,定理4.2表明GC比原函数有更好的Lipschitzness,更好的Lipschitzness意味着梯度更加稳定,优化过程也更加平滑,能够类似于BN和WS那样加速训练过程。

  • Gradient explosion suppression

  GC的另一个优点是防止梯度爆炸,使得训练更加稳定,作用原理类似于梯度裁剪。过大的梯度会导致损失严重震荡,难以收敛,而梯度裁剪能够抑制大梯度,使得训练更稳定、更快。

  对梯度的$L_2$ norm和最大值进行了可视化,可以看到使用GC后的值均比原函数要小,这也与定理4.2一致,GC能够让训练过程更平滑、更快。

Experiment


  与BN和WS结合的性能对比。

  Mini-ImageNet上的对比实验。

  CIFAR100上的对比实验。

  ImageNet上的对比实验。

  细粒度数据集上的性能对比。

  检测与分割任务上的性能对比。

Conclustion


  梯度中心化GC对权值梯度进行零均值化,能够使得网络的训练更加稳定,并且能提高网络的泛化能力,算法思路简单,论文的理论分析十分充分,能够很好地解释GC的作用原理。





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

Gradient Centralization: 简单的梯度中心化,一行代码加速训练并提升泛化能力 | ECCV 2020 Oral的更多相关文章

  1. 简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020

    论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果   来源:晓飞的算 ...

  2. 一行代码调用实现带字段选取+条件判断+排序+分页功能的增强ORM框架

    问题:3行代码 PDF.NET是一个开源的数据开发框架,它的特点是简单.轻量.快速,易上手,而且是一个注释完善的国产开发框架,受到不少朋友的欢迎,也在我们公司的项目中多次使用.但是,PDF.NET比起 ...

  3. 学习笔记57—归一化 (Normalization)、标准化 (Standardization)和中心化/零均值化 (Zero-centered)

    1 概念   归一化:1)把数据变成(0,1)或者(1,1)之间的小数.主要是为了数据处理方便提出来的,把数据映射到0-1范围之内处理,更加便捷快速.2)把有量纲表达式变成无量纲表达式,便于不同单位或 ...

  4. 一个轻client,多语言支持,去中心化,自己主动负载,可扩展的实时数据写服务的实现方案讨论

    背景 背景是设计一个实时数据接入的模块,负责接收client的实时数据写入(如日志流,点击流),数据支持直接下沉到HBase上(兴许提供HBase上的查询),或先持久化到Kafka里.方便兴许进行一些 ...

  5. [数据预处理]-中心化 缩放 KNN(一)

    据预处理是总称,涵盖了数据分析师使用它将数据转处理成想要的数据的一系列操作.例如,对某个网站进行分析的时候,可能会去掉 html 标签,空格,缩进以及提取相关关键字.分析空间数据的时候,一般会把带单位 ...

  6. 在dotnet core下去中心化访问HTTP服务集群

    一般应用服务都会部署到多台服务器之上,一.可以通过硬件得到更多的并发处理能力:二.可以避免单点太故障的出现,从而确保服务7X24有效运作.当访问这些HTTP服务的情况一般都是经过反向代理服务进行统一处 ...

  7. 理解去中心化 稳定币 DAI

    本文转载于深入浅出区块链, 原文链接 随着摩根大通推出JPM Coin 稳定币,可以预见稳定币将成为区块链落地的一大助推器. 坦白来讲,对于一个程序员的我来讲(不懂一点专业经济和金融),理解DAI的机 ...

  8. 去中心化存储项目终极指南 | Filecoin, Storj 和 PPIO 项目异同

    Filecoin,Storj 以及 PPIO 这三个存储公链的设计思路是不一样的,没有优劣之分,写这篇文章也并不是为了争论各项目的好坏对错.去中心化存储是一个长期商业赛道,不同团队在同一个赛道上往不同 ...

  9. 为什么比特币和以太坊未必真得比EOS更去中心化?

    在区块链行业里,有两派人一直在争论:一个是以比特币和以太坊为首的社群,另一个是以EOS为首的社群.这两群人一直在争论谁才是真正的未来,双方都认为自己这边更有未来.其中EOS抗争的重点就是100万TPS ...

随机推荐

  1. URL编码转换函数:escape()、encodeURI()、encodeURIComponent()讲解

    转自:https://www.cnblogs.com/douJiangYouTiao888/p/6473874.html 函数出现时间:         escape()                ...

  2. Hyperledger Fabric 2.1 搭建教程

    Hyperledger Fabric 2.1 搭建教程 环境准备 版本 Ubuntu 18.04 go 1.14.4 fabric 2.1 fabric-sample v1.4.4 nodejs 12 ...

  3. Linux 相关学习内容(不定期更新)

    Linux 主要目录 / 根目录,在 linux 下有且只有一个根目录,所有的东西都是从这里开始 /bin 可执行二进制文件的目录,如常用的命令,ls, tar, mv, cat.. /boot 放置 ...

  4. response对象乱码--解决

    中文乱码 响应对象中文乱码,即就是response对象乱码. response对象输出中文数据乱码解决方案: 1 字节流输出响应乱码. 该情况不一定乱码.但是解决乱码的步骤是: 1) 设置浏览器打开文 ...

  5. 赋值,逻辑,运算符, 控制流程之if 判断

    赋值运算 (1). 增量运算 age += 1 # age = age + 1 print(age) age -= 10 # age = age - 10 (2).交叉赋值 x = 111 y = 2 ...

  6. day16 三层装饰器和迭代器

    一. 经典的两层装饰器,也是标准装饰器 案例 import time def outter1(func): def wrapper(*args, **kwargs): start = time.tim ...

  7. redis(十六):Redis 安装,部署(LINUX环境下)

    第一步:下载安装包 访问https://redis.io/download  到官网进行下载.这里下载最新的4.0版本. 第二步:安装 1.通过远程管理工具,将压缩包拷贝到Linux服务器中,执行解压 ...

  8. 数据可视化实例(十五):有序条形图(matplotlib,pandas)

    偏差 (Deviation) 有序条形图 (Ordered Bar Chart) 有序条形图有效地传达了项目的排名顺序. 但是,在图表上方添加度量标准的值,用户可以从图表本身获取精确信息. https ...

  9. 浏览器常见攻击方式(XSS和CSRF)

    常见的浏览器攻击分为两种,一种为XSS(跨站脚本攻击),另一种则为CSRF(跨站请求伪造). XSS(跨站脚本攻击) 定义 XSS 全称是 Cross Site Scripting,为了与“CSS”区 ...

  10. Android 高德地图 java.lang.UnsatisfiedlinkError Native method not found: com.autonavi.amap.mapcore.MapCore.nativeNewInstance:(Ljava/lang/String;)

    在Android项目中引用高德地图,程序运行时出现上述问题,如果引用了Map3D的jar包,则需要在引入Jar文件的同时引入so文件,在高德地图的demo中,找到so文件: 然后将其复制到jniLib ...