SparkMLlib回归算法之决策树

(一),决策树概念

1,决策树算法(ID3,C4.5 ,CART)之间的比较:

  1,ID3算法在选择根节点和各内部节点中的分支属性时,采用信息增益作为评价标准。信息增益的缺点是倾向于选择取值较多的属性,在有些情况下这类属性可能不会提供太多有价值的信息。

  2 ID3算法只能对描述属性为离散型属性的数据集构造决策树,其余两种算法对离散和连续都可以处理

2,C4.5算法实例介绍(参考网址:http://m.blog.csdn.net/article/details?id=44726921

  

c4.5后剪枝策略:以悲观剪枝为主参考网址:http://www.cnblogs.com/zhangchaoyang/articles/2842490.html

(二) SparkMLlib决策树回归的应用

1,数据集来源及描述:参考http://www.cnblogs.com/ksWorld/p/6891664.html

2,代码实现:

  2.1 构建输入数据格式:

val file_bike = "hour_nohead.csv"
val file_tree=sc.textFile(file_bike).map(_.split(",")).map{
x =>
val feature=x.slice(2,x.length-3).map(_.toDouble)
val label=x(x.length-1).toDouble
LabeledPoint(label,Vectors.dense(feature))
}
println(file_tree.first())
val categoricalFeaturesInfo = Map[Int,Int]()
val model_DT=DecisionTree.trainRegressor(file_tree,categoricalFeaturesInfo,"variance",5,32)

  2.2 模型评判标准(mse,mae,rmsle)

val predict_vs_train = file_tree.map {
point => (model_DT.predict(point.features),point.label)
/* point => (math.exp(model_DT.predict(point.features)), math.exp(point.label))*/
}
predict_vs_train.take(5).foreach(println(_))
/*MSE是均方误差*/
val mse = predict_vs_train.map(x => math.pow(x._1 - x._2, 2)).mean()
/* 平均绝对误差(MAE)*/
val mae = predict_vs_train.map(x => math.abs(x._1 - x._2)).mean()
/*均方根对数误差(RMSLE)*/
val rmsle = math.sqrt(predict_vs_train.map(x => math.pow(math.log(x._1 + 1) - math.log(x._2 + 1), 2)).mean())
println(s"mse is $mse and mae is $mae and rmsle is $rmsle")
/*
mse is 11611.485999495755 and mae is 71.15018786490428 and rmsle is 0.6251152586960916
*/

(三) 改进模型性能和参数调优

1,改变目标量 (对目标值求根号),修改下面语句

LabeledPoint(math.log(label),Vectors.dense(feature))

val predict_vs_train = file_tree.map {
/*point => (model_DT.predict(point.features),point.label)*/
point => (math.exp(model_DT.predict(point.features)), math.exp(point.label))
}
/*结果
mse is 14781.575988339053 and mae is 76.41310991122032 and rmsle is 0.6405996100717035
*/

决策树在变换后的性能有所下降

2,模型参数调优

  1,构建训练集和测试集

 val file_tree=sc.textFile(file_bike).map(_.split(",")).map{
x =>
val feature=x.slice(2,x.length-3).map(_.toDouble)
val label=x(x.length-1).toDouble
LabeledPoint(label,Vectors.dense(feature))
/*LabeledPoint(math.log(label),Vectors.dense(feature))*/
}
val tree_orgin=file_tree.randomSplit(Array(0.8,0.2),11L)
val tree_train=tree_orgin(0)
val tree_test=tree_orgin(1)

  2,调节树的深度参数

val categoricalFeaturesInfo = Map[Int,Int]()
val model_DT=DecisionTree.trainRegressor(file_tree,categoricalFeaturesInfo,"variance",5,32)
/*调节树深度次数*/
val Deep_Results = Seq(1, 2, 3, 4, 5, 10, 20).map { param =>
val model = DecisionTree.trainRegressor(tree_train, categoricalFeaturesInfo,"variance",param,32)
val scoreAndLabels = tree_test.map { point =>
(model.predict(point.features), point.label)
}
val rmsle = math.sqrt(scoreAndLabels.map(x => math.pow(math.log(x._1) - math.log(x._2), 2)).mean)
(s"$param lambda", rmsle)
}
/*深度的结果输出*/
Deep_Results.foreach { case (param, rmsl) => println(f"$param, rmsle = ${rmsl}")}
/*
1 lambda, rmsle = 1.0763369409492645
2 lambda, rmsle = 0.9735820606349874
3 lambda, rmsle = 0.8786984993014815
4 lambda, rmsle = 0.8052113493915528
5 lambda, rmsle = 0.7014036913077335
10 lambda, rmsle = 0.44747906135994925
20 lambda, rmsle = 0.4769214752638845
*/

  深度较大的决策树出现过拟合,从结果来看这个数据集最优的树深度大概在10左右

  3,调节划分数

/*调节划分数*/
val ClassNum_Results = Seq(2, 4, 8, 16, 32, 64, 100).map { param =>
val model = DecisionTree.trainRegressor(tree_train, categoricalFeaturesInfo,"variance",10,param)
val scoreAndLabels = tree_test.map { point =>
(model.predict(point.features), point.label)
}
val rmsle = math.sqrt(scoreAndLabels.map(x => math.pow(math.log(x._1) - math.log(x._2), 2)).mean)
(s"$param lambda", rmsle)
}
/*划分数的结果输出*/
ClassNum_Results.foreach { case (param, rmsl) => println(f"$param, rmsle = ${rmsl}")}
/*
2 lambda, rmsle = 1.2995002615220668
4 lambda, rmsle = 0.7682777577495858
8 lambda, rmsle = 0.6615110909041817
16 lambda, rmsle = 0.4981237727958235
32 lambda, rmsle = 0.44747906135994925
64 lambda, rmsle = 0.4487531073836407
100 lambda, rmsle = 0.4487531073836407
*/

  更多的划分数会使模型变复杂,并且有助于提升特征维度较大的模型性能。划分数到一定程度之后,对性能的提升帮助不大。实际上,由于过拟合的原因会导致测试集的性能变差。可见分类数应在32左右。。

SparkMLlib回归算法之决策树的更多相关文章

  1. SparkMLlib分类算法之决策树学习

    SparkMLlib分类算法之决策树学习 (一) 决策树的基本概念 决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风 ...

  2. SparkMLlib学习分类算法之逻辑回归算法

    SparkMLlib学习分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/51693 ...

  3. SparkMLlib分类算法之逻辑回归算法

    SparkMLlib分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/5169383 ...

  4. SparkMLlib分类算法之支持向量机

    SparkMLlib分类算法之支持向量机 (一),概念 支持向量机(support vector machine)是一种分类算法,通过寻求结构化风险最小来提高学习机泛化能力,实现经验风险和置信范围的最 ...

  5. Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS

    Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ...

  6. Spark MLlib架构解析(含分类算法、回归算法、聚类算法和协同过滤)

    Spark MLlib架构解析 MLlib的底层基础解析 MLlib的算法库分析 分类算法 回归算法 聚类算法 协同过滤 MLlib的实用程序分析 从架构图可以看出MLlib主要包含三个部分: 底层基 ...

  7. Lasso回归算法: 坐标轴下降法与最小角回归法小结

    前面的文章对线性回归做了一个小结,文章在这: 线性回归原理小结.里面对线程回归的正则化也做了一个初步的介绍.提到了线程回归的L2正则化-Ridge回归,以及线程回归的L1正则化-Lasso回归.但是对 ...

  8. 基于Python的函数回归算法验证

    看机器学习看到了回归函数,看了一半看不下去了,看到能用方差进行函数回归,又手痒痒了,自己推公式写代码验证: 常见的最小二乘法是一阶函数回归回归方法就是寻找方差的最小值y = kx + bxi, yiy ...

  9. 机器学习之Logistic 回归算法

    1 Logistic 回归算法的原理 1.1 需要的数学基础 我在看机器学习实战时对其中的代码非常费解,说好的利用偏导数求最值怎么代码中没有体现啊,就一个简单的式子:θ= θ - α Σ [( hθ( ...

随机推荐

  1. 浏览器访问php脚本通过sendmail用mail函数发送邮件

    前几天做项目遇到这样的一个问题:当某一个结点下有新的文章发表的时候,以邮件的形式通知该结点下的所有用户.这就需要用到邮件发送的功能. 因为项目是php语言做的,所以最简单的方法就是使用php自带的函数 ...

  2. 奇葩问题:同样的字符串equal为false

    问题:什么情况下 "同样" 的字符串会不equal呢?例如   "a".equal("a")  => false 在你看来,这可能是个 ...

  3. 数据库习题(oracle)

    学生表 Student 字段值分别是 Sid ,Sname ,Sage ,Ssex 教师表 Teacher 字段值分别是 Tid ,Tname 课程表 Course 字段值分别是Cid ,Cname ...

  4. 使用Microsoft SQL Server Migration Assistant for Oracle迁移数据库

    前言:使用Microsoft SQL Server Migration Assistant for Oracle迁移Oracle数据库到SqlServer数据库. 准备:Oracle11g.SqlSe ...

  5. bit ( 比特 )和 Byte(字节)的关系 以及 网速怎么算

    今天来整理一下存储单位和网速的知识. 最近几天家里网不太好,所以就了解了一下网速和电脑的存储单位的关系. 一.存储单位的bit 和 Byte 1.bit(比特) bit也就是我们不一定听说过的比特,大 ...

  6. jQuery小测的总结

    1.在div元素中,包含了一个<span>元素,通过has选择器获取<div>元素中的<span>元素的语法是? 提示使用has() 答案: $(div:has(s ...

  7. Web性能优化工具WebPageTest(二)——性能数据

    在前一篇<配置>完成后,点击“START TEST”,就可以开始测试,测试需要一段时间. 有时候可能还要排队,如下图所示,测试完成后可查看到测试结果. 一.Summary 1)优化等级 优 ...

  8. sql中常见日期获取

    获取当前年月日 --获取当前月份 ,GETDATE())) --获取当前月份的下个月 ,GETDATE())) --获取当前月份的上个月 year()获取年 select year(GETDATE() ...

  9. lua 数组

    lua 数组 语法结构 arr = { - } 一维数组 数组的值仍然是数组的, 为多维数组, 否则为一维数组 示例程序 local arr = {1, 2, 3} for i = 1, #arr d ...

  10. php 启动过程 - sapi MSHUTDOWN 过程

    php 启动过程 - sapi MSHUTDOWN 过程 概述 当服务器关闭时, 会走到 sapi MSHUTDOWN 过程 注册过程 本次内容是在 php 启动过程 - sapi MINIT 过程 ...