【源码解读】cycleGAN(一):网络
源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

如图所示,cycleGAN的网络结构包括两个生成器G(X->Y)和F(Y->X),两个判别器Dx和Dy
生成器部分:网络整体上经过一个降采样然后上采样的过程,中间是一系列残差块,数目由实际情况确定,根据论文中所说,当输入分辨率为128x128,采用6个残差块,当输入分辨率为256x256甚至更高时,采用9个残差块,其源代码如下,
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
super(Generator, self).__init__() # Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ] # Downsampling
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2 # Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)] # Upsampling
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2 # Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x):
return self.model(x)
其中,值得注意的网络层是nn.ReflectionPad2d和nn.InstanceNorm2d,前者搭配7x7卷积,先在特征图周围以反射的方式补长度,使得卷积后特征图尺寸不变,示例如下,输出结果就是以特征图边界为反射边,向外补充

nn.InstanceNorm2d是相比于batchNorm更加适合图像生成,风格迁移的归一化方法,相比于batchNorm跨样本,单通道统计,InstanceNorm采用单样本,单通道统计,括号中的参数代表通道数。
判别器部分:结构比生成器更加简单,经过5层卷积,通道数缩减为1,最后池化平均,尺寸也缩减为1x1,最最后reshape一下,变为(batchsize,1)
class Discriminator(nn.Module):
def __init__(self, input_nc):
super(Discriminator, self).__init__() # A bunch of convolutions one after another
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True) ] # FCN classification layer
model += [nn.Conv2d(512, 1, 4, padding=1)] self.model = nn.Sequential(*model) def forward(self, x):
x = self.model(x)
# Average pooling and flatten
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0])
【源码解读】cycleGAN(一):网络的更多相关文章
- Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)
Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...
- SDWebImage源码解读 之 UIImage+GIF
第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...
- AFNetworking 3.0 源码解读 总结(干货)(下)
承接上一篇AFNetworking 3.0 源码解读 总结(干货)(上) 21.网络服务类型NSURLRequestNetworkServiceType 示例代码: typedef NS_ENUM(N ...
- AFNetworking 3.0 源码解读 总结(干货)(上)
养成记笔记的习惯,对于一个软件工程师来说,我觉得很重要.记得在知乎上看到过一个问题,说是人类最大的缺点是什么?我个人觉得记忆算是一个缺点.它就像时间一样,会自己消散. 前言 终于写完了 AFNetwo ...
- AFNetworking 3.0 源码解读(十)之 UIActivityIndicatorView/UIRefreshControl/UIImageView + AFNetworking
我们应该看到过很多类似这样的例子:某个控件拥有加载网络图片的能力.但这究竟是怎么做到的呢?看完这篇文章就明白了. 前言 这篇我们会介绍 AFNetworking 中的3个UIKit中的分类.UIAct ...
- AFNetworking 3.0 源码解读(九)之 AFNetworkActivityIndicatorManager
让我们的APP像艺术品一样优雅,开发工程师更像是一名匠人,不仅需要精湛的技艺,而且要有一颗匠心. 前言 AFNetworkActivityIndicatorManager 是对状态栏中网络激活那个小控 ...
- AFNetworking 3.0 源码解读(八)之 AFImageDownloader
AFImageDownloader 这个类对写DownloadManager有很大的借鉴意义.在平时的开发中,当我们使用UIImageView加载一个网络上的图片时,其原理就是把图片下载下来,然后再赋 ...
- AFNetworking 3.0 源码解读(七)之 AFAutoPurgingImageCache
这篇我们就要介绍AFAutoPurgingImageCache这个类了.这个类给了我们临时管理图片内存的能力. 前言 假如说我们要写一个通用的网络框架,除了必备的请求数据的方法外,必须提供一个下载器来 ...
- AFNetworking 3.0 源码解读(六)之 AFHTTPSessionManager
AFHTTPSessionManager相对来说比较好理解,代码也比较短.但却是我们平时可能使用最多的类. AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilit ...
- AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilityManager
做ios开发,AFNetworking 这个网络框架肯定都非常熟悉,也许我们平时只使用了它的部分功能,而且我们对它的实现原理并不是很清楚,就好像总是有一团迷雾在眼前一样. 接下来我们就非常详细的来读一 ...
随机推荐
- Python稀疏矩阵运算
import numpy as np import scipy import time import scipy.sparse as sparse t = [1]+[0]*4999 a = scipy ...
- KafKa集群安装详细步骤
最近在使用Spring Cloud进行分布式微服务搭建,顺便对集成KafKa的方案做了一些总结,今天详细介绍一下KafKa集群安装过程: 1. 在根目录创建kafka文件夹(service1.serv ...
- sqli-labs(11)
基于登录点的注入(小编这里傻逼了 可以直接用group_concat函数绕过显示问题我还在用limit绕过) 0X01这里我们的参数就不是在get的方法里面提交的了 我们遇到了全新的问题 那么该怎么 ...
- k8s中pod内dns无法解析的问题
用k8s创建了pod,然后进入pod后,发现在pod中无法解析www.baidu.com,也就是出现了无法解析外面的域名的问题.经过高人指点,做个小总结.操作如下. 一,将CoreDNS 的Confi ...
- 题目1.A乘以B
看我没骗你吧 —— 这是一道你可以在10秒内完成的题:给定两个绝对值不超过100的整数A和B,输出A乘以B的值. 1实验代码 #include<stdio.h>int main(void) ...
- legend3---Homestead常用操作代码
legend3---Homestead常用操作代码 一.总结 一句话总结: 在虚拟机里面改变文件windows里面也会变,在windows里面改变虚拟机里面也会变,所以可以在windows里面编程或者 ...
- 一个”.java”源文件中是否可以包含多个类(不是内部类)?有什么限制
这个是可以的,一个“.java”源文件里面可以包含多个类,但是只允许有一个public类,并且类名必须和文件名一致. 每个编译单元只能有一个public 类.这么做的意思是,每个编译单元只能有一个公开 ...
- 使用collection:分段查询结果集
1.在人员接口书写方法 public List<Employee> getEmpsByDeptId(Integer deptId); 2在人员映射文件中进行配置 <!-- publi ...
- 双三次插值C代码(利用opencv)
双三次插值C代码(利用opencv) phasecubic2.cpp D:\文件及下载相关\文档\Visual Studio 2010\Projects\phasecubic2\phasecubic2 ...
- word中打字会覆盖下一个字
insert键 误按了insert键,此时Word默认为改写模式,输入文本会覆盖后面的内容.