一、概述

  LightGBM 由微软公司开发,是基于梯度提升框架的高效机器学习算法,属于集成学习中提升树家族的一员。它以决策树为基学习器,通过迭代地训练一系列决策树,不断纠正前一棵树的预测误差,逐步提升模型的预测精度,最终将这些决策树的结果进行整合,输出最终的预测结果。

二、算法原理

1.训练过程

(1) 初始化模型

  首先,初始化一个简单的模型,通常是一个常数模型,记为\(f_0(X)\),其预测值为所有样本真实值的均值(回归任务)或多数类(分类任务),记为\(\hat y_0\)。此时,模型的预测结果与真实值之间存在误差。

(2) 计算梯度和 Hessian矩阵

  对于每个样本,计算损失函数关于当前模型预测值的梯度和 Hessian 矩阵(二阶导数),用以确定模型需要调整的方向和幅度。例如,在均方误差损失函数下,梯度就是预测值与真实值之间的差值。

(3) 构建决策树

  基于计算得到的梯度和 Hessian,构建一棵新的决策树,使用直方图算法等优化技术来加速决策树的构建过程。分裂节点的依据是最大化信息增益或最小化损失函数的减少量。同时,为了防止过拟合,应用一些剪枝策略,如限制树的深度、叶子节点的最小样本数等。

(4) 更新模型

  根据新训练的决策树,更新当前模型。更新公式为

\[\hat Y_1 = \hat Y_0 + \alpha f_1(X)
\]

,也就是将新决策树的输出乘以学习率,加到当前模型的预测值上。其中学习率 (也称为步长),用于控制每棵树对模型更新的贡献程度。学习率较小可以使模型训练更加稳定,但需要更多的迭代次数;学习率较大则可能导致模型收敛过快,甚至无法收敛。

(5) 重复迭代

  重复步骤 (2)–(4)步,不断训练新的决策树并更新模型,直到达到预设的迭代次数、损失函数收敛到一定程度或满足其他停止条件为止。最终,LightGBM 模型由多棵决策树组成,其预测结果是所有决策树预测结果的累加。

算法过程图示

2. 直方图算法

  LightGBM 采用了直方图算法来加速决策树的构建过程。传统的决策树算法在寻找最佳分裂点时,需要遍历所有的特征值,计算量巨大。而 LightGBM 将连续的特征值离散化成 k 个桶(bin),构建一个宽度为 k 的直方图。在寻找最佳分裂点时,只需遍历直方图中的 k 个值,大大减少了计算量,提高了算法的训练速度。

3. 单边梯度采样(GOSS)

  数据集中存在大量梯度较小的样本,这些样本对模型的提升作用较小,但在计算梯度时却占用了大量的计算资源。单边梯度采样(GOSS)根据样本的梯度大小对样本进行采样,保留梯度较大的样本,并对梯度较小的样本进行随机采样,在不影响模型精度的前提下,减少了训练数据量,提高了训练效率。

4. 互斥特征捆绑(EFB)

  在实际数据中,许多特征是相互关联的,存在大量的稀疏特征。互斥特征捆绑(EFB)算法将互斥的特征捆绑在一起,形成一个新的特征,从而减少特征的数量。这样在构建决策树时,就可以减少计算量,提高算法的运行效率。

三、算法优势

1.训练速度快

  得益于直方图算法、GOSS 和 EFB 等技术,LightGBM 在处理大规模数据时,训练速度相比传统的梯度提升算法有显著提升。无论是处理小数据集还是大数据集,都能在较短时间内完成模型训练。

2.内存占用少

  通过直方图算法和特征捆绑等技术,LightGBM 有效减少了数据存储和计算过程中的内存占用。这使得它可以在资源有限的环境下运行,如在个人计算机上处理大规模数据,或者在内存受限的服务器上同时运行多个模型训练任务。

3.可扩展性强

  LightGBM 支持分布式训练,可以充分利用多台计算机的计算资源,加快训练速度。它还支持并行学习,能够同时处理多个特征,进一步提高训练效率。这种强大的可扩展性使得它能够适应各种规模和复杂度的机器学习任务。

4.准确率高

  尽管 LightGBM 在训练过程中采用了多种优化技术来提高效率,但它并没有牺牲模型的准确率。在许多实际应用和机器学习竞赛中,LightGBM 都能取得与其他先进算法相当甚至更优的预测结果。

5.支持多种数据类型和任务

LightGBM 不仅支持常见的数值型和类别型数据,还能处理稀疏数据。同时,它广泛应用于回归、分类、排序等多种机器学习任务,具有很强的通用性。

四、Python实现

(Python 3.11,scikit-learn 1.6.1)

分类情形

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import lightgbm as lgbm
from sklearn import metrics # 生成数据集
X, y = make_classification(n_samples = 1000, n_features = 6, random_state = 42)
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42) # 创建lightGBM分类模型
clf = lgbm.LGBMClassifier(verbosity=-1)
# 训练模型
rclf = clf.fit(X_train, y_train) # 预测
y_pre = rclf.predict(X_test)
# 性能评价
accuracy = metrics.accuracy_score(y_test,y_pre) print('预测结果为:',y_pre)
print('准确率为:',accuracy)

