Spark 决策树--回归模型
package Spark_MLlib import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} case class data_scheam(features:Vector,label:String)
object 决策树__回归模型 {
val spark=SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt")
.map(_.split(",")).map(x=>data_schema(Vectors.dense(x().toDouble,x().toDouble,x().toDouble,x().toDouble),x())).toDF()
val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories().fit(data)
val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
//决策树回归模型构造设置
val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
//构造机器学习工作流
val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter))
//训练决策树回归模型
val modelRegressor=pipelineRegressor.fit(trainData)
//进行预测
val prediction=modelRegressor.transform(testData)
prediction.show()
//评估决策树回归模型
val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:设置决定你的度量标准是均方根误差还是均方误差等,值可以为:rmse,mse,r2,mae
val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction)
println("均方根误差为: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages().asInstanceOf[DecisionTreeRegressionModel]
val schema_decisionTree=treeModelRegressor.toDebugString
println("决策树分类模型的结构为: "+schema_decisionTree)
}
}
Spark 源码:关于setMetricName("")
@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))
val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
case "r2" => metrics.r2
case "mae" => metrics.meanAbsoluteError
}
metric
}
结果:
+-----------------+------+------------+-----------------+----------+--------------+
| features| label|indexedLabel| indexedFeatures|prediction|predictedLabel|
+-----------------+------+------------+-----------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|hadoop| 1.0|[4.6,3.1,1.5,0.2]| 1.0| hadoop|
|[4.6,3.4,1.4,0.3]|hadoop| 1.0|[4.6,3.4,1.4,0.3]| 1.0| hadoop|
|[4.7,3.2,1.3,0.2]|hadoop| 1.0|[4.7,3.2,1.3,0.2]| 1.0| hadoop|
|[4.8,3.0,1.4,0.1]|hadoop| 1.0|[4.8,3.0,1.4,0.1]| 1.0| hadoop|
|[5.1,3.3,1.7,0.5]|hadoop| 1.0|[5.1,3.3,1.7,0.5]| 1.0| hadoop|
|[5.1,3.7,1.5,0.4]|hadoop| 1.0|[5.1,3.7,1.5,0.4]| 1.0| hadoop|
|[5.4,3.9,1.3,0.4]|hadoop| 1.0|[5.4,3.9,1.3,0.4]| 1.0| hadoop|
|[5.5,2.3,4.0,1.3]| spark| 0.0|[5.5,2.3,4.0,1.3]| 0.0| spark|
|[5.5,3.5,1.3,0.2]|hadoop| 1.0|[5.5,3.5,1.3,0.2]| 1.0| hadoop|
|[5.6,2.7,4.2,1.3]| spark| 0.0|[5.6,2.7,4.2,1.3]| 0.0| spark|
|[5.6,3.0,4.1,1.3]| spark| 0.0|[5.6,3.0,4.1,1.3]| 0.0| spark|
|[5.6,3.0,4.5,1.5]| spark| 0.0|[5.6,3.0,4.5,1.5]| 0.0| spark|
|[5.7,2.6,3.5,1.0]| spark| 0.0|[5.7,2.6,3.5,1.0]| 0.0| spark|
|[5.7,4.4,1.5,0.4]|hadoop| 1.0|[5.7,4.4,1.5,0.4]| 1.0| hadoop|
|[5.8,2.7,3.9,1.2]| spark| 0.0|[5.8,2.7,3.9,1.2]| 0.0| spark|
|[5.8,2.7,4.1,1.0]| spark| 0.0|[5.8,2.7,4.1,1.0]| 0.0| spark|
|[5.8,2.8,5.1,2.4]| Scala| 2.0|[5.8,2.8,5.1,2.4]| 2.0| Scala|
|[5.8,4.0,1.2,0.2]|hadoop| 1.0|[5.8,4.0,1.2,0.2]| 1.0| hadoop|
|[5.9,3.0,4.2,1.5]| spark| 0.0|[5.9,3.0,4.2,1.5]| 0.0| spark|
|[5.9,3.0,5.1,1.8]| Scala| 2.0|[5.9,3.0,5.1,1.8]| 2.0| Scala|
|[5.9,3.2,4.8,1.8]| spark| 0.0|[5.9,3.2,4.8,1.8]| 2.0| Scala|
|[6.1,2.6,5.6,1.4]| Scala| 2.0|[6.1,2.6,5.6,1.4]| 2.0| Scala|
|[6.1,2.8,4.0,1.3]| spark| 0.0|[6.1,2.8,4.0,1.3]| 0.0| spark|
|[6.3,2.9,5.6,1.8]| Scala| 2.0|[6.3,2.9,5.6,1.8]| 2.0| Scala|
|[6.3,3.4,5.6,2.4]| Scala| 2.0|[6.3,3.4,5.6,2.4]| 2.0| Scala|
|[6.4,2.7,5.3,1.9]| Scala| 2.0|[6.4,2.7,5.3,1.9]| 2.0| Scala|
|[6.4,3.1,5.5,1.8]| Scala| 2.0|[6.4,3.1,5.5,1.8]| 2.0| Scala|
|[6.4,3.2,4.5,1.5]| spark| 0.0|[6.4,3.2,4.5,1.5]| 0.0| spark|
|[6.5,2.8,4.6,1.5]| spark| 0.0|[6.5,2.8,4.6,1.5]| 0.0| spark|
|[6.5,3.0,5.5,1.8]| Scala| 2.0|[6.5,3.0,5.5,1.8]| 2.0| Scala|
|[6.7,3.0,5.2,2.3]| Scala| 2.0|[6.7,3.0,5.2,2.3]| 2.0| Scala|
|[6.7,3.1,4.7,1.5]| spark| 0.0|[6.7,3.1,4.7,1.5]| 0.0| spark|
|[6.8,3.0,5.5,2.1]| Scala| 2.0|[6.8,3.0,5.5,2.1]| 2.0| Scala|
|[6.9,3.1,5.4,2.1]| Scala| 2.0|[6.9,3.1,5.4,2.1]| 2.0| Scala|
|[7.0,3.2,4.7,1.4]| spark| 0.0|[7.0,3.2,4.7,1.4]| 0.0| spark|
|[7.1,3.0,5.9,2.1]| Scala| 2.0|[7.1,3.0,5.9,2.1]| 2.0| Scala|
|[7.2,3.0,5.8,1.6]| Scala| 2.0|[7.2,3.0,5.8,1.6]| 0.0| spark|
|[7.2,3.2,6.0,1.8]| Scala| 2.0|[7.2,3.2,6.0,1.8]| 2.0| Scala|
|[7.2,3.6,6.1,2.5]| Scala| 2.0|[7.2,3.6,6.1,2.5]| 2.0| Scala|
|[7.4,2.8,6.1,1.9]| Scala| 2.0|[7.4,2.8,6.1,1.9]| 2.0| Scala|
|[7.7,2.6,6.9,2.3]| Scala| 2.0|[7.7,2.6,6.9,2.3]| 2.0| Scala|
|[7.7,2.8,6.7,2.0]| Scala| 2.0|[7.7,2.8,6.7,2.0]| 2.0| Scala|
+-----------------+------+------------+-----------------+----------+--------------+
均方根误差为: 0.43643578047198484
决策树分类模型的结构为: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
If (feature 3 <= 1.7)
If (feature 2 <= 1.9)
Predict: 1.0
Else (feature 2 > 1.9)
If (feature 2 <= 4.9)
If (feature 3 <= 1.6)
Predict: 0.0
Else (feature 3 > 1.6)
Predict: 2.0
Else (feature 2 > 4.9)
If (feature 3 <= 1.5)
Predict: 2.0
Else (feature 3 > 1.5)
Predict: 0.0
Else (feature 3 > 1.7)
Predict: 2.0
Spark 决策树--回归模型的更多相关文章
- Spark 决策树--分类模型
package Spark_MLlib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{D ...
- Spark机器学习5·回归模型(pyspark)
分类模型的预测目标是:类别编号 回归模型的预测目标是:实数变量 回归模型种类 线性模型 最小二乘回归模型 应用L2正则化时--岭回归(ridge regression) 应用L1正则化时--LASSO ...
- 吴裕雄 python 机器学习——集成学习梯度提升决策树GradientBoostingRegressor回归模型
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets,ensemble from sklear ...
- weka实际操作--构建分类、回归模型
weka提供了几种处理数据的方式,其中分类和回归是平时用到最多的,也是非常容易理解的,分类就是在已有的数据基础上学习出一个分类函数或者构造出一个分类模型.这个函数或模型能够把数据集中地映射到某个给定的 ...
- Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS
Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ...
- 如何在R语言中使用Logistic回归模型
在日常学习或工作中经常会使用线性回归模型对某一事物进行预测,例如预测房价.身高.GDP.学生成绩等,发现这些被预测的变量都属于连续型变量.然而有些情况下,被预测变量可能是二元变量,即成功或失败.流失或 ...
- SPSS数据分析—Poisson回归模型
在对数线性模型中,我们假设单元格频数分布为多项式分布,但是还有一类分类变量分布也是经常用到的,就是Poisson分布. Poisson分布是某件事发生次数的概率分布,用于描述单位时间.单位面积.单位空 ...
- SPSS数据分析—配对Logistic回归模型
Lofistic回归模型也可以用于配对资料,但是其分析方法和操作方法均与之前介绍的不同,具体表现 在以下几个方面1.每个配对组共有同一个回归参数,也就是说协变量在不同配对组中的作用相同2.常数项随着配 ...
- SPSS数据分析—多分类Logistic回归模型
前面我们说过二分类Logistic回归模型,但分类变量并不只是二分类一种,还有多分类,本次我们介绍当因变量为多分类时的Logistic回归模型. 多分类Logistic回归模型又分为有序多分类Logi ...
随机推荐
- VI/VIM 编辑器
[是什么?] VI 是 Unix 操作系统和类 Unix 操作系统中最通用的文本编辑器. VIM 编辑器是从 VI 发展出来的一个性能更强大的文本编辑器.可以主动的以字体颜色辨别语法的正确性,方便程序 ...
- 《 阿Q正传》-鲁迅 词语解释 | 经典语录
词语解释 “太上有立德,其次是立功,其次是立言,虽久不废,此之谓不朽”.-出自<左传>-左丘明(春秋末期) 解释:(1)最上等的是树立德行,其次是建功立业,再其次是创立学说,即使过了很久也 ...
- 获取某一个<tr>中<td>的值
$("#trId").children("td").eq(0).text(}; //当前行的第一个<td>的值 <td>下标 ...
- Organize Your Train part II 字典树(此题专卡STL)
Organize Your Train part II Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 8787 Acce ...
- 创建Django项目(三)——站点管理
2013-08-05 21:01:34| 1.激活管理界面 (1) 修改"mysite\mysite\settings.py"文件,将'django ...
- [bzoj4712]洪水_动态dp
洪水 bzoj-4712 题目大意:给定一棵$n$个节点的有根树.每次询问以一棵节点为根的子树内,选取一些节点使得这个被询问的节点包含的叶子节点都有一个父亲被选中,求最小权值.支持单点修改. 注释:$ ...
- Java并发包——使用新的方式创建线程
Java并发包——使用新的方式创建线程 摘要:本文主要学习了如何使用Java并发包中的类创建线程. 部分内容来自以下博客: https://www.cnblogs.com/dolphin0520/p/ ...
- 创建简单的spring-mvc项目
1.第一步:创建项目 new—>Dynamic Web Project 项目创建成功后,展示如图: 2.第二步:导入springmvc的jar包和common-logging.jar 3.第三步 ...
- java多线程断点下载原理(代码实例演示)
原文:http://www.open-open.com/lib/view/open1423214229232.html 其实多线程断点下载原理,很简单的,那么我们就来先了解下,如何实现多线程的断点下载 ...
- log4j 具体解释
通常,我们都提供一个名为 log4j.properties的文件.在第一次调用到Log4J时,Log4J会在类路径(../web-inf/class/当然也能够放到其他不论什么文件夹.仅仅要该文件夹被 ...