tensorflow+inceptionv3图像分类网络结构的解析与代码实现

论文链接:论文地址

ResNet传送门:Resnet-cifar10

DenseNet传送门:DenseNet

SegNet传送门:Segnet-segmentation

深度学习的火热,使得越来越多的科研人员投入到其中。而作为各种应用类型的网络基础,图像分类的网络结构有许多,从AlexNet开始,到VGG-Net,到GoogleNet,到ResNet,denseNet等。网络结构在不断地改进,也在不断地趋于稳定。新的单纯地图像分类结构越来越少(可能是分类效果已经达到了一定的需求)。本文主要讲解GoogleNet改进后的Inceptionv3网络结构。其网络结构如下所示:

该网络在ILSVRC 2012的分类挑战上能获得5.6%的top-5 error。在参数量方面远小于VGG-Net,所以能有更块地训练速度以及不错的分类精度。文章中提到了4个通用的网络设计原则。

简单来讲就是:1、不要在网络的一开始使用过大的filter size,这会导致图像信息的丢失;2、高维数据的表示更容易在网络内进行局部处理,添加激活函数可以获得更多的disentangled features (不知道怎么翻译,有知道的大佬可否在评论底下说说?);3、空间聚合可以通过低维嵌入来完成,其表示能力没有太多或任何损失。(这里讲的就是网络中inception模块的分成4个branch最后聚合在一起所使用的原则);4、平衡网络的宽度和深度。

卷积核的分解

文章的核心部分在于其inception modules。而inception modules中又用到了factorization(将的filter size 分解成多个小的filter size),其原理可以用如下的图表示:

假设有一个5x5的feature map,我们可以直接用一个5x5的filter对其做卷积得到1个值,也可以通过两个3x3的filter对其做卷积得到1个值,但相较于前者,后者有更少地参数:3x3x2=18。前者为5x5=25。可以减少的参数量为:(25-18)/25=28%。

在此基础上,论文又提出可以使用使用非对称的卷积核来替代较大的卷积核。如下图所示:

对于一个3x3的卷积核,可以使用一个1x3和一个3x1的组合来替代。一般化地话,可以使用1xn和nx1替代nxn的卷积核。

辅助分类器

辅助分类器即除了主分类器之外,还在网络结构中的某一层,论文中为17x17x768的那一层,添加了一个分支用来做辅助分类。其思想来源于GoogleNet(Going deeper with convolutions)  。

网络尺寸的有效减少

在论文中给出的网络结构中,3xInception和5xInception以及5xInception和2xInception有一个尺寸的减少,其具体实现方法为如下所示:

这里一并给出相关的代码实现:

def inception_grid_reduction_1(input,name=None):

    with tf.variable_scope(name) as scope:
with tf.variable_scope("Branch_0"):
branch_0=conv_inception(input,shape = [1,1,288,384],name = '0a_1x1')
branch_0=conv_inception(branch_0,shape = [3,3,384,384],stride = [1,2,2,1],padding = 'VALID',name = '0b_3x3')
with tf.variable_scope('Branch_1'):
branch_1=conv_inception(input,shape = [1,1,288,64],name = '0b_1x1')
branch_1=conv_inception(branch_1,shape = [3,3,64,96],name = '0b_3x3')
branch_1=conv_inception(branch_1,shape = [3,3,96,96],stride = [1,2,2,1],padding = 'VALID',name = '0c_3x3')
with tf.variable_scope('Branch_2'):
branch_2=tf.nn.max_pool(input,ksize = (1,3,3,1),strides = [1,2,2,1],padding = 'VALID',name = 'maxpool_0a_3x3')
inception_out=tf.concat([branch_0,branch_1,branch_2],3)
c=1 # for debug
return inception_out

其中conv_inception函数定义如下:

