最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。

OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}

首先定义10个操作,依次解释:

  • class PoolBN(nn.Module):
    """
    AvgPool or MaxPool - BN
    """
    def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
    """
    Args:
    pool_type: 'max' or 'avg'
    """
    super().__init__()
    if pool_type.lower() == 'max':
    self.pool = nn.MaxPool2d(kernel_size, stride, padding)
    elif pool_type.lower() == 'avg':
    self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
    else:
    raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
    out = self.pool(x)
    out = self.bn(out)
    return out

    这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去

  • class Identity(nn.Module):
    def __init__(self):
    super().__init__() def forward(self, x):
    return x

    这个表示skip conncet

  • class FactorizedReduce(nn.Module):
    """
    Reduce feature map size by factorized pointwise(stride=2).
    """
    def __init__(self, C_in, C_out, affine=True):
    super().__init__()
    self.relu = nn.ReLU()
    self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
    x = self.relu(x)
    out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
    out = self.bn(out)
    return out

    这个表示将特征图大小变为原来的一半

  • class DilConv(nn.Module):
    """ (Dilated) depthwise separable conv
    ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
    5x5 conv => 9x9 receptive field
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
    bias=False),
    nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。

  • class SepConv(nn.Module):
    """ Depthwise separable conv
    DilConv(dilation=1) * 2
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
    DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,由两个上面的深度分组卷积组成

  • class FacConv(nn.Module):
    """ Factorized conv
    ReLU - Conv(Kx1) - Conv(1xK) - BN
    """
    def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
    nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    这个表示长方形的卷积,增加了一点特征图的长和宽

  • class Zero(nn.Module):
    def __init__(self, stride):
    super().__init__()
    self.stride = stride def forward(self, x):
    if self.stride == 1:
    return x * 0. # re-sizing by stride
    return x[:, :, ::self.stride, ::self.stride] * 0.

    这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变

DARTS代码分析(Pytorch)的更多相关文章

  1. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  2. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  3. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  4. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  5. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  6. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

  7. 常用 Java 静态代码分析工具的分析与比较

    常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...

  8. SonarQube-5.6.3 代码分析平台搭建使用

    python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...

  9. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

随机推荐

  1. maven项目pom.xml中parent标签的使用(转)

    原文地址:https://blog.csdn.net/qq_41254677/article/details/81011681 使用maven是为了更好的帮项目管理包依赖,maven的核心就是pom. ...

  2. SpringBoot AOP介绍

    说起spring,我们知道其最核心的两个功能就是AOP(面向切面)和IOC(控制反转),这边文章来总结一下SpringBoot如何整合使用AOP. 一.示例应用场景:对所有的web请求做切面来记录日志 ...

  3. JavaMail应用--通过javamail API实现在代码中发送邮件功能

    JavaMail应用   在日常开发中,可能会引用到发邮件功能,例如在持续集成中,自动化测试运行完毕,自动将测试结果以报表的形式发送邮件给相关人.那么在Java中如何实现发邮件呢? 在java EE ...

  4. web上传整个文件夹

    在Web应用系统开发中,文件上传和下载功能是非常常用的功能,今天来讲一下JavaWeb中的文件上传和下载功能的实现. 先说下要求: PC端全平台支持,要求支持Windows,Mac,Linux 支持所 ...

  5. 判断一个ip地址合法性(基础c,不用库函数)

    #include <stdio.h> int judge(char *strIp); int main() { ]; ) { scanf("%s", a); == ju ...

  6. 【CUDA 基础】3.3 并行性表现

    title: [CUDA 基础]3.3 并行性表现 categories: - CUDA - Freshman tags: - nvprof toc: true date: 2018-04-15 21 ...

  7. MySQL_(Java)【事物操作】使用JDBC模拟银行转账向数据库发起修改请求

    MySQL_(Java)使用JDBC向数据库发起查询请求 传送门 MySQL_(Java)使用JDBC向数据库中插入(insert)数据 传送门 MySQL_(Java)使用JDBC向数据库中删除(d ...

  8. Android_(自动化)自动获取手机电池的剩余电量

    自动获取手机电池的剩余电量 通过使用BroadcastReceiver的特性来获取手机电池的电量,注册BroadcastReceiver时设置的IntentFilter来获取系统发出的Intent.A ...

  9. ssh以及双机互信

    当我们要远程到其他主机上面时就需要使用ssh服务了. 我们就来安装一下sshd服务以及ssh命令的使用方法. 服务安装: 需要安装OpenSSH 四个安装包: 安装包: openssh-5.3p1-1 ...

  10. [CSP-S模拟测试]:D(暴力+剪枝)

    题目传送门(内部题47) 输入格式 第一行一个正整数$n$.第二行$n$个正整数,表示序列$A_i$. 输出格式 一行一个正整数,表示答案. 样例 样例输入: 530 60 20 20 20 样例输出 ...