论文信息

论文标题:A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classification
论文作者:Yunlong Feng, Bohan Li, Libo Qin, Xiao Xu, Wanxiang Che
论文来源:2023 aRxiv
论文地址:download 
论文代码:download
视屏讲解:click

1 介绍

  动机:以前的工作主要集中于提取 域不变特征 或 任务不可知特征,而忽略了存在于目标域中可能对下游任务有用的域感知特征;

  贡献:

    • 提出一个两阶段的学习框架,使现有的分类模型能够有效地适应目标领域;
    • 引入自监督蒸馏,可以帮助模型更好地从目标领域的未标记数据中捕获域感知特征;
    • 在 Amazon 跨域分类基准上的实验表明,取得了 SOTA ;

2 相关

  

  Figure 1(a):阐述域不变特征和域感知特征与任务的关系;

  Figure 1(b):阐述遮蔽域不变特征和域感知特征与预测的关系:

    • 通过掩盖域不变特征,模型建立预测和域感知特征的相关性;
    • 通过掩盖域感知特征,模型加强了预测和域不变特征的关系;
PT

  一个文本提示组成如下:

    $\boldsymbol{x}_{\mathrm{p}}=\text { "[CLS] } \boldsymbol{x} \text {. It is [MASK]. [SEP]"}   \quad\quad(1)$

  $\text{PLM}$ 将 $\boldsymbol{x}_{\mathrm{p}}$ 作为输入,并利用上下文信息用词汇表中的一个单词填充 $\text{[MASK]}$ 作为输出,输出单词随后被映射到一个标签 $\mathcal{Y}$。

  PT 的目标:

    $\mathcal{L}_{p m t}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=-\sum_{\boldsymbol{x}, y \in \mathcal{D}} y \log p_{\theta_{\mathcal{M}}}\left(\hat{y} \mid \boldsymbol{x}_{\mathrm{p}}\right)$

MLM

  使用 $\text{MLM }$ 来避免快捷学习($\text{shortcut learning}$),并适应目标域分布。具体来说,构造了一个掩蔽文本提示符 $\boldsymbol{x}_{\mathrm{pm}}$:

    $\boldsymbol{x}_{\mathrm{pm}}=\text { "[CLS] } \boldsymbol{x}_{\mathrm{m}} \text {. It is [MASK]. [SEP]"}$

  MLM 损失如下:
    $\mathcal{L}_{m l m}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)=-\sum_{\boldsymbol{x} \in \mathcal{D}} \sum_{\hat{x} \in m\left(\boldsymbol{x}_{\mathrm{m}}\right)} \frac{\log p_{\theta_{\mathcal{M}}}\left(\hat{x} \mid \boldsymbol{x}_{\mathrm{pm}}\right)}{\operatorname{len} n_{m\left(\boldsymbol{x}_{\mathrm{m}}\right)}}$

  其中,$m\left(y_{\mathrm{m}}\right)$ 和 $\operatorname{len}_{m\left(\boldsymbol{x}_{\mathrm{m}}\right)}$ 分别表示 $x_{\mathrm{m}}$ 中的掩码词和计数;

SSKD

  核心:使模型能够在预测和目标域的域感知特征之间建立联系;

  具体:模型迫使 $x_{\mathrm{p}}$ 的预测和 $\boldsymbol{x}_{\mathrm{pm}}$ 的未掩蔽词之间联系起来,本文在 $p_{\theta}\left(y \mid \boldsymbol{x}_{\mathrm{pm}}\right)$ 和 $p_{\theta}\left(y \mid \boldsymbol{x}_{\mathrm{p}}\right)$ 的预测之间进行 $\text{KD}$:

    $\mathcal{L}_{s s d}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)=\sum_{\boldsymbol{x} \in \mathcal{D}} K L\left(p_{\theta_{\mathcal{M}}}\left(y \mid \boldsymbol{x}_{\mathrm{pm}}\right)|| p_{\theta_{\mathcal{M}}}\left(y \mid \boldsymbol{x}_{\mathrm{p}}\right)\right)$

  注意:$\boldsymbol{x}_{\mathrm{pm}}$ 可能包含域不变、域感知特征,或两者都包含;

2 方法

