支持连续变量和类别变量,类别变量就是某个属性有三个值,a,b,c,需要用Feature Transformers中的vectorindexer处理

上来是一堆参数

setMaxDepth:最大树深度

setMaxBins:最大装箱数,为了近似统计变量,比如变量有100个值,我只分成10段去做统计

setMinInstancesPerNode:每个节点最少实例

setMinInfoGain:最小信息增益

setMaxMemoryInMB:最大内存MB单位,这个值越大,一次处理的节点划分就越多

setCacheNodeIds:是否缓存节点id,缓存可以加速深层树的训练

setCheckpointInterval:检查点间隔,就是多少次迭代固化一次

setImpurity:随机森林有三种方式,entropy,gini,variance,回归肯定就是variance

setSubsamplingRate:设置采样率,就是每次选多少比例的样本构成新树

setSeed:采样种子,种子不变,采样结果不变

setNumTrees:设置森林里有多少棵树

setFeatureSubsetStrategy:设置特征子集选取策略,随机森林就是两个随机,构成树的样本随机,每棵树开始分裂的属性是随机的,其他跟决策树区别不大,注释这么写的

* The number of features to consider for splits at each tree node.
   * Supported options:
   *  - "auto": Choose automatically for task://默认策略
   *            If numTrees == 1, set to "all."     //决策树选择所有属性
   *            If numTrees > 1 (forest), set to "sqrt" for classification and //决策森林 分类选择属性数开平方,回归选择三分之一属性
   *              to "onethird" for regression.
   *  - "all": use all features
   *  - "onethird": use 1/3 of the features
   *  - "sqrt": use sqrt(number of features)
   *  - "log2": use log2(number of features) //还有取对数的
   * (default = "auto") 
   *
   * These various settings are based on the following references:
   *  - log2: tested in Breiman (2001)
   *  - sqrt: recommended by Breiman manual for random forests
   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
   *    package.

参数完毕,下面比较重要的是这段代码

val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

这个地比较蛋疼的是dataset.schema($(featuresCol))

/** An alias for [[getOrDefault()]]. */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

这段代码说明了$(featuresCol))只是求出一个字段名,实战中直接data.schema("features") ,data.schema("features")出来的是StructField,

case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable

StructField包含四个内容,最好知道一下,机器学习代码很多都用

回头说下getCategoricalFeatures,这个方法是识别一个属性是二值变量还是名义变量,例如a,b就是二值变量,a,b,c就是名义变量,最终把属性索引和变量值的个数放到一个map

这个函数的功能和vectorindexer类似,但是一般都用vectorindexer,因为实战中我们大都从sql读数据,sql读出来的数据metadata是空,无法识别二值变量还是名义变量

后面是

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    val strategy =
      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    val trees =
      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
        .map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = oldDataset.first().features.size

new RandomForestRegressionModel(trees, numFeatures)

可以看出还是调的RDD的旧方法,run这个方法是核心有1000多行,后面会详细跟踪,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一组决策树模型,也就是决策森林,另外一个是属性数,我们继续看RandomForestRegressionModel

在1.6版本每棵树的权重都是1,里面还有这么一个方法

override protected def transformImpl(dataset: DataFrame): DataFrame = {
    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

可以看到把模型通过广播的形式传给exectors,搞一个udf预测函数,最后通过withColumn把预测数据粘到原数据后面,

注意这个写法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一个参数是列名,第二个是计算出来的col,col是列类型,预测方法如下

override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
    // Predict average of tree predictions.
    // Ignore the weights since all are 1.0 for now.
    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
  }

可见预测用的是每个树的跟节点,predictImpl(features)返回这个根节点分配的叶节点,这是一个递归调用的过程,关于如何递归,后面也会拿出来细说,最后再用.prediction方法把所有树预测的结果相加求平均

后面有一个计算各属性重要性的方法

lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)

实现如下

private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
    val totalImportances = new OpenHashMap[Int, Double]()
    trees.foreach { tree =>
      // Aggregate feature importance vector for this tree      先计算每棵树的属性重要性值
      val importances = new OpenHashMap[Int, Double]()
      computeFeatureImportance(tree.rootNode, importances)
      // Normalize importance vector for this tree, and add it to total.
      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
      val treeNorm = importances.map(_._2).sum
      if (treeNorm != 0) {
        importances.foreach { case (idx, impt) =>
          val normImpt = impt / treeNorm
          totalImportances.changeValue(idx, normImpt, _ + normImpt)
        }
      }
    }
    // Normalize importances
    normalizeMapValues(totalImportances)
    // Construct vector
    val d = if (numFeatures != -1) {
      numFeatures
    } else {
      // Find max feature index used in trees
      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
      maxFeatureIndex + 1
    }
    if (d == 0) {
      assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
        s" importance: No splits in forest, but some non-zero importances.")
    }
    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
    Vectors.sparse(d, indices.toArray, values.toArray)
  }

computeFeatureImportance的实现如下

