论文解读(AAD)《Knowledge distillation for BERT unsupervised domain adaptation》
Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]
论文信息
论文标题:Knowledge distillation for BERT unsupervised domain adaptation
论文作者:Minho Ryu、Geonseok Lee、Kichun Lee
论文来源:2022 aRxiv
论文地址:download
论文代码:download
视屏讲解:click
1 介绍
出发点:域偏移导致的性能下降;
问题定义:UDA
比较有意思,这篇工作被抄袭了,但是抄袭的家伙还成功发论文了.............
2 相关工作
知识蒸馏 [7,8](KD)最初是一种模型压缩技术,旨在训练一个紧凑的模型(学生),以便将一个训练良好的更大的模型(教师)的知识转移到学生模型[28,29]。KD 可以通过最小化以下目标函数来表示:
$\mathcal{L}_{K D}=t^{2} \sum_{k}-\operatorname{softmax}\left(p_{k}^{T} / t\right) \times \log \left(\operatorname{softmax}\left(p_{k}^{S} / t\right)\right)$
其中,$p^{S}$ 和 $p^{T}$ 分别为学生模型和教师模型的预测,温度值 $t$ 控制着知识转移的程度。
推导过程:
$K L(p \| q)=\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(\frac{p\left(x_{i}\right)}{q(x i)}\right)$
$\begin{array}{l} K L(p \| q)&=\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(p\left(x_{i}\right)\right)-\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(q\left(x_{i}\right)\right)\\&=H(p(x)) -\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(q\left(x_{i}\right)\right)\end{array}$
注意:$P$ 代表着真实分布, $Q$ 代表着模型分布;
注意:学生模型训练时,教师模型的参数是固定的,因此 $H(p(x))$ 为常数,可以去掉;
注意:标准的监督训练,由于使用的是硬标签做监督训练,所以在重复训练的时候容易造成过拟合。由于较大的 $t$ 值产生较软的概率分布,知识蒸馏在结合领域自适应方法可以缓解这一问题。
3 方法
3.1 模型框架