Stage 1: Learn from the source domain

  

  Procedure:

    • Firstly, we calculate the classification loss of those sentences and update the parameters with the loss, as shown in line 5 of Algorithm 1.
    • Then we mask the same sentence and calculate mask language modeling loss to update the parameters, as depicted in line 8 of Algorithm 1. The parameters of the model will be updated together by these two losses.

  Objective:

    $\begin{array}{l}\mathcal{L}_{1}^{\prime}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=\alpha \mathcal{L}_{p m t}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right) \\\mathcal{L}_{1}^{\prime \prime}\left(\mathcal{D}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right)=\beta \mathcal{L}_{m l m}\left(\mathcal{D} ; \theta_{\mathcal{M}}\right)\end{array}$

Stage 2: Adapt to the target domain

  

  Procedure:

    • Firstly, we sample labeled data from the source domain $\mathcal{D}_{S}^{\mathcal{T}} $ and calculate sentiment classification loss. The model parameters are updated using this loss in line 5 of Algorithm 2.
    • Next, we sample unlabeled data from the target domain $\mathcal{D}_{T} $ and mask the unlabeled data to do a masking language model and selfsupervised distillation with the previous prediction.

  Objective:

    $\begin{aligned}\mathcal{L}_{2}^{\prime}\left(\mathcal{D}_{S}^{\mathcal{T}}, \mathcal{D}_{T} ; \theta_{\mathcal{M}}\right) & =\alpha \mathcal{L}_{p m t}\left(\mathcal{D}_{S}^{\mathcal{T}} ; \theta_{\mathcal{M}}\right) \\\mathcal{L}_{2}^{\prime \prime}\left(\mathcal{D}_{S}^{\mathcal{T}}, \mathcal{D}_{T} ; \theta_{\mathcal{M}}\right) & =\beta\left(\mathcal{L}_{m l m}\left(\mathcal{D}_{T} ; \theta_{\mathcal{M}}\right)\right. \left.+\mathcal{L}_{s s d}\left(\mathcal{D}_{T} ; \theta_{\mathcal{M}}\right)\right)\end{aligned}$

Algorithm

  

3 实验

Dataset

  Amazon reviews dataset

  

Baselines
  • $\text{R-PERL }$(2020): Use BERT for cross-domain text classification with pivot-based fine-tuning.
  • $\text{DAAT}$ (2020): Use BERT post training for cross-domain text classification with adversarial training.
  • $\text{p+CFd}$ (2020): Use XLM-R for cross-domain text classification with class-aware feature self-distillation (CFd).
  • $\text{SENTIX}_{\text{Fix}}$ (2020): Pre-train a sentiment-aware language model by several pretraining tasks.
  • $\text{UDALM}$ (2021): Fine-tuning with a mixed classification and MLM loss on domain-adapted PLMs.
  • $\text{AdSPT}$ (2022): Soft Prompt tuning with an adversarial training object on vanilla PLMs.
Implementation Details
  • During Stage 1, we train 10 epochs with batch size 4 and early stopping (patience =3 ) on the accuracy metric. The optimizer is AdamW with learning rate 1 $\times 10^{-5}$ . And we halve the learning rate every 3 epochs. We set $\alpha=1.0$, $\beta=0.6$ for Eq.6 .
  • During Stage 2, we train 10 epochs with batch size 4 and early stopping (patience =3 ) on the mixing loss of classification loss and mask language modeling loss. The optimizer is AdamW with a learning rate $1 \times 10^{-6}$ without learning rate decay. And we set $\alpha=0.5$, $\beta=0.5$ for Eq. 7 .
  • In addition, for the mask language modeling objective and the self-supervised distillation objective, we randomly replace 30% of tokens to [MASK] and the maximum sequence length is set to 512 by truncation of inputs. Especially we randomly select the equal num unlabeled data from the target domain every epoch during Stage 2.

Single-source domain adaptation on Amazon reviews

  

Multi-source domain adaptation on Amazon reviews

  

Ablation experiments

  

Case Study 

  

  

Generality Study

  

