论文将搜索空间从整体网络转化为卷积单元(cell),再按照设定堆叠成新的网络家族NASNet。不仅降低了搜索的复杂度,从原来的28天缩小到4天,而且搜索出来的结构具有扩展性,在小模型和大模型场景下都能使用更少的参数量和计算量来超越人类设计的模型,达到SOTA



来源:【晓飞的算法工程笔记】 公众号

论文: Learning Transferable Architectures for Scalable Image Recognition

Introduction


  论文作者在ICLR 2017使用强化学习进行神经网络架构搜索获得了很好的表现,但该搜索方法需要计算资源很多,在CIFAR-10上需要800块GPU搜索28天,几乎不可能在大型数据集上进行搜索。因此,论文提出在代理数据集(proxy dataset)上进行搜索,然后将网络迁移到ImageNet中,主要亮点如下:

  • 迁移的基础在于搜索空间的定义,由于常见的网络都是重复的结构堆叠而成的,论文将搜索空间从整个网络改成单元(cell),再按设定将单元堆叠成网络。这样做不仅搜索速度快,而且相对而言,单元结构通用性更高,可迁移
  • 论文搜索到的最好结构称为NASNet,达到当时的SOTA,在CIFAR-10提升了2.4%top-1准确率,而迁移到ImageNet提升了1.2%
  • 通过堆叠不同数量的单元(cell)以及修改单元中的卷积核数量,可以得到适应各种计算需求的NASNets,最小的NASNet在ImageNet top-1准确率为74.0%,比最好的移动端模型高3.1%
  • NASNets学习到的图片特征十分有用,并且能够迁移到其它视觉任务中。Faster-RCNN使用最大的NASNets能直接提高4%,达到SOTA 43.1%mAP

Method


  论文的神经网络搜索方法沿用了经典强化学习方法,具体可以看我之前的论文解读。流程如图1,简而言之就是使用RNN来生成网络结构,然后在数据集上进行训练,根据收敛后的准确率对RNN进行权重调整

  论文的核心在于定义一个全新的搜索空间,称之为the NASNet search space。论文观察到目前优秀的网络结构,如ResNet和Inception,其实都是重复模块(cell)堆叠而成的,因此可以使用RNN来预测通用的卷积模块,这样的模块可以组合堆叠成一个系列模型,论文主要包含两种单元(cell):

  • Normal Cell,卷积单元用来返回相同大小的特征图,
  • Reduction Cell,卷积单元用来返回宽高缩小两倍的特征图

  图2为CIFAR-10和ImageNet的网络框架,图片输入分别为32x32和299x299,Reduction Cell和Normal Cell可以为相同的结构,但论文发现独立的结构效果更好。当特征图的大小减少时,会手动加倍卷积核数量来大致保持总体特征点数量。另外,单元的重复次数N和初始的卷积核数量都是人工设定的,针对不同的分类问题

  单元的结构在搜索空间内定义,首先选取前两个低层单元的输出$h_i$和$h_{i-1}$作为输入,然后the controller RNN预测剩余的卷积单元结构block,单个block预测如图3所示,每个单元(cell)由B个block组合成,每个block包含5个预测步骤,每个步骤由一个softmax分类器来选择对应的操作,block的预测如下:

  • Step 1,在$h_i$,$h_{i-1}$和单元中之前的block输出中选择一个作为第一个隐藏层的输入
  • Step 2,选择第二个隐藏层的输入,如Step 1
  • Step 3,选择用于Step 1中的输入的操作
  • Step 4,选择用于Step 2中的输入的操作
  • Step 5,选择用于合并Step 3和Step 4输出的操作,并产生新的隐藏层,可供后面的block选择

  Step 3和4中选择的操作包含了如上的一些主流的卷积网络操作,而Step 5的合并操作主要包含两种:1) element-wise addition 2) concatenation,最后,所有没有被使用的隐藏层输出会concatenated一起作为单元的输出。the controller RNN总共进行$2\times 5B$次预测,前$5B$作为Normal Cell,而另外$5B$则作为Reduction Cell

  在RNN的训练方面,既可以用强化学习也可以用随机搜索,实验发现随机搜索仅比强化学习得到的网络稍微差一点,这意味着:

  • NASNet的搜索空间构造得很好,因此随机搜索也能有好的表现
  • 随机搜索是个很难打破的baseline