3.2 Adversarial adaptation with distillation
Step 1: fine-tune the source encoder and the classifier
使用源域数据进行标准的监督训练,训练 $E_s$ 和 $C$:
$\underset{E_{S}, C}{\text{min}} \; \mathcal{L}_{S}\left(\mathbf{X}_{S}, \mathbf{y}_{S}\right)=\mathbb{E}_{\left(\boldsymbol{x}_{s}, y_{s}\right) \sim\left(\mathbb{X}_{S}, \mathbb{Y}_{S}\right)}-\sum_{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(E_{S}\left(\boldsymbol{x}_{S}\right)\right)$
Step 2: adapt the target encoder via adversarial adaptation with distillation
固定 $E_s$ 的参数,并使用 $E_s$ 初始化 $E_t$ 的参数,接着进行对抗性训练:
$\begin{array}{l}\underset{D}{\text{min}} \; \mathcal{L}_{\text {dis }}\left(\mathbf{X}_{S}, \mathbf{X}_{T}\right)=\mathbb{E}_{\boldsymbol{x}_{s} \sim \mathbb{X}_{S}}-\log D\left(E_{S}\left(\boldsymbol{x}_{s}\right)\right)+\mathbb{E}_{\boldsymbol{x}_{t} \sim \mathbb{X}_{T}}-\log \left(1-D\left(E_{t}\left(\boldsymbol{x}_{t}\right)\right)\right)\\\underset{E_{t}}{\text{min}} \; \mathcal{L}_{g e n} \left(\mathbf{X}_{T}\right)=\mathbb{E}_{\boldsymbol{x}_{t} \sim \mathbb{X}_{T}}-\log D\left(E_{t}\left(\boldsymbol{x}_{t}\right)\right)\end{array}$
然而,由于无法使用类标签,该公式很容易导致灾难性的遗忘,从而导致分类性能下降。对于一个使用大的 $t$ 的知识蒸馏模型,它不仅可以使得对抗性训练稳定,还可以良好的保存类信息。因此,引入了知识蒸馏损失:
$\mathcal{L}_{K D}\left(\mathbf{X}_{S}\right)=t^{2} \times \mathbb{E}_{\boldsymbol{x}_{s} \sim \mathbb{X}_{S}} \sum_{k=1}^{K}-\operatorname{softmax}\left(p_{k}^{S} / t\right) \times \log \left(\operatorname{softmax}\left(p_{k}^{T} / t\right)\right)$
其中,$p^{S}=C\left(E_{S}\left(\boldsymbol{x}_{s}\right)\right)$、$\boldsymbol{p}^{T}=C\left(E_{t}\left(\boldsymbol{x}_{s}\right)\right)$;
因此,训练目标编码器 $E_{t}$ 的最终目标函数变为:
$\underset{E_{t}}{\text{min}} \;\mathcal{L}_{T}\left(\mathbf{X}_{S}, \mathbf{X}_{T}\right)=\mathcal{L}_{\text {gen }}\left(\mathbf{X}_{T}\right)+\mathcal{L}_{K D}\left(\mathbf{X}_{S}\right)$
Step 3: test the target encoder on the target data
使用训练好的目标编码器 $E_t$ 和分类器 $C$ 对用于测试的目标数据情绪极性标签预测如下:
$\hat{y}_{t}=\arg \max C\left(E_{t}\left(\boldsymbol{x}_{t}\right)\right)$
4 实验
跨域情感分析

论文解读(AAD)《Knowledge distillation for BERT unsupervised domain adaptation》的更多相关文章
- 论文解读(CDCL)《Cross-domain Contrastive Learning for Unsupervised Domain Adaptation》
论文信息 论文标题:Cross-domain Contrastive Learning for Unsupervised Domain Adaptation论文作者:Rui Wang, Zuxuan ...
- 论文解读(CDTrans)《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》
论文信息 论文标题:CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation论文作者:Tongkun Xu, Weihu ...
- 论文解读(ToAlign)《ToAlign: Task-oriented Alignment for Unsupervised Domain Adaptation》
论文信息 论文标题:ToAlign: Task-oriented Alignment for Unsupervised Domain Adaptation论文作者:Guoqiang Wei, Cuil ...
- 论文解读(CAN)《Contrastive Adaptation Network for Unsupervised Domain Adaptation》
论文信息 论文标题:Contrastive Adaptation Network for Unsupervised Domain Adaptation论文作者:Guoliang Kang, Lu Ji ...
- 虚假新闻检测(CADM)《Unsupervised Domain Adaptation for COVID-19 Information Service with Contrastive Adversarial Domain Mixup》
论文信息 论文标题:Unsupervised Domain Adaptation for COVID-19 Information Service with Contrastive Adversari ...
- 迁移学习(IIMT)——《Improve Unsupervised Domain Adaptation with Mixup Training》
论文信息 论文标题:Improve Unsupervised Domain Adaptation with Mixup Training论文作者:Shen Yan, Huan Song, Nanxia ...
- 迁移学习(DCCL)《Domain Confused Contrastive Learning for Unsupervised Domain Adaptation》
论文信息 论文标题:Domain Confused Contrastive Learning for Unsupervised Domain Adaptation论文作者:Quanyu Long, T ...
- 迁移学习(TSRP)《Improving Pseudo Labels With Intra-Class Similarity for Unsupervised Domain Adaptation》
论文信息 论文标题:Improving Pseudo Labels With Intra-Class Similarity for Unsupervised Domain Adaptation论文作者 ...
- 迁移学习《Asymmetric Tri-training for Unsupervised Domain Adaptation》
论文信息 论文标题:Asymmetric Tri-training for Unsupervised Domain Adaptation论文作者:Kuniaki Saito, Y. Ushiku, T ...
- 迁移学习《Efficient and Robust Pseudo-Labeling for Unsupervised Domain Adaptation》
论文信息 论文标题:Efficient and Robust Pseudo-Labeling for Unsupervised Domain Adaptation论文作者:Hochang Rhee.N ...
随机推荐
- vue下载附件按钮功能
一.tools文件夹下tools文件中封装下载方法: const iframeId = 'download_file_iframe' function iframeEle (src) { let el ...
- 2023-01-08:小红定义一个仅有r、e、d三种字符的字符串中, 如果仅有一个长度不小于2的回文子串,那么这个字符串定义为“好串“。 给定一个正整数n,输出长度为n的好串有多少个。 结果对10^9
2023-01-08:小红定义一个仅有r.e.d三种字符的字符串中, 如果仅有一个长度不小于2的回文子串,那么这个字符串定义为"好串". 给定一个正整数n,输出长度为n的好串有多少 ...
- 2021-05-27:定义何为step sum?比如680,680+68+6=754,680的step sum叫754。
2021-05-27:定义何为step sum?比如680,680+68+6=754,680的step sum叫754.给定一个整数num,判断它是不是某个数的step sum? 福大大 答案2021 ...
- 【Qt 6】读写剪贴板
剪贴板是个啥就不用多介绍了,最直观的功能是实现应用程序之间数据共享.就是咱们常说的"复制"."粘贴"功能. 在 Qt 中,QClipboard 类提供了相关 A ...
- Linux 创建 Python 虚拟环境
Linux 创建 Python 虚拟环境 0. 前言 网上教程太杂太乱,要么排版不好看,要么讲半天讲不到重点,故做此篇,精简干练. 1. 安装virtualenv 先用pip安装virtualenv第 ...
- [MAUI]模仿Chrome下拉标签页的交互实现
@ 目录 创建粘滞效果的圆控件 贝塞尔曲线绘制圆 创建控件 创建形变 可控形变 形变边界 形变动画 创建手势控件 创建页面布局 更新拖拽物位置 其它细节 项目地址 今天来说说怎样在.NET MAUI中 ...
- Jackson前后端开发模式必备json利器
前言 json是我们现代互联网程序最常用的交互格式,是否你在工作中会遇到前端说字段不一致需要改的需求,是否遇到过数据库字段名与javaBean的规范不同,是否遇到过json与javaBean相互转换时 ...
- Android Studio 引入kotlin 协程
首先保证创建了kotlin项目,然后,两个步骤: 1. 加入dependency,在 build.gradle(:app)中,加入 implementation 'org.jetbrains.kotl ...
- 【技术积累】Python中的NumPy库【一】
NumPy库是什么 NumPy是Python科学计算的核心库之一,用来进行科学计算,数值分析等矩阵运算.主要提供了以下几种功能: 1.多维数组(ndarray)对象,可以进行快速的数值计算和数组操作: ...
- 性能_3 jmeter连接数据库jdbc(sql server举例)
一.下载第三方工具包驱动数据库 1. 因为JMeter本身没有提供链接数据库的功能,所以我们需要借助第三方的工具包来实现. (有这个jar包之后,jmeter可以发起jdbc请求,没有这个jar包, ...