StarGAN论文及代码理解
StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:

创新点:为了实现可转换到多个领域,StarGAN加入了一个域的控制信息,类似于CGAN的形式。在网络结构设计上,鉴别器不仅仅需要学习鉴别样本是否真实,还需要对真实图片判断来自哪个域。

整个网络的处理流程如下:
- 将输入图片x和目标生成域c结合喂入到生成网络G来合成fake图片
- 将fake图片和真实图片分别喂入到鉴别器D,D需要判断图片是否真实,还需要判断它来自哪个域
- 与CycleGAN类似,还有一个一致性约束,将生成的fake图片和原始图片的域信息c'结合起来喂入到生成器G要求能输出重建出原始输入图片x
下面分析一下各个部分的损失函数:
一:GAN常见的对抗损失:

二:对于给定的输入图片x和目标域标签c,网络的目标是将x转换成输出图片y,输出图片y能够被归类成目标域c。为了实现这一点就需要鉴别器有判别域的功能。所以作者在D的顶端加了一个额外的域分类器,域分类器loss在优化D和G时都会用到,作者将这一损失分为两个方向,分别用来优化G和D。(这很容易理解,因为如下分析可以看到公式(3)没有办法为D提供训练需要的监督信息)
一个是真实图片的域分类损失用来优化D,另一个是fake图片的域分类损失来优化G。
1)
Dcls(c'|x)代表D对真实图片计算得到的域标签概率分布。这一学习目标将会使得D能够将输入图片x识别为对应的域c',这里的(x,c')是训练集给定的。
2)
fake图片的域分类的损失函数定义如(3),它用来优化G,也就是让G尽力去生成图片让它能够被D分类成目标域c。
三:还有一个重建损失
通过最小化对抗损失与分类损失,G努力尝试做到生成目标域中的现实图片。但是这无法保证学习到的转换只会改变输入图片的域相关的信息而不改变图片内容。所以加上了周期一致性损失:
这里就是将G(x,c)和图片x的原始标签c'结合喂入到G中,将生成的图片和x计算1范数差异。
总体损失:
在实际操作上,作者将对抗损失换成了WGAN的对抗损失:

以上对于单个数据集的训练来说已经足够了,但是现在想想另一个问题,假如我要联合训练多个数据集呢?
举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?
一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。
因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。
以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。
论文部分到此结束,下面来分析一下代码:
主要的代码有model.py和solver.py两个。
在model.py中作者创建了生成器G与鉴别器D。
在生成器中先对模型降维缩小为原来4倍,再使用多个残差网络获得等维度输出,接着使用转置卷积放大4倍,最后通过一层尺寸不变的卷积,取tanh作为输出。

另外一个值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的,代码中可以看出就是直接在第四维度上进行拼接(pytorch一般为N*C*H*W,所以看起来是在第二维)。
对于鉴别器,使用conv1的输出代表域的预测概率,conv2的输出代表图片是否为真的判断。这两个的关系是并行的。

Solver.py比较长,挑选重要的部分来解释:
首先是梯度惩罚,这一部分来自WGAN的改善工作,主要是为了满足Lipschitz连续这个WGAN推导中需要的数学约束。

令人疑惑的是分类loss并不都是交叉熵损失,这是因为CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式,而RaFD则是一个onehot。
下面来看看在多个训练集训练时代码上是怎么操作的。
在数据加载上其实还是单个数据集轮流进行操作的,如下:

以上提到在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。

以生成器为例,计算损失时也是只在输出判断向量中提取该数据集已知的部分进行loss计算。

StarGAN论文及代码理解的更多相关文章
- HHL论文及代码理解(Generalizing A Person Retrieval Model Hetero- and Homogeneously ECCV 2018)
行人再识别Re-ID面临两个特殊的问题: 1)源数据集和目标数据集类别完全不同 2)相机造成的图片差异 因为一般来说传统的域适应问题源域和目标域的类别是相同的,相机之间的不匹配也是造成行人再识别数据集 ...
- Context Encoder论文及代码解读
经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...
- [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合
原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...
- linux io的cfq代码理解
内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...
- 10K+,深度学习论文、代码最全汇总!
我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...
- (转) AI突破性论文及代码实现汇总
本文转自:https://zhuanlan.zhihu.com/p/25191377 AI突破性论文及代码实现汇总 极视角 · 2 天前 What Can AI Do For You? “The bu ...
- 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
秦鼎涛 <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处 <Linux内核分析>MOOC课程http: ...
随机推荐
- 企业级 Web 开发的挑战
本文翻译自土牛Halil ibrahim Kalkan的<Mastering ABP Framework>,是系列翻译的起头,适合ABP开发人员或者想对ABP框架进行深入演进的准架构师. ...
- FreeRTOS --(16)资源管理之临界区
转载自 https://blog.csdn.net/zhoutaopower/article/details/107387427 临界区的概念在任何的 SoC 都存在,比如,针对一个寄存器,基本操作为 ...
- MySQL常用数据类型及细节
目录 1 整数类型 1.1 可选属性 1.1.1 M 1.1.2 UNSIGNED 1.1.3 ZEROFILL 2 浮点类型 2.1 精度误差 3 定点数类型 3.1 数据精度说明 3.2 类型介绍 ...
- Linux进程总结
一个执着于技术的公众号 进程 进程,是计算机中的程序关于某数据集合上的一次运行活动,是系统进行资源分配和调度的基本单位,是操作系统结构的基础.它的执行需要系统分配资源创建实体之后,才能进行.举个例子: ...
- 北航内核操作系统-lab1
1.实验目的. 2.实验内容. 2.1Exercise 1.1 请修改 include.mk 文件,使交叉编译器的路径正确.之后执行 make指令,如果配置一切正确,则会在gxemul 目录下生成v ...
- 经过一个多月的等待我有幸成为Spring相关项目的Contributor
给开源项目尤其是Spring这种知名度高的项目贡献代码是比较难的,起码胖哥是这么认为的.有些时候我们的灵感未必契合作者的设计意图,即使你的代码十分优雅. 我曾经给Spring Security提交了一 ...
- 关于 MyBatis-Plus 分页查询的探讨 → count 都为 0 了,为什么还要查询记录?
开心一刻 记得上初中,中午午休的时候,我和哥们躲在厕所里吸烟 听见外面有人进来,哥们猛吸一口,就把烟甩了 进来的是教导主任,问:你们干嘛呢? 哥们鼻孔里一边冒着白烟一边说:我在生气 环境搭建 依赖引入 ...
- nacos 详细介绍(二)
五.nacos的namespace和group namespace:相当于环境,开发环境 测试环境 生产环境 ,每个空间里面的配置是独立的默认的namespace是public, namespace可 ...
- linux篇-修改mysql数据库密码
总是忘记,每次都要查文档,背背背 方法1: 用SET PASSWORD命令 首先登录MySQL. 格式:mysql> set password for 用户名@localhost = passw ...
- ATM+购物车项目流程
目录 需求分析 架构设计 功能实现 搭建文件目录 conf配置文件夹 lib公共功能文件夹 db数据文件夹 interface业务逻辑层文件夹 core表现层文件夹 测试 最外层功能(src.py) ...