梯度提升树是一种决策树的集成算法。它通过反复迭代训练决策树来最小化损失函数。决策树类似,梯度提升树具有可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。Spark.ml通过使用现有decision tree工具来实现。


梯度提升树依次迭代训练一系列的决策树。在一次迭代中,算法使用现有的集成来对每个训练实例的类别进行预测,然后将预测结果与真实的标签值进行比较。通过重新标记,来赋予预测结果不好的实例更高的权重。所以,在下次迭代中,决策树会对先前的错误进行修正。


对实例标签进行重新标记的机制由损失函数来指定。每次迭代过程中,梯度迭代树在训练数据上进一步减少损失函数的值。spark.ml为分类问题提供一种损失函数(Log Loss),为回归问题提供两种损失函数(平方误差与绝对误差)。


Spark.ml支持二分类以及回归的随机森林算法,适用于连续特征以及类别特征。不支持多分类问题。

# -*- coding: utf-8 -*-
"""
Created on Wed May 9 09:53:30 2018 @author: admin
""" import numpy as np
import matplotlib.pyplot as plt from sklearn import ensemble
from sklearn import datasets
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error # #############################################################################
# Load data
boston = datasets.load_boston()
X, y = shuffle(boston.data, boston.target, random_state=13)
X = X.astype(np.float32)
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:] # #############################################################################
# Fit regression model
params = {'n_estimators': 500, 'max_depth': 4, 'min_samples_split': 2,
'learning_rate': 0.01, 'loss': 'ls'} #随便指定参数长度,也不用在传参的时候去特意定义一个数组传参
clf = ensemble.GradientBoostingRegressor(**params) clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
print("MSE: %.4f" % mse) # #############################################################################
# Plot training deviance # compute test set deviance
test_score = np.zeros((params['n_estimators'],), dtype=np.float64) for i, y_pred in enumerate(clf.staged_predict(X_test)):
test_score[i] = clf.loss_(y_test, y_pred) plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-',
label='Training Set Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-',
label='Test Set Deviance')
plt.legend(loc='upper right')
plt.xlabel('Boosting Iterations')
plt.ylabel('Deviance') # #############################################################################
# Plot feature importance
feature_importance = clf.feature_importances_
# make importances relative to max importance
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + .5
plt.subplot(1, 2, 2)
plt.barh(pos, feature_importance[sorted_idx], align='center')
plt.yticks(pos, boston.feature_names[sorted_idx])
plt.xlabel('Relative Importance')
plt.title('Variable Importance')
plt.show()

房产数据介绍:

- CRIM     per capita crime rate by town 
- ZN       proportion of residential land zoned for lots over 25,000 sq.ft. 
- INDUS    proportion of non-retail business acres per town 
- CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) 
- NOX      nitric oxides concentration (parts per 10 million) 
- RM       average number of rooms per dwelling 
- AGE      proportion of owner-occupied units built prior to 1940 
- DIS      weighted distances to five Boston employment centres 
- RAD      index of accessibility to radial highways 
- TAX      full-value property-tax rate per $10,000 
- PTRATIO  pupil-teacher ratio by town 
- B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town 
- LSTAT    % lower status of the population 
- MEDV     Median value of owner-occupied homes in $1000'

参考:http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html#sphx-glr-auto-examples-ensemble-plot-gradient-boosting-regression-py

