支持连续变量和类别变量,类别变量就是某个属性有三个值,a,b,c,需要用Feature Transformers中的vectorindexer处理

上来是一堆参数

setMaxDepth:最大树深度

setMaxBins:最大装箱数,为了近似统计变量,比如变量有100个值,我只分成10段去做统计

setMinInstancesPerNode:每个节点最少实例

setMinInfoGain:最小信息增益

setMaxMemoryInMB:最大内存MB单位,这个值越大,一次处理的节点划分就越多

setCacheNodeIds:是否缓存节点id,缓存可以加速深层树的训练

setCheckpointInterval:检查点间隔,就是多少次迭代固化一次

setImpurity:随机森林有三种方式,entropy,gini,variance,回归肯定就是variance

setSubsamplingRate:设置采样率,就是每次选多少比例的样本构成新树

setSeed:采样种子,种子不变,采样结果不变

setNumTrees:设置森林里有多少棵树

setFeatureSubsetStrategy:设置特征子集选取策略,随机森林就是两个随机,构成树的样本随机,每棵树开始分裂的属性是随机的,其他跟决策树区别不大,注释这么写的

* The number of features to consider for splits at each tree node.
   * Supported options:
   *  - "auto": Choose automatically for task://默认策略
   *            If numTrees == 1, set to "all."     //决策树选择所有属性
   *            If numTrees > 1 (forest), set to "sqrt" for classification and //决策森林 分类选择属性数开平方,回归选择三分之一属性
   *              to "onethird" for regression.
   *  - "all": use all features
   *  - "onethird": use 1/3 of the features
   *  - "sqrt": use sqrt(number of features)
   *  - "log2": use log2(number of features) //还有取对数的
   * (default = "auto") 
   *
   * These various settings are based on the following references:
   *  - log2: tested in Breiman (2001)
   *  - sqrt: recommended by Breiman manual for random forests
   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
   *    package.

参数完毕,下面比较重要的是这段代码

val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

这个地比较蛋疼的是dataset.schema($(featuresCol))

/** An alias for [[getOrDefault()]]. */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

这段代码说明了$(featuresCol))只是求出一个字段名,实战中直接data.schema("features") ,data.schema("features")出来的是StructField,

case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable

StructField包含四个内容,最好知道一下,机器学习代码很多都用

回头说下getCategoricalFeatures,这个方法是识别一个属性是二值变量还是名义变量,例如a,b就是二值变量,a,b,c就是名义变量,最终把属性索引和变量值的个数放到一个map

这个函数的功能和vectorindexer类似,但是一般都用vectorindexer,因为实战中我们大都从sql读数据,sql读出来的数据metadata是空,无法识别二值变量还是名义变量

后面是

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    val strategy =
      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    val trees =
      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
        .map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = oldDataset.first().features.size

new RandomForestRegressionModel(trees, numFeatures)

可以看出还是调的RDD的旧方法,run这个方法是核心有1000多行,后面会详细跟踪,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一组决策树模型,也就是决策森林,另外一个是属性数,我们继续看RandomForestRegressionModel

在1.6版本每棵树的权重都是1,里面还有这么一个方法

override protected def transformImpl(dataset: DataFrame): DataFrame = {
    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

可以看到把模型通过广播的形式传给exectors,搞一个udf预测函数,最后通过withColumn把预测数据粘到原数据后面,

注意这个写法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一个参数是列名,第二个是计算出来的col,col是列类型,预测方法如下

override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
    // Predict average of tree predictions.
    // Ignore the weights since all are 1.0 for now.
    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
  }

可见预测用的是每个树的跟节点,predictImpl(features)返回这个根节点分配的叶节点,这是一个递归调用的过程,关于如何递归,后面也会拿出来细说,最后再用.prediction方法把所有树预测的结果相加求平均

后面有一个计算各属性重要性的方法

lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)

实现如下

private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
    val totalImportances = new OpenHashMap[Int, Double]()
    trees.foreach { tree =>
      // Aggregate feature importance vector for this tree      先计算每棵树的属性重要性值
      val importances = new OpenHashMap[Int, Double]()
      computeFeatureImportance(tree.rootNode, importances)
      // Normalize importance vector for this tree, and add it to total.
      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
      val treeNorm = importances.map(_._2).sum
      if (treeNorm != 0) {
        importances.foreach { case (idx, impt) =>
          val normImpt = impt / treeNorm
          totalImportances.changeValue(idx, normImpt, _ + normImpt)
        }
      }
    }
    // Normalize importances
    normalizeMapValues(totalImportances)
    // Construct vector
    val d = if (numFeatures != -1) {
      numFeatures
    } else {
      // Find max feature index used in trees
      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
      maxFeatureIndex + 1
    }
    if (d == 0) {
      assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
        s" importance: No splits in forest, but some non-zero importances.")
    }
    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
    Vectors.sparse(d, indices.toArray, values.toArray)
  }

computeFeatureImportance的实现如下

