DARTS代码分析(Pytorch)
最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。
OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}
首先定义10个操作,依次解释:
class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去
class Identity(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
return x这个表示skip conncet
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out这个表示将特征图大小变为原来的一半
class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
) def forward(self, x):
return self.net(x)深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。
class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
) def forward(self, x):
return self.net(x)深度可分离卷积,由两个上面的深度分组卷积组成
class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
) def forward(self, x):
return self.net(x)这个表示长方形的卷积,增加了一点特征图的长和宽
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride def forward(self, x):
if self.stride == 1:
return x * 0. # re-sizing by stride
return x[:, :, ::self.stride, ::self.stride] * 0.这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变
DARTS代码分析(Pytorch)的更多相关文章
- Android代码分析工具lint学习
1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...
- pmd静态代码分析
在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...
- [Asp.net 5] DependencyInjection项目代码分析-目录
微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...
- [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)
Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...
- 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)
构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...
- STM32启动代码分析 IAR 比较好
stm32启动代码分析 (2012-06-12 09:43:31) 转载▼ 最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...
- 常用 Java 静态代码分析工具的分析与比较
常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...
- SonarQube-5.6.3 代码分析平台搭建使用
python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...
- angular代码分析之异常日志设计
angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...
随机推荐
- PHP类知识----静态属性和方法
<?php class mycoach { public $name="陈培昌"; CONST hisage =; ; private $favorite = "喜 ...
- nginx日志、变量
日志格式类型等 包含两类:access_log error.log log_format log只能在http模块下配置 下图是一个典型error_log配置 warn表示默认日志级别为‘’警告‘’ ...
- Vim常用插件命令手册
此文章记录了,笔者使用的插件中的主要命令. junegunn/vim-plug :PlugInstall 安装插件 :PlugClean 清理插件 :PlugUpgrade 升级插件管理器 :Plug ...
- WCF错误:由于目标计算机积极拒绝,无法连接;127.0.0.1:3456
问题描述 最近Windows打完补丁,原来部署在本机的WCF无法连接:出现如下WCF错误:由于目标计算机积极拒绝,无法连接:127.0.0.1:3456 解决方案 检查一下本机的服务:NetTcpAc ...
- node的http与前端交互示例(入门)
一.目录(node_modules是npm install后新增的) node 和 npm 版本 npm install http 二.node下的index.js var http = requir ...
- 什么是URL百分号编码?
㈠什么是URL 统一资源定位系统(uniform resource locator;URL)是因特网的万维网服务程序上用于指定信息位置的表示方法. ㈡URL编码 url编码是一种浏览器用来打包表单输入 ...
- removeAttr(name)
removeAttr(name) 概述 从每一个匹配的元素中删除一个属性 1.6以下版本在IE6使用JQuery的removeAttr方法删除disabled是无效的.解决的方法就是使用$(" ...
- Comet OJ - Contest #10 鱼跃龙门 exgcd+推导
考试的时候推出来了,但是忘了 $exgcd$ 咋求,成功爆蛋~ 这里给出一个求最小正整数解的模板: ll solve(ll A,ll B,ll C) { ll x,y,g,b,ans; gcd = e ...
- svn的下载与安装,使用,包教包会!!!
svn的安装使用说明 下载svn服务器与搭建 高效开发 — SVN使用教程(客户端与服务端安装详解!带图!带注释!安装客户端与服务端的地址可以看上两个链接) svn安装分为两部分,服务端安装与客户端安 ...
- 无法连接虚拟设备 ide1:0
问题: 启动vmware之后,发现出现无法连接 ide 1:0. 网络查找之后,发现是之前挂载的iso镜像找不到了. 原因: 我把iso镜像放到其他位置. 解决: 指定iso文件的位置. 参考:htt ...