如今,GBDT被广泛运用于互联网行业,他的原理与优点这里就不细说了,网上google一大把。但是,我自认为自己不是一个理论牛人,对GBDT的理论理解之后也做不到从理论举一反三得到更深入的结果。但是学习一个算法,务必要深入细致才能领会到这个算法的精髓。因此,在了解了足够的GBDT理论之后,就需要通过去阅读其源码来深入学习GBDT了。但是,网上有关这类资料甚少,因此,我不得不自己亲自抄刀,索性自己从头学习了一下GBDT源码。幸好,这个算法在机器学习领域中的其它算法还是非常简单的。这里将心得简单分享,欢迎指正。源码可以去GBDT源码下载。

首先,这里需要介绍一下程序中用到的结构体,具体的每一个结构体的内容这里就不再赘述了,源码里面都有。这里只再细说一下每个结构体的作用,当然一些重要的结构体会详细解释。

struct gbdt_model_t:GBDT模型的结构体,也就是最终我们训练得到的由很多棵决策树组成的模型。

typedef struct {
          int* nodestatus;    //!<  
          int* depth;         // 
          int* splitid;       //!< 
          double* splitvalue; //!< 
          int* ndstart;       //!< 节点对应于 Index 的开始位置
          int* ndcount;       //!< 节点内元素的个数
          double* ndavg;      //!< 节点内元素的均值 
         //double* vpredict;
          int* lson;          //!< 左子树
          int* rson;          //!< 右子树
          int nodesize;       //!< 树的节点个数
     }gbdt_tree_t;

struct gbdt_tree_t:当然就代表模型中的一棵树的各种信息了。为了后面能理解,这里需要详细解释一下这个结构体。splitid[k]保存该棵树的第k个结点分裂的feature下标,splitvalue[k]保存该棵树第k个结点的分裂值,nodestatus[k]代表该棵树的第k个结点的状态,如果为GBDT_INTERIOR,代表该结点已分裂,如果为GBDT_TOSPLIT,代表该结点需分裂,如果为GBDT_TERMINAL表示该结点不需再分裂,一般是由于该结点的样本数ndcount[k]少于等于一阈值gbdt_min_node_size;depth[ncur+1]代表左子树的深度,depth[ncur+2]表示右子树的深度,其中ncur的增长步长为2,表示每次+2都相关于跳过当前结点的左子树和右子树,到达下一个结点。ndstart[ncur+1]代表划分到左子树开始样本的下标,ndstart[ncur+2]代表划分到右子树开始样本的下标,其中到底这个下标是代表第几个样本是由index的一个结构保存。ndcount[ncur + 1]代表划分到左子树的样本数量,ndcount[ncur + 2]代表划分到右子树的样本数量。ndavg[ncur+1]代表左子树样本的均值,同理是右子树样本的均值。nodestatus[ncur+1] = GBDT_TOSPLIT表示左子树可分裂。lson[k]=ncur+1表示第k个结点的左子树,同理表示第k个结点的右子树。

gbdt_info_t保存模型配置参数。

typedef struct   
     {    
          int* fea_pool; //!< 随机 feature 候选池
          double* fvalue_list; //!< 以feature i 为拉链的特征值 x_i
          double* fv; //!< 特征值排序用的buffer版本
          double* y_list; //!< 回归的y值集合
          int* order_i; //!< 排序的标号
      } bufset; //!< 训练数据池

bufset代表训练数据池,它保存了训练当前一棵树所用到的一些数据。fea_pool保存了训练数据的特征的下标,循环rand_fea_num(feature随机采样量)次,随机地从fea_pool中选取特征来计算分裂的损失函数(先过的feature不会再选)。fvalue_list保存在当前选择特征fid时,所有采样的样本特征fid对应的值。fv与favlue_list一样。y_list表示采样样本的y值。order_i保存左子树与右子树结点下标。

nodeinfo代表节点的信息。

typedef struct  
     {        
         int bestid; //!< 分裂使用的Feature ID
         double bestsplit; //!< 分裂边界的x值
         int pivot; //!< 分裂边界的数据标号               
     } splitinfo; //!< 分裂的信息

splitinfo代表分裂的信息。pivot代表分裂点在order_i中的下标。bestsplit表示分裂值。bestid表示分裂的feature。

