上一节中介绍了 $ \lambda $ 的计算,lambdaMART就以计算的每个doc的 $\lambda$ 值作为label,训练Regression Tree,并在最后对叶子节点上的样本 $lambda$ 均值还原成 $\gamma$ ,乘以learningRate加到此前的Regression Trees上,更新score,重新对query下的doc按score排序,再次计算deltaNDCG以及 $\lambda$ ,如此迭代下去直至树的数目达到参数设定或者在validation集上不再持续变好(一般实践来说不在模型训练时设置validation集合,因为validation集合一般比训练集合小很多,很容易收敛,达不到效果,不如训练时一步到位,然后另起test集合做结果评估)。

其实Regression Tree的训练很简单,最主要的就是决定如何分裂节点。lambdaMART采用最朴素的最小二乘法,也就是最小化平方误差和来分裂节点:即对于某个选定的feature,选定一个值val,所有<=val的样本分到左子节点,>val的分到右子节点。然后分别对左右两个节点计算平方误差和,并加在一起作为这次分裂的代价。遍历所有feature以及所有可能的分裂点val(每个feature按值排序,每个不同的值都是可能的分裂点),在这些分裂中找到代价最小的。

举个栗子,假设样本只有上一节中计算出 $\lambda$ 的那10个:

 qId=1830 features and lambdas
qId=1830 1:0.003 2:0.000 3:0.000 4:0.000 5:0.003 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(1):-0.495
qId=1830 1:0.026 2:0.125 3:0.000 4:0.000 5:0.027 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(2):-0.206
qId=1830 1:0.001 2:0.000 3:0.000 4:0.000 5:0.001 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(3):-0.104
qId=1830 1:0.189 2:0.375 3:0.333 4:1.000 5:0.196 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(4):0.231
qId=1830 1:0.078 2:0.500 3:0.667 4:0.000 5:0.086 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(5):0.231
qId=1830 1:0.075 2:0.125 3:0.333 4:0.000 5:0.078 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(6):-0.033
qId=1830 1:0.079 2:0.250 3:0.667 4:0.000 5:0.085 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(7):0.240
qId=1830 1:0.148 2:0.000 3:0.000 4:0.000 5:0.148 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(8):0.247
qId=1830 1:0.059 2:0.000 3:0.000 4:0.000 5:0.059 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(9):-0.051
qId=1830 1:0.071 2:0.125 3:0.333 4:0.000 5:0.074 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000 lambda(10):-0.061

上表中除了第一列是qId,最后一列是lambda外,其余都是feature,比如我们选择feature(1)的0.059做分裂点,则左子节点<=0.059的doc有: 1, 2, 3, 9;而>0.059的被安排到右子节点,doc有4, 5, 6, 7, 8, 10。由此左右两个子节点的lambda均值分别为:

$ \bar{\lambda_L}=\frac{\lambda_1+\lambda_2+\lambda_3+\lambda_9}{4}=\frac{-0.495-0.206-0.104-0.051}{4}=-0.214$

$\bar{\lambda_R}=\frac{\lambda_4+\lambda_5+\lambda_6+\lambda_7+\lambda_8+\lambda_{10}}{6}=\frac{0.231+0.231-0.033+0.240+0.247-0.061}{6}=0.143$

继续计算左右子节点的平方误差和:

$s_{L}=\sum_{i\in L}{(\lambda_i-\bar{\lambda_L})^2}=(-0.495+0.214)^2+(-0.206+0.214)^2+(-0.104+0.214)^2+(-0.051+0.214)^2=0.118$

$s_{R}=\sum_{i\in R}{(\lambda_i-\bar{\lambda_R})^2}=(0.231-0.143)^2+(0.231-0.143)^2+(-0.033-0.143)^2+(0.240-0.143)^2+(0.247-0.143)^2+(0.016-0.143)^2=0.083$

因此将feature(1)的0.059的均方差(分裂代价)是:

$Cost_{0.059@feature(1)}=s_{L}+s_{R}=0.118+0.083=0.201$

我们可以像上面那样遍历所有feature的不同值,尝试分裂,计算Cost,最终选择所有可能分裂中最小Cost的那一个作为分裂点。然后将 $s_{L}$ 和 $s_{R}$ 分别作为左右子节点的属性存储起来,并把分裂的样本也分别存储到左右子节点中,然后维护一个队列,始终按平方误差和 s 降序插入新分裂出的节点,每次从该队列头部拿出一个节点(并基于这个节点上的样本)进行分裂(即最大均方差优先分裂),直到树的分裂次数达到参数设定(训练时传入的leaf值,叶子节点的个数与分裂次数等价)。这样我们就训练出了一棵Regression Tree。

上面讲述了一棵树的标准分裂过程,需要多提一点的是,树的分裂还有一个参数设定:叶子节点上的最少样本数,比如我们设定为3,则在feature(1)处,0.001和0.003两个值都不能作为分裂点,因为用它们做分裂点,左子树的样本数分别是1和2,均<3。叶子节点的最少样本数越小,模型则拟合得越好,当然也容易过拟合(over-fitting);反之如果设置得越大,模型则可能欠拟合(under-fitting),实践中可以使用cross validation的办法来寻找最佳的参数设定。

