之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构。首先,以DecisionTreeParams作为起始,这里存储了跟树相关的最基础的参数,注意它扩展自PredictorParams。接下来为了区分分类器和回归器,提出了TreeClassifierParams和TreeRegressorParams,两者都直接扩展自Params,分别定义了树相关的分类器和回归器所需要的特殊参数。有了这三个trait,我们可以组合成DecisionTreeClassifierParams和DecisionTreeRegressorParams。然后,为了产生Ensemble的模型,提出了TreeEnsembleParams,专门针对树集合的模型,在此基础上,分别为随机森林和GBT定义了RandomForestParams和GBTParams,可以通过混入分类器和回归器,在这俩的基础上生成随机森林分类器、随机森林回归器、GBT分类器、GBT回归器。
        讲完了参数再讲讲模型,树相关的模型只有两种,DecisionTreeModel和TreeEnsembleModel,比较简单就不深入介绍了。
        下面说一下“树”结构的具体表示,这里涉及到Node.scala文件。树由节点组成,这个文件中定义了四种节点,分别是,Node,这是最基础的节点,是其它三种节点的超类;LeafNode,这是叶子节点,InternalNode,这是中间节点。最后一种是LearningNode,这种节点是专门用来在树学习过程中使用的,其余三类节点都是不可变对象,唯有这个LearningNode是可变对象,可以在学习的过程中被不断迭代,在学习结束之后转换成LeafNode和InternalNode。
        接下来再说说数据样本在树学习中的表示形式。由于树学习的特殊性,在正式用数据进行学习之前,需要对数据进行转换和丰富,Spark ML为了方便树学习,提出了很多数据结构。用于树学习的数据样本,通常可以包含离散和连续两种特征,本质上树学习仅支持离散特征,因此在正式学习前都需要对特征进行离散化。离散化后的样本数据可以表示成一个数组,数组中的元素就是该样本在不同特征的取值。因此,一个原始的LabeledPoint就可以被表示为一个TreePoint,它包含两个数据,Label和binFeature(就是数据特征数组)。这样对于单棵树的学习足够了,但是对于随机森林一类的多树的学习还不行,对于随机森林来说,同样的一个数据样本点可能在森林中的任一一颗树上出现(采样),因此在树学习时,我们需要记录这个数据节点在某棵树上有没有出现过,我们用一个数组来表示一个节点在各树上出现的次数,其中数组长度代表森林中树的数量,数组元素代表数据节点在这棵树上出现的次数,因此就形成了BaggedPoint结构。最后,我们知道随机森林学习的过程,实际上就是不断扩展节点的过程,当前这个数据在每颗树上的节点位置,也需要记录,因此提出NodeInCache数据结构,这同样也是个数组,数组长度是森林中树的个数,数组元素代表当前数据在这棵树上的节点编号。某棵树上的数据刚开始都在根节点,随着学习的进行,数据会逐渐从根节点向下走,直到某个节点因为不能再分解,形成叶节点为止。因此这个数组的内容也是在不断迭代的。关于树的结构,这里再多说一句,如果某个节点的编号是id,那么它的左子节点的编号就是id<<1,右子节点编号就是id<<1+1,因此在一棵树中,某个节点的位置和编号是一一对应的。
        做了这么多铺垫,最后终于说到树学习的过程了,具体算法在RandomForest.scala文件中,这个文件是整个tree文件夹的集大成者,其中包含了随机森林模型学习的核心代码。学习的过程可以分为以下几个步骤,第一,创建metadata,在学习之前,我们要先对数据集有一个大致的了解,比如有多少数据、每个数据多少维度等等,因此需要先从原始数据中总结出metadata。第二,寻找切分点,由于数据特征不可能都是离散型,或者即便是离散型,但因为类别太多,需要进一步处理,因此需要在这里对每个特征寻找切分点,然后把每一个特征值放入特定的切分区间内,形成原始的binFeature向量。第三,转换为BaggedPoint,关于BaggedPoint的含义,之前说的很清楚了,主要是为了方便学习和记录,加入了额外的数据结构。第四,初始化NodeInCache,这个也是为了加速训练过程。第五,初始化NodeStack,我们知道在学习随机森林模型时,需要不断扩展已有的树,把节点拆分为左右子节点,那么已有的树由那么多节点,先拆分哪些节点呢?Spark ML的做法是,把待拆分的节点放入一个栈中,每次根据现有内存的容量,从中拿出若干个点用来拆分,然后把拆分好的点再作为中间节点放入栈中。重复这个过程,直到栈中没有节点为止。因此这里的第五步就是初始化这个节点栈。
        第六,就正式进入随机森林学习的大循环了。这个大循环包含三个步骤,1,master端,在NodeStack中选择节点和特征,在随机森林学习中,每个节点的待选特征集是随机的,不一定就是特征全集,所以这里除了要选节点之外,还要选特征集;2,worker端,计算当前worker上各数据针对选出节点的充分统计量,然后把各worker上的充分统计量在某一个worker上汇总,由这个worker通过算法确定切分点,然后把选好的切分标准返回给master;3,master端,收集各worker返回的切分结果,对模型进行迭代,然后在下一轮循环前,把当前的模型push给各个worker。循环的三个步骤中,计算量最大的就是第二个步骤,既要计算充分统计量,又要把各worker上的统计量汇总,这里的计算和通信压力都很大,是随机森林的性能瓶颈所在。因此在使用随机森林模型时,在数据量允许的情况下,尽量不要把executor数量设置的太高,否则通信的成本会很大。
        为什么没讲决策树直接讲了随机森林呢?因为角色数就是只有一棵树的随机森林嘛,简化一下就得到了。
        以上就是Spark ML中树模型的基础内容,讲的还比较烦,有时间的话再专门出一篇博客,详细讲解随机森林学习中的优化方法。还请大家不吝赐教。

