Summary

本文提出超越神经架构搜索(NAS)的高效神经架构搜索(ENAS),这是一种经济的自动化模型设计方法,通过强制所有子模型共享权重从而提升了NAS的效率,克服了NAS算力成本巨大且耗时的缺陷,GPU运算时间缩短了1000倍以上。在Penn Treebank数据集上,ENAS实现了55.8的测试困惑度;在CIFAR-10数据集上,其测试误差达到了2.89%,与NASNet不相上下(2.65%的测试误差)

Research Objective 作者的研究目标

设计一种快速有效且耗费资源低的用于自动化网络模型设计的方法。主要贡献是基于NAS方法提升计算效率,使得各个子网络模型共享权重,从而避免低效率的从头训练。

Problem Statement 问题陈述,要解决什么问题?

本文提出的方法是对NAS的改进。NAS存在的问题是它的计算瓶颈,因为NAS是每次将一个子网络训练到收敛,之后得到相应的reward,再将这个reward反馈给RNN controller。但是在下一轮训练子网络时,是从头开始训练,而上一轮的子网络的训练结果并没有利用起来。

另外NAS虽然在每个节点上的operation设计灵活度较高,但是固定了网络的拓扑结构为二叉树。所以ENAS对于网络拓扑结构的设计做了改进,有了更高的灵活性。

Method(s) 解决问题的方法/算法

ENAS算法核心

回顾NAS,可以知道其本质是在一个大的搜索图中找到合适的子图作为模型,也可以理解为使用单个有向无环图(single directed acyclic graph, DAG)来表示NAS的搜索空间。

基于此,ENAS的DAG其实就是NAS搜索空间中所有可能的子模型的叠加。

下图给出了一个通用的DAG示例

如图示,各个节点表示本地运算,边表示信息的流动方向。图中的6个节点包含有多种单向DAG,而红色线标出的DAG则是所选择的的子图。

以该子图为例,节点1表示输入,而节点3和节点6因为是端节点,所以作为输出,一般是将而二者合并求均值后输出。

在讨论ENAS的搜索空间之前,需要介绍的是ENAS的测试数据集分别是CIFAR-10和Penn Treebank,前者需要通过ENAS生成CNN网络,后者则需要生成RNN网络。

所以下面会从生成RNN和生成CNN两个方面来介绍ENAS算法。

1.Design Recurrent Cells

本小节介绍如何从特定的DAG和controller中设计一个递归神经网络的cell(Section 2.1)?

首先假设共有\(N\)个节点,ENAS的controller其实就是一个RNN结构,它用于决定

  • 哪条边需要激活
  • DAG中每个节点需要执行什么样的计算

下图以\(N=4\)为例子展示了如何生成RNN。

假设\(x[t]\)为输入,\(h[t-1]\)表示上一个时刻的输出状态。

  • 节点1:由图可知,controller在节点1上选择的操作是tanh运算,所以有\(h_1=tanh(X_t·W^{(X)}+h_{t-1}·W_1^{(h)})\)
  • 节点2:同理有\(h_2 = ReLU(h_1·W_{2,1}^{(h)})\)
  • 节点3:\(h_3 = ReLU(h_2·W_{3,2}^{(h)})\)
  • 节点4:\(h_4 = ReLU(h_1·W_{4,1}^{(h)})\)
  • 节点3和节点4因为不是其他节点的输入,所以二者的平均值作为输出,即\(h_t=\frac{h_3+h_4}{2}\)

由上面的例子可以看到对于每一组节点\((node_i,node_j),i<j\),都会有对应的权重矩阵\(W_{j,i}^{(h)}\)。因此在ENAS中,所有的recurrent cells其实是在搜索空间中共享这样一组权重的。

但是我们可以很容易知道ENAS的搜索空间是非常庞大的,例如假设共有4中激活参数(tanh,identity,sigmoid,ReLU)可以选择,节点数为N,那么搜索空间大小则为\(4^N * N!\),如果N=12,那么就将近有\(10^{15}\)种模型参数。

2.1 Design Convolutional Networks

本小节解释如何设计卷积结构的搜索空间

回顾上面的Recurrent Cell的设计,我们知道controller RNN在每一个节点会做如下两个决定:a)该节点需要连接前面哪一个节点 b)使用何种激活函数。

而在卷积模型的搜索空间中,controller RNN也会做如下两个觉得:a)该节点需要连接前面哪一个节点 b)使用何种计算操作。

在卷积模型中,(a)决定 (连接哪一个节点) 其实就是skip connections。(b)决定一共有6种选择,分别是3*3和5*5大小的卷积核、3*3和5*5大小的深度可分离卷积核,3*3大小的最大池化和平均池化。

