域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018
- 文章转自:微信公众号「机器学习炼丹术」
- 作者:炼丹兄(已授权)
- 联系方式:微信cyx645016617
- 论文名称:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”
「前言」:最近好久没更新公众号了,我一不小心陷入了一个误区:我以为自己看的文章足够多了,用之前的风格迁移和GAN的知识来解决一个domain adaptive的问题,一顿乱拳并没有打死老师傅,反而自己累个够呛。然后找到这样一篇不错的DA framework,来认真学习一下章法,假期结束重新用章法组合拳再来会会。
0 综述
不同于以往的对抗模型或者是超像素信息来实现这个领域迁移,本文使用的是对抗生成网络GAN来将两个领域的特征空间拉近。
本文提出的是语义分割的领域自适应算法。论文特别关注的问题是:目标领域没有label。
传统的DA方法包含最小化某些可以衡量source和target两个分布的距离函数。两种常见的度量是:
- 最大均值差(Maximum Mean Discrepancy, MMD)
- 通过对抗学习,使用DCNN来学习distance metric
本文的主要贡献在于提出了一种基于生成模型的特征空间源分布与目标分布对齐算法。
1 method
从图片中来初步判断,其实是比较好理解的:
- 首先,我猜测其做域迁移,可能是仿照GAN领域中做风格迁移的办法;
- 图片中总共有4个网络,F网络应该是特征提取网络,C网络是做分割的网络,G网络是把F提取的特征再还原成原图的网络,D网络是做分类的网络,和一般GAN不同的是,D中做四个分类,是True source,True target, False source, False targe. 类似于把cycleGAN中的两个二分类的discriminator合并了。
2 细节
原始图片定义为\(X\),source domain的图片定义为\(X^s\),target domain的图片定义为\(X^t\).
- base network. 架构类似于预训练的VGG16,被分成了两个部分:特征提取部分叫做F网络,做像素分割的叫做C网络。
- G网络是用来从F生成的embedding特征中,重建原始图像的;D网络不仅要分别出图片是否是real or fake,还会做一个分割任务,类似于C网络。这个分割任务仅仅针对source domain,因为target domain不存在标签。
现在我们假定已经准备好了数据和标签\({X^s,Y^s}\):
- 首先经过F提取出来feature expression,\(F(X^s)\)
- C网络生成分割的标签\(\hat{Y}^s\)
- G网络重建图片\(\hat{X}^s\)
基于最近的相关的成功的研究,不再在G的输入中显式的concatenate一个随机变量,而是在Generator中使用dropout layer
3 损失
作者提出了很多的对抗损失:
- 在一个domin内的损失有:
- Discriminator损失,分辨src-real和src-fake;
- Discriminator损失,分辨tgt-real和tgt-fake;
- Generator损失,让fake source可以被discriminator判断成src-real的损失;
- 在不同domain的损失:
- F网络的损失,可以让fake source的输入被判断为real target;
- F网络的损失,可以让fake target的输入被判断为real source;
除了上面说到的对抗损失外,还有下面的分割损失:
- \(L_{seg}\):在标准分割网络C中的pixel-wise的交叉熵损失;
- \(L_{aux}\):D网络也会输出一个分割结果,交叉熵损失;
- \(L_{rec}\):原始图像和重建图像之间的L1损失。
4 训练过程
在每一个iteration中,一个随机的三元组被输入到模型中:\(\{X^s,Y^s,X^t\}\),然后网络按照下面的顺序进行更新参数:
- 先更新参数D,更新策略如下:
- 对于source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
- 对于target input,用\(L^t_{adv,D}\)
- 然后更新G,更新策略如下:
- 愚弄discriminator的两个loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
- 重建损失,\(L^s_{rec}\)和\(L^t_{rec}\);
- F网络的更新策略如下:
- F网络的更新是最关键的!(论文中说的)
- 更新F网络是为了实现domain adaptive,$L^s_{adv,F}$是为了混淆fake source 和real target;
- 类似于G-D之间的min-max game,这里是F和D之间的竞争,只不过前者是为了混淆fake和real,后者是为了混淆source domain和target domain;
5 D的设计动机
我们可以发现,这里面的D其实不是传统的GAN中的D,输出不再是单独的一个scalar,表示图片是fake or real的概率
最近有一篇GAN里面提到了,patch discriminator(这个论文恰好之前读过),这个是让D输出的也是一个二位的量,每一个值表示对应patch的fake or real的概率,这个措施极大的提高了G重建的图片的质量,这里继承延伸了patch discriminator的思想,输出的图片是一个pixel-wise的类似分割的结果,每一个像素有四个类别:fake-src,real-src,fake-tgt,real-tgt;
GAN一般是比较难训练的,尤其是针对大尺度的真实图片数据,一种稳定的方法来训练生成模型的架构是Auxiliary Classifier GAN(ACGAN)(真好,这个论文我之前也看过),简单的说通过增加一个辅助分类损失,可以训练一个更稳定的G,因此这也是为什么D中还会有一个分割损失\(L_{aux}\)
6 总结
作者提高,每一个组件都提供了关键的信息,不多说了,假期回实验室我要开始用章法组合拳来解决问题了。
域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018的更多相关文章
- In machine learning, is more data always better than better algorithms?
In machine learning, is more data always better than better algorithms? No. There are times when mor ...
- 多标记学习--Learning from Multi-Label Data
传统分类问题,即多类分类问题是,假设每个示例仅具有单个标记,且所有样本的标签类别数|L|大于1,然而,在很多现实世界的应用中,往往存在单个示例同时具有多重标记的情况. 而在多分类问题中,每个样本所含标 ...
- Coursera, Big Data 4, Machine Learning With Big Data (week 1/2)
Week 1 Machine Learning with Big Data KNime - GUI based Spark MLlib - inside Spark CRISP-DM Week 2, ...
- 不平衡学习 Learning from Imbalanced Data
问题: ICC警情数据分类不均,30+分类,最多的分类数据数量1w+条,只有10个类别数量超过1k,大部分分类数量少于100条. 解决办法: 下采样:通过非监督学习,找出每个分类中的异常点,减少数据. ...
- R8:Learning paths for Data Science[continuous updating…]
Comprehensive learning path – Data Science in Python Journey from a Python noob to a Kaggler on Pyth ...
- 使用ADMT和PES实现window AD账户跨域迁移-介绍篇
使用 ADMT 和 pwdmig 实现 window AD 账户跨域迁移系列: 介绍篇 ADMT 安装 PES 的安装 ADMT:迁移组 ADMT:迁移用户 ADMT:计算机迁移 ADMT:报告生成 ...
- A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)
A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...
- Overcoming Forgetting in Federated Learning on Non-IID Data
郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! 以下是对本文关键部分的摘抄翻译,详情请参见原文. NeurIPS 2019 Workshop on Federated Learning ...
- mysql迁移-----拷贝mysql目录/load data/mysqldump/into outfile
摘要:本文简单介绍了mysql的三种备份,并解答了有一些实际备份中会遇到的问题.备份恢复有三种(除了用从库做备份之外), 直接拷贝文件,load data 和 mysqldump命令.少量数据使用my ...
随机推荐
- Hexo一键部署到阿里云OSS并设置浏览器缓存
自建博客地址:https://bytelife.net,欢迎访问! 本文为博客自动同步文章,为了更好的阅读体验,建议您移步至我的博客 本文作者: Jeffrey 本文链接: https://bytel ...
- 看完我的笔记不懂也会懂----ECMAscript 567
目录 ECMAscript 567 严格模式 字符串扩展 数值的扩展 Object对象方法扩展 数组的扩展 数组方法的扩展 bind.call.apply用法详解 let const 变量的解构赋值 ...
- hiho一下 第195周 奖券兑换[C solution][Accepted]
时间限制:20000ms 单点时限:1000ms 内存限制:256MB 描述 小Hi在游乐园中获得了M张奖券,这些奖券可以用来兑换奖品. 可供兑换的奖品一共有N件.第i件奖品需要Wi张奖券才能兑换到, ...
- Java 面向对象 01
面向对象·一级 面向对象思想概述 * A:面向过程思想概述 * 第一步 * 第二步 * B:面向对象思想概述 * 找对象(第一步,第二步) * C:举例 * 买煎饼果子 ...
- 04-Spring自定义标签解析
自定义标签的解析 这一篇主要说明自定义标签的解析流程,除了 bean.alias.import.beans之外的标签,都属于自定义标签的范围,自定义标签的解析需要命名空间配合, 获取对应的命名空间 根 ...
- Django 使用 pycharm 创建新的app(可以理解为模块)
创建工程的时候,注意选择Existing interpreter 选择对应的 python 解释器,电脑如果安装有多个版本的 Python 的话,注意python版本的问题, 以上即是创建的项目目录, ...
- 基础篇:java.security框架之签名、加密、摘要及证书
前言 和前端进行数据交互时或者和第三方商家对接时,需要对隐私数据进行加密.单向加密,对称加密,非对称加密,其对应的算法也各式各样.java提供了统一的框架来规范(java.security)安全加密这 ...
- 关于github的使用学习心得
先写先介绍一下如何用github上创建一个项目吧. 用户登录后的界面如上所示.右下角是我们已经建好的库.点击其中任何一个就可以查看相应的库了.如果要新建一个项目的话,就点击Start a projec ...
- Java中的Set集合
Set接口简介 Set接口和List接口一样,同样继承自Collection接口,它与Collection接口中的方法基本一致,并没有对Collection接口进行功能上的扩充,它是比Collecti ...
- mysql建表约束
--mysql建表约束--主键约束它能够唯一确定一张表中的内容,也就是我们通过某个字段添加约束,就可以是的该字段唯一(不重复)且不为空.create table user( id int pr ...