GBDT梯度提升树算法及官方案例的更多相关文章

  1. 【小白学AI】GBDT梯度提升详解

    文章来自微信公众号:[机器学习炼丹术] 文章目录: 目录 0 前言 1 基本概念 2 梯度 or 残差 ? 3 残差过于敏感 4 两个基模型的问题 0 前言 先缕一缕几个关系: GBDT是gradie ...

  2. GBDT(梯度提升树)scikit-klearn中的参数说明及简汇

    1.GBDT(梯度提升树)概述: GBDT是集成学习Boosting家族的成员,区别于Adaboosting.adaboosting是利用前一次迭代弱学习器的误差率来更新训练集的权重,在对更新权重后的 ...

  3. 一文读懂:GBDT梯度提升

    先缕一缕几个关系: GBDT是gradient-boost decision tree GBDT的核心就是gradient boost,我们搞清楚什么是gradient boost就可以了 GBDT是 ...

  4. 机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第30篇文章,我们今天来聊一个机器学习时代可以说是最厉害的模型--GBDT. 虽然文无第一武无第二,在机器学习领域并没有 ...

  5. GBDT 梯度提升决策树简述

    首先明确一点,gbdt 无论用于分类还是回归一直都是使用的CART 回归树.不会因为我们所选择的任务是分类任务就选用分类树,这里面的核心是因为gbdt 每轮的训练是在上一轮的训练的残差基础之上进行训练 ...

  6. 梯度提升决策树(GBDT)与XGBoost、LightGBM

    今天是周末,之前给自己定了一个小目标:每周都要写一篇博客,不管是关于什么内容的都行,关键在于总结和思考,今天我选的主题是梯度提升树的一些方法,主要从这些方法的原理以及实现过程入手讲解这个问题. 本文按 ...

  7. 机器学习 之梯度提升树GBDT

    目录 1.基本知识点简介 2.梯度提升树GBDT算法 2.1 思路和原理 2.2 梯度代替残差建立CART回归树 1.基本知识点简介 在集成学习的Boosting提升算法中,有两大家族:第一是AdaB ...

  8. 梯度提升树 Gradient Boosting Decision Tree

    Adaboost + CART 用 CART 决策树来作为 Adaboost 的基础学习器 但是问题在于,需要把决策树改成能接收带权样本输入的版本.(need: weighted DTree(D, u ...

  9. R︱Yandex的梯度提升CatBoost 算法(官方述:超越XGBoost/lightGBM/h2o)

    俄罗斯搜索巨头 Yandex 昨日宣布开源 CatBoost ,这是一种支持类别特征,基于梯度提升决策树的机器学习方法. CatBoost 是由 Yandex 的研究人员和工程师开发的,是 Matri ...

随机推荐

  1. Random Forest And Extra Trees

    随机森林 我们对使用决策树随机取样的集成学习有个形象的名字–随机森林. scikit-learn 中封装的随机森林,在决策树的节点划分上,在随机的特征子集上寻找最优划分特征. import numpy ...

  2. STM32F103驱动ADS1118

    ADS1118 作为常用温度测量芯片被越来越多的开发者熟知,TI官方给出的是基于 MSP430 的驱动测试程序,由于 STM32 的普及,闲暇中移植了 MSP430 的 ADS1118 驱动程序到 S ...

  3. FBI今年最重要的任务:招募黑客

    ​ 当FBI(联邦调查局)一次又一次被爆出丑闻的时候,面临着一个又一个的尴尬局面.在这样的情况下,FBI发现了自己的一个巨大问题,那就是以前都依靠隐秘行动和人员的保密性来保证国家的安全,现在必须依靠更 ...

  4. 事务Transaction

    目录 为什么写这系列的文章 事务概念 ACID 并发事务导致的问题 脏读(Dirty Read) 非重复读(Nonrepeatable Read) 幻读(Phantom Reads) 丢失修改(Los ...

  5. Dart 运行速度测评与比较

    引言 Dart 是一门优秀的跨平台语言,尽管生态方面略有欠缺,但无疑作为一门编程语言来说,Dart 是很优美,很健壮的,同时也引入了一些先进的编程范式,值得去学习. 测试内容 现在,我们就来测评一下D ...

  6. iOS中的分类和扩展

    一.什么是分类? 概念:分类(Category)是OC中的特有语法,它是表示一个指向分类的结构体指针.根据下面源码组成可以看到它没有属性列表,原则上是不能添加成员变量(其实可以借助运行时功能,进行关联 ...

  7. crypto-js aes加密解密

    安装 npm install crypto-js --save unit.js import CryptoJS from "crypto-js"; //秘钥 const CRYPT ...

  8. vue+webpack工程环境搭建

    使用Vue-cli脚手架(属于vue全家桶)快速构建一个项目: [1]首先需要安装好node.js; [2]安装webpack,指令$npm install -g webpack; //如果之前有安装 ...

  9. Web安全相关(一):CSRF/XSRF(跨站请求伪造)和XSS(跨站脚本)

    XSS(Cross Site Script):跨站脚本,也就是javascript脚本注入,一般在站点中的富文本框,里面发表文章,留言等表单,这种表单一般是写入数据库,然后再某个页面打开. 防御: 1 ...

  10. 开发RTSP 直播软件 H264 AAC 编码

    上一篇对摄像头预览,拍照做了大概的介绍,现在已经可以拿到视频帧了,在加上 RTSP 实现,就是直播的雏形,当然还要加上一些 WEB 管理和手机平台的支援,就是一整套直播软件. 介绍一些基础概念:RTP ...