下图展示了卷积网络的生成示意图。

2.2 Design Convolutional Cell

本文并没有采用直接设计完整的卷积网络的方法,而是先设计小型的模块然后将模块连接以构建完整的网络(Zoph et al., 2018)。

下图展示了这种设计的例子,其中设计了卷积单元和 reduction cell。

接下来将讨论如何利用 ENAS 搜索由这些单元组成的架构。

假设下图的DAG共有\(B\)个节点,其中节点1和节点2是输入,所以controller只需要对剩下的\(B-2\)个节点都要做如下两个决定:a)当前节点需要与那两个节点相连 b)所选择的两个节点需要采用什么样的操作。(可选择的操作有5种:identity(id,相等),大小为3*3或者5*5的separate conv(sep),大小为3*3的最大池化。)

可以看到对于节点3而言,controller采样的需要连接的两个节点都是节点2,两个节点预测的操作分别是sep 5*5和identity。

3.Training ENAS and Deriving Architectures

本小节介绍如何训练ENAS以及如何从ENAS的controller中生成框架结构。(Section 2.2)

controller网络是含有100个隐藏节点的LSTM。LSTM通过softmax分类器做出选择。另外在第一步时controller会接收一个空的embedding作为输入。

在ENAS中共有两组可学习的参数:

  • 子网络模型的共享参数,用\(w\)表示。
  • controller网络(即LSTM网络参数),用\(θ\)表示。

而训练ENAS的步骤主要包含两个交叉阶段:第一部训练子网络的共享参数\(w\);第二个阶段是训练controller的参数\(θ\)。这两个阶段在ENAS的训练过程中交替进行,具体介绍如下:

子网络模型共享参数\(w\)的训练

在这个步骤中,首先固定controller的policy network,即\(π(m;θ)\)。之后对\(w\)使用SGD算法来最小化期望损失函数\(E_{m~π}[L(m;w)]\)。

其中\(L(m;w)\)是标准的交叉熵损失函数:\(m\)表示根据policy network \(π(m;θ)\)生成的模型,然后用这个模型在一组训练数据集上计算得到的损失值。

根据Monte Carlo估计计算梯度公式如下:

\[\nabla_w E_{m-~π}(m;θ)[L(m;w)] ≈ \frac{1}{M} \sum_i^M \nabla_wL(m_i;w) \]

其中上式中的\(m_i\)表示由\(π(m;θ)\)生成的M个模型中的某一个模型。

虽然上式给出了梯度的无偏估计,但是方差比使用SGD得到的梯度的方差大。但是当\(M=1\)时,上式效果还可以。

训练controller参数θ

在这个步骤中,首先固定\(w\),之后通过求解最大化期望奖励\(E_{m~π}[R(m;w)]\)来更新\(θ\)。其中在语言模型实验中\(R(m;w)=\frac{c}{valid\_ppl}\),perplexity是基于小批量验证集计算得到的;在分类模型试验中,\(R(m;w)\)等于基于小批量验证集的准确率。

导出模型架构

首先使用\(π(m,θ)\)生成若干模型。

之后对于每一个采样得到的模型,直接计算其在验证集上得到的奖励。

最后选择奖励最高的模型再次从头训练。

当然如果像NAS那样把所有采样得到的子模型都先从头训练一边,也许会对实验结果有所提升。但是ENAS之所以Efficient,就是因为它不用这么做,原理继续看下文。

Evaluation 评估方法

1.在 Penn Treebank 数据集上训练的语言模型

2.在 CIFAR-10 数据集上的图像分类实验

由上表可以看出,ENAS的最终结果不如NAS,这是因为ENAS没有像NAS那样从训练后的controller中采样多个模型架构,然后从中选出在验证集上表现最好的一个。但是即便效果不如NAS,但是ENAS效果并不差太多,而且训练效率大幅提升。

下图是生成的宏观搜索空间。

ENAS 用了 11.5 个小时来发现合适的卷积单元和 reduction 单元,如下图所示。

Conclusion

ENAS能在Penn Treebank和CIFAR-10两个数据集上得到和NAS差不多的效果,而且训练时间大幅缩短,效率大大提升。

MARSGGBO♥原创







2018-8-7

