一、概述

  LightGBM 由微软公司开发,是基于梯度提升框架的高效机器学习算法,属于集成学习中提升树家族的一员。它以决策树为基学习器,通过迭代地训练一系列决策树,不断纠正前一棵树的预测误差,逐步提升模型的预测精度,最终将这些决策树的结果进行整合,输出最终的预测结果。

二、算法原理

1.训练过程

(1) 初始化模型

  首先,初始化一个简单的模型,通常是一个常数模型,记为\(f_0(X)\),其预测值为所有样本真实值的均值(回归任务)或多数类(分类任务),记为\(\hat y_0\)。此时,模型的预测结果与真实值之间存在误差。

(2) 计算梯度和 Hessian矩阵

  对于每个样本,计算损失函数关于当前模型预测值的梯度和 Hessian 矩阵(二阶导数),用以确定模型需要调整的方向和幅度。例如,在均方误差损失函数下,梯度就是预测值与真实值之间的差值。

(3) 构建决策树

  基于计算得到的梯度和 Hessian,构建一棵新的决策树,使用直方图算法等优化技术来加速决策树的构建过程。分裂节点的依据是最大化信息增益或最小化损失函数的减少量。同时,为了防止过拟合,应用一些剪枝策略,如限制树的深度、叶子节点的最小样本数等。

(4) 更新模型

  根据新训练的决策树,更新当前模型。更新公式为

\[\hat Y_1 = \hat Y_0 + \alpha f_1(X)
\]

,也就是将新决策树的输出乘以学习率,加到当前模型的预测值上。其中学习率 (也称为步长),用于控制每棵树对模型更新的贡献程度。学习率较小可以使模型训练更加稳定,但需要更多的迭代次数;学习率较大则可能导致模型收敛过快,甚至无法收敛。

(5) 重复迭代

  重复步骤 (2)–(4)步,不断训练新的决策树并更新模型,直到达到预设的迭代次数、损失函数收敛到一定程度或满足其他停止条件为止。最终,LightGBM 模型由多棵决策树组成,其预测结果是所有决策树预测结果的累加。

算法过程图示

2. 直方图算法

  LightGBM 采用了直方图算法来加速决策树的构建过程。传统的决策树算法在寻找最佳分裂点时,需要遍历所有的特征值,计算量巨大。而 LightGBM 将连续的特征值离散化成 k 个桶(bin),构建一个宽度为 k 的直方图。在寻找最佳分裂点时,只需遍历直方图中的 k 个值,大大减少了计算量,提高了算法的训练速度。

3. 单边梯度采样(GOSS)

  数据集中存在大量梯度较小的样本,这些样本对模型的提升作用较小,但在计算梯度时却占用了大量的计算资源。单边梯度采样(GOSS)根据样本的梯度大小对样本进行采样,保留梯度较大的样本,并对梯度较小的样本进行随机采样,在不影响模型精度的前提下,减少了训练数据量,提高了训练效率。

4. 互斥特征捆绑(EFB)

  在实际数据中,许多特征是相互关联的,存在大量的稀疏特征。互斥特征捆绑(EFB)算法将互斥的特征捆绑在一起,形成一个新的特征,从而减少特征的数量。这样在构建决策树时,就可以减少计算量,提高算法的运行效率。

三、算法优势

1.训练速度快

  得益于直方图算法、GOSS 和 EFB 等技术,LightGBM 在处理大规模数据时,训练速度相比传统的梯度提升算法有显著提升。无论是处理小数据集还是大数据集,都能在较短时间内完成模型训练。

2.内存占用少

  通过直方图算法和特征捆绑等技术,LightGBM 有效减少了数据存储和计算过程中的内存占用。这使得它可以在资源有限的环境下运行,如在个人计算机上处理大规模数据,或者在内存受限的服务器上同时运行多个模型训练任务。

3.可扩展性强

  LightGBM 支持分布式训练,可以充分利用多台计算机的计算资源,加快训练速度。它还支持并行学习,能够同时处理多个特征,进一步提高训练效率。这种强大的可扩展性使得它能够适应各种规模和复杂度的机器学习任务。

4.准确率高

  尽管 LightGBM 在训练过程中采用了多种优化技术来提高效率,但它并没有牺牲模型的准确率。在许多实际应用和机器学习竞赛中,LightGBM 都能取得与其他先进算法相当甚至更优的预测结果。

5.支持多种数据类型和任务

LightGBM 不仅支持常见的数值型和类别型数据,还能处理稀疏数据。同时,它广泛应用于回归、分类、排序等多种机器学习任务,具有很强的通用性。

四、Python实现

(Python 3.11,scikit-learn 1.6.1)

分类情形

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import lightgbm as lgbm
from sklearn import metrics # 生成数据集
X, y = make_classification(n_samples = 1000, n_features = 6, random_state = 42)
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42) # 创建lightGBM分类模型
clf = lgbm.LGBMClassifier(verbosity=-1)
# 训练模型
rclf = clf.fit(X_train, y_train) # 预测
y_pre = rclf.predict(X_test)
# 性能评价
accuracy = metrics.accuracy_score(y_test,y_pre) print('预测结果为:',y_pre)
print('准确率为:',accuracy)

