论文地址

channel pruning是指给定一个CNN模型,去掉卷积层的某几个输入channel以及相应的卷积核,
并最小化裁剪channel后与原始输出的误差。

可以分两步来解决:

  1. channel selection
    利用LASSO回归裁剪掉多余的channel,求出每个channel的权重,如果为0即是被裁减。
  2. feature map reconstruction
    利用剩下的channel重建输出,直接使用最小平方误差来拟合原始卷积层的输出,求出新的卷积核W。

二、优化目标

2.1 定义优化目标

输入c个channel,输出n个channel,卷积核W的大小是

我们对输入做了采样,假设对每个输入,对channel采样出来N个大小为块

为了把输入channel从c个剪裁到c’个,我们定义最优化的目标为

其中是每个channel的权重向量,如果是0,意味着裁剪当前channel,相应的也被裁减。

2.2 求最优目标

为了最优化目标,分为如下两步

2.2.1 固定W,求

其中,大小是,

这里之所以加上关于的L1正则项,是为了避免所有的都为1,而是让它们趋于0。

2.2.2 固定,求W

利用剩下的channel重建输出,直接求最小平方误差

其中,大小为,
W’也被reshape为。

2.2.3 多分支的情况

论文只考虑了常见的残差网络,设residual分支的输出为,shortcut 分支的输出为。

这里首先在residual分支的第一层前做了channel采样,从而减少计算量(训练过程中做的,即filter layer)。

设为原始的上一层的输出,
那么channel pruning中,residual分支的输出拟合,其中是上一层裁减后的shortcut。

三、实现

实现的时候,不是按照不断迭代第一步和第二步,因为比较耗时。
而是先不断的迭代第一步,直到裁剪剩下的channel个数为c’,然后执行第二步求出最终的W。

3.1 第一步Channel Selection

如何得到LASSO回归的输入:

(1)首先把输入做转置

# (N, c, hw) --> (c, N, hw)
inputs = np.transpose(inputs, [1, 0, 2])

(2)把weigh做转置

# (n, c, hw) --> (c, hw, n)
weights = np.transpose(weights, [1, 2, 0]))

(3)最后两维做矩阵乘法

# (c, N, n), matmul apply dot on the last two dim
outputs = np.matmul(inputs, weights)

(4)把输出做reshape和转置

# (Nn, c)
outputs = np.transpose(outputs.reshape(outputs.shape[0], -1))

LASSO回归的目标值即是对应的Y,大小为

的大小影响了最终为0的个数,为了找出合适的,需要尝试不同的值,直到裁剪剩下的channel个数为为止。

为了找到合适的可以使用二分查找,
或者不断增大直到裁剪剩下的channel个数,然后降序排序取前,剩下的为0。

while True:
coef = solve(alpha)
if sum(coef != 0) < rank:
break
last_alpha = alpha
last_coef = coef
alpha = 4 * alpha + math.log(coef.shape[0])
if not fast:
# binary search until compression ratio is satisfied
left = last_alpha
right = alpha
while True:
alpha = (left + right) / 2
coef = solve(alpha)
if sum(coef != 0) < rank:
right = alpha
elif sum(coef != 0) > rank:
left = alpha
else:
break
else:
last_coef = np.abs(last_coef)
sorted_coef = sorted(last_coef, reverse=True)
rank_max = sorted_coef[rank - 1]
coef = np.array([c if c >= rank_max else 0 for c in last_coef])

3.2 第二步Feature Map Reconstruction

直接利用最小平方误差,求出最终的卷积核。

from sklearn import linear_model
def LinearRegression(input, output):
clf = linear_model.LinearRegression()
clf.fit(input, output)
return clf.coef_, clf.intercept_
pruned_weights, pruned_bias = LinearRegression(input=inputs, output=real_outputs)

3.3 一些细节

  1. 将Relu层和卷积层分离
    因为Relu一般会使用inplace操作来节省内存/显存,如果不分离开,那么得到的卷积层的输出是经过了Relu激活函数计算后的结果。

  2. 每次裁减完一个卷积层后,需要对该层的bottom和top层的输入或输出大小作相应的改变。

  3. 第一步求出后,若为0,则说明要裁减对应的channel,否则置为1,表示保留channel。

参考链接

https://github.com/yihui-he/channel-pruning

