网络结构定义

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.datasets import CIFAR10
from torchvision import transforms
import numpy as np
import time class BasicBlock(nn.Module):
expansion = 1 def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
) def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out class Bottleneck(nn.Module):
expansion = 4 def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes) self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
) def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes) for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers) def forward(self, x, out_feature=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
feature = out.view(out.size(0), -1)
out = self.linear(feature)
if out_feature == False:
return out
else:
return out,feature def ResNet18(num_classes=10):
return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet50(num_classes=10):
return ResNet(Bottleneck, [3,4,6,3], num_classes)

speed test

原始模型 ResNet18

剪枝策略: L1Strategy 各Block裁剪比率 [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]

比较原始网络,通道不取整,通道按照16倍数取整的推理速度


def measure_inference_time(net, input, repeat=100):
# torch.cuda.synchronize() # if use cuda uncomment it
start = time.perf_counter()
for _ in range(repeat):
model(input)
#torch.cuda.synchronize() # if use cuda uncomment it
end = time.perf_counter()
return (end-start) / repeat def prune_model(model, round_to=1):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2, round_to=1):
#weight = conv.weight.detach().cpu().numpy()
#out_channels = weight.shape[0]
#L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
#num_pruned = int(out_channels * pruned_prob)
#pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount, round_to=round_to)
plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec() block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, BasicBlock ):
prune_conv( m.conv1, block_prune_probs[blk_id], round_to )
prune_conv( m.conv2, block_prune_probs[blk_id], round_to )
blk_id+=1
return model device = torch.device('cpu') #torch.device('cuda') # or torch.device('cpu')
repeat = 100 # before pruning
model = ResNet18().eval()
fake_input = torch.randn(16,3,32,32)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model))) # w/o rounding
model = ResNet18().eval()
prune_model(model)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model))) # w/ rounding
model = ResNet18().eval()
prune_model(model, round_to=16)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))

accuracy test

from cifar_resnet import ResNet18
import cifar_resnet as resnet def get_dataloader():
train_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]), download=True),batch_size=256, num_workers=2)
test_loader = torch.utils.data.DataLoader(
CIFAR10('./chapter3_data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
]),download=True),batch_size=256, num_workers=2)
return train_loader, test_loader def eval(model, test_loader):
correct = 0
total = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
for i, (img, target) in enumerate(test_loader):
img = img.to(device)
out = model(img)
pred = out.max(1)[1].detach().cpu().numpy()
target = target.cpu().numpy()
correct += (pred==target).sum()
total += len(target)
return correct / total _, test_loader = get_dataloader() # original
previous_ckpt = 'resnet18-round0.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("before pruning: Acc=%.4f"%(acc)) # w/o rounding
previous_ckpt = 'resnet18-pruning-noround.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/o rounding: Acc=%.4f"%(acc)) # w/ rounding
previous_ckpt = 'resnet18-pruning-round_to16.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/ rounding: Acc=%.4f"%(acc))

