• 文章转自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄(已授权)
  • 联系方式:微信cyx645016617
  • 论文名称:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”

「前言」:最近好久没更新公众号了,我一不小心陷入了一个误区:我以为自己看的文章足够多了,用之前的风格迁移和GAN的知识来解决一个domain adaptive的问题,一顿乱拳并没有打死老师傅,反而自己累个够呛。然后找到这样一篇不错的DA framework,来认真学习一下章法,假期结束重新用章法组合拳再来会会。

0 综述

不同于以往的对抗模型或者是超像素信息来实现这个领域迁移,本文使用的是对抗生成网络GAN来将两个领域的特征空间拉近。

本文提出的是语义分割的领域自适应算法。论文特别关注的问题是:目标领域没有label

传统的DA方法包含最小化某些可以衡量source和target两个分布的距离函数。两种常见的度量是:

  • 最大均值差(Maximum Mean Discrepancy, MMD)
  • 通过对抗学习,使用DCNN来学习distance metric

本文的主要贡献在于提出了一种基于生成模型的特征空间源分布与目标分布对齐算法。

1 method

从图片中来初步判断,其实是比较好理解的:

  • 首先,我猜测其做域迁移,可能是仿照GAN领域中做风格迁移的办法;
  • 图片中总共有4个网络,F网络应该是特征提取网络,C网络是做分割的网络,G网络是把F提取的特征再还原成原图的网络,D网络是做分类的网络,和一般GAN不同的是,D中做四个分类,是True source,True target, False source, False targe. 类似于把cycleGAN中的两个二分类的discriminator合并了。

2 细节

原始图片定义为\(X\),source domain的图片定义为\(X^s\),target domain的图片定义为\(X^t\).

  • base network. 架构类似于预训练的VGG16,被分成了两个部分:特征提取部分叫做F网络,做像素分割的叫做C网络。
  • G网络是用来从F生成的embedding特征中,重建原始图像的;D网络不仅要分别出图片是否是real or fake,还会做一个分割任务,类似于C网络。这个分割任务仅仅针对source domain,因为target domain不存在标签。

现在我们假定已经准备好了数据和标签\({X^s,Y^s}\):

  • 首先经过F提取出来feature expression,\(F(X^s)\)
  • C网络生成分割的标签\(\hat{Y}^s\)
  • G网络重建图片\(\hat{X}^s\)

基于最近的相关的成功的研究,不再在G的输入中显式的concatenate一个随机变量,而是在Generator中使用dropout layer

3 损失

作者提出了很多的对抗损失:

  • 在一个domin内的损失有:

    • Discriminator损失,分辨src-real和src-fake;
    • Discriminator损失,分辨tgt-real和tgt-fake;
    • Generator损失,让fake source可以被discriminator判断成src-real的损失;
  • 在不同domain的损失:
    • F网络的损失,可以让fake source的输入被判断为real target;
    • F网络的损失,可以让fake target的输入被判断为real source;

除了上面说到的对抗损失外,还有下面的分割损失

  • \(L_{seg}\):在标准分割网络C中的pixel-wise的交叉熵损失;
  • \(L_{aux}\):D网络也会输出一个分割结果,交叉熵损失;
  • \(L_{rec}\):原始图像和重建图像之间的L1损失。

4 训练过程

在每一个iteration中,一个随机的三元组被输入到模型中:\(\{X^s,Y^s,X^t\}\),然后网络按照下面的顺序进行更新参数:

  1. 先更新参数D,更新策略如下:

    • 对于source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
    • 对于target input,用\(L^t_{adv,D}\)

  1. 然后更新G,更新策略如下:

    • 愚弄discriminator的两个loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
    • 重建损失,\(L^s_{rec}\)和\(L^t_{rec}\);

  1. F网络的更新策略如下:

    • F网络的更新是最关键的!(论文中说的)

- 更新F网络是为了实现domain adaptive,$L^s_{adv,F}$是为了混淆fake source 和real target;
- 类似于G-D之间的min-max game,这里是F和D之间的竞争,只不过前者是为了混淆fake和real,后者是为了混淆source domain和target domain;

5 D的设计动机

我们可以发现,这里面的D其实不是传统的GAN中的D,输出不再是单独的一个scalar,表示图片是fake or real的概率

最近有一篇GAN里面提到了,patch discriminator(这个论文恰好之前读过),这个是让D输出的也是一个二位的量,每一个值表示对应patch的fake or real的概率,这个措施极大的提高了G重建的图片的质量,这里继承延伸了patch discriminator的思想,输出的图片是一个pixel-wise的类似分割的结果,每一个像素有四个类别:fake-src,real-src,fake-tgt,real-tgt;

GAN一般是比较难训练的,尤其是针对大尺度的真实图片数据,一种稳定的方法来训练生成模型的架构是Auxiliary Classifier GAN(ACGAN)(真好,这个论文我之前也看过),简单的说通过增加一个辅助分类损失,可以训练一个更稳定的G,因此这也是为什么D中还会有一个分割损失\(L_{aux}\)

