ResNet网络,本文获得2016 CVPR best paper,获得了ILSVRC2015的分类任务第一名。

  本篇文章解决了深度神经网络中产生的退化问题(degradation problem。什么是退化问题呢?如下图:

  上图所示,网络随着深度的增加(从20层增加到56层),训练误差和测试误差非但没有降低,反而变大了。然而这种问题的出现并不是因为过拟合(overfitting)。

  照理来说,如果我们有一个浅层的网络,然后我们可以构造一个这样的深层的网络:前面一部分的网络和浅层网络一模一样,后面一部分的网络采用恒等映射(identity mapping),那么,深层网络的产生的误差至少不会比浅层网络的高。但是目前却不能找到一个更好的方法比用刚才的方法构造的网络效果要好。

  于是,作者就提出了deep residual learning framework。结构如下:

其实就是在原来网络的基础上,每隔2层(或者3层,或者更多,这篇文章作者只做了2层和3层)的输出F(x)上再加上之前的输入x。这样做,不会增加额外的参数和计算复杂度,整个网络也可以用SGD方法进行端对端的训练,用目前流行的深度学习库(caffe等)也可以很容易的实现。

这种网络的优点有:

1)  更容易优化(easier to optimize)

2)  can gain accuracy from increased depth,即能够做到网络越深,准确率越高

对于作者提出的网络结构,有2种情形。

1)  当F和x相同维度时,直接相加(element-wise addition),公式如下:

这种方法不会增加网络的参数以及计算复杂度。

2)  当F和x维度不同时,需要先将x做一个变换(linear projection),然后再相加,公式如下:

Ws仅仅用于维度匹配上。

对于x的维度变换,一种是zero-padding,另一种是通过1x1的卷积。

网络结构

测试网络如下:

  基准网络为:基于VGGNet,采用的卷积核为3x3,其中有两个设计原则,1)对于有相同的输出feature map尺寸,filter的个数相同;2)当feature map尺寸减半时,filter的数量加倍。下采样的策略是直接用stride=2的卷积核。网络最后末尾是一个global average pooling layer(不需要参数,参考http://www.cnblogs.com/hejunlin1992/articles/7750759.html)和一个1000的全连接层(后面接softmax)。

  残差网络为:在基准网络的基础上,插入了shortcut connections。当输入输出具有相同尺寸时,identity shortcuts可以直接使用(实线部分),就是公式1;当维度增加时(虚线部分),有以下两种选择:A)仍然采用恒等映射(identity mapping),超出部分的维度使用0填充;B) 利用1x1卷积核来匹配维度,就是公式2。对于上面两种方案,当shortcuts通过两种大小的feature map时,采取A或B方案的同时,stride=2。

实现细节:

  Our implementation for ImageNet follows the practice in [21, 41]. The image is resized with its shorter side randomly sampled in [256, 480] for scale augmentation [41]. A 224×224 crop is randomly sampled from an image or its horizontal flip, with the per-pixel mean subtracted [21]. The standard color augmentation in [21] is used. We adopt batch normalization (BN) [16] right after each convolution and before activation, following [16]. We initialize the weights as in [13] and train all plain/residual nets from scratch. We use SGD with a mini-batch size of 256. The learning rate starts from 0.1 and is divided by 10 when the error plateaus, and the models are trained for up to 60 × 104 iterations. We use a weight decay of 0.0001 and a momentum of 0.9. We do not use dropout [14], following the practice in [16]. In testing, for comparison studies we adopt the standard 10-crop testing [21]. For best results, we adopt the fullyconvolutional form as in [41, 13], and average the scores at multiple scales (images are resized such that the shorter side is in {224, 256, 384, 480, 640}).

实验

  从上图左边,可以看出,plain-34网络不管是训练误差还是验证集上的误差,都要比plain-18要大,由于plain网络采用了BN来训练,并且作者也验证过前向传播或者反向传播中,信号并没有消失,因此说明出现了退化现象(到底为什么会出现这种情况,作者也还在研究之中)。

再看上图右边的残差网络,结合下面的表2,34层的resNet比18层的resNet在训练集和验证集上的误差都要小,说明并没有出现退化现象。34层的resNet与34层的plain网络相比,误差减少了3.5%,说明在深度网络中残差学习是有效的。另外,18层的ResNet与18层的plain网络相比,18层的ResNet训练更快了。

Identity vs. Projection Shortcuts.

上面展示了,恒等映射(identity shortcuts)可以帮助训练。下面我们测试一下projection shortcuts(公式2)的效果。有3种测试方案。A)当维度增加时,使用zero-padding shortcuts,这些所有的shortcuts是没有参数的(与Table2和Fig4右侧的图一致);B) projection shortcuts只用于维度增加的情况,其他情况(输入输出维度一致时)还是使用恒等映射(即公式1);C)所有的shortcus都是projection(即公式2)。测试结果如下表:

从上表中可以看出,方案A,B和C的效果都要比plain-34要好。B比A稍微好一点,这是因为A的用零填充的那几个维度没有进行残差学习。C要B好的多,这是因为引入了额外的参数。但是总体上A,B和C的差别还是比较小的,说明projection shortcuts在解决退化问题中并不是十分重要。因此,为了减少参数和计算量,我们在这篇文章中不使用方案C。

