Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models
Tutorial on GoogleNet based image classification
2018-06-26 15:50:29
本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。
1. GoogleNet 的结构:
class Inception(nn.Module):
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
super(Inception, self).__init__()
# 1x1 conv branch
self.b1 = nn.Sequential(
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
nn.BatchNorm2d(kernel_1_x),
nn.ReLU(True),
) # 1x1 conv -> 3x3 conv branch
self.b2 = nn.Sequential(
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
nn.BatchNorm2d(kernel_3_in),
nn.ReLU(True),
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_3_x),
nn.ReLU(True),
) # 1x1 conv -> 5x5 conv branch
self.b3 = nn.Sequential(
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
nn.BatchNorm2d(kernel_5_in),
nn.ReLU(True),
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
) # 3x3 pool -> 1x1 conv branch
self.b4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
nn.BatchNorm2d(pool_planes),
nn.ReLU(True),
) def forward(self, x):
y1 = self.b1(x)
y2 = self.b2(x)
y3 = self.b3(x)
y4 = self.b4(x)
return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module):
def __init__(self):
super(GoogLeNet, self).__init__()
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=3, padding=1),
nn.BatchNorm2d(192),
nn.ReLU(True),
) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1)
self.linear = nn.Linear(1024, 10) def forward(self, x):
x = self.pre_layers(x)
x = self.a3(x)
x = self.b3(x)
x = self.max_pool(x)
x = self.a4(x)
x = self.b4(x)
x = self.c4(x)
x = self.d4(x)
x = self.e4(x)
x = self.max_pool(x)
x = self.a5(x)
x = self.b5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
2. 保存和加载模型:
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章
- A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)
A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON ...
- 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification
ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...
- Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】
https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...
- Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature
链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...
- How to Build Android Applications Based on FFmpeg by An Example
This is a follow up post of the previous blog How to Build FFmpeg for Android. You can read the pre ...
- 解读(GoogLeNet)Going deeper with convolutions
(GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...
- [论文阅读]Going deeper with convolutions(GoogLeNet)
本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...
- Node.js NPM Tutorial: Create, Publish, Extend & Manage
A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...
- Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法
目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...
随机推荐
- c++引用和指针的彻底理解
★ 相同点: 1. 都是地址的概念: 指针指向一块内存,它的内容是所指内存的地址:引用是某块内存的别名. ★ 区别: 1. 指针是一个实体,而引用仅是个别名: 2. 引用使用时无需解引用(*),指 ...
- Sql server 存储过程批量插入若干数据。
测试时,经常需要生成大量数据来测试系统性能,此功能可以用存储过程快速生成. 1. 随机生成日期 DECLARE @Date_start datetime DECLARE @Date_end datet ...
- codeforces 980B Marlin
题意: 有一个城市有4行n列,n是奇数,有一个村庄在(1,1),村民的活动地点是(4,n): 有一个村庄在(4,1),村民的活动地点是(1,n): 现在要修建k个宾馆,不能修建在边界上,问能否给出一种 ...
- NSOperation、NSOperationQueue(II)
NSOperationQueue 控制串行执行.并发执行 NSOperationQueue 创建的自定义队列同时具有串行.并发功能 这里有个关键属性 maxConcurrentOperationCou ...
- 文件格式(图像 IO 14.3)
文件格式 图片加载性能取决于加载大图的时间和解压小图时间的权衡.很多苹果的文档都说PNG是iOS所有图片加载的最好格式.但这是极度误导的过时信息了. PNG图片使用的无损压缩算法可以比使用JPEG的图 ...
- android使用inject需要注意的地方
android使用inject需要注意的地方1.viewmodel里面添加注解@Inject FavoritesDBManager mFavoritesDBManager; 2.Component里面 ...
- excel vba 数据分析
(Visual Basic Application) VBA(Visual Basic for Application)是Microsoft Office系列软件的内置编程语言,其语法结构与Visua ...
- 介绍Python中6个序列的内置类型
1.Python中6个序列的内置类型分别是什么? Python包含6中内建的序列,即列表.元组.字符串.Unicode字符串.buffer对象和 xrange 对象.序列通用的操作包括:索引.长度.组 ...
- git getting started
2019/4/25-- after committing to blessed. modify dependency file to download file so as to get latest ...
- Pytorch的torch.cat实例
import torch 通过 help((torch.cat)) 可以查看 cat 的用法 cat(seq,dim,out=None) 其中 seq表示要连接的两个序列,以元组的形式给出,例如:se ...