使用torch pruning工具进行结构化剪枝
网络结构定义
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.datasets import CIFAR10
from torchvision import transforms
import numpy as np
import time
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, out_feature=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
feature = out.view(out.size(0), -1)
out = self.linear(feature)
if out_feature == False:
return out
else:
return out,feature
def ResNet18(num_classes=10):
return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet50(num_classes=10):
return ResNet(Bottleneck, [3,4,6,3], num_classes)
speed test
原始模型 ResNet18
剪枝策略: L1Strategy 各Block裁剪比率 [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
比较原始网络,通道不取整,通道按照16倍数取整的推理速度
def measure_inference_time(net, input, repeat=100):
# torch.cuda.synchronize() # if use cuda uncomment it
start = time.perf_counter()
for _ in range(repeat):
model(input)
#torch.cuda.synchronize() # if use cuda uncomment it
end = time.perf_counter()
return (end-start) / repeat
def prune_model(model, round_to=1):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2, round_to=1):
#weight = conv.weight.detach().cpu().numpy()
#out_channels = weight.shape[0]
#L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
#num_pruned = int(out_channels * pruned_prob)
#pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount, round_to=round_to)
plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec()
block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, BasicBlock ):
prune_conv( m.conv1, block_prune_probs[blk_id], round_to )
prune_conv( m.conv2, block_prune_probs[blk_id], round_to )
blk_id+=1
return model
device = torch.device('cpu') #torch.device('cuda') # or torch.device('cpu')
repeat = 100
# before pruning
model = ResNet18().eval()
fake_input = torch.randn(16,3,32,32)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))
# w/o rounding
model = ResNet18().eval()
prune_model(model)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
# w/ rounding
model = ResNet18().eval()
prune_model(model, round_to=16)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))
accuracy test
from cifar_resnet import ResNet18
import cifar_resnet as resnet
def get_dataloader():
train_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]), download=True),batch_size=256, num_workers=2)
test_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
]),download=True),batch_size=256, num_workers=2)
return train_loader, test_loader
def eval(model, test_loader):
correct = 0
total = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
for i, (img, target) in enumerate(test_loader):
img = img.to(device)
out = model(img)
pred = out.max(1)[1].detach().cpu().numpy()
target = target.cpu().numpy()
correct += (pred==target).sum()
total += len(target)
return correct / total
_, test_loader = get_dataloader()
# original
previous_ckpt = 'resnet18-round0.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("before pruning: Acc=%.4f"%(acc))
# w/o rounding
previous_ckpt = 'resnet18-pruning-noround.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/o rounding: Acc=%.4f"%(acc))
# w/ rounding
previous_ckpt = 'resnet18-pruning-round_to16.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/ rounding: Acc=%.4f"%(acc))
使用torch pruning工具进行结构化剪枝的更多相关文章
- 软工+C(2017第5期) 工具和结构化
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 工具/轮子 软件工程/计算机相关专业的一个特点是会使用到众多的工具,工具的使用是从程序猿进化到程序员的一个关键要素.软件工程师 ...
- 软工+C(5): 工具和结构化(重构中, part 1...)
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 目录: ** 0x01 讨论:工具/轮子 ** 0x02 讨论:结构/演进 ** 0x03 讨论:行为/活动 ** 0x04 ...
- Apache Sqoop 结构化、非结构化数据转换工具
简介: Apache Sqoop 是一种用于 Apache Hadoop 与关系型数据库之间结构化.非结构化数据转换的工具. 一.安装 MySQL.导入测试数据 1.文档链接:http://www.c ...
- Spark如何与深度学习框架协作,处理非结构化数据
随着大数据和AI业务的不断融合,大数据分析和处理过程中,通过深度学习技术对非结构化数据(如图片.音频.文本)进行大数据处理的业务场景越来越多.本文会介绍Spark如何与深度学习框架进行协同工作,在大数 ...
- cuSPARSELt开发NVIDIA Ampere结构化稀疏性
cuSPARSELt开发NVIDIA Ampere结构化稀疏性 深度神经网络在各种领域(例如计算机视觉,语音识别和自然语言处理)中均具有出色的性能.处理这些神经网络所需的计算能力正在迅速提高,因此有效 ...
- WordPress插件--WP BaiDu Submit结构化数据插件又快又全的向百度提交网页
一.WP BaiDu Submit 简介 WP BaiDu Submit帮助具有百度站长平台链接提交权限的用户自动提交最新文章,以保证新链接可以及时被百度收录. 安装WP BaiDu Submit后, ...
- 利用Gson和SharePreference存储结构化数据
问题的导入 Android互联网产品通常会有很多的结构化数据需要保存,比如对于登录这个流程,通常会保存诸如username.profile_pic.access_token等等之类的数据,这些数据可以 ...
- 最近打算体验一下discuz,有不错的结构化数据插件
提交sitemap是每位站长必做的事情,但是提交到哪里,能不能提交又是另外一回事.国内的话百度是大伙都会盯的蛋糕,BD站长工具也会去注册的,可有些账号sitemap模块一直不能用,或许是等级不够,就像 ...
- seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码
seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码 网页摘要是搜索引擎搜索结果下的几行字,用户能通过网页摘要迅速了解到网页的大概内容,传统的摘要是纯文字摘要,而结 ...
- 【阿里云产品公测】结构化数据服务OTS之JavaSDK初体验
[阿里云产品公测]结构化数据服务OTS之JavaSDK初体验 作者:阿里云用户蓝色之鹰 一.OTS简单介绍 OTS 是构建在阿里云飞天分布式系统之上的NoSQL数据库服务,提供海量结构化数据的存储和实 ...
随机推荐
- 饥荒联机版 mod
Extra Equip Slots Extra Equip Slots 额外装备槽:护甲.背包.护符各用各的插槽 Wormhole Marks [DST] Wormhole Marks DST] 相通 ...
- TV盒子常用的影视APP和直播软件分享合集
最近自己也在倒腾机顶盒,少不了直播.影视APP,当然只会收集无广告和无会员的版本,文章介绍部分APP,链接里面我会放目前收集的合集,一直会更新. 本文资源下载: 2025.2.18号更新: 包含直播. ...
- Kubernetes - [04] 常用命令
kubectl 语法 kubectl [command] [TYPE] [NAME] [flags] command:指定在一个或多个资源商要执行的操作.例如:create.get.describe. ...
- Elasticsearch搜索引擎学习笔记(三)
索引的一些操作 集群健康 GET /_cluster/health 创建索引 PUT /index_test { "settings": { "index": ...
- LaTeX使用记录
安装与使用 曾在Windows10下装过MikTeX,并配合vscode插件LaTeX Workshop使用过一段时间:这次转到wsl2中,并使用texlive,所以插件的配置json需要小修改 参考 ...
- Error: Address already in use
端口被某个进程占用 使用命令 lsof -i:端口号 然后看到进程号,直接杀掉进程就好 kill -9 进程号
- jquery submit 解决多次提交
jquery submit 解决多次提交 web应用中常见的问题就是多次提交,由于表单提交的延迟,有时几秒或者更长,让用户有机会多次点击提交按钮,从而导致服务器端代码的种种麻烦. 为了解决这个问题,我 ...
- 神经网络与模式识别课程报告-卷积神经网络(CNN)算法的应用
======================================================================================= 完整的神经网络与模式识别 ...
- Windows Server评估版/正式版/数据中心版的来源及转换
评估版: 从微软评估中心下载的版本,相当于微软提供的试用版,可免费使用一段时间.但该版本无法使用 KMS授权或 MAS 永久授权进行激活. 正式版/数据中心版: 从微软许可证中心下载的版本已标识了GL ...
- ANSYS 导出节点的位移数据
1. 数据保存 确定待提取的节点编号: 获取节点位移变量: 将节点位移变量存储到数组中,用于数据传递: ! 输出对应节点的位移到csv文件 ! 注意同时导入.db和.rst,并切换到/post26模块 ...