论文地址

channel pruning是指给定一个CNN模型,去掉卷积层的某几个输入channel以及相应的卷积核,
并最小化裁剪channel后与原始输出的误差。

可以分两步来解决:

  1. channel selection
    利用LASSO回归裁剪掉多余的channel,求出每个channel的权重,如果为0即是被裁减。
  2. feature map reconstruction
    利用剩下的channel重建输出,直接使用最小平方误差来拟合原始卷积层的输出,求出新的卷积核W。

二、优化目标

2.1 定义优化目标

输入c个channel,输出n个channel,卷积核W的大小是

我们对输入做了采样,假设对每个输入,对channel采样出来N个大小为块

为了把输入channel从c个剪裁到c’个,我们定义最优化的目标为

其中是每个channel的权重向量,如果是0,意味着裁剪当前channel,相应的也被裁减。

2.2 求最优目标

为了最优化目标,分为如下两步

2.2.1 固定W,求

其中,大小是,

这里之所以加上关于的L1正则项,是为了避免所有的都为1,而是让它们趋于0。

2.2.2 固定,求W

利用剩下的channel重建输出,直接求最小平方误差

其中,大小为,
W’也被reshape为。

2.2.3 多分支的情况

论文只考虑了常见的残差网络,设residual分支的输出为,shortcut 分支的输出为。

这里首先在residual分支的第一层前做了channel采样,从而减少计算量(训练过程中做的,即filter layer)。

设为原始的上一层的输出,
那么channel pruning中,residual分支的输出拟合,其中是上一层裁减后的shortcut。

三、实现

实现的时候,不是按照不断迭代第一步和第二步,因为比较耗时。
而是先不断的迭代第一步,直到裁剪剩下的channel个数为c’,然后执行第二步求出最终的W。

3.1 第一步Channel Selection

如何得到LASSO回归的输入:

(1)首先把输入做转置

# (N, c, hw) --> (c, N, hw)
inputs = np.transpose(inputs, [1, 0, 2])

(2)把weigh做转置

# (n, c, hw) --> (c, hw, n)
weights = np.transpose(weights, [1, 2, 0]))

(3)最后两维做矩阵乘法

# (c, N, n), matmul apply dot on the last two dim
outputs = np.matmul(inputs, weights)

(4)把输出做reshape和转置

# (Nn, c)
outputs = np.transpose(outputs.reshape(outputs.shape[0], -1))

LASSO回归的目标值即是对应的Y,大小为

的大小影响了最终为0的个数,为了找出合适的,需要尝试不同的值,直到裁剪剩下的channel个数为为止。

为了找到合适的可以使用二分查找,
或者不断增大直到裁剪剩下的channel个数,然后降序排序取前,剩下的为0。

while True:
coef = solve(alpha)
if sum(coef != 0) < rank:
break
last_alpha = alpha
last_coef = coef
alpha = 4 * alpha + math.log(coef.shape[0])
if not fast:
# binary search until compression ratio is satisfied
left = last_alpha
right = alpha
while True:
alpha = (left + right) / 2
coef = solve(alpha)
if sum(coef != 0) < rank:
right = alpha
elif sum(coef != 0) > rank:
left = alpha
else:
break
else:
last_coef = np.abs(last_coef)
sorted_coef = sorted(last_coef, reverse=True)
rank_max = sorted_coef[rank - 1]
coef = np.array([c if c >= rank_max else 0 for c in last_coef])

3.2 第二步Feature Map Reconstruction

直接利用最小平方误差,求出最终的卷积核。

from sklearn import linear_model
def LinearRegression(input, output):
clf = linear_model.LinearRegression()
clf.fit(input, output)
return clf.coef_, clf.intercept_
pruned_weights, pruned_bias = LinearRegression(input=inputs, output=real_outputs)

3.3 一些细节

  1. 将Relu层和卷积层分离
    因为Relu一般会使用inplace操作来节省内存/显存,如果不分离开,那么得到的卷积层的输出是经过了Relu激活函数计算后的结果。

  2. 每次裁减完一个卷积层后,需要对该层的bottom和top层的输入或输出大小作相应的改变。

  3. 第一步求出后,若为0,则说明要裁减对应的channel,否则置为1,表示保留channel。

参考链接

https://github.com/yihui-he/channel-pruning