Experiments and Results


  The controller RNN使用Proximal Policy Optimization(PPO)进行训练,以global workqueue形式对子网络进行分布式训练,实验总共使用500块P100来训练queue中的网络,整个训练花费4天,相比之前的版本800块K40训练28天,训练加速了7倍以上,效果也更好

  图4为表现最好的Normal Cell和Reduction Cell的结构,这个结构在CIFAR-10上搜索获得的,然后迁移到ImageNet上。在获得卷积单元后,需要修改几个超参数来构建最终的网络,首先上单元重复数N,其次上初始单元的卷积核数,例如$4@64$为单元重复4次以及初始单元的卷积核数为64

  对于搜索的细节可以查看论文的Appendix A,需要注意的是,论文提出DropPath的改进版ScheduledDropPath这一正则化方法。DropPath是在训练时以一定的概率随机丢弃单元的路径(如Figure 4中的黄色框连接的边),但在论文的case中不太奏效。因此,论文改用ScheduledDropPath,在训练过程中线性增加丢弃的概率

Results on CIFAR-10 Image Classification

  NASNet-A结合随机裁剪数据增强达到了SOTA

Results on ImageNet Image Classification

  论文将在CIFAR-10上学习到的结构迁移到ImageNet上,最大的模型达到了SOTA(82.7%),与SENet的准确率一致,但是参数量大幅减少

  图5直观地展示了NASNet家族与其它人工构建网络的对比,NASNet各方面都比人工构建的网络要好

  论文也测试了移动端配置的网络准确率,这里要求网络的参数和计算量要足够的小,NASNet依然有很抢眼的表现

Improved features for object detection

  论文研究了NASNet在其它视觉任务中的表现,将NASNet作为Faster-RCNN的主干在COCO训练集上进行测试。对比移动端的网络,mAP达到29.6%mAP,提升了5.1%。而使用最好的NASNet,mAP则达到43.1%mAP,提升4.0%mAP。结果表明,NASNet能够提供更丰富且更通用的特征,从而在其它视觉任务也有很好的表现

Efficiency of architecture search methods

  论文对比了网络搜索方法的性能,主要是强化学习方法(RL)和随机搜索方法(RS)。对于最好网络,RL搜索到的准确率整体要比RS的高1%,而对于整体表现(比如top-5和top-25),两种方法则比较接近。因此,论文认为尽管RS是可行的搜索策略,但RL在NASNet的搜索空间表现更好

CONCLUSION


  论文基于之前使用强化学习进行神经网络架构搜索的研究,将搜索空间从整体网络转化为卷积单元(cell),再按照设定堆叠成新的网络NASNet。这样不仅降低了搜索的复杂度,加速搜索过程,从原来的28天缩小到4天,而且搜索出来的结构具有扩展性,分别在小模型和大模型场景下都能使用更少的参数量和计算量来超越人类设计的模型,达到SOTA

  另外,由于搜索空间和模型结构的巧妙设计,使得论文能够将小数据集学习到的结构迁移到大数据集中,通用性更好。而且该网络在目标检测领域的表现也是相当不错的



Appendix NASNet-B & NASNet-C

  论文还有另外两种结构NASNet-B和NASNet-C,其搜索空间和方法与NASNet-A有点区别,有兴趣的可以去看看原文的Appendix





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

