Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

 class Inception(nn.Module):
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
super(Inception, self).__init__()
# 1x1 conv branch
self.b1 = nn.Sequential(
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
nn.BatchNorm2d(kernel_1_x),
nn.ReLU(True),
) # 1x1 conv -> 3x3 conv branch
self.b2 = nn.Sequential(
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
nn.BatchNorm2d(kernel_3_in),
nn.ReLU(True),
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_3_x),
nn.ReLU(True),
) # 1x1 conv -> 5x5 conv branch
self.b3 = nn.Sequential(
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
nn.BatchNorm2d(kernel_5_in),
nn.ReLU(True),
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
) # 3x3 pool -> 1x1 conv branch
self.b4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
nn.BatchNorm2d(pool_planes),
nn.ReLU(True),
) def forward(self, x):
y1 = self.b1(x)
y2 = self.b2(x)
y3 = self.b3(x)
y4 = self.b4(x)
return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module):
def __init__(self):
super(GoogLeNet, self).__init__()
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=3, padding=1),
nn.BatchNorm2d(192),
nn.ReLU(True),
) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1)
self.linear = nn.Linear(1024, 10) def forward(self, x):
x = self.pre_layers(x)
x = self.a3(x)
x = self.b3(x)
x = self.max_pool(x)
x = self.a4(x)
x = self.b4(x)
x = self.c4(x)
x = self.d4(x)
x = self.e4(x)
x = self.max_pool(x)
x = self.a5(x)
x = self.b5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x

2. 保存和加载模型:

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章

  1. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  2. 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification

    ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...

  3. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】

    https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...

  4. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature

    链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...

  5. How to Build Android Applications Based on FFmpeg by An Example

    This is a follow up post of the previous blog How to Build FFmpeg for Android.  You can read the pre ...

  6. 解读(GoogLeNet)Going deeper with convolutions

    (GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...

  7. [论文阅读]Going deeper with convolutions(GoogLeNet)

    本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...

  8. Node.js NPM Tutorial: Create, Publish, Extend & Manage

    A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...

  9. Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法

    目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...

随机推荐

  1. c++引用和指针的彻底理解

     ★ 相同点: 1. 都是地址的概念: 指针指向一块内存,它的内容是所指内存的地址:引用是某块内存的别名.  ★ 区别: 1. 指针是一个实体,而引用仅是个别名: 2. 引用使用时无需解引用(*),指 ...

  2. Sql server 存储过程批量插入若干数据。

    测试时,经常需要生成大量数据来测试系统性能,此功能可以用存储过程快速生成. 1. 随机生成日期 DECLARE @Date_start datetime DECLARE @Date_end datet ...

  3. codeforces 980B Marlin

    题意: 有一个城市有4行n列,n是奇数,有一个村庄在(1,1),村民的活动地点是(4,n): 有一个村庄在(4,1),村民的活动地点是(1,n): 现在要修建k个宾馆,不能修建在边界上,问能否给出一种 ...

  4. NSOperation、NSOperationQueue(II)

    NSOperationQueue 控制串行执行.并发执行 NSOperationQueue 创建的自定义队列同时具有串行.并发功能 这里有个关键属性 maxConcurrentOperationCou ...

  5. 文件格式(图像 IO 14.3)

    文件格式 图片加载性能取决于加载大图的时间和解压小图时间的权衡.很多苹果的文档都说PNG是iOS所有图片加载的最好格式.但这是极度误导的过时信息了. PNG图片使用的无损压缩算法可以比使用JPEG的图 ...

  6. android使用inject需要注意的地方

    android使用inject需要注意的地方1.viewmodel里面添加注解@Inject FavoritesDBManager mFavoritesDBManager; 2.Component里面 ...

  7. excel vba 数据分析

    (Visual Basic Application) VBA(Visual Basic for Application)是Microsoft Office系列软件的内置编程语言,其语法结构与Visua ...

  8. 介绍Python中6个序列的内置类型

    1.Python中6个序列的内置类型分别是什么? Python包含6中内建的序列,即列表.元组.字符串.Unicode字符串.buffer对象和 xrange 对象.序列通用的操作包括:索引.长度.组 ...

  9. git getting started

    2019/4/25-- after committing to blessed. modify dependency file to download file so as to get latest ...

  10. Pytorch的torch.cat实例

    import torch 通过 help((torch.cat)) 可以查看 cat 的用法 cat(seq,dim,out=None) 其中 seq表示要连接的两个序列,以元组的形式给出,例如:se ...