Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

 class Inception(nn.Module):
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
super(Inception, self).__init__()
# 1x1 conv branch
self.b1 = nn.Sequential(
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
nn.BatchNorm2d(kernel_1_x),
nn.ReLU(True),
) # 1x1 conv -> 3x3 conv branch
self.b2 = nn.Sequential(
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
nn.BatchNorm2d(kernel_3_in),
nn.ReLU(True),
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_3_x),
nn.ReLU(True),
) # 1x1 conv -> 5x5 conv branch
self.b3 = nn.Sequential(
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
nn.BatchNorm2d(kernel_5_in),
nn.ReLU(True),
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
) # 3x3 pool -> 1x1 conv branch
self.b4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
nn.BatchNorm2d(pool_planes),
nn.ReLU(True),
) def forward(self, x):
y1 = self.b1(x)
y2 = self.b2(x)
y3 = self.b3(x)
y4 = self.b4(x)
return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module):
def __init__(self):
super(GoogLeNet, self).__init__()
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=3, padding=1),
nn.BatchNorm2d(192),
nn.ReLU(True),
) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1)
self.linear = nn.Linear(1024, 10) def forward(self, x):
x = self.pre_layers(x)
x = self.a3(x)
x = self.b3(x)
x = self.max_pool(x)
x = self.a4(x)
x = self.b4(x)
x = self.c4(x)
x = self.d4(x)
x = self.e4(x)
x = self.max_pool(x)
x = self.a5(x)
x = self.b5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x

2. 保存和加载模型:

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章

  1. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  2. 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification

    ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...

  3. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】

    https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...

  4. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature

    链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...

  5. How to Build Android Applications Based on FFmpeg by An Example

    This is a follow up post of the previous blog How to Build FFmpeg for Android.  You can read the pre ...

  6. 解读(GoogLeNet)Going deeper with convolutions

    (GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...

  7. [论文阅读]Going deeper with convolutions(GoogLeNet)

    本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...

  8. Node.js NPM Tutorial: Create, Publish, Extend & Manage

    A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...

  9. Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法

    目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...

随机推荐

  1. php aes128加密

    //[加密数据]AES 128 ECB模式 public function aesEncrypt($str){ $screct_key = Yii::$app->params['encryptK ...

  2. uva 11354 Bond

    题意: 邦德在逃命!他在一个有N个城市,由M条边连接的道路网中.一条路的危险度被定义为这条路上危险度最大的边的危险度. 现在给出若干个询问,s,t,问从s到t的最小的危险度是多少. 思路: 首先可以证 ...

  3. linux 查看python安装路径,版本号

    一.想要查看ubuntu中安装的python路径 方法一:whereis python     方法二:which python   二.想要查看ubuntu中安装的python版本号 python ...

  4. 使用Oozie中workflow的定时任务重跑hive数仓表的历史分期调度

    在数仓和BI系统的开发和使用过程中会经常出现需要重跑数仓中某些或一段时间内的分区数据,原因可能是:1.数据统计和计算逻辑/口径调整,2.发现之前的埋点数据收集出现错误或者埋点出现错误,3.业务数据库出 ...

  5. 转:SQL Server游标的使用

    使用游标步骤:1.在某个查询的基础上声明游标 --声明游标 declare c_Customers cursor for --查询所有店铺客户的客户编号 下面我们来看游标定义的参数: LOCAL和GL ...

  6. Codeforces 579A. Raising Bacteria

    You are a lover of bacteria. You want to raise some bacteria in a box. Initially, the box is empty. ...

  7. Navicat连接MySQL8.0亲测有效

    今天下了个 MySQL8.0,发现Navicat连接不上,总是报错1251: 原因是MySQL8.0版本的加密方式和MySQL5.0的不一样,连接会报错. 试了很多种方法,终于找到一种可以实现的: 更 ...

  8. Linux centos7 下 svn 服务器搭建

    摘自:https://www.cnblogs.com/mymelon/p/5483215.html 鉴于在搭建时,参考网上很多资料,网上资料在有用的同时,也坑了很多人 本文的目的,也就是想让后继之人在 ...

  9. 怎样从外网访问内网Apache HTTP Server

    本地安装了一个Apache HTTP Server,只能在局域网内访问,怎样从外网也能访问到本地的Apache HTTP Server呢?本文将介绍具体的实现步骤. 1. 准备工作 1.1 安装并启动 ...

  10. TCP/IP协议三次握手与四次挥手

    一.标志位和序号 seq序号 :发送方随机生成的 ack确认序号:ack=seq+1 标志位ACK=1时确认序号有效 SYN标志位:发起一个新连接 ACK标志位:确认序号有效 FIN标志位:断开连接 ...