pytorch实现squeezenet
squeezenet是16年发布的一款轻量级网络模型,模型很小,只有4.8M,可用于移动设备,嵌入式设备。
关于squeezenet的原理可自行阅读论文或查找博客,这里主要解读下pytorch对squeezenet的官方实现。
地址:https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py
首先定义fire模块,这是squeezenet的核心所在,降低3X3卷积的数量。
class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)#定义压缩层,1X1卷积
self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,#定义扩展层,1X1卷积
kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,#定义扩展层,3X3卷积
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.squeeze_activation(self.squeeze(x))
return torch.cat([
self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))
], 1)
可以看到首先定义压缩层与两个扩展层,压缩层用的是1X1卷积,扩展层是1X1卷积和3X3卷积的混合使用,网络inference的脉络是先经过压缩层,然后并行经过两个扩展层,最后将扩展层串联。
定义完核心模块,来看网络整体。
class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000):
super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
raise ValueError("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
self.num_classes = num_classes
if version == 1.0:
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
else:
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
# Final convolution is initialized differently form the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.ReLU(inplace=True),
nn.AvgPool2d(13, stride=1)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), self.num_classes)
首先依然是定义网络层,在这里有两个版本,差别不大,都是fire模块的堆积,最后经过全局平均池化输出1000类。这里对卷积层采用了不同的初始化策略,我还没仔细研究过,就不说了。
pytorch实现squeezenet的更多相关文章
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- 生产与学术之Pytorch模型导出为安卓Apk尝试记录
生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...
- 深度学习框架PyTorch一书的学习-第六章-实战指南
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...
- 深度学习框架PyTorch一书的学习-第五章-常用工具模块
https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...
- (转)Awesome PyTorch List
Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...
- (转) The Incredible PyTorch
转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...
- PyTorch源码解读之torchvision.models(转)
原文地址:https://blog.csdn.net/u014380165/article/details/79119664 PyTorch框架中有一个非常重要且好用的包:torchvision,该包 ...
- PyTorch深度学习计算机视觉框架
Taylor Guo @ Shanghai - 2018.10.22 - 星期一 PyTorch 资源链接 图像分类 VGG ResNet DenseNet MobileNetV2 ResNeXt S ...
随机推荐
- x264中重要结构体参数解释,参数设置,函数说明 <转>
x264中重要结构体参数解释http://www.usr.cc/thread-51995-1-3.htmlx264参数设置http://www.usr.cc/thread-51996-1-3.html ...
- X—shell的安装以及与Linux的链接(http://www.cnblogs.com/v-weiwang/p/5029559.html)
X—shell作为一种强大的远程操作工具,使我们能够简单的去操作虚拟机,因此呢我们最好是能够在我们的电脑上进行安装. X—shell作为一个工具我们无论什么版本的都可以,在安装的时候呢也特别的简单,但 ...
- [CSS Hack]解決IE6、IE7、IE8、Firefox的瀏覽器相容性問題!
每次調CSS最令人頭痛的就是瀏覽器校正問題,因為每個瀏覽器對CSS的解釋都不太一樣,Firefox本身算是比較照規矩來,處理上比較簡單,但是遇到微軟的IE系列頭就大了,雖然都是IE,但是IE6.IE7 ...
- sqlplus--sqlldr命令参数详解
sqlplus--sqlldr参数详解 sqlldr,Oracle快速导入数据的工具,是sqlplus的指令,不是sql语法里的东西. 一.下面是SQL*LOADER的基本特点:1)能装入不同数据类型 ...
- onRetainNonConfigurationInstance方法状态保存
onRetainNonConfigurationInstance方法作用于ONSAVEINSTANCE类似,但是能保存更多的信息,可以使用getLastNonConfigurationInstance ...
- android tween动画和Frame动画总结
tween 动画有四种 //透明度动画 AlphaAnimation aa = (AlphaAnimation) AnimationUtils.loadAnimation(MainActivity. ...
- 在异步回调中调用MessageBox.Show
public static void Test() { ThreadStart aThreadStart = delegate() { ); MessageBox.Show("Good!&q ...
- [poj1703]Find them, Catch them(种类并查集)
题意:食物链的弱化版本 解题关键:种类并查集,注意向量的合成. $rank$为1代表与父亲对立,$rank$为0代表与父亲同类. #include<iostream> #include&l ...
- C++——static
1.第一条也是最重要的一条:隐藏.(static函数,static变量均可) 所有未加static前缀的全局变量和函数都具有全局可见性:加static前缀的全局变量和函数只有有局部可见性: //a.c ...
- 框架和事务 非常 有用 hibernate和mybatis区别
1****第一章 Hibernate与MyBatis 章 开发对比 开发学习 Hibernate的真正掌握要比Mybatis来得难些.Mybatis框架相对简单很容易上手,但也相对简陋些.个人觉得要用 ...