论文地址

channel pruning是指给定一个CNN模型,去掉卷积层的某几个输入channel以及相应的卷积核,
并最小化裁剪channel后与原始输出的误差。

可以分两步来解决:

  1. channel selection
    利用LASSO回归裁剪掉多余的channel,求出每个channel的权重,如果为0即是被裁减。
  2. feature map reconstruction
    利用剩下的channel重建输出,直接使用最小平方误差来拟合原始卷积层的输出,求出新的卷积核W。

二、优化目标

2.1 定义优化目标

输入c个channel,输出n个channel,卷积核W的大小是

我们对输入做了采样,假设对每个输入,对channel采样出来N个大小为块

为了把输入channel从c个剪裁到c’个,我们定义最优化的目标为

其中是每个channel的权重向量,如果是0,意味着裁剪当前channel,相应的也被裁减。

2.2 求最优目标

为了最优化目标,分为如下两步

2.2.1 固定W,求

其中,大小是,

这里之所以加上关于的L1正则项,是为了避免所有的都为1,而是让它们趋于0。

2.2.2 固定,求W

利用剩下的channel重建输出,直接求最小平方误差

其中,大小为,
W’也被reshape为。

2.2.3 多分支的情况

论文只考虑了常见的残差网络,设residual分支的输出为,shortcut 分支的输出为。

这里首先在residual分支的第一层前做了channel采样,从而减少计算量(训练过程中做的,即filter layer)。

设为原始的上一层的输出,
那么channel pruning中,residual分支的输出拟合,其中是上一层裁减后的shortcut。

三、实现

实现的时候,不是按照不断迭代第一步和第二步,因为比较耗时。
而是先不断的迭代第一步,直到裁剪剩下的channel个数为c’,然后执行第二步求出最终的W。

3.1 第一步Channel Selection

如何得到LASSO回归的输入:

(1)首先把输入做转置

# (N, c, hw) --> (c, N, hw)
inputs = np.transpose(inputs, [1, 0, 2])

(2)把weigh做转置

# (n, c, hw) --> (c, hw, n)
weights = np.transpose(weights, [1, 2, 0]))

(3)最后两维做矩阵乘法

# (c, N, n), matmul apply dot on the last two dim
outputs = np.matmul(inputs, weights)

(4)把输出做reshape和转置

# (Nn, c)
outputs = np.transpose(outputs.reshape(outputs.shape[0], -1))

LASSO回归的目标值即是对应的Y,大小为

的大小影响了最终为0的个数,为了找出合适的,需要尝试不同的值,直到裁剪剩下的channel个数为为止。

为了找到合适的可以使用二分查找,
或者不断增大直到裁剪剩下的channel个数,然后降序排序取前,剩下的为0。

while True:
coef = solve(alpha)
if sum(coef != 0) < rank:
break
last_alpha = alpha
last_coef = coef
alpha = 4 * alpha + math.log(coef.shape[0])
if not fast:
# binary search until compression ratio is satisfied
left = last_alpha
right = alpha
while True:
alpha = (left + right) / 2
coef = solve(alpha)
if sum(coef != 0) < rank:
right = alpha
elif sum(coef != 0) > rank:
left = alpha
else:
break
else:
last_coef = np.abs(last_coef)
sorted_coef = sorted(last_coef, reverse=True)
rank_max = sorted_coef[rank - 1]
coef = np.array([c if c >= rank_max else 0 for c in last_coef])

3.2 第二步Feature Map Reconstruction

直接利用最小平方误差,求出最终的卷积核。

from sklearn import linear_model
def LinearRegression(input, output):
clf = linear_model.LinearRegression()
clf.fit(input, output)
return clf.coef_, clf.intercept_
pruned_weights, pruned_bias = LinearRegression(input=inputs, output=real_outputs)

3.3 一些细节

  1. 将Relu层和卷积层分离
    因为Relu一般会使用inplace操作来节省内存/显存,如果不分离开,那么得到的卷积层的输出是经过了Relu激活函数计算后的结果。

  2. 每次裁减完一个卷积层后,需要对该层的bottom和top层的输入或输出大小作相应的改变。

  3. 第一步求出后,若为0,则说明要裁减对应的channel,否则置为1,表示保留channel。

参考链接

https://github.com/yihui-he/channel-pruning

