Relational Learning with Gated and Attentive Neighbor Aggregator for Few-Shot Knowledge Graph Completion 小样本关系学习论文解读
小样本知识图补全——关系学习。利用三元组的邻域信息,提升模型的关系表示学习,来实现小样本的链接预测。主要应用的思想和模型包括:GAT、TransH、SLTM、Model-Agnostic Meta-Learning (MAML)。
论文地址:https://arxiv.org/pdf/2104.13095.pdf
引出
在WIkidata数据集中,有超大约10%的关系只被不超过10个的三元组所包含,所以要用小样本学习来为这些关系扩充实体。尽管我们可以直接学习整个知识图谱的关系,然后对整个知识图谱进行推理和扩充,但是这样计算量太大。论文用关系的邻域作为额外信息来提升关系表示的小样本学习。
方法
在这篇论文中,实体和关系的嵌入在训练和测试阶段都是使用TransE预训练好已知的,论文通过引入邻域信息来进一步发掘关系的表示。另外,论文方法使用了模型不可知元学习(Model-Agnostic Meta Learning, MAML),为了表述清晰,我们先对测试的pipeline进行介绍,然后介绍训练过程。
测试pipeline
对于知识图谱$\mathcal{G}$中的小样本关系$r$,其包含少量的$K$个三元组$\{(h_1,r,t_1),...,(h_K,r,t_K)\}$,我们称这个集合为支撑集(Support set)$S$。模型就是用支撑集作为推理信息,对查询集(Query set)$Q=\{(h_1,r,t_1^?),...,(h_m,r,t^?_m)\}$进行推理(已知关系和一个实体,预测另一个实体)。查询集中三元组头实体$h$或尾实体$t$未知都可。
如论文中Figure 2所示,对于支撑集的每个三元组$(h_i,r,t_i)$,都在$\mathcal{G}$中找出它们邻域的三元组$\{\mathcal{N}_{h_i},\mathcal{N}_{t_i}\}$。然后,通过一个所谓门控注意力邻域聚合器,将$h_i$和$t_i$分别和它们各自的邻域$\mathcal{N}_{h_i}$和$\mathcal{N}_{t_i}$融合,分别得到$h'_i$和$t'_i$。然后拼接得到$s_i$。$s_i$可以看做是$r$在实体$h_i$和$t_i$下,同时融合了它们的邻域的表示。具体操作请看论文中式(1)-式(6)。
然后为了对$r$进行完整的表示,也就是将$r$的所有三元组信息都融合起来,文中用Bi-LSTM和注意力机制将所有的$s_i$进行融合,最后得到$r$的表示$r'$。具体操作请看文中式(7)-式(13)。其实不太了解为什么要用Bi-LSTM,毕竟它们之间没有时序关系。
然后文中使用知识图嵌入方法(KGE)TransH来使表示$r'$和支撑集$S$中实体的嵌入相契合,从而能进行后面的推理步骤。与TransH一样,文中定义一个$P_r$作为$r'$所在超平面的法向量,并使用TransH相关的损失对之前计算出来的$r'$和$P_r$进行优化。也就是说,优化$r'$和$P_r$,使得对于支撑集$S$中所有的$h_i$和$t_i$,有
$(h_i-P_r^Th_iP_r)+r'=(t_i-P_r^Tt_iP_r),\,\,i=1,...,K$
优化方式就是用梯度下降,如文中式(14)-式(19)所示。需要注意的是,论文中TransH映射到超平面的式子写错了,式(14)和式(20)。
最后就是用优化后的$P_r$和$r'$进行在查询集$Q$上的推理。也就是对知识图谱所有实体$e$进行排序,与映射后的嵌入差异$Dif$越小排名越靠前。$Dif$计算方式如下:
$Dif=\|(h_i-P_r^Th_iP_r)+r'-(e-P_r^TeP_r)\|$
其中优化后的$P_r$和$r'$在论文中表示为$P'_r$和$r_m$。另外,文中还出现了一个$P^*_r$。它是模型经过MAML预训练后得到的参数,在下面的训练过程中介绍。
训练过程
训练过程是对计算$r'$所用的模型参数(图2The Global Stage的参数)进行训练,以及对超平面法向量$P_r$进行预训练(MAML)。之所以称之为预训练,是因为$P_r$在测试阶段依然需要使用支撑集进行微调,而聚合器参数如Bi-LSTM在测试阶段则是固定不变的。
训练与测试阶段一样,同样包含支撑集与查询集。训练过程描述如下(论文Algorithm 1):
初始化The Global Stage的模型参数,以及$P_r$。
对于某个关系$r$,设其支撑集和查询集分别为$S_r=\{(h_1,r,t_1),...,(h_K,r,t_K)\}$和$Q_r=\{(h_1,r,t_1^?),...,(h_m,r,t^?_m)\}$。$S_r$经过邻域聚合器得到表示$r'$。然后与测试阶段一样用支撑集实体优化$r'$和$P_r$。最后计算$r'$和$P_r$在查询集$Q_r$上的损失,并用该损失对The Global Stage的模型参数$\theta$(经由$r'$)以及$P_r$进行优化:
$\displaystyle L=\frac{1}{m}\sum\limits_{i=1}^m\|(h_i-P_r^Th_iP_r)+r'-(t_i^?-P_r^Tt_i^?P_r)\|$
$\theta=\theta-\beta\nabla_\theta L$
$P_r=P_r-\beta\nabla_{P_r} L$
值得注意的是,由于执行了二重反向传播,因此在实现时第一次对$r'$和$P_r$的优化需要保存其反向传播的计算图。
使用多个关系的支撑集和查询集对模型进行训练。最终训练完成得到$\theta^*$和$P_r^*$。
实验
实验设置
论文实验使用NELL-One和Wiki-One数据集进行实验,选择包含三元组数量在50到500之间的关系进行训练(支撑集和查询集的size没说),之后使用支撑集大小为1/3/5的关系进行测试。使用平均倒数秩(Mean reciprocal rank,MRR)和Hit@n作为评价指标,都是越高越好。指标计算流程是这样的:
对于某个关系$r$,测试集的查询集是$Q=\{(h_1,r,t^?_1),...,(h_m,r,t^?_m)\}$。对每个待查询三元组$(h_i,r,?)$进行推理,按照相似度得到可能的尾实体排序$\{t^1_i,...,t^w_i\}$。
MRR:获取$t^?_i$在其中的排序$rank_i$,然后将所有$rank_i$的倒数进行平均,得到这个关系$r$的MRR:
$\displaystyle MRR_r=\frac{1}{m}\sum\limits_{i=1}^m\frac{1}{rank_i}$
Hit@n:查看尾实体排序的前$n$个,如果$t^?_i$在其中,则记$Hit_i=1$,否则为$0$,将所有$Hit_i$进行平均,得到这个关系$r$的Hit@n:
$\displaystyle Hit@n_r=\frac{1}{m}\sum\limits_{i=1}^mHit_i$
然后测试集可能包含对多种关系的测试,最终MRR和Hit@n的结果应该是所有$MRR_r$和$Hit@n_r$的平均。
实验结果
做了1/3/5-shot实验、1-N/N-1/N-N实验、消融实验,证明了模型和各个模块的有效性,如表2/3/4所示。
因为论文使用关系的邻域促进模型对关系表示的学习,论文还对邻域给关系表示的贡献作了可视化,如图3所示。权重越大,表示这个邻居对关系的表示贡献越大。
Relational Learning with Gated and Attentive Neighbor Aggregator for Few-Shot Knowledge Graph Completion 小样本关系学习论文解读的更多相关文章
- ICLR 2013 International Conference on Learning Representations深度学习论文papers
ICLR 2013 International Conference on Learning Representations May 02 - 04, 2013, Scottsdale, Arizon ...
- 收藏:左路Deep Learning+右路Knowledge Graph,谷歌引爆大数据
发表于2013-01-18 11:35| 8827次阅读| 来源sina微博 条评论| 作者邓侃 数据分析智能算法机器学习大数据Google 摘要:文章来自邓侃的博客.数据革命迫在眉睫. 各大公司重兵 ...
- Deep Learning 和 Knowledge Graph howto
领军大家: Geoffrey E. Hinton http://www.cs.toronto.edu/~hinton/ 阅读列表: reading lists and survey papers fo ...
- 论文解读《Learning Deep CNN Denoiser Prior for Image Restoration》
CVPR2017的一篇论文 Learning Deep CNN Denoiser Prior for Image Restoration: 一般的,image restoration(IR)任务旨在从 ...
- 自监督学习(Self-Supervised Learning)多篇论文解读(下)
自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...
- 自监督学习(Self-Supervised Learning)多篇论文解读(上)
自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...
- A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)
A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...
- 论文解读(GraphDA)《Data Augmentation for Deep Graph Learning: A Survey》
论文信息 论文标题:Data Augmentation for Deep Graph Learning: A Survey论文作者:Kaize Ding, Zhe Xu, Hanghang Tong, ...
- 论文解读(MGAE)《MGAE: Masked Autoencoders for Self-Supervised Learning on Graphs》
论文信息 论文标题:MGAE: Masked Autoencoders for Self-Supervised Learning on Graphs论文作者:Qiaoyu Tan, Ninghao L ...
- 论文解读(USIB)《Towards Explanation for Unsupervised Graph-Level Representation Learning》
论文信息 论文标题:Towards Explanation for Unsupervised Graph-Level Representation Learning论文作者:Qinghua Zheng ...
随机推荐
- 在 Linux 中找出 CPU 占用高的进程
列出系统中 CPU 占用高的进程列表来确定.我认为只有两种方法能实现:使用 top 命令 和 ps 命令.出于一些理由,我更倾向于用 top 命令而不是 ps 命令.但是两个工具都能达到你要的目的,所 ...
- 如何用Virtualbox搭建一个虚拟机
序言 各位好啊,我是会编程的蜗牛,作为java开发者,我们肯定会接触Linux服务器,除了使用云服务搭建Linux服务器外,我们一般也可以在自己的电脑上安装虚拟机来搭建Linux服务器用于各种功能的验 ...
- PHP + ELK实现日志记录
一个简单的PHP 文件 效果 full.conf文件 流程: 开启logstash服务之后. 在业务代码里面操作函数写入日志.log logstash通过实践戳获取到用户的变更,取出最后一行数据,发送 ...
- Linux基础_3_文件/文件夹权限管理
注:权限遮罩码: 控制用户创建文件和文件夹的默认安全设置,文件默认权限为666-umask的值,文件夹默认权限为777-umask的值. root默认0022,普通用户默认0002. 文件的默认权限不 ...
- vue3+element-plus+登录逻辑token+环境搭建
vue3+element-plus+登录逻辑token环境搭建 安装脚手架工具 1 npm i @vue/cli@4.5.13 -g 验证是否安装成功 1 vue -V # 输出 @vue/cli 4 ...
- $_SERVER['HTTP_USER_AGENT']:在PHP中HTTP_USER_AGENT是用来获取用户的相关信息的,包括用户使用的浏览器,操作系统等信息
在PHP中HTTP_USER_AGENT是用来获取用户的相关信息的,包括用户使用的浏览器,操作系统等信息. 我机器:操作系统:WIN7旗舰版 64操作系统 以下为各个浏览器下$_SERVER['HTT ...
- Java多线程(6):锁与AQS(下)
您好,我是湘王,这是我的博客园,欢迎您来,欢迎您再来- 之前说过,AQS(抽象队列同步器)是Java锁机制的底层实现.既然它这么优秀,是骡子是马,就拉出来溜溜吧. 首先用重入锁来实现简单的累加,就像这 ...
- nodered获取简单的时间
1.添加simpletime 的节点 2. 添加一个inject节点用来每1s循环获取当点的信息 3.添加一个函数节点对simpletime发来的msg进行解析 var payload=msg;var ...
- selenium 添加特殊配置(如不完整 希望各位大神评论告诉我)
options 常用配置 #添加特殊配置 options=webdriver.ChromeOptions() #设置默认编码为utf-8,也就是中文 options.add_argument('lan ...
- JavaScript常用工具函数
检测数据是不是除了symbol外的原始数据 function isStatic(value) { return ( typeof value === 'string' || typeof value ...