论文信息

论文标题:Generalized Domain Adaptation with Covariate and Label Shift CO-ALignment
论文作者:Shuhan Tan, Xingchao Peng, Kate Saenko
论文来源:ICLR 2020
论文地址:download 
论文代码:download
视屏讲解:click

1 摘要

  提出问题:在已知类的基础上发现新类;

2 介绍

2.1 当前工作

  假设条件标签分布不变 $p(y \mid x)=q(y \mid x)$,只有特征偏移 $p(x) \neq q(x)$,忽略标签偏移 $p(y) \neq q(y)$。

  假设不成立的原因:

    • 场景不同,标签跨域转移 $p(y) \neq q(y)$ 很常见;
    • 如果存在标签偏移,则当前的 UDA 工作性能显著下降;
    • 一个合适的 UDA 方法应该能同时处理协变量偏移和标签偏移;

2.2 本文工作

  本文提出类不平衡域适应 (CDA),需要同时处理 条件特征转移 和 标签转移。

  具体来说,除了协变量偏移假设 $p(x) \neq   q(x)$, $p(y \mid x)=q(y \mid x)$,进一步假设 $p(x \mid y) \neq q(x \mid y)$ 和 $p(y) \neq q(y)$。

  CDA 的主要挑战:

    • 标签偏移阻碍了主流领域自适应方法的有效性,这些方法只能边缘对齐特征分布;
    • 在存在标签偏移的情况下,对齐条件特征分布 $p(x \mid y)$, $q(x \mid y)$ 很困难;
    • 当一个或两个域中的数据在不同类别中分布不均时,很难训练无偏分类器;

  CDA 概述:

  

3 问题定义

  In Class-imbalanced Domain Adaptation, we are given a source domain  $\mathcal{D}_{\mathcal{S}}=   \left\{\left(x_{i}^{s}, y_{i}^{s}\right)_{i=1}^{N_{s}}\right\}$  with  $N_{s}$  labeled examples, and a target domain  $\mathcal{D}_{\mathcal{T}}=\left\{\left(x_{i}^{t}\right)_{i=1}^{N_{t}}\right\}$  with  $N_{t}$  unlabeled examples. We assume that  $p(y \mid x)=q(y \mid x)$  but  $p(x \mid y) \neq   q(x \mid y)$, $p(x) \neq q(x)$ , and  $p(y) \neq q(y)$ . We aim to construct an end-to-end deep neural network which is able to transfer the knowledge learned from  $\mathcal{D}_{\mathcal{S}}$  to  $\mathcal{D}_{\mathcal{T}}$ , and train a classifier  $y=\theta(x)$  which can minimize task risk in target domain  $\epsilon_{T}(\theta)=\operatorname{Pr}_{(x, y) \sim q}[\theta(x) \neq y]$.

4 方法

4.1 整体框架

  

4.2 用于特征转移的基于原型的条件对齐

  目的:对齐 $p(x \mid y)$ 和 $q(x \mid y)$

  步骤:首先使用原型分类器(基于相似度)估计 $p(x \mid y)$ ,然后使用一种 $\text{minimax entropy}$ 算法将其和 $q(x \mid y)$ 对齐;

4.2.1 原型分类器

  原因:基于原型的分类器在少样本学习设置中表现良好,因为在标签偏移的假设下中,某些类别的设置频率可能较低;

# 深层原型分类器
class Predictor_deep_latent(nn.Module):
def __init__(self, in_dim = 1208, num_class = 2, temp = 0.05):
super(Predictor_deep_latent, self).__init__()
self.in_dim = in_dim
self.hid_dim = 512
self.num_class = num_class
self.temp = temp #0.05 self.fc1 = nn.Linear(self.in_dim, self.hid_dim)
self.fc2 = nn.Linear(self.hid_dim, num_class, bias=False) def forward(self, x, reverse=False, eta=0.1):
x = self.fc1(x)
if reverse:
x = GradReverse.apply(x, eta)
feat = F.normalize(x)
logit = self.fc2(feat) / self.temp
return feat, logit

  源域上的样本使用交叉熵做监督训练:

    $\mathcal{L}_{S C}=\mathbb{E}_{(x, y) \in \mathcal{D}_{S}} \mathcal{L}_{c e}(h(x), y)  \quad \quad \quad(1)$

  样本 $x$ 被分类为 $i$ 类的置信度越高,$x$ 的嵌入越接近 $w_i$。因此,在优化上式时,通过将每个样本 $x$ 的嵌入更接近其在 $W$ 中的相应权重向量来减少类内变化。所以,可以将 $w_i$ 视为 $p$ 的代表性数据点(原型) $p(x \mid y=i)$ 。

