之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构。首先,以DecisionTreeParams作为起始,这里存储了跟树相关的最基础的参数,注意它扩展自PredictorParams。接下来为了区分分类器和回归器,提出了TreeClassifierParams和TreeRegressorParams,两者都直接扩展自Params,分别定义了树相关的分类器和回归器所需要的特殊参数。有了这三个trait,我们可以组合成DecisionTreeClassifierParams和DecisionTreeRegressorParams。然后,为了产生Ensemble的模型,提出了TreeEnsembleParams,专门针对树集合的模型,在此基础上,分别为随机森林和GBT定义了RandomForestParams和GBTParams,可以通过混入分类器和回归器,在这俩的基础上生成随机森林分类器、随机森林回归器、GBT分类器、GBT回归器。
        讲完了参数再讲讲模型,树相关的模型只有两种,DecisionTreeModel和TreeEnsembleModel,比较简单就不深入介绍了。
        下面说一下“树”结构的具体表示,这里涉及到Node.scala文件。树由节点组成,这个文件中定义了四种节点,分别是,Node,这是最基础的节点,是其它三种节点的超类;LeafNode,这是叶子节点,InternalNode,这是中间节点。最后一种是LearningNode,这种节点是专门用来在树学习过程中使用的,其余三类节点都是不可变对象,唯有这个LearningNode是可变对象,可以在学习的过程中被不断迭代,在学习结束之后转换成LeafNode和InternalNode。
        接下来再说说数据样本在树学习中的表示形式。由于树学习的特殊性,在正式用数据进行学习之前,需要对数据进行转换和丰富,Spark ML为了方便树学习,提出了很多数据结构。用于树学习的数据样本,通常可以包含离散和连续两种特征,本质上树学习仅支持离散特征,因此在正式学习前都需要对特征进行离散化。离散化后的样本数据可以表示成一个数组,数组中的元素就是该样本在不同特征的取值。因此,一个原始的LabeledPoint就可以被表示为一个TreePoint,它包含两个数据,Label和binFeature(就是数据特征数组)。这样对于单棵树的学习足够了,但是对于随机森林一类的多树的学习还不行,对于随机森林来说,同样的一个数据样本点可能在森林中的任一一颗树上出现(采样),因此在树学习时,我们需要记录这个数据节点在某棵树上有没有出现过,我们用一个数组来表示一个节点在各树上出现的次数,其中数组长度代表森林中树的数量,数组元素代表数据节点在这棵树上出现的次数,因此就形成了BaggedPoint结构。最后,我们知道随机森林学习的过程,实际上就是不断扩展节点的过程,当前这个数据在每颗树上的节点位置,也需要记录,因此提出NodeInCache数据结构,这同样也是个数组,数组长度是森林中树的个数,数组元素代表当前数据在这棵树上的节点编号。某棵树上的数据刚开始都在根节点,随着学习的进行,数据会逐渐从根节点向下走,直到某个节点因为不能再分解,形成叶节点为止。因此这个数组的内容也是在不断迭代的。关于树的结构,这里再多说一句,如果某个节点的编号是id,那么它的左子节点的编号就是id<<1,右子节点编号就是id<<1+1,因此在一棵树中,某个节点的位置和编号是一一对应的。
        做了这么多铺垫,最后终于说到树学习的过程了,具体算法在RandomForest.scala文件中,这个文件是整个tree文件夹的集大成者,其中包含了随机森林模型学习的核心代码。学习的过程可以分为以下几个步骤,第一,创建metadata,在学习之前,我们要先对数据集有一个大致的了解,比如有多少数据、每个数据多少维度等等,因此需要先从原始数据中总结出metadata。第二,寻找切分点,由于数据特征不可能都是离散型,或者即便是离散型,但因为类别太多,需要进一步处理,因此需要在这里对每个特征寻找切分点,然后把每一个特征值放入特定的切分区间内,形成原始的binFeature向量。第三,转换为BaggedPoint,关于BaggedPoint的含义,之前说的很清楚了,主要是为了方便学习和记录,加入了额外的数据结构。第四,初始化NodeInCache,这个也是为了加速训练过程。第五,初始化NodeStack,我们知道在学习随机森林模型时,需要不断扩展已有的树,把节点拆分为左右子节点,那么已有的树由那么多节点,先拆分哪些节点呢?Spark ML的做法是,把待拆分的节点放入一个栈中,每次根据现有内存的容量,从中拿出若干个点用来拆分,然后把拆分好的点再作为中间节点放入栈中。重复这个过程,直到栈中没有节点为止。因此这里的第五步就是初始化这个节点栈。
        第六,就正式进入随机森林学习的大循环了。这个大循环包含三个步骤,1,master端,在NodeStack中选择节点和特征,在随机森林学习中,每个节点的待选特征集是随机的,不一定就是特征全集,所以这里除了要选节点之外,还要选特征集;2,worker端,计算当前worker上各数据针对选出节点的充分统计量,然后把各worker上的充分统计量在某一个worker上汇总,由这个worker通过算法确定切分点,然后把选好的切分标准返回给master;3,master端,收集各worker返回的切分结果,对模型进行迭代,然后在下一轮循环前,把当前的模型push给各个worker。循环的三个步骤中,计算量最大的就是第二个步骤,既要计算充分统计量,又要把各worker上的统计量汇总,这里的计算和通信压力都很大,是随机森林的性能瓶颈所在。因此在使用随机森林模型时,在数据量允许的情况下,尽量不要把executor数量设置的太高,否则通信的成本会很大。
        为什么没讲决策树直接讲了随机森林呢?因为角色数就是只有一棵树的随机森林嘛,简化一下就得到了。
        以上就是Spark ML中树模型的基础内容,讲的还比较烦,有时间的话再专门出一篇博客,详细讲解随机森林学习中的优化方法。还请大家不吝赐教。

