最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。

OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}

首先定义10个操作,依次解释:

  • class PoolBN(nn.Module):
    """
    AvgPool or MaxPool - BN
    """
    def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
    """
    Args:
    pool_type: 'max' or 'avg'
    """
    super().__init__()
    if pool_type.lower() == 'max':
    self.pool = nn.MaxPool2d(kernel_size, stride, padding)
    elif pool_type.lower() == 'avg':
    self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
    else:
    raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
    out = self.pool(x)
    out = self.bn(out)
    return out

    这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去

  • class Identity(nn.Module):
    def __init__(self):
    super().__init__() def forward(self, x):
    return x

    这个表示skip conncet

  • class FactorizedReduce(nn.Module):
    """
    Reduce feature map size by factorized pointwise(stride=2).
    """
    def __init__(self, C_in, C_out, affine=True):
    super().__init__()
    self.relu = nn.ReLU()
    self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
    x = self.relu(x)
    out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
    out = self.bn(out)
    return out

    这个表示将特征图大小变为原来的一半

  • class DilConv(nn.Module):
    """ (Dilated) depthwise separable conv
    ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
    5x5 conv => 9x9 receptive field
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
    bias=False),
    nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。

  • class SepConv(nn.Module):
    """ Depthwise separable conv
    DilConv(dilation=1) * 2
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
    DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,由两个上面的深度分组卷积组成

  • class FacConv(nn.Module):
    """ Factorized conv
    ReLU - Conv(Kx1) - Conv(1xK) - BN
    """
    def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
    nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    这个表示长方形的卷积,增加了一点特征图的长和宽

  • class Zero(nn.Module):
    def __init__(self, stride):
    super().__init__()
    self.stride = stride def forward(self, x):
    if self.stride == 1:
    return x * 0. # re-sizing by stride
    return x[:, :, ::self.stride, ::self.stride] * 0.

    这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变

DARTS代码分析(Pytorch)的更多相关文章

  1. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  2. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  3. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  4. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  5. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  6. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

  7. 常用 Java 静态代码分析工具的分析与比较

    常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...

  8. SonarQube-5.6.3 代码分析平台搭建使用

    python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...

  9. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

随机推荐

  1. li元素之间产生间隔

    是因为li标签换行导致的 简单的解决办法是将所有的li标签写到一行(不过实际上一般不会这样做) 或者把ul设置font-size为0,但这样ul中的文字就会消失,所以要记得单独给子元素设置font-s ...

  2. perl 数组变量(Array) 转载

    Perl 变量(2)--数组 原文地址:Perl 变量(2)--数组 作者:飞鸿无痕 二.数组 数组是标量数据的有序列表. 数组可以含任意多个元素.最小的数组可以不含元素,而最大的数组可以占满全部可用 ...

  3. fish-redux快速创建文件夹模板 FishReduxTemplate

    推荐一款插件: 在插件plugins中搜  FishReduxTemplate

  4. JavaScript 小技巧整理

    1.过滤唯一值 Set类型是在ES6中新增的,它类似于数组,但是成员的值都是唯一的,没有重复的值.结合扩展运算符(...)我们可以创建一个新的数组,达到过滤原数组重复值的功能. const array ...

  5. Java当中的集合框架

    Java当中的集合框架 01 在我们班里有50位同学,就有50位对象. // 简书作者:达叔小生 Student[] stus = new Student[20]; 结果来了一位插班生,该同学因为觉得 ...

  6. hive连接hbase

    使用hive连接hbase 前提说明:一个hive表指向一个hbase表,一对一,不能多对一 建立外部表 CREATE EXTERNAL TABLE test_hbase( key string, m ...

  7. Java file.encoding

    1. file.encoding属性的作用 file.encoding 的值是整个程序使用的编码格式. 可以使用  System.out.println(System.getProperty(&quo ...

  8. win10 c++程序打包

    步骤如下: 1. 先动态编译连链接,生成exe: 2. 找到exe依赖的dll文件 使用Process Explore来获取所依赖的dll文件 打开procexp.exe,通过菜单View–Lower ...

  9. 发送http请求和https请求的工具类

    package com.haiyisoft.cAssistant.utils; import java.io.IOException;import java.util.ArrayList; impor ...

  10. MIL/SIL/PIL/HIL/VIL

    MIL:Model in the loop 模型在环,对模型在模型的开发环境下(如SIMULINK)进行仿真,通过输入一系列的测试用例,验证模型是否满足设计的功能需求.验证控制算法模型是否准确地实现了 ...