GBDT的数学原理
一、GBDT的原理
GBDT(Gradient Boosting Decision Tree) 又叫 MART(Multiple Additive Regression Tree),是一种迭代的决策树算法,该算法由多棵决策树组成,所有树的结论累加起来做最终答案。它在被提出之初就和SVM一起被认为是泛化能力较强的算法。
GBDT中的树是回归树(不是分类树),GBDT用来做回归预测,调整后也可以用于分类。
GBDT的思想使其具有天然优势可以发现多种有区分性的特征以及特征组合。业界中,Facebook使用其来自动发现有效的特征、特征组合,来作为LR模型中的特征,以提高CTR预估(Click-Through Rate Prediction)的准确性(详见参考文献5、6);GBDT在淘宝的搜索及预测业务上也发挥了重要作用(详见参考文献7)。
提升树利用加法模型和前向分步算法实现学习的优化过程。当损失函数时平方损失和指数损失函数时,每一步的优化很简单,如平方损失函数学习残差回归树。
但对于一般的损失函数,往往每一步优化没那么容易,如上图中的绝对值损失函数和Huber损失函数。针对这一问题,Freidman提出了梯度提升算法:利用最速下降的近似方法,即利用损失函数的负梯度在当前模型的值,作为回归问题中提升树算法的残差的近似值,拟合一个回归树。(注:鄙人私以为,与其说负梯度作为残差的近似值,不如说残差是负梯度的一种特例)算法如下(截图来自《The Elements of Statistical Learning》):
算法步骤解释:
- 1、初始化,估计使损失函数极小化的常数值,它是只有一个根节点的树,即ganma是一个常数值。
- 2、
(a)计算损失函数的负梯度在当前模型的值,将它作为残差的估计
(b)估计回归树叶节点区域,以拟合残差的近似值
(c)利用线性搜索估计叶节点区域的值,使损失函数极小化
(d)更新回归树 - 3、得到输出的最终模型 f(x)
二、GBDT的参数设置
1、推荐GBDT树的深度:6;(横向比较:DecisionTree/RandomForest需要把树的深度调到15或更高)
2、【问】xgboost/gbdt在调参时为什么树的深度很少就能达到很高的精度?
用xgboost/gbdt在在调参的时候把树的最大深度调成6就有很高的精度了。但是用DecisionTree/RandomForest的时候需要把树的深度调到15或更高。用RandomForest所需要的树的深度和DecisionTree一样我能理解,因为它是用bagging的方法把DecisionTree组合在一起,相当于做了多次DecisionTree一样。但是xgboost/gbdt仅仅用梯度上升法就能用6个节点的深度达到很高的预测精度,使我惊讶到怀疑它是黑科技了。请问下xgboost/gbdt是怎么做到的?它的节点和一般的DecisionTree不同吗?
【答】
(1)Boosting主要关注降低偏差(bais),因此Boosting能基于泛化性能相当弱的学习器构建出很强的集成;Bagging主要关注降低方差(variance),因此它在不剪枝的决策树、神经网络等学习器上效用更为明显。
(2)随机森林(random forest)和GBDT都是属于集成学习(ensemble learning)的范畴。集成学习下有两个重要的策略Bagging和Boosting。
Bagging算法是这样做的:每个分类器都随机从原样本中做有放回的采样,然后分别在这些采样后的样本上训练分类器,然后再把这些分类器组合起来。简单的多数投票一般就可以。其代表算法是随机森林。Boosting的意思是这样,他通过迭代地训练一系列的分类器,每个分类器采用的样本分布都和上一轮的学习结果有关。其代表算法是AdaBoost, GBDT。
(3)其实就机器学习算法来说,其泛化误差可以分解为两部分,偏差(bias)和方差(variance)。这个可由下图的式子导出(这里用到了概率论公式D(X)=E(X^2)-[E(X)]^2)。偏差指的是算法的期望预测与真实预测之间的偏差程度,反应了模型本身的拟合能力;方差度量了同等大小的训练集的变动导致学习性能的变化,刻画了数据扰动所导致的影响。这个有点儿绕,不过你一定知道过拟合。
如下图所示,当模型越复杂时,拟合的程度就越高,模型的训练偏差就越小。但此时如果换一组数据可能模型的变化就会很大,即模型的方差很大。所以模型过于复杂的时候会导致过拟合。
当模型越简单时,即使我们再换一组数据,最后得出的学习器和之前的学习器的差别就不那么大,模型的方差很小。还是因为模型简单,所以偏差会很大。
也就是说,当我们训练一个模型时,偏差和方差都得照顾到,漏掉一个都不行。
对于Bagging算法来说,由于我们会并行地训练很多不同的分类器的目的就是降低这个方差(variance) ,因为采用了相互独立的基分类器多了以后,h的值自然就会靠近.所以对于每个基分类器来说,目标就是如何降低这个偏差(bias),所以我们会采用深度很深甚至不剪枝的决策树。
对于Boosting来说,每一步我们都会在上一轮的基础上更加拟合原数据,所以可以保证偏差(bias),所以对于每个基分类器来说,问题就在于如何选择variance更小的分类器,即更简单的分类器,所以我们选择了深度很浅的决策树。
三、参考文献
1、http://www.jianshu.com/p/005a4e6ac775
GBDT的数学原理的更多相关文章
- 机器学习系列------1. GBDT算法的原理
GBDT算法是一种监督学习算法.监督学习算法需要解决如下两个问题: 1.损失函数尽可能的小,这样使得目标函数能够尽可能的符合样本 2.正则化函数对训练结果进行惩罚,避免过拟合,这样在预测的时候才能够准 ...
- OpenGL坐标变换及其数学原理,两种摄像机交互模型(附源程序)
实验平台:win7,VS2010 先上结果截图(文章最后下载程序,解压后直接运行BIN文件夹下的EXE程序): a.鼠标拖拽旋转物体,类似于OGRE中的“OgreBites::CameraStyle: ...
- RSA加密数学原理
RSA加密数学原理 */--> *///--> *///--> UP | HOME RSA加密数学原理 Table of Contents 1 引言 2 RSA加密解密过程 2.1 ...
- PCA的数学原理
PCA(Principal Component Analysis)是一种常用的数据分析方法.PCA通过线性变换将原始数据变换为一组各维度线性无关的表示,可用于提取数据的主要特征分量,常用于高维 数据的 ...
- PCA数学原理
PCA(Principal Component Analysis)是一种常用的数据分析方法.PCA通过线性变换将原始数据变换为一组各维度线性无关的表示,可用于提取数据的主要特征分量,常用于高维数据的降 ...
- 【机器学习笔记之七】PCA 的数学原理和可视化效果
PCA 的数学原理和可视化效果 本文结构: 什么是 PCA 数学原理 可视化效果 1. 什么是 PCA PCA (principal component analysis, 主成分分析) 是机器学习中 ...
- word2vec 数学原理
word2vec 是 Google 于 2013 年推出的一个用于获取词向量的开源工具包.我们在项目中多次使用到它,但囿于时间关系,一直没仔细探究其背后的原理. 网络上 <word2vec 中的 ...
- 非对称加密技术- RSA算法数学原理分析
非对称加密技术,在现在网络中,有非常广泛应用.加密技术更是数字货币的基础. 所谓非对称,就是指该算法需要一对密钥,使用其中一个(公钥)加密,则需要用另一个(私钥)才能解密. 但是对于其原理大部分同学应 ...
- PCA的数学原理(转)
PCA(Principal Component Analysis)是一种常用的数据分析方法.PCA通过线性变换将原始数据变换为一组各维度线性无关的表示,可用于提取数据的主要特征分量,常用于高维数据的降 ...
随机推荐
- C#实现基于ffmpeg加虹软的人脸识别demo及开发分享
对开发库的C#封装,屏蔽使用细节,可以快速安全的调用人脸识别相关API.具体见github地址.新增对.NET Core的支持,在Linux(Ubuntu下)测试通过.具体的使用例子和Demo详解,参 ...
- libcrypto.so.1.0.0: no version information available
openssl-1.0.1p源码安装后,依赖于openssl.so库的应用报错libcrypto.so.1.0.0: no version information available 解法:1. 创建 ...
- php 两次encodeURI,解决浏览器跳转请求页乱码报错找不到页面的bug
Not Found The requested URL /index.php/XXX/mid/97329240798095910/bname/3000T/D/sname/水泥粉磨/un ...
- sublime和vscode 格式化Json ——两步走
目录 1.问题来源 2.sublime安装插件方式 3.使用方式 4.扩展:对于软件vscode 1.问题来源 最近做数据匹配任务,需要生成很多json文件,但是每个json文件又太大,想要逐字段(k ...
- vs.net2015发布web网站时,提示JsonIgnoreAttribute无法找到的解决办法
产生该问题的原因是因为项目中引用了两个版本的newtonsoft.json.dll,具体解决办法参见: 用记事本打开项目文件(*.csproj) 可以找到在这个文件中,有两处Newtonsoft.Js ...
- sql server中的varchar和Nvarchar有什么区别?
很多开发者进行数据库设计的时候往往并没有太多的考虑char, varchar类型,有的是根本就没注意,因为存储价格变得越来越便宜了,忘记了最开始的一些基本设计理论和原则,这点让我想到了现在的年轻人,大 ...
- RBMQ发布和订阅消息
RBMQ发布和订阅消息 exchange 参考翻译自: RabbitMQ官网 生产者并非将消息直接发送到queue,而是发送到exchange中,具体将消息发送到特定的队列还是多个队列,或者是丢弃,取 ...
- java Arrays常用方法
1. 简介 Arrays类包含用于操作数组的各种方法(例如排序和搜索).此类还包含一个静态工厂,允许将数组视为列表. 如果指定的数组引用为null,则此类中的方法都抛出NullPointerExcep ...
- vmware三种网络模式的工作原理及配置详解
vmware为我们提供了三种网络工作模式,它们分别是:Bridged(桥接模式).NAT(网络地址转换模式).Host-Only(仅主机模式). 打开vmware虚拟机,我们可以在选项栏的“编辑”下的 ...
- 一个狗血的mysql编码错误
执行查询语句总是报错,某个查询语句字段编码错误. 各种修改那个表没用, 最后发现是创建schemas的时候没有加编码 应该由 CREATE SCHEMA new_schema;改为 CREATE SC ...