SparkMLlib分类算法之决策树学习

(一) 决策树的基本概念

    决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。通过信息增益来筛选出属性的优先性。

缺点:参考网址:http://www.ppvke.com/Blog/archives/25042

1)对连续性的字段比较难预测。
2)对有时间顺序的数据,需要很多预处理的工作。
3)当类别太多时,错误可能就会增加的比较快。
4)一般的算法分类的时候,只是根据一个字段来分类。
 
(二) SparkMLlib对决策树应用
1,数据集的准备:参考:http://www.cnblogs.com/ksWorld/p/6882398.html
2,数据预处理及获取训练集和测试集
val orig_file=sc.textFile("train_nohead.tsv")
//println(orig_file.first())
val data_file=orig_file.map(_.split("\t")).map{
r =>
val trimmed =r.map(_.replace("\"",""))
val lable=trimmed(r.length-1).toDouble
val feature=trimmed.slice(4,r.length-1).map(d => if(d=="?")0.0
else d.toDouble)
LabeledPoint(lable,Vectors.dense(feature))
}
/*特征标准化优化,似乎对决策数没啥影响*/
val vectors=data_file.map(x =>x.features)
val rows=new RowMatrix(vectors)
println(rows.computeColumnSummaryStatistics().variance)//每列的方差
val scaler=new StandardScaler(withMean=true,withStd=true).fit(vectors)//标准化
val scaled_data=data_file.map(point => LabeledPoint(point.label,scaler.transform(point.features)))
.randomSplit(Array(0.7,0.3),11L)//固定seed为11L,确保每次每次实验能得到相同的结果
val data_train=scaled_data(0)
val data_test=scaled_data(1)

3,构建模型及模型评价

