最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。

OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}

首先定义10个操作,依次解释:

  • class PoolBN(nn.Module):
    """
    AvgPool or MaxPool - BN
    """
    def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
    """
    Args:
    pool_type: 'max' or 'avg'
    """
    super().__init__()
    if pool_type.lower() == 'max':
    self.pool = nn.MaxPool2d(kernel_size, stride, padding)
    elif pool_type.lower() == 'avg':
    self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
    else:
    raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
    out = self.pool(x)
    out = self.bn(out)
    return out

    这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去

  • class Identity(nn.Module):
    def __init__(self):
    super().__init__() def forward(self, x):
    return x

    这个表示skip conncet

  • class FactorizedReduce(nn.Module):
    """
    Reduce feature map size by factorized pointwise(stride=2).
    """
    def __init__(self, C_in, C_out, affine=True):
    super().__init__()
    self.relu = nn.ReLU()
    self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
    x = self.relu(x)
    out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
    out = self.bn(out)
    return out

    这个表示将特征图大小变为原来的一半

  • class DilConv(nn.Module):
    """ (Dilated) depthwise separable conv
    ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
    5x5 conv => 9x9 receptive field
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
    bias=False),
    nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。

  • class SepConv(nn.Module):
    """ Depthwise separable conv
    DilConv(dilation=1) * 2
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
    DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,由两个上面的深度分组卷积组成

  • class FacConv(nn.Module):
    """ Factorized conv
    ReLU - Conv(Kx1) - Conv(1xK) - BN
    """
    def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
    nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    这个表示长方形的卷积,增加了一点特征图的长和宽

  • class Zero(nn.Module):
    def __init__(self, stride):
    super().__init__()
    self.stride = stride def forward(self, x):
    if self.stride == 1:
    return x * 0. # re-sizing by stride
    return x[:, :, ::self.stride, ::self.stride] * 0.

    这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变

DARTS代码分析(Pytorch)的更多相关文章

  1. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  2. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  3. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  4. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  5. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  6. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

  7. 常用 Java 静态代码分析工具的分析与比较

    常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...

  8. SonarQube-5.6.3 代码分析平台搭建使用

    python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...

  9. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

随机推荐

  1. SpringBoot入门系列:第五篇 JPA mysql(转)

    一,准备工作,建立spring-boot-sample-mysql工程1.http://start.spring.io/ A.Artifact中输入spring-boot-sample-mysql B ...

  2. 线程安全 Vs 非线程安全

    线程安全:多线程访问时,采用了加锁机制,当一个线程读取数据时,其他线程不能访问直到该线程读取完毕.不会出现数据不一致或者脏数据. 非线程安全:不提供数据保护,可能出现其他线程访问时更改数据而该线程得到 ...

  3. EntityManager的merge()方法

    EntityManager的merge()方法相当于hibernate中session的saveOrUpdate()方法: 用于实体的插入和更新操作:

  4. .Net中手动实现AOP

    序言 资料 https://www.cnblogs.com/farb/p/AopImplementationTypes.html

  5. 利用msyqlfont + plsql 客户端 完成msyql数据向oracle的转移

    方法一: 1.这是mysqlfont 连接工具 ,选中表右键点击 输出->csv文件 2.选择导出的文件为ANSI型,因为csv文件excel打开的默认编码方式为ANSI这样可以防止中文在exc ...

  6. The Semantics of Constructors(拷贝构造函数之编译背后的行为)

    本文是 Inside The C++ Object Model's Chapter 2  的部分读书笔记. 有三种情况,需要拷贝构造函数: 1)object直接为另外一个object的初始值 2)ob ...

  7. MessagePack Java Jackson Dataformat - POJO 的序列化和反序列化

    在本测试代码中,我们定义了一个 POJO 类,名字为 MessageData,你可以访问下面的链接找到有关这个类的定义. https://github.com/cwiki-us-demo/serial ...

  8. noi.ac NOI挑战营模拟赛1-5

    注:因为博主是个每次考试都爆零垫底的菜鸡,所以此篇博客很有可能咕咕咕 (指只贴AC代码不写题解的......如果我真的不会做的话,就不能怪我了qwqwq) Day1 T1 swap 23pts 从一个 ...

  9. 2.1 MATLAB的数据类型

    2.1 MATLAB的数据类型 每种数据类型都是以矩阵的形式存在的 数据类型:数值型.字符型.元胞型.结构体.函数句柄 数值型包含:双精度类型.单精度类型.整型 支持不同数据的转换 2.1.1 变量与 ...

  10. 石川es6课程---11、json

    石川es6课程---11.json 一.总结 一句话总结: ` 感觉更方便了一点,增加了一些简写 ` key-value 一样时可以简写:console.log({ a,b}}) ` 里面函数可以简写 ...