/**
   * Recursive method for computing feature importances for one tree.
   * This walks down the tree, adding to the importance of 1 feature at each node.
   * @param node  Current node in recursion
   * @param importances  Aggregate feature importances, modified by this method
   */
  private[impl] def computeFeatureImportance(
      node: Node,
      importances: OpenHashMap[Int, Double]): Unit = {
    node match {
      case n: InternalNode =>
        val feature = n.split.featureIndex
        val scaledGain = n.gain * n.impurityStats.count
        importances.changeValue(feature, scaledGain, _ + scaledGain)
        computeFeatureImportance(n.leftChild, importances)
        computeFeatureImportance(n.rightChild, importances)
      case n: LeafNode =>
        // do nothing
    }
  }

由于属性重要性是由gain概念扩展而来,这里以gain来说明如何计算属性重要性。

这里首先可以看出为什么每次树的调用都回到rootnode的调用,因为要递归的沿着树的层深往下游走,这里游走到叶节点什么也不做,其他分裂节点也就是代码里的InternalNode ,先找到该节点划分的属性索引,然后该节点增益乘以该节点数据量,然后更新属性重要性值,这样继续递归左节点,右节点,直到结束

然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把刚才计算的每棵树的属性重要性求和,然后计算每个属性重要性占这棵树总重要性的比值,这样整棵树就搞完了,foreach走完,所有树的属性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按刚才的方法算一遍,这样出来的属性值和就为1了,有了属性个数和排好序的属性重要性值,装入向量,就是最终输出的结果

入口方法就这些了

现在我们还有run方法的1000多行,还有如何递归分配节点这两个点需要讲,后面会继续

RF的特征子集选取策略(spark ml)的更多相关文章

  1. 使用spark ml pipeline进行机器学习

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  2. spark ml 的例子

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  3. spark ml pipeline构建机器学习任务

    一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...

  4. Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用

    1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...

  5. Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...

  6. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  7. Spark ML 几种 归一化(规范化)方法总结

    规范化,有关之前都是用 python写的,  偶然要用scala 进行写, 看到这位大神写的, 那个网页也不错,那个连接图做的还蛮不错的,那天也将自己的博客弄一下那个插件. 本文来源 原文地址:htt ...

  8. Spark ML Pipeline简介

    Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...

  9. spark org.apache.spark.ml.linalg.DenseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector

    在使用 import org.apache.spark.ml.feature.VectorAssembler 转换特征后,想要放入 import org.apache.spark.mllib.clas ...

随机推荐

  1. 【LOJ】#6432. 「PKUSC2018」真实排名

    题解 简单分析一下,如果这个选手成绩是0,直接输出\(\binom{n}{k}\) 如果这个选手的成绩没有被翻倍,那么找到大于等于它的数(除了它自己)有a个,翻倍后不大于它的数有b个,那么就从这\(a ...

  2. hadoop2.6.4集群的搭建

    hadoop集群搭建(亲自操作成功步骤!值得信赖!) 1.1集群简介 hadoop的核心组件: HDFS(分布式文件系统) YARN(运算资源调度系统) MapReduce(分布式运算编程框架) HA ...

  3. 000 Ajax介紹

    1.介紹 2.应用 3.优点 4.缺点

  4. android弹出对话框

    我们在平时做开发的时候,免不了会用到各种各样的对话框,相信有过其他平台开发经验的朋友都会知道,大部分的平台都只提供了几个最简单的实现,如果我们想实现自己特定需求的对话框,大家可能首先会想到,通过继承等 ...

  5. jsonp的理解

    众所周知:在开发过程中,有时候需要客户端从服务器接收或向服务器发送一些数据:如果使用普通的ajax,则会遇到跨域访问无权限的问题. 要解决这个问题,就需要了解一下jsonp了: 1. ajax请求普通 ...

  6. android 启动 service 的两种方式,及什么时候用哪个 android 什么时候用bindService

    韩梦飞沙  韩亚飞  313134555@qq.com  yue31313  han_meng_fei_sha android  什么时候用bindService ============ 启动方式有 ...

  7. BZOJ 3253 Fence Repair 哈夫曼树 水题

    http://poj.org/problem?id=3253 这道题约等于合并果子,但是通过这道题能够看出来哈夫曼树是什么了. #include<cstdio> #include<c ...

  8. BZOJ.2882.工艺(后缀自动机 最小表示 map)

    题目链接 BZOJ 洛谷 SAM求字符串的最小循环表示. 因为从根节点出发可以得到所有子串,所以每次找字典序最小的一个出边走即可.因为长度问题把原串再拼接在后面一次. 需要用map存转移.复杂度O(n ...

  9. 【BZOJ-3681】Arietta 网络流 + 线段树合并

    3681: Arietta Time Limit: 20 Sec  Memory Limit: 64 MBSubmit: 182  Solved: 70[Submit][Status][Discuss ...

  10. VSCode换行符

    如果要显示换行符:\r\n 如果要替换显示出来的\n,替换上要用正则表达式,然后使用\r\n. 如果要直接换行,\n