Spark ML源码分析之四 树的更多相关文章

  1. Spark ML源码分析之二 从单机到分布式

            前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spar ...

  2. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

  3. Spark ML源码分析之三 分类器

            前面跟大家扯了这么多废话,终于到具体的机器学习模型了.大部分机器学习的教程,总要从监督学习开始讲起,而监督学习的众多算法当中,又以分类算法最为基础,原因在于分类问题非常的单纯直接,几乎 ...

  4. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  5. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  6. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  7. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  8. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

  9. 【Spark SQL 源码分析系列文章】

    从决定写Spark SQL源码分析的文章,到现在一个月的时间里,陆陆续续差不多快完成了,这里也做一个整合和索引,方便大家阅读,这里给出阅读顺序 :) 第一篇 Spark SQL源码分析之核心流程 第二 ...

随机推荐

  1. Linux显示USB设备

    Linux显示USB设备 youhaidong@youhaidong-ThinkPad-Edge-E545:~$ lsusb -tv /: Bus 08.Port 1: Dev 1, Class=ro ...

  2. DirectX:函数可以连接任意两个filter

    函数可以连接任意两个filter HRESULT ConnectFilters( IBaseFilter *pSrc, IBaseFilter *pDest ) { IPin *pIn = 0; IP ...

  3. EntityFramework Core 2.0 Explicitly Compiled Query(显式编译查询)

    前言 EntityFramework Core 2.0引入了显式编译查询,在查询数据时预先编译好LINQ查询便于在请求数据时能够立即响应.显式编译查询提供了高可用场景,通过使用显式编译的查询可以提高查 ...

  4. 【转】Nginx的启动、停止与重启

    Nginx的启动.停止与重启 启动 启动代码格式:nginx安装目录地址 -c nginx配置文件地址 例如: [root@LinuxServer sbin]# /usr/local/nginx/sb ...

  5. 小程序for循环中通过index实现单个点击事件

    <!--xml--> <view class='content3-list' wx:for="{{listItems}}" > <view class ...

  6. Spring AOP介绍

    1.介绍 AOP(面向切面编程)对OOP(面向对象编程)是一种补充,它提供了另一种程序结构的思路.OOP的模块单元是class,而AOP的模块单元是aspect.Spring中一个关键的组件是AOP框 ...

  7. 创建文本节点createTextNode

    <!DOCTYPE HTML> <html> <head> <meta http-equiv="Content-Type" content ...

  8. 配置Tomcat线程参数maxThreads、acceptCount

    一.配置Tomcat/conf/server.xml修改配置 <Connector port="8080" protocol="org.apache.coyote. ...

  9. Bzoj2946:[POI2000] 最长公共子串

    题面 求多个串的最长公共子串 Sol 套路,拼在一起,二分答案+后缀数组判定 把大于答案的\(height\)分组,然后计算出一个组内是否有所有串的后缀 由于串只有\(5\)个开个桶就好了 # inc ...

  10. MySQL根据出生日期计算年龄的五种方法比较

    方法一 SELECT DATE_FORMAT(FROM_DAYS(TO_DAYS(NOW())-TO_DAYS(birthday)), '%Y')+0 AS age 方法一,作者也说出了缺陷,就是当日 ...