deeplearning搜索空间

搜索空间是神经网络搜索中的一个概念。搜索空间是一系列模型结构的汇集, SANAS主要是利用模拟退火的思想在搜索空间中搜索到一个比较小的模型结构或者一个精度比较高的模型结构。

paddleslim.nas 提供的搜索空间

根据初始模型结构构造搜索空间:

  1. MobileNetV2Space

      MobileNetV2的网络结构
  2. MobileNetV1Space

      MobilNetV1的网络结构
  3. ResNetSpace

      ResNetSpace的网络结构

根据相应模型的block构造搜索空间:

  1. MobileNetV1BlockSpace

      MobileNetV1Block的结构
  2. MobileNetV2BlockSpace

      MobileNetV2Block的结构
  3. ResNetBlockSpace

      ResNetBlock的结构
  4. InceptionABlockSpace

      InceptionABlock的结构
  5. InceptionCBlockSpace

      InceptionCBlock结构

搜索空间使用示例

  1. 使用paddleslim中提供用初始的模型结构来构造搜索空间的话,仅需要指定搜索空间名字即可。例如:如果使用原本的MobileNetV2的搜索空间进行搜索的话,传入SANAS中的configs直接指定为[(‘MobileNetV2Space’)]。
  2. 使用paddleslim中提供的block搜索空间构造搜索空间:

    2.1 使用input_size, output_size和block_num来构造搜索空间。例如:传入SANAS的configs可以指定为[(‘MobileNetV2BlockSpace’,
    {‘input_size’: 224, ‘output_size’: 32, ‘block_num’: 10})]。

    2.2 使用block_mask构造搜索空间。例如:传入SANAS的configs可以指定为[(‘MobileNetV2BlockSpace’,
    {‘block_mask’: [0, 1, 1, 1, 1, 0, 1, 0]})]。

自定义搜索空间(search space)

自定义搜索空间类需要继承搜索空间基类并重写以下几部分:

  1. 初始化的tokens(init_tokens函数),可以设置为自己想要的tokens列表, tokens列表中的每个数字指的是当前数字在相应的搜索列表中的索引。例如本示例中若tokens=[0, 3, 5],则代表当前模型结构搜索到的通道数为[8, 40, 128]。

  2. tokens中每个数字的搜索列表长度(range_table函数),tokens中每个token的索引范围。

  3. 根据tokens产生模型结构(token2arch函数),根据搜索到的tokens列表产生模型结构。

以新增reset block为例说明如何构造自己的search space。自定义的search space不能和已有的search space同名。

### 引入搜索空间基类函数和search space的注册类函数

from .search_space_base import SearchSpaceBase

from .search_space_registry import SEARCHSPACE

import numpy as np

### 需要调用注册函数把自定义搜索空间注册到space space

@SEARCHSPACE.register

### 定义一个继承SearchSpaceBase基类的搜索空间的类函数

class ResNetBlockSpace2(SearchSpaceBase):

def __init__(self, input_size, output_size, block_num, block_mask):

### 定义一些实际想要搜索的内容,例如:通道数、每个卷积的重复次数、卷积核大小等等

### self.filter_num
代表通道数的搜索列表

self.filter_num = np.array([8, 16, 32, 40, 64, 128, 256, 512])

### 定义初始化token,初始化token的长度根据传入的block_num或者block_mask的长度来得到的

def init_tokens(self):

return [0] * 3 * len(self.block_mask)

### 定义tokenindex的取值范围

def range_table(self):

return [len(self.filter_num)] * 3 * len(self.block_mask)

### token转换成模型结构

def token2arch(self, tokens=None):

if tokens == None:

tokens = self.init_tokens()

self.bottleneck_params_list = []

for i in range(len(self.block_mask)):

self.bottleneck_params_list.append(self.filter_num[tokens[i * 3 + 0]],

self.filter_num[tokens[i * 3 + 1]],

self.filter_num[tokens[i * 3 + 2]],

2 if self.block_mask[i] == 1 else 1)

def net_arch(input):

for i, layer_setting in enumerate(self.bottleneck_params_list):

channel_num, stride = layer_setting[:-1], layer_setting[-1]

input = self._resnet_block(input, channel_num, stride, name='resnet_layer{}'.format(i+1))

return input

return net_arch

### 构造具体block的操作

def _resnet_block(self, input, channel_num, stride, name=None):

shortcut_conv = self._shortcut(input, channel_num[2], stride, name=name)

input = self._conv_bn_layer(input=input, num_filters=channel_num[0], filter_size=1, act='relu', name=name + '_conv0')

input = self._conv_bn_layer(input=input, num_filters=channel_num[1], filter_size=3, stride=stride, act='relu', name=name + '_conv1')

input = self._conv_bn_layer(input=input, num_filters=channel_num[2], filter_size=1, name=name + '_conv2')

return fluid.layers.elementwise_add(x=shortcut_conv, y=input, axis=0, name=name+'_elementwise_add')

def _shortcut(self, input, channel_num, stride, name=None):

channel_in = input.shape[1]

if channel_in != channel_num or stride != 1:

return self.conv_bn_layer(input, num_filters=channel_num, filter_size=1, stride=stride, name=name+'_shortcut')

else:

return input

