Deep Transfer Network: Unsupervised Domain Adaptation
转自:http://blog.csdn.net/mao_xiao_feng/article/details/54426101
一、Domain adaptation
在开始介绍之前,首先我们需要知道Domain adaptation的概念。Domain adaptation,我在标题上把它称之为域适应,但是在文中我没有再翻译它,而是保持它的英文原意,这也有助于我们更好的理解它的概念。
Domain adaptation的目标是在某一个训练集上训练的模型,可以应用到另一个相关但不相同的测试集上。
对这个问题列出一个规范的形式:
给出source dataset(源数据,也就是初始的训练集)其中
表示源数据集的数量,它是有标签的。
给出target dataset(目标数据,就是相关域的数据集)其中
表示目标数据集的数量,它是无标签的。
现在Domain adaptation的目标就是使用提供的所有数据训练一个统计模型,从而最小化预测误差,其中
是第i个样本的预测标签,而
就是对应的真实标签,它也是未知的。我们考虑一种情况,source和target dataset之间的边缘分布和条件分布都不相同,即
≠
且
≠
ok,问题就是上述的那样,关于Domain adaptation我们可以再扯一点闲话,所谓边缘分布就是数据在特征空间当中的分布,如果你不理解特征空间这个词,把它理解为数据分布就好。可能还会有人问现实当中数据分布很抽象,你怎么知道几万张图片,它们的分布是怎样的?这个问题是初入坑必须要搞明白的,衡量图像我们也是通过特征(例如,haar特征,梯度,颜色直方图等等),将图像特征量化成数字,分布就能看出来了,所以记住我们讨论分布的前提是我们已经确定用哪种特征来衡量数据。同样条件分布就是某个确定样本的分类概率分布了,如果是二分类问题,那么此条件分布就看作一个伯努利分布,其他情况以此类推。
Domain adaptation有哪些实现手段呢?几乎所有的手段都尝试去学习一个特征转换,使得在转换过后的特征空间上,source dataset和target dataset分布的区分度达到最小。现实世界当中这个问题又分为不同的类型:1)边缘分布相同,条件分布不同且相关2)边缘分布不同且相关,条件分布相同3)边缘分布和条件分布都不同且相关。这几种情况其实可以归纳到迁移学习domain和task的范畴中,以后我会写一篇文章专门对迁移学习和Domain adaptation作整理。
Instance reweighting和subspace learning是Domain adaptation中两种经典的学习策略,前者对source data每一个样本加权,学习一组权使得分布差异最小化,后者则是转换到一个新的共享样本空间上,使得两者的分布相匹配。另外比较重要的的一点是,实际训练当中,“最小化分布差异”这个约束条件是放在目标函数中和最小化误差一起优化的,而不是单独优化。
二、DTN之共享特征抽取层
Deep Transfer Network(这里简称DTN)就是一个用深度网络去做Domain adaptation的理念,这个网络被分为了两种类型的层,共享特征抽取层和判别层。第一层共享特征抽取层用于匹配边缘分布,共享特征抽取层可以是一个多层感知机,如果网络层数为l的话,我们一般会把前l-1层看作共享特征抽取层,而l-1层的输出则是一种分布相近的共享特征,它可用于后面做类别判断。
在这之前我们应该指定一个分布差异的度量标准,这里使用了empirical Maximum Mean Discrepancy(MMD),假设有source dataset和target dataset分别为和
,指定
那么有下式成立:
其中M是MMD矩阵
接下来我们讨论如何进行match,假设W是k*d的投影矩阵,它把d维特征向量x投影到k维上面,然后通过激活函数f做一个非线性变化,得到h:
在经历了l-1层的类似变换以后,假设输出为根据这个输出我们可以列出source dataset和target dataset的边缘分布分别为
和
,在l-1层的输出上,我们规定约束两者的边缘分布差异最小,于是有
和
,同样令
,那么最终的MMD度量如下:
三、DTN之判别层
在判别层,其实和传统的神经网络没多大区别,多用softmax回归来做概率预测,这里也是一样,列出softmax的判别公式:
其中
是最后一层的连接权值,j是总的类数,j为2时退化成逻辑斯蒂回归,这个相信大家都清楚。最后列出以条件概率形式表述的MMD如下:
四、优化步骤
最后加上上述的两个约束条件,目标函数变为
后面的λ和μ都是人为指定的,分别表示了约束的重要性程度,为0时,就退化成了传统神经网络。
优化还是使用的随机梯度下降,上面列出来的两个MMD公式当中,首先我们要确定哪个是变量,这是很重要的,虽然这个问题很弱智。在确定了变量之后,我们可以计算对于,MMD的偏导
,分为两种情况来讨论:
同样的对于,也分为source dataset和target dataset两种情况:
接下来只要求和
整个问题就ok了,我们列出目标函数的梯度求解公式:
看似一切都解决了!但其实别忘了还有一个问题,就是MMD其实是对全部数据集来求的,但是在神经网络中不可能做到这一点,所以训练的时候采取了mini batch的方法,用一个batch来代替数据集的分布,mini batch其实并不会影响实验结果,因为假设将数据集切分为N份,在N上的MMD会大于总的数据集上的MMD,用公式表达如下:
也就是说,只要我们约束了就一定能使得原MMD最小。
最后,总结一下算法的流程:
Deep Transfer Network: Unsupervised Domain Adaptation的更多相关文章
- Unsupervised Domain Adaptation by Backpropagation
目录 概 主要内容 代码 Ganin Y. and Lempitsky V. Unsupervised Domain Adaptation by Backpropagation. ICML 2015. ...
- 论文笔记:Unsupervised Domain Adaptation by Backpropagation
14年9月份挂出来的文章,基本思想就是用对抗训练的方法来学习domain invariant的特征表示.方法也很只管,在网络的某一层特征之后接一个判别网络,负责预测特征所属的domain,而后特征提取 ...
- Unsupervised Domain Adaptation Via Domain Adversarial Training For Speaker Recognition
年域适应挑战(DAC)数据集的实验表明,所提出的方法不仅有效解决了数据集不匹配问题,而且还优于上述无监督域自适应方法.
- Domain Adaptation (3)论文翻译
Abstract The recent success of deep neural networks relies on massive amounts of labeled data. For a ...
- Domain Adaptation (1)选题讲解
1 所选论文 论文题目: <Unsupervised Domain Adaptation with Residual Transfer Networks> 论文信息: NIPS2016, ...
- 【论文笔记】Domain Adaptation via Transfer Component Analysis
论文题目:<Domain Adaptation via Transfer Component Analysis> 论文作者:Sinno Jialin Pan, Ivor W. Tsang, ...
- Domain adaptation:连接机器学习(Machine Learning)与迁移学习(Transfer Learning)
domain adaptation(域适配)是一个连接机器学习(machine learning)与迁移学习(transfer learning)的新领域.这一问题的提出在于从原始问题(对应一个 so ...
- Domain Adaptation论文笔记
领域自适应问题一般有两个域,一个是源域,一个是目标域,领域自适应可利用来自源域的带标签的数据(源域中有大量带标签的数据)来帮助学习目标域中的网络参数(目标域中很少甚至没有带标签的数据).领域自适应如今 ...
- What are the advantages of ReLU over sigmoid function in deep neural network?
The state of the art of non-linearity is to use ReLU instead of sigmoid function in deep neural netw ...
随机推荐
- 13. js延迟加载的方式有哪些
JS延迟加载,也就是等页面加载完成之后再加载 JavaScript 文件. JS延迟加载有助于提高页面加载速度. 一般有以下几种方式: 1)defer 属性 <script src=&q ...
- Hibernate学习笔记(六)—— 查询优化
一.Hibernate的抓取策略 1.1 什么是抓取策略 抓取策略是当应用程序需要在(Hibernate实体对象图的)关联关系间进行导航的时候,Hibernate如何获取关联对象的策略. HIbern ...
- bytes和str之间的转换
1.方法:decode解码(二进制转换成字符串) encode与上相反
- JAVA数据结构--希尔排序
希尔排序通过将比较的全部元素分为几个区域来提升插入排序的性能.这样可以让一个元素可以一次性地朝最终位置前进一大步.然后算法再取越来越小的步长进行排序,算法的最后一步就是普通的插入排序,但是到了这步,需 ...
- 石头剪刀布(2019Wannafly winter camp day3 i) 带权并查集+按秩合并 好题
题目传送门 思路: 按照题意描述,所有y挑战x的关系最后会形成一棵树的结构,n个人的总方案数是 3n 种,假设一个人被挑战(主场作战)a次,挑战别人(客场)b次,那么这个人存活到最后的方案数就是3n* ...
- 115th LeetCode Weekly Contest Prison Cells After N Days
There are 8 prison cells in a row, and each cell is either occupied or vacant. Each day, whether the ...
- Python入门(1)
1.编程语言 机器语言:直接用计算机能听懂的二进制指令去编写程序,需要了解硬件的细节 汇编语言:用英文标签取代二进制指令去编写程序,同样需要了解硬件的细节 高级语言:直接用人类能理解的表达方式去编写程 ...
- display inline-block 间隔
1.如果li横排用display:inline-block; 则li之间不能有间隔 必须连着一起,所以才一般用float:left; .today-wrap{ position: relative; ...
- 搭建Flask+Vue及配置Vue 基础路由
最近一直在看关于Python的东西,准备多学习点东西.以前的项目是用Vue+Java写的,所以试着在升级下系统的前提下.能不能使用Python+Vue做一遍. 选择Flask的原因是不想随大流,并且比 ...
- Android Zygote进程是如何fork一个APP进程的
进程创建流程 不管从桌面启动应用还是应用内启动其它应用,如果这个应用所在进程不存在的话,都需要发起进程通过Binder机制告诉system server进程的AMS system server进程的A ...