论文信息

论文标题:Adversarial Discriminative Domain Adaptation
论文作者:Eric Tzeng, Judy Hoffman, Kate Saenko, Trevor Darrell
论文来源:CVPR 2017
论文地址:download 
论文代码:download
引用次数:3257

1 简介

  本文主要探讨的是:源域和目标域特征提取器共享参数的必要性。

  源域和目标域特征提取器共享参数的代表——DANN。

2 对抗域适应

  标准监督损失训练源数据:

    $\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{t}\right)=  \mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{t}\right)}-\sum\limits _{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\quad\quad(1)$

  域对抗:首先使得域鉴别器分类准确,即最小化交叉熵损失 $\mathcal{L}_{\operatorname{adv}_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)$:

    $\begin{array}{l}\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)= -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right]\end{array} \quad\quad(2)$

  其次,源映射和目标映射根据一个受约束的对抗性目标进行优化(使得域鉴别器损失最大)。

  域对抗技术的通用公式如下:

    $\begin{array}{l}\underset{D}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right) \\\underset{M_{s}, M_{t}}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right) \\\text { s.t. } & \psi\left(M_{s}, M_{t}\right)\end{array}\quad\quad(3)$

2.1 源域和目标域映射

  

  归结为三个问题:

    • 选择生成式模型还是判别式模型?
    • 针对源域与目标域的映射是否共享参数?
    • 损失函数如何定义?

2.2 Adversarial losses

  回顾DANN 的训练方式:DANN 的梯度反转层优化映射,使鉴别器损失最大化

    $\mathcal{L}_{\text {adv }_{M}}=-\mathcal{L}_{\mathrm{adv}_{D}}\quad\quad(6)$

  这个目标可能有问题,因为在训练的早期,鉴别器快速收敛,导致梯度消失。

  当训练 GANs 时,而不是直接使用 minimax,通常是用带有倒置标签[10]的标准损失函数来训

  回顾 GAN :GAN将优化分为两个独立的目标,一个用于生成器,另一个用于鉴别器。训练生成器的时候,其中 $\mathcal{L}_{\mathrm{adv}_{D}}$ 保持不变,但 $\mathcal{L}_{\mathrm{adv}_{M}}$ 变成:

    $\mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)=-\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] \quad\quad(7)$

  Note:$\mathbf{x}_{t}$ 代表噪声数据,这里是使得噪声数据尽可能迷惑鉴别器。

adversarial_loss = torch.nn.BCELoss()  # 损失函数(二分类交叉熵损失)
generator = Generator() #生成器
discriminator = Discriminator() #鉴别器 optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 生成器优化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 鉴别器优化器 for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) #torch.Size([64, 1])
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) #torch.Size([64, 1])
real_imgs = Variable(imgs.type(Tensor)) #torch.Size([64, 1, 28, 28]) 真实数据 # ----------------------> 训练生成器 [生成器使用噪声数据,使得其尽可能为真,迷惑鉴别器]
optimizer_G.zero_grad()
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) #torch.Size([64, 100])
gen_imgs = generator(z) #torch.Size([64, 1, 28, 28])
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step() # ----------------------> 训练鉴别器 [ 尽可能将真实数据和噪声数据区分开]
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()

GAN code

  本文采用的方法类似于  GAN 。

3 对抗性域适应

  与之前方法不同:

  

  本文方法:

  

  首先:Pretrain ,使用源域训练一个分类器;[ 公式 9 第一个子公式]

  其次:Adversarial Adaption

    1. :使用源域和目标域数据,训练一个域鉴别器 Discriminator ,是的鉴别器尽可能区分源域和目标域数据 ;[ 公式 9 第二个子公式]  
    2. :使用目标域数据,训练目标域特征提取器,尽可能使得域鉴别器区分不出目标域样本;[ 公式 9 第三个子公式]  

  最后:Testing,在目标域上做 Eval;

  ADDA对应于以下无约束优化:

    $\begin{array}{l}\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{s}\right) &=&-\mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{s}\right)} \sum_{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right) \\\underset{D}{\text{min}}  \quad\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)&=& -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] \text { - } \mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right] \\\underset{M_{t}}{\text{min}}  \quad \mathcal{L}_{\operatorname{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)&=& -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] \\\end{array} \quad\quad(9)$

    tgt_encoder.train()
