Spark中的CrossValidation

  • Spark中采用是k折交叉验证 (k-fold cross validation)。举个例子,例如10折交叉验证(10-fold cross validation),将数据集分成10份,轮流将其中9份做训练1份做验证,10次的结果的均值作为对算法精度的估计。
  • 10折交叉检验最常见,是因为通过利用大量数据集、使用不同学习技术进行的大量试验,表明10折是获得最好误差估计的恰当选择,而且也有一些理论根据可以证明这一点。但这并非最终结论,争议仍然存在。而且似乎5折或者20折与10折所得出的结果也相差无几。
  • 交叉检验常用于分析模型的泛化能力,提高模型的稳定。相对于手工探索式的参数调试,交叉验证更具备统计学上的意义。
  • 在Spark中,Cross Validation和ParamMap(“参数组合”的Map)结合使用。具体做法是,针对特定的Param组合,CrossValidator计算K (K 折交叉验证)个评估分数的平均值。然后和其它“参数组合”CrossValidator计算结果比较,完成所有的比较后,将最优的“参数组合”挑选出来,这“最优的一组参数”将用在整个训练数据集上重新训练(re-fit),得到最终的Model。
  • 也就是说,通过交叉验证,找到了最佳的”参数组合“,利用这组参数,在整个训练集上可以训练(fit)出一个泛化能力强,误差相对最小的的最佳模型。
  • 很显然,交叉验证计算代价很高,假设有三个参数:参数alpha有3中选择,参数beta有4种选择,参数gamma有4中选择,进行10折计算,那么将进行(3×4×4)×10=480次模型训练。

Spark documnets 原文: 
(1)CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap, CrossValidator computes the average evaluation metric for the 3 Models produced by fitting the Estimator on the 3 different (training, test) dataset pairs. 
(2)After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset. 
(3)Using CrossValidator to select from a grid of parameters.Note that cross-validation over a grid of parameters is expensive. E.g., in the example below, the parameter grid has 3 values for hashingTF.numFeatures and 2 values for lr.regParam, and CrossValidator uses 2 folds. This multiplies out to (3×2)×2=12different models being trained. In realistic settings, it can be common to try many more parameters and use more folds (k=3 and k=10 are common). In other words, using CrossValidator can be very expensive. However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.

计算流程

//Spark Version 2.0
package my.spark.ml.practice; import java.io.IOException; import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession; /**ALS算法协同过滤推荐算法
* 使用Spark 2.0 基于Pipeline,ParamMap,CrossValidation
* 对超参数进行调优,并进行模型选择
*/ public class MyCrossValidation {
public static void main(String[] args) throws IOException{
SparkSession spark=SparkSession
.builder()
.appName("myCrossValidation")
.master("local[4]")
.getOrCreate();
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
//加载数据
JavaRDD<Rating> ratingsRDD = spark
.read().textFile("/home/hadoop/spark/spark-2.0.0-bin-hadoop2.6" +
"/data/mllib/als/sample_movielens_ratings.txt").javaRDD()
.map(new Function<String, Rating>() {
public Rating call(String str) {
return Rating.parseRating(str);
}
});
//将整个数据集划分为训练集和测试集
//注意training集将用于Cross Validation,而test集将用于最终模型的评估
//在traning集中,在Croos Validation时将进一步划分为K份,每次留一份作为
//Validation,注意区分:ratings.randomSplit()分出的Test集和K 折留
//下验证的那一份完全不是一个概念,也起着完全不同的作用,一定不要相混淆
Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1]; // Build the recommendation model using ALS on the training data
ALS als=new ALS()
.setMaxIter(8)
.setRank(20).setRegParam(0.8)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
.setPredictionCol("predict_rating");
/*
* (1)秩Rank:模型中隐含因子的个数:低阶近似矩阵中隐含特在个数,因子一般多一点比较好,
* 但是会增大内存的开销。因此常在训练效果和系统开销之间进行权衡,通常取值在10-200之间。
* (2)最大迭代次数:运行时的迭代次数,ALS可以做到每次迭代都可以降低评级矩阵的重建误差,
* 一般少数次迭代便能收敛到一个比较合理的好模型。
* 大部分情况下没有必要进行太对多次迭代(10次左右一般就挺好了)
* (3)正则化参数regParam:和其他机器学习算法一样,控制模型的过拟合情况。
* 该值与数据大小,特征,系数程度有关。此参数正是交叉验证需要验证的参数之一。
*/
// Configure an ML pipeline, which consists of one stage
//一般会包含多个stages
Pipeline pipeline=new Pipeline().
setStages(new PipelineStage[] {als});
// We use a ParamGridBuilder to construct a grid of parameters to search over.
ParamMap[] paramGrid=new ParamGridBuilder()
.addGrid(als.rank(),new int[]{5,10,20})
.addGrid(als.regParam(),new double[]{0.05,0.10,0.15,0.20,0.40,0.80})
.build(); // CrossValidator 需要一个Estimator,一组Estimator ParamMaps, 和一个Evaluator.
// (1)Pipeline作为Estimator;
// (2)定义一个RegressionEvaluator作为Evaluator,并将评估标准设置为“rmse”均方根误差
// (3)设置ParamMap
// (4)设置numFolds CrossValidator cv=new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new RegressionEvaluator()
.setLabelCol("rating")
.setPredictionCol("predict_rating")
.setMetricName("rmse"))
.setEstimatorParamMaps(paramGrid)
.setNumFolds(5); // 运行交叉检验,自动选择最佳的参数组合
CrossValidatorModel cvModel=cv.fit(training);
//保存模型
cvModel.save("/home/hadoop/spark/cvModel_als.modle"); //System.out.println("numFolds: "+cvModel.getNumFolds());
//Test数据集上结果评估
Dataset<Row> predictions=cvModel.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")//RMS Error
.setLabelCol("rating")
.setPredictionCol("predict_rating");
Double rmse = evaluator.evaluate(predictions);
System.out.println("RMSE @ test dataset " + rmse);
//Output: RMSE @ test dataset 0.943644792277118
}
}

