package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} case class data_scheam(features:Vector,label:String)
object 决策树__回归模型 {
val spark=SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt")
.map(_.split(",")).map(x=>data_schema(Vectors.dense(x().toDouble,x().toDouble,x().toDouble,x().toDouble),x())).toDF()
val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories().fit(data)
val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
//决策树回归模型构造设置
val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
//构造机器学习工作流
val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter))
//训练决策树回归模型
val modelRegressor=pipelineRegressor.fit(trainData)
//进行预测
val prediction=modelRegressor.transform(testData)
prediction.show()
//评估决策树回归模型
val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:设置决定你的度量标准是均方根误差还是均方误差等,值可以为:rmse,mse,r2,mae
val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction)
println("均方根误差为: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages().asInstanceOf[DecisionTreeRegressionModel]
val schema_decisionTree=treeModelRegressor.toDebugString
println("决策树分类模型的结构为: "+schema_decisionTree)
}
}
Spark 源码:关于setMetricName("")
@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol)) val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
case "r2" => metrics.r2
case "mae" => metrics.meanAbsoluteError
}
metric
}

结果:

+-----------------+------+------------+-----------------+----------+--------------+
|         features| label|indexedLabel|  indexedFeatures|prediction|predictedLabel|
+-----------------+------+------------+-----------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|hadoop|         1.0|[4.6,3.1,1.5,0.2]|       1.0|        hadoop|
|[4.6,3.4,1.4,0.3]|hadoop|         1.0|[4.6,3.4,1.4,0.3]|       1.0|        hadoop|
|[4.7,3.2,1.3,0.2]|hadoop|         1.0|[4.7,3.2,1.3,0.2]|       1.0|        hadoop|
|[4.8,3.0,1.4,0.1]|hadoop|         1.0|[4.8,3.0,1.4,0.1]|       1.0|        hadoop|
|[5.1,3.3,1.7,0.5]|hadoop|         1.0|[5.1,3.3,1.7,0.5]|       1.0|        hadoop|
|[5.1,3.7,1.5,0.4]|hadoop|         1.0|[5.1,3.7,1.5,0.4]|       1.0|        hadoop|
|[5.4,3.9,1.3,0.4]|hadoop|         1.0|[5.4,3.9,1.3,0.4]|       1.0|        hadoop|
|[5.5,2.3,4.0,1.3]| spark|         0.0|[5.5,2.3,4.0,1.3]|       0.0|         spark|
|[5.5,3.5,1.3,0.2]|hadoop|         1.0|[5.5,3.5,1.3,0.2]|       1.0|        hadoop|
|[5.6,2.7,4.2,1.3]| spark|         0.0|[5.6,2.7,4.2,1.3]|       0.0|         spark|
|[5.6,3.0,4.1,1.3]| spark|         0.0|[5.6,3.0,4.1,1.3]|       0.0|         spark|
|[5.6,3.0,4.5,1.5]| spark|         0.0|[5.6,3.0,4.5,1.5]|       0.0|         spark|
|[5.7,2.6,3.5,1.0]| spark|         0.0|[5.7,2.6,3.5,1.0]|       0.0|         spark|
|[5.7,4.4,1.5,0.4]|hadoop|         1.0|[5.7,4.4,1.5,0.4]|       1.0|        hadoop|
|[5.8,2.7,3.9,1.2]| spark|         0.0|[5.8,2.7,3.9,1.2]|       0.0|         spark|
|[5.8,2.7,4.1,1.0]| spark|         0.0|[5.8,2.7,4.1,1.0]|       0.0|         spark|
|[5.8,2.8,5.1,2.4]| Scala|         2.0|[5.8,2.8,5.1,2.4]|       2.0|         Scala|
|[5.8,4.0,1.2,0.2]|hadoop|         1.0|[5.8,4.0,1.2,0.2]|       1.0|        hadoop|
|[5.9,3.0,4.2,1.5]| spark|         0.0|[5.9,3.0,4.2,1.5]|       0.0|         spark|
|[5.9,3.0,5.1,1.8]| Scala|         2.0|[5.9,3.0,5.1,1.8]|       2.0|         Scala|
|[5.9,3.2,4.8,1.8]| spark|         0.0|[5.9,3.2,4.8,1.8]|       2.0|         Scala|
|[6.1,2.6,5.6,1.4]| Scala|         2.0|[6.1,2.6,5.6,1.4]|       2.0|         Scala|
|[6.1,2.8,4.0,1.3]| spark|         0.0|[6.1,2.8,4.0,1.3]|       0.0|         spark|
|[6.3,2.9,5.6,1.8]| Scala|         2.0|[6.3,2.9,5.6,1.8]|       2.0|         Scala|
|[6.3,3.4,5.6,2.4]| Scala|         2.0|[6.3,3.4,5.6,2.4]|       2.0|         Scala|
|[6.4,2.7,5.3,1.9]| Scala|         2.0|[6.4,2.7,5.3,1.9]|       2.0|         Scala|
|[6.4,3.1,5.5,1.8]| Scala|         2.0|[6.4,3.1,5.5,1.8]|       2.0|         Scala|
|[6.4,3.2,4.5,1.5]| spark|         0.0|[6.4,3.2,4.5,1.5]|       0.0|         spark|
|[6.5,2.8,4.6,1.5]| spark|         0.0|[6.5,2.8,4.6,1.5]|       0.0|         spark|
|[6.5,3.0,5.5,1.8]| Scala|         2.0|[6.5,3.0,5.5,1.8]|       2.0|         Scala|
|[6.7,3.0,5.2,2.3]| Scala|         2.0|[6.7,3.0,5.2,2.3]|       2.0|         Scala|
|[6.7,3.1,4.7,1.5]| spark|         0.0|[6.7,3.1,4.7,1.5]|       0.0|         spark|
|[6.8,3.0,5.5,2.1]| Scala|         2.0|[6.8,3.0,5.5,2.1]|       2.0|         Scala|
|[6.9,3.1,5.4,2.1]| Scala|         2.0|[6.9,3.1,5.4,2.1]|       2.0|         Scala|
|[7.0,3.2,4.7,1.4]| spark|         0.0|[7.0,3.2,4.7,1.4]|       0.0|         spark|
|[7.1,3.0,5.9,2.1]| Scala|         2.0|[7.1,3.0,5.9,2.1]|       2.0|         Scala|
|[7.2,3.0,5.8,1.6]| Scala|         2.0|[7.2,3.0,5.8,1.6]|       0.0|         spark|
|[7.2,3.2,6.0,1.8]| Scala|         2.0|[7.2,3.2,6.0,1.8]|       2.0|         Scala|
|[7.2,3.6,6.1,2.5]| Scala|         2.0|[7.2,3.6,6.1,2.5]|       2.0|         Scala|
|[7.4,2.8,6.1,1.9]| Scala|         2.0|[7.4,2.8,6.1,1.9]|       2.0|         Scala|
|[7.7,2.6,6.9,2.3]| Scala|         2.0|[7.7,2.6,6.9,2.3]|       2.0|         Scala|
|[7.7,2.8,6.7,2.0]| Scala|         2.0|[7.7,2.8,6.7,2.0]|       2.0|         Scala|
+-----------------+------+------------+-----------------+----------+--------------+