def _conv_bn_layer(self, input, num_filters, filter_size, stride=1, padding='SAME', act=None, name=None):

conv = fluid.layers.conv2d(input, num_filters, filter_size, stride, name=name+'_conv')

bn = fluid.layers.batch_norm(conv, act=act, name=name+'_bn')

return bn

deeplearning搜索空间的更多相关文章

  1. deeplearning算法优化原理

    deeplearning算法优化原理目录· 量化原理介绍 · 剪裁原理介绍 · 蒸馏原理介绍 · 轻量级模型结构搜索原理介绍 1. Quantization Aware Training量化介绍1.1 ...

  2. deeplearning模型库

    deeplearning模型库 1. 图像分类 数据集:ImageNet1000类 1.1  量化 分类模型Lite时延(ms) 设备 模型类型 压缩策略 armv7 Thread 1 armv7 T ...

  3. DeepLearning之路(三)MLP

    DeepLearning tutorial(3)MLP多层感知机原理简介+代码详解 @author:wepon @blog:http://blog.csdn.net/u012162613/articl ...

  4. DeepLearning之路(二)SoftMax回归

    Softmax回归   1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...

  5. 用中文把玩Google开源的Deep-Learning项目word2vec

    google最近新开放出word2vec项目,该项目使用deep-learning技术将term表示为向量,由此计算term之间的相似度,对term聚类等,该项目也支持phrase的自动识别,以及与t ...

  6. Deeplearning原文作者Hinton代码注解

    [z]Deeplearning原文作者Hinton代码注解 跑Hinton最初代码时看到这篇注释文章,很少细心,待研究... 原文地址:>http://www.cnblogs.com/BeDPS ...

  7. Google开源的Deep-Learning项目word2vec

    用中文把玩Google开源的Deep-Learning项目word2vec   google最近新开放出word2vec项目,该项目使用deep-learning技术将term表示为向量,由此计算te ...

  8. DeepLearning.ai学习笔记(一)神经网络和深度学习--Week3浅层神经网络

    介绍 DeepLearning课程总共五大章节,该系列笔记将按照课程安排进行记录. 另外第一章的前两周的课程在之前的Andrew Ng机器学习课程笔记(博客园)&Andrew Ng机器学习课程 ...

  9. DeepLearning.ai学习笔记汇总

    第一章 神经网络与深度学习(Neural Network & Deeplearning) DeepLearning.ai学习笔记(一)神经网络和深度学习--Week3浅层神经网络 DeepLe ...

随机推荐

  1. 源码篇:ThreadLocal的奇思妙想(万字图文)

    前言 ThreadLocal的文章在网上也有不少,但是看了一些后,理解起来总感觉有绕,而且看了ThreadLocal的源码,无论是线程隔离.类环形数组.弱引用结构等等,实在是太有意思了!我必须也要让大 ...

  2. 播放视频插件swfobject.js与Video Html5

    播放视频的方法: 方法一. 使用HTML5播放 <video src="./files/Clip_480_5sec_6mbps_h264.mp4" width="1 ...

  3. 【SpringBoot】SpringBoot 处理后端返回的小数(全局配置 + 定制化配置)

    一.抛出问题: 现在的项目中,存在这样的几个问题: 问题一.数据库存的数据类型是BigDecimal,或者代码中计算需要返回BigDecimal的值,由于BigDecimal返回给前端可能存在精度丢失 ...

  4. POJ1422 最小路径覆盖

    题意:      一个战场,往战场上投放伞兵,每个伞兵不能后退,只能往前走,问你最少多少个伞兵可以吧所有的点都占领. 思路:      这个题是最小路径覆盖,最小路径覆盖 = n - 最大匹配数,首先 ...

  5. CVE-2011-0104:Microsoft Office Excel 栈溢出漏洞修复分析

    0x01 前言 上一篇讲到了 CVE-2011-0104 漏洞的成因和分析的方法,并没有对修复后的程序做分析.之后在一次偶然的情况下,想看一看是怎么修复的,结果却发现了一些问题 环境:修复后的 EXC ...

  6. composer update -- memory_limit

    compsoer update取消memory_limit限制.取消扩展对于版本的限制 php -d memory_limit=-1 ./composer.phar update --ignore-p ...

  7. 远程连接mysql出现"Can't connect to MySQL server 'Ip' ()"的解决办法

    1.大多是防火墙的问题(参考链接:https://blog.csdn.net/jiezhi2013/article/details/50603366) 2.上面方法不能解决,不造成影响情况下可关闭防火 ...

  8. 用fseek和ftell获取文件的大小

    #include <stdio.h> #include <stdlib.h> #include <unistd.h> int main(int argc,char ...

  9. OO随笔之纠结的第二单元——多线程电梯

    综述 主要任务就是写一个电梯模拟器,读入每一个人的请求然后让电梯把他们送到想去的地方. 从第一次到第三次作业,三次的主要任务都是相同的,但是每次都增加了很多的细节,每次的难度都逐步增长,电梯复杂度和瞎 ...

  10. xxl-job的一些感悟与规范

    后台计划任务设计思路: 日志埋点处理,便于prd排查问题 2种主动job搭配规范(正向job.反查job) 1种消息接收的处理规范,重试机制,返回状态 job开关维度 数据流图 线上暗job-便捷性- ...