之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构。首先,以DecisionTreeParams作为起始,这里存储了跟树相关的最基础的参数,注意它扩展自PredictorParams。接下来为了区分分类器和回归器,提出了TreeClassifierParams和TreeRegressorParams,两者都直接扩展自Params,分别定义了树相关的分类器和回归器所需要的特殊参数。有了这三个trait,我们可以组合成DecisionTreeClassifierParams和DecisionTreeRegressorParams。然后,为了产生Ensemble的模型,提出了TreeEnsembleParams,专门针对树集合的模型,在此基础上,分别为随机森林和GBT定义了RandomForestParams和GBTParams,可以通过混入分类器和回归器,在这俩的基础上生成随机森林分类器、随机森林回归器、GBT分类器、GBT回归器。
        讲完了参数再讲讲模型,树相关的模型只有两种,DecisionTreeModel和TreeEnsembleModel,比较简单就不深入介绍了。
        下面说一下“树”结构的具体表示,这里涉及到Node.scala文件。树由节点组成,这个文件中定义了四种节点,分别是,Node,这是最基础的节点,是其它三种节点的超类;LeafNode,这是叶子节点,InternalNode,这是中间节点。最后一种是LearningNode,这种节点是专门用来在树学习过程中使用的,其余三类节点都是不可变对象,唯有这个LearningNode是可变对象,可以在学习的过程中被不断迭代,在学习结束之后转换成LeafNode和InternalNode。
        接下来再说说数据样本在树学习中的表示形式。由于树学习的特殊性,在正式用数据进行学习之前,需要对数据进行转换和丰富,Spark ML为了方便树学习,提出了很多数据结构。用于树学习的数据样本,通常可以包含离散和连续两种特征,本质上树学习仅支持离散特征,因此在正式学习前都需要对特征进行离散化。离散化后的样本数据可以表示成一个数组,数组中的元素就是该样本在不同特征的取值。因此,一个原始的LabeledPoint就可以被表示为一个TreePoint,它包含两个数据,Label和binFeature(就是数据特征数组)。这样对于单棵树的学习足够了,但是对于随机森林一类的多树的学习还不行,对于随机森林来说,同样的一个数据样本点可能在森林中的任一一颗树上出现(采样),因此在树学习时,我们需要记录这个数据节点在某棵树上有没有出现过,我们用一个数组来表示一个节点在各树上出现的次数,其中数组长度代表森林中树的数量,数组元素代表数据节点在这棵树上出现的次数,因此就形成了BaggedPoint结构。最后,我们知道随机森林学习的过程,实际上就是不断扩展节点的过程,当前这个数据在每颗树上的节点位置,也需要记录,因此提出NodeInCache数据结构,这同样也是个数组,数组长度是森林中树的个数,数组元素代表当前数据在这棵树上的节点编号。某棵树上的数据刚开始都在根节点,随着学习的进行,数据会逐渐从根节点向下走,直到某个节点因为不能再分解,形成叶节点为止。因此这个数组的内容也是在不断迭代的。关于树的结构,这里再多说一句,如果某个节点的编号是id,那么它的左子节点的编号就是id<<1,右子节点编号就是id<<1+1,因此在一棵树中,某个节点的位置和编号是一一对应的。
        做了这么多铺垫,最后终于说到树学习的过程了,具体算法在RandomForest.scala文件中,这个文件是整个tree文件夹的集大成者,其中包含了随机森林模型学习的核心代码。学习的过程可以分为以下几个步骤,第一,创建metadata,在学习之前,我们要先对数据集有一个大致的了解,比如有多少数据、每个数据多少维度等等,因此需要先从原始数据中总结出metadata。第二,寻找切分点,由于数据特征不可能都是离散型,或者即便是离散型,但因为类别太多,需要进一步处理,因此需要在这里对每个特征寻找切分点,然后把每一个特征值放入特定的切分区间内,形成原始的binFeature向量。第三,转换为BaggedPoint,关于BaggedPoint的含义,之前说的很清楚了,主要是为了方便学习和记录,加入了额外的数据结构。第四,初始化NodeInCache,这个也是为了加速训练过程。第五,初始化NodeStack,我们知道在学习随机森林模型时,需要不断扩展已有的树,把节点拆分为左右子节点,那么已有的树由那么多节点,先拆分哪些节点呢?Spark ML的做法是,把待拆分的节点放入一个栈中,每次根据现有内存的容量,从中拿出若干个点用来拆分,然后把拆分好的点再作为中间节点放入栈中。重复这个过程,直到栈中没有节点为止。因此这里的第五步就是初始化这个节点栈。
        第六,就正式进入随机森林学习的大循环了。这个大循环包含三个步骤,1,master端,在NodeStack中选择节点和特征,在随机森林学习中,每个节点的待选特征集是随机的,不一定就是特征全集,所以这里除了要选节点之外,还要选特征集;2,worker端,计算当前worker上各数据针对选出节点的充分统计量,然后把各worker上的充分统计量在某一个worker上汇总,由这个worker通过算法确定切分点,然后把选好的切分标准返回给master;3,master端,收集各worker返回的切分结果,对模型进行迭代,然后在下一轮循环前,把当前的模型push给各个worker。循环的三个步骤中,计算量最大的就是第二个步骤,既要计算充分统计量,又要把各worker上的统计量汇总,这里的计算和通信压力都很大,是随机森林的性能瓶颈所在。因此在使用随机森林模型时,在数据量允许的情况下,尽量不要把executor数量设置的太高,否则通信的成本会很大。
        为什么没讲决策树直接讲了随机森林呢?因为角色数就是只有一棵树的随机森林嘛,简化一下就得到了。
        以上就是Spark ML中树模型的基础内容,讲的还比较烦,有时间的话再专门出一篇博客,详细讲解随机森林学习中的优化方法。还请大家不吝赐教。

