梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
http://blog.csdn.net/liulingyuan6/article/details/53426350
梯度迭代树
算法简介:
梯度提升树是一种决策树的集成算法。它通过反复迭代训练决策树来最小化损失函数。决策树类似,梯度提升树具有可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。Spark.ml通过使用现有decision tree工具来实现。
梯度提升树依次迭代训练一系列的决策树。在一次迭代中,算法使用现有的集成来对每个训练实例的类别进行预测,然后将预测结果与真实的标签值进行比较。通过重新标记,来赋予预测结果不好的实例更高的权重。所以,在下次迭代中,决策树会对先前的错误进行修正。
对实例标签进行重新标记的机制由损失函数来指定。每次迭代过程中,梯度迭代树在训练数据上进一步减少损失函数的值。spark.ml为分类问题提供一种损失函数(Log Loss),为回归问题提供两种损失函数(平方误差与绝对误差)。
Spark.ml支持二分类以及回归的随机森林算法,适用于连续特征以及类别特征。
*注意梯度提升树目前不支持多分类问题。
参数:
checkpointInterval:
类型:整数型。
含义:设置检查点间隔(>=1),或不设置检查点(-1)。
featuresCol:
类型:字符串型。
含义:特征列名。
impurity:
类型:字符串型。
含义:计算信息增益的准则(不区分大小写)。
labelCol:
类型:字符串型。
含义:标签列名。
lossType:
类型:字符串型。
含义:损失函数类型。
maxBins:
类型:整数型。
含义:连续特征离散化的最大数量,以及选择每个节点分裂特征的方式。
maxDepth:
类型:整数型。
含义:树的最大深度(>=0)。
maxIter:
类型:整数型。
含义:迭代次数(>=0)。
minInfoGain:
类型:双精度型。
含义:分裂节点时所需最小信息增益。
minInstancesPerNode:
类型:整数型。
含义:分裂后自节点最少包含的实例数量。
predictionCol:
类型:字符串型。
含义:预测结果列名。
rawPredictionCol:
类型:字符串型。
含义:原始预测。
seed:
类型:长整型。
含义:随机种子。
subsamplingRate:
类型:双精度型。
含义:学习一棵决策树使用的训练数据比例,范围[0,1]。
stepSize:
类型:双精度型。
含义:每次迭代优化步长。
示例:
下面的例子导入LibSVM格式数据,并将之划分为训练数据和测试数据。使用第一部分数据进行训练,剩下数据来测试。训练之前我们使用了两种数据预处理方法来对特征进行转换,并且添加了元数据到DataFrame。
Scala:
- import org.apache.spark.ml.Pipeline
- import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
- // Load and parse the data file, converting it to a DataFrame.
- val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
- // Index labels, adding metadata to the label column.
- // Fit on whole dataset to include all labels in index.
- val labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data)
- // Automatically identify categorical features, and index them.
- // Set maxCategories so features with > 4 distinct values are treated as continuous.
- val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data)
- // Split the data into training and test sets (30% held out for testing).
- val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
- // Train a GBT model.
- val gbt = new GBTClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10)
- // Convert indexed labels back to original labels.
- val labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels)
- // Chain indexers and GBT in a Pipeline.
- val pipeline = new Pipeline()
- .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))
- // Train model. This also runs the indexers.
- val model = pipeline.fit(trainingData)
- // Make predictions.
- val predictions = model.transform(testData)
- // Select example rows to display.
- predictions.select("predictedLabel", "label", "features").show(5)
- // Select (prediction, true label) and compute test error.
- val evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("accuracy")
- val accuracy = evaluator.evaluate(predictions)
- println("Test Error = " + (1.0 - accuracy))
- val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
- println("Learned classification GBT model:\n" + gbtModel.toDebugString)
Java:
- import org.apache.spark.ml.Pipeline;
- import org.apache.spark.ml.PipelineModel;
- import org.apache.spark.ml.PipelineStage;
- import org.apache.spark.ml.classification.GBTClassificationModel;
- import org.apache.spark.ml.classification.GBTClassifier;
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
- import org.apache.spark.ml.feature.*;
- import org.apache.spark.sql.Dataset;
- import org.apache.spark.sql.Row;
- import org.apache.spark.sql.SparkSession;
- // Load and parse the data file, converting it to a DataFrame.
- Dataset<Row> data = spark
- .read()
- .format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
- // Index labels, adding metadata to the label column.
- // Fit on whole dataset to include all labels in index.
- StringIndexerModel labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data);
- // Automatically identify categorical features, and index them.
- // Set maxCategories so features with > 4 distinct values are treated as continuous.
- VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data);
- // Split the data into training and test sets (30% held out for testing)
- Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
- Dataset<Row> trainingData = splits[0];
- Dataset<Row> testData = splits[1];
- // Train a GBT model.
- GBTClassifier gbt = new GBTClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10);
- // Convert indexed labels back to original labels.
- IndexToString labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels());
- // Chain indexers and GBT in a Pipeline.
- Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});
- // Train model. This also runs the indexers.
- PipelineModel model = pipeline.fit(trainingData);
- // Make predictions.
- Dataset<Row> predictions = model.transform(testData);
- // Select example rows to display.
- predictions.select("predictedLabel", "label", "features").show(5);
- // Select (prediction, true label) and compute test error.
- MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("accuracy");
- double accuracy = evaluator.evaluate(predictions);
- System.out.println("Test Error = " + (1.0 - accuracy));
- GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]);
- System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
Python:
- from pyspark.ml import Pipeline
- from pyspark.ml.classification import GBTClassifier
- from pyspark.ml.feature import StringIndexer, VectorIndexer
- from pyspark.ml.evaluation import MulticlassClassificationEvaluator
- # Load and parse the data file, converting it to a DataFrame.
- data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
- # Index labels, adding metadata to the label column.
- # Fit on whole dataset to include all labels in index.
- labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
- # Automatically identify categorical features, and index them.
- # Set maxCategories so features with > 4 distinct values are treated as continuous.
- featureIndexer =\
- VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
- # Split the data into training and test sets (30% held out for testing)
- (trainingData, testData) = data.randomSplit([0.7, 0.3])
- # Train a GBT model.
- gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
- # Chain indexers and GBT in a Pipeline
- pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
- # Train model. This also runs the indexers.
- model = pipeline.fit(trainingData)
- # Make predictions.
- predictions = model.transform(testData)
- # Select example rows to display.
- predictions.select("prediction", "indexedLabel", "features").show(5)
- # Select (prediction, true label) and compute test error
- evaluator = MulticlassClassificationEvaluator(
- labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
- accuracy = evaluator.evaluate(predictions)
- print("Test Error = %g" % (1.0 - accuracy))
- gbtModel = model.stages[2]
- print(gbtModel) # summary only
梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)的更多相关文章
- 梯度提升树GBDT算法
转自https://zhuanlan.zhihu.com/p/29802325 本文对Boosting家族中一个重要的算法梯度提升树(Gradient Boosting Decison Tree, 简 ...
- 朴素贝叶斯算法原理及Spark MLlib实例(Scala/Java/Python)
朴素贝叶斯 算法介绍: 朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法. 朴素贝叶斯的思想基础是这样的:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,在没有其它可用信息下,我 ...
- 三种文本特征提取(TF-IDF/Word2Vec/CountVectorizer)及Spark MLlib调用实例(Scala/Java/python)
https://blog.csdn.net/liulingyuan6/article/details/53390949
- 机器学习 之梯度提升树GBDT
目录 1.基本知识点简介 2.梯度提升树GBDT算法 2.1 思路和原理 2.2 梯度代替残差建立CART回归树 1.基本知识点简介 在集成学习的Boosting提升算法中,有两大家族:第一是AdaB ...
- 梯度提升树(GBDT)原理小结(转载)
在集成学习值Adaboost算法原理和代码小结(转载)中,我们对Boosting家族的Adaboost算法做了总结,本文就对Boosting家族中另一个重要的算法梯度提升树(Gradient Boos ...
- scikit-learn 梯度提升树(GBDT)调参小结
在梯度提升树(GBDT)原理小结中,我们对GBDT的原理做了总结,本文我们就从scikit-learn里GBDT的类库使用方法作一个总结,主要会关注调参中的一些要点. 1. scikit-learn ...
- GBDT算法原理深入解析
GBDT算法原理深入解析 标签: 机器学习 集成学习 GBM GBDT XGBoost 梯度提升(Gradient boosting)是一种用于回归.分类和排序任务的机器学习技术,属于Boosting ...
- 机器学习(七)—Adaboost 和 梯度提升树GBDT
1.Adaboost算法原理,优缺点: 理论上任何学习器都可以用于Adaboost.但一般来说,使用最广泛的Adaboost弱学习器是决策树和神经网络.对于决策树,Adaboost分类用了CART分类 ...
- scikit-learn 梯度提升树(GBDT)调参笔记
在梯度提升树(GBDT)原理小结中,我们对GBDT的原理做了总结,本文我们就从scikit-learn里GBDT的类库使用方法作一个总结,主要会关注调参中的一些要点. 1. scikit-learn ...
随机推荐
- 基于拖放布局的 Twitter Bootstrap 网站生成器
简单的几个拖放操作就能做出漂亮的 Twitter Bootstrap 网站?是的,LayoutIt 是一个 Twitter Bootstrap 界面生成器,能够帮助你快速制作出网站和界面模型,同时能够 ...
- mybatis 设置新增数据后返回自增主键
主要是注解@Options起作用,语句如下: @Insert({ "INSERT INTO application_open_up ( " + "app_open_hos ...
- jieba gensim 用法
简单的问答已经实现了,那么问题也跟着出现了,我不能确定问题一定是"你叫什么名字",也有可能是"你是谁","你叫啥"之类的,这就引出了人工智能 ...
- Android忘记锁屏密码如何进入手机?
Android忘记锁屏密码如何进入手机? 1.关闭手机 2.进入recovery模式(即恢复模式,记住不是挖煤模式.进入恢复模式不同手机有不同方法,三星的话安主页键,关机键和音量+(或-键), ...
- I/O多路复用 select poll epoll
I/O多路复用指:通过一种机制,可以监视多个描述符,一旦某个描述符就绪(一般是读就绪或者写就绪),能够通知程序进行相应的读写操作. select select最早于1983年出现在4.2BSD中,它通 ...
- webDAV
wiki webDAV client - java https://github.com/lookfirst/sardine https://www.cnblogs.com/xgjblog/p/383 ...
- 制作签名jar放置到前端资源目录下
给jar包打签名keytool -genkey -keystore myKeystore -alias jwstest查看签名信息jarsigner -keystore myKeystore data ...
- 《汇编语言 基于x86处理器》第六章条件处理部分的代码
▶ 书中第六章的程序,使用了条件判断和跳转来实现一些功能 ● 代码,查找数组首个非零值 INCLUDE Irvine32.inc .data intArray SWORD , , , , , , , ...
- <转载> MySQL 架构 http://www.cnblogs.com/winner-0715/p/6863802.html
1.MySQL整体逻辑架构 我们先下图看看MySQL整体逻辑架构(MySQL’s Logical Architecture) 图1 第一层,即最上一层,所包含的服务并不是MySQL所独有的技术.它们都 ...
- Angular.js实现分页
一.编写angularJS实现的分页js(网上搜)和样式表(pagination),并在页面引入 二.编写变量和方法 //分页控件控制 $scope.paginationConf={ currentP ...