NASNet : Google Brain经典作,改造搜索空间,性能全面超越人工网络,继续领跑NAS领域 | CVPR 2018
论文将搜索空间从整体网络转化为卷积单元(cell),再按照设定堆叠成新的网络家族NASNet。不仅降低了搜索的复杂度,从原来的28天缩小到4天,而且搜索出来的结构具有扩展性,在小模型和大模型场景下都能使用更少的参数量和计算量来超越人类设计的模型,达到SOTA
来源:【晓飞的算法工程笔记】 公众号
论文: Learning Transferable Architectures for Scalable Image Recognition

Introduction
论文作者在ICLR 2017使用强化学习进行神经网络架构搜索获得了很好的表现,但该搜索方法需要计算资源很多,在CIFAR-10上需要800块GPU搜索28天,几乎不可能在大型数据集上进行搜索。因此,论文提出在代理数据集(proxy dataset)上进行搜索,然后将网络迁移到ImageNet中,主要亮点如下:
- 迁移的基础在于搜索空间的定义,由于常见的网络都是重复的结构堆叠而成的,论文将搜索空间从整个网络改成单元(cell),再按设定将单元堆叠成网络。这样做不仅搜索速度快,而且相对而言,单元结构通用性更高,可迁移
- 论文搜索到的最好结构称为NASNet,达到当时的SOTA,在CIFAR-10提升了2.4%top-1准确率,而迁移到ImageNet提升了1.2%
- 通过堆叠不同数量的单元(cell)以及修改单元中的卷积核数量,可以得到适应各种计算需求的NASNets,最小的NASNet在ImageNet top-1准确率为74.0%,比最好的移动端模型高3.1%
- NASNets学习到的图片特征十分有用,并且能够迁移到其它视觉任务中。Faster-RCNN使用最大的NASNets能直接提高4%,达到SOTA 43.1%mAP
Method

论文的神经网络搜索方法沿用了经典强化学习方法,具体可以看我之前的论文解读。流程如图1,简而言之就是使用RNN来生成网络结构,然后在数据集上进行训练,根据收敛后的准确率对RNN进行权重调整
论文的核心在于定义一个全新的搜索空间,称之为the NASNet search space。论文观察到目前优秀的网络结构,如ResNet和Inception,其实都是重复模块(cell)堆叠而成的,因此可以使用RNN来预测通用的卷积模块,这样的模块可以组合堆叠成一个系列模型,论文主要包含两种单元(cell):
- Normal Cell,卷积单元用来返回相同大小的特征图,
- Reduction Cell,卷积单元用来返回宽高缩小两倍的特征图

图2为CIFAR-10和ImageNet的网络框架,图片输入分别为32x32和299x299,Reduction Cell和Normal Cell可以为相同的结构,但论文发现独立的结构效果更好。当特征图的大小减少时,会手动加倍卷积核数量来大致保持总体特征点数量。另外,单元的重复次数N和初始的卷积核数量都是人工设定的,针对不同的分类问题

单元的结构在搜索空间内定义,首先选取前两个低层单元的输出$h_i$和$h_{i-1}$作为输入,然后the controller RNN预测剩余的卷积单元结构block,单个block预测如图3所示,每个单元(cell)由B个block组合成,每个block包含5个预测步骤,每个步骤由一个softmax分类器来选择对应的操作,block的预测如下:
- Step 1,在$h_i$,$h_{i-1}$和单元中之前的block输出中选择一个作为第一个隐藏层的输入
- Step 2,选择第二个隐藏层的输入,如Step 1
- Step 3,选择用于Step 1中的输入的操作
- Step 4,选择用于Step 2中的输入的操作
- Step 5,选择用于合并Step 3和Step 4输出的操作,并产生新的隐藏层,可供后面的block选择

Step 3和4中选择的操作包含了如上的一些主流的卷积网络操作,而Step 5的合并操作主要包含两种:1) element-wise addition 2) concatenation,最后,所有没有被使用的隐藏层输出会concatenated一起作为单元的输出。the controller RNN总共进行$2\times 5B$次预测,前$5B$作为Normal Cell,而另外$5B$则作为Reduction Cell
在RNN的训练方面,既可以用强化学习也可以用随机搜索,实验发现随机搜索仅比强化学习得到的网络稍微差一点,这意味着:
- NASNet的搜索空间构造得很好,因此随机搜索也能有好的表现
- 随机搜索是个很难打破的baseline
Experiments and Results
The controller RNN使用Proximal Policy Optimization(PPO)进行训练,以global workqueue形式对子网络进行分布式训练,实验总共使用500块P100来训练queue中的网络,整个训练花费4天,相比之前的版本800块K40训练28天,训练加速了7倍以上,效果也更好

