作者对residual network进行了改进:加入了gating network,基于上一层的激活值,得到一个二进制的决策0或1,从而继续推断或跳过下一个block。作者还提出了对应的训练方法,集成有监督学习和强化学习,从而克服了skipping不可差分的问题。

1. 概括

难点:skipping决策是不可差分的,那么就无法用基于梯度的优化方法进行学习。

  1. [2,30,31]提出了软近似,但实验发现它们的精度很差。

    We show that the subsequent hard thresholding required to reduce computation results in low accuracy.

  2. [4,23]则提出用强化学习解决硬判决问题,但实验发现它们很脆弱,即精度也很差。

  3. [16,21]还采用了reparametrization技术,但其中的松弛会引入估计误差,导致习得策略欠佳。

训练方法大致分为2步:

  1. 借助reparameterization和soft-max松弛,同时训练网络和门限。

  2. 取消松弛,借助强化学习,继续精炼skipping政策。

实验结果:在CIFAR-10、CIFAR-100、SVHN和ImageNet上,分别能降低50%、37%、86%和30%的计算量。并且,SkipNet也存在一个超参数,可以针对不同计算量约束进行调节。

2. 相关工作

为了实现模型压缩,大多数工作集中在参数稀疏化、滤波器剪枝,向量量化和蒸馏。这些方法的共同问题:

  1. 通常是后处理,即对已经训练好的网络执行的操作。

  2. 并不能根据输入动态调整网络。

还有一些工作[6,8,29]通过提前终止实现这一目标。其中[8]是暂停循环过程,[6]和[29]是提前终止CNN。但本文的SkipNet是跳过而不是提前终止。

还有一些工作[1,22,32]集成了不同计算复杂度的多个模型,并设计决策机制或终止机制。但是这样做严重浪费了存储,并且每个模型并不存在计算共享。

3. 方法细节

基本方法就是:在ResNet的基础上,加入了门限网络。其将上一层的输入映射至0或1,从而跳过或执行下一层。

注意:要求输入、输出的维度相同。而ResNet的块结构正好满足这一要求。或者要采用池化等操作。

门限模块的结构

作者尝试了三种结构:

前两种都是CNN结构,第三种是RNN结构。第一种计算量大,作者只用于浅层网络;第二种计算量小,作者用于超过百层的网络。在后续实验中,作者发现循环网络效果最好,不仅计算量远小,而且精度也高。这归功于其时序学习能力。

训练方法

最简单直接的方法就是用softmax软化(例如Highway Networks),使得网络参数能够差分;而在推导(测试)时再用硬判决。但实验发现其精度很差,原因是其中存在误差。

作者决定在训练阶段保留硬判决。现在我们分析损失函数。假设第\(i\)层的输入是\(\mathbf{X}^i\),门模块是\(G^i(\mathbf{X}^i)\),判决结果是\(g_i\)。\(g_i = 1\)时,该层执行;\(g_i = 0\)时,该层被跳过(输入直接恒等映射至输出)。一共\(N\)层,则总判决为\(\mathbf{g} = \{0, 1\}^N\)。

假设网络每一层参数的集合(包括门模块)为:\(F_{\theta} = [F_{\theta}^1, ..., F_{\theta}^N]\)。在给定\(\mathbf{X}\)和\(\mathbf{g}\)的情况下,损失为:
\[
L_{\theta}(\mathbf{g}, \mathbf{X}) = \mathcal{L}(\hat{y}(\mathbf{X}, F_{\theta}, g), y) - \frac{\alpha}{N} \sum_{i=1}^N (1 - g_i) C_i
\]

前半部分应该是有监督学习中的保真度(fidelity)或者准确度指标之类的【作者没提】,后者惩罚的是计算量。其中\(C_i\)用来调节\(F_i\)的重要性【注意负号】,作者设恒为1。\(\alpha\)是权衡计算量和精度的超参数。

进一步,右半部分可以视为强化学习中的奖励(reward)。

我们的训练目标严格写是这样:
\[
\min \mathcal{J}(\theta) = \min \mathbb{E}_{\mathbf{X}} \mathbb{E}_{\mathbf{g}} L_{\theta}(\mathbf{g}, \mathbf{X})
\]

即:对训练集中的所有样本取统计平均(一般就是平权,因为假设i.i.d.),对所有可能的判决集结果取统计平均,并最终实现 最小化误差的同时 最小化计算量。二者相对重要性由\(\alpha\)调控。

我们也可以看看该训练目标函数的梯度。注意梯度是关于参数\(\theta\)的梯度:

第二步的右半部分是这样的,熟悉RL的同学都很清楚:
\[
\nabla_{\theta} \log p_{\theta}(\mathbf{g} | \mathbf{X}) = \frac{1}{p_{\theta}(\mathbf{g} | \mathbf{X})} \nabla_{\theta} p_{\theta}(\mathbf{g} | \mathbf{X}) =
\]

对于最终结果,左半部分就可以看作监督学习损失函数的梯度,右半部分就可以看作强化学习损失的梯度。其中:
\[
r_i = - [\mathcal{L} - \frac{\alpha}{N} \sum_{j=i}^N R_j]
\]

在实际操作中,我们降低对精度的要求,给前半部分加一个超参数:
\[
r_i = - [\beta \mathcal{L} - \frac{\alpha}{N} \sum_{j=i}^N R_j]
\]

作者设\(\beta = \frac{\alpha}{N}\)或1。

实际上,分两个部分分别训练是不完美的,但是一个折衷的处理方式。作者首先使用监督学习,让网络参数初步收敛。然后再采用强化学习。实验发现,如果直接将上式作为强化学习的激励,那么训练效果会很不好。原因可能是学习的策略过早收敛于垃圾特征。