回归情形

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import lightgbm as lgbm
from sklearn.metrics import mean_squared_error # 生成回归数据集
X, y = make_regression(n_samples = 1000, n_features = 6, random_state = 42)
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42) # 创建lightGBM回归模型
model = lgbm.LGBMRegressor(verbosity=-1)
# 训练模型
model.fit(X_train, y_train) # 进行预测
y_pred = model.predict(X_test) # 计算均方误差评估模型性能
mse = mean_squared_error(y_test, y_pred) print(f"均方误差: {mse}")

End.

下载

LightGBM算法原理及Python实现的更多相关文章

  1. 深入学习主成分分析(PCA)算法原理(Python实现)

    一:引入问题 首先看一个表格,下表是某些学生的语文,数学,物理,化学成绩统计: 首先,假设这些科目成绩不相关,也就是说某一科目考多少分与其他科目没有关系,那么如何判断三个学生的优秀程度呢?首先我们一眼 ...

  2. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  3. KNN算法原理(python代码实现)

    kNN(k-nearest neighbor algorithm)算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性 ...

  4. (数据科学学习手札13)K-medoids聚类算法原理简介&Python与R的实现

    前几篇我们较为详细地介绍了K-means聚类法的实现方法和具体实战,这种方法虽然快速高效,是大规模数据聚类分析中首选的方法,但是它也有一些短板,比如在数据集中有脏数据时,由于其对每一个类的准则函数为平 ...

  5. PageRank算法原理与Python实现

    一.什么是pagerank PageRank的Page可是认为是网页,表示网页排名,也可以认为是Larry Page(google 产品经理),因为他是这个算法的发明者之一,还是google CEO( ...

  6. 【机器学习】:Kmeans均值聚类算法原理(附带Python代码实现)

    这个算法中文名为k均值聚类算法,首先我们在二维的特殊条件下讨论其实现的过程,方便大家理解. 第一步.随机生成质心 由于这是一个无监督学习的算法,因此我们首先在一个二维的坐标轴下随机给定一堆点,并随即给 ...

  7. 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现

    笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...

  8. BP算法从原理到python实现

    BP算法从原理到实践 反向传播算法Backpropagation的python实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 博主接触深度学习已经一段时间,近期在与别人进行讨论时,发现自 ...

  9. 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

    梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...

  10. 模拟退火算法SA原理及python、java、php、c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径

    模拟退火算法SA原理及python.java.php.c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径 模拟退火算法(Simulated Annealing,SA)最早的思 ...

随机推荐

  1. 使用PySide6/PyQt6实现Python跨平台GUI框架的开发

    在前面的<Python开发>中主要介绍了FastAPI的后端Python开发,以及基于WxPython的跨平台GUI的开发过程,由于PySide6/PyQt6 在GUI的用途上也有很大的优 ...

  2. 在OCI上快速静默安装23ai数据库

    拿到同事帮忙申请好的OCI环境[OEL 8.10]后,开始安装23ai数据库用于后续测试,本文选择快速静默安装模式. OCI环境都是opc用户登录的,执行高权限的操作均需要用到sudo命令. 首先创建 ...

  3. 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人

    一.Windows 版 DeepSeek-R1.Ollama 与 AnythingLLM 介绍及核心使用场景‌ ‌一.组件功能与定位‌ ‌DeepSeek-R1‌ ‌模型特性‌:支持 ‌FP16 计算 ...

  4. 【Unit2】电梯调度(多线程设计)-作业总结

    第一次作业 1.1 题目概述 5座楼,每座楼单电梯,类型相同,请求不跨楼层 1.2 个人处理思路 红色加粗为线程类,绿色块为临界区(共享对象) /...鄙人还在加班加点的赶制中.qwq./ 1.3 B ...

  5. Chrome 134 版本新特性

    Chrome 134 版本新特性 一.Chrome 134 版本浏览器更新 1. 在桌面和 iOS 设备上使用 Google Lens 进行屏幕搜索 Chrome 版本 适用平台 发布进度 Chrom ...

  6. Mysql join算法深入浅出

    导语 联表查询在日常的数据库设计中非常的常见,但是联表查询可能会带来性能问题,为了调优.避免设计出有性能问题的SQL,在explain命令中,会显示用的是哪个join算法,学习一下join过程是非常有 ...

  7. AXUI一个面向设计的UI前端框架,好用

    以下是官方介绍: ax的中文意义是:斧子,读音[aeks],取其攻击力强.简单实用之意为本前端框架命名.本团队开发了诸多网站项目,使用了许多常见的前端框架,结合实际项目经验,借鉴了同行的经验,特自主开 ...

  8. 冒泡排序(LOW)

    博客地址:https://www.cnblogs.com/zylyehuo/ # _*_coding:utf-8_*_ import random def bubble_sort(li): for i ...

  9. http状态码413,并提示Request Entity Too Large的解决办法

    使用wordpress的用户经常遇到的问题,就是在后台上传多媒体文件的时候,发现文件大小是有限制的,通常是2M.如图: 如果上传的文件超过2M,服务端返回的状态码会是413,同时提示上传失败.实际上, ...

  10. Sql Server执行情况

    --- 1.查找目前SQL Server所执行的SQL语法,并展示资源情况: SELECT s2.dbid , DB_NAME(s2.dbid) AS [数据库名] , --s1.sql_handle ...