简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020
论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果
来源:晓飞的算法工程笔记 公众号
论文: Accelerating CNN Training by Pruning
Activation Gradients
Introduction
在训练过程中,特征值梯度的回传和权值梯度的计算占了大部分的计算消耗。由于这两个操作都是以特征值梯度作为输入,而且零梯度不会占用计算资源,所以稀疏化特征值梯度可以降低回传阶段的计算消耗以及内存消耗。论文的目标在于高效地降低训练负载,从而在资源有限的平台进行大规模数据集的训练。
论文假设特征值梯度服从正态分布,基于此计算阈值$\tau$,随后使用随机剪枝算法(stochastic pruning)将小于阈值的特征值梯度随机置为零或$\pm \tau$。经理论推理和实验证明,这种方法不仅能够有效地稀疏化特征值梯度,还能在加速训练的同时,不影响训练的收敛性。
General Dataflow
卷积层通常包含4个阶段:推理、特征值梯度回传、权值梯度计算和权值更新。为了表示这些阶段的计算,论文定义了一些符号:
卷积层的四个训练阶段的总结为:
论文通过可视化发现,回传阶段的特征值梯度几乎全是非常小的、接近于零的值,自然而然地想到将这些值去掉不会对权值更新阶段造成很大的影响,所以论文认为剪枝特征值梯度能够加速卷积层在训练时的计算。
Sparsification Algorithms
Distribution Based Threshold Determination (DBTD)
剪枝操作最关键的步骤是决定选择哪些元素进行消除,先前有研究使用最小堆进行元素选择,但这会带来较大的额外计算开销。为此,论文采用简单的阈值过滤进行元素选择。
论文首先分析了两种经典的卷积网络结构的特征值梯度分布:Conv-ReLU结构和Conv-BN-ReLU结构:
- 对于Conv-ReLU结构,输出的特征值梯度$dO$是稀疏的,但其分布是无规律的,而结构的输入特征值梯度$dI$几乎全是非零值。通过统计发现,$dI(\cdot)$的分布以零值对称分布,且密度随着梯度值的增加而下降。
- 对于Conv-BN-ReLU结构,BN层设置在卷积层与ReLU层中间,改变了梯度的分布,且$dO$的分布与$dI$类似,。
所以,上述的两种结构的梯度都可认为服从零均值、方差为$\sigma^2$的正态分布。对于Conv-ReLu结构,由于ReLU不会降低稀疏性,$dO$能够继承$dI$的稀疏性,将$dI$是作为Conv-ReLU结构中的剪枝目标梯度$g$。而对于Conv-BN-ReLU结构,则将$dO$作为剪枝目标$g$。这样,两种结构的剪枝目标都可统一为正态分布。假设$g$的数量为$n$,可以计算梯度的绝对值的均值,并得到该均值的期望为:
这里的期望为从分布中采样$n$个点的期望,而非分布的整体期望,再定义以下公式
将公式2代入公式1中,可以得到:
从公式3可以看出$\tilde{\sigma}$为参数$\sigma$的无偏估计,接近于真实的均值,且$\tilde{\sigma}$的整体计算消耗是可以接受的。基于上面的分析,论文结合正态分布的累积函数$\Phi$、剪枝率$p$和$\tilde{\sigma}$计算阈值$\tau$:
Stochastic Pruning
剪枝少量值较小的梯度几乎对权值的更新没有影响,但如果将这些值较小的梯度全部设为零,则会对特征值梯度的分布影响很大,进而影响梯度更新,造成严重的精度损失。参考Stochastic Rounding算法,论文采用随机剪枝来解决这个问题。
随机剪枝逻辑如算法1所示,对于小于阈值$\tau$的梯度值,随机采样一个缩放权重来计算新阈值,再根据新阈值将梯度值置为零或$\pm \tau$。
随机剪枝的效果如图2所示,能够在保持梯度分布的数学期望的情况下进行剪枝,与当前的方法相比,论文提出的方法的优点如下:
- Lower runtime cost:DBTD的计算复杂度$O(n)$小于top-k算法$O(nlogk)$,且DBTD对硬件更友好,能够在异构平台实现。
- Lower memory footprint:随机裁剪能保持收敛性,且不需要存储而外的内存。
至此,Sparsification Algorithms在梯度回传时的特征值梯度计算为:
Experimental Results
在CIFAR-10、CIFAR-100以及ImageNet上进行准确率验证。
在CIFAR-10和ImageNet上进行收敛性验证。
在不同的设备上进行加速效果验证。
Conclustion
论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果。论文提出的特征值稀疏化算法看似很简单,其实进行了充分的理论推导以及实验验证,才得到最终合理的过滤方法,唯一可惜的是没在GPU设备上进行实验验证。论文对算法的收敛性以及期望有详细的理论验证,不过这里没有列出来,有兴趣的可以去看看原文。
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】
简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020的更多相关文章
- 简单线性回归(梯度下降法) python实现
grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...
- MPlayer在ARM上的移植(S5PV210开发板)
MPlayer 1.0已经把大部分解码库都自带了,如ffmpeg,但是自带的音频库在S5PV210下效果非常不好.换成使用libmad效果不错.因此MPlayer 在ARM-Linux的最简单的移植只 ...
- ARM上的linux如何实现无线网卡的冷插拔和热插拔
ARM上的linux如何实现无线网卡的冷插拔和热插拔 fulinux 凌云实验室 1. 冷插拔 如果在系统上电之前就将RT2070/RT3070芯片的无线网卡(以下简称wlan)插上,即冷插拔.我们通 ...
- arm上的参数列表传递的分析(以android为例)
1. Linux中可变列表实现的源码分析 查看Linux0.11的内核源代码,对va_list, va_start, va_arg 的实现如下: va_list的实现没有差别,chartypedef ...
- OpenCV在ARM上的移植
OpenCV在ARM上的移植 与X86 Linux类似,请参考:Linux 下编译安装OpenCV 本文在此基础上进行进一步操作. 网络上很多移植编译的方法比较老,多数针对OpenCV 1.0,而且方 ...
- 【Qt开发】【ARM-Linux开发】 QT在ARM上显示字体的问题
在PC机上利用QT开发的应用程序在设置字体时,在PC上运行,可根据自己的设置,字体随之变大或变小.而移植到ARM上运行时发现,显示字体与所设置的字体不用,字体普遍偏小.经过上网搜索发现,是环境变量字库 ...
- 我写的界面,在ARM上跑
这个...其实,我对ARM了解并不多,我顶多也就算是知道ARM怎么玩,EMMC干啥,MMU干啥,还有早期的叫法,比如那个NorFlash NandFlash ,然后也就没啥了. 然后写个裸机什么的,那 ...
- Ubuntu在ARM上建立NFS服务
先引用别人的做法: 1.进行NFS服务器端与客户端的安装: sudo apt-get install nfs-kernel-server nfs-common portmap 安装客户端的作用是可以在 ...
- 用于ARM上的FFT与IFFT源代码(C语言,不依赖特定平台)(转)
源:用于ARM上的FFT与IFFT源代码(C语言,不依赖特定平台) 代码在2011年全国电子大赛结束后(2011年9月3日)发布,多个版本,注释详细. /*********************** ...
随机推荐
- 1. JDK基础说明
1. JDK基础说明 版本及新特性获取 作为技术人,关注新技术必不可少,那么最佳的途径...看下面. 在 Oracle Java 官方站点有这个非常好的引导地图 官方站点 https://docs.o ...
- 37 Reasons why your Neural Network is not working
37 Reasons why your Neural Network is not working Neural Network Check List 如何使用这个指南 数据问题 检查输入数据 试一下 ...
- Kafka2.6.0发布——性能大幅提升
近日Kafka2.6版本发布,距离2.5.0发布只过去了不到四个月的时间. Kafka 2.6.0包含许多重要的新功能.以下是一些重要更改的摘要: 默认情况下,已为Java 11或更高版本启用TLSv ...
- MSF常用命令备忘录
msf下的命令 set session x:设置要攻击的session #监听端口反弹PHP shell use exploit/multi/handler set payload php/meter ...
- Mac中的垃圾文件的清理
一 前言 最近发现mac的存储空间不够了,看一下系统的存储空间如下图所示,这个其他占了160+G的存储空间,那么这个其他到底包含什么东西呢?在网上查了很久,找到一种比较认可的说法是这样的: 不同Mac ...
- 7.hbase shell命令 cmd
$HADOOP_USER_NAME #创建命名空间create_namespace 'bd1902' #展示所有命名空间 list_namespace #删除命名空间,The namespace mu ...
- linux下top命令详细介绍
linux下top命令详细介绍 top 命令是 Linux 下常用的系统资源占用查看及性能分析工具,能够实时显示系统中各个进程的资源(比如cpu.内存的使用)占用状况,top命令的执行结果是一个动态显 ...
- 简单的股票信息查询系统 1 程序启动后,给用户提供查询接口,允许用户重复查股票行情信息(用到循环) 2 允许用户通过模糊查询股票名,比如输入“啤酒”, 就把所有股票名称中包含“啤酒”的信息打印出来 3 允许按股票价格、涨跌幅、换手率这几列来筛选信息, 比如输入“价格>50”则把价格大于50的股票都打印,输入“市盈率<50“,则把市盈率小于50的股票都打印,不用判断等于。
'''需求:1 程序启动后,给用户提供查询接口,允许用户重复查股票行情信息(用到循环)2 允许用户通过模糊查询股票名,比如输入“啤酒”, 就把所有股票名称中包含“啤酒”的信息打印出来3 允许按股票价格 ...
- python数据处理工具 -- pandas(序列与数据框的构造)
Pandas模块的核心操作对象就是对序列(Series)和数据框(Dataframe).序列可以理解为数据集中的一个字段,数据框是值包含至少两个字段(或序列) 的数据集. 构造序列 1.通过同质的列表 ...
- Windows中使用PowerShell查看和卸载补丁
查看:get-hotfix -id KB4470788 卸载:wusa /uninstall /kb:3045999 get-hotfix -id KB4470788 wusa /uninstall ...