论文解读(TAMEPT)《A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classification》的更多相关文章

  1. 论文解读(SimGRACE)《SimGRACE: A Simple Framework for Graph Contrastive Learning without Data Augmentation》

    论文信息 论文标题:SimGRACE: A Simple Framework for Graph Contrastive Learning without Data Augmentation论文作者: ...

  2. AAAI2019 | 基于区域分解集成的目标检测 论文解读

    Object Detection based on Region Decomposition and Assembly AAAI2019 | 基于区域分解集成的目标检测 论文解读 作者 | 文永亮 学 ...

  3. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  4. 论文解读(SDNE)《Structural Deep Network Embedding》

    论文题目:<Structural Deep Network Embedding>发表时间:  KDD 2016 论文作者:  Aditya Grover;Aditya Grover; Ju ...

  5. 论文解读(IDEC)《Improved Deep Embedded Clustering with Local Structure Preservation》

    Paper Information Title:<Improved Deep Embedded Clustering with Local Structure Preservation>A ...

  6. 论文解读(KP-GNN)《How Powerful are K-hop Message Passing Graph Neural Networks》

    论文信息 论文标题:How Powerful are K-hop Message Passing Graph Neural Networks论文作者:Jiarui Feng, Yixin Chen, ...

  7. 论文解读(SR-GNN)《Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data》

    论文信息 论文标题:Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data论文作者:Qi Zhu, ...

  8. itemKNN发展史----推荐系统的三篇重要的论文解读

    itemKNN发展史----推荐系统的三篇重要的论文解读 本文用到的符号标识 1.Item-based CF 基本过程: 计算相似度矩阵 Cosine相似度 皮尔逊相似系数 参数聚合进行推荐 根据用户 ...

  9. CVPR2019 | Mask Scoring R-CNN 论文解读

    Mask Scoring R-CNN CVPR2019 | Mask Scoring R-CNN 论文解读 作者 | 文永亮 研究方向 | 目标检测.GAN 推荐理由: 本文解读的是一篇发表于CVPR ...

  10. Gaussian field consensus论文解读及MATLAB实现

    Gaussian field consensus论文解读及MATLAB实现 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 一.Introduction ...

随机推荐

  1. Django4全栈进阶之路13 template模板

    在 Django 中,模板(Template)用于生成动态的 HTML 页面.通常情况下,Django 项目包含多个视图函数,每个视图函数都负责渲染不同的 HTML 页面.使用模板可以让我们将 HTM ...

  2. [ABC270D] Stones

    [ABC270D] Stones 题意 有两个人玩游戏,有 \(n\) 个石子,和一个长度为 \(k\) 的序列,每次可以取 \(a_i\) 个但前提是剩下来的石子数有 \(a_i\) 个,第一个人先 ...

  3. yaml的读写

    yaml文件的读写是真的快,也很简单.代码如下:from ruamel.yaml import YAMLimport os # 读取yaml配置文件def read_yaml(yaml_path): ...

  4. AcWing 3729. 改变数组元素

    给定一个空数组 V 和一个整数数组 a1,a2,-,an. 现在要对数组 V进行 n次操作. 第 i次操作的具体流程如下: 从数组 V尾部插入整数 0.    2.将位于数组 V末尾的 ai 个元素都 ...

  5. sipp重放rtp数据测试FreeSWITCH

    环境:CentOS 7.6_x64 FreeSWITCH版本 :1.10.9 sipp版本:3.6.1 一.背景描述 sipp是一款VoIP测试工具,日常开发过程中会使用到该软件,但其自身携带的pca ...

  6. Dapr在Java中的实践 之 状态管理

    状态管理 状态管理(State Management)使用键值对作为存储机制,可以轻松的使长时运行.高可用的有状态服务和无状态服务共同运行在我们的服务中. 我们的服务可以利用Dapr的状态管理API在 ...

  7. ORM总览

    ORM(Object-Relational Mapping)是一种常见的数据访问技术,它将对象模型和关系模型之间进行映射.ORM的主要作用是简化数据访问和管理,提高开发效率和代码质量.在实际应用中,O ...

  8. 浏览器输入URL到网页完全呈现的过程

    前言 临近计算机网络期末考试, 最近在复习(预习), 写一遍博客讲解加深印象. 浏览器输入URL过程图 浏览器输入 URL 过程: 当用户在网页上输入网址 URL 后, 浏览器会对网址进行 DNS 域 ...

  9. JavaCV人脸识别三部曲之一:视频中的人脸保存为图片

    欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos 关于人脸识别 本文是<JavaCV人脸识别三部曲 ...

  10. Rainbond助力“信创应用”迁移上云

    Rainbond v5.14.2 版本,又称信创版本.从这个版本开始,开源用户也可以利用 Rainbond 管理符合信创要求的硬件计算资源.在这个版本中,产品团队将此前只在企业版产品中存在的信创相关功 ...