回归情形

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import lightgbm as lgbm
from sklearn.metrics import mean_squared_error # 生成回归数据集
X, y = make_regression(n_samples = 1000, n_features = 6, random_state = 42)
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42) # 创建lightGBM回归模型
model = lgbm.LGBMRegressor(verbosity=-1)
# 训练模型
model.fit(X_train, y_train) # 进行预测
y_pred = model.predict(X_test) # 计算均方误差评估模型性能
mse = mean_squared_error(y_test, y_pred) print(f"均方误差: {mse}")

End.

下载

LightGBM算法原理及Python实现的更多相关文章

  1. 深入学习主成分分析(PCA)算法原理(Python实现)

    一:引入问题 首先看一个表格,下表是某些学生的语文,数学,物理,化学成绩统计: 首先,假设这些科目成绩不相关,也就是说某一科目考多少分与其他科目没有关系,那么如何判断三个学生的优秀程度呢?首先我们一眼 ...

  2. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  3. KNN算法原理(python代码实现)

    kNN(k-nearest neighbor algorithm)算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性 ...

  4. (数据科学学习手札13)K-medoids聚类算法原理简介&Python与R的实现

    前几篇我们较为详细地介绍了K-means聚类法的实现方法和具体实战,这种方法虽然快速高效,是大规模数据聚类分析中首选的方法,但是它也有一些短板,比如在数据集中有脏数据时,由于其对每一个类的准则函数为平 ...

  5. PageRank算法原理与Python实现

    一.什么是pagerank PageRank的Page可是认为是网页,表示网页排名,也可以认为是Larry Page(google 产品经理),因为他是这个算法的发明者之一,还是google CEO( ...

  6. 【机器学习】:Kmeans均值聚类算法原理(附带Python代码实现)

    这个算法中文名为k均值聚类算法,首先我们在二维的特殊条件下讨论其实现的过程,方便大家理解. 第一步.随机生成质心 由于这是一个无监督学习的算法,因此我们首先在一个二维的坐标轴下随机给定一堆点,并随即给 ...

  7. 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现

    笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...

  8. BP算法从原理到python实现

    BP算法从原理到实践 反向传播算法Backpropagation的python实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 博主接触深度学习已经一段时间,近期在与别人进行讨论时,发现自 ...

  9. 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

    梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...

  10. 模拟退火算法SA原理及python、java、php、c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径

    模拟退火算法SA原理及python.java.php.c++语言代码实现TSP旅行商问题,智能优化算法,随机寻优算法,全局最短路径 模拟退火算法(Simulated Annealing,SA)最早的思 ...

随机推荐

  1. 流程控制之do while循环

    语法 do {    //代码语句}while(布尔表达式): 与while的区别 while是先判断再执行,do while是先执行再判断 循环体至少会被执行一次 实例1: package com. ...

  2. P9869 [NOIP2023] 三值逻辑 题解

    NOIP2023 T2 三值逻辑 题解 题面 思路 乍一看好像很并查集,而且不太难,但是, 注意到:按顺序运行这 \(m\) 条语句 事情并没有那么简单. 比如说如下情况: x1:=T x2:=x1 ...

  3. CF1837E Play Fixing 题解

    首先来考虑什么情况方案数为 \(0\): 可以确定,在某一层中,两个原本都能晋级的队伍比赛: 可以确定,在某一层中,两个原本都不能晋级的队伍比赛. 发现假如写出每一场比赛及其胜者,可以形成一棵树形结构 ...

  4. 函数static的作用

    限制作用域和保持状态 ‌函数static的作用主要体现在限制作用域和保持状态两个方面.‌‌1 限制作用域 ‌静态全局变量‌:在全局变量前加上static关键字,该变量就被定义成为一个静态全局变量.这种 ...

  5. OSAL架构

    OSAL操作系统最多可以支持16个任务,由任务功耗管理PwrMgr_task_state变量可知,而OSAL每个任务最多只能支持16个事件处理,理论上最大可以执行256个事件处理. 对于一些运算能力不 ...

  6. android studio真机调试华为手机

    背景 近来开发一个视频通话App,需要在华为手机上调试,按网上一顿操作,开启了USB调试之后,发现手机连上电脑后,android studio没反应,在此记录下解决方法.调试的手机型号是华为 nova ...

  7. Golang 实现本地持久化缓存

    // Copyright (c) 2024 LiuShuKu // Project Name : balance // Author : liushuku@yeah.net package cache ...

  8. QT5笔记: 25. 非模态的自定义对话框

    窗口对象为QDialog 显示方法为 show(); locateCell->show(); 可以通过public方法或者信号槽机制获取非模态窗口的信息 例子:非模态窗口,为主窗口数据输入吧 v ...

  9. Scala样例类及底层实现伴生对象

    package com.wyh.day01 /** * 样例类的使用 * 1.使用case修饰类 * 2.不需要写构造方法,getter,setter方法,toString方法 * 3.直接通过对象名 ...

  10. 免费的编程连字等宽字体:Fira Code

    免费的编程连字等宽字体:Fira Code 介绍和特征 介绍 Fira 是 Mozilla 公司 主推的字体系列.Fira Code 专为写程序而生,开源免费.除了具有等宽等基本属性外,还加入了编程连 ...