均方根误差为: 0.43643578047198484
决策树分类模型的结构为: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
  If (feature 3 <= 1.7)
   If (feature 2 <= 1.9)
    Predict: 1.0
   Else (feature 2 > 1.9)
    If (feature 2 <= 4.9)
     If (feature 3 <= 1.6)
      Predict: 0.0
     Else (feature 3 > 1.6)
      Predict: 2.0
    Else (feature 2 > 4.9)
     If (feature 3 <= 1.5)
      Predict: 2.0
     Else (feature 3 > 1.5)
      Predict: 0.0
  Else (feature 3 > 1.7)
   Predict: 2.0

Spark 决策树--回归模型的更多相关文章

  1. Spark 决策树--分类模型

    package Spark_MLlib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{D ...

  2. Spark机器学习5·回归模型(pyspark)

    分类模型的预测目标是:类别编号 回归模型的预测目标是:实数变量 回归模型种类 线性模型 最小二乘回归模型 应用L2正则化时--岭回归(ridge regression) 应用L1正则化时--LASSO ...

  3. 吴裕雄 python 机器学习——集成学习梯度提升决策树GradientBoostingRegressor回归模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets,ensemble from sklear ...

  4. weka实际操作--构建分类、回归模型

    weka提供了几种处理数据的方式,其中分类和回归是平时用到最多的,也是非常容易理解的,分类就是在已有的数据基础上学习出一个分类函数或者构造出一个分类模型.这个函数或模型能够把数据集中地映射到某个给定的 ...

  5. Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS

    Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ...

  6. 如何在R语言中使用Logistic回归模型

    在日常学习或工作中经常会使用线性回归模型对某一事物进行预测,例如预测房价.身高.GDP.学生成绩等,发现这些被预测的变量都属于连续型变量.然而有些情况下,被预测变量可能是二元变量,即成功或失败.流失或 ...

  7. SPSS数据分析—Poisson回归模型

    在对数线性模型中,我们假设单元格频数分布为多项式分布,但是还有一类分类变量分布也是经常用到的,就是Poisson分布. Poisson分布是某件事发生次数的概率分布,用于描述单位时间.单位面积.单位空 ...

  8. SPSS数据分析—配对Logistic回归模型

    Lofistic回归模型也可以用于配对资料,但是其分析方法和操作方法均与之前介绍的不同,具体表现 在以下几个方面1.每个配对组共有同一个回归参数,也就是说协变量在不同配对组中的作用相同2.常数项随着配 ...

  9. SPSS数据分析—多分类Logistic回归模型

    前面我们说过二分类Logistic回归模型,但分类变量并不只是二分类一种,还有多分类,本次我们介绍当因变量为多分类时的Logistic回归模型. 多分类Logistic回归模型又分为有序多分类Logi ...

随机推荐

  1. Django之Ajax提交

    Ajax 提交数据,页面不刷新 Ajax要引入jQuery Django之Ajax提交 Js实现页面的跳转: location.href = "/url/" $ajax({ url ...

  2. jquery 点击弹框

    <a href="#" class="big-link" data-reveal-id="myModal" data-animatio ...

  3. Thawte SSL Web Server 多域型SSL证书

    Thawte SSL Web Server 多域型SSL证书,最多支持25个域名,需要验证域名所有权和申请单位信息,属于企业验证型SSL证书,提供40位/56位/128位,最高支持256位自适应加密. ...

  4. windows10开启内置ubuntu系统,使用xshell连接

    windows安装配置ubuntu系统内置子系统 官方文档:https://docs.microsoft.com/zh-cn/windows/wsl/about https://www.jianshu ...

  5. 【Tomcat】Tomcat性能分析

    一.预研任务介绍和预研目标 任务介绍: Apache Tomcat是目前较为流行的web服务器,以其技术先进.性能稳定著称,其次它还是一个免费开源的项目. Tomcat性能分析的意义在于能为日常工作中 ...

  6. Pull方式解析XML文件

    package com.pingyijinren.test; import android.content.Intent; import android.os.Handler; import andr ...

  7. Ubuntu 16.04通过Magent搭建Memcached集群(转)

    一.下载Magent 官网:https://code.google.com/archive/p/memagent/downloads 离线版本:(链接: https://pan.baidu.com/s ...

  8. MySQL主从复制搭建教程收集(待实践)

    先收集一下,后续再搭建测试. https://zhangge.net/4019.html http://www.cnblogs.com/jiangwenju/p/6098974.html http:/ ...

  9. Mesos, Marathon, Docker 平台部署记录

    Mesos, Marathon, Docker 平台部署记录 所有组件部署基于Ubuntu 14.04 x64 主机 IP 角色 master 192.168.1.3 Mesos Master, Ma ...

  10. iOS8使用TouchID

    iOS8新增了LocalAuthentication框架,用于TouchID的授权使用.亲測,眼下须要用户的设备支持指纹识别并已设置锁屏,并且实际測试过程中反馈比較慢.不能直接跟第三方账号passwo ...