4.2.2 通过 Minimax Entropy 实现条件对齐

  目标域缺少数据标签,所以使用 $\text{Eq.1}$ 获得类原型是不可行的;

  解决办法:

    • 将每个源原型移动到更接近其附近的目标样本;
    • 围绕这个移动的原型聚类目标样本;

  因此,提出 熵极小极大 实现上述两个目标。

  具体来说,对于输入网络的每个样本 $x^{t} \in \mathcal{D}_{\mathcal{T}}$,可以通过下式计算分类器输出的平均熵

    $\mathcal{L}_{H}=\mathbb{E}_{x \in \mathcal{D}_{\mathcal{T}}} H(x)=-\mathbb{E}_{x \in \mathcal{D}_{\mathcal{T}}} \sum_{i=1}^{c} h_{i}(x) \log h_{i}(x)\quad \quad \quad(2)$

  通过在对抗过程中对齐源原型和目标原型来实现条件特征分布对齐:

    • 训练 $C$ 以最大化 $\mathcal{L}_{H}$ ,旨在将原型从源样本移动到邻近的目标样本;
    • 训练 $F$ 来最小化 $\mathcal{L}_{H}$,目的是使目标样本的嵌入更接近它们附近的原型;

4.3 标签转移的类平衡自训练

  由于源标签分布 $p(y)$ 与目标标签分布 $q(y)$ 不同,因此不能保证在 $\mathcal{D}_{\mathcal{S}}$ 上具有低风险的分类器 $C$ 在 $\mathcal{D}_{\mathcal{T}}$ 上具有低错误。 直观地说,如果分类器是用不平衡的源数据训练的,决策边界将由训练数据中最频繁的类别主导,导致分类器偏向源标签分布。 当分类器应用于具有不同标签分布的目标域时,其准确性会降低,因为它高度偏向源域。

  为解决这个问题,本文使用[19]中的方法进行自我训练来估计目标标签分布并细化决策边界。自训练为了细化决策边界,本文建议通过自训练来估计目标标签分布。 我们根据分类器 $C$ 的输出将伪标签 $y$ 分配给所有目标样本。由于还对齐条件特征分布 $p(x \mid y$ 和 $q(x \mid y)$,假设分布高置信度伪标签 $q(y)$ 可以用作目标域的真实标签分布 $q(y)$ 的近似值。 在近似的目标标签分布下用这些伪标记的目标样本训练 $C$,能够减少标签偏移的负面影响。

  为了获得高置信度的伪标签,对于每个类别,本文选择属于该类别的具有最高置信度分数的目标样本的前 $k%$。利用 $h(x)$ 中的最高概率作为分类器对样本 $x$ 的置信度。 具体来说,对于每个伪标记样本 $(x, y)$,如果 $h(x)$ 位于具有相同伪标签的所有目标样本的前 $k%$ 中,将其选择掩码设置为 $m = 1$,否则 $m = 0 $。将伪标记目标集表示为 $\hat{\mathcal{D}}_{T}=\left\{\left(x_{i}^{t}, \hat{y}_{i}^{t}, m_{i}\right)_{i=1}^{N_{t}}\right\}$,利用来自 $\hat{\mathcal{D}}_{T}$ 的输入和伪标签来训练分类器 $C$,旨在细化决策 与目标标签分布的边界。 分类的总损失函数为:

    $\mathcal{L}_{S T}=\mathcal{L}_{S C}+\mathbb{E}_{(x, \hat{y}, m) \in \hat{\mathcal{D}}_{T}} \mathcal{L}_{c e}(h(x), \hat{y}) \cdot m$

  通常,用 $k_{0}=5$ 初始化 $k$,并设置 $k_{\text {step }}=5$,$k_{\max }=30$。

  Note:本文还对源域数据使用了平衡采样的方法,使得分类器不会偏向于某一类。

4.4 训练目标

  总体目标:

    $\begin{array}{l}\hat{C}=\underset{C}{\arg \min } \mathcal{L}_{S T}-\alpha \mathcal{L}_{H} \\\hat{F}=\underset{F}{\arg \min } \mathcal{L}_{S T}+\alpha \mathcal{L}_{H}\end{array}$

5 总结

  略

迁移学习(COAL)《Generalized Domain Adaptation with Covariate and Label Shift CO-ALignment》的更多相关文章

  1. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  2. Domain adaptation:连接机器学习(Machine Learning)与迁移学习(Transfer Learning)

    domain adaptation(域适配)是一个连接机器学习(machine learning)与迁移学习(transfer learning)的新领域.这一问题的提出在于从原始问题(对应一个 so ...

  3. 迁移学习(IIMT)——《Improve Unsupervised Domain Adaptation with Mixup Training》

    论文信息 论文标题:Improve Unsupervised Domain Adaptation with Mixup Training论文作者:Shen Yan, Huan Song, Nanxia ...

  4. 迁移学习(JDDA) 《Joint domain alignment and discriminative feature learning for unsupervised deep domain adaptation》

    论文信息 论文标题:Joint domain alignment and discriminative feature learning for unsupervised deep domain ad ...

  5. 迁移学习(ADDA)《Adversarial Discriminative Domain Adaptation》

    论文信息 论文标题:Adversarial Discriminative Domain Adaptation论文作者:Eric Tzeng, Judy Hoffman, Kate Saenko, Tr ...

  6. 论文阅读 | A Curriculum Domain Adaptation Approach to the Semantic Segmentation of Urban Scenes

    paper链接:https://arxiv.org/pdf/1812.09953.pdf code链接:https://github.com/YangZhang4065/AdaptationSeg 摘 ...

  7. 【论文笔记】Domain Adaptation via Transfer Component Analysis

    论文题目:<Domain Adaptation via Transfer Component Analysis> 论文作者:Sinno Jialin Pan, Ivor W. Tsang, ...

  8. 域适应(Domain adaptation)

    定义 在迁移学习中, 当源域和目标的数据分布不同 ,但两个任务相同时,这种 特殊 的迁移学习 叫做域适应 (Domain Adaptation). Domain adaptation有哪些实现手段呢? ...

  9. Deep Transfer Network: Unsupervised Domain Adaptation

    转自:http://blog.csdn.net/mao_xiao_feng/article/details/54426101 一.Domain adaptation 在开始介绍之前,首先我们需要知道D ...

  10. Domain Adaptation论文笔记

    领域自适应问题一般有两个域,一个是源域,一个是目标域,领域自适应可利用来自源域的带标签的数据(源域中有大量带标签的数据)来帮助学习目标域中的网络参数(目标域中很少甚至没有带标签的数据).领域自适应如今 ...

随机推荐

  1. 源代码管理工具介绍(以GITHUB为例)

    Github:全球最大的社交编程及代码托管网站,可以托管各种git库,并提供一个web界面 1.基本概念 仓库(Repository):用来存放项目代码,每个项目对应一个仓库,多个开源项目则有多个仓库 ...

  2. C#学习之详解C#Break ,Continue, Return

    C#编程语法中break ,continue, return这三个常用的关键字的学习对于我们编程开发是十分有用的,那么本文就向你介绍break ,continue, return具体的语法规范. C# ...

  3. jvm垃圾收集器汇总

    1.吞吐量和延时 吞吐量:吞吐量指的是cpu的利用时间,计算公式是 运行用户代码时间  / (用户代码时间 + 垃圾收集时间),吞吐量越大说明cpu的利用率越大. 延时:延时指的是停顿时间,用户代码不 ...

  4. BeanFactory与FactoryBean区别

    1. BeanFactory BeanFactory,以Factory结尾,表示它是一个工厂类(接口),用于管理Bean的一个工厂.在Spring中,BeanFactory是IOC容器的核心接口,也是 ...

  5. uglifyjs-webpack-plugin配置

    项目使用vuecli3搭建,在vue.config.js文件中进行配置,主要配置了去除线上环境的打印信息. 首先安装插件, 执行命令 npm install uglifyjs-webpack-plug ...

  6. Linux部署JDK教程

    上一次说了windows下的jdk部署,这一次记录下Linux下的jdk部署,恰巧遇到一篇写的很清楚的教程,我就直接转过来啦,哈哈.. 一. 解压安装jdk 在shell终端下进入jdk-6u14-l ...

  7. [极客大挑战 2019]PHP 1

    进入后提示我们网页有备份文件,这边使用爆破工具,网页会down掉 随便随便猜了一下www.zip,成功下载源码 常见的网页备份有 .git ~ .swp .swo .bak .zip 还不知道是什么题 ...

  8. 大数据 Hadoop 的五大优势

    Hadoop与竞争对手相比有哪些优势? 到目前为止,人们可能已经听说过ApacheHadoop.这个名字来源于一只可爱的玩具大象,但Hadoop只不过是一个毛绒玩具.Hadoop是一个开源软件项目,它 ...

  9. Netty ByteBuf 详解

    ByteBuf类:Netty的数据容器 ByteBuf 维护了两个不同的索引:① readerIndex:用于读取:② writerIndex:用于写入:起始位置都从0开始:​名称以 read或者 w ...

  10. fork语句遇见for循环语句

    一.没有automatic的fork-join_none 通常小白会这么写: 代码如下: foreach(a[i]) begin fork repeat(a[i]) #1ns; $display(&q ...