SparkMLlib回归算法之决策树
SparkMLlib回归算法之决策树
(一),决策树概念
1,决策树算法(ID3,C4.5 ,CART)之间的比较:
1,ID3算法在选择根节点和各内部节点中的分支属性时,采用信息增益作为评价标准。信息增益的缺点是倾向于选择取值较多的属性,在有些情况下这类属性可能不会提供太多有价值的信息。
2 ID3算法只能对描述属性为离散型属性的数据集构造决策树,其余两种算法对离散和连续都可以处理
2,C4.5算法实例介绍(参考网址:http://m.blog.csdn.net/article/details?id=44726921)
  



c4.5后剪枝策略:以悲观剪枝为主参考网址:http://www.cnblogs.com/zhangchaoyang/articles/2842490.html
(二) SparkMLlib决策树回归的应用
1,数据集来源及描述:参考http://www.cnblogs.com/ksWorld/p/6891664.html
2,代码实现:
2.1 构建输入数据格式:
val file_bike = "hour_nohead.csv"
val file_tree=sc.textFile(file_bike).map(_.split(",")).map{
x =>
val feature=x.slice(2,x.length-3).map(_.toDouble)
val label=x(x.length-1).toDouble
LabeledPoint(label,Vectors.dense(feature))
}
println(file_tree.first())
val categoricalFeaturesInfo = Map[Int,Int]()
val model_DT=DecisionTree.trainRegressor(file_tree,categoricalFeaturesInfo,"variance",5,32)
2.2 模型评判标准(mse,mae,rmsle)
val predict_vs_train = file_tree.map {
        point => (model_DT.predict(point.features),point.label)
       /* point => (math.exp(model_DT.predict(point.features)), math.exp(point.label))*/
      }
      predict_vs_train.take(5).foreach(println(_))
      /*MSE是均方误差*/
      val mse = predict_vs_train.map(x => math.pow(x._1 - x._2, 2)).mean()
      /* 平均绝对误差(MAE)*/
      val mae = predict_vs_train.map(x => math.abs(x._1 - x._2)).mean()
      /*均方根对数误差(RMSLE)*/
      val rmsle = math.sqrt(predict_vs_train.map(x => math.pow(math.log(x._1 + 1) - math.log(x._2 + 1), 2)).mean())
      println(s"mse is $mse and mae is $mae and rmsle is $rmsle")
/*
mse is 11611.485999495755 and mae is 71.15018786490428 and rmsle is 0.6251152586960916
*/
(三) 改进模型性能和参数调优
1,改变目标量 (对目标值求根号),修改下面语句
LabeledPoint(math.log(label),Vectors.dense(feature))
和
val predict_vs_train = file_tree.map {
/*point => (model_DT.predict(point.features),point.label)*/
point => (math.exp(model_DT.predict(point.features)), math.exp(point.label))
}
/*结果
mse is 14781.575988339053 and mae is 76.41310991122032 and rmsle is 0.6405996100717035
*/
决策树在变换后的性能有所下降
2,模型参数调优
1,构建训练集和测试集
 val file_tree=sc.textFile(file_bike).map(_.split(",")).map{
      x =>
        val feature=x.slice(2,x.length-3).map(_.toDouble)
        val label=x(x.length-1).toDouble
      LabeledPoint(label,Vectors.dense(feature))
        /*LabeledPoint(math.log(label),Vectors.dense(feature))*/
    }
    val tree_orgin=file_tree.randomSplit(Array(0.8,0.2),11L)
    val tree_train=tree_orgin(0)
    val tree_test=tree_orgin(1)
2,调节树的深度参数
val categoricalFeaturesInfo = Map[Int,Int]()
val model_DT=DecisionTree.trainRegressor(file_tree,categoricalFeaturesInfo,"variance",5,32)
/*调节树深度次数*/
val Deep_Results = Seq(1, 2, 3, 4, 5, 10, 20).map { param =>
val model = DecisionTree.trainRegressor(tree_train, categoricalFeaturesInfo,"variance",param,32)
val scoreAndLabels = tree_test.map { point =>
(model.predict(point.features), point.label)
}
val rmsle = math.sqrt(scoreAndLabels.map(x => math.pow(math.log(x._1) - math.log(x._2), 2)).mean)
(s"$param lambda", rmsle)
}
/*深度的结果输出*/
Deep_Results.foreach { case (param, rmsl) => println(f"$param, rmsle = ${rmsl}")}
/*
1 lambda, rmsle = 1.0763369409492645
2 lambda, rmsle = 0.9735820606349874
3 lambda, rmsle = 0.8786984993014815
4 lambda, rmsle = 0.8052113493915528
5 lambda, rmsle = 0.7014036913077335
10 lambda, rmsle = 0.44747906135994925
20 lambda, rmsle = 0.4769214752638845
*/
深度较大的决策树出现过拟合,从结果来看这个数据集最优的树深度大概在10左右
3,调节划分数
/*调节划分数*/
val ClassNum_Results = Seq(2, 4, 8, 16, 32, 64, 100).map { param =>
val model = DecisionTree.trainRegressor(tree_train, categoricalFeaturesInfo,"variance",10,param)
val scoreAndLabels = tree_test.map { point =>
(model.predict(point.features), point.label)
}
val rmsle = math.sqrt(scoreAndLabels.map(x => math.pow(math.log(x._1) - math.log(x._2), 2)).mean)
(s"$param lambda", rmsle)
}
/*划分数的结果输出*/
ClassNum_Results.foreach { case (param, rmsl) => println(f"$param, rmsle = ${rmsl}")}
/*
2 lambda, rmsle = 1.2995002615220668
4 lambda, rmsle = 0.7682777577495858
8 lambda, rmsle = 0.6615110909041817
16 lambda, rmsle = 0.4981237727958235
32 lambda, rmsle = 0.44747906135994925
64 lambda, rmsle = 0.4487531073836407
100 lambda, rmsle = 0.4487531073836407
*/
更多的划分数会使模型变复杂,并且有助于提升特征维度较大的模型性能。划分数到一定程度之后,对性能的提升帮助不大。实际上,由于过拟合的原因会导致测试集的性能变差。可见分类数应在32左右。。
SparkMLlib回归算法之决策树的更多相关文章
- SparkMLlib分类算法之决策树学习
		SparkMLlib分类算法之决策树学习 (一) 决策树的基本概念 决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风 ... 
- SparkMLlib学习分类算法之逻辑回归算法
		SparkMLlib学习分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/51693 ... 
- SparkMLlib分类算法之逻辑回归算法
		SparkMLlib分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/5169383 ... 
- SparkMLlib分类算法之支持向量机
		SparkMLlib分类算法之支持向量机 (一),概念 支持向量机(support vector machine)是一种分类算法,通过寻求结构化风险最小来提高学习机泛化能力,实现经验风险和置信范围的最 ... 
- Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS
		Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ... 
- Spark MLlib架构解析(含分类算法、回归算法、聚类算法和协同过滤)
		Spark MLlib架构解析 MLlib的底层基础解析 MLlib的算法库分析 分类算法 回归算法 聚类算法 协同过滤 MLlib的实用程序分析 从架构图可以看出MLlib主要包含三个部分: 底层基 ... 
- Lasso回归算法: 坐标轴下降法与最小角回归法小结
		前面的文章对线性回归做了一个小结,文章在这: 线性回归原理小结.里面对线程回归的正则化也做了一个初步的介绍.提到了线程回归的L2正则化-Ridge回归,以及线程回归的L1正则化-Lasso回归.但是对 ... 
- 基于Python的函数回归算法验证
		看机器学习看到了回归函数,看了一半看不下去了,看到能用方差进行函数回归,又手痒痒了,自己推公式写代码验证: 常见的最小二乘法是一阶函数回归回归方法就是寻找方差的最小值y = kx + bxi, yiy ... 
- 机器学习之Logistic 回归算法
		1 Logistic 回归算法的原理 1.1 需要的数学基础 我在看机器学习实战时对其中的代码非常费解,说好的利用偏导数求最值怎么代码中没有体现啊,就一个简单的式子:θ= θ - α Σ [( hθ( ... 
随机推荐
- 浏览器访问php脚本通过sendmail用mail函数发送邮件
			前几天做项目遇到这样的一个问题:当某一个结点下有新的文章发表的时候,以邮件的形式通知该结点下的所有用户.这就需要用到邮件发送的功能. 因为项目是php语言做的,所以最简单的方法就是使用php自带的函数 ... 
- 奇葩问题:同样的字符串equal为false
			问题:什么情况下 "同样" 的字符串会不equal呢?例如 "a".equal("a") => false 在你看来,这可能是个 ... 
- 数据库习题(oracle)
			学生表 Student 字段值分别是 Sid ,Sname ,Sage ,Ssex 教师表 Teacher 字段值分别是 Tid ,Tname 课程表 Course 字段值分别是Cid ,Cname ... 
- 使用Microsoft SQL Server Migration Assistant for Oracle迁移数据库
			前言:使用Microsoft SQL Server Migration Assistant for Oracle迁移Oracle数据库到SqlServer数据库. 准备:Oracle11g.SqlSe ... 
- bit ( 比特 )和 Byte(字节)的关系 以及 网速怎么算
			今天来整理一下存储单位和网速的知识. 最近几天家里网不太好,所以就了解了一下网速和电脑的存储单位的关系. 一.存储单位的bit 和 Byte 1.bit(比特) bit也就是我们不一定听说过的比特,大 ... 
- jQuery小测的总结
			1.在div元素中,包含了一个<span>元素,通过has选择器获取<div>元素中的<span>元素的语法是? 提示使用has() 答案: $(div:has(s ... 
- Web性能优化工具WebPageTest(二)——性能数据
			在前一篇<配置>完成后,点击“START TEST”,就可以开始测试,测试需要一段时间. 有时候可能还要排队,如下图所示,测试完成后可查看到测试结果. 一.Summary 1)优化等级 优 ... 
- sql中常见日期获取
			获取当前年月日 --获取当前月份 ,GETDATE())) --获取当前月份的下个月 ,GETDATE())) --获取当前月份的上个月 year()获取年 select year(GETDATE() ... 
- lua 数组
			lua 数组 语法结构 arr = { - } 一维数组 数组的值仍然是数组的, 为多维数组, 否则为一维数组 示例程序 local arr = {1, 2, 3} for i = 1, #arr d ... 
- php 启动过程 - sapi MSHUTDOWN 过程
			php 启动过程 - sapi MSHUTDOWN 过程 概述 当服务器关闭时, 会走到 sapi MSHUTDOWN 过程 注册过程 本次内容是在 php 启动过程 - sapi MINIT 过程 ... 