论文笔记系列-Efficient Neural Architecture Search via Parameter Sharing的更多相关文章

  1. 论文笔记:Fast Neural Architecture Search of Compact Semantic Segmentation Models via Auxiliary Cells

    Fast Neural Architecture Search of Compact Semantic Segmentation Models via Auxiliary Cells 2019-04- ...

  2. 论文笔记:Progressive Neural Architecture Search

    Progressive Neural Architecture Search 2019-03-18 20:28:13 Paper:http://openaccess.thecvf.com/conten ...

  3. 论文笔记系列-DARTS: Differentiable Architecture Search

    Summary 我的理解就是原本节点和节点之间操作是离散的,因为就是从若干个操作中选择某一个,而作者试图使用softmax和relaxation(松弛化)将操作连续化,所以模型结构搜索的任务就转变成了 ...

  4. (转)Illustrated: Efficient Neural Architecture Search ---Guide on macro and micro search strategies in ENAS

    Illustrated: Efficient Neural Architecture Search --- Guide on macro and micro search strategies in  ...

  5. 论文笔记:DARTS: Differentiable Architecture Search

    DARTS: Differentiable Architecture Search 2019-03-19 10:04:26accepted by ICLR 2019 Paper:https://arx ...

  6. 论文笔记:Progressive Differentiable Architecture Search:Bridging the Depth Gap between Search and Evaluation

    Progressive Differentiable Architecture Search:Bridging the Depth Gap between Search and Evaluation ...

  7. 【论文笔记系列】AutoML:A Survey of State-of-the-art (下)

    [论文笔记系列]AutoML:A Survey of State-of-the-art (上) 上一篇文章介绍了Data preparation,Feature Engineering,Model S ...

  8. Research Guide for Neural Architecture Search

    Research Guide for Neural Architecture Search 2019-09-19 09:29:04 This blog is from: https://heartbe ...

  9. 论文笔记系列-Auto-DeepLab:Hierarchical Neural Architecture Search for Semantic Image Segmentation

    Pytorch实现代码:https://github.com/MenghaoGuo/AutoDeeplab 创新点 cell-level and network-level search 以往的NAS ...

随机推荐

  1. BZOJ1915[USACO 2010 Open Gold 1.Cow Hopscotch]——DP+斜率优化

    题目描述 奶牛们正在回味童年,玩一个类似跳格子的游戏,在这个游戏里,奶牛们在草地上画了一行N个格子,(3 <=N <= 250,000),编号为1..N.就像任何一个好游戏一样,这样的跳格 ...

  2. POJ - 1062(昂贵的聘礼)(有限制的spfa最短路)

    题意:...中文题... 昂贵的聘礼 Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 54350   Accepted: 16 ...

  3. 码云平台IDEA系列的插件使用

    一.IDEA插件安装 file -- setting --  Plugins -- 搜索gitee --  Search in repositories 安装后重启编译器 二.登录并拉取项目 file ...

  4. 【题解】 bzoj2006: [NOI2010]超级钢琴 (ST表+贪心)

    题面戳我 Solution 不会,看的题解 Attention 哇痛苦,一直不会打\(ST\)表,我是真的菜啊qwq 预处理 Log[1]=0;two[0]=1; for(int i=2;i<= ...

  5. 自学Linux Shell9.3-基于Red Hat系统工具包:RPM属性依赖的解决方式-YUM在线升级

    点击返回 自学Linux命令行与Shell脚本之路 9.3-基于Red Hat系统工具包:RPM属性依赖的解决方式-YUM在线升级 本节主要介绍基于Red Had的系统(测试系统centos) yum ...

  6. BZOJ 3526: [Poi2014]Card

    3526: [Poi2014]Card Time Limit: 25 Sec  Memory Limit: 64 MBSubmit: 267  Solved: 191[Submit][Status][ ...

  7. LOCALDB安装和连接

    关于LOCALDB的详细文档说明,包含安装,连接,共享连接等操作  https://technet.microsoft.com/zh-cn/hh510202 目的: 调试程序没有安装 sql serv ...

  8. Linux 常用命令——cat, tac, nl, more, less, head, tail, od

    Drecik学习经验分享 转载请注明出处:http://blog.csdn.net/drecik__/article/details/8453584 1. cat 由第一行开始显示文件内容 2. ta ...

  9. 【洛谷P4113】采花 HH的项链+

    题目大意:静态统计序列区间中出现次数大于等于 2 的颜色数. 题解:类似于HH的项链,只需将 i 和 pre[i] 的关系对应到 pre[i] 和 pre[pre[i]] 的关系即可. 代码如下 #i ...

  10. 【POJ1187】陨石的秘密

    题目大意: 定义一个串:只含有 '( )','[ ]','{ }',3种(6个)字符. 定义 SS 串: 空串是SS表达式. 若A是SS表达式,且A串中不含有中括号和大括号,则(A)是SS表达式. 若 ...