支持连续变量和类别变量,类别变量就是某个属性有三个值,a,b,c,需要用Feature Transformers中的vectorindexer处理

上来是一堆参数

setMaxDepth:最大树深度

setMaxBins:最大装箱数,为了近似统计变量,比如变量有100个值,我只分成10段去做统计

setMinInstancesPerNode:每个节点最少实例

setMinInfoGain:最小信息增益

setMaxMemoryInMB:最大内存MB单位,这个值越大,一次处理的节点划分就越多

setCacheNodeIds:是否缓存节点id,缓存可以加速深层树的训练

setCheckpointInterval:检查点间隔,就是多少次迭代固化一次

setImpurity:随机森林有三种方式,entropy,gini,variance,回归肯定就是variance

setSubsamplingRate:设置采样率,就是每次选多少比例的样本构成新树

setSeed:采样种子,种子不变,采样结果不变

setNumTrees:设置森林里有多少棵树

setFeatureSubsetStrategy:设置特征子集选取策略,随机森林就是两个随机,构成树的样本随机,每棵树开始分裂的属性是随机的,其他跟决策树区别不大,注释这么写的

* The number of features to consider for splits at each tree node.
   * Supported options:
   *  - "auto": Choose automatically for task://默认策略
   *            If numTrees == 1, set to "all."     //决策树选择所有属性
   *            If numTrees > 1 (forest), set to "sqrt" for classification and //决策森林 分类选择属性数开平方,回归选择三分之一属性
   *              to "onethird" for regression.
   *  - "all": use all features
   *  - "onethird": use 1/3 of the features
   *  - "sqrt": use sqrt(number of features)
   *  - "log2": use log2(number of features) //还有取对数的
   * (default = "auto") 
   *
   * These various settings are based on the following references:
   *  - log2: tested in Breiman (2001)
   *  - sqrt: recommended by Breiman manual for random forests
   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
   *    package.

参数完毕,下面比较重要的是这段代码

val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

这个地比较蛋疼的是dataset.schema($(featuresCol))

/** An alias for [[getOrDefault()]]. */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

这段代码说明了$(featuresCol))只是求出一个字段名,实战中直接data.schema("features") ,data.schema("features")出来的是StructField,

case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable

StructField包含四个内容,最好知道一下,机器学习代码很多都用

回头说下getCategoricalFeatures,这个方法是识别一个属性是二值变量还是名义变量,例如a,b就是二值变量,a,b,c就是名义变量,最终把属性索引和变量值的个数放到一个map

这个函数的功能和vectorindexer类似,但是一般都用vectorindexer,因为实战中我们大都从sql读数据,sql读出来的数据metadata是空,无法识别二值变量还是名义变量

后面是

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    val strategy =
      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    val trees =
      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
        .map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = oldDataset.first().features.size

new RandomForestRegressionModel(trees, numFeatures)

可以看出还是调的RDD的旧方法,run这个方法是核心有1000多行,后面会详细跟踪,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一组决策树模型,也就是决策森林,另外一个是属性数,我们继续看RandomForestRegressionModel

在1.6版本每棵树的权重都是1,里面还有这么一个方法

override protected def transformImpl(dataset: DataFrame): DataFrame = {
    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

可以看到把模型通过广播的形式传给exectors,搞一个udf预测函数,最后通过withColumn把预测数据粘到原数据后面,

注意这个写法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一个参数是列名,第二个是计算出来的col,col是列类型,预测方法如下

override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
    // Predict average of tree predictions.
    // Ignore the weights since all are 1.0 for now.
    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
  }

可见预测用的是每个树的跟节点,predictImpl(features)返回这个根节点分配的叶节点,这是一个递归调用的过程,关于如何递归,后面也会拿出来细说,最后再用.prediction方法把所有树预测的结果相加求平均

后面有一个计算各属性重要性的方法

lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)

实现如下

private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
    val totalImportances = new OpenHashMap[Int, Double]()
    trees.foreach { tree =>
      // Aggregate feature importance vector for this tree      先计算每棵树的属性重要性值
      val importances = new OpenHashMap[Int, Double]()
      computeFeatureImportance(tree.rootNode, importances)
      // Normalize importance vector for this tree, and add it to total.
      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
      val treeNorm = importances.map(_._2).sum
      if (treeNorm != 0) {
        importances.foreach { case (idx, impt) =>
          val normImpt = impt / treeNorm
          totalImportances.changeValue(idx, normImpt, _ + normImpt)
        }
      }
    }
    // Normalize importances
    normalizeMapValues(totalImportances)
    // Construct vector
    val d = if (numFeatures != -1) {
      numFeatures
    } else {
      // Find max feature index used in trees
      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
      maxFeatureIndex + 1
    }
    if (d == 0) {
      assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
        s" importance: No splits in forest, but some non-zero importances.")
    }
    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
    Vectors.sparse(d, indices.toArray, values.toArray)
  }

computeFeatureImportance的实现如下

