源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

如图所示,cycleGAN的网络结构包括两个生成器G(X->Y)和F(Y->X),两个判别器Dx和Dy

生成器部分:网络整体上经过一个降采样然后上采样的过程,中间是一系列残差块,数目由实际情况确定,根据论文中所说,当输入分辨率为128x128,采用6个残差块,当输入分辨率为256x256甚至更高时,采用9个残差块,其源代码如下,

class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
super(Generator, self).__init__() # Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ] # Downsampling
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2 # Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)] # Upsampling
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2 # Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x):
return self.model(x)

其中,值得注意的网络层是nn.ReflectionPad2d和nn.InstanceNorm2d,前者搭配7x7卷积,先在特征图周围以反射的方式补长度,使得卷积后特征图尺寸不变,示例如下,输出结果就是以特征图边界为反射边,向外补充

nn.InstanceNorm2d是相比于batchNorm更加适合图像生成,风格迁移的归一化方法,相比于batchNorm跨样本,单通道统计,InstanceNorm采用单样本,单通道统计,括号中的参数代表通道数。

判别器部分:结构比生成器更加简单,经过5层卷积,通道数缩减为1,最后池化平均,尺寸也缩减为1x1,最最后reshape一下,变为(batchsize,1)

class Discriminator(nn.Module):
def __init__(self, input_nc):
super(Discriminator, self).__init__() # A bunch of convolutions one after another
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True) ] # FCN classification layer
model += [nn.Conv2d(512, 1, 4, padding=1)] self.model = nn.Sequential(*model) def forward(self, x):
x = self.model(x)
# Average pooling and flatten
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0])

【源码解读】cycleGAN(一):网络的更多相关文章

  1. Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)

    Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...

  2. SDWebImage源码解读 之 UIImage+GIF

    第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...

  3. AFNetworking 3.0 源码解读 总结(干货)(下)

    承接上一篇AFNetworking 3.0 源码解读 总结(干货)(上) 21.网络服务类型NSURLRequestNetworkServiceType 示例代码: typedef NS_ENUM(N ...

  4. AFNetworking 3.0 源码解读 总结(干货)(上)

    养成记笔记的习惯,对于一个软件工程师来说,我觉得很重要.记得在知乎上看到过一个问题,说是人类最大的缺点是什么?我个人觉得记忆算是一个缺点.它就像时间一样,会自己消散. 前言 终于写完了 AFNetwo ...

  5. AFNetworking 3.0 源码解读(十)之 UIActivityIndicatorView/UIRefreshControl/UIImageView + AFNetworking

    我们应该看到过很多类似这样的例子:某个控件拥有加载网络图片的能力.但这究竟是怎么做到的呢?看完这篇文章就明白了. 前言 这篇我们会介绍 AFNetworking 中的3个UIKit中的分类.UIAct ...

  6. AFNetworking 3.0 源码解读(九)之 AFNetworkActivityIndicatorManager

    让我们的APP像艺术品一样优雅,开发工程师更像是一名匠人,不仅需要精湛的技艺,而且要有一颗匠心. 前言 AFNetworkActivityIndicatorManager 是对状态栏中网络激活那个小控 ...

  7. AFNetworking 3.0 源码解读(八)之 AFImageDownloader

    AFImageDownloader 这个类对写DownloadManager有很大的借鉴意义.在平时的开发中,当我们使用UIImageView加载一个网络上的图片时,其原理就是把图片下载下来,然后再赋 ...

  8. AFNetworking 3.0 源码解读(七)之 AFAutoPurgingImageCache

    这篇我们就要介绍AFAutoPurgingImageCache这个类了.这个类给了我们临时管理图片内存的能力. 前言 假如说我们要写一个通用的网络框架,除了必备的请求数据的方法外,必须提供一个下载器来 ...

  9. AFNetworking 3.0 源码解读(六)之 AFHTTPSessionManager

    AFHTTPSessionManager相对来说比较好理解,代码也比较短.但却是我们平时可能使用最多的类. AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilit ...

  10. AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilityManager

    做ios开发,AFNetworking 这个网络框架肯定都非常熟悉,也许我们平时只使用了它的部分功能,而且我们对它的实现原理并不是很清楚,就好像总是有一团迷雾在眼前一样. 接下来我们就非常详细的来读一 ...

随机推荐

  1. jquery last 选择器 语法

    jquery last 选择器 语法 作用: :last 选择器选取最后一个元素.最常见的用法:与其他元素一起使用,选取指定组合中的最后一个元素(就像上面的例子). 语法:$(":last& ...

  2. mysql LAST()函数 语法

    mysql LAST()函数 语法 作用:返回指定的字段中最后一个记录的值. 语法:SELECT LAST(column_name) FROM table_name 注释:可使用 ORDER BY 语 ...

  3. BZOJ 4289: PA2012 Tax Dijkstra + 查分

    Description 给出一个N个点M条边的无向图,经过一个点的代价是进入和离开这个点的两条边的边权的较大值,求从起点1到点N的最小代价.起点的代价是离开起点的边的边权,终点的代价是进入终点的边的边 ...

  4. hdu 1166 线段树 区间求和 +单点更新 CD模板

    题目链接 敌兵布阵 Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others)Total S ...

  5. adaptiveThreshold(自适应阈值)

    void adaptiveThreshold(InputArray src, OutputArray dst, double maxValue, int adaptiveMethod, int thr ...

  6. BZOJ 4422 Cow Confinement (线段树、DP、扫描线、差分)

    题目链接: https://www.lydsy.com/JudgeOnline/problem.php?id=4422 我真服了..这题我能调一天半,最后还是对拍拍出来的...脑子还是有病啊 题解: ...

  7. 3D Computer Grapihcs Using OpenGL - 13 优化矩阵

    上节说过矩阵是可以结合的,而且相乘是按照和应用顺序相反的顺序进行的.我们之前初始化translationMatrix和rotationMatrix的时候,第一个参数都是使用的一个初始矩阵 glm::m ...

  8. 【转】毛虫算法——尺取法

    转自http://www.myexception.cn/program/1839999.html 妹子满分~~~~ 毛毛虫算法--尺取法 有这么一类问题,需要在给的一组数据中找到不大于某一个上限的&q ...

  9. @清晰掉 c语言三"巨头" const:volatile:static

    const: 1.如果把const放在变量类型前,说明这个变量的值是保持不变的(即为常量),改变量必须在定义时初始化,初始化后对她的任何赋值都是非法的. 2.当指针或是引用指向一个常量时,必须在类型名 ...

  10. system系统调用返回值判断命令是否执行成功

    system函数对返回值的处理,涉及3个阶段: 阶段1:创建子进程等准备工作.如果失败,返回-1. 阶段2:调用/bin/sh拉起shell脚本,如果拉起失败或者shell未正常执行结束(参见备注1) ...