Learning Efficient Convolutional Networks through Network Slimming

简介

这是我看的第一篇模型压缩方面的论文,应该也算比较出名的一篇吧,因为很早就对模型压缩比较感兴趣,所以抽了个时间看了一篇,代码也自己实现了一下,觉得还是挺容易的。这篇文章就模型压缩问题提出了一种剪枝针对BN层的剪枝方法,作者通过利用BN层的权重来评估输入channel的score,通过对score进行threshold过滤到score低的channel,在连接的时候这些score太小的channel的神经元就不参与连接,然后逐层剪枝,就达到了压缩效果。

就我个人而言,现在常用的attention mechanism我认为可以用来评估channel的score可以做一做文章,但是肯定是针对特定任务而言的,后面我会自己做一做实验,利用attention机制来模型剪枝。

方法

本文的方法如图所示,即

  1. 给定要保留层的比例,记下所有BN层大于该比例的权重
  2. 对模型先进行BN层的剪枝,即丢弃小于上面权重比例的参数
  3. 对模型进行卷积层剪枝(因为通常是卷积层后+BN,所以知道由前后的BN层可以知道卷积层权重size),对卷积层的size做匹配前后BN的对应channel元素丢弃的剪枝。
  4. 对FC层进行剪枝

感觉说不太清楚,但是一看代码就全懂了。。

代码

我自己实现了一下。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torchsummary import summary class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.convnet = nn.Sequential(
nn.Conv2d(3,16,kernel_size = 3),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16,32,kernel_size = 3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,64,kernel_size = 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,128,kernel_size = 3),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.maxpool = nn.MaxPool2d(216)
self.fc = nn.Linear(128,3) def forward(self,x):
x = self.convnet(x)
x = self.maxpool(x)
x = x.view(-1,x.size(1))
return self.fc(x) if __name__ == "__main__":
net = Net()
net_new = Net()
idxs = []
idxs.append(range(3))
for module in net.modules():
if type(module) is nn.BatchNorm2d:
weight = module.weight.data
n = weight.size(0)
y,idx = torch.sort(weight)
n = int(0.8 * n)
idxs.append(idx[:n])
#print(module.weight.data.size())
i=1
for module in net_new.modules():
if type(module) is nn.Conv2d:
weight = module.weight.data.clone()
weight = weight[idxs[i],:,:,:]
weight = weight[:,idxs[i-1],:,:]
module.bias.data = module.bias.data[idxs[i]]
module.weight.data = weight
elif type(module) is nn.BatchNorm2d:
weight = module.weight.data.clone()
bias = module.bias.data.clone()
running_mean = module.running_mean.data.clone()
running_var = module.running_var.data.clone() weight = weight[idxs[i]]
bias = bias[idxs[i]]
running_mean = running_mean[idxs[i]]
running_var = running_var[idxs[i]] module.weight.data = weight
module.bias.data = bias
module.running_var.data = running_var
module.running_mean.data = running_mean
i += 1
elif type(module) is nn.Linear:
#print(module.weight.data.size())
module.weight.data = module.weight.data[:,idxs[-1]] summary(net_new,(3,224,224),device = "cpu")
'''
这是对vgg的剪枝例子,文章中说了对其他网络的slimming例子
'''
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.models import vgg19
from models import * # Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar100',
help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
help='input batch size for testing (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--depth', type=int, default=19,
help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5,
help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',
help='path to the model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',
help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available() if not os.path.exists(args.save):
os.makedirs(args.save) model = vgg19(dataset=args.dataset, depth=args.depth)
if args.cuda:
model.cuda() if args.model:
if os.path.isfile(args.model):
print("=> loading checkpoint '{}'".format(args.model))
checkpoint = torch.load(args.model)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
.format(args.model, checkpoint['epoch'], best_prec1))
else:
print("=> no checkpoint found at '{}'".format(args.resume)) print(model)
total = 0
for m in model.modules():# 遍历vgg的每个module
if isinstance(m, nn.BatchNorm2d): # 如果发现BN层
total += m.weight.data.shape[0] # BN层的特征数目,total就是所有BN层的特征数目总和 bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size # 把所有BN层的权重给CLONE下来 y, i = torch.sort(bn) # 这些权重排序
thre_index = int(total * args.percent) # 要保留的数量
thre = y[thre_index] # 最小的权重值 pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()# 小于权重thre的为0,大于的为1
pruned = pruned + mask.shape[0] - torch.sum(mask) # 被剪枝的权重的总数
m.weight.data.mul_(mask) # 权重对应相乘
m.bias.data.mul_(mask) # 偏置也对应相乘
cfg.append(int(torch.sum(mask))) #第几个batchnorm保留多少。
cfg_mask.append(mask.clone()) # 第几个batchnorm 保留的weight
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M') pruned_ratio = pruned/total # 剪枝比例 print('Pre-processing Successful!') # simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'cifar10':
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.dataset == 'cifar100':
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
else:
raise ValueError("No valid dataset is given.")
model.eval()
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
return correct / float(len(test_loader.dataset)) acc = test(model) # Make real prune
print(cfg)
newmodel = vgg(dataset=args.dataset, cfg=cfg)
if args.cuda:
newmodel.cuda()
# torch.nelement() 可以统计张量的个数
num_parameters = sum([param.nelement() for param in newmodel.parameters()]) # 元素个数,比如对于张量shape为(20,3,3,3),那么他的元素个数就是四者乘积也就是20*27 = 540
# 可以用来统计参数量 嘿嘿
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, "w") as fp:
fp.write("Configuration: \n"+str(cfg)+"\n")
fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
fp.write("Test accuracy: \n"+str(acc)) layer_id_in_cfg = 0 # 第几层
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg] #
for [m0, m1] in zip(model.modules(), newmodel.modules()):
if isinstance(m0, nn.BatchNorm2d):
# np.where 返回的是所有满足条件的数的索引,有多少个满足条件的数就有多少个索引,绝对的索引
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 大于0的所有数据的索引,squeeze变成向量
if idx1.size == 1: # 只有一个要变成数组的1个
idx1 = np.resize(idx1,(1,))
m1.weight.data = m0.weight.data[idx1.tolist()].clone() # 用经过剪枝的替换原来的
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1 # 下一层
start_mask = end_mask.clone() # 当前在处理的层的mask
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d): # 对卷积层进行剪枝
# 卷积后面会接bn
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 这个剪枝牛B了。。
w1 = w1[idx1.tolist(), :, :, :].clone() # 最终的权重矩阵
m1.weight.data = w1.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone() torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) print(newmodel)
model = newmodel
test(model)

