一种基于均值不等式的Listwise损失函数

1 前言

1.1 Learning to Rank 简介

Learning to Rank (LTR) , 也被叫做排序学习, 是搜索中的重要技术, 其目的是根据候选文档和查询语句的相关性对候选文档进行排序, 或者选取topk文档. 比如在搜索引擎中, 需要根据用户问题选取最相关的搜索结果展示到首页. 下图是搜索引擎的搜索结果

1.2 LTR算法分类

根据损失函数可把LTR分为三种:

  1. Pointwise, 该类型算法将LTR任务作为回归任务来训练, 即尝试训练一个为文档和查询语句的打分器, 然后根据打分进行排序.
  2. Pairwise, 该类型算法的损失函数考虑了两个候选文档, 学习目标是把相关性高的文档排在前面, triplet loss 就属于Pairwise, 它的损失函数是
\[loss = max(0, score_{neg}-score_{pos}+margin)
\]

可以看出该损失函数一次考虑两个候选文档.

3. Listwise, 该类型算法的损失函数会考虑多个候选文档, 这是本文的重点, 下面会详细介绍.

1.3 本文主要内容

本文主要介绍了本人在学习研究过程中发明的一种新的Listwise损失函数, 以及该损失函数的使用效果. 如果读者对LTR任务及其算法还不够熟悉, 建议先去学习LTR相关知识, 同时本人博文自然语言处理中的负样本挖掘 (分类与排序任务中如何选择负样本) 也和本文关系较大, 可以先进行阅读.

2 预备知识

2.1 数学符号定义

\(q\)代表用户搜索问题, 比如"如何成为宇航员", \(D\)代表候选文档集合,\(d^+\)代表和\(q\)相关的文档,\(d^-\)代表和\(q\)不相关的文档, \(d^+_i\)代表第\(i\)个和\(q\)相关的文档, LTR的目标就是根据\(q\)找到最相关的文档\(d\)

2.2 学习目标

本次学习目标是训练一个打分器 scorer, 它可以衡量q和d的相关性, \(scorer(q, d)\)就是相关性分数,分值越大越相关. 当前主流方法下, scorer一般选用深度神经网络模型.

2.3训练数据分类

损失函数不同, 构造训练数据的方法也会不同:

-Pointwise, 可以构造回归数据集, 相关的数据设为1, 不相关设为0.

-Pairwise, 可构造triplet类型的数据集, 形如(\(q,d^+, d^-\))