好了,解释完关键的一些结构体,下面要看懂整个gbdt的流程就非常简单了。这里我就简单的从头至尾叙述一下整个训练的流程。

首先申请分配模型空间gbdt_model,并且计算所有样本在每一维特征上的平均值。假如我们需要训练infbox.tree_num棵树,每一棵的训练流程为:从x_fea_value中采样gbdt_inf.sample_num个样本,index[i]记录了第i个结点所对应的样本集合x_fea_value中的下标,其始终保存了训练本棵树的所有采样样本对应样本空间的下标值,同时,结点的顺序是按该棵树所有结点按广度优先遍历算法遍历的结果的。即当前树gbdt_single_tree只有一个根结点0,其中gbdt_single_tree->nodestatus为GBDT_TOSPLIT,ndstart[0]=0,ndcount[0]=sample_num,ndavg为所有采样样本的y的梯度值均值。下面就是对这个结点进行分裂的过程:首先nodeinfo ninf这个结构体保存了当前分裂结点的一些信息,比如结点中样本开始的下标(指相对于index的下标值,index指向的值才是样本空间中该样本的下标),样本结束下标(同上),样本结点数,样本结点的y的梯度之和等。循环rand_fea_num次,随机采样feature,来计算在该feature分裂的信息增益,计算方式为(左子树样子目标值和的平方均值+右子树目标值和的平方均值-父结点所有样本和的平方均值)。选过的feature就不会再选中来计算信息增益了。利用data_set来保存当前分裂过程所用到的一些信息,包括候选feature池,选中feature对应的采样样本的特征值及其y值。data_set->order_i保存了左右子树对应结点在样本集合中的下标。计算每个feature的信息增益,并取最大的,保存分点信息到spinf中,包括最优分裂值,最优分裂feature。然后,将该结点小于分裂值的结点样本下标与大于分裂值的结点样本下标都保存在data_set->order_i中,nl记录了order_i中右子树开始的位置。更新index数组,将order_i中copy到index中。将nl更新到spinf中。注意index数组从左至右保存了最终分裂的左子树与右子树样本对应在样本空间的下标。

至此,我们找到了这个结点的最优分裂点。gbdt_single_tree->ndstart[1]保存了左孩子的开始下标(指相对于index的下标值,index指向的值才是样本下标),gbdt_single_tree->ndstart[2]保存了右孩子的开始下标,即nl的值。同理,ndcount,depth等也是对就保存了左右孩子信息。gbdt_single_tree->lson[0]=1,gbdt_single_tree->lson[0]=2即表示当前结点0的左子树是1,右子树是2。当前结点分裂完了之后,下一次就同理广度优先算法,对该结点的孩子继续上述步骤。

该棵树分裂完成之后,对每一个样本,都用目前模型(加上分裂完成的这棵树)计算预测值,并且更新每一个样本的残差y_gradient。计算过程:选取当前结点的分裂feature以及分裂值,小于则走左子树,大于则走右子树,直到叶子结点。预测值为shrink*该叶子结点的样本目标值的均值。

训练第二棵树同理,只是训练的样本的目标值变成了前面模型预测结果的残差了。这点就体现在梯度下降的寻优过程。

好了,这里只是简单的对gbdt代码做了说明,当然如果没有看过本文引用的源码,是不怎么能看懂的,如果结合源码来看,就很容易看懂了。总之,个人感觉,只有结合原码来学习gbdt,才真正能体会到事个模型的学习以及树的生成过程。

