Spark2.0机器学习系列之5:随机森林
概述
随机森林是决策树的组合算法,基础是决策树,关于决策树和Spark2.0中的代码设计可以参考本人另外一篇博客:
http://www.cnblogs.com/itboys/p/8312894.html
随机森林Spark中基于Pipeline和DataFrame的代码编写和决策树基本上是一样的,只需要将classifer换一下可以了,其它部分是一模一样的,因此本文不再对代码进行注释分析。
随机森林模型可以快速地被应用到几乎任何的数据科学问题中去,从而使人们能够高效快捷地获得第一组基准测试结果。在各种各样的问题中,随机森林一次又一次地展示出令人难以置信的强大,而与此同时它又是如此的方便实用。
随机森林是决策树模型的组合,是解决分类和回归问题最为成功的机器学习算法之一。组合多个决策树的目的是为了降低overfitting的风险。
随机森林同时也具备了决策树的诸多优点:
- 可以处理类别特征;
- 可以扩张到多分类问题;
- 不需要对特征进行标准化(归一化)处理;
- 能够检测到feature间的互相影响。
另外随机森林能够处理很高维度(feature很多)的数据,并且不用做特征选择,在训练完后,它能够给出哪些features比较重要。
Random forests are ensembles of decision trees. Random forests are one of the most successful machine learning models for classification and regression.
They combine many decision trees in order to reduce the risk of overfitting. Like decision trees, random forests handle categorical features,
extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions.
随机森林算法
(1)算法:因为随机森林算法是对多个决策树分开独立训练的,所以很容易设计成并行算法。
(2)随机森林是每一颗决策树预测结果的组合,因此随机森林算法产生的预测结果降低了预测的方差,提高了在测试数据上的表现。
Random forests train a set of decision trees separately, so the training can be done in parallel.
The algorithm injects randomness Combining the predictions from each tree reduces the variance of the predictions,
improving the performance on test data.
(3)随机森林训练过程中的随机性,包括两个方面,数据的随机性选取,以及待选特征的随机选取。
除了上述的这些随机性,对每一棵决策树的训练是完全相同的方法。
Training The randomness injected into the training process includes:
Subsampling the original dataset on each iteration to get a different training set (a.k.a. bootstrapping).
Considering different random subsets of features to split on at each tree node.
Apart from these randomizations, decision tree training is done in the same way as for individual decision trees.
(4)数据取样:假设我们设定训练集中的样本个数为N,然后通过有重置的重复多次抽样来获得这N个样本,这样的抽样结果将作为我们生成决策树的训练集;
(5)待选特征随机选取:如果有M个输入变量,每个节点都将随机选择m(m小于M)个特定的变量,然后运用这m个变量来确定最佳的分裂点。在决策树的生成过程中,m的值是保持不变的;
(6)预测:通过对所有的决策树进行加总来预测新的数据(在分类时采用多数投票,在回归时采用平均)。
算法缺点:
http://www.cnblogs.com/emanlee/p/4851555.html一文中说到了随机森林算法的优缺点,分析的很详细。他列举的缺点是:
(1)随机森林在解决回归问题时并没有像它在分类中表现的那么好,这是因为它并不能给出一个连续型的输出。当进行回归时,随机森林不能够作出超越训练集数据范围的预测,这可能导致在对某些还有特定噪声的数据进行建模时出现过度拟合。
(2)对于许多统计建模者来说,随机森林给人的感觉像是一个黑盒子——你几乎无法控制模型内部的运行,只能在不同的参数和随机种子之间进行尝试。
Spark2.0中完整的随机森林代码:
关键参数
最重要的,常常需要调试以提高算法效果的有两个参数:numTrees,maxDepth。
- numTrees(决策树的个数):增加决策树的个数会降低预测结果的方差,这样在测试时会有更高的accuracy。训练时间大致与numTrees呈线性增长关系。
- maxDepth:是指森林中每一棵决策树最大可能depth,在决策树中提到了这个参数。更深的一棵树意味模型预测更有力,但同时训练时间更长,也更倾向于过拟合。但是值得注意的是,随机森林算法和单一决策树算法对这个参数的要求是不一样的。随机森林由于是多个决策树预测结果的投票或平均而降低预测结果的方差,因此相对于单一决策树而言,不容易出现过拟合的情况。所以随机森林可以选择比决策树模型中更大的maxDepth。
甚至有的文献说,随机森林的每棵决策树都最大可能地进行生长而不进行剪枝。但是不管怎样,还是建议对maxDepth参数进行一定的实验,看看是否可以提高预测的效果。
另外还有两个参数,subsamplingRate,featureSubsetStrategy一般不需要调试,但是这两个参数也可以重新设置以加快训练,但是值得注意的是可能会影响模型的预测效果(如果需要调试的仔细读下面英文吧)。
We include a few guidelines for using random forests by discussing the various parameters. We omit some decision tree parameters since those are covered in the decision tree guide.
The first two parameters we mention are the most important, and tuning them can often improve performance:
()numTrees: Number of trees in the forest.
Increasing the number of trees will decrease the variance in predictions, improving the model’s test-time accuracy.
Training time increases roughly linearly in the number of trees.
()maxDepth: Maximum depth of each tree in the forest.
Increasing the depth makes the model more expressive and powerful. However, deep trees take longer to train and are also more prone to overfitting.
In general, it is acceptable to train deeper trees when using random forests than when using a single decision tree. One tree is more likely to overfit than a random forest (because of the variance reduction from averaging multiple trees in the forest).
The next two parameters generally do not require tuning. However, they can be tuned to speed up training.
()subsamplingRate: This parameter specifies the size of the dataset used for training each tree in the forest, as a fraction of the size of the original dataset. The default (1.0) is recommended, but decreasing this fraction can speed up training.
()featureSubsetStrategy: Number of features to use as candidates for splitting at each tree node. The number is specified as a fraction or function of the total number of features. Decreasing this number will speed up training, but can sometimes impact performance if too low.
We include a few guidelines for using random forests by discussing the various parameters. We omit some decision tree parameters since those are covered in the decision tree guide.
完整代码
package my.spark.ml.practice.classification; import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession; public class myRandomForest { public static void main(String[] args) {
SparkSession spark=SparkSession
.builder()
.appName("CoFilter")
.master("local[4]")
.config("spark.sql.warehouse.dir",
"file///:G:/Projects/Java/Spark/spark-warehouse" )
.getOrCreate(); String path="C:/Users/user/Desktop/ml_dataset/classify/horseColicTraining2libsvm.txt";
String path2="C:/Users/user/Desktop/ml_dataset/classify/horseColicTest2libsvm.txt";
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); Dataset<Row> training=spark.read().format("libsvm").load(path);
Dataset<Row> test=spark.read().format("libsvm").load(path2);
//采用的数据集是《机器学习实战》一书中所用到一个比较难分数据集:
//从疝气病症预测马的死亡率,加载前用Python将格式转换为
//libsvm格式(比较简单的一种Spark SQL DataFrame输入格式)
//这种格式导入的数据,label和features自动分好了,不需要再做任何转换了。 StringIndexerModel indexerModel=new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(training);
VectorIndexerModel vectorIndexerModel=new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.fit(training);
IndexToString converter=new IndexToString()
.setInputCol("prediction")
.setOutputCol("convertedPrediction")
.setLabels(indexerModel.labels());
//设计了一个简单的循环,对关键参数(决策树的个数)进行分析调试
for (int numOfTrees = ; numOfTrees < ; numOfTrees+=) {
RandomForestClassifier rfclassifer=new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(numOfTrees);
PipelineModel pipeline=new Pipeline().setStages
(new PipelineStage[]
{indexerModel,vectorIndexerModel,rfclassifer,converter})
.fit(training); Dataset<Row> predictDataFrame=pipeline.transform(test); double accuracy=new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy").evaluate(predictDataFrame); System.out.println("numOfTrees "+numOfTrees+" accuracy "+accuracy);
//RandomForestClassificationModel rfmodel=
//(RandomForestClassificationModel) pipeline.stages()[2];
//System.out.println(rfmodel.toDebugString());
}//numOfTree Cycle
}
}
/**
对两个关键参数maxDepth(1-4)和NumTrees(100-1000)组合进行分析:
maxDepth 1 numOfTrees 100 accuracy 0.761
...
maxDepth 1 numOfTrees 500 accuracy 0.791
maxDepth 1 numOfTrees 600 accuracy 0.820
maxDepth 1 numOfTrees 700 accuracy 0.791
...
maxDepth 2 numOfTrees 100 accuracy 0.776
maxDepth 2 numOfTrees 200 accuracy 0.820//最高
maxDepth 2 numOfTrees 300 accuracy 0.805
...
maxDepth 2 numOfTrees 1000 accuracy 0.805
maxDepth 3 numOfTrees 100 accuracy 0.791
...
maxDepth 3 numOfTrees 600 accuracy 0.805
maxDepth 3 numOfTrees 700 accuracy 0.791
maxDepth 3 numOfTrees 800 accuracy 0.820//最高
maxDepth 3 numOfTrees 900 accuracy 0.791
...
与一棵决策树比较
在这个数据集上,单一决策树预测最高0.746,随机森林在这个数据集上效果还是很明显的
(预测accuracy可以稳定在80%左右,最高达到82%)。
python格式转换代码:转换为libsvm格式
fr=open("C:\\Users\\user\\Desktop\\ml_dataset\\classify\\horseColicTest2.txt");
fr2=open("C:\\Users\\user\\Desktop\\ml_dataset\\classify\\horseColicTest2libsvm.txt",'w+');
for line in fr.readlines():
line=line.strip().split("\t")
features=line[:-]
label=(line[-])
fr2.write(label+" ")
for k in range(len(features)):
fr2.write(str(k+))
fr2.write(":")
fr2.write(features[k]+" ")
fr2.write("\n")
fr.close()
fr2.close()
参考:
(1)Spark2.0 文档
(2)http://www.cnblogs.com/emanlee/p/4851555.html
Spark2.0机器学习系列之5:随机森林的更多相关文章
- Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析
概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树. GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...
- Spark2.0机器学习系列之10: 聚类(高斯混合模型 GMM)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之9: 聚类(k-means,Bisecting k-means,Streaming k-means)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之3:决策树
概述 分类决策树模型是一种描述对实例进行分类的树形结构. 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 .决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造. 生 ...
- Spark2.0机器学习系列之1: 聚类算法(LDA)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之11: 聚类(幂迭代聚类, power iteration clustering, PIC)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet all ...
- Spark2.0机器学习系列之7: MLPC(多层神经网络)
Spark2.0 MLPC(多层神经网络分类器)算法概述 MultilayerPerceptronClassifier(MLPC)这是一个基于前馈神经网络的分类器,它是一种在输入层与输出层之间含有一层 ...
- Spark2.0机器学习系列之12: 线性回归及L1、L2正则化区别与稀疏解
概述 线性回归拟合一个因变量与一个自变量之间的线性关系y=f(x). Spark中实现了: (1)普通最小二乘法 (2)岭回归(L2正规化) (3)La ...
- Spark2.0机器学习系列之4:Logistic回归及Binary分类(二分问题)结果评估
参数设置 α: 梯度上升算法迭代时候权重更新公式中包含 α : http://blog.csdn.net/lu597203933/article/details/38468303 为了更好理解 α和 ...
随机推荐
- Golang 中操作 Mongo Update 的方法
Golang 和 MongoDB 中的 ISODate 时间交互问题 2018年02月27日 11:28:43 独一无二的小个性 阅读数:357 标签: GolangMongoDB时间交互时间转换 更 ...
- 关于json动态拼接响应数据
在EasyUI http://www.jeasyui.com/demo/main/get_users.php 响应数据如下格式: { "total": "11" ...
- 【cb2】安装终端
虽然xterm轻量,但用起来不爽. sudo apt-get install terminator 其它安装 sudo apt-get install spyder sudo apt-get inst ...
- UDP传输原理及数据分片——学习笔记
TCP传输可靠性是:TCP协议里自己做了设计来保证可靠性. IP报文本身是不可靠的 UDP也是 TCP做了很多复杂的协议设计,来保证可靠性. TCP 面向连接,三次握手,四次挥手 拥塞机制 重传机制 ...
- python3----练习......
# 上行遍历 soup = BeautifulSoup(demo, 'html.parser') for parent in soup.a.parents: if parent is None: pr ...
- lumen 常用辅助函数
optional 函数接收任意参数并允许你访问对象上的属性或调用其方法.如果给定的对象为空,属性或方法调用返回 null return optional($user->address)-> ...
- CodeForces 558C Amr and Chemistry (位运算,数论,规律,枚举)
Codeforces 558C 题意:给n个数字,对每一个数字能够进行两种操作:num*2与num/2(向下取整),求:让n个数相等最少须要操作多少次. 分析: 计算每一个数的二进制公共前缀. 枚举法 ...
- Hibernate_day04--QBC查询
QBC查询 1 使用hql查询需要写hql语句实现,但是使用qbc时候,不需要写语句了,使用方法实现 2 使用qbc时候,操作实体类和属性 3 使用qbc,使用Criteria对象实现 查询所有 1 ...
- Gson 解析JSON数据
{"province":[{"name":"安徽省","city":[{"name":"安 ...
- 多线程sshd爆破程序代码
不多说了,直接上手代码,也没有啥练手的,都是很熟悉的代码,水一篇,方便作为工作的小工具吧.试了一下,配合一个好点的字典,还是可以作为个人小工具使用的. #!/usr/bin/env python # ...