之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构。首先,以DecisionTreeParams作为起始,这里存储了跟树相关的最基础的参数,注意它扩展自PredictorParams。接下来为了区分分类器和回归器,提出了TreeClassifierParams和TreeRegressorParams,两者都直接扩展自Params,分别定义了树相关的分类器和回归器所需要的特殊参数。有了这三个trait,我们可以组合成DecisionTreeClassifierParams和DecisionTreeRegressorParams。然后,为了产生Ensemble的模型,提出了TreeEnsembleParams,专门针对树集合的模型,在此基础上,分别为随机森林和GBT定义了RandomForestParams和GBTParams,可以通过混入分类器和回归器,在这俩的基础上生成随机森林分类器、随机森林回归器、GBT分类器、GBT回归器。
        讲完了参数再讲讲模型,树相关的模型只有两种,DecisionTreeModel和TreeEnsembleModel,比较简单就不深入介绍了。
        下面说一下“树”结构的具体表示,这里涉及到Node.scala文件。树由节点组成,这个文件中定义了四种节点,分别是,Node,这是最基础的节点,是其它三种节点的超类;LeafNode,这是叶子节点,InternalNode,这是中间节点。最后一种是LearningNode,这种节点是专门用来在树学习过程中使用的,其余三类节点都是不可变对象,唯有这个LearningNode是可变对象,可以在学习的过程中被不断迭代,在学习结束之后转换成LeafNode和InternalNode。
        接下来再说说数据样本在树学习中的表示形式。由于树学习的特殊性,在正式用数据进行学习之前,需要对数据进行转换和丰富,Spark ML为了方便树学习,提出了很多数据结构。用于树学习的数据样本,通常可以包含离散和连续两种特征,本质上树学习仅支持离散特征,因此在正式学习前都需要对特征进行离散化。离散化后的样本数据可以表示成一个数组,数组中的元素就是该样本在不同特征的取值。因此,一个原始的LabeledPoint就可以被表示为一个TreePoint,它包含两个数据,Label和binFeature(就是数据特征数组)。这样对于单棵树的学习足够了,但是对于随机森林一类的多树的学习还不行,对于随机森林来说,同样的一个数据样本点可能在森林中的任一一颗树上出现(采样),因此在树学习时,我们需要记录这个数据节点在某棵树上有没有出现过,我们用一个数组来表示一个节点在各树上出现的次数,其中数组长度代表森林中树的数量,数组元素代表数据节点在这棵树上出现的次数,因此就形成了BaggedPoint结构。最后,我们知道随机森林学习的过程,实际上就是不断扩展节点的过程,当前这个数据在每颗树上的节点位置,也需要记录,因此提出NodeInCache数据结构,这同样也是个数组,数组长度是森林中树的个数,数组元素代表当前数据在这棵树上的节点编号。某棵树上的数据刚开始都在根节点,随着学习的进行,数据会逐渐从根节点向下走,直到某个节点因为不能再分解,形成叶节点为止。因此这个数组的内容也是在不断迭代的。关于树的结构,这里再多说一句,如果某个节点的编号是id,那么它的左子节点的编号就是id<<1,右子节点编号就是id<<1+1,因此在一棵树中,某个节点的位置和编号是一一对应的。
        做了这么多铺垫,最后终于说到树学习的过程了,具体算法在RandomForest.scala文件中,这个文件是整个tree文件夹的集大成者,其中包含了随机森林模型学习的核心代码。学习的过程可以分为以下几个步骤,第一,创建metadata,在学习之前,我们要先对数据集有一个大致的了解,比如有多少数据、每个数据多少维度等等,因此需要先从原始数据中总结出metadata。第二,寻找切分点,由于数据特征不可能都是离散型,或者即便是离散型,但因为类别太多,需要进一步处理,因此需要在这里对每个特征寻找切分点,然后把每一个特征值放入特定的切分区间内,形成原始的binFeature向量。第三,转换为BaggedPoint,关于BaggedPoint的含义,之前说的很清楚了,主要是为了方便学习和记录,加入了额外的数据结构。第四,初始化NodeInCache,这个也是为了加速训练过程。第五,初始化NodeStack,我们知道在学习随机森林模型时,需要不断扩展已有的树,把节点拆分为左右子节点,那么已有的树由那么多节点,先拆分哪些节点呢?Spark ML的做法是,把待拆分的节点放入一个栈中,每次根据现有内存的容量,从中拿出若干个点用来拆分,然后把拆分好的点再作为中间节点放入栈中。重复这个过程,直到栈中没有节点为止。因此这里的第五步就是初始化这个节点栈。
        第六,就正式进入随机森林学习的大循环了。这个大循环包含三个步骤,1,master端,在NodeStack中选择节点和特征,在随机森林学习中,每个节点的待选特征集是随机的,不一定就是特征全集,所以这里除了要选节点之外,还要选特征集;2,worker端,计算当前worker上各数据针对选出节点的充分统计量,然后把各worker上的充分统计量在某一个worker上汇总,由这个worker通过算法确定切分点,然后把选好的切分标准返回给master;3,master端,收集各worker返回的切分结果,对模型进行迭代,然后在下一轮循环前,把当前的模型push给各个worker。循环的三个步骤中,计算量最大的就是第二个步骤,既要计算充分统计量,又要把各worker上的统计量汇总,这里的计算和通信压力都很大,是随机森林的性能瓶颈所在。因此在使用随机森林模型时,在数据量允许的情况下,尽量不要把executor数量设置的太高,否则通信的成本会很大。
        为什么没讲决策树直接讲了随机森林呢?因为角色数就是只有一棵树的随机森林嘛,简化一下就得到了。
        以上就是Spark ML中树模型的基础内容,讲的还比较烦,有时间的话再专门出一篇博客,详细讲解随机森林学习中的优化方法。还请大家不吝赐教。