def conv_inception(input, shape, stride= [1,1,1,1], activation = True, padding = 'SAME', name = None):
in_channel = shape[2]
out_channel = shape[3]
k_size = shape[0]
with tf.variable_scope(name) as scope:
kernel = _variable('conv_weights', shape = shape)
conv = tf.nn.conv2d(input = input, filter = kernel, strides = stride, padding = padding)
biases = _variable('biases', [out_channel])
bias = tf.nn.bias_add(conv, biases)
if activation is True:
conv_out = tf.nn.relu(bias, name = 'relu')
else:
conv_out = bias
return conv_out

_variable定义如下:

def _variable(name, shape):
"""Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
Returns:
Variable Tensor
"""
with tf.device('/gpu:0'):
var = tf.get_variable(name, shape)
return var

下面给出网络中每一部分的解释以及实现:

文章中的卷积部分就不讲了,基本操作。主要讲讲inception部分怎么做。论文中共用到了三种Inception modules,即传统的inception(如GoogleNet所示),以及使用了非对称分解卷积核的inception,以及加入了filter expanded的inception。先说说传统的,如图所示:

这里Base的input size在网络中对应为35x35x288,有4个分支,其中pool为平均池化-avgpool,最后将4个分支串到一起,其代码实现如下:

def inception_block_tradition(input, name=None):

    with tf.variable_scope(name) as scope:
with tf.variable_scope("Branch_0"):
branch_0=conv_inception(input,shape = [1,1,288,64],name = '0a_1x1')
with tf.variable_scope('Branch_1'):
branch_1=conv_inception(input,shape = [1,1,288,48],name = '0a_1x1')
branch_1=conv_inception(branch_1,shape = [5,5,48,64],name = '0b_5x5')
with tf.variable_scope("Branch_2"):
branch_2=conv_inception(input,shape = [1,1,288,64],name = '0a_1x1')
branch_2=conv_inception(branch_2,shape = [3,3,64,96],name = '0b_3x3')
with tf.variable_scope('Branch_3'):
branch_3=tf.nn.avg_pool(input,ksize = (1,3,3,1),strides = [1,1,1,1],padding = 'SAME',name = 'Avgpool_0a_3x3')
branch_3=conv_inception(branch_3,shape = [1,1,288,64],name = '0b_1x1')
inception_out=tf.concat([branch_0,branch_1,branch_2,branch_3],3)
b=1 # for debug
return inception_out

接下来是使用了非对称分解的Inception moduels,如下图所示:

这里n=7,Base为17x17x768;pool为 3x3 stride为1的avgpool(同上);其代码实现如下:

def inception_block_factorization(input,name=None):

    with tf.variable_scope(name) as scope:
with tf.variable_scope('Branch_0'):
branch_0=conv_inception(input,shape = [1,1,768,192],name = '0a_1x1')
with tf.variable_scope('Branch_1'):
branch_1=conv_inception(input,shape = [1,1,768,128],name = '0a_1x1')
branch_1=conv_inception(branch_1,shape = [1,7,128,128],name = '0b_1x7')
branch_1=conv_inception(branch_1,shape = [7,1,128,128],name = '0c_7x1')
branch_1=conv_inception(branch_1,shape = [1,7,128,128],name = '0d_1x7')
branch_1=conv_inception(branch_1,shape = [7,1,128,192],name = '0e_7x1')
with tf.variable_scope('Branch_2'):
branch_2=conv_inception(input,shape = [1,1,768,128],name = '0a_1x1')
branch_2=conv_inception(branch_2,shape = [1,7,128,128],name = '0b_1x7')
branch_2=conv_inception(branch_2,shape = [7,1,128,192],name = '0c_7x1')
with tf.variable_scope('Branch_3'):
branch_3=tf.nn.avg_pool(input,ksize = (1,3,3,1),strides = [1,1,1,1],padding = 'SAME',name = 'Avgpool_0a_3x3')
branch_3=conv_inception(branch_3,shape = [1,1,768,192],name = '0b_1x1')
inception_out=tf.concat([branch_0,branch_1,branch_2,branch_3],3)
d=1 # for debug
return inception_out

接下来使用了filter expanded的inception,如图所示:

也是4个分支,pool同上。其代码实现如下:

def inception_block_expanded(input,name=None):
with tf.variable_scope(name) as scope:
with tf.variable_scope('Branch_0'):
branch_0=conv_inception(input,shape = [1,1,2048,320],name = '0a_1x1')
with tf.variable_scope('Branch_1'):
branch_1=conv_inception(input,shape = [1,1,2048,448],name = '0a_1x1')
branch_1=conv_inception(branch_1,shape = [3,3,448,384],name = '0b_3x3')
branch_1=tf.concat([conv_inception(branch_1,shape = [1,3,384,384],name = '0c_1x3'),
conv_inception(branch_1,shape = [3,1,384,384],name = '0d_3x1')],3)
with tf.variable_scope('Branch_2'):
branch_2=conv_inception(input,shape = [1,1,2048,384],name = '0a_1x1')
branch_2=tf.concat([conv_inception(branch_2,shape = [1,3,384,384],name = '0b_1x3'),
conv_inception(branch_2,shape = [3,1,384,384],name = '0c_3x1')],3)
with tf.variable_scope('Branch_3'):
branch_3=tf.nn.avg_pool(input,ksize = (1,3,3,1),strides = [1,1,1,1],padding = 'SAME',name = 'Avgpool_0a_3x3')
branch_3=conv_inception(branch_3,shape = [1,1,2048,192],name = '0b_1x1')
inception_out=tf.concat([branch_0,branch_1,branch_2,branch_3],3)
e=1 # for debug
return inception_out

经过上述操作可得到8x8x2048的feature maps,根据论文中的结构,对其做池化操作并加入1x1的卷积得到我们最终需要的1x1xnum_class即可,其实现如下(不唯一):

        with tf.variable_scope('Logits'):
net=tf.nn.avg_pool(net,ksize = [8,8,2048,2048],strides = [1,1,1,1],padding = 'VALID',name = 'Avgpool_1a_8x8') # 1x1x2048
net=tf.nn.dropout(net,keep_prob = dropout_keep_prob,name = 'Dropout_1b')
end_points['PreLogits']=net
#
logits=conv_inception(net,shape = [1,1,2048,num_classes],activation = False,name = 'conv_1c_1x1') end_points['Logits']=logits
end_points['Predictions']=tf.nn.softmax(logits,name = 'Predictions')
return logits,end_points

论文中提及到的优化方法有SGD和RMSProp。可以随便选择,论文中得到的最佳模型为使用了RMSProp的方法。

附上代码下载地址:

tensorflow+inceptionv3

PS:数据集需要自行提供。

tensorflow+inceptionv3图像分类网络结构的解析与代码实现的更多相关文章

  1. 【深度学习系列】用PaddlePaddle和Tensorflow进行图像分类

    上个月发布了四篇文章,主要讲了深度学习中的"hello world"----mnist图像识别,以及卷积神经网络的原理详解,包括基本原理.自己手写CNN和paddlepaddle的 ...

  2. 【学习笔记】Tensorflow+Inception-v3训练自己的数据

    导读 喵喵的,一个大坑.本文分为吐槽和干货两部分. 一.吐槽 大周末的,被导师扣下加班,嗨气,谁叫本狗子太弱鸡呢,看起来很简单的任务倒腾了两天还没完,不扣你扣谁? 自己刚接到微调Inception-v ...

  3. 一步步教你为网站开发Android客户端---HttpWatch抓包,HttpClient模拟POST请求,Jsoup解析HTML代码,动态更新ListView

    本文面向Android初级开发者,有一定的Java和Android知识即可. 文章覆盖知识点:HttpWatch抓包,HttpClient模拟POST请求,Jsoup解析HTML代码,动态更新List ...

  4. 分析和解析PHP代码的7大工具

    PHP已成为时下最热门的编程语言之一,然而却有许多PHP程序员苦恼找不到合适的工具来帮助自己分析和解析PHP代码.今天小编就为大家介绍几个非常不错的工具,来帮助程序员们提高自己的工作效率,一起来看看吧 ...

  5. 用phpQuery像jquery一样解析html代码

    简介 如何在php中方便地解析html代码,估计是每个phper都会遇到的问题.用phpQuery就可以让php处理html代码像jQuery一样方便. 项目地址:https://code.googl ...

  6. Android JSON 解析关键代码

    Android Json 解析其实还是蛮重要的知识点,为什么这么说呢,因为安卓通信大部分的协议都是使用 json 的方式传输,我知道以前大部分是使用的 xml ,但是时代在发展社会在进步,json 成 ...

  7. CVE-2012-0003 Microsoft Windows Media Player ‘winmm.dll’ MIDI文件解析远程代码执行漏洞 分析

    [CNNVD]Microsoft Windows Media Player ‘winmm.dll’ MIDI文件解析远程代码执行漏洞(CNNVD-201201-110)    Microsoft Wi ...

  8. 使用python解析C代码

    我有一个巨大的C文件(~100k行),我需要能够解析.主要是我需要能够从其定义中获取有关每个结构的各个字段的详细信息(如结构中每个字段的字段名称和类型).是否有一个好的(开源,我可以在我的代码中使用) ...

  9. 实现迁徙学习-《Tensorflow 实战Google深度学习框架》代码详解

    为了实现迁徙学习,首先是数据集的下载 #利用curl下载数据集 curl -o flower_photos.tgz http://download.tensorflow.org/example_ima ...