LambdaMART简介——基于Ranklib源码(二 Regression Tree训练)的更多相关文章

  1. LambdaMART简介——基于Ranklib源码(一 lambda计算)

    学习Machine Learning,阅读文献,看各种数学公式的推导,其实是一件很枯燥的事情.有的时候即使理解了数学推导过程,也仍然会一知半解,离自己写程序实现,似乎还有一道鸿沟.所幸的是,现在很多主 ...

  2. Java_io体系之PipedWriter、PipedReader简介、走进源码及示例——14

    Java_io体系之PipedWriter.PipedReader简介.走进源码及示例——14 ——管道字符输出流.必须建立在管道输入流之上.所以先介绍管道字符输出流.可以先看示例或者总结.总结写的有 ...

  3. Java_io体系之BufferedWriter、BufferedReader简介、走进源码及示例——16

    Java_io体系之BufferedWriter.BufferedReader简介.走进源码及示例——16 一:BufferedWriter 1.类功能简介: BufferedWriter.缓存字符输 ...

  4. Java_io体系之RandomAccessFile简介、走进源码及示例——20

    Java_io体系之RandomAccessFile简介.走进源码及示例——20 RandomAccessFile 1.       类功能简介: 文件随机访问流.关心几个特点: 1.他实现的接口不再 ...

  5. AQS源码二探-JUC系列

    本文已在公众号上发布,感谢关注,期待和你交流. AQS源码二探-JUC系列 共享模式 doAcquireShared 这个方法是共享模式下获取资源失败,执行入队和等待操作,等待的线程在被唤醒后也在这个 ...

  6. Unity UGUI图文混排源码(二)

    Unity UGUI图文混排源码(一):http://blog.csdn.net/qq992817263/article/details/51112304 Unity UGUI图文混排源码(二):ht ...

  7. JMeter 源码二次开发函数示例

    JMeter 源码二次开发函数示例 一.JMeter 5.0 版本 实际测试中,依靠jmeter自带的函数已经无法满足我们需求,这个时候就需要二次开发.本次导入的是jmeter 5.0的源码进行实际的 ...

  8. Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练

    Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练 目录 Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练 0x00 摘要 0x01 前文回顾 1.1 上文总体流程图 1 ...

  9. [源码分析] Facebook如何训练超大模型---(1)

    [源码分析] Facebook如何训练超大模型---(1) 目录 [源码分析] Facebook如何训练超大模型---(1) 0x00 摘要 0x01 简介 1.1 FAIR & FSDP 1 ...

随机推荐

  1. java基础语法 数组

    数组是相同数据类型元素的集合   数组本身是引用数据类型,即对象.但是数组可以存储基本数据类型,也可以存储引用数据类型. 在java中如果想要保存一组基本类型的数据,使用数组:如果想保存一组对象或者其 ...

  2. 本地blast的安装

    1 下载程序 在ftp://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/下载 ncbi-blast-2.2.25+-x64-linux.t ...

  3. 5.3 Components — Passing Properties to A Component

    1. 默认情况下,一个组件在它使用的模板范围中没有访问属性. 例如,假想你有一个blog-post组件被用来展示一个blog post: app/templates/components/blog-p ...

  4. vgg_face人脸识别

    最近参考http://blog.csdn.net/hlx371240/article/details/51388022一文,用LFW数据集对vgg_face.caffemodel进行fine-tune ...

  5. 用opencv检测人眼并定位瞳孔位置

    最近的研究要用到定位瞳孔的位置,所以上网搜了下相关的代码.总结如下: 1) 定位瞳孔可以直接使用opencv中的自带的分类器(haarcascade_eye_tree_eyeglasses.xml)来 ...

  6. 自动化运维工具SaltStack安装配置

    SaltStack是一种全新的基础设置管理方式,部署轻松,在几分钟内可运作起来,扩展性好,很容易管理上万台服务器,速度够快,服务器之间秒级通讯.通过部署SaltStack环境,我们可以在成千上万台服务 ...

  7. 20145307陈俊达《网络对抗》Exp2 后门原理与实践

    20145307陈俊达<网络对抗>Exp2 后门原理与实践 基础问题回答 例举你能想到的一个后门进入到你系统中的可能方式? 非正规网站下载的软件 脚本 或者游戏中附加的第三方插件 例举你知 ...

  8. uboot 版本号生成过程

    uboot 版本号生成过程 uboot版本号貌似与实际开发不相关,但是我现在遇到一个bug与版本号关联密切. 这个bug与<uboot dm9000驱动故障>基本上是一样的,但是在上一篇博 ...

  9. Log4j2报错ERROR StatusLogger Unrecognized format specifier

    问题 使用maven-shade-plugin或者maven-assembly-plugin插件把项目打成一个可执行JAR包时,如果你引入了log4j2会出现如下问题: ERROR StatusLog ...

  10. HDU 2457 DNA repair(AC自动机+DP)题解

    题意:给你几个模式串,问你主串最少改几个字符能够使主串不包含模式串 思路:从昨天中午开始研究,研究到现在终于看懂了.既然是多模匹配,我们是要用到AC自动机的.我们把主串放到AC自动机上跑,并保证不出现 ...