梯度提升树是一种决策树的集成算法。它通过反复迭代训练决策树来最小化损失函数。决策树类似,梯度提升树具有可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。Spark.ml通过使用现有decision tree工具来实现。


梯度提升树依次迭代训练一系列的决策树。在一次迭代中,算法使用现有的集成来对每个训练实例的类别进行预测,然后将预测结果与真实的标签值进行比较。通过重新标记,来赋予预测结果不好的实例更高的权重。所以,在下次迭代中,决策树会对先前的错误进行修正。


对实例标签进行重新标记的机制由损失函数来指定。每次迭代过程中,梯度迭代树在训练数据上进一步减少损失函数的值。spark.ml为分类问题提供一种损失函数(Log Loss),为回归问题提供两种损失函数(平方误差与绝对误差)。


Spark.ml支持二分类以及回归的随机森林算法,适用于连续特征以及类别特征。不支持多分类问题。

# -*- coding: utf-8 -*-
"""
Created on Wed May 9 09:53:30 2018 @author: admin
""" import numpy as np
import matplotlib.pyplot as plt from sklearn import ensemble
from sklearn import datasets
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error # #############################################################################
# Load data
boston = datasets.load_boston()
X, y = shuffle(boston.data, boston.target, random_state=13)
X = X.astype(np.float32)
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:] # #############################################################################
# Fit regression model
params = {'n_estimators': 500, 'max_depth': 4, 'min_samples_split': 2,
'learning_rate': 0.01, 'loss': 'ls'} #随便指定参数长度,也不用在传参的时候去特意定义一个数组传参
clf = ensemble.GradientBoostingRegressor(**params) clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
print("MSE: %.4f" % mse) # #############################################################################
# Plot training deviance # compute test set deviance
test_score = np.zeros((params['n_estimators'],), dtype=np.float64) for i, y_pred in enumerate(clf.staged_predict(X_test)):
test_score[i] = clf.loss_(y_test, y_pred) plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-',
label='Training Set Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-',
label='Test Set Deviance')
plt.legend(loc='upper right')
plt.xlabel('Boosting Iterations')
plt.ylabel('Deviance') # #############################################################################
# Plot feature importance
feature_importance = clf.feature_importances_
# make importances relative to max importance
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + .5
plt.subplot(1, 2, 2)
plt.barh(pos, feature_importance[sorted_idx], align='center')
plt.yticks(pos, boston.feature_names[sorted_idx])
plt.xlabel('Relative Importance')
plt.title('Variable Importance')
plt.show()

房产数据介绍:

- CRIM     per capita crime rate by town 
- ZN       proportion of residential land zoned for lots over 25,000 sq.ft. 
- INDUS    proportion of non-retail business acres per town 
- CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) 
- NOX      nitric oxides concentration (parts per 10 million) 
- RM       average number of rooms per dwelling 
- AGE      proportion of owner-occupied units built prior to 1940 
- DIS      weighted distances to five Boston employment centres 
- RAD      index of accessibility to radial highways 
- TAX      full-value property-tax rate per $10,000 
- PTRATIO  pupil-teacher ratio by town 
- B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town 
- LSTAT    % lower status of the population 
- MEDV     Median value of owner-occupied homes in $1000'

参考:http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html#sphx-glr-auto-examples-ensemble-plot-gradient-boosting-regression-py

