• 文章转自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄(已授权)
  • 联系方式:微信cyx645016617
  • 论文名称:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”

「前言」:最近好久没更新公众号了,我一不小心陷入了一个误区:我以为自己看的文章足够多了,用之前的风格迁移和GAN的知识来解决一个domain adaptive的问题,一顿乱拳并没有打死老师傅,反而自己累个够呛。然后找到这样一篇不错的DA framework,来认真学习一下章法,假期结束重新用章法组合拳再来会会。

0 综述

不同于以往的对抗模型或者是超像素信息来实现这个领域迁移,本文使用的是对抗生成网络GAN来将两个领域的特征空间拉近。

本文提出的是语义分割的领域自适应算法。论文特别关注的问题是:目标领域没有label

传统的DA方法包含最小化某些可以衡量source和target两个分布的距离函数。两种常见的度量是:

  • 最大均值差(Maximum Mean Discrepancy, MMD)
  • 通过对抗学习,使用DCNN来学习distance metric

本文的主要贡献在于提出了一种基于生成模型的特征空间源分布与目标分布对齐算法。

1 method

从图片中来初步判断,其实是比较好理解的:

  • 首先,我猜测其做域迁移,可能是仿照GAN领域中做风格迁移的办法;
  • 图片中总共有4个网络,F网络应该是特征提取网络,C网络是做分割的网络,G网络是把F提取的特征再还原成原图的网络,D网络是做分类的网络,和一般GAN不同的是,D中做四个分类,是True source,True target, False source, False targe. 类似于把cycleGAN中的两个二分类的discriminator合并了。

2 细节

原始图片定义为\(X\),source domain的图片定义为\(X^s\),target domain的图片定义为\(X^t\).

  • base network. 架构类似于预训练的VGG16,被分成了两个部分:特征提取部分叫做F网络,做像素分割的叫做C网络。
  • G网络是用来从F生成的embedding特征中,重建原始图像的;D网络不仅要分别出图片是否是real or fake,还会做一个分割任务,类似于C网络。这个分割任务仅仅针对source domain,因为target domain不存在标签。

现在我们假定已经准备好了数据和标签\({X^s,Y^s}\):

  • 首先经过F提取出来feature expression,\(F(X^s)\)
  • C网络生成分割的标签\(\hat{Y}^s\)
  • G网络重建图片\(\hat{X}^s\)

基于最近的相关的成功的研究,不再在G的输入中显式的concatenate一个随机变量,而是在Generator中使用dropout layer

3 损失

作者提出了很多的对抗损失:

  • 在一个domin内的损失有:

    • Discriminator损失,分辨src-real和src-fake;
    • Discriminator损失,分辨tgt-real和tgt-fake;
    • Generator损失,让fake source可以被discriminator判断成src-real的损失;
  • 在不同domain的损失:
    • F网络的损失,可以让fake source的输入被判断为real target;
    • F网络的损失,可以让fake target的输入被判断为real source;

除了上面说到的对抗损失外,还有下面的分割损失

  • \(L_{seg}\):在标准分割网络C中的pixel-wise的交叉熵损失;
  • \(L_{aux}\):D网络也会输出一个分割结果,交叉熵损失;
  • \(L_{rec}\):原始图像和重建图像之间的L1损失。

4 训练过程

在每一个iteration中,一个随机的三元组被输入到模型中:\(\{X^s,Y^s,X^t\}\),然后网络按照下面的顺序进行更新参数:

  1. 先更新参数D,更新策略如下:

    • 对于source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
    • 对于target input,用\(L^t_{adv,D}\)

  1. 然后更新G,更新策略如下:

    • 愚弄discriminator的两个loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
    • 重建损失,\(L^s_{rec}\)和\(L^t_{rec}\);

  1. F网络的更新策略如下:

    • F网络的更新是最关键的!(论文中说的)

- 更新F网络是为了实现domain adaptive,$L^s_{adv,F}$是为了混淆fake source 和real target;
- 类似于G-D之间的min-max game,这里是F和D之间的竞争,只不过前者是为了混淆fake和real,后者是为了混淆source domain和target domain;

5 D的设计动机

我们可以发现,这里面的D其实不是传统的GAN中的D,输出不再是单独的一个scalar,表示图片是fake or real的概率

最近有一篇GAN里面提到了,patch discriminator(这个论文恰好之前读过),这个是让D输出的也是一个二位的量,每一个值表示对应patch的fake or real的概率,这个措施极大的提高了G重建的图片的质量,这里继承延伸了patch discriminator的思想,输出的图片是一个pixel-wise的类似分割的结果,每一个像素有四个类别:fake-src,real-src,fake-tgt,real-tgt;

