StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:

创新点:为了实现可转换到多个领域,StarGAN加入了一个域的控制信息,类似于CGAN的形式。在网络结构设计上,鉴别器不仅仅需要学习鉴别样本是否真实,还需要对真实图片判断来自哪个域。

整个网络的处理流程如下:

  1. 将输入图片x和目标生成域c结合喂入到生成网络G来合成fake图片
  2. 将fake图片和真实图片分别喂入到鉴别器D,D需要判断图片是否真实,还需要判断它来自哪个域
  3. 与CycleGAN类似,还有一个一致性约束,将生成的fake图片和原始图片的域信息c'结合起来喂入到生成器G要求能输出重建出原始输入图片x

下面分析一下各个部分的损失函数:

一:GAN常见的对抗损失:

二:对于给定的输入图片x和目标域标签c,网络的目标是将x转换成输出图片y,输出图片y能够被归类成目标域c。为了实现这一点就需要鉴别器有判别域的功能。所以作者在D的顶端加了一个额外的域分类器,域分类器loss在优化D和G时都会用到,作者将这一损失分为两个方向,分别用来优化G和D。(这很容易理解,因为如下分析可以看到公式(3)没有办法为D提供训练需要的监督信息)

一个是真实图片的域分类损失用来优化D,另一个是fake图片的域分类损失来优化G。

1)

Dcls(c'|x)代表D对真实图片计算得到的域标签概率分布。这一学习目标将会使得D能够将输入图片x识别为对应的域c',这里的(x,c')是训练集给定的。

2)

fake图片的域分类的损失函数定义如(3),它用来优化G,也就是让G尽力去生成图片让它能够被D分类成目标域c。

三:还有一个重建损失

通过最小化对抗损失与分类损失,G努力尝试做到生成目标域中的现实图片。但是这无法保证学习到的转换只会改变输入图片的域相关的信息而不改变图片内容。所以加上了周期一致性损失:

这里就是将G(x,c)和图片x的原始标签c'结合喂入到G中,将生成的图片和x计算1范数差异。

总体损失:

在实际操作上,作者将对抗损失换成了WGAN的对抗损失:

以上对于单个数据集的训练来说已经足够了,但是现在想想另一个问题,假如我要联合训练多个数据集呢?

举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?

一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。

因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。

以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。

论文部分到此结束,下面来分析一下代码

主要的代码有model.py和solver.py两个。

在model.py中作者创建了生成器G与鉴别器D。

在生成器中先对模型降维缩小为原来4倍,再使用多个残差网络获得等维度输出,接着使用转置卷积放大4倍,最后通过一层尺寸不变的卷积,取tanh作为输出。

另外一个值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的,代码中可以看出就是直接在第四维度上进行拼接(pytorch一般为N*C*H*W,所以看起来是在第二维)。

对于鉴别器,使用conv1的输出代表域的预测概率,conv2的输出代表图片是否为真的判断。这两个的关系是并行的。

Solver.py比较长,挑选重要的部分来解释:

首先是梯度惩罚,这一部分来自WGAN的改善工作,主要是为了满足Lipschitz连续这个WGAN推导中需要的数学约束。

令人疑惑的是分类loss并不都是交叉熵损失,这是因为CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式,而RaFD则是一个onehot。

下面来看看在多个训练集训练时代码上是怎么操作的。

在数据加载上其实还是单个数据集轮流进行操作的,如下:

以上提到在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。

以生成器为例,计算损失时也是只在输出判断向量中提取该数据集已知的部分进行loss计算。

StarGAN论文及代码理解的更多相关文章

  1. HHL论文及代码理解(Generalizing A Person Retrieval Model Hetero- and Homogeneously ECCV 2018)

    行人再识别Re-ID面临两个特殊的问题: 1)源数据集和目标数据集类别完全不同 2)相机造成的图片差异 因为一般来说传统的域适应问题源域和目标域的类别是相同的,相机之间的不匹配也是造成行人再识别数据集 ...

  2. Context Encoder论文及代码解读

    经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...

  3. [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合

    原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...

  4. linux io的cfq代码理解

    内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...

  5. 10K+,深度学习论文、代码最全汇总!

    我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...

  6. (转) AI突破性论文及代码实现汇总

    本文转自:https://zhuanlan.zhihu.com/p/25191377 AI突破性论文及代码实现汇总 极视角 · 2 天前 What Can AI Do For You? “The bu ...

  7. 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    秦鼎涛  <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...

  8. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  9. 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处   <Linux内核分析>MOOC课程http: ...

随机推荐

  1. Java包装类,基本的装箱与拆箱

    我的博客 何为包装类 将原始类型和包装类分开以保持简单.当需要一个适合像面向对象编程的类型时就需要包装类.当希望数据类型变得简单时就使用原始类型. 原始类型不能为null,但包装类可以为null.包装 ...

  2. 攻防世界-MISC:2017_Dating_in_Singapore

    这是MISC高手进阶区的题目:题目如下: 点击下载附件一,得到一张pdf图片,除此之外就只有题目给的字符串了,不知道是什么意思(查看了一下WP)原来每一串通过"-"隔开的字符串代表 ...

  3. Python技法:用re模块实现简易tokenizer

    一个简单的tokenizer 分词(tokenization)任务是Python字符串处理中最为常见任务了.我们这里讲解用正则表达式构建简单的表达式分词器(tokenizer),它能够将表达式字符串从 ...

  4. python实现基于smtp发送邮件

    [前言] 在某些项目中,我们需要实现发送邮件的功能,比如: 爬虫结束后,发送邮件通知 定时发送邮件提醒待办事项 某项业务逻辑触发邮件通知 今天我们就分享如何基于smtp借助163邮箱来发送邮件 [实现 ...

  5. line-height和height关系

    如图所示,line-height = font-size + 上下本行距.上下半行距总是相等.font-size居于中间.当font-size值固定时,line-height越大,半行距越大.所以当l ...

  6. Hadoop安装学习(第二天)

    学习任务: 1.对VMnet8进行设置 2.配置主机名,对host文件进行编辑 3.将Hadoop文件以及jdk通过Xshell7传输到Linux系统 4.设置免密登录

  7. Redis快速度特性及为什么支持多线程及应用场景

    转载请注明出处: 目录 1.Redis 访问速度快特性 2.Redis 6.0 为什么支持多线程? 3.Redis可以做什么 3.1.缓存 3.2.排行榜系统 3.3.计数器应用 3.4.社交网络 3 ...

  8. axios的请求参数格式(get、post、put、delete)

    1.get请求方式: axios.get(url[, config]) // [字符拼接型]axios.get(url?id=123&status=0') // 等同于 axios.get(u ...

  9. 【Azure 存储服务】Java Azure Storage SDK V12使用Endpoint连接Blob Service遇见 The Azure Storage endpoint url is malformed

    问题描述 使用Azure Storage Account的共享访问签名(Share Access Signature) 生成的终结点,连接时遇见  The Azure Storage endpoint ...

  10. 测试平台系列(95) 前置条件支持简单的python脚本

    大家好~我是米洛! 我正在从0到1打造一个开源的接口测试平台, 也在编写一套与之对应的教程,希望大家多多支持. 欢迎关注我的公众号米洛的测开日记,获取最新文章教程! 回顾 上一节我们构思了一下怎么去支 ...