模型压缩之Channel Pruning的更多相关文章

  1. 【转载】NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩、机器学习及最优化算法

    原文:NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩.机器学习及最优化算法 导读 AI领域顶会NeurIPS正在加拿大蒙特利尔举办.本文针对实验室关注的几个研究热点,模型压缩.自 ...

  2. 【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法

    论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning 论文地址:https://arxiv.org/ ...

  3. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

  4. [论文分享]Channel Pruning via Automatic Structure Search

    authors: Mingbao Lin, Rongrong Ji, etc. comments: IJCAL2020 cite: [2001.08565v3] Channel Pruning via ...

  5. CNN 模型压缩与加速算法综述

    本文由云+社区发表 导语:卷积神经网络日益增长的深度和尺寸为深度学习在移动端的部署带来了巨大的挑战,CNN模型压缩与加速成为了学术界和工业界都重点关注的研究领域之一. 前言 自从AlexNet一举夺得 ...

  6. 论文笔记——Channel Pruning for Accelerating Very Deep Neural Networks

    论文地址:https://arxiv.org/abs/1707.06168 代码地址:https://github.com/yihui-he/channel-pruning 采用方法 这篇文章主要讲诉 ...

  7. 【DMCP】2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks-论文阅读

    DMCP 2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks Shaopeng Guo(sensetime ...

  8. 对抗性鲁棒性与模型压缩:ICCV2019论文解析

    对抗性鲁棒性与模型压缩:ICCV2019论文解析 Adversarial Robustness vs. Model Compression, or Both? 论文链接: http://openacc ...

  9. 模型压缩,模型减枝,tf.nn.zero_fraction,统计0的比例,等。

    我们刚接到一个项目时,一开始并不是如何设计模型,而是去先跑一个现有的模型,看在项目需求在现有模型下面效果怎么样.当现有模型效果不错需要深入挖掘时,仅仅时跑现有模型是不够的,比如,如果你要在嵌入式里面去 ...

随机推荐

  1. Java之同步代码块处理实现Runnable的线程安全问题

    /** * 例子:创建三个窗口卖票,总票数为100张.使用实现Runnable接口的方式 * * 1.问题:卖票过程中,出现了重票.错票 -->出现了线程的安全问题 * 2.问题出现的原因:当某 ...

  2. Python笔记_第三篇_面向对象_6.继承(单继承和多继承)

    1. 概念解释: 继承:有两个类:A类和B类.那么A类就拥有了B类中的属性和方法. * 例如:Object:是所有类的父亲,还可以成为基类或者超类(super()) * 继承者为子类,被继承者成为父类 ...

  3. 量化投资_TB交易开拓者A函数和Q函数详解

    //////////////////A函数详解/////////////// //A函数主要在端口上进行下单操作//////////////// A_AccountID说明 返回当前公式应用的交易帐户 ...

  4. 使用java获取手机号归属地等信息httpClient实现

    java获取手机号归属地 一般想获取手机号归属地等信息个人是无法获取的,但是可以通过调用第三方接口获取,具体百度搜索很多这里例子提供一个淘宝的接口 ,该功能已经发布到网站作为一个在线小工具,拿走不谢: ...

  5. sql的书写顺序

    例:select t.* from (select *  from t_user where isDelete = 1 limit 0,10) t order by t.qq select from ...

  6. P2486 [SDOI2011]染色 区间合并+树链剖分(加深对线段树的理解)

    #include<bits/stdc++.h> using namespace std; ; struct node{ int l,r,cnt,lazy; node(,,,):l(l1), ...

  7. day34-进程

    #进程是程序的运行,程序不运行不产生进程. #1.进程的并行与并发: # 并行:是指两者同时执行,比如赛跑,两人都在不停的往前跑.(资源够用,比如三个线程,四核的cpu) # 并发:是指资源有限的情况 ...

  8. 003.前端开发知识,前端基础CSS(2020-01-07)

    一.CSS初识 CSS通常称为CSS样式表或层叠样式表(级联样式表),主要用于设置HTML页面中的文本内容(字体.大小.对齐方式等).图片的外形(宽高.边框样式.边距等)以及版面的布局等外观显示样式. ...

  9. EXAM-2018-8-3

    EXAM-2018-8-3 D H 喜闻乐见的水题 J lower_bound + upper_bound 一个可以查找第一个大于,另一个可查找第一个不小于. F 找规律+奇偶分析 偶数好找,就是奇数 ...

  10. Leetcode7_整数反转

    题目 给出一个 32 位的有符号整数,你需要将这个整数中每位上的数字进行反转. 示例 1: 输入: 123输出: 321 示例 2: 输入: -123输出: -321 示例 3: 输入: 120输出: ...