模型压缩之Channel Pruning的更多相关文章

  1. 【转载】NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩、机器学习及最优化算法

    原文:NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩.机器学习及最优化算法 导读 AI领域顶会NeurIPS正在加拿大蒙特利尔举办.本文针对实验室关注的几个研究热点,模型压缩.自 ...

  2. 【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法

    论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning 论文地址:https://arxiv.org/ ...

  3. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

  4. [论文分享]Channel Pruning via Automatic Structure Search

    authors: Mingbao Lin, Rongrong Ji, etc. comments: IJCAL2020 cite: [2001.08565v3] Channel Pruning via ...

  5. CNN 模型压缩与加速算法综述

    本文由云+社区发表 导语:卷积神经网络日益增长的深度和尺寸为深度学习在移动端的部署带来了巨大的挑战,CNN模型压缩与加速成为了学术界和工业界都重点关注的研究领域之一. 前言 自从AlexNet一举夺得 ...

  6. 论文笔记——Channel Pruning for Accelerating Very Deep Neural Networks

    论文地址:https://arxiv.org/abs/1707.06168 代码地址:https://github.com/yihui-he/channel-pruning 采用方法 这篇文章主要讲诉 ...

  7. 【DMCP】2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks-论文阅读

    DMCP 2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks Shaopeng Guo(sensetime ...

  8. 对抗性鲁棒性与模型压缩:ICCV2019论文解析

    对抗性鲁棒性与模型压缩:ICCV2019论文解析 Adversarial Robustness vs. Model Compression, or Both? 论文链接: http://openacc ...

  9. 模型压缩,模型减枝,tf.nn.zero_fraction,统计0的比例,等。

    我们刚接到一个项目时,一开始并不是如何设计模型,而是去先跑一个现有的模型,看在项目需求在现有模型下面效果怎么样.当现有模型效果不错需要深入挖掘时,仅仅时跑现有模型是不够的,比如,如果你要在嵌入式里面去 ...

随机推荐

  1. cookie保存

    <!DOCTYPE html> <html lang="en"> <head>     <meta charset="UTF-8 ...

  2. Maven--可选依赖

    假设有这样换一个依赖关系,项目 A 依赖于项目 B,项目 B 依赖于项目 X 和 Y,B 对于 X 和 Y的依赖都是可选依赖: A -> B B -> X(可选) B -> Y(可选 ...

  3. js等于符号的详解

    JavaScript == 与 === 区别 1.对于 string.number 等基础类型,== 和 === 是有区别的 a)不同类型间比较,== 之比较 "转化成同一类型后的值&quo ...

  4. Oauth2.0详解及安全使用

    引言:刚刚参加工作的时候接到的第一个任务就是接入新浪的联合登录功能,当时新浪用的还是oauth1.0协议.接入的时候没有对oauth协议有过多的了解,只是按照开放平台的接入流程进行开发,当时还在想这么 ...

  5. Leetcode——863.二叉树中所有距离为 K 的结点

    给定一个二叉树(具有根结点 root), 一个目标结点 target ,和一个整数值 K . 返回到目标结点 target 距离为 K 的所有结点的值的列表. 答案可以以任何顺序返回. 示例 1: 输 ...

  6. 关于mysql数据库连接异常处理

    tomcat启动错误日志关键信息: 28-Aug-2019 14:22:55.014 SEVERE [localhost-startStop-1] org.apache.catalina.core.C ...

  7. Docker容器化【Dockerfile编写&&搭建与使用Docker私有仓库】

    # Docker 学习目标: 掌握Docker基础知识,能够理解Docker镜像与容器的概念 完成Docker安装与启动 掌握Docker镜像与容器相关命令 掌握Tomcat Nginx 等软件的常用 ...

  8. spark mllib lda 中文分词、主题聚合基本样例

    github https://github.com/cclient/spark-lda-example spark mllib lda example 官方示例较为精简 在官方lda示例的基础上,给合 ...

  9. 关于目录的操作|*|<>|opendir |readdir|unlink|find2perl|rename|readlink|oct()|utime

    #!/usr/bin/perl use strict; use warnings; foreach my $arg(@ARGV) { print "one is $arg\n"; ...

  10. LeetCode No.88,89,90

    No.88 Merge 合并两个有序数组 题目 给定两个有序整数数组 nums1 和 nums2,将 nums2 合并到 nums1 中,使得 num1 成为一个有序数组. 说明: 初始化 nums1 ...