[论文理解] Learning Efficient Convolutional Networks through Network Slimming的更多相关文章

  1. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

  2. [论文理解] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

    MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications Intro MobileNet 我 ...

  3. 论文翻译:2020_WaveCRN: An efficient convolutional recurrent neural network for end-to-end speech enhancement

    论文地址:用于端到端语音增强的卷积递归神经网络 论文代码:https://github.com/aleXiehta/WaveCRN 引用格式:Hsieh T A, Wang H M, Lu X, et ...

  4. 图像处理论文详解 | Deformable Convolutional Networks | CVPR | 2017

    文章转自同一作者的微信公众号:[机器学习炼丹术] 论文名称:"Deformable Convolutional Networks" 论文链接:https://arxiv.org/a ...

  5. [论文阅读] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (MobileNet)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 本文提出的模型叫Mobi ...

  6. 论文笔记——MobileNets(Efficient Convolutional Neural Networks for Mobile Vision Applications)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications MobileNet由Go ...

  7. VGGNet论文翻译-Very Deep Convolutional Networks for Large-Scale Image Recognition

    Very Deep Convolutional Networks for Large-Scale Image Recognition Karen Simonyan[‡] & Andrew Zi ...

  8. 目标检测论文阅读:Deformable Convolutional Networks

    https://blog.csdn.net/qq_21949357/article/details/80538255 这篇论文其实读起来还是比较难懂的,主要是细节部分很需要推敲,尤其是deformab ...

  9. 论文学习:Fully Convolutional Networks for Semantic Segmentation

    发表于2015年这篇<Fully Convolutional Networks for Semantic Segmentation>在图像语义分割领域举足轻重. 1 CNN 与 FCN 通 ...

随机推荐

  1. 织梦DEDEcms5.7解决arclist标签调用副栏目文章

    使用arclist标签调用文章的时候才发现,根本无法调用相关文章. 下面给出解决办法,希望帮到需要的人. 找到/include/taglib/arclist.lib.php文件然后打开.然后在大约30 ...

  2. shelve:极其强大的序列化模块

    介绍 数据持久化,就是把数据从内存刷到磁盘上.但是要保证在读取的时候还能恢复到原来的状态.像pickle和json之类的持久化模块基本上无需介绍了,这里介绍两个其他很少用但是功能很强大的模块. dbm ...

  3. npm install 报错,提示`gyp ERR! stack Error: EACCES: permission denied` 解决方法

    m install 报错,提示gyp ERR! stack Error: EACCES: permission denied 猜测可能是因为没有权限读写,ls -la看下文件权限设置情况 [root@ ...

  4. apache笔记

    apache笔记 一)两种工作模式 Prefork和worker prefork模式: 一个进程响应一个请求 主进程生成多个工作进程,由工作进程一对一的去响应客户端的请求 过程: 1)用户空间有个具有 ...

  5. PAT Basic 1019 数字黑洞 (20 分)

    给定任一个各位数字不完全相同的 4 位正整数,如果我们先把 4 个数字按非递增排序,再按非递减排序,然后用第 1 个数字减第 2 个数字,将得到一个新的数字.一直重复这样做,我们很快会停在有“数字黑洞 ...

  6. Backtracking(一)

    LeetCode中涉及到回溯的题目有通用的解题套路: 46. permutations 这一类回溯题目中的基础中的基础,无重复数字枚举: /* Given a collection of distin ...

  7. html base标签 target=_parent使用介绍

    <base>标签为页面上的所有链接规定默认地址或默认目标. 通常情况下,浏览器会从当前文档的URL中提取相应的元素来填写相对URL中的空白. 使用<base> 标签可以改变这一 ...

  8. Kruskal重构树+LCA || BZOJ 3732: Network

    题面:https://www.lydsy.com/JudgeOnline/problem.php?id=3732 题解:Kruskal重构树板子 代码: #include<cstdio> ...

  9. mysql和oracle的语法差异(网络收集)

    oracle没有offet,limit,在mysql中我们用它们来控制显示的行数,最多的是分页了.oracle要分页的话,要换成rownum. oracle建表时,没有auto_increment,所 ...

  10. qt5---步长调节器SpinBox和QDoubleSpinBox

    #include <QSpinBox>            #include <QDoubleSpinBox> QSpinBox 用于整数的显示和输入,一般显示十进制数,也可 ...