使用torch pruning工具进行结构化剪枝
网络结构定义
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.datasets import CIFAR10
from torchvision import transforms
import numpy as np
import time
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, out_feature=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
feature = out.view(out.size(0), -1)
out = self.linear(feature)
if out_feature == False:
return out
else:
return out,feature
def ResNet18(num_classes=10):
return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet50(num_classes=10):
return ResNet(Bottleneck, [3,4,6,3], num_classes)
speed test
原始模型 ResNet18
剪枝策略: L1Strategy 各Block裁剪比率 [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
比较原始网络,通道不取整,通道按照16倍数取整的推理速度
def measure_inference_time(net, input, repeat=100):
# torch.cuda.synchronize() # if use cuda uncomment it
start = time.perf_counter()
for _ in range(repeat):
model(input)
#torch.cuda.synchronize() # if use cuda uncomment it
end = time.perf_counter()
return (end-start) / repeat
def prune_model(model, round_to=1):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2, round_to=1):
#weight = conv.weight.detach().cpu().numpy()
#out_channels = weight.shape[0]
#L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
#num_pruned = int(out_channels * pruned_prob)
#pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount, round_to=round_to)
plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec()
block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, BasicBlock ):
prune_conv( m.conv1, block_prune_probs[blk_id], round_to )
prune_conv( m.conv2, block_prune_probs[blk_id], round_to )
blk_id+=1
return model
device = torch.device('cpu') #torch.device('cuda') # or torch.device('cpu')
repeat = 100
# before pruning
model = ResNet18().eval()
fake_input = torch.randn(16,3,32,32)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))
# w/o rounding
model = ResNet18().eval()
prune_model(model)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
# w/ rounding
model = ResNet18().eval()
prune_model(model, round_to=16)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))
accuracy test
from cifar_resnet import ResNet18
import cifar_resnet as resnet
def get_dataloader():
train_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]), download=True),batch_size=256, num_workers=2)
test_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
]),download=True),batch_size=256, num_workers=2)
return train_loader, test_loader
def eval(model, test_loader):
correct = 0
total = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
for i, (img, target) in enumerate(test_loader):
img = img.to(device)
out = model(img)
pred = out.max(1)[1].detach().cpu().numpy()
target = target.cpu().numpy()
correct += (pred==target).sum()
total += len(target)
return correct / total
_, test_loader = get_dataloader()
# original
previous_ckpt = 'resnet18-round0.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("before pruning: Acc=%.4f"%(acc))
# w/o rounding
previous_ckpt = 'resnet18-pruning-noround.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/o rounding: Acc=%.4f"%(acc))
# w/ rounding
previous_ckpt = 'resnet18-pruning-round_to16.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/ rounding: Acc=%.4f"%(acc))
使用torch pruning工具进行结构化剪枝的更多相关文章
- 软工+C(2017第5期) 工具和结构化
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 工具/轮子 软件工程/计算机相关专业的一个特点是会使用到众多的工具,工具的使用是从程序猿进化到程序员的一个关键要素.软件工程师 ...
- 软工+C(5): 工具和结构化(重构中, part 1...)
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 目录: ** 0x01 讨论:工具/轮子 ** 0x02 讨论:结构/演进 ** 0x03 讨论:行为/活动 ** 0x04 ...
- Apache Sqoop 结构化、非结构化数据转换工具
简介: Apache Sqoop 是一种用于 Apache Hadoop 与关系型数据库之间结构化.非结构化数据转换的工具. 一.安装 MySQL.导入测试数据 1.文档链接:http://www.c ...
- Spark如何与深度学习框架协作,处理非结构化数据
随着大数据和AI业务的不断融合,大数据分析和处理过程中,通过深度学习技术对非结构化数据(如图片.音频.文本)进行大数据处理的业务场景越来越多.本文会介绍Spark如何与深度学习框架进行协同工作,在大数 ...
- cuSPARSELt开发NVIDIA Ampere结构化稀疏性
cuSPARSELt开发NVIDIA Ampere结构化稀疏性 深度神经网络在各种领域(例如计算机视觉,语音识别和自然语言处理)中均具有出色的性能.处理这些神经网络所需的计算能力正在迅速提高,因此有效 ...
- WordPress插件--WP BaiDu Submit结构化数据插件又快又全的向百度提交网页
一.WP BaiDu Submit 简介 WP BaiDu Submit帮助具有百度站长平台链接提交权限的用户自动提交最新文章,以保证新链接可以及时被百度收录. 安装WP BaiDu Submit后, ...
- 利用Gson和SharePreference存储结构化数据
问题的导入 Android互联网产品通常会有很多的结构化数据需要保存,比如对于登录这个流程,通常会保存诸如username.profile_pic.access_token等等之类的数据,这些数据可以 ...
- 最近打算体验一下discuz,有不错的结构化数据插件
提交sitemap是每位站长必做的事情,但是提交到哪里,能不能提交又是另外一回事.国内的话百度是大伙都会盯的蛋糕,BD站长工具也会去注册的,可有些账号sitemap模块一直不能用,或许是等级不够,就像 ...
- seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码
seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码 网页摘要是搜索引擎搜索结果下的几行字,用户能通过网页摘要迅速了解到网页的大概内容,传统的摘要是纯文字摘要,而结 ...
- 【阿里云产品公测】结构化数据服务OTS之JavaSDK初体验
[阿里云产品公测]结构化数据服务OTS之JavaSDK初体验 作者:阿里云用户蓝色之鹰 一.OTS简单介绍 OTS 是构建在阿里云飞天分布式系统之上的NoSQL数据库服务,提供海量结构化数据的存储和实 ...
随机推荐
- thymeleaf的手动渲染HTML模板
thymeleaf的手动渲染HTML模板 长河 2018-11-14 11:18:10 6833 收藏 2 分类专栏: Springboot 版权 现在很多公司都在thymeleaf作为前端的显示 ...
- IDM 下载器 汉化注册激活
将以下程序更改为.bat文件.使用windows的管理员权限打开. 汉化于:https://github.com/WindowsAddict/IDM-Activation-Script @setloc ...
- Typecho复制文章自带版权说明
自带版权说明代码 <script> document.body.addEventListener('copy', function (e) { if (window.getSelectio ...
- 傻妞教程——对接PagerMaid-Pyro
PagerMaid-Pyro 是一个开源的 TG 人形自走 Bot 方案,功能强大而丰富,可以帮助你打造专属的便利功能. 为什么叫人形机器人? TG 官方是有 Bot Api 的,但是这个 Api 需 ...
- VsCode安装Copilot详细教程
安装GitHub Copilot插件前,您需要安装以下软件: 安装Visual Studio Code:前往https://code.visualstudio.com下载并安装最新版的Visual S ...
- docker - [04] 常用命令
官方文档:https://docs.docker.com/reference/ 一.帮助命令 1.1.docker version 查看docker的版本信息 1.2.docker info 显示do ...
- 【BUUCTF】easy calc
[BUUCTF]easy calc (PHP代码审计) 题目来源 收录于:BUUCTF BUUCTF2019 题目描述 一个计算器,尝试SSTI,SQL注入都无果 对计算过程抓包,发现/clac.ph ...
- Angular CLI 源码分析
准备: 安装 Node.js https://nodejs.org/: 安装 VS Code https://code.visualstudio.com/: 创建文件夹 angular-cli-sou ...
- Java Map一些基本使用方法
1 // Map key值不能相同,value值可以相同 2 // HashMap中的Entry对象是无序排列的 3 4 // 实例化1 5 Map<String, String> map ...
- FastAPI Cookie 和 Header 参数完全指南:从基础到高级实战 🚀
title: FastAPI Cookie 和 Header 参数完全指南:从基础到高级实战 date: 2025/3/9 updated: 2025/3/9 author: cmdragon exc ...