/**
   * Recursive method for computing feature importances for one tree.
   * This walks down the tree, adding to the importance of 1 feature at each node.
   * @param node  Current node in recursion
   * @param importances  Aggregate feature importances, modified by this method
   */
  private[impl] def computeFeatureImportance(
      node: Node,
      importances: OpenHashMap[Int, Double]): Unit = {
    node match {
      case n: InternalNode =>
        val feature = n.split.featureIndex
        val scaledGain = n.gain * n.impurityStats.count
        importances.changeValue(feature, scaledGain, _ + scaledGain)
        computeFeatureImportance(n.leftChild, importances)
        computeFeatureImportance(n.rightChild, importances)
      case n: LeafNode =>
        // do nothing
    }
  }

由于属性重要性是由gain概念扩展而来,这里以gain来说明如何计算属性重要性。

这里首先可以看出为什么每次树的调用都回到rootnode的调用,因为要递归的沿着树的层深往下游走,这里游走到叶节点什么也不做,其他分裂节点也就是代码里的InternalNode ,先找到该节点划分的属性索引,然后该节点增益乘以该节点数据量,然后更新属性重要性值,这样继续递归左节点,右节点,直到结束

然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把刚才计算的每棵树的属性重要性求和,然后计算每个属性重要性占这棵树总重要性的比值,这样整棵树就搞完了,foreach走完,所有树的属性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按刚才的方法算一遍,这样出来的属性值和就为1了,有了属性个数和排好序的属性重要性值,装入向量,就是最终输出的结果

入口方法就这些了

现在我们还有run方法的1000多行,还有如何递归分配节点这两个点需要讲,后面会继续

RF的特征子集选取策略(spark ml)的更多相关文章

  1. 使用spark ml pipeline进行机器学习

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  2. spark ml 的例子

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  3. spark ml pipeline构建机器学习任务

    一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...

  4. Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用

    1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...

  5. Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...

  6. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  7. Spark ML 几种 归一化(规范化)方法总结

    规范化,有关之前都是用 python写的,  偶然要用scala 进行写, 看到这位大神写的, 那个网页也不错,那个连接图做的还蛮不错的,那天也将自己的博客弄一下那个插件. 本文来源 原文地址:htt ...

  8. Spark ML Pipeline简介

    Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...

  9. spark org.apache.spark.ml.linalg.DenseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector

    在使用 import org.apache.spark.ml.feature.VectorAssembler 转换特征后,想要放入 import org.apache.spark.mllib.clas ...

随机推荐

  1. muduo学习笔记(二)Reactor关键结构

    目录 muduo学习笔记(二)Reactor关键结构 Reactor简述 什么是Reactor Reactor模型的优缺点 poll简述 poll使用样例 muduo Reactor关键结构 Chan ...

  2. Python异常处理回顾与总结

    1 引言 在我们调试程序时,经常不可避免地出现意料之外的情况,导致程序不得不停止运行,然后提示大堆提示信息,大多是这种情况都是由异常引起的.异常的出现一方面是因为写代码时粗心导致的语法错误,这种错误在 ...

  3. 循序渐进学.Net Core Web Api开发系列【4】:前端访问WebApi

    系列目录 循序渐进学.Net Core Web Api开发系列目录 本系列涉及到的源码下载地址:https://github.com/seabluescn/Blog_WebApi 一.概述 前一篇文章 ...

  4. Ubuntu 安装Chrome

    apt方式安装Chrome 1.添加密匙 wget -q -O - https://dl.google.com/linux/linux_signing_key.pub | sudo apt-key a ...

  5. BZOJ3518 : 点组计数

    若直线的斜率为0或者不存在斜率,则有$nC(m,3)+mC(n,3)$种方案.若直线的斜率不为0,只需考虑斜率为正的情况,最后答案再乘以2即可.枚举两个点的坐标差,设$t=\min(n,m)$,则有: ...

  6. hdu 5775 Bubble Sort 树状数组

    Bubble Sort 题目连接: http://acm.hdu.edu.cn/showproblem.php?pid=5775 Description P is a permutation of t ...

  7. PHP 图像居中裁剪函数

    图像居中裁减的大致思路: 1.首先将图像进行缩放,使得缩放后的图像能够恰好覆盖裁减区域.(imagecopyresampled — 重采样拷贝部分图像并调整大小) 2.将缩放后的图像放置在裁减区域中间 ...

  8. AtomicReference,AtomicStampedReference与AtomicMarkableReference的区别

    AtomicReference 通过volatile和Unsafe提供的CAS函数实现原子操作. 自旋+CAS的无锁操作保证共享变量的线程安全 value是volatile类型,这保证了:当某线程修改 ...

  9. Centos 安装Percona Toolkit工具集

    1.下载 下载地址:   https://www.percona.com/downloads/percona-toolkit/LATEST/ [root@bogon ~]# wget https:// ...

  10. HTML5中的跨文档消息传递

    跨文档消息传送(cross-document messaging),有时候也简称为XDM,指的是来自不同域的页面间传递消息.例如,www.w3cmm.com域中的一个页面与一个位于内嵌框架中的p2p. ...