随机推荐

  1. 一道Oracle子查询小练习

    一道Oracle子查询小练习   昨天晚上躺在床上看Oracle(最近在学习这个),室友说出个题目让我试试.题目如下: 有如下表结构,请选择出成绩为前三名的人的信息(如果成绩相同,则算并列),表名为t ...

  2. MySQL基础知识 数据库 数据表

    1.数据库结构 库 表 数据 2. sql(structured query language)结构化查询语言 管理数据库 管理表 管理数据 3.数据库 增删改查 增 create database  ...

  3. 大数据之hadoop集群安全模式

    集群安全模式1.概述(1)NameNode启动 NameNode启动时,首先将镜像文件(Fsimage)载入内存,并执行编辑日志(Edits)中的各项操作.-旦在内存中成功建立文件系统元数据的影像,则 ...

  4. [NOI2007]生成树计数环形版

    NOI2007这道题人类进化更完全之后出现了新的做法 毕姥爷题解: 于是毕姥爷出了一道环形版的这题(test0814),让我们写这个做法 环形的情况下,k=5的时候是162阶递推. 求这个递推可以用B ...

  5. 用连接池链接redis

    package com.itheima.utils; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; i ...

  6. RPC远程过程调用实例详解

    1.创建IDL文件,定义接口. IDL文件可以由uuidgen.exe创建. 首先找到系统中uuidgen.exe的位置,如:C:\Program Files\Microsoft Visual Stu ...

  7. VS2010-MFC(对话框:一般属性页对话框的创建及显示)

    转自:http://www.jizhuomi.com/software/169.html 属性页对话框包括向导对话框和一般属性页对话框两类,上一节演示了如何创建并显示向导对话框,本节将继续介绍一般属性 ...

  8. sql stuff拼接字符串的用法

    要把图2显示成图1的方法:要用到stuff函数,并且图1显示的时间有所截断. 图2sql,只是很普通的sql ), SKSJ, )=' order by SKSJ 图1sql,用了stuff拼接 ), ...

  9. POJ-1260-Pearls-dp+理解题意

    In Pearlania everybody is fond of pearls. One company, called The Royal Pearl, produces a lot of jew ...

  10. 左神算法书籍《程序员代码面试指南》——3_05Morris遍历二叉树的神级方法【★★★★★】

    [问题]介绍一种时间复杂度O(N),额外空间复杂度O(1)的二叉树的遍历方式,N为二叉树的节点个数无论是递归还是非递归,避免不了额外空间为O(h),h 为二叉树的高度使用morris遍历,即利用空节点 ...