/**
   * Recursive method for computing feature importances for one tree.
   * This walks down the tree, adding to the importance of 1 feature at each node.
   * @param node  Current node in recursion
   * @param importances  Aggregate feature importances, modified by this method
   */
  private[impl] def computeFeatureImportance(
      node: Node,
      importances: OpenHashMap[Int, Double]): Unit = {
    node match {
      case n: InternalNode =>
        val feature = n.split.featureIndex
        val scaledGain = n.gain * n.impurityStats.count
        importances.changeValue(feature, scaledGain, _ + scaledGain)
        computeFeatureImportance(n.leftChild, importances)
        computeFeatureImportance(n.rightChild, importances)
      case n: LeafNode =>
        // do nothing
    }
  }

由于属性重要性是由gain概念扩展而来,这里以gain来说明如何计算属性重要性。

这里首先可以看出为什么每次树的调用都回到rootnode的调用,因为要递归的沿着树的层深往下游走,这里游走到叶节点什么也不做,其他分裂节点也就是代码里的InternalNode ,先找到该节点划分的属性索引,然后该节点增益乘以该节点数据量,然后更新属性重要性值,这样继续递归左节点,右节点,直到结束

然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把刚才计算的每棵树的属性重要性求和,然后计算每个属性重要性占这棵树总重要性的比值,这样整棵树就搞完了,foreach走完,所有树的属性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按刚才的方法算一遍,这样出来的属性值和就为1了,有了属性个数和排好序的属性重要性值,装入向量,就是最终输出的结果

入口方法就这些了

现在我们还有run方法的1000多行,还有如何递归分配节点这两个点需要讲,后面会继续

RF的特征子集选取策略(spark ml)的更多相关文章

  1. 使用spark ml pipeline进行机器学习

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  2. spark ml 的例子

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  3. spark ml pipeline构建机器学习任务

    一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...

  4. Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用

    1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...

  5. Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...

  6. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  7. Spark ML 几种 归一化(规范化)方法总结

    规范化,有关之前都是用 python写的,  偶然要用scala 进行写, 看到这位大神写的, 那个网页也不错,那个连接图做的还蛮不错的,那天也将自己的博客弄一下那个插件. 本文来源 原文地址:htt ...

  8. Spark ML Pipeline简介

    Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...

  9. spark org.apache.spark.ml.linalg.DenseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector

    在使用 import org.apache.spark.ml.feature.VectorAssembler 转换特征后,想要放入 import org.apache.spark.mllib.clas ...

随机推荐

  1. streaming优化:spark.default.parallelism调整处理并行度

    官方是这么说的: Cluster resources can be under-utilized if the number of parallel tasks used in any stage o ...

  2. [过程记录]Centos7 下 Hadoop分布式集群搭建

    过程如下: 配置hosts vim /etc/hosts 格式: ip hostname ip hostname 设置免密登陆 首先:每台主机使用ssh命令连接其余主机 ssh 用户名@主机名 提示是 ...

  3. java线程中断的办法

    目录 中断线程相关的方法 中断线程 for循环标记退出 阻塞的退出线程 使用stop()方法停止线程 中断线程相关的方法 中断线程有一些相应的方法,这里列出来一下. 注意,如果是Thread.meth ...

  4. 洛谷P3576 [POI2014]MRO-Ant colony [二分答案,树形DP]

    题目传送门 MRO-Ant colony 题目描述 The ants are scavenging an abandoned ant hill in search of food. The ant h ...

  5. 001.WordPress建站部署

    一 WordPress简介 WordPress是一种使用PHP语言开发的博客平台,用户可以在支持PHP和MySQL数据库的服务器上架设属于自己的网站.也可以把 WordPress当作一个内容管理系统( ...

  6. JAVAEE——SpringBoot入门:简介、微服务、环境准备、helloworld与探究、快速构建项目

    一.Spring Boot 入门 1.Spring Boot 简介 简化Spring应用开发的一个框架: 整个Spring技术栈的一个大整合: J2EE开发的一站式解决方案: 2.微服务 2014,m ...

  7. [ 转载 ] Android JNI(一)——NDK与JNI基础

    Android JNI(一)——NDK与JNI基础 隔壁老李头 关注  4.4 2018.05.09 17:15* 字数 5481 阅读 11468评论 8喜欢 140 本系列文章如下: Androi ...

  8. 试图(View)

    试图是通过命名约定与动作方法想关联的.这个动作方法称为Index,控制器名称为Home; 添加试图,试图名与该试图相关联的动作方法的名称一致.

  9. 快速沃尔什变换与k进制FWT

    这是一篇用来卖萌的文章QAQ 考虑以下三类卷积 \(C_k = \sum \limits_{i \;or\;j = k} A_i * B_j\) \(C_k = \sum \limits_{i\;an ...

  10. Codeforces.810D.Glad to see you!(交互 二分)

    题目链接 \(Description\) 有一个大小为\(k\)的集合\(S\),元素两两不同且在\([1,n]\)内.你可以询问不超过\(60\)次,每次询问你给出\(x,y\),交互库会返回\(\ ...