Torch-Pruning

通道剪枝网络实现加速的工作。



Torch pruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪枝,自动发现层与层剪枝的依赖关系,可以处理Densenet、ResNet和DeepLab

特性

卷积网络通道剪枝 CNNs (e.g. ResNet, DenseNet, Deeplab) 和 Transformers (即 Bert, @horseee贡献代码)

  • 网络图跟踪以及依赖关系.
  • 支持网络层: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding 和 扩展层.
  • 支持操作: split, concatenation, skip connection, flatten, 等等.
  • 剪枝策略: Random, L1, L2, 等等.

它是怎样工作的

Torch-Pruning 使用 fake inputs输入网络和torch.jit一样收集网络信息.

dependency graph 用来表示计算图和层之间的关系. 由于裁剪一层会影响若干层 , dependecy会自动传播剪枝到其他层并且保存在PruningPlan.

如果模型中有 torch.split或者torch.cat,所有剪枝的indices都会做一些变换的

Conv-Conv:\(n_{i+1}\) oc中减少1个通道,下一个卷积每个通oc通道中ic通道\(n_{i+1}\)少一个

Skip Connection: 需要考虑ic和上一层的oc互相关联,所以这里shortcutadd都需要传递这种关联。

依赖关系 可视化 例子
Conv-Conv AlexNet
Conv-FC(Global Pooling or Flatten) ResNet,VGG
Skip Connection ResNet
Concatenation DenseNet, ASPP
Split torch.chunk

一个例子

先来看下torchpruning 的流程图:

# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy() # 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) # 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4, round_to=16) # or manually selected pruning_idxs=[2, 6, 9, ...]
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan) # 4. execute this plan (prune the model)
pruning_plan.exec() print(model)

pruning_plan = DG.get_pruning_plan( pruning_idxs ):

底层剪枝函数

使用一层一层的固定剪枝和上面是等价的

tp.prune_conv( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )

运行结果:

(Conv2d(36, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
3456)

对设备友好的通道对齐剪枝

可以通过设置round_to参数,下例可以使得通道对16取整(即,16,32,48,64)

strategy = tp.strategy.L1Strategy()
pruning_idxs = strategy(model.conv1.weight, amount=0.2, round_to=16)

本文暂时没有对torch pruning源码进行分析,先学会使用,后续如果有需要、有时间会再进行源码分析

Torch-Pruning工具箱的更多相关文章

  1. 07_利用pytorch的nn工具箱实现LeNet网络

    07_利用pytorch的nn工具箱实现LeNet网络 目录 一.引言 二.定义网络 三.损失函数 四.优化器 五.数据加载和预处理 六.Hub模块简介 七.总结 pytorch完整教程目录:http ...

  2. EMD分析 Matlab 精华总结 附开源工具箱(全)

    前言: 本贴写于2016年12与15日,UK.最近在学习EMD(Empirical Mode Decomposition)和HHT(Hilbert-Huang Transform)多分辨信号处理,FQ ...

  3. 相机标定简介与MatLab相机标定工具箱的使用(未涉及原理公式推导)

    相机标定 一.相机标定的目的 确定空间物体表面某点的三维几何位置与其在图像中对应点之间的相互关系,建立摄像机成像的几何模型,这些几何模型参数就是摄像机参数. 二.通用摄像机模型 世界坐标系.摄像机坐标 ...

  4. Torch Problems: require some packages doesn't work

    I've recently got a problem. require 'cutorch' doesn't work. But it was ok yesterday, although I hav ...

  5. Torch学习笔记1--Torch简介

    Torch是什么 Torch是一个由Lua语言开发的深度学习框架,目前支持Mac OS X 和Ubuntu 12及以上,官网 ,github地址. 具有如下特点: 交互式开发工具 可视化式的工具 第三 ...

  6. 基于CkEditor实现.net在线开发之路(4)快速布局,工具箱,模板载入,tab选项卡简单说明与使用

    上一章给常用的from表单控件属性页面,进行了简单说明和介绍,但是由于是在网页中做界面设计,操作肯定没有桌面应用程序方便,便捷,为了更方便的布局与设计,今天我主要说一下快速布局,工具箱,tab选项卡, ...

  7. 通用PE工具箱 4.0精简优化版

    通用PE工具箱 4.0精简优化版 经用过不少 WinPE 系统,都不是很满意,普遍存在篡改主页.添加广告链接至收藏夹.未经允许安装推广软件等流氓行为,还集成了诸多不常用的工具,令人头疼不已.那么今天给 ...

  8. 创建 WPF 工具箱控件

    创建 WPF 工具箱控件 WPF (Windows Presentation Framework) 工具箱控件模板允许您创建 WPF 控件,会自动添加到 工具箱 安装扩展的安装. 本主题演示如何使用模 ...

  9. 使用脚本自动配置matlab安装libsvm和随机森林工具箱

    前言 支持向量机(SVM)和随机森林 都是用于分类的机器学习算法. 这里我需要对网上的工具箱在matlab中进行配置. 效果演示: 1.双击运行“自动配置.bat” 2.matlab会自动启动,手动配 ...

  10. .NET开发者必备的工具箱

    本文作者Spencer是一名专注于ASP.NET和C#的程序员,他列举了平时工作.在家所使用的大部分开发工具,其中大部分工具都是集中于开发,当然也有一些其它用途的,比如图片处理.文件压缩等. 如果你是 ...

随机推荐

  1. Python实现URL自动转二维码的高效方法

    Python实现URL自动转二维码的高效方法 安装包依赖 pip install qrcode pip install pillow 程序 import qrcode data = "htt ...

  2. axurerp9怎么汉化:Axure RP9 中文激活安装教程

    Axure RP 9是一款一款专业级快速产品原型设计工具,使用它可以让用户快速.高效创建应用软件或Web网站的线框图.流程图.原型和规格说明文档.采用了极简主义的设计,界面布局更加清爽简洁,操作也非常 ...

  3. 浅析Bootstrap中Tab(标签页)的使用方法

    Bootstrap 导航元素使用相同的标记和基类,改变修饰的class,可以在不同的样式间进行切换如".nav-pills"(胶囊式导航)与 ".nav-tabs&quo ...

  4. 傻妞教程——对接QQ机器人go-cqhttp

    原本我懒,用的傻妞QQbot一键安装版,docker的,最近有点问题,索性换了go-cqhttp 安装go-cqhttp: go-cqhttp项目地址:https://github.com/Mrs4s ...

  5. 大数据之路Week08_day02 (Flume架构介绍和安装)

    Flume架构介绍和安装 写在前面在学习一门新的技术之前,我们得知道了解这个东西有什么用?我们可以使用它来做些什么呢?简单来说,flume是大数据日志分析中不能缺少的一个组件,既可以使用在流处理中,也 ...

  6. 【FAQ】HarmonyOS SDK 闭源开放能力 —Push Kit(9)

    1.问题描述: 通过push token向鸿蒙手机推送一条通知,收到通知后,通知右侧不展示图片. 解决方案: 检查一下是否存在图片风控:https://developer.huawei.com/con ...

  7. Hack The Box-代理连接及靶机-Meow-喵呜

    前言 ​ 在第一层,您将获得网络安全渗透测试领域的基本技能.您将首先学习如何匿名连接到各种服务,例如 FTP.SMB.Telnet.Rsync 和 RDP.接下来,您将发现 Nmap 的强大功能,这是 ...

  8. 全国省市区基础数据SQL插入脚本

    整理了一份全国省市区SQL插入脚本,并配上抓取数据读取插入数据库源码,附件下载地址:https://files.cnblogs.com/files/101Love/Region.rar

  9. Git分支命名规范总结

    Git分支命名规范总结 在Git分支命名规范中,通常通过前缀明确区分需求(功能开发)和Bug修复,以下是具体规则及示例: 一.命名规范区分原则 需求分支(Feature) 前缀:feature/ 或 ...

  10. [每日算法 - 华为机试] leetcode463. 岛屿的周长

    入口 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer.https://le ...