论文信息

论文标题:A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classification
论文作者:Yunlong Feng, Bohan Li, Libo Qin, Xiao Xu, Wanxiang Che
论文来源:2023 aRxiv
论文地址:download 
论文代码:download
视屏讲解:click

1 介绍

  动机:以前的工作主要集中于提取 域不变特征 或 任务不可知特征,而忽略了存在于目标域中可能对下游任务有用的域感知特征;

  贡献:

    • 提出一个两阶段的学习框架,使现有的分类模型能够有效地适应目标领域;
    • 引入自监督蒸馏,可以帮助模型更好地从目标领域的未标记数据中捕获域感知特征;
    • 在 Amazon 跨域分类基准上的实验表明,取得了 SOTA ;

2 相关

  

  Figure 1(a):阐述域不变特征和域感知特征与任务的关系;

  Figure 1(b):阐述遮蔽域不变特征和域感知特征与预测的关系:

    • 通过掩盖域不变特征,模型建立预测和域感知特征的相关性;
    • 通过掩盖域感知特征,模型加强了预测和域不变特征的关系;
PT

  一个文本提示组成如下:

    $\boldsymbol{x}_{\mathrm{p}}=\text { "[CLS] } \boldsymbol{x} \text {. It is [MASK]. [SEP]"}   \quad\quad(1)$

  $\text{PLM}$ 将 $\boldsymbol{x}_{\mathrm{p}}$ 作为输入,并利用上下文信息用词汇表中的一个单词填充 $\text{[MASK]}$ 作为输出,输出单词随后被映射到一个标签 $\mathcal{Y}$。

  PT 的目标:

    $\mathcal{L}_{p m t}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=-\sum_{\boldsymbol{x}, y \in \mathcal{D}} y \log p_{\theta_{\mathcal{M}}}\left(\hat{y} \mid \boldsymbol{x}_{\mathrm{p}}\right)$

MLM

  使用 $\text{MLM }$ 来避免快捷学习($\text{shortcut learning}$),并适应目标域分布。具体来说,构造了一个掩蔽文本提示符 $\boldsymbol{x}_{\mathrm{pm}}$:

    $\boldsymbol{x}_{\mathrm{pm}}=\text { "[CLS] } \boldsymbol{x}_{\mathrm{m}} \text {. It is [MASK]. [SEP]"}$

  MLM 损失如下:
    $\mathcal{L}_{m l m}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)=-\sum_{\boldsymbol{x} \in \mathcal{D}} \sum_{\hat{x} \in m\left(\boldsymbol{x}_{\mathrm{m}}\right)} \frac{\log p_{\theta_{\mathcal{M}}}\left(\hat{x} \mid \boldsymbol{x}_{\mathrm{pm}}\right)}{\operatorname{len} n_{m\left(\boldsymbol{x}_{\mathrm{m}}\right)}}$

  其中,$m\left(y_{\mathrm{m}}\right)$ 和 $\operatorname{len}_{m\left(\boldsymbol{x}_{\mathrm{m}}\right)}$ 分别表示 $x_{\mathrm{m}}$ 中的掩码词和计数;

SSKD

  核心:使模型能够在预测和目标域的域感知特征之间建立联系;

  具体:模型迫使 $x_{\mathrm{p}}$ 的预测和 $\boldsymbol{x}_{\mathrm{pm}}$ 的未掩蔽词之间联系起来,本文在 $p_{\theta}\left(y \mid \boldsymbol{x}_{\mathrm{pm}}\right)$ 和 $p_{\theta}\left(y \mid \boldsymbol{x}_{\mathrm{p}}\right)$ 的预测之间进行 $\text{KD}$:

    $\mathcal{L}_{s s d}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)=\sum_{\boldsymbol{x} \in \mathcal{D}} K L\left(p_{\theta_{\mathcal{M}}}\left(y \mid \boldsymbol{x}_{\mathrm{pm}}\right)|| p_{\theta_{\mathcal{M}}}\left(y \mid \boldsymbol{x}_{\mathrm{p}}\right)\right)$

  注意:$\boldsymbol{x}_{\mathrm{pm}}$ 可能包含域不变、域感知特征,或两者都包含;

2 方法

