Spark中决策树源码分析
1.Example
使用Spark MLlib中决策树分类器API,训练出一个决策树模型,使用Python开发。
"""
Decision Tree Classification Example.
"""
from __future__ import print_function
from pyspark import SparkContext
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
if __name__ == "__main__":
sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
# 加载和解析数据文件为RDD
dataPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/sample_libsvm_data.txt"
print(dataPath)
data = MLUtils.loadLibSVMFile(sc,dataPath)
# 将数据集分割为训练数据集和测试数据集
(trainingData,testData) = data.randomSplit([0.7,0.3])
print("train data count: " + str(trainingData.count()))
print("test data count : " + str(testData.count()))
# 训练决策树分类器
# categoricalFeaturesInfo 为空,表示所有的特征均为连续值
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
impurity='gini', maxDepth=5, maxBins=32)
# 测试数据集上预测
predictions = model.predict(testData.map(lambda x: x.features))
# 打包真实值与预测值
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
# 统计预测错误的样本的频率
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
print('Decision Tree Test Error = %5.3f%%'%(testErr*100))
print("Decision Tree Learned classifiction tree model : ")
print(model.toDebugString())
# 保存和加载训练好的模型
modelPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/myDecisionTreeClassificationModel"
model.save(sc, modelPath)
sameModel = DecisionTreeModel.load(sc, modelPath)
2.决策树源码分析
决策树分类器API为DecisionTree.trainClassifier,进入源码分析。
源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala。
@Since("1.1.0")
def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo)
}
训练出一个分类器,然后调用了train方法。
@Since("1.0.0")
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
}
train方法首先将模型类型(分类或者回归)、信息增益指标、决策树深度、分类数目、最大切分箱子数等参数封装为Strategy,然后新建一个DecisionTree对象,并调用run方法。
@Since("1.0.0")
class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
extends Serializable with Logging {
/**
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
@Since("1.0.0")
def this(strategy: Strategy) = this(strategy, seed = 0)
strategy.assertValid()
/**
* Method to train a decision tree model over an RDD
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}
}
run方法中首先新建一个RandomForest对象,将strategy、决策树数目设置为1,子集选择策略为"all"传递给RandomForest对象,然后调用RandomForest中的run方法,最后返回随机森林模型中的第一棵决策树。
也就是,决策树模型使用了随机森林模型进行训练,将决策树数目设置为1,然后将随机森林模型中的第一棵决策树作为结果,返回作为决策树训练模型。
3.随机森林源码分析
随机森林的源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala。
private class RandomForest (
private val strategy: Strategy,
private val numTrees: Int,
featureSubsetStrategy: String,
private val seed: Int)
extends Serializable with Logging {
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
|| Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
|| Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
featureSubsetStrategy, seed.toLong, None)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
}
在该文件开头,通过"import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}"将ml中的RandomForest引入,重新命名为NewRandomForest。
在RandomForest.run方法中,首先新建NewRandomForest模型,并调用该类的run方法,然后将生成的trees作为新建RandomForestModel的入参。
NewRandomForest,源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala。
由于涉及代码量较大,因此无法将代码展开,run方法主要有如下调用。
run方法
--->1. val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees,featureSubsetStrategy) # 对输入数据建立元数据
--->2. val splits = findSplits(retaggedInput, metadata, seed) # 对元数据中的特征进行切分
--->2.1 计算采样率,对输入样本进行采样
--->2.2 findSplitsBySorting(sampledInput, metadata, continuousFeatures) # 对采样后的样本中的特征进行切分
--->2.2.1 val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) # 针对连续型特征
--->2.2.2 val categories = extractMultiClassCategories(splitIndex + 1, featureArity) # 针对分类型特征,且特征无序
--->2.2.3 Array.empty[Split] # 针对分类型特征,且特征有序,训练时直接构造即可
--->3. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) # 将输入数据转换为树形数据
--->3.1 input.map { x => TreePoint.labeledPointToTreePoint(x, thresholds, featureArity) # 将LabeledPoint数据转换为TreePoint数据
--->3.2 arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) # 在(labeledPoint,feature)中找出一个离散值
--->4. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,withReplacement, seed) # 对输入数据进行采样
--->4.1 convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) #有放回采样
--->4.2 convertToBaggedRDDWithoutSampling(input) # 样本数为1,采样率为100%
--->4.3 convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) # 无放回采样
--->5. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage,metadata, rng) # 取得每棵树所有需要切分的结点
--->5.1 val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)} # 如果需要子采样,选择特征子集
--->5.2 val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L # 计算添加这个结点之后,是否有足够的内存
--->6. RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) # 找出最优切分点
--->6.1 val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) #找出每个结点最好的切分
--->7. new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, strategy.getNumClasses) # 返回决策树分类模型
4.Reference
spark mllib中的随机森林算法,实现源码以及使用介绍
Spark MLlib - Decision Tree源码分析
Spark中决策树源码分析的更多相关文章
- 【原】Spark中Client源码分析(二)
继续前一篇的内容.前一篇内容为: Spark中Client源码分析(一)http://www.cnblogs.com/yourarebest/p/5313006.html DriverClient中的 ...
- 【原】Spark中Master源码分析(二)
继续上一篇的内容.上一篇的内容为: Spark中Master源码分析(一) http://www.cnblogs.com/yourarebest/p/5312965.html 4.receive方法, ...
- 【原】 Spark中Worker源码分析(二)
继续前一篇的内容.前一篇内容为: Spark中Worker源码分析(一)http://www.cnblogs.com/yourarebest/p/5300202.html 4.receive方法, r ...
- 【原】Spark中Master源码分析(一)
Master作为集群的Manager,对于集群的健壮运行发挥着十分重要的作用.下面,我们一起了解一下Master是听从Client(Leader)的号召,如何管理好Worker的吧. 1.家当(静态属 ...
- 【原】Spark中Client源码分析(一)
在Spark Standalone中我们所谓的Client,它的任务其实是由AppClient和DriverClient共同完成的.AppClient是一个允许app(Client)和Spark集群通 ...
- 【原】 Spark中Worker源码分析(一)
Worker作为对于Spark集群的健壮运行起着举足轻重的作用,作为Master的奴隶,每15s向Master告诉自己还活着,一旦主人(Master>有了任务(Application),立马交给 ...
- Spark Scheduler模块源码分析之TaskScheduler和SchedulerBackend
本文是Scheduler模块源码分析的第二篇,第一篇Spark Scheduler模块源码分析之DAGScheduler主要分析了DAGScheduler.本文接下来结合Spark-1.6.0的源码继 ...
- Spark Scheduler模块源码分析之DAGScheduler
本文主要结合Spark-1.6.0的源码,对Spark中任务调度模块的执行过程进行分析.Spark Application在遇到Action操作时才会真正的提交任务并进行计算.这时Spark会根据Ac ...
- Spark RPC框架源码分析(一)简述
Spark RPC系列: Spark RPC框架源码分析(一)运行时序 Spark RPC框架源码分析(二)运行时序 Spark RPC框架源码分析(三)运行时序 一. Spark rpc框架概述 S ...
随机推荐
- 原生AJAX封装
var ajaxHelper = { /*1.0 浏览器兼容的方式创建异步对象*/ makeXHR: function () { //声明异步对象变量 var xmlHttp = false; //声 ...
- 我听说 C...
我听说在 c 语言的世界里,goto 和异常处理都是声名狼藉的东西,而我认为它们在一起就能化解各自的问题.
- pythonchallenge 解谜 Level 6
第六关地址 http://www.pythonchallenge.com/pc/def/channel.html 和前几关一样,首先看网页源码吧.反正不看也没办法... <html>< ...
- 验证mongodb副本集并实现自动切换primary~记录过程
接 验证mongodb主从复制过程 1.创建数据目录 同 验证mongodb主从复制过程 的实验一样,本次实验也是采用直接指定启动参数来启动mongodb数据库,本次实验我们需要启动三个数据库,为了与 ...
- WPF整理-Mutex确保Application单例运行
有时我们不希望我们的WPF应用程序可以同时运行有多个实例,当我们试图运行第二个实例的时候,已经运行的实例也应该弹出来. 我们可以用Mutex来实现 打开App.xaml.cs,在App类中添加如下内容 ...
- DM 多路径存储
DM多路径存储 系统环境:RHEL5.4 small install selinux and iptables disabled主机规划:主机网卡软件station133eth0: 192.168. ...
- 已经过事务处理的 MSMQ 绑定(转载)
https://msdn.microsoft.com/zh-cn/biztalk/ms751493 本示例演示如何使用消息队列 (MSMQ) 执行已经过事务处理的排队通信. 注意 本主题的末尾介绍了此 ...
- ASP.NET Core 十种方式扩展你的 Views
原文地址:http://asp.net-hacker.rocks/2016/02/18/extending-razor-views.html 作者:Jürgen Gutsch 翻译:杨晓东(Savor ...
- CYQ.Data 支持WPF相关的数据控件绑定(2013-08-09)
事件的结果 经过多天的思考及忙碌的开发及测试,CYQ.Data 终于在UI上全面支持WPF,至此,CYQ.Data 已经可以方便支持wpf的开发,同时,框架仍保留最低.net framework2.0 ...
- [Xamarin] 使用Webview 來做APP (转帖)
有時候,企業要求的沒有這麼多,他原本可能官方網站就已經有支援Mobile Web Design 他只需要原封不動的開發一個APP 也或是,他只是要型錄型,或是問卷調查的型的APP,這時候透過類似像if ...