谈谈模型融合之三 —— GBDT
前言
本来应该是年后就要写的一篇博客,因为考完试后忙了一段时间课设和实验,然后回家后又在摸鱼,就一直没开动。趁着这段时间只能呆在家里来把这些博客补上。在之前的文章中介绍了 Random Forest 和 AdaBoost,这篇文章将介绍介绍在数据挖掘竞赛中,最常用的算法之一 —— GBDT(Gradient Boosting Decision Tree)。
GBDT
原理
GBDT
实际上是 GBM(Gradient Boosting Machine)
中的一种,采用 CART 树作为基学习器,所以称为 GBDT。与 AdaBoost 一样,GBDT 为 Boosting 家族中的一员。其表达式为
\]
其中\(T(x;\Theta_m)\)表示决策树;\(\Theta_m\)为决策树参数;M为树的个数。
这里来回顾下 AdaBoost,AdaBoost 通过不断调整样本权重,使得上一次分类错误的样本权重变大,最后训练出 m 个弱分类器,然后对 m 个弱分类器加权结合形成强分类器。
而 GBDT 又是另一思想,当我们假设前一轮迭代得到的学习器为 \(f_{m-1}(x)\) ,损失函数为 \(L(y, f_{m-1}(x))\) ,那么,本轮迭代的目标就是使损失函数 \(L(y, f_{m-1}(x) + h_m(x))\) 的值尽可能的小。
我们先将损失函数假设为最常用的平方损失
令 \(r = y - f_{m-1}(x)\) ,那么第 m 步的目标就是最小化 \(L(y, f_m(x)) = \frac{1}{2}(y-f_m(x))^2=\frac{1}{2}(r-h_m(x))^2\)
到这里,似乎发现了点什么,我们对损失函数求导,发现:
\]
看出什么来了没?对其取个负号就变成 \(r-h_m(x)\) ,即我们常说的残差
,所以,当为平方损失函数时,弱学习器拟合的目标为之前拟合结果的残差。那到这里代码就很好写了,但是,我们实际中有很多其它的损失函数,而且在很多问题中,这些损失函数比平方损失要好得多。那这时候,如果我们还采用同样的思路,那就没办法像上面那样直接展开并拟合残差了,这时候该怎么办?
这里别忘了,我们最终的目标是使得 \(L(y, f_m(x))\) 最小,那么只需要保证 \(L(y, f_{m-1}(x)+h_m(x))\) 的值比 \(L(y, f_{m-1}(x))\) 小就好了。
即
\]
检验大一高数学的怎么样的时候到了 orz
我们前面说了第 m 轮迭代的损失函数为 \(L(y, f_{m-1}(x) + h_m(x))\) ,换一种形式,写成 \(L(f_{m-1}(x) + h_m(x))\) ,对其进行一阶泰勒展开,得
\]
所以,我们只需使得满足
L'(f_{m-1}(x))h_m(x)<0
\]
那我们的 \(h_m(x)\) 到底要拟合什么呢?别忘了,我们是要求梯度的,在这里我们已知的是 \(L'(f_{m-1}(x))\) ,我们肯定是根据上一次的拟合的结果来拟合这一次的结果,所以,要使得结果最大,自然就是梯度方向。那么 \(h_m(x)=-L'(f_{m-1}(x))\) , 这样原先的 \(r\) 也就变成了梯度。这里如果把损失函数看作是平方损失,我们得到的结果也恰好就是我们所说的残差!!
此时也总算明白了之前面腾讯的时候我说 GBDT 是拟合残差的时候面试官让我再回去重新康康这个算法的原因了。
算法步骤
输入: 训练数据集 \(T = {(x_1, y_1),(x_2, y_2), ..., (x_N, y_N)}, x_i \in X \subset R^n, y_i \in Y \subset R\); 损失函数 L(y,f(x)),树的个数M.
输出: 梯度提升树\(F_M(x)\)
(1) 初始化 \(f_0(x) = argmin_c \Sigma_{i=1}^N L(y_i,c)\).
(2) 对 \(m=1,2,...,M\)
(a) 对\(i =1,2,...,N\),计算, \(r_{mi} = - [\frac{\partial L(y_i, f(x_i))}{\partial f(x_i)}]_{f(x) = F_{m-1}(x)}\).
(b) 拟合残差\(r_{mi}\)学习一个回归树,得到\(f_m(x)\).
(c) 更新\(F_m(x) = F_{m-1}(x) + f_m(x)\).
(3) 得到回归问题提升树 \(F_M(x) = \Sigma_{i=0}^M f_i(x)\)
代码
这里代码是采用了平方损失的方法来写的,且解决的是分类问题
def sigmoid(x):
"""
计算sigmoid函数值
"""
return 1 / (1 + np.exp(-x))
def gbdt_classifier(X, y, M, max_depth=None):
"""
用于分类的GBDT函数
参数:
X: 训练样本
y: 样本标签,y = {0, +1}
M: 使用M个回归树
max_depth: 基学习器CART决策树的最大深度
返回:
F: 生成的模型
"""
# 用0初始化y_reg
y_reg = np.zeros(len(y))
f = []
for m in range(M):
# 计算r
r = y - sigmoid(y_reg)
# 拟合r
# 使用DecisionTreeRegressor,设置树深度为5,random_state=0
f_m = DecisionTreeRegressor(max_depth=5, random_state=0)
# 开始训练
f_m.fit(X, r)
# 更新f
f.append(f_m)
y_reg += f_m.predict(X)
def F(X):
num_X, _ = X.shape
reg = np.zeros((num_X))
for t in f:
reg += t.predict(X)
y_pred_gbdt = sigmoid(reg)
# 以0.5为阈值,得到最终分类结果0或1
one_position = y_pred_gbdt >= 0.5
y_pred_gbdt[one_position] = 1
y_pred_gbdt[~one_position] = 0
return y_pred_gbdt
return F
小节
到这里 GBDT 也就讲完了,从决策树 ID3 开始一直到 GBDT,后面终于要迎来最开始想要梳理的数据挖掘的两大杀器 XGBoost 和 LightGBM 了,下一篇将介绍 XGBoost。
谈谈模型融合之三 —— GBDT的更多相关文章
- 谈谈模型融合之一 —— 集成学习与 AdaBoost
前言 前面的文章中介绍了决策树以及其它一些算法,但是,会发现,有时候使用使用这些算法并不能达到特别好的效果.于是乎就有了集成学习(Ensemble Learning),通过构建多个学习器一起结合来完成 ...
- 在Caffe中实现模型融合
模型融合 有的时候我们手头可能有了若干个已经训练好的模型,这些模型可能是同样的结构,也可能是不同的结构,训练模型的数据可能是同一批,也可能不同.无论是出于要通过ensemble提升性能的目的,还是要设 ...
- Gluon炼丹(Kaggle 120种狗分类,迁移学习加双模型融合)
这是在kaggle上的一个练习比赛,使用的是ImageNet数据集的子集. 注意,mxnet版本要高于0.12.1b2017112. 下载数据集. train.zip test.zip labels ...
- 深度学习模型stacking模型融合python代码,看了你就会使
话不多说,直接上代码 def stacking_first(train, train_y, test): savepath = './stack_op{}_dt{}_tfidf{}/'.format( ...
- 深度学习模型融合stacking
当你的深度学习模型变得很多时,选一个确定的模型也是一个头痛的问题.或者你可以把他们都用起来,就进行模型融合.我主要使用stacking和blend方法.先把代码贴出来,大家可以看一下. import ...
- 模型融合——stacking原理与实现
一般提升模型效果从两个大的方面入手 数据层面:数据增强.特征工程等 模型层面:调参,模型融合 模型融合:通过融合多个不同的模型,可能提升机器学习的性能.这一方法在各种机器学习比赛中广泛应用, 也是在比 ...
- 机器学习 | 从加法模型讲到GBDT算法
作者:JSong, 日期:2017.10.10 集成学习(ensemble learning)通过构建并结合多个学习器来完成学习任务,常可获得比单一学习器显著优越的泛化性能,这对"弱学习器& ...
- 模型融合之blending和stacking
1. blending 需要得到各个模型结果集的权重,然后再线性组合. """Kaggle competition: Predicting a Biological Re ...
- 22(7).模型融合---CatBoost
一.Catboost简介 全称:Gradient Boosting(梯度提升) + Categorical Features(类别型特征) 作者:俄罗斯的搜索巨头Yandex 官方地址 论文链接 | ...
随机推荐
- 对“TD信息树”的使用体验
在本次同2017级学长进行的软件交流会上,我们有幸使用学长们开发的软件与成果,进过27个不尽相同的软件的使用,让我初步意识到了学习软件工程这门学科的实用价值.最终我选择了"TD信息树&quo ...
- Linux 命令整理 vim
Vim 一.官方网站 http://www.vim.org 二.背景 所有的 Unix Like 系统都会内建 vi 文书编辑器,但是在我们编程这里开发使用最多的要数 vim命令了. 三.操作 三种 ...
- Java中的变量、数据类型和运算符
1. java语言是一种强类型的语言,对各种数据类型都有明确的区分,而计算机使用内存来记忆大量运算时需要使用的数据,而当声明一个变量时,即在内存中划分一块空间存储数据,而变量类型决定划分内存空间的大小 ...
- 使用element的upload组件实现一个完整的文件上传功能(下)
本篇文章是<使用element的upload组件实现一个完整的文件上传功能(上)>的续篇. 话不多说,接着上一篇直接开始 一.功能完善—保存表格中每一列的文件列表状态 1.思路 保存表格中 ...
- 掌握这些,ArrayList就不用担心了!
关于ArrayList的学习 ArrayList属于Java基础知识,面试中会经常问到,所以作为一个Java从业者,它是你不得不掌握的一个知识点.
- $bzoj4237$稻草人 $cdq$分治
正解:$cdq$分治 解题报告: 传送门$QwQ$ $umm$总感觉做过这题的亚子,,,? 先把坐标离散化,然后把所有点先按$x$排序$QwQ$,然后用类似平面最近点对的方法,先分别解决$mid$两侧 ...
- 4.eclipse中导入别人用的源代码问题
最近在导入别人用的源代码问题时,出现两个问题: 问题一:提示无法解析导入,如下图: 解决方法:删除项目下的module-info.java文件即可,或者在创建项目时将创建module-info.jav ...
- Ant Design中根据用户交互展示不同的标签
Ant Design中根据用户交互展示不同的标签 Ant Design使用的是React框架,那么我们先看代码: <Fragment> <a onClick={() => th ...
- spring boot使用拦截器
1.编写一个拦截器 首先,我们先编写一个拦截器,和spring mvc方式一样.实现HandlerInterceptor类,代码如下 package com.example.demo.intercep ...
- P2871 [USACO07DEC]手链Charm Bracelet(01背包模板)
题目传送门:P2871 [USACO07DEC]手链Charm Bracelet 题目描述 Bessie has gone to the mall's jewelry store and spies ...