Deeper Bottleneck Architectures

下面描述用于ImageNet的更深的网络结构。考虑到训练时间,我们将下图中左侧的网络改成右侧的网络。1x1的卷积的作用是减少和增加(恢复)维度,使得3x3的卷积核的个数可以减少。

50-layer ResNet

将34层中的每个2-layer block替换为3-layer bottleneck block,就得到了一个50层的ResNet。我们使用B方案来增加维度。

101-layer and 152-layer ResNets

我们使用更多的3-layer bottleneck block,就得到了101层和152层的ResNet。152层的网络深度很深,但是参数量却比VGG16/19要小。

模型之间的比对

如下图所示,下图中的ResNet是使用多个不同深度的ResNet,综合结果而成的,效果非常好。

[论文阅读] Deep Residual Learning for Image Recognition(ResNet)的更多相关文章

  1. 论文笔记——Deep Residual Learning for Image Recognition

    论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...

  2. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  3. Deep Residual Learning for Image Recognition (ResNet)

    目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...

  4. Deep Residual Learning for Image Recognition这篇文章

    作者:何凯明等,来自微软亚洲研究院: 这篇文章为CVPR的最佳论文奖:(conference on computer vision and pattern recognition) 在神经网络中,常遇 ...

  5. Deep Residual Learning for Image Recognition论文笔记

    Abstract We present a residual learning framework to ease the training of networks that are substant ...

  6. Deep Residual Learning for Image Recognition

    Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun           Microsoft Research {kahe, v-xiangz, v-sh ...

  7. Deep Residual Learning for Image Recognition(残差网络)

    深度在神经网络中有及其重要的作用,但越深的网络越难训练. 随着深度的增加,从训练一开始,梯度消失或梯度爆炸就会阻止收敛,normalized initialization和intermediate n ...

  8. 【网络结构】Deep Residual Learning for Image Recognition(ResNet) 论文解析

    目录 0. 论文链接 1. 概述 2. 残差学习 3. Identity Mapping by shortcuts 4. Network Architectures 5. 训练细节 6. 实验 @ 0 ...

  9. Paper | Deep Residual Learning for Image Recognition

    目录 1. 故事 2. 残差学习网络 2.1 残差块 2.2 ResNet 2.3 细节 3. 实验 3.1 短连接网络与plain网络 3.2 Projection解决短连接维度不匹配问题 3.3 ...

随机推荐

  1. 8.String StringBuffer StringBuilder

    http://www.cnblogs.com/xiohao/p/4271140.html

  2. 使用mescroll来实现移动端页面上拉刷新, 下拉加载更多功能

    * mescroll请参考官方文档 1. 使用mescroll实现下拉滑动的效果: (仅仅效果, 有的页面不需要刷新数据, 只要你能下拉就行) 代码如下: var mescroll = new MeS ...

  3. Java线程sleep,yield,join,wait方法详解

    1.sleep() 当一个线程调用sleep方法后,他就会放弃cpu,转到阻塞队列,sleep(long millis)方法是Thread类中的静态方法,millis参数设定线程睡眠的时间,毫秒为单位 ...

  4. JQ 判断 浏览器打开的设备类型

    <script> $(document).ready(function(){ var ua = navigator.userAgent; var ipad = ua.match(/(iPa ...

  5. 蚂蚁金服安全实验室首次同时亮相BlackHat Asia 以及CanSecWest国际安全舞台

    近期,蚂蚁金服巴斯光年安全实验室(以下简称AFLSLab)同时中稿BlackHat Asia黑帽大会的文章以及武器库,同时在北美的CanSecWest安全攻防峰会上首次中稿Android安全领域的漏洞 ...

  6. 如何正确使用Java异常处理机制

    文章来源:leaforbook - 如何正确使用Java异常处理机制作者:士别三日 第一节 异常处理概述 第二节 Java异常处理类 2.1 Throwable 2.1.1 Throwable有五种构 ...

  7. java开发常用技术

    基础部分 1. 线程和进程的区别 线程三个基本状态:就绪.执行.阻塞 线程五个基本操作:创建.就绪.运行.阻塞.终止 进程四种形式:主从式.会话式.消息或邮箱机制.共享存储区方式 进程是具有一定功能的 ...

  8. iOS移动端直连数据库

    一个可以直接连接服务器MySQL的工具包(极不安全,如非特殊需求,不推荐使用) 这种直接连接服务器数据的方式是极为不安全的,但因为我们这个项目特殊情况,只在局域网内使用, 且只有一个pad对一台设备进 ...

  9. JS常见操作,日期操作,字符串操作,表单验证等

    复制代码 //第一篇博文,希望大家多多支持 /***** BasePage.js 公共的 脚本文件 部分方法需引用jquery库 *****/ //#region 日期操作 //字符串转化为时间. f ...

  10. Java基础学习笔记四 Java基础语法

    数组 数组的需求 现在需要统计某公司员工的工资情况,例如计算平均工资.最高工资等.假设该公司有50名员工,用前面所学的知识完成,那么程序首先需要声明50个变量来分别记住每位员工的工资,这样做会显得很麻 ...