实验有几个有趣的发现:

  1. 简单的样本(跳过层数多)偏亮,清晰,对比度高:

  2. 越大尺度的图像平均需要块越多(可能因为感受野不够):

  3. 前面层和后面层被跳过比较频繁,中间层跳过率很低。

  4. 有监督预训练为强化学习提供了很好的起点。

  5. 在计算量相同的情况下,硬判决的精度远高于软判决。

4. 总结

优点:不同于提前退出,这种方法比较新。

不足:每一层或块的输入、输出维度必须相同,否则无法执行跳过判决(跳过或执行的输出维度必须得一致)。或者需要池化等额外操作。

Paper | SkipNet: Learning Dynamic Routing in Convolutional Networks的更多相关文章

  1. FlowNet: Learning Optical Flow with Convolutional Networks

    作者:嫩芽33出处:http://www.cnblogs.com/nenya33/p/7122701.html 版权:本文版权归作者和博客园共有 转载:欢迎转载,但未经作者同意,必须保留此段声明:必须 ...

  2. 论文翻译——Character-level Convolutional Networks for Text Classification

    论文地址 Abstract Open-text semantic parsers are designed to interpret any statement in natural language ...

  3. (原)DropBlock A regularization method for convolutional networks

    转载请注明出处: https://www.cnblogs.com/darkknightzh/p/9985027.html 论文网址: https://arxiv.org/abs/1810.12890 ...

  4. 论文笔记:Learning Dynamic Memory Networks for Object Tracking

    Learning Dynamic Memory Networks for Object Tracking  ECCV 2018Updated on 2018-08-05 16:36:30 Paper: ...

  5. Paper | Densely Connected Convolutional Networks

    目录 黄高老师190919在北航的报告听后感 故事背景 网络结构 Dense block DenseNet 过渡层 成长率 瓶颈层 细节 实验 发表在2017 CVPR. 摘要 Recent work ...

  6. Hinton's paper Dynamic Routing Between Capsules 的 Tensorflow , Keras ,Pytorch实现

    Tensorflow 实现 A Tensorflow implementation of CapsNet(Capsules Net) in Hinton's paper Dynamic Routing ...

  7. Deep Learning 33:读论文“Densely Connected Convolutional Networks”-------DenseNet 简单理解

    一.读前说明 1.论文"Densely Connected Convolutional Networks"是现在为止效果最好的CNN架构,比Resnet还好,有必要学习一下它为什么 ...

  8. How to do Deep Learning on Graphs with Graph Convolutional Networks

    翻译: How to do Deep Learning on Graphs with Graph Convolutional Networks 什么是图卷积网络 图卷积网络是一个在图上进行操作的神经网 ...

  9. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

随机推荐

  1. 洛谷 P2657 (数位DP)

    ### 洛谷 P2657 题目链接 ### 题目大意:给你一个数的范围 [A,B] ,问你这段区间内,有几个数满足如下条件: 1.两个相邻数位上的数的差值至少为 2 . 2.不包含前导零. 很简单的数 ...

  2. Dubbo+Zookeeper实现简单的远程方法调用示例

    1. Dubbo介绍 示例代码:Github 1.1 RPC Remote Procedure Call:远程过程调用 1.2 Dubbo架构 Subscribe 订阅:签署:赞成 Monitor 监 ...

  3. Mysql优化之Explain查询计划查看

    我们经常说到mysql优化,优化中一种常见的方式就是对于经常查询的字段创建索引.那么mysql中有哪些索引类型呢? 一.索引分类1.普通索引:即一个索引只包含单个列,一个表可以有多个单列索引 2.唯一 ...

  4. MySQL(10)---自定义函数

    MySQL(10)---自定义函数 之前讲过存储过程,存储过程和自定义函数还是非常相似的,其它的可以认为和存储过程是一样的,比如含义,优点都可以按存储过程的优点来理解. 存储过程相关博客: 1.MyS ...

  5. C#关闭多线程程序

    Process[] processes = System.Diagnostics.Process.GetProcesses(); //获得所有进程 foreach (Process p in proc ...

  6. C# read dll config

    public static SqlConnection GetSqlConnection() { Configuration myDllConfig = ConfigurationManager.Op ...

  7. 《Head First C#》外星人入侵WPF编写源码

    目录 引言 前期工作 只要努力没有什么困难可以难倒你,加油骚年! @(外星人入侵(WPF编写)) 引言 自学的C#,看了几本教材讲的都是程序代码,网上找的也有视屏,但都比较老了.只会打些代码为不晓得为 ...

  8. C# 校验并转换 16 进制字符串到字节数组

    问题 最近在进行硬件上位机开发的时候,经常会遇到将 16 进制字符串转换为 byte[] 的情况,除了这种需求以外,还需要判定一个字符串是否是有效的 16 进制数据. 解决 字符串转 byte[] 的 ...

  9. 湖南省web应用软件(中慧杯)

    湖南省web应用软件 写这篇博客已经是比完赛的第四天了,我还记得那天下着小雨.我们早早的到了比赛的现场抽检机器,在比赛前一天我很是激动.我还记得我们从学校,去株洲的时候我们的领导来给我加油,特别是我的 ...

  10. sql 小全

    前些日子sql用到哪里写到哪里,乱七八糟,今天整理了一下,以作备份(虽然开通博客已经八个月了,但是今天还是第一次发表博文,好紧张啊~~) --2014.08.27号整理sql语句 1:进入数据库 us ...