一、概述

  GBDT(Gradient Boosting Decision Tree,梯度提升决策树)是集成学习中提升(Boosting)方法的典型代表。它以决策树(通常是 CART 树,即分类回归树)作为弱学习器,通过迭代的方式,不断拟合残差(回归任务)或负梯度(分类任务),逐步构建一系列决策树,最终将这些树的预测结果进行累加,得到最终的预测值。

二、算法原理

1. 梯度下降思想​

  梯度下降是一种常用的优化算法,用于寻找函数的最小值。在 GBDT 中,它扮演着至关重要的角色。假设我们有一个损失函数\(L\left( y,\hat{y} \right)\),其中\(y\)是真实值,\(\hat y\)是预测值。梯度下降的目标就是通过不断调整模型参数,使得损失函数的值最小化。具体来说,每次迭代时,沿着损失函数关于参数的负梯度方向更新参数,以逐步接近最优解。在 GBDT 中,虽然没有显式地更新参数(通过构建多颗决策树来拟合目标),但拟合的目标是损失函数的负梯度,本质上也是利用了梯度下降的思想。

2. 决策树的构建​

  GBDT 使用决策树作为弱学习器。决策树是一种基于树结构的预测模型,它通过对数据特征的不断分裂,将数据划分成不同的子集,每个子集对应树的一个节点。在每个节点上,通过某种准则(如回归任务中的平方误差最小化,分类任务中的基尼指数最小化)选择最优的特征和分裂点,使得划分后的子集在目标变量上更加 “纯净” 或具有更好的区分度。通过递归地进行特征分裂,直到满足停止条件(如达到最大树深度、节点样本数小于阈值等),从而构建出一棵完整的决策树。

3. 迭代拟合的过程​

(1) 初始化模型

  首先,初始化一个简单的模型,通常是一个常数模型,记为\(f_0(X)\) ,其预测值为所有样本真实值的均值(回归任务)或多数类(分类任务),记为\(\hat y_0\)。此时,模型的预测结果与真实值之间存在误差。

(2) 计算残差或负梯度

  在回归任务中,计算每个样本的残差,即真实值\(y_i\)与当前模型预测值\(\hat y_{i,t-1}\)的差值\(r_{i,t}=y_i-\hat y_{i,t-1}\),其中表示迭代的轮数。在分类任务中,计算损失函数关于当前模型预测值的负梯度$$g_{i,t}=-\frac{\vartheta L(y_i,\hat y_{i,t-1})}{\vartheta \hat y_{i,t-1}}$$

(3) 拟合决策树

  使用计算得到的残差(回归任务)或负梯度(分类任务)作为新的目标值,训练一棵新的决策树\(f_t(X)\)。这棵树旨在拟合当前模型的误差,从而弥补当前模型的不足。

(4) 更新模型

  根据新训练的决策树,更新当前模型。更新公式为\(\hat y_{i,t}=\hat y_{i,t-1}+\alpha f_t(x_i)\),其中是学习率(也称为步长),用于控制每棵树对模型更新的贡献程度。学习率较小可以使模型训练更加稳定,但需要更多的迭代次数;学习率较大则可能导致模型收敛过快,甚至无法收敛。

(5) 重复迭代

  重复步骤 (2)–(4)步,不断训练新的决策树并更新模型,直到达到预设的迭代次数、损失函数收敛到一定程度或满足其他停止条件为止。最终,GBDT 模型由多棵决策树组成,其预测结果是所有决策树预测结果的累加。

算法过程图示



  GBDT 算法将梯度下降思想与决策树相结合,通过迭代拟合残差或负梯度,逐步构建一个强大的集成模型。它在处理复杂数据和非线性关系时表现较为出色,在数据挖掘、机器学习等领域得到了广泛的应用。然而,GBDT 也存在一些缺点,如训练时间较长、对异常值较为敏感等,在实际应用中需要根据具体情况进行优化和调整 。

三、Python实现

(环境:Python 3.11,scikit-learn 1.5.1)

分类情形

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import metrics # 生成样本数据
X, y = make_classification(n_samples=1000, n_features=50, n_informative=10, n_redundant=5, random_state=1) # 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1) # 创建GDBT分类模型
gbc = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=1) # 训练模型
gbc.fit(X_train, y_train) # 进行预测
y_pred = gbc.predict(X_test) # 计算准确率
accuracy = metrics.accuracy_score(y_test,y_pred)
print('准确率为:',accuracy)

回归情形

from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error # 生成样本数据
X, y = make_regression(n_samples=1000, n_features=10, n_informative=5, noise=0.1, random_state=42) # 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 创建GDBT回归模型
model = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, random_state=42) # 训练模型
model.fit(X_train, y_train) # 在测试集上进行预测
y_pred = model.predict(X_test) # 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"MSE: {mse}")

End.

下载

