• 文章转自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄(已授权)
  • 联系方式:微信cyx645016617
  • 论文名称:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”

「前言」:最近好久没更新公众号了,我一不小心陷入了一个误区:我以为自己看的文章足够多了,用之前的风格迁移和GAN的知识来解决一个domain adaptive的问题,一顿乱拳并没有打死老师傅,反而自己累个够呛。然后找到这样一篇不错的DA framework,来认真学习一下章法,假期结束重新用章法组合拳再来会会。

0 综述

不同于以往的对抗模型或者是超像素信息来实现这个领域迁移,本文使用的是对抗生成网络GAN来将两个领域的特征空间拉近。

本文提出的是语义分割的领域自适应算法。论文特别关注的问题是:目标领域没有label

传统的DA方法包含最小化某些可以衡量source和target两个分布的距离函数。两种常见的度量是:

  • 最大均值差(Maximum Mean Discrepancy, MMD)
  • 通过对抗学习,使用DCNN来学习distance metric

本文的主要贡献在于提出了一种基于生成模型的特征空间源分布与目标分布对齐算法。

1 method

从图片中来初步判断,其实是比较好理解的:

  • 首先,我猜测其做域迁移,可能是仿照GAN领域中做风格迁移的办法;
  • 图片中总共有4个网络,F网络应该是特征提取网络,C网络是做分割的网络,G网络是把F提取的特征再还原成原图的网络,D网络是做分类的网络,和一般GAN不同的是,D中做四个分类,是True source,True target, False source, False targe. 类似于把cycleGAN中的两个二分类的discriminator合并了。

2 细节

原始图片定义为\(X\),source domain的图片定义为\(X^s\),target domain的图片定义为\(X^t\).

  • base network. 架构类似于预训练的VGG16,被分成了两个部分:特征提取部分叫做F网络,做像素分割的叫做C网络。
  • G网络是用来从F生成的embedding特征中,重建原始图像的;D网络不仅要分别出图片是否是real or fake,还会做一个分割任务,类似于C网络。这个分割任务仅仅针对source domain,因为target domain不存在标签。

现在我们假定已经准备好了数据和标签\({X^s,Y^s}\):

  • 首先经过F提取出来feature expression,\(F(X^s)\)
  • C网络生成分割的标签\(\hat{Y}^s\)
  • G网络重建图片\(\hat{X}^s\)

基于最近的相关的成功的研究,不再在G的输入中显式的concatenate一个随机变量,而是在Generator中使用dropout layer

3 损失

作者提出了很多的对抗损失:

  • 在一个domin内的损失有:

    • Discriminator损失,分辨src-real和src-fake;
    • Discriminator损失,分辨tgt-real和tgt-fake;
    • Generator损失,让fake source可以被discriminator判断成src-real的损失;
  • 在不同domain的损失:
    • F网络的损失,可以让fake source的输入被判断为real target;
    • F网络的损失,可以让fake target的输入被判断为real source;

除了上面说到的对抗损失外,还有下面的分割损失

  • \(L_{seg}\):在标准分割网络C中的pixel-wise的交叉熵损失;
  • \(L_{aux}\):D网络也会输出一个分割结果,交叉熵损失;
  • \(L_{rec}\):原始图像和重建图像之间的L1损失。

4 训练过程

在每一个iteration中,一个随机的三元组被输入到模型中:\(\{X^s,Y^s,X^t\}\),然后网络按照下面的顺序进行更新参数:

  1. 先更新参数D,更新策略如下:

    • 对于source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
    • 对于target input,用\(L^t_{adv,D}\)

  1. 然后更新G,更新策略如下:

    • 愚弄discriminator的两个loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
    • 重建损失,\(L^s_{rec}\)和\(L^t_{rec}\);

  1. F网络的更新策略如下:

    • F网络的更新是最关键的!(论文中说的)

- 更新F网络是为了实现domain adaptive,$L^s_{adv,F}$是为了混淆fake source 和real target;
- 类似于G-D之间的min-max game,这里是F和D之间的竞争,只不过前者是为了混淆fake和real,后者是为了混淆source domain和target domain;

5 D的设计动机

我们可以发现,这里面的D其实不是传统的GAN中的D,输出不再是单独的一个scalar,表示图片是fake or real的概率

最近有一篇GAN里面提到了,patch discriminator(这个论文恰好之前读过),这个是让D输出的也是一个二位的量,每一个值表示对应patch的fake or real的概率,这个措施极大的提高了G重建的图片的质量,这里继承延伸了patch discriminator的思想,输出的图片是一个pixel-wise的类似分割的结果,每一个像素有四个类别:fake-src,real-src,fake-tgt,real-tgt;

GAN一般是比较难训练的,尤其是针对大尺度的真实图片数据,一种稳定的方法来训练生成模型的架构是Auxiliary Classifier GAN(ACGAN)(真好,这个论文我之前也看过),简单的说通过增加一个辅助分类损失,可以训练一个更稳定的G,因此这也是为什么D中还会有一个分割损失\(L_{aux}\)