GAN一般是比较难训练的,尤其是针对大尺度的真实图片数据,一种稳定的方法来训练生成模型的架构是Auxiliary Classifier GAN(ACGAN)(真好,这个论文我之前也看过),简单的说通过增加一个辅助分类损失,可以训练一个更稳定的G,因此这也是为什么D中还会有一个分割损失\(L_{aux}\)

6 总结

作者提高,每一个组件都提供了关键的信息,不多说了,假期回实验室我要开始用章法组合拳来解决问题了。

域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018的更多相关文章

  1. In machine learning, is more data always better than better algorithms?

    In machine learning, is more data always better than better algorithms? No. There are times when mor ...

  2. 多标记学习--Learning from Multi-Label Data

    传统分类问题,即多类分类问题是,假设每个示例仅具有单个标记,且所有样本的标签类别数|L|大于1,然而,在很多现实世界的应用中,往往存在单个示例同时具有多重标记的情况. 而在多分类问题中,每个样本所含标 ...

  3. Coursera, Big Data 4, Machine Learning With Big Data (week 1/2)

    Week 1 Machine Learning with Big Data KNime - GUI based Spark MLlib - inside Spark CRISP-DM Week 2, ...

  4. 不平衡学习 Learning from Imbalanced Data

    问题: ICC警情数据分类不均,30+分类,最多的分类数据数量1w+条,只有10个类别数量超过1k,大部分分类数量少于100条. 解决办法: 下采样:通过非监督学习,找出每个分类中的异常点,减少数据. ...

  5. R8:Learning paths for Data Science[continuous updating…]

    Comprehensive learning path – Data Science in Python Journey from a Python noob to a Kaggler on Pyth ...

  6. 使用ADMT和PES实现window AD账户跨域迁移-介绍篇

    使用 ADMT 和 pwdmig 实现 window AD 账户跨域迁移系列: 介绍篇 ADMT 安装 PES 的安装 ADMT:迁移组 ADMT:迁移用户 ADMT:计算机迁移 ADMT:报告生成 ...

  7. A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)

    A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...

  8. Overcoming Forgetting in Federated Learning on Non-IID Data

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! 以下是对本文关键部分的摘抄翻译,详情请参见原文. NeurIPS 2019 Workshop on Federated Learning ...

  9. mysql迁移-----拷贝mysql目录/load data/mysqldump/into outfile

    摘要:本文简单介绍了mysql的三种备份,并解答了有一些实际备份中会遇到的问题.备份恢复有三种(除了用从库做备份之外), 直接拷贝文件,load data 和 mysqldump命令.少量数据使用my ...

随机推荐

  1. java与freemarker遍历map

    一.java遍历MAP /** * 1.把key放到一个集合里,遍历key值同时根据key得到值 (推荐) */ Set set =map.keySet(); Iterator it=set.iter ...

  2. 在C#的WPF程序使用XAML实现画线

    在WPF中画直线.新建WPF应用程序,使用XAML画直线.使用X1.Y1两个属性可以设置直线的起点坐标,X2.Y2两个属性则可以设置直线的终点坐标.控制起点/终点坐标就可以实现平行.交错等效果.Str ...

  3. 微信小程序折线图表折线图加区域图

    1.先来个效果图 这里我用的是插件@antv/f2-canvas(安装的方法如下) npm init 此处如果直接使用官方npm install 可能会出现没有node_modules错误 npm i ...

  4. 【转载】几张图轻松理解String.intern()

    出处:https://blog.csdn.net/soonfly/article/details/70147205 在翻<深入理解Java虚拟机>的书时,又看到了2-7的 String.i ...

  5. pytorch(16)损失函数(二)

    5和6是在数据回归中用的较多的损失函数 5. nn.L1Loss 功能:计算inputs与target之差的绝对值 代码: nn.L1Loss(reduction='mean') 公式: \[l_n ...

  6. C语言中字符串详解

    C语言中字符串详解 字符串时是C语言中非常重要的部分,我们从字符串的性质和字符串的创建.程序中字符串的输入输出和字符串的操作来对字符串进行详细的解析. 什么是字符串? C语言本身没有内置的字符串类型, ...

  7. git clone 提速

    将类似于 git clone https://github.com/graykode/nlp-tutorial 的命令改成 https://github.com.cnpmjs.org/graykode ...

  8. windows 下使用vargant 搭建虚拟机服务

    使用vagrant 下载 vagrant[https://www.vagrantup.com/downloads.html] 下载管理工具VirtualBox[https://www.virtualb ...

  9. 卷积神经网络学习笔记——轻量化网络MobileNet系列(V1,V2,V3)

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和Mo ...

  10. JavaCV 树莓派打造监控系统平台

    使用树莓派搭建视频监控平台去年就简单的实现了,只不过功能比较简陋,最近抽时间重构了原来的平台. 环境搭建 环境部分,参考旧版的安装及配置: 树莓派搭建视频监控平台 树莓派视频监控平台实现录制归档 框架 ...