GBDT算法原理及Python实现的更多相关文章

  1. GBDT算法原理深入解析

    GBDT算法原理深入解析 标签: 机器学习 集成学习 GBM GBDT XGBoost 梯度提升(Gradient boosting)是一种用于回归.分类和排序任务的机器学习技术,属于Boosting ...

  2. 深入学习主成分分析(PCA)算法原理(Python实现)

    一:引入问题 首先看一个表格,下表是某些学生的语文,数学,物理,化学成绩统计: 首先,假设这些科目成绩不相关,也就是说某一科目考多少分与其他科目没有关系,那么如何判断三个学生的优秀程度呢?首先我们一眼 ...

  3. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  4. KNN算法原理(python代码实现)

    kNN(k-nearest neighbor algorithm)算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性 ...

  5. (数据科学学习手札13)K-medoids聚类算法原理简介&Python与R的实现

    前几篇我们较为详细地介绍了K-means聚类法的实现方法和具体实战,这种方法虽然快速高效,是大规模数据聚类分析中首选的方法,但是它也有一些短板,比如在数据集中有脏数据时,由于其对每一个类的准则函数为平 ...

  6. PageRank算法原理与Python实现

    一.什么是pagerank PageRank的Page可是认为是网页,表示网页排名,也可以认为是Larry Page(google 产品经理),因为他是这个算法的发明者之一,还是google CEO( ...

  7. 【机器学习】:Kmeans均值聚类算法原理(附带Python代码实现)

    这个算法中文名为k均值聚类算法,首先我们在二维的特殊条件下讨论其实现的过程,方便大家理解. 第一步.随机生成质心 由于这是一个无监督学习的算法,因此我们首先在一个二维的坐标轴下随机给定一堆点,并随即给 ...

  8. 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现

    笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...

  9. GBDT 算法:原理篇

    本文由云+社区发表 GBDT 是常用的机器学习算法之一,因其出色的特征自动组合能力和高效的运算大受欢迎. 这里简单介绍一下 GBDT 算法的原理,后续再写一个实战篇. 1.决策树的分类 决策树分为两大 ...

  10. 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

    梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...

随机推荐

  1. nginx 编译安装及配置解析

    一.编译安装 安装插件 安装需要的组件 yum -y install gcc gcc-c++ pcre pcre-devel zlib zlib-devel openssl openssl-devel ...

  2. 腾讯元宝接入 DeepSeek R1 模型,支持深度思考 + 联网搜索,好用不卡机!

    前言 腾讯元宝AI产品于2025年2月13日在应用商店发布更新,正式接入了DeepSeek R1模型,并宣布该模型已联网.满血上线,DeepSeek+腾讯混元,好用不卡机. 腾讯元宝介绍 腾讯元宝是依 ...

  3. Johnson 全源负权最短路径算法详解

    Floyd-Warshall算法可以求解出图内任意两点的最短路径,适用于稠密图,但时间复杂度为 \(O(n³)\):Dijkstra算法求解单源最短路径的时间复杂度为 \(O(m + n log n) ...

  4. 库卡机器人维修常见报警代码KSS故障修复

    对库卡机器人工作中一些细节和一些安全的措施有所了解才能防患于未然.库卡机器人故障排查可通过观察KUKA机械手报警代码得知,故障代码以及原因有: --kuka机械臂提示文字KSS 0121:电流过大 原 ...

  5. AD22 solder层开窗不见了什么原因?PCB阻焊层不显示阻焊。

    做回来的板子焊盘没有开窗! 3D上看也是没有的! 上网查了,在设计->规则->Mask->SolderMaskExpansion 将盖油勾选去掉即可. END

  6. WPF .Net Core 3.1遇到Satellite Assemblies无法正常加载的处理

    1.原因 加载的时候没有调取 AssemblyLoadContext.Default 2.解决方案: 在程序启动的时候,手动调用 /// <summary> /// Interaction ...

  7. Java01 - Scanner对象

    简介 之前我们学的基本语法中并没有实现程序和人的交互,但是Java给我们提供了这样一个工具类,我们可以获取用户的输入.java.util.Scanner是Java5的新特征,我们可以通过Scanner ...

  8. Hive - [01] 概述

    一.Hive是什么 是Facebook开源,用于解决海量结构化日志的数据统计工具. 是基于Hadoop的一个数据仓库工具,可以将结构化的数据文件映射为一张表,并提供类SQL查询功能. Hive处理的数 ...

  9. SSM - 狂神的项目示例

    出于对狂神的崇拜,总结SSM项目. 基本介绍 项目分层 基本介绍 项目名称:ssmbuild 介绍:通过书籍管理系统实现一个简单的SSM项目,可以作为其他Java Web项目的借鉴. 主要功能模块:查 ...

  10. 大模型知识引擎 LKE 新手入门指南:官方文档难懂?看这篇就够了

    昨天简单体验了一下大模型知识引擎(LKE),总体来说,虽然其功能方面还有一定欠缺,但在一些特定领域,特别是RAG(Retrieval-Augmented Generation)技术的应用上,还是表现出 ...