-Listwise, 可构造这种类型的训练集: (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 一个正例还是多个正例也会影响到损失函数的构造, 本文提出的损失函数是针对多正例多负例的情况.

3 基于均值不等式的Listwise损失函数

3.1 损失函数推导过程

在上一小结我们可以知道,训练集是如下形式 (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 对于一个q, 有n个相关的文档和m个不相关的文档, 那么我们一共可以获取m+n个分值:\((score_1,score_2,...,score_n,...,score_{n+m})\), 我们希望打分器对相关文档打分趋近于正无穷, 对不相关文档打分趋近于负无穷.

对m+n个分值做一个softmax得到\(p_1,p_2,...,p_n,...,p_{n+m}\), 此时\(p_i\)可以看作是第i个候选文档与q相关的概率, 显然我们希望\(p_1,p_2,...,p_n\)越大越好, \(p_{n+1},...,p_{m+n}\)越小越好, 即趋近于0. 因此我们暂时的优化目标是\(\sum_{i=1}^{n}{p_i} \rightarrow 1\).

但是这个优化目标是不合理的, 假设\(p_1=1\), 其他值全为0, 虽然满足了上面的要求, 但这并不是我们想要的. 因为我们不仅希望\(\sum_{i=1}^{n}{p_i} \rightarrow 1\), 还希望相关候选文档的每一个p值都要足够大, 即我们希望: n个候选文档都与q相关的概率是最大的, 所以我们真正的优化目标是:

\[\max(\prod_{i=1}^{n}{p_i} ) , \sum_{i=1}^{n}{p_i} = 1
\]

当前情况下, 损失函数已经可以通过代码实现了, 但是我们还可以做一些化简工作, \(\prod_{i=1}^{n}{p_i}\)是存在最大值的, 根据均值不等式可得:

\[\prod_{i=1}^{n}{p_i} \leq (\frac{\sum_{i=1}^{n}{p_i}}{n})^n
\]

对两边取对数:

\[\sum_{i=1}^{n}{log(p_i)} \leq -nlog(n)
\]

这样是不是感觉清爽多了, 然后我们把它转换成损失函数的形式:

\[loss = -nlog(n) - \sum_{i=1}^{n}{log(p_i)}
\]

所以我们的训练目标就是\(\min{(loss)}\)

3.2 使用pytorch实现该损失函数

在获取到最终的损失函数后, 我们还需要用代码来实现, 实现代码如下:

# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''

该示例代码仅展现了batch_size为1的情况, 在batch_size大于1时, 每一条数据都有不同的m和n, 为了能一起送入模型计算分值, 需要灵活的使用mask. 本人在实际使用该损失函数时,一共使用了两种mask, 分别mask每条数据所有候选文档和每条数据的相关文档, 供大家参考使用.

3.3 效果评估和使用经验

由于评测数据使用的是内部数据, 代码和数据都无法公开, 因此只能对使用效果做简单总结:

  1. 效果优于PointwisePairwise, 但差距不是特别大
  2. 相比Pairwise收敛速度极快, 训练一轮基本就可以达到最佳效果

下面是个人使用经验:

  1. 该损失函数比较占用显存, 实际的batch_size是batch_size*(m+n), 建议显存在12G以上
  2. 负例数量越多,效果越好, 收敛也越快
  3. 用pytorch实现log_softmax时, 不要自己实现, 直接使用torch中的log_softmax函数, 它的效率更高些.
  4. 只有一个正例, 还可以考虑转为分类问题,使用交叉熵做优化, 效果同样较好

4 总结

该损失函数还是比较简单的, 只需要简单的数学知识就可以自行推导, 在实际使用中也取得了较好的效果, 希望也能够帮助到大家. 如果大家有更好的做法欢迎告诉我.

文章可以转载, 但请注明出处:

一种基于均值不等式的Listwise损失函数的更多相关文章

  1. 基于均值坐标(Mean-Value Coordinates)的图像融合算法的优化实现

    目录 1. 概述 2. 实现 2.1. 原理 2.2. 核心代码 2.3. 第二种优化 3. 结果 1. 概述 我在之前的文章<基于均值坐标(Mean-Value Coordinates)的图像 ...

  2. LM-MLC 一种基于完型填空的多标签分类算法

    LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛[赛道一]设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签 ...

  3. [信安Presentation]一种基于GPU并行计算的MD5密码解密方法

    -------------------paper--------------------- 一种基于GPU并行计算的MD5密码解密方法 0.abstract1.md5算法概述2.md5安全性分析3.基 ...

  4. <<一种基于δ函数的图象边缘检测算法>>一文算法的实现。

    原始论文下载: 一种基于δ函数的图象边缘检测算法. 这篇论文读起来感觉不像现在的很多论文,废话一大堆,而是直入主题,反倒使人觉得文章的前后跳跃有点大,不过算法的原理已经讲的清晰了.     一.原理 ...

  5. 16种基于 CSS3 & SVG 的创意的弹窗效果

    在去年,我给大家分享了<基于 CSS3 的精美模态窗口效果>,而今天我要与大家分享一些新鲜的想法.风格和趋势变化,要求更加适合现代UI的不同的效果.这组新模态窗口效果包含了一些微妙的动画, ...

  6. tmpfs:一种基于内存的文件系统

    tmpfs是一种基于内存的文件系统, tmpfs有时候使用rm(物理内存),有时候使用swap(磁盘一块区域).根据实际情况进行分配. rm:物理内存.real memery的简称? 真实内存就是电脑 ...

  7. 一种基于重载的高效c#上图片添加文字图形图片的方法

    在做图片监控显示的时候,需要在图片上添加文字,如果用graphics类绘制图片上的字体,实现图像上添加自定义标记,这种方法经验证是可行的,并且在visual c#2005 编程技巧大全上有提到,但是, ...

  8. 一种基于Qt的可伸缩的全异步C/S架构服务器实现(流浪小狗,六篇,附下载地址)

    本文向大家介绍一种基于Qt的伸缩TCP服务实现.该实现针对C/S客户端-服务集群应用需求而搭建.连接监听.数据传输.数据处理均在独立的线程池中进行,根据特定任务不同,可安排负责监听.传输.处理的线程数 ...

  9. 一种基于Qt的可伸缩的全异步C/S架构server实现(一) 综述

    本文向大家介绍一种基于Qt的伸缩TCP服务实现.该实现针对C/Sclient-服务集群应用需求而搭建. 连接监听.传输数据.数据处理均在独立的线程池中进行,依据特定任务不同,可安排负责监听.传输.处理 ...

随机推荐

  1. JavaScript作用域与对象

    1 - 作用域 1.1 作用域概述 通常来说,一段程序代码中所用到的名字并不总是有效和可用的,而限定这个名字的可用性的代码范围就是这个名字的作用域.作用域的使用提高了程序逻辑的局部性,增强了程序的可靠 ...

  2. 力扣Leetcode 33. 搜索旋转排序数组

    33. 搜索旋转排序数组 假设按照升序排序的数组在预先未知的某个点上进行了旋转. ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1,2] ). 搜索一个给定的目标值, ...

  3. jq cdn地址

    百度CDN支持版本2.0.3, 2.0.2, 2.0.1, 2.0.0,1.11.1, 1.10.2, 1.10.1, 1.10.0, 1.9.1, 1.9.0, 1.8.3, 1.8.2, 1.8. ...

  4. akka-streams - 从应用角度学习:basic stream parts

    实际上很早就写了一系列关于akka-streams的博客.但那个时候纯粹是为了了解akka而去学习的,主要是从了解akka-streams的原理为出发点.因为akka-streams是akka系列工具 ...

  5. C015:十进制转8进制

    程序: #include "stdafx.h" #include <string.h> int _tmain(int argc, _TCHAR* argv[]) { i ...

  6. Agumaster 将爬虫取股票名称代号子系统分出来成agumaster_crawler, 两系统通过RabbitMq连接

    agumaster_crawler系统负责启动爬虫取得数据,之后便往队列中推送. agumaster_crawler系统中pom.xml关于RabbitMq的依赖是: <!-- RabbitMq ...

  7. 给EmpMapper开放Restful接口

    本文例程下载:https://files.cnblogs.com/files/xiandedanteng/gatling20200428-3.zip 接口控制器代码如下: 请求url和响应都写在了每个 ...

  8. 20190923-09Linux磁盘分区类 000 017

    df 查看磁盘空间使用情况 df: disk free 空余硬盘 1.基本语法 df  选项 (功能描述:列出文件系统的整体磁盘使用量,检查文件系统的磁盘空间占用情况) 2.选项说明 表1-32 选项 ...

  9. [java]将多个文件压缩成一个zip文件

    此文进阶请见:https://www.cnblogs.com/xiandedanteng/p/12155957.html 方法: package zip; import java.io.Buffere ...

  10. [深入理解JVM虚拟机]第3章-垃圾收集器、内存分配策略

    垃圾收集器 判断对象是否需存活 回收堆 判断对象是否存活: 方法一:引用计数法.对象被引用一次就+1,当为0时回收对象.缺点:无法解决循环引用问题. 方法二:可达性分析算法.记录当前对象是否有和GC ...