一、原理篇

1.1 温故知新
回归树是GBDT的基础,之前的一篇文章曾经讲过回归树的原理和实现。链接如下:

回归树的原理及Python实现

1.2 预测年龄
仍然以预测同事年龄来举例,从《回归树》那篇文章中我们可以知道,如果需要通过一个常量来预测同事的年龄,平均值是最佳选择之一。

1.3 年龄的残差
我们不妨假设同事的年龄分别为5岁、6岁、7岁,那么同事的平均年龄就是6岁。所以我们用6岁这个常量来预测同事的年龄,即[6, 6, 6]。每个同事年龄的残差 = 年龄 - 预测值 = [5, 6, 7] - [6, 6, 6],所以残差为[-1, 0, 1]

1.4 预测年龄的残差
为了让模型更加准确,其中一个思路是让残差变小。如何减少残差呢?我们不妨对残差建立一颗回归树,然后预测出准确的残差。假设这棵树预测的残差是[-0.9,
0, 0.9],将上一轮的预测值和这一轮的预测值求和,每个同事的年龄 = [6, 6, 6] + [-0.9, 0, 0.9] = [5.1,
6, 6.9],显然与真实值[5, 6, 7]更加接近了, 年龄的残差此时变为[-0.1, 0, 0.1],预测的准确性得到了提升。

1.5 GBDT
重新整理一下思路,假设我们的预测一共迭代3轮 年龄:[5, 6, 7]

第1轮预测:6, 6, 6

第1轮残差:[-1, 0, 1]

第2轮预测:6, 6, 6 + -0.9, 0, 0.9 = [5.1, 6, 6.9]

第2轮残差:[-0.1, 0, 0.1]

第3轮预测:6, 6, 6 + -0.9, 0, 0.9 + -0.08, 0, 0.07 = [5.02, 6, 6.97]

第3轮残差:[-0.08, 0, 0.03]

看上去残差越来越小,而这种预测方式就是GBDT算法。

1.6 公式推导
看到这里,相信您对GBDT已经有了直观的认识。这么做有什么科学依据么,为什么残差可以越来越小呢?前方小段数学公式低能预警。

因此,我们需要通过用第m-1轮残差的均值来得到函数fm,进而优化函数Fm。而回归树的原理就是通过最佳划分区域的均值来进行预测。所以fm可以选用回归树作为基础模型,将初始值,m-1颗回归树的预测值相加便可以预测y。

二、实现篇

Python实现了GBDT回归算法,没有依赖任何第三方库,便于学习和使用。

2.1 导入回归树类
回归树是我之前已经写好的一个类,在之前的文章详细介绍过,代码请参考:

https://github.com/tushushu/imylu/blob/master/imylu/tree/regression_tree.py
from ..tree.regression_tree import RegressionTree

2.2 创建GradientBoostingBase类
初始化,存储回归树、学习率、初始预测值和变换函数。(注:回归不需要做变换,因此函数的返回值等于参数)

class GradientBoostingBase(object):
    def __init__(self):
        self.trees = None
        self.lr = None
        self.init_val = None
        self.fn = lambda x: x

2.3 计算初始预测值
初始预测值即y的平均值。

def _get_init_val(self, y):
    return sum(y) / len(y)

2.4 计算残差

def _get_residuals(self, y, y_hat):
    return [yi - self.fn(y_hat_i) for yi, y_hat_i in zip(y, y_hat)]

2.5 训练模型

训练模型的时候需要注意以下几点:
1. 控制树的最大深度max_depth; 2. 控制分裂时最少的样本量min_samples_split; 3.
训练每一棵回归树的时候要乘以一个学习率lr,防止模型过拟合; 4. 对样本进行抽样的时候要采用有放回的抽样方式。

def fit(self, X, y, n_estimators, lr, max_depth, min_samples_split, subsample=None):
    self.init_val = self._get_init_val(y)     n = len(y)
    y_hat = [self.init_val] * n
    residuals = self._get_residuals(y, y_hat)     self.trees = []
    self.lr = lr
    for _ in range(n_estimators):
        idx = range(n)
        if subsample is not None:
            k = int(subsample * n)
            idx = choices(population=idx, k=k)
        X_sub = [X[i] for i in idx]
        residuals_sub = [residuals[i] for i in idx]
        y_hat_sub = [y_hat[i] for i in idx]         tree = RegressionTree()
        tree.fit(X_sub, residuals_sub, max_depth, min_samples_split)         self._update_score(tree, X_sub, y_hat_sub, residuals_sub)         y_hat = [y_hat_i + lr * res_hat_i for y_hat_i,
                    res_hat_i in zip(y_hat, tree.predict(X))]         residuals = self._get_residuals(y, y_hat)
        self.trees.append(tree)

2.6 预测一个样本

def _predict(self, Xi):
    return self.fn(self.init_val + sum(self.lr * tree._predict(Xi) for tree in self.trees))

2.7 预测多个样本

def predict(self, X):
    return [self._predict(Xi) for Xi in X]

三、效果评估

3.1 main函数

使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。

@run_time
def main():
    print("Tesing the accuracy of GBDT regressor...")     X, y = load_boston_house_prices()     X_train, X_test, y_train, y_test = train_test_split(
        X, y, random_state=10)     reg = GradientBoostingRegressor()
    reg.fit(X=X_train, y=y_train, n_estimators=4,
            lr=0.5, max_depth=2, min_samples_split=2)     get_r2(reg, X_test, y_test)

3.2 效果展示
最终拟合优度0.851,运行时间5.0秒,效果还算不错~