备注:程序运行需要定义Rating Class 在下面链接里可以找到: http://spark.apache.org/docs/latest/ml-collaborative-filtering.html

Spark2.0机器学习系列之2:基于Pipeline、交叉验证、ParamMap的模型选择和超参数调优的更多相关文章

  1. Spark2.0机器学习系列之3:决策树

    概述 分类决策树模型是一种描述对实例进行分类的树形结构. 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 .决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造. 生 ...

  2. Spark2.0机器学习系列之12: 线性回归及L1、L2正则化区别与稀疏解

    概述 线性回归拟合一个因变量与一个自变量之间的线性关系y=f(x).       Spark中实现了:       (1)普通最小二乘法       (2)岭回归(L2正规化)       (3)La ...

  3. Spark2.0机器学习系列之1: 聚类算法(LDA)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  4. Spark2.0机器学习系列之11: 聚类(幂迭代聚类, power iteration clustering, PIC)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:             (1)K-means             (2)Latent Dirichlet all ...

  5. Spark2.0机器学习系列之10: 聚类(高斯混合模型 GMM)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  6. Spark2.0机器学习系列之9: 聚类(k-means,Bisecting k-means,Streaming k-means)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  7. Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析

    概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树.     GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...

  8. Spark2.0机器学习系列之5:随机森林

    概述 随机森林是决策树的组合算法,基础是决策树,关于决策树和Spark2.0中的代码设计可以参考本人另外一篇博客: http://www.cnblogs.com/itboys/p/8312894.ht ...

  9. Spark2.0机器学习系列之7: MLPC(多层神经网络)

    Spark2.0 MLPC(多层神经网络分类器)算法概述 MultilayerPerceptronClassifier(MLPC)这是一个基于前馈神经网络的分类器,它是一种在输入层与输出层之间含有一层 ...

随机推荐

  1. 树形dp - BNU 39572 Usoperanto

    Usoperanto Problem's Link Mean: 给定n个单词,每个单词可以作为形容词来修饰其他单词. 如果当前单词Wi修饰Wj,那么这个修饰的代价是:Wi~Wj之间的单词的总长度. 你 ...

  2. javascript弹出层-DEMO001

    首先上一张图 看下弹出层的效果 从图中可以看到二部分 一是弹出层 二是遮照层 弹出层:即弹出你要操作的内容 遮照层:遮照住不要操作的内空 实际技术原理主要是 CSS +JS  (z-index是核心) ...

  3. 【转】【iOS测试系列】常用测试小插件的使用

    背景介绍 由于iOS系统的限制,在非越狱的自动化测试中无法实现一些常用的功能,比如不同应用之间来回切换.模拟全局的点击事件等等.但是在越狱的环境下,这些限制就不存在了,我们可以利用各种小插件来实现我们 ...

  4. C语言 百炼成钢23

    /* 题目59:链表如下 typedef struct _LinkList { int data; struct _LinkList*next; } LinkList; 有如下结点数据域 1 2 3 ...

  5. Ubuntu之No module named cv2

    最简单的方法是:pip install opencv-python 另外,从源码安装的方法: 1下载opencv源码:http://opencv.org/releases.html  推荐2.4.13 ...

  6. 第二百五十四节,Bootstrap项目实战--案例

    Bootstrap项目实战--案例 html <!DOCTYPE html> <html lang="zh-cn"> <head> <me ...

  7. css -- 运用@media实现网页自适应中的几个关键分辨率

    经常为不同分辨率设备或不同窗口大小下布局错位而头疼,可以利用@media screen实现网页布局的自适应,但是怎样兼容所有主流设备就成了问题.到底分辨率是多少的时候设置呢? 先看下面的代码,这是从b ...

  8. php -- 可变变量

    有时候使用可变变量名是很方便的.就是说,一个变量的变量名可以动态的设置和使用.一个普通的变量通过声明来设置,例如: <?php $a = 'hello'; ?> 一个可变变量获取了一个普通 ...

  9. 小结:单调栈 & 单调队列

    概要: 对于维护信息具有单调性的性质或者问题可以转化为具有单调性质的模型的题,我们可以考虑用单调栈或单调队列. 技巧及注意: 技巧很多,只要能将问题转化为单调性问题,就好解决了. 当维护固定长度的单调 ...

  10. Swift AVFoundation 二维码扫描和生成

    项目最终不须要支持iOS6了(泪崩),在二维码扫描这一块,可以全然的放弃ZXing库,改用系统的AVFoundation了,拿swift写了个Demo,效果例如以下: github地址:点这里 有关A ...