使用torch pruning工具进行结构化剪枝
网络结构定义
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.datasets import CIFAR10
from torchvision import transforms
import numpy as np
import time
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, out_feature=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
feature = out.view(out.size(0), -1)
out = self.linear(feature)
if out_feature == False:
return out
else:
return out,feature
def ResNet18(num_classes=10):
return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet50(num_classes=10):
return ResNet(Bottleneck, [3,4,6,3], num_classes)
speed test
原始模型 ResNet18
剪枝策略: L1Strategy 各Block裁剪比率 [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
比较原始网络,通道不取整,通道按照16倍数取整的推理速度
def measure_inference_time(net, input, repeat=100):
# torch.cuda.synchronize() # if use cuda uncomment it
start = time.perf_counter()
for _ in range(repeat):
model(input)
#torch.cuda.synchronize() # if use cuda uncomment it
end = time.perf_counter()
return (end-start) / repeat
def prune_model(model, round_to=1):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2, round_to=1):
#weight = conv.weight.detach().cpu().numpy()
#out_channels = weight.shape[0]
#L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
#num_pruned = int(out_channels * pruned_prob)
#pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount, round_to=round_to)
plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec()
block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, BasicBlock ):
prune_conv( m.conv1, block_prune_probs[blk_id], round_to )
prune_conv( m.conv2, block_prune_probs[blk_id], round_to )
blk_id+=1
return model
device = torch.device('cpu') #torch.device('cuda') # or torch.device('cpu')
repeat = 100
# before pruning
model = ResNet18().eval()
fake_input = torch.randn(16,3,32,32)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))
# w/o rounding
model = ResNet18().eval()
prune_model(model)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
# w/ rounding
model = ResNet18().eval()
prune_model(model, round_to=16)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))
accuracy test
from cifar_resnet import ResNet18
import cifar_resnet as resnet
def get_dataloader():
train_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]), download=True),batch_size=256, num_workers=2)
test_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
]),download=True),batch_size=256, num_workers=2)
return train_loader, test_loader
def eval(model, test_loader):
correct = 0
total = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
for i, (img, target) in enumerate(test_loader):
img = img.to(device)
out = model(img)
pred = out.max(1)[1].detach().cpu().numpy()
target = target.cpu().numpy()
correct += (pred==target).sum()
total += len(target)
return correct / total
_, test_loader = get_dataloader()
# original
previous_ckpt = 'resnet18-round0.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("before pruning: Acc=%.4f"%(acc))
# w/o rounding
previous_ckpt = 'resnet18-pruning-noround.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/o rounding: Acc=%.4f"%(acc))
# w/ rounding
previous_ckpt = 'resnet18-pruning-round_to16.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/ rounding: Acc=%.4f"%(acc))
使用torch pruning工具进行结构化剪枝的更多相关文章
- 软工+C(2017第5期) 工具和结构化
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 工具/轮子 软件工程/计算机相关专业的一个特点是会使用到众多的工具,工具的使用是从程序猿进化到程序员的一个关键要素.软件工程师 ...
- 软工+C(5): 工具和结构化(重构中, part 1...)
// 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 目录: ** 0x01 讨论:工具/轮子 ** 0x02 讨论:结构/演进 ** 0x03 讨论:行为/活动 ** 0x04 ...
- Apache Sqoop 结构化、非结构化数据转换工具
简介: Apache Sqoop 是一种用于 Apache Hadoop 与关系型数据库之间结构化.非结构化数据转换的工具. 一.安装 MySQL.导入测试数据 1.文档链接:http://www.c ...
- Spark如何与深度学习框架协作,处理非结构化数据
随着大数据和AI业务的不断融合,大数据分析和处理过程中,通过深度学习技术对非结构化数据(如图片.音频.文本)进行大数据处理的业务场景越来越多.本文会介绍Spark如何与深度学习框架进行协同工作,在大数 ...
- cuSPARSELt开发NVIDIA Ampere结构化稀疏性
cuSPARSELt开发NVIDIA Ampere结构化稀疏性 深度神经网络在各种领域(例如计算机视觉,语音识别和自然语言处理)中均具有出色的性能.处理这些神经网络所需的计算能力正在迅速提高,因此有效 ...
- WordPress插件--WP BaiDu Submit结构化数据插件又快又全的向百度提交网页
一.WP BaiDu Submit 简介 WP BaiDu Submit帮助具有百度站长平台链接提交权限的用户自动提交最新文章,以保证新链接可以及时被百度收录. 安装WP BaiDu Submit后, ...
- 利用Gson和SharePreference存储结构化数据
问题的导入 Android互联网产品通常会有很多的结构化数据需要保存,比如对于登录这个流程,通常会保存诸如username.profile_pic.access_token等等之类的数据,这些数据可以 ...
- 最近打算体验一下discuz,有不错的结构化数据插件
提交sitemap是每位站长必做的事情,但是提交到哪里,能不能提交又是另外一回事.国内的话百度是大伙都会盯的蛋糕,BD站长工具也会去注册的,可有些账号sitemap模块一直不能用,或许是等级不够,就像 ...
- seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码
seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码 网页摘要是搜索引擎搜索结果下的几行字,用户能通过网页摘要迅速了解到网页的大概内容,传统的摘要是纯文字摘要,而结 ...
- 【阿里云产品公测】结构化数据服务OTS之JavaSDK初体验
[阿里云产品公测]结构化数据服务OTS之JavaSDK初体验 作者:阿里云用户蓝色之鹰 一.OTS简单介绍 OTS 是构建在阿里云飞天分布式系统之上的NoSQL数据库服务,提供海量结构化数据的存储和实 ...
随机推荐
- Java中List通过Lambda实现排序
目录 1.正常排序,1,2,3 2.倒序 3,2,1 1.正常排序,1,2,3 list=list.stream().sorted(Comparator.comparing(VipCardVo::ge ...
- 5090D-deepseek-Anythingllm-Ollama运行测试
ollama ollama配置环境变量 ollama地址与镜像 C:\Users\DK>curl http://10.208.10.240:11434 Ollama is running C:\ ...
- 新版Edge 浏览器几种详细的卸载方法
一.使用Windows系统自带卸载程序卸载 1.打开"控制面板"(可在Windows搜索栏中搜索): 2.选择"程序和功能": 3.找到Microsoft Ed ...
- LayerSkip: 使用自推测解码加速大模型推理
自推测解码是一种新颖的文本生成方法,它结合了推测解码 (Speculative Decoding) 的优势和大语言模型 (LLM) 的提前退出 (Early Exit) 机制.该方法出自论文 Laye ...
- 22. Generate Parentheses--求n对括号组成可以组成的全部有效括号序列
描述: Given n pairs of parentheses, write a function to generate all combinations of well-formed paren ...
- Kubernetes:根据进程 Pid 获取 Pod 名称
前言 在管理 Kubernetes 集群的过程中,我们经常会遇到这样一种情况:在某台节点上发现某个进程资源占用量很高,却又不知道是哪个容器里的进程.有没有办法可以根据进程 PID 快速找到 Pod 名 ...
- Mac port 443: Connection refused
MAC 安装brew raw.githubusercontent.com port 443: Connection refused 本人亲自认证过,踩过多种方案,最终认证的解决方案 原因:由于某些你懂 ...
- bug|项目经验|记录某次页面div使用v-html标签渲染图片等内容的过程
前言 记录某次页面div使用v-html标签渲染图片等内容的过程 一.结论: get请求但被设置Sec-Fetch-*请求头的图片无法展示. 二.原因: 1.本项目中的img标签发起get请求,目标链 ...
- MD5加密BASE64加解密
MD5需要引入system.Hash,BASE64需要引入System.NetEncoding,这两个单元应该只有高版本的DELPHI IDE才有 (貌似XE5以上版本才有).如果是D7的话,找第三方 ...
- 分布式任务调度系统 xxl-job
微服务难不难,不难!无非就是一个消费方,一个生产方,一个注册中心,然后就是实现一些微服务,其实微服务的难点在于治理,给你一堆 微服务,如何来管理?这就有很多方面了,比如容器化,服务间通信,服务上下线发 ...