论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果



来源:晓飞的算法工程笔记 公众号

论文: Accelerating CNN Training by Pruning

Activation Gradients

Introduction


  在训练过程中,特征值梯度的回传和权值梯度的计算占了大部分的计算消耗。由于这两个操作都是以特征值梯度作为输入,而且零梯度不会占用计算资源,所以稀疏化特征值梯度可以降低回传阶段的计算消耗以及内存消耗。论文的目标在于高效地降低训练负载,从而在资源有限的平台进行大规模数据集的训练。

  论文假设特征值梯度服从正态分布,基于此计算阈值$\tau$,随后使用随机剪枝算法(stochastic pruning)将小于阈值的特征值梯度随机置为零或$\pm \tau$。经理论推理和实验证明,这种方法不仅能够有效地稀疏化特征值梯度,还能在加速训练的同时,不影响训练的收敛性。

General Dataflow


  卷积层通常包含4个阶段:推理、特征值梯度回传、权值梯度计算和权值更新。为了表示这些阶段的计算,论文定义了一些符号:

  卷积层的四个训练阶段的总结为:

  论文通过可视化发现,回传阶段的特征值梯度几乎全是非常小的、接近于零的值,自然而然地想到将这些值去掉不会对权值更新阶段造成很大的影响,所以论文认为剪枝特征值梯度能够加速卷积层在训练时的计算。

Sparsification Algorithms


Distribution Based Threshold Determination (DBTD)

  剪枝操作最关键的步骤是决定选择哪些元素进行消除,先前有研究使用最小堆进行元素选择,但这会带来较大的额外计算开销。为此,论文采用简单的阈值过滤进行元素选择。

  论文首先分析了两种经典的卷积网络结构的特征值梯度分布:Conv-ReLU结构和Conv-BN-ReLU结构:

  • 对于Conv-ReLU结构,输出的特征值梯度$dO$是稀疏的,但其分布是无规律的,而结构的输入特征值梯度$dI$几乎全是非零值。通过统计发现,$dI(\cdot)$的分布以零值对称分布,且密度随着梯度值的增加而下降。
  • 对于Conv-BN-ReLU结构,BN层设置在卷积层与ReLU层中间,改变了梯度的分布,且$dO$的分布与$dI$类似,。

  所以,上述的两种结构的梯度都可认为服从零均值、方差为$\sigma^2$的正态分布。对于Conv-ReLu结构,由于ReLU不会降低稀疏性,$dO$能够继承$dI$的稀疏性,将$dI$是作为Conv-ReLU结构中的剪枝目标梯度$g$。而对于Conv-BN-ReLU结构,则将$dO$作为剪枝目标$g$。这样,两种结构的剪枝目标都可统一为正态分布。假设$g$的数量为$n$,可以计算梯度的绝对值的均值,并得到该均值的期望为:

  这里的期望为从分布中采样$n$个点的期望,而非分布的整体期望,再定义以下公式

  将公式2代入公式1中,可以得到:

  从公式3可以看出$\tilde{\sigma}$为参数$\sigma$的无偏估计,接近于真实的均值,且$\tilde{\sigma}$的整体计算消耗是可以接受的。基于上面的分析,论文结合正态分布的累积函数$\Phi$、剪枝率$p$和$\tilde{\sigma}$计算阈值$\tau$:

Stochastic Pruning

  剪枝少量值较小的梯度几乎对权值的更新没有影响,但如果将这些值较小的梯度全部设为零,则会对特征值梯度的分布影响很大,进而影响梯度更新,造成严重的精度损失。参考Stochastic Rounding算法,论文采用随机剪枝来解决这个问题。

  随机剪枝逻辑如算法1所示,对于小于阈值$\tau$的梯度值,随机采样一个缩放权重来计算新阈值,再根据新阈值将梯度值置为零或$\pm \tau$。

  随机剪枝的效果如图2所示,能够在保持梯度分布的数学期望的情况下进行剪枝,与当前的方法相比,论文提出的方法的优点如下:

  • Lower runtime cost:DBTD的计算复杂度$O(n)$小于top-k算法$O(nlogk)$,且DBTD对硬件更友好,能够在异构平台实现。
  • Lower memory footprint:随机裁剪能保持收敛性,且不需要存储而外的内存。

  至此,Sparsification Algorithms在梯度回传时的特征值梯度计算为:

Experimental Results


  在CIFAR-10、CIFAR-100以及ImageNet上进行准确率验证。

  在CIFAR-10和ImageNet上进行收敛性验证。

  在不同的设备上进行加速效果验证。

Conclustion


  论文通过DBTD方法计算过滤阈值,再结合随机剪枝算法对特征值梯度进行裁剪,稀疏化特征值梯度,能够降低回传阶段的计算量,在CPU和ARM上的训练分别有3.99倍和5.92倍的加速效果。论文提出的特征值稀疏化算法看似很简单,其实进行了充分的理论推导以及实验验证,才得到最终合理的过滤方法,唯一可惜的是没在GPU设备上进行实验验证。论文对算法的收敛性以及期望有详细的理论验证,不过这里没有列出来,有兴趣的可以去看看原文。





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

