RF的特征子集选取策略(spark ml)
支持连续变量和类别变量,类别变量就是某个属性有三个值,a,b,c,需要用Feature Transformers中的vectorindexer处理
上来是一堆参数
setMaxDepth:最大树深度
setMaxBins:最大装箱数,为了近似统计变量,比如变量有100个值,我只分成10段去做统计
setMinInstancesPerNode:每个节点最少实例
setMinInfoGain:最小信息增益
setMaxMemoryInMB:最大内存MB单位,这个值越大,一次处理的节点划分就越多
setCacheNodeIds:是否缓存节点id,缓存可以加速深层树的训练
setCheckpointInterval:检查点间隔,就是多少次迭代固化一次
setImpurity:随机森林有三种方式,entropy,gini,variance,回归肯定就是variance
setSubsamplingRate:设置采样率,就是每次选多少比例的样本构成新树
setSeed:采样种子,种子不变,采样结果不变
setNumTrees:设置森林里有多少棵树
setFeatureSubsetStrategy:设置特征子集选取策略,随机森林就是两个随机,构成树的样本随机,每棵树开始分裂的属性是随机的,其他跟决策树区别不大,注释这么写的
* The number of features to consider for splits at each tree node.
* Supported options:
* - "auto": Choose automatically for task://默认策略
* If numTrees == 1, set to "all." //决策树选择所有属性
* If numTrees > 1 (forest), set to "sqrt" for classification and //决策森林 分类选择属性数开平方,回归选择三分之一属性
* to "onethird" for regression.
* - "all": use all features
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features) //还有取对数的
* (default = "auto")
*
* These various settings are based on the following references:
* - log2: tested in Breiman (2001)
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
参数完毕,下面比较重要的是这段代码
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
这个地比较蛋疼的是dataset.schema($(featuresCol))
/** An alias for [[getOrDefault()]]. */
protected final def $[T](param: Param[T]): T = getOrDefault(param)
这段代码说明了$(featuresCol))只是求出一个字段名,实战中直接data.schema("features") ,data.schema("features")出来的是StructField,
case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable
StructField包含四个内容,最好知道一下,机器学习代码很多都用
回头说下getCategoricalFeatures,这个方法是识别一个属性是二值变量还是名义变量,例如a,b就是二值变量,a,b,c就是名义变量,最终把属性索引和变量值的个数放到一个map
这个函数的功能和vectorindexer类似,但是一般都用vectorindexer,因为实战中我们大都从sql读数据,sql读出来的数据metadata是空,无法识别二值变量还是名义变量
后面是
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = oldDataset.first().features.size
new RandomForestRegressionModel(trees, numFeatures)
可以看出还是调的RDD的旧方法,run这个方法是核心有1000多行,后面会详细跟踪,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一组决策树模型,也就是决策森林,另外一个是属性数,我们继续看RandomForestRegressionModel
在1.6版本每棵树的权重都是1,里面还有这么一个方法
override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
可以看到把模型通过广播的形式传给exectors,搞一个udf预测函数,最后通过withColumn把预测数据粘到原数据后面,
注意这个写法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一个参数是列名,第二个是计算出来的col,col是列类型,预测方法如下
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}
可见预测用的是每个树的跟节点,predictImpl(features)返回这个根节点分配的叶节点,这是一个递归调用的过程,关于如何递归,后面也会拿出来细说,最后再用.prediction方法把所有树预测的结果相加求平均
后面有一个计算各属性重要性的方法
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
实现如下
private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree 先计算每棵树的属性重要性值
val importances = new OpenHashMap[Int, Double]()
computeFeatureImportance(tree.rootNode, importances)
// Normalize importance vector for this tree, and add it to total.
// TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
val treeNorm = importances.map(_._2).sum
if (treeNorm != 0) {
importances.foreach { case (idx, impt) =>
val normImpt = impt / treeNorm
totalImportances.changeValue(idx, normImpt, _ + normImpt)
}
}
}
// Normalize importances
normalizeMapValues(totalImportances)
// Construct vector
val d = if (numFeatures != -1) {
numFeatures
} else {
// Find max feature index used in trees
val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
maxFeatureIndex + 1
}
if (d == 0) {
assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
s" importance: No splits in forest, but some non-zero importances.")
}
val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
Vectors.sparse(d, indices.toArray, values.toArray)
}
computeFeatureImportance的实现如下
/**
* Recursive method for computing feature importances for one tree.
* This walks down the tree, adding to the importance of 1 feature at each node.
* @param node Current node in recursion
* @param importances Aggregate feature importances, modified by this method
*/
private[impl] def computeFeatureImportance(
node: Node,
importances: OpenHashMap[Int, Double]): Unit = {
node match {
case n: InternalNode =>
val feature = n.split.featureIndex
val scaledGain = n.gain * n.impurityStats.count
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
case n: LeafNode =>
// do nothing
}
}
由于属性重要性是由gain概念扩展而来,这里以gain来说明如何计算属性重要性。
这里首先可以看出为什么每次树的调用都回到rootnode的调用,因为要递归的沿着树的层深往下游走,这里游走到叶节点什么也不做,其他分裂节点也就是代码里的InternalNode ,先找到该节点划分的属性索引,然后该节点增益乘以该节点数据量,然后更新属性重要性值,这样继续递归左节点,右节点,直到结束
然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把刚才计算的每棵树的属性重要性求和,然后计算每个属性重要性占这棵树总重要性的比值,这样整棵树就搞完了,foreach走完,所有树的属性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按刚才的方法算一遍,这样出来的属性值和就为1了,有了属性个数和排好序的属性重要性值,装入向量,就是最终输出的结果
入口方法就这些了
现在我们还有run方法的1000多行,还有如何递归分配节点这两个点需要讲,后面会继续
RF的特征子集选取策略(spark ml)的更多相关文章
- 使用spark ml pipeline进行机器学习
一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...
- spark ml 的例子
一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...
- spark ml pipeline构建机器学习任务
一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...
- Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用
1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...
- Spark ML源码分析之四 树
之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...
- Spark ML机器学习
Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...
- Spark ML 几种 归一化(规范化)方法总结
规范化,有关之前都是用 python写的, 偶然要用scala 进行写, 看到这位大神写的, 那个网页也不错,那个连接图做的还蛮不错的,那天也将自己的博客弄一下那个插件. 本文来源 原文地址:htt ...
- Spark ML Pipeline简介
Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...
- spark org.apache.spark.ml.linalg.DenseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector
在使用 import org.apache.spark.ml.feature.VectorAssembler 转换特征后,想要放入 import org.apache.spark.mllib.clas ...
随机推荐
- Windows下CRF++进行中文人名识别的初次尝试
语料来自1998年1月份人民日报语料 1 语料处理 1.1 原始语料数据格式 语料中,句子已经被分词好,并且在人名后以“/”标注了“nr”表示是人名,其他非人名的分词没有进行标注 1.2 CRF++要 ...
- POJ - 2115C Looooops 扩展欧几里得(做的少了无法一眼看出)
题目大意&&分析: for (variable = A; variable != B; variable += C) statement;这个循环式子表示a+c*n(n为整数)==b是 ...
- HTML5 LocalStorage 本地存储(转)
原文:http://www.cnblogs.com/xiaowei0705/archive/2011/04/19/2021372.html HTML5 LocalStorage 本地存储 说到本地存储 ...
- 新手:Qt之QLabel类的应用
在Qt中,我们不可避免的会用到QLabel类.而Qlabel的强大功能作为程序员的你有多少了解? 下面,跟着我一起在来学习一下吧! 1.添加文本 Qlabel类添加文本有两种方式,一种是直接在实现时添 ...
- WIN10下 VS2017+OpenCv 3.4.1 配置
写篇博客来记录一下opencv在VS中的配置. 一.下载OpenCv安装包 下载的途径有三种: 1.官网下载 但是官网下载真的是贼头大,首先下载好好的突然说下载中断,而且无法恢复,此外,还慢,毕竟外网 ...
- 使用Plant Simulation连接SQL Server
1. 在管理类库中添加ODBC. 2. 在控制面板->管理工具中设置ODBC,添加SQL Server服务. 3. 在plant simulation中将信息流中的ODBC添加到Frame中. ...
- zoj 3460 二分+二分图匹配
不错的思想 /* 大致题意: 用n个导弹发射塔攻击m个目标.每个发射架在某个时刻只能为 一颗导弹服务,发射一颗导弹需要准备t1的时间,一颗导弹从发 射到击中目标的时间与目标到发射架的距离有关.每颗导弹 ...
- 钻牛角尖还是走进死胡同--shell脚本根据名称获得 dubbo 服务的 pid
到了下午,突然觉得坐立不安,可能是因为中午没有休息好.老大不小了还在做页面整合的事情,这是参加工作时就干的工作了.然后突然想去挑战高级一点的缺陷排查,结果一不小心就钻了一个牛角尖.启动 dubbo 服 ...
- FireDAC 下的 Sqlite [5] - 数据的插入、更新、删除
先在空白窗体上添加: TFDConnection.TFDPhysSQLiteDriverLink.TFDGUIxWaitCursor.TFDQuery.TDataSource.TDBGrid(并在设计 ...
- loading加载和layer.js
layer.js中的loading加载 l本篇主要介绍layerjs中的loading加载在实际项目中的应用 1.使用的技术 前端:HTML5+CSS3+JS+layer.js 后端:.net 2.遇 ...