Spark ML源码分析之四 树的更多相关文章

  1. Spark ML源码分析之二 从单机到分布式

            前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spar ...

  2. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

  3. Spark ML源码分析之三 分类器

            前面跟大家扯了这么多废话,终于到具体的机器学习模型了.大部分机器学习的教程,总要从监督学习开始讲起,而监督学习的众多算法当中,又以分类算法最为基础,原因在于分类问题非常的单纯直接,几乎 ...

  4. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  5. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  6. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  7. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  8. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

  9. 【Spark SQL 源码分析系列文章】

    从决定写Spark SQL源码分析的文章,到现在一个月的时间里,陆陆续续差不多快完成了,这里也做一个整合和索引,方便大家阅读,这里给出阅读顺序 :) 第一篇 Spark SQL源码分析之核心流程 第二 ...

随机推荐

  1. 芝麻HTTP:Gerapy的安装

    Gerapy是一个Scrapy分布式管理模块,本节就来介绍一下它的安装方式. 1. 相关链接 GitHub:https://github.com/Gerapy 2. pip安装 这里推荐使用pip安装 ...

  2. R语言实现二分查找法

    二分查找时间复杂度O(h)=O(log2n),具备非常高的效率,用R处理数据时有时候需要用到二分查找法以便快速定位 Rbisect <- function(lst, value){ low=1 ...

  3. 在TextBox中敲击回车执行ASP.NET后台事件

    1.在TextBox中敲击回车执行ASP.NET后台事件   0.说明 页面中有一个用于搜索的TextBox,希望能在输入内容后直接回车开始搜索,而不是手动去点击它旁边的搜索按钮 但因为该TextBo ...

  4. 网页加载进度的实现--JavaScript基础

    总结了一些网页加载进度的实现方式…… 1.定时器实现加载进度 <!DOCTYPE html><html lang="en"><head> < ...

  5. Dubbo 新编程模型之外部化配置

    外部化配置(External Configuration) 在Dubbo 注解驱动例子中,无论是服务提供方,还是服务消费方,均需要转配相关配置Bean: @Bean public Applicatio ...

  6. POJ 2516 Minimum Cost (费用流)

    题面 Dearboy, a goods victualer, now comes to a big problem, and he needs your help. In his sale area ...

  7. javascript ES5、ES6的一些知识

    ES6 标签(空格分隔): ES6 严格模式 "use strict" 注意:严格模式也有作用域,如果在某个函数内部声明的话,只在该函数内部有作用 1) 严格模式下全局变量声明必须 ...

  8. fitnesse - 框架介绍

    fitnesse - 框架介绍 2017-09-29 目录: 1 fitnesse是什么?2 框架介绍3 与junit.testng比较,fitnesse教其他框架有什么优势 1 fitnesse是什 ...

  9. java——对象学习笔记

    1.面向对象(OOP)的三大特性 对象的行为(behavior):可以对对象施加哪些操作,或者可以对对象施加哪些方法. 对象的状态(state):当施加那些方法后,对象如何响应. 对象标识(ident ...

  10. Ubuntu14.04 设置wifi热点

    Ubuntu14.04 设置wifi热点 $ sudo add-apt-repository ppa:nilarimogard/webupd8 $ sudo apt-get update $ sudo ...