3.3 工具函数
本人自定义了一些工具函数,可以在github上查看

https://github.com/tushushu/imylu/blob/master/imylu/utils.py
  1. run_time - 测试函数运行时间

  2. load_boston_house_prices - 加载波士顿房价数据

  3. train_test_split - 拆分训练集、测试集

  4. get_r2 - 计算拟合优度

四、总结

GBDT回归的原理:平均值加回归树

GBDT回归的实现:加加减减for循环

WeChat:https://mp.weixin.qq.com/s/t7kxG9_6ZBIKpuKFZBU0Jg

GBDT回归的原理及Python实现的更多相关文章

  1. 对数几率回归(逻辑回归)原理与Python实现

    目录 一.对数几率和对数几率回归 二.Sigmoid函数 三.极大似然法 四.梯度下降法 四.Python实现 一.对数几率和对数几率回归   在对数几率回归中,我们将样本的模型输出\(y^*\)定义 ...

  2. GBDT原理及利用GBDT构造新的特征-Python实现

    1. 背景 1.1 Gradient Boosting Gradient Boosting是一种Boosting的方法,它主要的思想是,每一次建立模型是在之前建立模型损失函数的梯度下降方向.损失函数是 ...

  3. 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

    梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...

  4. GBDT 算法:原理篇

    本文由云+社区发表 GBDT 是常用的机器学习算法之一,因其出色的特征自动组合能力和高效的运算大受欢迎. 这里简单介绍一下 GBDT 算法的原理,后续再写一个实战篇. 1.决策树的分类 决策树分为两大 ...

  5. 机器学习系列------1. GBDT算法的原理

    GBDT算法是一种监督学习算法.监督学习算法需要解决如下两个问题: 1.损失函数尽可能的小,这样使得目标函数能够尽可能的符合样本 2.正则化函数对训练结果进行惩罚,避免过拟合,这样在预测的时候才能够准 ...

  6. paip.编程语言方法重载实现的原理及python,php,js中实现方法重载

    paip.编程语言方法重载实现的原理及python,php,js中实现方法重载 有些语言,在方法的重载上,形式上不支持函数重载,但可以通过模拟实现.. 主要原理:根据参数个数进行重载,或者使用默认值 ...

  7. MapReduce 原理与 Python 实践

    MapReduce 原理与 Python 实践 1. MapReduce 原理 以下是个人在MongoDB和Redis实际应用中总结的Map-Reduce的理解 Hadoop 的 MapReduce ...

  8. 模拟退火算法SA原理及python、java、php、c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径

    模拟退火算法SA原理及python.java.php.c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径 模拟退火算法(Simulated Annealing,SA)最早的思 ...

  9. 编解码原理,Python默认解码是ascii

    编解码原理,Python默认解码是ascii 首先我们知道,python里的字符默认是ascii码,英文当然没问题啦,碰到中文的时候立马给跪. 不知道你还记不记得,python里打印中文汉字的时候需要 ...

随机推荐

  1. [Xcode 实际操作]五、使用表格-(7)UITableView单元格间隔背景色

    目录:[Swift]Xcode实际操作 本文将演示如何给表格设置间隔的背景颜色. 在项目导航区,打开视图控制器的代码文件[ViewController.swift] import UIKit //首先 ...

  2. 为什么Python如此慢

    Python当前人气暴涨.它在DevOps,数据科学,Web开发和安全领域均有使用. 但是在速度方面没有赢得美誉. 这里有关于Python比较其他语言如,Java, C#, Go, JavaScrip ...

  3. 1、kvm的vnc服务关闭、设置网络模式

    一.kvm的vnc服务关闭 1.关闭虚拟机 virsh shutdown privi-server 2.virsh edit privi-server找到下面内容进行删除 <graphics t ...

  4. 基于SSL加密的vsftpd 服务器搭建和配置

    基于SSL加密的VSFTPD 服务器搭建和配置 1.安装 ubuntu系统:apt-get install vsftp lftp centos系统:yum install -y vsftpd ftp ...

  5. Angular2.0的学习(五)

    第五节课: 1.组件的输入输出属性 2.使用中间人模式传递数据 3.组件生命周期以及Angular的变化发现机制

  6. 大数据量高并发访问SQL优化方法

    保证在实现功能的基础上,尽量减少对数据库的访问次数:通过搜索参数,尽量减少对表的访问行数,最小化结果集,从而减轻网络负担:能够分开的操作尽量分开处理,提高每次的响应速度:在数据窗口使用SQL时,尽量把 ...

  7. chapter07

    // 包和引入// 包也可以像内部类那样嵌套// 包路径不是绝对路径// 包声明链x.y.x并不自动 将中间包x和x.y变成可见// 位于文件顶部不带花括号的包声明在整个文件范围内有效// 包对象可以 ...

  8. Spark操作

    ### scala源码 /* SimpleApp.scala */ import org.apache.spark.SparkContext import org.apache.spark.Spark ...

  9. android 缓存路径

    用程序在运行的过程中如果需要向手机上保存数据,一般是把数据保存在SDcard中的.大部分应用是直接在SDCard的根目录下创建一个文件夹,然后把数据保存在该文件夹中.这样当该应用被卸载后,这些数据还保 ...

  10. {ldelim},{rdelim} - smarty 内建函数

    {ldelim}和{rdelim}用来转义模板的分隔符,缺省为{和}.你也可以用{literal}{/literal}来转义文本块(如Javascript或CSS). 例: {* 在模板外将原样打印分 ...