/*训练决策树模型*/
val model_DT=DecisionTree.train(data_train,Algo.Classification,Entropy,maxTreeDepth)
/*决策树的精确度*/
val predectionAndLabeledDT=data_test.map { point =>
val predectLabeled = if (model_DT.predict(point.features) > 0.5) 1.0 else 0.0
(predectLabeled,point.label)
}
val metricsDT=new MulticlassMetrics(predectionAndLabeledDT)
println(metricsDT.accuracy)//0.6273062730627307
/*决策树的PR曲线和AOC曲线*/
val dtMetrics = Seq(model_DT).map{ model =>
val scoreAndLabels = data_test.map { point =>
val score = model.predict(point.features)
(if (score > 0.5) 1.0 else 0.0, point.label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
(model.getClass.getSimpleName, metrics.areaUnderPR, metrics.areaUnderROC)
}
val allMetrics = dtMetrics
allMetrics.foreach{ case (m, pr, roc) =>
println(f"$m, Area under PR: ${pr * 100.0}%2.4f%%, Area under ROC: ${roc * 100.0}%2.4f%%")
}
/*
DecisionTreeModel, Area under PR: 74.2335%, Area under ROC: 62.4326%
*/

4,模型参数调优(可以调解长度和纯度两方面考虑)

4.1 构建调参函数

/*调参函数*/
def trainDTWithParams(input: RDD[LabeledPoint], maxDepth: Int,
impurity: Impurity) = {
DecisionTree.train(input, Algo.Classification, impurity, maxDepth)
}

4.2  调解树的深度评估函数  (提高树的深度可以得到更精确的模型(这和预期一致,因为模型在更大的树深度下会变得更加复杂)。然而树的深度越大,模型对训练数据过拟合程度越严重)

 /*改变深度*/
val dtResultsEntropy = Seq(1, 2, 3, 4, 5, 10, 20).map { param =>
val model = trainDTWithParams(data_train, param, Entropy)
val scoreAndLabels = data_test.map { point =>
val score = model.predict(point.features)
(if (score > 0.5) 1.0 else 0.0, point.label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
(s"$param tree depth", metrics.areaUnderROC)
}
dtResultsEntropy.foreach { case (param, auc) => println(f"$param, " +
f"AUC = ${auc * 100}%2.2f%%") }
/*
1 tree depth, AUC = 58.57%
2 tree depth, AUC = 60.69%
3 tree depth, AUC = 61.40%
4 tree depth, AUC = 61.30%
5 tree depth, AUC = 62.43%
10 tree depth, AUC = 62.26%
20 tree depth, AUC = 60.59%
*/

2,调解纯度参数 (差异不是很明显。。)

 /*改变纯度*/
val dtResultsEntropy = Seq(Gini,Entropy).map { param =>
val model = trainDTWithParams(data_train, 5, param)
val scoreAndLabels = data_test.map { point =>
val score = model.predict(point.features)
(if (score > 0.5) 1.0 else 0.0, point.label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
(s"$param tree depth", metrics.areaUnderROC)
}
dtResultsEntropy.foreach { case (param, auc) => println(f"$param, " +
f"AUC = ${auc * 100}%2.2f%%") }
/*
org.apache.spark.mllib.tree.impurity.Gini$@32d8e58d tree depth, AUC = 62.37%
org.apache.spark.mllib.tree.impurity.Entropy$@1ddba7a0 tree depth, AUC = 62.43%
*/

(三) 交叉验证

1,数据集分类

创建三个数据集:训练集

评估集(类似上述测试集用于模型参数的调优,比如 lambda 和步长)

测试集(不用于模型的训练和参数调优,只用于估计模型在新数据中性能)

2,交叉验证的常用方法

一个流行的方法是 K- 折叠交叉验证,其中数据集被分成 K 个不重叠的部分。用数据中的 K-1 份训练模型,剩下一部分测试模型。而只分训练集和测试集可以看做是 2- 折叠交叉验证。

还有“留一交叉验证”和“随机采样”。更多资料详见 http://en.wikipedia.org/wiki/Cross-validation_(statistics) 。

SparkMLlib分类算法之决策树学习的更多相关文章

  1. Netflix工程总监眼中的分类算法:深度学习优先级最低

    Netflix工程总监眼中的分类算法:深度学习优先级最低 摘要:不同分类算法的优势是什么?Netflix公司工程总监Xavier Amatriain根据奥卡姆剃刀原理依次推荐了逻辑回归.SVM.决策树 ...

  2. SparkMLlib分类算法之支持向量机

    SparkMLlib分类算法之支持向量机 (一),概念 支持向量机(support vector machine)是一种分类算法,通过寻求结构化风险最小来提高学习机泛化能力,实现经验风险和置信范围的最 ...

  3. SparkMLlib回归算法之决策树

    SparkMLlib回归算法之决策树 (一),决策树概念 1,决策树算法(ID3,C4.5 ,CART)之间的比较: 1,ID3算法在选择根节点和各内部节点中的分支属性时,采用信息增益作为评价标准.信 ...

  4. SparkMLlib分类算法之逻辑回归算法

    SparkMLlib分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/5169383 ...

  5. 【分类算法】决策树(Decision Tree)

    (注:本篇博文是对<统计学习方法>中决策树一章的归纳总结,下列的一些文字和图例均引自此书~) 决策树(decision tree)属于分类/回归方法.其具有可读性.可解释性.分类速度快等优 ...

  6. (ZT)算法杂货铺——分类算法之决策树(Decision tree)

    https://www.cnblogs.com/leoo2sk/archive/2010/09/19/decision-tree.html 3.1.摘要 在前面两篇文章中,分别介绍和讨论了朴素贝叶斯分 ...

  7. 数据挖掘分类算法之决策树(zz)

    决策树(Decision tree) 决策树是以实例为基础的归纳学习算法.     它从一组无次序.无规则的元组中推理出决策树表示形式的分类规则.它采用自顶向下的递归方式,在决策树的内部结点进行属性值 ...

  8. 分类算法:决策树(C4.5)(转)

    C4.5是机器学习算法中的另一个分类决策树算法,它是基于ID3算法进行改进后的一种重要算法,相比于ID3算法,改进有如下几个要点: 1)用信息增益率来选择属性.ID3选择属性用的是子树的信息增益,这里 ...

  9. SparkMLlib学习分类算法之逻辑回归算法

    SparkMLlib学习分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/51693 ...

随机推荐

  1. ubuntu下 mysql安装以后无法登陆的的解决方法((ERROR 1698 (28000): Access denied for user 'root'@'localhost'))

    1. 删除mysql sudo apt-get autoremove --purge mysql-server-5.0 sudo apt-get remove mysql-server sudo ap ...

  2. Linux驱动之PCI

    <背景> PCI设备有许多地址配置的寄存器,初始化时这寄存器来配置设备的总线地址,配置好后CPU就可以访问该设备的各项资源了.(提炼:配置总线地址)   <配置寄存器>   ( ...

  3. sublime插件FileHeader使用,自动的添加模板

    sublime插件FileHeader能够自动的监测创建新文件动作,自动的添加模板 下载地址:https://github.com/shiyanhui/FileHeader FileHeader能够自 ...

  4. android viewHolder static 静态

    韩梦飞沙  韩亚飞  313134555@qq.com  yue31313  han_meng_fei_sha 不是静态内部类 会 持有 外部类的 引用.  就像经常自定义的 适配器 类 作为内部类 ...

  5. [BZOJ 4591] 超能粒子炮-改

    Link: 传送门 Solution: 记录一下推$\sum_{i=0}^k C_n^i$的过程: 其实就是将相同的$i/p$合起来算,这样每个里面都是一个可以预处理的子问题 接下来递归下去算即可 T ...

  6. elasticsearch 亿级数据检索案例与原理

    版权说明: 本文章版权归本人及博客园共同所有,转载请标明原文出处( https://www.cnblogs.com/mikevictor07/p/10006553.html ),以下内容为个人理解,仅 ...

  7. jmeter3.3—插件管理器的安装

    一.介绍JMeter Plugins 一直以来, JMeter Plugins 为我们提供了很多高价值的JMeter插件,比如: 用于服务器性能监视的 PerfMon Metrics Collecto ...

  8. php中赋值和引用真真的理解

    php的引用(就是在变量或者函数.对象等前面加上&符号) //最重要就是 删除引用的变量 ,只是引用的变量访问不了,但是内容并没有销毁 在PHP 中引用的意思是:不同的名字访问同一个变量内容. ...

  9. 如何使用 sqlite3 访问 Android 手机的数据库

    如何设置Android手机的sqlite3命令环境 http://www.cnblogs.com/linjiqin/archive/2011/11/28/2266619.html SQLite3 为a ...

  10. HDU 5280 Senior&#39;s Array 最大区间和

    题意:给定n个数.要求必须将当中某个数改为P,求修改后最大的区间和能够为多少. 水题.枚举每一个区间.假设该区间不改动(即改动该区间以外的数),则就为该区间和,若该区间要改动,由于必须改动,所以肯定是 ...