6 总结

作者提高,每一个组件都提供了关键的信息,不多说了,假期回实验室我要开始用章法组合拳来解决问题了。

域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018的更多相关文章

  1. In machine learning, is more data always better than better algorithms?

    In machine learning, is more data always better than better algorithms? No. There are times when mor ...

  2. 多标记学习--Learning from Multi-Label Data

    传统分类问题,即多类分类问题是,假设每个示例仅具有单个标记,且所有样本的标签类别数|L|大于1,然而,在很多现实世界的应用中,往往存在单个示例同时具有多重标记的情况. 而在多分类问题中,每个样本所含标 ...

  3. Coursera, Big Data 4, Machine Learning With Big Data (week 1/2)

    Week 1 Machine Learning with Big Data KNime - GUI based Spark MLlib - inside Spark CRISP-DM Week 2, ...

  4. 不平衡学习 Learning from Imbalanced Data

    问题: ICC警情数据分类不均,30+分类,最多的分类数据数量1w+条,只有10个类别数量超过1k,大部分分类数量少于100条. 解决办法: 下采样:通过非监督学习,找出每个分类中的异常点,减少数据. ...

  5. R8:Learning paths for Data Science[continuous updating…]

    Comprehensive learning path – Data Science in Python Journey from a Python noob to a Kaggler on Pyth ...

  6. 使用ADMT和PES实现window AD账户跨域迁移-介绍篇

    使用 ADMT 和 pwdmig 实现 window AD 账户跨域迁移系列: 介绍篇 ADMT 安装 PES 的安装 ADMT:迁移组 ADMT:迁移用户 ADMT:计算机迁移 ADMT:报告生成 ...

  7. A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)

    A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...

  8. Overcoming Forgetting in Federated Learning on Non-IID Data

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! 以下是对本文关键部分的摘抄翻译,详情请参见原文. NeurIPS 2019 Workshop on Federated Learning ...

  9. mysql迁移-----拷贝mysql目录/load data/mysqldump/into outfile

    摘要:本文简单介绍了mysql的三种备份,并解答了有一些实际备份中会遇到的问题.备份恢复有三种(除了用从库做备份之外), 直接拷贝文件,load data 和 mysqldump命令.少量数据使用my ...

随机推荐

  1. 前端问题录——在导入模块时使用'@'时提示"Modile is not installed"

    前情提要 为了尽可能解决引用其他模块时路径过长的问题,通常会在 vue.config.js 文件中为 src 目录配置一个别名 '@' configureWebpack: { resolve: { a ...

  2. Redis缓存中的常见问题

    缓存穿透:是指查询一个Redis和数据库中都不存在的数据. 问题:查询一个Redis和数据库中都不存在的数据,大量请求去访问数据库,导致数据库宕机. 解决办法: 1.根据id查询,如果id是自增的,将 ...

  3. Elasticsearch--Logstash定时同步MySQL数据到Elasticsearch

    新地址体验:http://www.zhouhong.icu/post/139 一.Logstash介绍 Logstash是elastic技术栈中的一个技术.它是一个数据采集引擎,可以从数据库采集数据到 ...

  4. Python高级——多任务编程之线程

    转: Python高级--多任务编程之线程 文章目录 线程概念 1. 线程的介绍 2. 线程的概念 3. 线程的作用 多线程的使用 1. 导入线程模块 2. 线程类Thread参数说明 3. 启动线程 ...

  5. Java的特性和优势以及不同版本的分类,jdk,jre,jvm的联系与区别,javadoc的生成

    Java 1.Java的特性和优势 Write Once,Run Anywhere 简单性 面向对象 可移植性 高性能 分布式 动态性 多线程 安全性 健壮性 2.Java的三大版本 JavaSE:标 ...

  6. WPF 基础 - 资源

    为了避免丢失和损坏,编译器允许我们把外部文件编译进程序主体.成为程序主体不可分割的一部分,这就是传统意义上的程序资源,即二进制资源: WPF 的四个等级资源: 数据库里的数据 (仓库) 资源文件 (行 ...

  7. 开源项目renren-fast-vue开发环境部署(前端部分)

    开源项目renren-fast-vue开发环境部署(前端部分) 说明:renren-fast是一个开源的基于springboot的前后端分离手脚架,当前版本是3.0 开发文档需要付费,官方的开发环境部 ...

  8. P1055_ISBN号码(JAVA语言)

    题目描述 每一本正式出版的图书都有一个ISBN号码与之对应,ISBN码包括9位数字.1位识别码和3位分隔符, 其规定格式如x-xxx-xxxxx-x,其中符号-就是分隔符(键盘上的减号), 最后一位是 ...

  9. 生成元(JAVA语言)

    package 第三章; import java.util.Scanner; public class 生成元 { public static void main(String[] args) { / ...

  10. [Kick Start] 2021 Round A

    题目:2021 Round-A . K-Goodness String 签到题,计算当前字符串的 K-Goodness Score ,然后与给出的 K 做差即可. #include <iostr ...