Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

 class Inception(nn.Module):
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
super(Inception, self).__init__()
# 1x1 conv branch
self.b1 = nn.Sequential(
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
nn.BatchNorm2d(kernel_1_x),
nn.ReLU(True),
) # 1x1 conv -> 3x3 conv branch
self.b2 = nn.Sequential(
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
nn.BatchNorm2d(kernel_3_in),
nn.ReLU(True),
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_3_x),
nn.ReLU(True),
) # 1x1 conv -> 5x5 conv branch
self.b3 = nn.Sequential(
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
nn.BatchNorm2d(kernel_5_in),
nn.ReLU(True),
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
) # 3x3 pool -> 1x1 conv branch
self.b4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
nn.BatchNorm2d(pool_planes),
nn.ReLU(True),
) def forward(self, x):
y1 = self.b1(x)
y2 = self.b2(x)
y3 = self.b3(x)
y4 = self.b4(x)
return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module):
def __init__(self):
super(GoogLeNet, self).__init__()
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=3, padding=1),
nn.BatchNorm2d(192),
nn.ReLU(True),
) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1)
self.linear = nn.Linear(1024, 10) def forward(self, x):
x = self.pre_layers(x)
x = self.a3(x)
x = self.b3(x)
x = self.max_pool(x)
x = self.a4(x)
x = self.b4(x)
x = self.c4(x)
x = self.d4(x)
x = self.e4(x)
x = self.max_pool(x)
x = self.a5(x)
x = self.b5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x

2. 保存和加载模型:

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章

  1. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  2. 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification

    ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...

  3. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】

    https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...

  4. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature

    链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...

  5. How to Build Android Applications Based on FFmpeg by An Example

    This is a follow up post of the previous blog How to Build FFmpeg for Android.  You can read the pre ...

  6. 解读(GoogLeNet)Going deeper with convolutions

    (GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...

  7. [论文阅读]Going deeper with convolutions(GoogLeNet)

    本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...

  8. Node.js NPM Tutorial: Create, Publish, Extend & Manage

    A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...

  9. Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法

    目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...

随机推荐

  1. html5-常用的通用元素

    <!DOCTYPE html><html lang="en"><head>    <meta charset="UTF-8&qu ...

  2. E. Kefa and Watch hash 线段树

    2015-09-28 14:11:36 by opas 这题给的是一个字符串 把其中一些子串给取出来 判断是否是周期为d的字符串  还需要把 其中的一个区间完全变成一个数 ,然后在查询,我们把每个字符 ...

  3. Linux 中常用的命令

    Linux中的常用命令: 终端快捷键: Ctrl + a/Home 切换到命令行开始 Ctrl + e/End 切换到命令行末尾 Ctrl + l 清除屏幕内容,效果等同于clear Ctrl + u ...

  4. 常用bash,autoUserAdd.sh

    #!/bin/bash # auth: xiluhua # date: -- read -p "please input a username:" username [ -z $u ...

  5. Beta分布深入理解

    一些公式 Gamma函数 (1) 贝叶斯公式 (2) 贝叶斯公式计算二项分布概率 现在有一枚未知硬币,我们想要计算抛出后出现正面的概率.我们使用贝叶斯公式计算硬币出现正面的概率.硬币出现正反率的概率和 ...

  6. pip使用简要说明

    一.pip常用命令 安装指定包 pip install SomePackage #最新版本 安装指定包 pip install SomePackage==1.0.4 #指定版本 安装指定包 pip i ...

  7. python 关键字yield解析

    python 关键字yield解析 yield 的作用就是把一个函数变成一个 generator,带有 yield 的函数不再是一个普通函数,Python 解释器会将其视为一个 generator.y ...

  8. 数据库的增、删、改、查 (CURD)

    增改查删可以用CURD来表示  增加:create  修改:update   查找:read      删除:delete 增加create :  insert +表名+values+(信息): in ...

  9. MySQL5.7 开启SSL

    MySQL5.7配置SSL加密的方式比较简单. 生成证书文件 [root@ ~]# bin/mysql_ssl_rsa_setup --datadir=/data/database/mysql [ro ...

  10. linux系统日常维护常用命令

    环境: OS:Red Hat Linux As 5   1.find 11.查找当前目录以及子目录下包含ORA字符的文件 find . -type f|xargs  grep "ORA&qu ...