图4为表现最好的Normal Cell和Reduction Cell的结构,这个结构在CIFAR-10上搜索获得的,然后迁移到ImageNet上。在获得卷积单元后,需要修改几个超参数来构建最终的网络,首先上单元重复数N,其次上初始单元的卷积核数,例如$4@64$为单元重复4次以及初始单元的卷积核数为64
对于搜索的细节可以查看论文的Appendix A,需要注意的是,论文提出DropPath的改进版ScheduledDropPath这一正则化方法。DropPath是在训练时以一定的概率随机丢弃单元的路径(如Figure 4中的黄色框连接的边),但在论文的case中不太奏效。因此,论文改用ScheduledDropPath,在训练过程中线性增加丢弃的概率
Results on CIFAR-10 Image Classification

NASNet-A结合随机裁剪数据增强达到了SOTA
Results on ImageNet Image Classification

论文将在CIFAR-10上学习到的结构迁移到ImageNet上,最大的模型达到了SOTA(82.7%),与SENet的准确率一致,但是参数量大幅减少

图5直观地展示了NASNet家族与其它人工构建网络的对比,NASNet各方面都比人工构建的网络要好

论文也测试了移动端配置的网络准确率,这里要求网络的参数和计算量要足够的小,NASNet依然有很抢眼的表现
Improved features for object detection

论文研究了NASNet在其它视觉任务中的表现,将NASNet作为Faster-RCNN的主干在COCO训练集上进行测试。对比移动端的网络,mAP达到29.6%mAP,提升了5.1%。而使用最好的NASNet,mAP则达到43.1%mAP,提升4.0%mAP。结果表明,NASNet能够提供更丰富且更通用的特征,从而在其它视觉任务也有很好的表现
Efficiency of architecture search methods

论文对比了网络搜索方法的性能,主要是强化学习方法(RL)和随机搜索方法(RS)。对于最好网络,RL搜索到的准确率整体要比RS的高1%,而对于整体表现(比如top-5和top-25),两种方法则比较接近。因此,论文认为尽管RS是可行的搜索策略,但RL在NASNet的搜索空间表现更好
CONCLUSION
论文基于之前使用强化学习进行神经网络架构搜索的研究,将搜索空间从整体网络转化为卷积单元(cell),再按照设定堆叠成新的网络NASNet。这样不仅降低了搜索的复杂度,加速搜索过程,从原来的28天缩小到4天,而且搜索出来的结构具有扩展性,分别在小模型和大模型场景下都能使用更少的参数量和计算量来超越人类设计的模型,达到SOTA
另外,由于搜索空间和模型结构的巧妙设计,使得论文能够将小数据集学习到的结构迁移到大数据集中,通用性更好。而且该网络在目标检测领域的表现也是相当不错的
Appendix NASNet-B & NASNet-C
论文还有另外两种结构NASNet-B和NASNet-C,其搜索空间和方法与NASNet-A有点区别,有兴趣的可以去看看原文的Appendix


如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