discriminator.train() # setup criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_tgt = optim.Adam(tgt_encoder.parameters(),lr=params.c_learning_rate,betas=(params.beta1, params.beta2))
optimizer_discriminator = optim.Adam(discriminator.parameters(),lr=params.d_learning_rate,betas=(params.beta1, params.beta2))
len_data_loader = min(len(src_data_loader), len(tgt_data_loader)) #149 for epoch in range(params.num_epochs):
# zip source and target data pair
data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
for step, ((images_src, _), (images_tgt, _)) in data_zip:
# 2.1 训练域鉴别器,使得域鉴别器尽可能的准确
images_src = make_variable(images_src)
images_tgt = make_variable(images_tgt)
discriminator.zero_grad()
feat_src,feat_tgt = src_encoder(images_src) ,tgt_encoder(images_tgt) # 源域特征提取 # 目标域特征提取
feat_concat = torch.cat((feat_src, feat_tgt), 0)
pred_concat = discriminator(feat_concat.detach()) # 域分类结果 label_src = make_variable(torch.ones(feat_src.size(0)).long()) #假设源域的标签为 1
label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long()) #假设目标域域的标签为 0
label_concat = torch.cat((label_src, label_tgt), 0) loss_critic = criterion(pred_concat, label_concat)
loss_critic.backward()
optimizer_discriminator.step() # 域鉴别器优化 pred_cls = torch.squeeze(pred_concat.max(1)[1])
acc = (pred_cls == label_concat).float().mean() # 2.2 train target encoder # 使得目标域特征生成器,尽可能使得域鉴别器区分不出源域和目标域样本
optimizer_discriminator.zero_grad()
optimizer_tgt.zero_grad()
feat_tgt = tgt_encoder(images_tgt)
pred_tgt = discriminator(feat_tgt)
label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long()) #假设目标域域的标签为 1(错误标签),使得域鉴别器鉴别错误
loss_tgt = criterion(pred_tgt, label_tgt)
loss_tgt.backward()
optimizer_tgt.step() # 目标域 encoder 优化

ADDA Code