模型压缩之Channel Pruning的更多相关文章

  1. 【转载】NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩、机器学习及最优化算法

    原文:NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩.机器学习及最优化算法 导读 AI领域顶会NeurIPS正在加拿大蒙特利尔举办.本文针对实验室关注的几个研究热点,模型压缩.自 ...

  2. 【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法

    论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning 论文地址:https://arxiv.org/ ...

  3. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

  4. [论文分享]Channel Pruning via Automatic Structure Search

    authors: Mingbao Lin, Rongrong Ji, etc. comments: IJCAL2020 cite: [2001.08565v3] Channel Pruning via ...

  5. CNN 模型压缩与加速算法综述

    本文由云+社区发表 导语:卷积神经网络日益增长的深度和尺寸为深度学习在移动端的部署带来了巨大的挑战,CNN模型压缩与加速成为了学术界和工业界都重点关注的研究领域之一. 前言 自从AlexNet一举夺得 ...

  6. 论文笔记——Channel Pruning for Accelerating Very Deep Neural Networks

    论文地址:https://arxiv.org/abs/1707.06168 代码地址:https://github.com/yihui-he/channel-pruning 采用方法 这篇文章主要讲诉 ...

  7. 【DMCP】2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks-论文阅读

    DMCP 2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks Shaopeng Guo(sensetime ...

  8. 对抗性鲁棒性与模型压缩:ICCV2019论文解析

    对抗性鲁棒性与模型压缩:ICCV2019论文解析 Adversarial Robustness vs. Model Compression, or Both? 论文链接: http://openacc ...

  9. 模型压缩,模型减枝,tf.nn.zero_fraction,统计0的比例,等。

    我们刚接到一个项目时,一开始并不是如何设计模型,而是去先跑一个现有的模型,看在项目需求在现有模型下面效果怎么样.当现有模型效果不错需要深入挖掘时,仅仅时跑现有模型是不够的,比如,如果你要在嵌入式里面去 ...

随机推荐

  1. Popular generalized linear models|GLMM| Zero-truncated Models|Zero-Inflated Models|matched case–control studies|多重logistics回归|ordered logistics regression

    ============================================================== Popular generalized linear models 将不同 ...

  2. 关于电脑识别不出自己画的板子上的USB接口问题

    现在在画一个Cortex-A5的底板,现在已经完成,正在测试各个模块,发现USB插上后,电脑提示报错,如下: 网上查了很多,有的说是配置问题,有的说是走线问题,首先配置肯定没问题,因为同一台电脑,在买 ...

  3. linux 中的.so和.a文件

    Linux下的.so是基于Linux下的动态链接,其功能和作用类似与windows下.dll文件. 下面是关于.so的介绍: 一.引言 通常情况下,对函数库的链接是放在编译时期(compile tim ...

  4. 93.QuerySet转换为SQL的条件:迭代,切片(指定步长),len函数,list函数,判断

    生成一个QuerySet对象并不会马上转换为SQL语句去执行. books = Book.objects.filter(pk=3) print(connection.queries) 打印出djang ...

  5. VMware12 + Ubuntu16.04 虚拟磁盘扩容

    转载自:https://blog.csdn.net/Timsley/article/details/50742755 今天用虚拟机的时候,发现虚拟机快满了,提示磁盘空间小,不得不扩充虚拟机空间.经过百 ...

  6. 你必须知道的基本位运算技巧(状压DP、搜索优化都会用到)

    一. 位操作基础 基本的位操作符有与.或.异或.取反.左移.右移这6种,它们的运算规则如下所示: 符号 描述 运算规则 & 与 两个位都为1时,结果才为1 | 或 两个位都为0时,结果才为0 ...

  7. Sqlite教程(2) Data Access Object

    因为这个项目的业务层很薄,因此想在架构上尽量保持着「轻」,不会把创建DbHelper的interface. 而是直接用DAO创建DbHelper对象. DAO和DbHelper也是同样使用懒汉模式. ...

  8. poj-3259 Wormholes(无向、负权、最短路之负环判断)

    http://poj.org/problem?id=3259 Description While exploring his many farms, Farmer John has discovere ...

  9. 吴裕雄--天生自然 JAVA开发学习:Applet 基础

    import java.applet.*; import java.awt.*; public class HelloWorldApplet extends Applet { public void ...

  10. HDU 6126 Give out candies(网络流)

    题目给出n,m,k 然后给出n*m的矩阵a[i][j]代表第i个人在获得j 颗糖果能得到的满足值, 然后k是k行每行输入三个整数x,y,z     ,x,y,z表示一组限制表示第x个人分到的糖数减去第 ...