NASNet : Google Brain经典作,改造搜索空间,性能全面超越人工网络,继续领跑NAS领域 | CVPR 2018的更多相关文章
- 告别炼丹,Google Brain提出强化学习助力Neural Architecture Search | ICLR2017
论文为Google Brain在16年推出的使用强化学习的Neural Architecture Search方法,该方法能够针对数据集搜索构建特定的网络,但需要800卡训练一个月时间.虽然论文的思路 ...
- Google 发布的15个 Android 性能优化典范
2015年伊始,Google发布了关于Android性能优化典范的专题,一共16个短视频,每个3-5分钟,帮助开发者创建更快更优秀的Android App.课程专题不仅仅介绍了Android系统中有关 ...
- 用Google Brain的机器学习项目:Magenta,教神经网络学抖音小姐姐作曲。
先上我们要学习的小姐姐 的美照.. 一.配置环境 1.自己配置环境:python,tensorflow,bazel(编译),java.然后下载magenta(https://github.com/te ...
- 老子云AMRT全新三维格式正式上线,其性能全面超越现有的三维数据格式
9月16日,老子云AMRT全新三维格式正式上线,其性能远超现有的三维数据格式.目前已有含国家超算长沙中心.中科院空间所.中车集团等上百家政企事业单位的项目中使用了AMRT格式,大大提升了可视化项目的开 ...
- WebService - 怎样提高WebService性能 大数据量网络传输处理
直接返回DataSet对象 返回DataSet对象用Binary序列化后的字节数组 返回DataSetSurrogate对象用Binary序列化后的字节数组 返回DataSetSurrogate对象用 ...
- 共享式以太网与交换式以太网的性能比较(OPNET网络仿真实验)
一.实验目的 比较共享式以太网和交换式以太网在不同网络规模下的性能. 二.实验方法 使用opnet来创建和模拟网络拓扑,并运行分析其性能. 三.实验内容 3.1 实验设置(网络拓扑.参数设置. ...
- C# 之 提高WebService性能大数据量网络传输处理
1.直接返回DataSet对象 特点:通常组件化的处理机制,不加任何修饰及处理: 优点:代码精减.易于处理,小数据量处理较快: 缺点:大数据量的传递处理慢,消耗网络资源: 建议:当应用系统在内网.专网 ...
- sql性能优化(摘自网络)
索引,索引!!!为经常查询的字段建索引!! 但也不能过多地建索引.insert和delete等改变表记录的操作会导致索引重排,增加数据库负担. 优化目标 1.减少 IO 次数 IO永远是数据库最容易瓶 ...
- VPS性能综合测试(6):UnixBench跑分工具测试
测试时间可能会比较长,请耐心等待.最后UnixBench会详细列出各个测试项目的得分情况,以及VPS性能的综合跑分结果 UinxBench 的使用 使用方法如下: Run [ -q | -v ] [- ...
随机推荐
- UEditor问题整理
网上可以使用的富文本编辑器有很多,但是经过慎(sui)重(shou)思(yi)考(cha),选择了UEditor,毕竟是百度的东西,质量上应该经得起推敲,另外,使用别人的插件,总要去适应别人的编码习惯 ...
- 痞子衡嵌入式:恩智浦i.MX RT1xxx系列MCU启动那些事(11.2)- FlexSPI NOR连接方式大全(RT1060/1064(SIP))
大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是恩智浦i.MX RT1060/1064(SIP)两款MCU的FlexSPI NOR启动的连接方式. 上一篇文章<FlexSPI N ...
- 浅谈ConcurrentDictionary与Dictionary
在.NET4.0之前,如果我们需要在多线程环境下使用Dictionary类,除了自己实现线程同步来保证线程安全外,我们没有其他选择.很多开发人员肯定都实现过类似的线程安全方案,可能是通过创建全新的线程 ...
- mysql插入数据报错一二
上周selenium+phantomjs+python3简单爬取一个网站,往数据库写数据遇到以下两个问题,记录一下: 报错一:Data truncated for column 'update_tim ...
- 基于arduino的红外传感系统
一.作品背景 在这个科技飞速发展的时代,物联网已经成为了我们身边必不可少的技术模块,我这次课程设计做的是一个基于arduino+树莓派+OneNet的红外报警系统,它主要通过识别人或者动物的运动来判断 ...
- JavaScript实现队列结构(Queue)
JavaScript实现队列结构(Queue) 一.队列简介 队列是是一种受限的线性表,特点为先进先出(FIFO:first in first out). 受限之处在于它只允许在表的前端(front) ...
- py基础之有序列表
L =['adam',95.5,'lisa',85,'bart','bart',59]print (L)#list是一种有序的列表,可以使用索引访问每个list中的值print (L[1])#list ...
- 前端每日实战:26# 视频演示如何用不到 50 行 CSS 代码,创作按钮被从纸上掀起的立体效果
效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/KRbXGe 可交互视频教程 此视频 ...
- Linux进程操作记录
关于Gunicorn如何终止进程: 1.用进程树显示主进程PID: pstree -ap | grep gunicorn 2.如果有daemon进程无法用kill -9删除(可能是因为daemon屏蔽 ...
- Java Opencv 实现锐化
§ Laplacian() void cv::Laplacian ( InputArray src, O ...