Stage 1: Learn from the source domain

  

  Procedure:

    • Firstly, we calculate the classification loss of those sentences and update the parameters with the loss, as shown in line 5 of Algorithm 1.
    • Then we mask the same sentence and calculate mask language modeling loss to update the parameters, as depicted in line 8 of Algorithm 1. The parameters of the model will be updated together by these two losses.

  Objective:

    $\begin{array}{l}\mathcal{L}_{1}^{\prime}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=\alpha \mathcal{L}_{p m t}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right) \\\mathcal{L}_{1}^{\prime \prime}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=\beta \mathcal{L}_{m l m}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)\end{array}$

Stage 2: Adapt to the target domain

  

  Procedure:

    • Firstly, we sample labeled data from the source domain $\mathcal{D}_{S}^{\mathcal{T}} $ and calculate sentiment classification loss. The model parameters are updated using this loss in line 5 of Algorithm 2.
    • Next, we sample unlabeled data from the target domain $\mathcal{D}_{T} $ and mask the unlabeled data to do a masking language model and selfsupervised distillation with the previous prediction.

  Objective:

    $\begin{aligned}\mathcal{L}_{2}^{\prime}\left(\mathcal{D}_{S}^{\mathcal{T}}, \mathcal{D}_{T} ; \theta_{\mathcal{M}}\right) & =\alpha \mathcal{L}_{p m t}\left(\mathcal{D}_{S}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right) \\\mathcal{L}_{2}^{\prime \prime}\left(\mathcal{D}_{S}^{\mathcal{T}}, \mathcal{D}_{T} ; \theta_{\mathcal{M}}\right) & =\beta\left(\mathcal{L}_{m l m}\left(\mathcal{D}_{T} ; \theta_{\mathcal{M}}\right)\right. \left.+\mathcal{L}_{s s d}\left(\mathcal{D}_{T} ; \theta_{\mathcal{M}}\right)\right)\end{aligned}$

Algorithm

  

3 实验

Dataset

  Amazon reviews dataset

  

Baselines
  • $\text{R-PERL }$(2020): Use BERT for cross-domain text classification with pivot-based fine-tuning.
  • $\text{DAAT}$ (2020): Use BERT post training for cross-domain text classification with adversarial training.
  • $\text{p+CFd}$ (2020): Use XLM-R for cross-domain text classification with class-aware feature self-distillation (CFd).
  • $\text{SENTIX}_{\text{Fix}}$ (2020): Pre-train a sentiment-aware language model by several pretraining tasks.
  • $\text{UDALM}$ (2021): Fine-tuning with a mixed classification and MLM loss on domain-adapted PLMs.
  • $\text{AdSPT}$ (2022): Soft Prompt tuning with an adversarial training object on vanilla PLMs.
Implementation Details
  • During Stage 1, we train 10 epochs with batch size 4 and early stopping (patience =3 ) on the accuracy metric. The optimizer is AdamW with learning rate 1 $\times 10^{-5}$ . And we halve the learning rate every 3 epochs. We set $\alpha=1.0$, $\beta=0.6$ for Eq.6 .
  • During Stage 2, we train 10 epochs with batch size 4 and early stopping (patience =3 ) on the mixing loss of classification loss and mask language modeling loss. The optimizer is AdamW with a learning rate $1 \times 10^{-6}$ without learning rate decay. And we set $\alpha=0.5$, $\beta=0.5$ for Eq. 7 .
  • In addition, for the mask language modeling objective and the self-supervised distillation objective, we randomly replace 30% of tokens to [MASK] and the maximum sequence length is set to 512 by truncation of inputs. Especially we randomly select the equal num unlabeled data from the target domain every epoch during Stage 2.

Single-source domain adaptation on Amazon reviews

  

Multi-source domain adaptation on Amazon reviews

  

Ablation experiments

  

Case Study 

  

  

Generality Study

  