Spark ML源码分析之四 树的更多相关文章

  1. Spark ML源码分析之二 从单机到分布式

            前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spar ...

  2. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

  3. Spark ML源码分析之三 分类器

            前面跟大家扯了这么多废话,终于到具体的机器学习模型了.大部分机器学习的教程,总要从监督学习开始讲起,而监督学习的众多算法当中,又以分类算法最为基础,原因在于分类问题非常的单纯直接,几乎 ...

  4. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  5. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  6. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  7. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  8. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

  9. 【Spark SQL 源码分析系列文章】

    从决定写Spark SQL源码分析的文章,到现在一个月的时间里,陆陆续续差不多快完成了,这里也做一个整合和索引,方便大家阅读,这里给出阅读顺序 :) 第一篇 Spark SQL源码分析之核心流程 第二 ...

随机推荐

  1. Docker 小记 — Compose & Swarm

    前言 任何相对完整的应用服务都不可能是由单一的程序来完成支持,计划使用 Docker 来部署的服务更是如此.大型服务需要进行拆分,形成微服务集群方能增强其稳定性和可维护性.本篇随笔将对 Docker ...

  2. Tornado模块

    Tornado 一个轻量级的Web框架 简介 1.Tornado在设计之初就考虑到了性能因素,旨在解决C10K问题,这样的设计使得其成为一个拥有非常高性能的框架.此外,它还拥有处理安全性.用户验证.社 ...

  3. 申请Jetbrain教育帐号,免费使用一年

    JetBrains是一家捷克的软件开发公司.旗下IDE产品有(不限于):(1) IntelliJ,IDEA  Java集成开发工具:(2) PHPStorm,PHP 集成开发工具:(3) PyChar ...

  4. 通过Java WebService接口从服务端下载文件

    一. 前言 本文讲述如何通过webservice接口,从服务端下载文件.报告到客户端.适用于跨系统间的文件交互,传输文件不大的情况(控制在几百M以内).对于这种情况搭建一个FTP环境,增加了系统部署的 ...

  5. hive java编写udf函数

    (一)创建JAVA 代码--例子 package hiveOpt; import org.apache.hadoop.hive.ql.exec.UDF;import org.apache.hadoop ...

  6. JavaScript一看就懂(1)作用域

    函数级作用域 1.函数外声明的变量为全局变量,函数内可以直接访问全局变量: var global_var = 10; //全局变量 function a(){ alert(global_var); / ...

  7. Redis持久化方案

    Redis可以实现数据的持久化存储,即将数据保存到磁盘上. Redis的持久化存储提供两种方式:RDB与AOF.RDB是默认配置.AOF需要手动开启. 默认redis是会以快照的形式将数据持久化到磁盘 ...

  8. 平衡树(Splay)模板

    支持区间操作. 单点操作和区间操作分开使用,需要一起使用需要部分修改. 对应题目FJUTOJ2490 #include<cstdio> #include<cstring> #i ...

  9. Linux-Centos7----安装Python的psutil模块插件

    # wget https://pypi.python.org/packages/source/p/psutil/psutil-2.1.3.tar.gz # tar zxvf psutil-2.1.3. ...

  10. Problem : 1012 ( u Calculate e )

    /*tips:本题只有输入,没有输出,在线测试只检测结果,所以将前面几个结果罗列出来就OK了.为了格式输出问题纠结了半天,最后答案竟然还是错的....所以啊,做题还是得灵活变通.*/ #include ...