6 总结

作者提高,每一个组件都提供了关键的信息,不多说了,假期回实验室我要开始用章法组合拳来解决问题了。

域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018的更多相关文章

  1. In machine learning, is more data always better than better algorithms?

    In machine learning, is more data always better than better algorithms? No. There are times when mor ...

  2. 多标记学习--Learning from Multi-Label Data

    传统分类问题,即多类分类问题是,假设每个示例仅具有单个标记,且所有样本的标签类别数|L|大于1,然而,在很多现实世界的应用中,往往存在单个示例同时具有多重标记的情况. 而在多分类问题中,每个样本所含标 ...

  3. Coursera, Big Data 4, Machine Learning With Big Data (week 1/2)

    Week 1 Machine Learning with Big Data KNime - GUI based Spark MLlib - inside Spark CRISP-DM Week 2, ...

  4. 不平衡学习 Learning from Imbalanced Data

    问题: ICC警情数据分类不均,30+分类,最多的分类数据数量1w+条,只有10个类别数量超过1k,大部分分类数量少于100条. 解决办法: 下采样:通过非监督学习,找出每个分类中的异常点,减少数据. ...

  5. R8:Learning paths for Data Science[continuous updating…]

    Comprehensive learning path – Data Science in Python Journey from a Python noob to a Kaggler on Pyth ...

  6. 使用ADMT和PES实现window AD账户跨域迁移-介绍篇

    使用 ADMT 和 pwdmig 实现 window AD 账户跨域迁移系列: 介绍篇 ADMT 安装 PES 的安装 ADMT:迁移组 ADMT:迁移用户 ADMT:计算机迁移 ADMT:报告生成 ...

  7. A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)

    A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...

  8. Overcoming Forgetting in Federated Learning on Non-IID Data

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! 以下是对本文关键部分的摘抄翻译,详情请参见原文. NeurIPS 2019 Workshop on Federated Learning ...

  9. mysql迁移-----拷贝mysql目录/load data/mysqldump/into outfile

    摘要:本文简单介绍了mysql的三种备份,并解答了有一些实际备份中会遇到的问题.备份恢复有三种(除了用从库做备份之外), 直接拷贝文件,load data 和 mysqldump命令.少量数据使用my ...

随机推荐

  1. MySql-Day-01

    MySql 能够理解数据库的概念 能够安装MySQL数据库 能够启动,关闭及登录MySQL 能够使用SQL语句操作数据库 能够使用SQL语句操作表结构 能够使用SQL语句进行数据的添加修改和删除的操作 ...

  2. HDOJ-6628(dfs+第k字典序最小差异序列)

    permutation 1 HDOJ-6628 这题使用的暴力深搜,在dfs里面直接从最小的差异开始枚举 注意这里的pre记录前一个数,并且最后答案需要减去排列中最小的数再加一 这里有一个技巧关于求第 ...

  3. 盘点Excel中的那些有趣的“bug”

    本文由葡萄城技术团队原创并首发 转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具.解决方案和服务,赋能开发者. Excel 1.0早在1985年正式进入市场,距今已经有36年了,虽然在推出 ...

  4. MySQL 40题练习题和答案

    2.查询"生物"课程比"物理"课程成绩高的所有学生的学号: 思路:    获取所有有生物课程的人(学号,成绩) - 临时表    获取所有有物理课程的人(学号, ...

  5. 测试平台系列(4) 使用Flask蓝图(blueprint)

    使用Flask蓝图(blueprint) 回顾 先来看一下上一篇的作业吧,使用「logbook」的时候,遇到了时区不对的情况.那么我们怎么去解决这个问题呢? 实际上logbook默认采用的是世界标准时 ...

  6. CVE-2017-7504-JBoss JMXInvokerServlet 反序列化

    漏洞分析 https://paper.seebug.org/312/ 漏洞原理 这是经典的JBoss反序列化漏洞,JBoss在/invoker/JMXInvokerServlet请求中读取了用户传入的 ...

  7. pandas函数的使用

    一.Pandas的数据结构 1.Series Series是一种类似与一维数组的对象,由下面两个部分组成: values:一组数据(ndarray类型) index:相关的数据索引标签 1)Serie ...

  8. [SPOJ2021] Moving Pebbles

    [SPOJ2021] Moving Pebbles 题目大意:给你\(N\)堆\(Stone\),两个人玩游戏. 每次任选一堆,首先拿掉至少一个石头,然后移动任意个石子到任意堆中. 谁不能移动了,谁就 ...

  9. teprunner测试平台部署到Linux系统Docker

    本文是一篇过渡,在进行用例管理模块开发之前,有必要把入门篇开发完成的代码部署到Linux系统Docker中,把部署流程走一遍,这个过程对后端设计有决定性影响. 本地运行 通过在Vue项目执行npm r ...

  10. 微信二维码引擎OpenCV开源研究

    <微信二维码引擎OpenCV开源研究> 一.编译和Test测试        opencv_wechat_qrcode的编译需要同时下载opencv(https://github.com/ ...