GBDT源码剖析的更多相关文章

  1. jQuery之Deferred源码剖析

    一.前言 大约在夏季,我们谈过ES6的Promise(详见here),其实在ES6前jQuery早就有了Promise,也就是我们所知道的Deferred对象,宗旨当然也和ES6的Promise一样, ...

  2. Nodejs事件引擎libuv源码剖析之:高效线程池(threadpool)的实现

    声明:本文为原创博文,转载请注明出处. Nodejs编程是全异步的,这就意味着我们不必每次都阻塞等待该次操作的结果,而事件完成(就绪)时会主动回调通知我们.在网络编程中,一般都是基于Reactor线程 ...

  3. Apache Spark源码剖析

    Apache Spark源码剖析(全面系统介绍Spark源码,提供分析源码的实用技巧和合理的阅读顺序,充分了解Spark的设计思想和运行机理) 许鹏 著   ISBN 978-7-121-25420- ...

  4. 基于mybatis-generator-core 1.3.5项目的修订版以及源码剖析

    项目简单说明 mybatis-generator,是根据数据库表.字段反向生成实体类等代码文件.我在国庆时候,没事剖析了mybatis-generator-core源码,写了相当详细的中文注释,可以去 ...

  5. STL"源码"剖析-重点知识总结

    STL是C++重要的组件之一,大学时看过<STL源码剖析>这本书,这几天复习了一下,总结出以下LZ认为比较重要的知识点,内容有点略多 :) 1.STL概述 STL提供六大组件,彼此可以组合 ...

  6. SpringMVC源码剖析(四)- DispatcherServlet请求转发的实现

    SpringMVC完成初始化流程之后,就进入Servlet标准生命周期的第二个阶段,即“service”阶段.在“service”阶段中,每一次Http请求到来,容器都会启动一个请求线程,通过serv ...

  7. 自己实现多线程的socket,socketserver源码剖析

    1,IO多路复用 三种多路复用的机制:select.poll.epoll 用的多的两个:select和epoll 简单的说就是:1,select和poll所有平台都支持,epoll只有linux支持2 ...

  8. Java多线程9:ThreadLocal源码剖析

    ThreadLocal源码剖析 ThreadLocal其实比较简单,因为类里就三个public方法:set(T value).get().remove().先剖析源码清楚地知道ThreadLocal是 ...

  9. JS魔法堂:mmDeferred源码剖析

    一.前言 avalon.js的影响力愈发强劲,而作为子模块之一的mmDeferred必然成为异步调用模式学习之旅的又一站呢!本文将记录我对mmDeferred的认识,若有纰漏请各位指正,谢谢.项目请见 ...

随机推荐

  1. Android 经典欧美小游戏 guess who

    本来是要做iOS开发的,因为一些世事无常和机缘巧合与测试工作还有安卓系统结下了不解之缘,前不久找到了guess who 源码,又加入了一些自己的元素最终完成了这个简单的小游戏. <?xml ve ...

  2. 文档撰写思路与排版(hadoop)

    这几天在写项目提交的几个报告,写完回想了一下,在写作框架确定与排版上浪费了不少时间,特此总结一下思路. 这个写完回家过年了....emmmm 1. 定好大标题框架,使用自动添加序号,先不着急修改样式 ...

  3. 文科妹学 GitHub 简易教程

      #什么是 Github ?必须要放这张图了!!!<img src="https://pic4.zhimg.com/7c9d3403bf922b1663f56975869c829b_ ...

  4. poj_3253 Fence Repair

    Fence Repair Description Farmer John wants to repair a small length of the fence around the pasture. ...

  5. 试试SQLServer 2014的内存优化表(转载)

    SQL Server2014存储引擎:行存储引擎,列存储引擎,内存引擎 SQL Server 2014中的内存引擎(代号为Hekaton)将OLTP提升到了新的高度. 现在,存储引擎已整合进当前的数据 ...

  6. USB 相关笔记

    1分析已有代码项目 Android从USB声卡录制高质量音频-----使用libusb读取USB声卡数据 github 项目:usbaudio-android-demo usb声卡取数据项目也是参考的 ...

  7. .Net 环境

    更多系统版本下载:https://www.microsoft.com/net/download VSCode :https://code.visualstudio.com/

  8. mac系统默认python3.6

    1. 终端打开.bash_profile文件 终端输入:open ~/.bash_profile   2. 打开.bash_profile文件后在内容最后添加  alias python=" ...

  9. SDN期末

    一.项目描述 负载均衡程序 二.小组分工 组名:我们真的很弱 组员:李佳铭.吴森杰.张岚鑫.薛宇涛.杨凌澜 三.代码演示及过程描述 四.课程总结

  10. python爬虫(二)

    python爬虫之urllib 在python2和python3中的差异 在python2中,urllib和urllib2各有各个的功能,虽然urllib2是urllib的升级版,但是urllib2还 ...