迁移学习(ADDA)《Adversarial Discriminative Domain Adaptation》的更多相关文章

  1. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  2. Domain adaptation:连接机器学习(Machine Learning)与迁移学习(Transfer Learning)

    domain adaptation(域适配)是一个连接机器学习(machine learning)与迁移学习(transfer learning)的新领域.这一问题的提出在于从原始问题(对应一个 so ...

  3. 迁移学习(Transformer),面试看这些就够了!(附代码)

    1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...

  4. 中文NER的那些事儿2. 多任务,对抗迁移学习详解&代码实现

    第一章我们简单了解了NER任务和基线模型Bert-Bilstm-CRF基线模型详解&代码实现,这一章按解决问题的方法来划分,我们聊聊多任务学习,和对抗迁移学习是如何优化实体识别中边界模糊,垂直 ...

  5. 论文阅读 | A Curriculum Domain Adaptation Approach to the Semantic Segmentation of Urban Scenes

    paper链接:https://arxiv.org/pdf/1812.09953.pdf code链接:https://github.com/YangZhang4065/AdaptationSeg 摘 ...

  6. 【论文笔记】Domain Adaptation via Transfer Component Analysis

    论文题目:<Domain Adaptation via Transfer Component Analysis> 论文作者:Sinno Jialin Pan, Ivor W. Tsang, ...

  7. 域适应(Domain adaptation)

    定义 在迁移学习中, 当源域和目标的数据分布不同 ,但两个任务相同时,这种 特殊 的迁移学习 叫做域适应 (Domain Adaptation). Domain adaptation有哪些实现手段呢? ...

  8. Deep Transfer Network: Unsupervised Domain Adaptation

    转自:http://blog.csdn.net/mao_xiao_feng/article/details/54426101 一.Domain adaptation 在开始介绍之前,首先我们需要知道D ...

  9. Domain Adaptation论文笔记

    领域自适应问题一般有两个域,一个是源域,一个是目标域,领域自适应可利用来自源域的带标签的数据(源域中有大量带标签的数据)来帮助学习目标域中的网络参数(目标域中很少甚至没有带标签的数据).领域自适应如今 ...

  10. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

随机推荐

  1. $_SERVER['HTTP_USER_AGENT']:在PHP中HTTP_USER_AGENT是用来获取用户的相关信息的,包括用户使用的浏览器,操作系统等信息

    在PHP中HTTP_USER_AGENT是用来获取用户的相关信息的,包括用户使用的浏览器,操作系统等信息. 我机器:操作系统:WIN7旗舰版 64操作系统 以下为各个浏览器下$_SERVER['HTT ...

  2. Linux正则表达式与grep

    bash是什么 bash是一个命令处理器,运行在文本窗口中,并能执行用户直接输入的命令 bash还能从文件中读取linxu命令,称之为脚本 bash支持通配符.管道.命令替换.条件判断等逻辑控制语句 ...

  3. Java多线程-ThreadPool线程池(三)

    开完一趟车完整的过程是启动.行驶和停车,但老司机都知道,真正费油的不是行驶,而是长时间的怠速.频繁地踩刹车等动作.因为在速度切换的过程中,发送机要多做一些工作,当然就要多费一些油. 而一个Java线程 ...

  4. pod(八):pod的调度——将 Pod 指派给节点

    目录 一.系统环境 二.前言 三.pod的调度 3.1 pod的调度概述 3.2 pod自动调度 3.2.1 创建3个主机端口为80的pod 3.3 使用nodeName 字段指定pod运行在哪个节点 ...

  5. FHE学习笔记 #2 多项式环

    https://en.wikipedia.org/wiki/Polynomial_ring https://zhuanlan.zhihu.com/p/419266064 这篇知乎文章讲的比较透彻,但是 ...

  6. AGC007C Pushing Balls —— 期望的神题

    Problem Link 题意: 序列上按顺序交错有 \(n\) 个球和 \(n+1\) 个洞,即 \(hole_1,ball_1,hole_2,ball_2,\dots,ball_n,hole_{n ...

  7. Mybatis:解决调用带有集合类型形参的mapper方法时,集合参数为空或null的问题

    此文章有问题,待修改! 使用Mybatis时,有时需要批量增删改查,这时就要向mapper方法中传入集合类型(List或Set)参数,下面是一个示例. // 该文件不完整,只展现关键部分 @Mappe ...

  8. C++初阶(命名空间+缺省参数+const总结+引用总结+内联函数+auto关键字)

    命名空间 概述 在C/C++中,变量.函数和后面要学到的类都是大量存在的,这些变量.函数和类的名称将都存在于全局作用域中,可能会导致很多冲突.使用命名空间的目的是对标识符的名称进行本地化,以避免命名冲 ...

  9. MVC开发单元测试小工具 —— 搜寻还没写单元测试的方法

    方法比较笨,有更好的建议可以提. 写这个工具呢,因为要写单元测试,保证代码质量,方便修改维护.一切为了自己方便.当然这个算是个人开发的项目 1.MVC中控制器建立个基类(这个光明正大的抄袭的),控制往 ...

  10. C温故补缺(二):volatile

    volatile 参考:CSDN volatile也是一个类型修饰符,被其修饰的变量意味着可以被某些编译器未知的因素修改,如操作系统,硬件,线程等. 当遇到volatile修饰的变量时,编译器对访问该 ...