StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:

创新点:为了实现可转换到多个领域,StarGAN加入了一个域的控制信息,类似于CGAN的形式。在网络结构设计上,鉴别器不仅仅需要学习鉴别样本是否真实,还需要对真实图片判断来自哪个域。

整个网络的处理流程如下:

  1. 将输入图片x和目标生成域c结合喂入到生成网络G来合成fake图片
  2. 将fake图片和真实图片分别喂入到鉴别器D,D需要判断图片是否真实,还需要判断它来自哪个域
  3. 与CycleGAN类似,还有一个一致性约束,将生成的fake图片和原始图片的域信息c'结合起来喂入到生成器G要求能输出重建出原始输入图片x

下面分析一下各个部分的损失函数:

一:GAN常见的对抗损失:

二:对于给定的输入图片x和目标域标签c,网络的目标是将x转换成输出图片y,输出图片y能够被归类成目标域c。为了实现这一点就需要鉴别器有判别域的功能。所以作者在D的顶端加了一个额外的域分类器,域分类器loss在优化D和G时都会用到,作者将这一损失分为两个方向,分别用来优化G和D。(这很容易理解,因为如下分析可以看到公式(3)没有办法为D提供训练需要的监督信息)

一个是真实图片的域分类损失用来优化D,另一个是fake图片的域分类损失来优化G。

1)

Dcls(c'|x)代表D对真实图片计算得到的域标签概率分布。这一学习目标将会使得D能够将输入图片x识别为对应的域c',这里的(x,c')是训练集给定的。

2)

fake图片的域分类的损失函数定义如(3),它用来优化G,也就是让G尽力去生成图片让它能够被D分类成目标域c。

三:还有一个重建损失

通过最小化对抗损失与分类损失,G努力尝试做到生成目标域中的现实图片。但是这无法保证学习到的转换只会改变输入图片的域相关的信息而不改变图片内容。所以加上了周期一致性损失:

这里就是将G(x,c)和图片x的原始标签c'结合喂入到G中,将生成的图片和x计算1范数差异。

总体损失:

在实际操作上,作者将对抗损失换成了WGAN的对抗损失:

以上对于单个数据集的训练来说已经足够了,但是现在想想另一个问题,假如我要联合训练多个数据集呢?

举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?

一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。

因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。

以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。

论文部分到此结束,下面来分析一下代码

主要的代码有model.py和solver.py两个。

在model.py中作者创建了生成器G与鉴别器D。

在生成器中先对模型降维缩小为原来4倍,再使用多个残差网络获得等维度输出,接着使用转置卷积放大4倍,最后通过一层尺寸不变的卷积,取tanh作为输出。

另外一个值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的,代码中可以看出就是直接在第四维度上进行拼接(pytorch一般为N*C*H*W,所以看起来是在第二维)。

对于鉴别器,使用conv1的输出代表域的预测概率,conv2的输出代表图片是否为真的判断。这两个的关系是并行的。

Solver.py比较长,挑选重要的部分来解释:

首先是梯度惩罚,这一部分来自WGAN的改善工作,主要是为了满足Lipschitz连续这个WGAN推导中需要的数学约束。

令人疑惑的是分类loss并不都是交叉熵损失,这是因为CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式,而RaFD则是一个onehot。

下面来看看在多个训练集训练时代码上是怎么操作的。

在数据加载上其实还是单个数据集轮流进行操作的,如下:

以上提到在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。

以生成器为例,计算损失时也是只在输出判断向量中提取该数据集已知的部分进行loss计算。

StarGAN论文及代码理解的更多相关文章

  1. HHL论文及代码理解(Generalizing A Person Retrieval Model Hetero- and Homogeneously ECCV 2018)

    行人再识别Re-ID面临两个特殊的问题: 1)源数据集和目标数据集类别完全不同 2)相机造成的图片差异 因为一般来说传统的域适应问题源域和目标域的类别是相同的,相机之间的不匹配也是造成行人再识别数据集 ...

  2. Context Encoder论文及代码解读

    经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...

  3. [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合

    原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...

  4. linux io的cfq代码理解

    内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...

  5. 10K+,深度学习论文、代码最全汇总!

    我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...

  6. (转) AI突破性论文及代码实现汇总

    本文转自:https://zhuanlan.zhihu.com/p/25191377 AI突破性论文及代码实现汇总 极视角 · 2 天前 What Can AI Do For You? “The bu ...

  7. 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    秦鼎涛  <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...

  8. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  9. 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处   <Linux内核分析>MOOC课程http: ...

随机推荐

  1. SerialPort-4.0.+ 使用说明(Java版本)

    SerialPort-4.0.+ 项目官网 Kotlin版本使用说明 介绍 SerialPort 是一个开源的对 Android 蓝牙串口通信的轻量封装库,轻松解决了构建自己的串口调试APP的复杂程度 ...

  2. css3 做出顶边倾斜的 梯形 div

    效果图: <html> <head> <meta charset="utf-8"> <title>顶边倾斜的div梯形</ti ...

  3. 网络协议之:Domain name service DNS详解

    目录 简介 DNS的功能 DNS的组成 域名空间Domain name space Name servers DNS的工作流程 DNS资源记录 DNS消息的结构 总结 简介 现在是互联网的世界,大家从 ...

  4. kernel UAF && tty_struct

    kernel UAF && 劫持tty_struct ciscn2017_babydriver exp1 fork进程时会申请堆来存放cred.cred结构大小为0xA8.修改cred ...

  5. redis数据结构附录

    引言 本次对上一次的数据结构知识进行补充,主要有redis数据结构的相关应用场景和内存相关知识 引用计数-内存 redis中的对象回收机制是采用引用计数的方式,首先我们可以通过redis对象结构体代码 ...

  6. zipper题解

    -请奆佬们洁身自好,好好打代码从我做起 - 题目大意: 给三个字符串,判断C字符串是否由A B字符串顺序组成, 题意分析: 很容易想到的是,A的长度加上B的长度为C的长度 其实进一步想,这 提供了一个 ...

  7. 蓝桥杯Web练习题:多个斜线开始的路径重定向问题

    多个斜线开始的路径重定向问题 需求说明 在 vue-router v3.5.2 版本代码中存在一个 Bug,一个以多个斜线(///)开始的路径实际上可能会重定向到另一个域.这是因为 cleanPath ...

  8. python之函数的进阶

    1.名称空间: 定义:用来存放名字的(变量,函数名,类名,引入的模块名) 分类: 内置名称空间:python解释器提供好的一些内置内容 全局名称空间:py文件中自己写的变量 局部名称空间:执行函数时会 ...

  9. 哈工大软件构造Lab1(2022)

    目录 一.实验目标概述 二.实验环境配置 1.安装编写java程序的IDE--IntelliJ IDEA 2.安装Git 3.安装Junit 4.GitHub Lab1仓库的URL地址 三.实验过程 ...

  10. drools中Fact的equality modes

    一.equality modes介绍 在drools中存在如下2种equality modes. 1.identity模式 identity:这是默认的情况.drools引擎使用IdentityHas ...