NASNet : Google Brain经典作,改造搜索空间,性能全面超越人工网络,继续领跑NAS领域 | CVPR 2018的更多相关文章

  1. 告别炼丹,Google Brain提出强化学习助力Neural Architecture Search | ICLR2017

    论文为Google Brain在16年推出的使用强化学习的Neural Architecture Search方法,该方法能够针对数据集搜索构建特定的网络,但需要800卡训练一个月时间.虽然论文的思路 ...

  2. Google 发布的15个 Android 性能优化典范

    2015年伊始,Google发布了关于Android性能优化典范的专题,一共16个短视频,每个3-5分钟,帮助开发者创建更快更优秀的Android App.课程专题不仅仅介绍了Android系统中有关 ...

  3. 用Google Brain的机器学习项目:Magenta,教神经网络学抖音小姐姐作曲。

    先上我们要学习的小姐姐 的美照.. 一.配置环境 1.自己配置环境:python,tensorflow,bazel(编译),java.然后下载magenta(https://github.com/te ...

  4. 老子云AMRT全新三维格式正式上线,其性能全面超越现有的三维数据格式

    9月16日,老子云AMRT全新三维格式正式上线,其性能远超现有的三维数据格式.目前已有含国家超算长沙中心.中科院空间所.中车集团等上百家政企事业单位的项目中使用了AMRT格式,大大提升了可视化项目的开 ...

  5. WebService - 怎样提高WebService性能 大数据量网络传输处理

    直接返回DataSet对象 返回DataSet对象用Binary序列化后的字节数组 返回DataSetSurrogate对象用Binary序列化后的字节数组 返回DataSetSurrogate对象用 ...

  6. 共享式以太网与交换式以太网的性能比较(OPNET网络仿真实验)

      一.实验目的 比较共享式以太网和交换式以太网在不同网络规模下的性能. 二.实验方法 使用opnet来创建和模拟网络拓扑,并运行分析其性能. 三.实验内容 3.1   实验设置(网络拓扑.参数设置. ...

  7. C# 之 提高WebService性能大数据量网络传输处理

    1.直接返回DataSet对象 特点:通常组件化的处理机制,不加任何修饰及处理: 优点:代码精减.易于处理,小数据量处理较快: 缺点:大数据量的传递处理慢,消耗网络资源: 建议:当应用系统在内网.专网 ...

  8. sql性能优化(摘自网络)

    索引,索引!!!为经常查询的字段建索引!! 但也不能过多地建索引.insert和delete等改变表记录的操作会导致索引重排,增加数据库负担. 优化目标 1.减少 IO 次数 IO永远是数据库最容易瓶 ...

  9. VPS性能综合测试(6):UnixBench跑分工具测试

    测试时间可能会比较长,请耐心等待.最后UnixBench会详细列出各个测试项目的得分情况,以及VPS性能的综合跑分结果 UinxBench 的使用 使用方法如下: Run [ -q | -v ] [- ...

随机推荐

  1. Jackie's blog

    介绍使用winmm.h进行音频流的获取   首先需要包含以下引用对象 #include <Windows.h>#include "mmsystem.h"#pragma ...

  2. markdoen语法

    # 标题1 ## 标题2 ### 标题3 #### 标题4 ##### 标题5 ###### 标题6 1. 有序列表1 2. 有序列表2 <!--more--> + 无序列表 * 无序列表 ...

  3. 在 React Native 中使用 moment.js 無法載入語系檔案

    moment.js 是很常見的日期時間 library,友善的 API 與極佳的執行效率是它的兩大賣點.例如 (new Date()).getFullYear(),如果使用 moment.js 我可以 ...

  4. c#或者C#.net中的“ToolTip”是“System.Windows.Forms.ToolTip”和“DevComponents.DotNetBar.ToolTip”之间的不明确的引用

    “ToolTip”是“System.Windows.Forms.ToolTip”和“DevComponents.DotNetBar.ToolTip”之间的不明确的引用 ,在编程时,有时候会编译出现不明 ...

  5. Vue.js——学习笔记(一)

    Vue-自学笔记 Vue (读音 /vjuː/,类似于 view) 是一套用于构建用户界面的渐进式框架.与其它大型框架不同的是,Vue 被设计为可以自底向上逐层应用.Vue 的核心库只关注视图层,不仅 ...

  6. 达拉草201771010105《面向对象程序设计(java)》第七周学习总结

    达拉草201771010105<面向对象程序设计(java)>第七周学习总结 实验七继承附加实验 实验时间 2018-10-11 1.实验目的与要求 (1)进一步理解4个成员访问权限修饰符 ...

  7. 彻底消灭if-else嵌套

    一.背景 1.1 反面教材 不知大家有没遇到过像横放着的金字塔一样的if-else嵌套: if (true) { if (true) { if (true) { if (true) { if (tru ...

  8. 《Deep Learning of Graph Matching》论文阅读

    1. 论文概述 论文首次将深度学习同图匹配(Graph matching)结合,设计了end-to-end网络去学习图匹配过程. 1.1 网络学习的目标(输出) 是两个图(Graph)之间的相似度矩阵 ...

  9. c语言之单向链表

    0x00 什么是链表 链表可以说是一种最为基础的数据结构了,而单向链表更是基础中的基础.链表是由一组元素以特定的顺序组合或链接在一起的,不同元素之间在逻辑上相邻,但是在物理上并不一定相邻.在维护一组数 ...

  10. 7-5 jmu-python-分段函数1 (10 分)

    本题目要求计算下列分段函数f(x)的值(x为从键盘输入的一个任意实数): 输入格式: 直接输入一个实数给 x,没有其他任何附加字符. 输出格式: 在一行中按“f(x)=result”的格式输出,其中x ...