StarGAN论文及代码理解
StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:

创新点:为了实现可转换到多个领域,StarGAN加入了一个域的控制信息,类似于CGAN的形式。在网络结构设计上,鉴别器不仅仅需要学习鉴别样本是否真实,还需要对真实图片判断来自哪个域。

整个网络的处理流程如下:
- 将输入图片x和目标生成域c结合喂入到生成网络G来合成fake图片
- 将fake图片和真实图片分别喂入到鉴别器D,D需要判断图片是否真实,还需要判断它来自哪个域
- 与CycleGAN类似,还有一个一致性约束,将生成的fake图片和原始图片的域信息c'结合起来喂入到生成器G要求能输出重建出原始输入图片x
下面分析一下各个部分的损失函数:
一:GAN常见的对抗损失:

二:对于给定的输入图片x和目标域标签c,网络的目标是将x转换成输出图片y,输出图片y能够被归类成目标域c。为了实现这一点就需要鉴别器有判别域的功能。所以作者在D的顶端加了一个额外的域分类器,域分类器loss在优化D和G时都会用到,作者将这一损失分为两个方向,分别用来优化G和D。(这很容易理解,因为如下分析可以看到公式(3)没有办法为D提供训练需要的监督信息)
一个是真实图片的域分类损失用来优化D,另一个是fake图片的域分类损失来优化G。
1)
Dcls(c'|x)代表D对真实图片计算得到的域标签概率分布。这一学习目标将会使得D能够将输入图片x识别为对应的域c',这里的(x,c')是训练集给定的。
2)
fake图片的域分类的损失函数定义如(3),它用来优化G,也就是让G尽力去生成图片让它能够被D分类成目标域c。
三:还有一个重建损失
通过最小化对抗损失与分类损失,G努力尝试做到生成目标域中的现实图片。但是这无法保证学习到的转换只会改变输入图片的域相关的信息而不改变图片内容。所以加上了周期一致性损失:
这里就是将G(x,c)和图片x的原始标签c'结合喂入到G中,将生成的图片和x计算1范数差异。
总体损失:
在实际操作上,作者将对抗损失换成了WGAN的对抗损失:

以上对于单个数据集的训练来说已经足够了,但是现在想想另一个问题,假如我要联合训练多个数据集呢?
举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?
一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。
因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。
以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。
论文部分到此结束,下面来分析一下代码:
主要的代码有model.py和solver.py两个。
在model.py中作者创建了生成器G与鉴别器D。
在生成器中先对模型降维缩小为原来4倍,再使用多个残差网络获得等维度输出,接着使用转置卷积放大4倍,最后通过一层尺寸不变的卷积,取tanh作为输出。

另外一个值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的,代码中可以看出就是直接在第四维度上进行拼接(pytorch一般为N*C*H*W,所以看起来是在第二维)。
对于鉴别器,使用conv1的输出代表域的预测概率,conv2的输出代表图片是否为真的判断。这两个的关系是并行的。

Solver.py比较长,挑选重要的部分来解释:
首先是梯度惩罚,这一部分来自WGAN的改善工作,主要是为了满足Lipschitz连续这个WGAN推导中需要的数学约束。

令人疑惑的是分类loss并不都是交叉熵损失,这是因为CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式,而RaFD则是一个onehot。
下面来看看在多个训练集训练时代码上是怎么操作的。
在数据加载上其实还是单个数据集轮流进行操作的,如下:

以上提到在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。

以生成器为例,计算损失时也是只在输出判断向量中提取该数据集已知的部分进行loss计算。

StarGAN论文及代码理解的更多相关文章
- HHL论文及代码理解(Generalizing A Person Retrieval Model Hetero- and Homogeneously ECCV 2018)
行人再识别Re-ID面临两个特殊的问题: 1)源数据集和目标数据集类别完全不同 2)相机造成的图片差异 因为一般来说传统的域适应问题源域和目标域的类别是相同的,相机之间的不匹配也是造成行人再识别数据集 ...
- Context Encoder论文及代码解读
经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...
- [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合
原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...
- linux io的cfq代码理解
内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...
- 10K+,深度学习论文、代码最全汇总!
我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...
- (转) AI突破性论文及代码实现汇总
本文转自:https://zhuanlan.zhihu.com/p/25191377 AI突破性论文及代码实现汇总 极视角 · 2 天前 What Can AI Do For You? “The bu ...
- 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
秦鼎涛 <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处 <Linux内核分析>MOOC课程http: ...
随机推荐
- Python标准库tempfile的使用总结
Python标准库tempfile的使用总结 临时文件是计算机程序存储临时数据的文件,它的扩展名通常是".temp".本文用于记录使用Python提供的临时文件API解决实际问题的 ...
- OV5640图像采集(一)VGA显示
vga控制器模块 1 引言 项目的背景是采集无人车间现场的工件图像并送往控制间pc端处理,最终实现缺陷检测.项目包括图像采集模块,数据传输模块,上位机,缺陷检测算法等四个部分.其中,图像采集模块又分 ...
- 打基础丨Python图像处理入门知识详解
摘要:本文讲解图像处理基础知识和OpenCV入门函数. 本文分享自华为云社区<[Python图像处理] 一.图像处理基础知识及OpenCV入门函数>,作者: eastmount. 一.图像 ...
- Python图像处理丨OpenCV+Numpy库读取与修改像素
摘要:本篇文章主要讲解 OpenCV+Numpy 图像处理基础知识,包括读取像素和修改像素. 本文分享自华为云社区<[Python图像处理] 二.OpenCV+Numpy库读取与修改像素> ...
- 真实本人亲测Elasticsearch未授权访问漏洞——利用及修复【踩坑指南到脱坑!】
如要转载请注明出处谢谢: https://www.cnblogs.com/vitalemontea/p/16105490.html 1.前言 某天"发现"了个漏洞,咳咳,原本以为这 ...
- jquery 动态 给select赋值
<div class="right_left"> <select id="supply"> <option>请选择供应商&l ...
- MyBatis插件 - 通用mapper
1.简单认识通用mapper 1.1.了解mapper 作用:就是为了帮助我们自动的生成sql语句 [ ps:MyBatis需要编写xxxMapper.xml,而逆向工程是根据entity实体类来进行 ...
- rabbitmq 安装延时队列插件rabbitmq-delayed-message-exchange
1.下载rabbitmq-delayed-message-exchange(注意版本对应) 链接:https://github.com/rabbitmq/rabbitmq-delayed-messag ...
- 使用 Ansible 快速部署 HBase 集群
背景 出于数据安全的考虑,自研了一个低成本的时序数据存储系统,用于存储历史行情数据. 系统借鉴了 InfluxDB 的列存与压缩策略,并基于 HBase 实现了海量存储能力. 由于运维同事缺乏 Had ...
- 一文彻底搞懂JavaScript中的prototype
prototype初步认识 在学习JavaScript中,遇到了prototype,经过一番了解,知道它是可以进行动态扩展的 function Func(){}; var func1 = new Fu ...