使用torch pruning工具进行结构化剪枝的更多相关文章

  1. 软工+C(2017第5期) 工具和结构化

    // 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 工具/轮子 软件工程/计算机相关专业的一个特点是会使用到众多的工具,工具的使用是从程序猿进化到程序员的一个关键要素.软件工程师 ...

  2. 软工+C(5): 工具和结构化(重构中, part 1...)

    // 上一篇:Alpha/Beta换人 // 下一篇:最近发展区/脚手架 目录: ** 0x01 讨论:工具/轮子 ** 0x02 讨论:结构/演进 ** 0x03 讨论:行为/活动 ** 0x04 ...

  3. Apache Sqoop 结构化、非结构化数据转换工具

    简介: Apache Sqoop 是一种用于 Apache Hadoop 与关系型数据库之间结构化.非结构化数据转换的工具. 一.安装 MySQL.导入测试数据 1.文档链接:http://www.c ...

  4. Spark如何与深度学习框架协作,处理非结构化数据

    随着大数据和AI业务的不断融合,大数据分析和处理过程中,通过深度学习技术对非结构化数据(如图片.音频.文本)进行大数据处理的业务场景越来越多.本文会介绍Spark如何与深度学习框架进行协同工作,在大数 ...

  5. cuSPARSELt开发NVIDIA Ampere结构化稀疏性

    cuSPARSELt开发NVIDIA Ampere结构化稀疏性 深度神经网络在各种领域(例如计算机视觉,语音识别和自然语言处理)中均具有出色的性能.处理这些神经网络所需的计算能力正在迅速提高,因此有效 ...

  6. WordPress插件--WP BaiDu Submit结构化数据插件又快又全的向百度提交网页

    一.WP BaiDu Submit 简介 WP BaiDu Submit帮助具有百度站长平台链接提交权限的用户自动提交最新文章,以保证新链接可以及时被百度收录. 安装WP BaiDu Submit后, ...

  7. 利用Gson和SharePreference存储结构化数据

    问题的导入 Android互联网产品通常会有很多的结构化数据需要保存,比如对于登录这个流程,通常会保存诸如username.profile_pic.access_token等等之类的数据,这些数据可以 ...

  8. 最近打算体验一下discuz,有不错的结构化数据插件

    提交sitemap是每位站长必做的事情,但是提交到哪里,能不能提交又是另外一回事.国内的话百度是大伙都会盯的蛋糕,BD站长工具也会去注册的,可有些账号sitemap模块一直不能用,或许是等级不够,就像 ...

  9. seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码

    seo之google rich-snippets丰富网页摘要结构化数据(微数据)实例代码 网页摘要是搜索引擎搜索结果下的几行字,用户能通过网页摘要迅速了解到网页的大概内容,传统的摘要是纯文字摘要,而结 ...

  10. 【阿里云产品公测】结构化数据服务OTS之JavaSDK初体验

    [阿里云产品公测]结构化数据服务OTS之JavaSDK初体验 作者:阿里云用户蓝色之鹰 一.OTS简单介绍 OTS 是构建在阿里云飞天分布式系统之上的NoSQL数据库服务,提供海量结构化数据的存储和实 ...

随机推荐

  1. 饥荒联机版 mod

    Extra Equip Slots Extra Equip Slots 额外装备槽:护甲.背包.护符各用各的插槽 Wormhole Marks [DST] Wormhole Marks DST] 相通 ...

  2. TV盒子常用的影视APP和直播软件分享合集

    最近自己也在倒腾机顶盒,少不了直播.影视APP,当然只会收集无广告和无会员的版本,文章介绍部分APP,链接里面我会放目前收集的合集,一直会更新. 本文资源下载: 2025.2.18号更新: 包含直播. ...

  3. Kubernetes - [04] 常用命令

    kubectl 语法 kubectl [command] [TYPE] [NAME] [flags] command:指定在一个或多个资源商要执行的操作.例如:create.get.describe. ...

  4. Elasticsearch搜索引擎学习笔记(三)

    索引的一些操作 集群健康 GET /_cluster/health 创建索引 PUT /index_test { "settings": { "index": ...

  5. LaTeX使用记录

    安装与使用 曾在Windows10下装过MikTeX,并配合vscode插件LaTeX Workshop使用过一段时间:这次转到wsl2中,并使用texlive,所以插件的配置json需要小修改 参考 ...

  6. Error: Address already in use

    端口被某个进程占用 使用命令 lsof -i:端口号 然后看到进程号,直接杀掉进程就好 kill -9 进程号

  7. jquery submit 解决多次提交

    jquery submit 解决多次提交 web应用中常见的问题就是多次提交,由于表单提交的延迟,有时几秒或者更长,让用户有机会多次点击提交按钮,从而导致服务器端代码的种种麻烦. 为了解决这个问题,我 ...

  8. 神经网络与模式识别课程报告-卷积神经网络(CNN)算法的应用

    ======================================================================================= 完整的神经网络与模式识别 ...

  9. Windows Server评估版/正式版/数据中心版的来源及转换

    评估版: 从微软评估中心下载的版本,相当于微软提供的试用版,可免费使用一段时间.但该版本无法使用 KMS授权或 MAS 永久授权进行激活. 正式版/数据中心版: 从微软许可证中心下载的版本已标识了GL ...

  10. ANSYS 导出节点的位移数据

    1. 数据保存 确定待提取的节点编号: 获取节点位移变量: 将节点位移变量存储到数组中,用于数据传递: ! 输出对应节点的位移到csv文件 ! 注意同时导入.db和.rst,并切换到/post26模块 ...