Torch-Pruning工具箱
Torch-Pruning
通道剪枝网络实现加速的工作。
Torch pruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪枝,自动发现层与层剪枝的依赖关系,可以处理Densenet、ResNet和DeepLab
特性
卷积网络通道剪枝 CNNs (e.g. ResNet, DenseNet, Deeplab) 和 Transformers (即 Bert, @horseee贡献代码)
- 网络图跟踪以及依赖关系.
- 支持网络层: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding 和 扩展层.
- 支持操作: split, concatenation, skip connection, flatten, 等等.
- 剪枝策略: Random, L1, L2, 等等.
它是怎样工作的
Torch-Pruning 使用 fake inputs输入网络和torch.jit一样收集网络信息.
dependency graph 用来表示计算图和层之间的关系. 由于裁剪一层会影响若干层 , dependecy会自动传播剪枝到其他层并且保存在PruningPlan.
如果模型中有 torch.split
或者torch.cat
,所有剪枝的indices都会做一些变换的
Conv-Conv:\(n_{i+1}\) oc中减少1个通道,下一个卷积每个通oc通道中ic通道\(n_{i+1}\)少一个
Skip Connection: 需要考虑ic和上一层的oc互相关联,所以这里shortcut
和add
都需要传递这种关联。
依赖关系 | 可视化 | 例子 |
---|---|---|
Conv-Conv | ![]() |
AlexNet |
Conv-FC(Global Pooling or Flatten) | ![]() |
ResNet,VGG |
Skip Connection | ![]() |
ResNet |
Concatenation | ![]() |
DenseNet, ASPP |
Split | ![]() |
torch.chunk |
一个例子
先来看下torchpruning 的流程图:
# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()
# 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4, round_to=16) # or manually selected pruning_idxs=[2, 6, 9, ...]
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan)
# 4. execute this plan (prune the model)
pruning_plan.exec()
print(model)
pruning_plan = DG.get_pruning_plan( pruning_idxs ):
底层剪枝函数
使用一层一层的固定剪枝和上面是等价的
tp.prune_conv( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )
运行结果:
(Conv2d(36, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
3456)
对设备友好的通道对齐剪枝
可以通过设置round_to
参数,下例可以使得通道对16取整(即,16,32,48,64)
strategy = tp.strategy.L1Strategy()
pruning_idxs = strategy(model.conv1.weight, amount=0.2, round_to=16)
本文暂时没有对torch pruning源码进行分析,先学会使用,后续如果有需要、有时间会再进行源码分析
Torch-Pruning工具箱的更多相关文章
- 07_利用pytorch的nn工具箱实现LeNet网络
07_利用pytorch的nn工具箱实现LeNet网络 目录 一.引言 二.定义网络 三.损失函数 四.优化器 五.数据加载和预处理 六.Hub模块简介 七.总结 pytorch完整教程目录:http ...
- EMD分析 Matlab 精华总结 附开源工具箱(全)
前言: 本贴写于2016年12与15日,UK.最近在学习EMD(Empirical Mode Decomposition)和HHT(Hilbert-Huang Transform)多分辨信号处理,FQ ...
- 相机标定简介与MatLab相机标定工具箱的使用(未涉及原理公式推导)
相机标定 一.相机标定的目的 确定空间物体表面某点的三维几何位置与其在图像中对应点之间的相互关系,建立摄像机成像的几何模型,这些几何模型参数就是摄像机参数. 二.通用摄像机模型 世界坐标系.摄像机坐标 ...
- Torch Problems: require some packages doesn't work
I've recently got a problem. require 'cutorch' doesn't work. But it was ok yesterday, although I hav ...
- Torch学习笔记1--Torch简介
Torch是什么 Torch是一个由Lua语言开发的深度学习框架,目前支持Mac OS X 和Ubuntu 12及以上,官网 ,github地址. 具有如下特点: 交互式开发工具 可视化式的工具 第三 ...
- 基于CkEditor实现.net在线开发之路(4)快速布局,工具箱,模板载入,tab选项卡简单说明与使用
上一章给常用的from表单控件属性页面,进行了简单说明和介绍,但是由于是在网页中做界面设计,操作肯定没有桌面应用程序方便,便捷,为了更方便的布局与设计,今天我主要说一下快速布局,工具箱,tab选项卡, ...
- 通用PE工具箱 4.0精简优化版
通用PE工具箱 4.0精简优化版 经用过不少 WinPE 系统,都不是很满意,普遍存在篡改主页.添加广告链接至收藏夹.未经允许安装推广软件等流氓行为,还集成了诸多不常用的工具,令人头疼不已.那么今天给 ...
- 创建 WPF 工具箱控件
创建 WPF 工具箱控件 WPF (Windows Presentation Framework) 工具箱控件模板允许您创建 WPF 控件,会自动添加到 工具箱 安装扩展的安装. 本主题演示如何使用模 ...
- 使用脚本自动配置matlab安装libsvm和随机森林工具箱
前言 支持向量机(SVM)和随机森林 都是用于分类的机器学习算法. 这里我需要对网上的工具箱在matlab中进行配置. 效果演示: 1.双击运行“自动配置.bat” 2.matlab会自动启动,手动配 ...
- .NET开发者必备的工具箱
本文作者Spencer是一名专注于ASP.NET和C#的程序员,他列举了平时工作.在家所使用的大部分开发工具,其中大部分工具都是集中于开发,当然也有一些其它用途的,比如图片处理.文件压缩等. 如果你是 ...
随机推荐
- 小米13Pro一键ROOT秒杀全版本
小米13p专属 通杀全版本 但是必须解开bl锁 小米13pro一键root使用方法: 解锁bl后,不要设置锁屏密码,有的话就取消掉,打开软件,点击安装驱动(管理员) 手机上打开usb调试和usb安装 ...
- 【攻防世界】BadProgrammer
BadProgrammer(原型链污染) 题目来源 攻防世界 NO.GFSJ0986 题目描述 打开网址页面如下,没有什么有用信息 用dirsearch扫一下目录,发现/static../(用御剑扫不 ...
- Flink学习(三) 批流版本的wordcount JAVA版本
Flink 开发环境通常来讲,任何一门大数据框架在实际生产环境中都是以集群的形式运行,而我们调试代码大多数会在本地搭建一个模板工程,Flink 也不例外. Flink 一个以 Java 及 Scala ...
- linux下配置ip为动态获取
点击查看代码 在Linux系统中配置网络接口以动态获取IP地址,通常需要使用DHCP(Dynamic Host Configuration Protocol).大多数现代Linux发行版都默认支持这个 ...
- FastAPI性能优化指南:参数解析与惰性加载
扫描二维码关注或者微信搜一搜:编程智域 前端至全栈交流与成长 探索数千个预构建的 AI 应用,开启你的下一个伟大创意 第一章:参数解析性能原理 1.1 FastAPI请求处理管线 async def ...
- go切片排序
前言 有时候我们需要根据切片中的某个字段进行切片排序,但sort包中只有默认基本类型 int . float64 和 string 的排序,所以我们可以手动实现sort包的 sort.Interfac ...
- cnpm : 无法加载文件 C:\Users\Raytine\AppData\Roaming\npm\cnpm.ps1,因为在此系统上禁止运行脚本。
解决方式: 1.在系统中搜索框 输入 Windos PowerShell 2.点击"管理员身份运行" 3.输入" set-ExecutionPolicy RemoteSi ...
- Laravel 配置连接多个数据库以及如何使用
目录 配置连接 配置 .env 文件 配置 \config\database.php 文件 使用 Schema Query Eloquent 配置连接 配置 .env 文件 /* 这部分是默认的数据库 ...
- 执行Django 的迁移命令报错[1193, Unknown system variable default_storage_engine]
在学习""编写你的第一个 Django 应用程序,第2部分"时候,遇到一个问题. 执行迁移命令 python manage.py makemigrations polls ...
- Tampermonkey 油猴脚本中文手册(出处:https://www.itblogcn.com/article/2233.html)
文章目录 @name @namespace @copyright @version @description @icon, @iconURL, @defaulticon @icon64, @icon6 ...