简单的特征值梯度剪枝,CPU和ARM上带来4-5倍的训练加速 | ECCV 2020的更多相关文章

  1. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

  2. MPlayer在ARM上的移植(S5PV210开发板)

    MPlayer 1.0已经把大部分解码库都自带了,如ffmpeg,但是自带的音频库在S5PV210下效果非常不好.换成使用libmad效果不错.因此MPlayer 在ARM-Linux的最简单的移植只 ...

  3. ARM上的linux如何实现无线网卡的冷插拔和热插拔

    ARM上的linux如何实现无线网卡的冷插拔和热插拔 fulinux 凌云实验室 1. 冷插拔 如果在系统上电之前就将RT2070/RT3070芯片的无线网卡(以下简称wlan)插上,即冷插拔.我们通 ...

  4. arm上的参数列表传递的分析(以android为例)

    1. Linux中可变列表实现的源码分析 查看Linux0.11的内核源代码,对va_list, va_start, va_arg 的实现如下: va_list的实现没有差别,chartypedef ...

  5. OpenCV在ARM上的移植

    OpenCV在ARM上的移植 与X86 Linux类似,请参考:Linux 下编译安装OpenCV 本文在此基础上进行进一步操作. 网络上很多移植编译的方法比较老,多数针对OpenCV 1.0,而且方 ...

  6. 【Qt开发】【ARM-Linux开发】 QT在ARM上显示字体的问题

    在PC机上利用QT开发的应用程序在设置字体时,在PC上运行,可根据自己的设置,字体随之变大或变小.而移植到ARM上运行时发现,显示字体与所设置的字体不用,字体普遍偏小.经过上网搜索发现,是环境变量字库 ...

  7. 我写的界面,在ARM上跑

    这个...其实,我对ARM了解并不多,我顶多也就算是知道ARM怎么玩,EMMC干啥,MMU干啥,还有早期的叫法,比如那个NorFlash NandFlash ,然后也就没啥了. 然后写个裸机什么的,那 ...

  8. Ubuntu在ARM上建立NFS服务

    先引用别人的做法: 1.进行NFS服务器端与客户端的安装: sudo apt-get install nfs-kernel-server nfs-common portmap 安装客户端的作用是可以在 ...

  9. 用于ARM上的FFT与IFFT源代码(C语言,不依赖特定平台)(转)

    源:用于ARM上的FFT与IFFT源代码(C语言,不依赖特定平台) 代码在2011年全国电子大赛结束后(2011年9月3日)发布,多个版本,注释详细. /*********************** ...

随机推荐

  1. 1. JDK基础说明

    1. JDK基础说明 版本及新特性获取 作为技术人,关注新技术必不可少,那么最佳的途径...看下面. 在 Oracle Java 官方站点有这个非常好的引导地图 官方站点 https://docs.o ...

  2. 37 Reasons why your Neural Network is not working

    37 Reasons why your Neural Network is not working Neural Network Check List 如何使用这个指南 数据问题 检查输入数据 试一下 ...

  3. Kafka2.6.0发布——性能大幅提升

    近日Kafka2.6版本发布,距离2.5.0发布只过去了不到四个月的时间. Kafka 2.6.0包含许多重要的新功能.以下是一些重要更改的摘要: 默认情况下,已为Java 11或更高版本启用TLSv ...

  4. MSF常用命令备忘录

    msf下的命令 set session x:设置要攻击的session #监听端口反弹PHP shell use exploit/multi/handler set payload php/meter ...

  5. Mac中的垃圾文件的清理

    一 前言 最近发现mac的存储空间不够了,看一下系统的存储空间如下图所示,这个其他占了160+G的存储空间,那么这个其他到底包含什么东西呢?在网上查了很久,找到一种比较认可的说法是这样的: 不同Mac ...

  6. 7.hbase shell命令 cmd

    $HADOOP_USER_NAME #创建命名空间create_namespace 'bd1902' #展示所有命名空间 list_namespace #删除命名空间,The namespace mu ...

  7. linux下top命令详细介绍

    linux下top命令详细介绍 top 命令是 Linux 下常用的系统资源占用查看及性能分析工具,能够实时显示系统中各个进程的资源(比如cpu.内存的使用)占用状况,top命令的执行结果是一个动态显 ...

  8. 简单的股票信息查询系统 1 程序启动后,给用户提供查询接口,允许用户重复查股票行情信息(用到循环) 2 允许用户通过模糊查询股票名,比如输入“啤酒”, 就把所有股票名称中包含“啤酒”的信息打印出来 3 允许按股票价格、涨跌幅、换手率这几列来筛选信息, 比如输入“价格>50”则把价格大于50的股票都打印,输入“市盈率<50“,则把市盈率小于50的股票都打印,不用判断等于。

    '''需求:1 程序启动后,给用户提供查询接口,允许用户重复查股票行情信息(用到循环)2 允许用户通过模糊查询股票名,比如输入“啤酒”, 就把所有股票名称中包含“啤酒”的信息打印出来3 允许按股票价格 ...

  9. python数据处理工具 -- pandas(序列与数据框的构造)

    Pandas模块的核心操作对象就是对序列(Series)和数据框(Dataframe).序列可以理解为数据集中的一个字段,数据框是值包含至少两个字段(或序列) 的数据集. 构造序列 1.通过同质的列表 ...

  10. Windows中使用PowerShell查看和卸载补丁

    查看:get-hotfix -id KB4470788 卸载:wusa /uninstall /kb:3045999 get-hotfix -id KB4470788 wusa /uninstall ...