论文解读(TAMEPT)《A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classification》的更多相关文章

  1. 论文解读(SimGRACE)《SimGRACE: A Simple Framework for Graph Contrastive Learning without Data Augmentation》

    论文信息 论文标题:SimGRACE: A Simple Framework for Graph Contrastive Learning without Data Augmentation论文作者: ...

  2. AAAI2019 | 基于区域分解集成的目标检测 论文解读

    Object Detection based on Region Decomposition and Assembly AAAI2019 | 基于区域分解集成的目标检测 论文解读 作者 | 文永亮 学 ...

  3. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  4. 论文解读(SDNE)《Structural Deep Network Embedding》

    论文题目:<Structural Deep Network Embedding>发表时间:  KDD 2016 论文作者:  Aditya Grover;Aditya Grover; Ju ...

  5. 论文解读(IDEC)《Improved Deep Embedded Clustering with Local Structure Preservation》

    Paper Information Title:<Improved Deep Embedded Clustering with Local Structure Preservation>A ...

  6. 论文解读(KP-GNN)《How Powerful are K-hop Message Passing Graph Neural Networks》

    论文信息 论文标题:How Powerful are K-hop Message Passing Graph Neural Networks论文作者:Jiarui Feng, Yixin Chen, ...

  7. 论文解读(SR-GNN)《Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data》

    论文信息 论文标题:Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data论文作者:Qi Zhu, ...

  8. itemKNN发展史----推荐系统的三篇重要的论文解读

    itemKNN发展史----推荐系统的三篇重要的论文解读 本文用到的符号标识 1.Item-based CF 基本过程: 计算相似度矩阵 Cosine相似度 皮尔逊相似系数 参数聚合进行推荐 根据用户 ...

  9. CVPR2019 | Mask Scoring R-CNN 论文解读

    Mask Scoring R-CNN CVPR2019 | Mask Scoring R-CNN 论文解读 作者 | 文永亮 研究方向 | 目标检测.GAN 推荐理由: 本文解读的是一篇发表于CVPR ...

  10. Gaussian field consensus论文解读及MATLAB实现

    Gaussian field consensus论文解读及MATLAB实现 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 一.Introduction ...

随机推荐

  1. django4 前后端分离和不分离的优缺点

    Django4可以采用前后端分离或者不分离两种方式来开发Web应用,它们各有优缺点. 前后端分离的优点: 前后端职责分离:前端负责视图展示.用户交互,后端负责数据处理.逻辑处理,分工明确,开发效率高. ...

  2. Java JDK1.8环境变量配置

    Java JDK1.8.0_152下载地址:https://pan.baidu.com/s/1BRB2MRETPdWVL-IN2FRTEw   提取码:63jb 下载好后傻瓜式一键Next下载就好,默 ...

  3. 全网最详细解读《GIN-HOW POWERFUL ARE GRAPH NEURAL NETWORKS》!!!

    Abstract + Introduction GNNs 大都遵循一个递归邻居聚合的方法,经过 k 次迭代聚合,一个节点所表征的特征向量能够捕捉到距离其 k-hop 邻域的邻居节点的特征,然后还可以通 ...

  4. 牧云 • 主机管理助手|正式开放应用市场,梦幻联动雷池WAF等多款开源软件

    0x00 前言 上个月,我司长亭开源了雷池WAF,不到三天就吸引了超过上千个师傅使用,几个交流群里,师傅们讨论的热火朝天,其中两个话题引起了我们牧云 • 主机管理助手 ( Collie ) 团队的关注 ...

  5. CKS 考试题整理 (05)-Container 安全上下文

    Context Container Security Context 应在特定 namespace 中修改 Deployment. Task 按照如下要求修改 sec-ns 命名空间里的 Deploy ...

  6. Java求数组元素的最大值和最小值

    代码如下: public static void main(String[] args) { int [] a = {1,2,3,88,2,90}; int max = a[0]; int min = ...

  7. 深度解析SpringBoot内嵌Web容器

    你好,我是刘牌! 前言 今天分享一个SpringBoot的内嵌Web容器,在SpringBoot还没有出现时,我们使用Java开发了Web项目,需要将其部署到Tomcat下面,需要配置很多xml文件, ...

  8. Python运维开发之路《python基础介绍》

    一. python介绍相关 1. Python简介 Python 是一个高层次的结合了解释性.编译性.互动性和面向对象的脚本语言. - Python 的设计具有很强的可读性,相比其他语言经常使用英文关 ...

  9. PHP支付接口签名生成数据

    <?php //作者主页 https://www.woailunwen.com $pay_memberid = '商户号'; $pay_orderid = '订单号'; $pay_amount ...

  10. java解析CSV文件(zipFiles 打成压缩包 exportObeEventDataExcel 前端页面响应)

    JAR包及代码17:39:09 <!-- https://mvnrepository.com/artifact/com.opencsv/opencsv --> <dependency ...