Spark ML源码分析之四 树的更多相关文章

  1. Spark ML源码分析之二 从单机到分布式

            前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spar ...

  2. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

  3. Spark ML源码分析之三 分类器

            前面跟大家扯了这么多废话,终于到具体的机器学习模型了.大部分机器学习的教程,总要从监督学习开始讲起,而监督学习的众多算法当中,又以分类算法最为基础,原因在于分类问题非常的单纯直接,几乎 ...

  4. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  5. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  6. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  7. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  8. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

  9. 【Spark SQL 源码分析系列文章】

    从决定写Spark SQL源码分析的文章,到现在一个月的时间里,陆陆续续差不多快完成了,这里也做一个整合和索引,方便大家阅读,这里给出阅读顺序 :) 第一篇 Spark SQL源码分析之核心流程 第二 ...

随机推荐

  1. 芝麻HTTP:Python爬虫入门之正则表达式

    1.了解正则表达式 正则表达式是对字符串操作的一种逻辑公式,就是用事先定义好的一些特定字符.及这些特定字符的组合,组成一个"规则字符串",这个"规则字符串"用来 ...

  2. .Net Core从命令行读取配置文件

    最近在学习博客园腾飞(jesse)的.Net Core视频教程,收益匪浅,在此作推荐 : http://video.jessetalk.cn/ 言归正传,.Net Core应用程序中如何通过命令行读取 ...

  3. iOS - Core Animation 核心动画

    1.UIView 动画 具体讲解见 iOS - UIView 动画 2.UIImageView 动画 具体讲解见 iOS - UIImageView 动画 3.CADisplayLink 定时器 具体 ...

  4. 原生态的ajax代码

    <script type="text/javascript"> var xmlhttprequest; function GetXmlHttpRequest() { i ...

  5. Think with Google 京东如何玩转TensorFlow?

    2018 年 2 月 6 日,Think with Google 年度峰会在北京召开.在本次峰会上,我们分享了 Google 和我们的合作伙伴在 AI (人工智能) 方面取得的成绩,探讨如何利用人工智 ...

  6. linux jdk 和tomcat环境变量配置

    系统版本:centos6.5版本 java版本:1.8 一.准备工作 1. java -version 检查是否有java环境,没有则需要去安装并配置到环境变量中. 2.下载tomcat包,下载地址: ...

  7. 【转】UML的9种图例解析

    UML图中类之间的关系:依赖,泛化,关联,聚合,组合,实现 类与类图 1) 类(Class)封装了数据和行为,是面向对象的重要组成部分,它是具有相同属性.操作.关系的对象集合的总称. 2) 在系统中, ...

  8. python自动拉取备份压缩包并删除3天前的旧备份

    业务场景,异地机房自动拉取已备份好的tar.gz数据库压缩包,并且只保留3天内的压缩包文件,用python实现 #!/usr/bin/env python import requests,time,o ...

  9. 403 forbidden 错误解决方案

    在本机启动程序,访问手机移动端(wap)的程序时,返回404无法访问,控制台报错403 forbidden,网上找问题所在: [ 以下引用百度知道:https://zhidao.baidu.com/q ...

  10. The method queryForMap(String, Object...) from the type JdbcTemplate refers to the missing type DataAccessException

    Add spring-tx jar of your spring version to your classpath.