GBDT梯度提升树算法及官方案例的更多相关文章

  1. 【小白学AI】GBDT梯度提升详解

    文章来自微信公众号:[机器学习炼丹术] 文章目录: 目录 0 前言 1 基本概念 2 梯度 or 残差 ? 3 残差过于敏感 4 两个基模型的问题 0 前言 先缕一缕几个关系: GBDT是gradie ...

  2. GBDT(梯度提升树)scikit-klearn中的参数说明及简汇

    1.GBDT(梯度提升树)概述: GBDT是集成学习Boosting家族的成员,区别于Adaboosting.adaboosting是利用前一次迭代弱学习器的误差率来更新训练集的权重,在对更新权重后的 ...

  3. 一文读懂:GBDT梯度提升

    先缕一缕几个关系: GBDT是gradient-boost decision tree GBDT的核心就是gradient boost,我们搞清楚什么是gradient boost就可以了 GBDT是 ...

  4. 机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第30篇文章,我们今天来聊一个机器学习时代可以说是最厉害的模型--GBDT. 虽然文无第一武无第二,在机器学习领域并没有 ...

  5. GBDT 梯度提升决策树简述

    首先明确一点,gbdt 无论用于分类还是回归一直都是使用的CART 回归树.不会因为我们所选择的任务是分类任务就选用分类树,这里面的核心是因为gbdt 每轮的训练是在上一轮的训练的残差基础之上进行训练 ...

  6. 梯度提升决策树(GBDT)与XGBoost、LightGBM

    今天是周末,之前给自己定了一个小目标:每周都要写一篇博客,不管是关于什么内容的都行,关键在于总结和思考,今天我选的主题是梯度提升树的一些方法,主要从这些方法的原理以及实现过程入手讲解这个问题. 本文按 ...

  7. 机器学习 之梯度提升树GBDT

    目录 1.基本知识点简介 2.梯度提升树GBDT算法 2.1 思路和原理 2.2 梯度代替残差建立CART回归树 1.基本知识点简介 在集成学习的Boosting提升算法中,有两大家族:第一是AdaB ...

  8. 梯度提升树 Gradient Boosting Decision Tree

    Adaboost + CART 用 CART 决策树来作为 Adaboost 的基础学习器 但是问题在于,需要把决策树改成能接收带权样本输入的版本.(need: weighted DTree(D, u ...

  9. R︱Yandex的梯度提升CatBoost 算法(官方述:超越XGBoost/lightGBM/h2o)

    俄罗斯搜索巨头 Yandex 昨日宣布开源 CatBoost ,这是一种支持类别特征,基于梯度提升决策树的机器学习方法. CatBoost 是由 Yandex 的研究人员和工程师开发的,是 Matri ...

随机推荐

  1. yum配置与使用

    yum的配置一般有两种方式,一种是直接配置/etc目录下的yum.conf文件,另外一种是在/etc/yum.repos.d目录下增加.repo文件. 一.yum的配置文件 [main] cached ...

  2. IPFS初窥

    虽然区块链有很多令人兴奋的特性,但是也有其固有的缺点.比如,文件或者长度较长的文本信息就不适合存储在链上.那么如何解决这个缺点呢?一个解决方案就是IPFS(Interplanetary File Sy ...

  3. Design Patterns in Android

    对日常在 Android 中实用设计模式进行一下梳理和总结,文中参考了一些网站和大佬的博客,如 MichaelX(xiong_it) .菜鸟教程.四月葡萄.IAM四十二等,在这里注明下~另外强烈推荐图 ...

  4. 带你学习ES5中新增的方法

    1. ES5中新增了一些方法,可以很方便的操作数组或者字符串,这些方法主要包括以下几个方面 数组方法 字符串方法 对象方法 2. 数组方法 迭代遍历方法:forEach().map().filter( ...

  5. datatables异步获取数据、简单实用

    IKC项目总结 一.认证难题管理模块 1. 如何使用datatables进行获取数据内容 datatables简介:Datatables是一款jquery表格插件.它是一个高度灵活的工具,可以将任何H ...

  6. pika使用报错queue_declare() missing 1 required positional argument: 'queue'

    报错如下截图,使用pika的版本太高导致,重新安装pika==0.10.0解决.

  7. 从0开发3D引擎(十一):使用领域驱动设计,从最小3D程序中提炼引擎(第二部分)

    目录 上一篇博文 本文流程 回顾上文 解释基本的操作 开始实现 准备 建立代码的文件夹结构,约定模块文件的命名规则 模块文件的命名原则 一级和二级文件夹 api_layer的文件夹 applicati ...

  8. Angular 1 深度解析:脏数据检查与 angular 性能优化

    TL;DR 脏检查是一种模型到视图的数据映射机制,由 $apply 或 $digest 触发. 脏检查的范围是整个页面,不受区域或组件划分影响 使用尽量简单的绑定表达式提升脏检查执行速度 尽量减少页面 ...

  9. 前端开发--vue开发部分报错指南

    前期开发过程中 [Vue warn]: Error in render: "TypeError: Cannot read property '0' of undefined". 解 ...

  10. C++常数优化

    #pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC target("avx") #pragma GCC opti ...