Spark机器学习库现支持两种接口的API:RDD-based和DataFrame-based,Spark官方网站上说,RDD-based APIs在2.0后进入维护模式,主要的机器学习API是spark-ml包中的DataFrame-based API,并将在3.0后完全移除RDD-based API。

在学习了两周Spark MLlib后,准备转向DataFrame-based接口。由于现有的文档资料均是RDD-based接口,于是便去看了看Spark MLlib的源码。DataFrame-based API 包含在org.apache.spark.ml包中,其中主要的类结构如下:

咱先看一个线性回归的例子examples/ml/LinearRegressionExample.scala,其首先定义了一个LinearRegression的对象:

val lir = new LinearRegression()
.setFeaturesCol("features")
.setLabelCol("label")
.setRegParam(params.regParam)
.setElasticNetParam(params.elasticNetParam)
.setMaxIter(params.maxIter)
.setTol(params.tol)

然后,调用fit方法训练数据,得到一个训练好的模型lirModel,它是一个LinearRegressionModel类的对象。

val lirModel = lir.fit(training)

现在,我们大概可以理清MLlib机器学习的流程,和很多单机机器学习库一样,先定义一个模型并设置好参数,然后训练数据,最后返回一个训练好了的模型。

我们现在在源码中去查看LinearRegression和LinearRegressionModel,其类的依赖关系如下:

LinearRegression是一个Predictor,LinearRegressionModel是一个Model,那么Predictor是学习算法,Model是训练得到的模型。除此之外,还有一类继承自Params的类,这是一个表示参数的类。Predictor 和Model 共享一套参数。

现在用Spark MLlib来完成第一个机器学习例子,数据是我之前放在txt文件里的回归数据,一共550多万条,共13列,第一列是Label,后面是Features。分别演示两种接口,先用旧的接口:

1.读取原始数据:

scala> import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg._
scala> import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.regression._
scala> val raw_data = sc.textFile("data/my/y_x.txt")
raw_data: org.apache.spark.rdd.RDD[String] = data/my/y_x.txt MapPartitionsRDD[1] at textFile at <console>:30

2.转换格式,RDD-based接口以LabeledPoint为输入数据的格式:

scala> val data = raw_data.map{ line =>
| val arr = line.split(' ').map(_.toDouble)
| val label = arr.head
| val features = Vectors.dense(arr.tail)| LabeledPoint(label,features)
| }
data: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[2] at map at <console>:32

3.划分train、test数据集:

scala> val splits = data.randomSplit(Array(0.8, 0.2))
splits: Array[org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]] = Array(MapPartitionsRDD[3] at randomSplit at <console>:34, MapPartitionsRDD[4] at randomSplit at <console>:34)
scala> val train_set = splits(0).cache
train_set: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[3] at randomSplit at <console>:34
scala> val test_set = splits(1).cache
test_set: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint] = MapPartitionsRDD[4] at randomSplit at <console>:34

4.使用LinearRegressionWithSGD.train训练模型:

scala> val lr = LinearRegressionWithSGD.train(train_set,100,0.0001)
warning: there was one deprecation warning; re-run with -deprecation for details
16/08/26 09:20:44 WARN Executor: 1 block locks were not released by TID = 0:
[rdd_3_0]
lr: org.apache.spark.mllib.regression.LinearRegressionModel = org.apache.spark.mllib.regression.LinearRegressionModel: intercept = 0.0, numFeatures = 12

5.模型评估:

scala> val pred_labels = test_set.map(lp => (lp.label, lr.predict(lp.features)))
pred_labels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[17] at map at <console>:42
scala> val mse = pred_labels.map{case (p,v) => math.pow(p-v,2)}.mean
mse: Double = 0.05104150735910074

再用新的接口:

1.读取原始数据:

scala> import org.apache.spark.ml.linalg._
import org.apache.spark.ml.linalg._
scala> import org.apache.spark.ml.regression._
import org.apache.spark.ml.regression._
scala> import org.apache.spark.sql._
import org.apache.spark.sql._
scala> val raw_data = spark.read.text("data/my/y_x.txt")
raw_data: org.apache.spark.sql.DataFrame = [value: string]

2.转换数据

scala> val data = raw_data.rdd.map { case Row(line:String) =>
| val arr = line.split(' ').map(_.toDouble)
| val label = arr.head
| val features = Vectors.dense(arr.tail)
| (label,features)
| }
data: org.apache.spark.rdd.RDD[(Double, org.apache.spark.ml.linalg.Vector)] = MapPartitionsRDD[4] at map at <console>:34

3.划分数据集

scala> val splits = data.randomSplit(Array(0.8, 0.2))
splits: Array[org.apache.spark.rdd.RDD[(Double, org.apache.spark.ml.linalg.Vector)]] = Array(MapPartitionsRDD[5] at randomSplit at <console>:36, MapPartitionsRDD[6] at randomSplit at <console>:36)
scala> val train_set = splits(0).toDS.cache
train_set: org.apache.spark.sql.Dataset[(Double, org.apache.spark.ml.linalg.Vector)] = [_1: double, _2: vector]
scala> val test_set = splits(1).toDS.cache
test_set: org.apache.spark.sql.Dataset[(Double, org.apache.spark.ml.linalg.Vector)] = [_1: double, _2: vector]

4.创建LinearRegression对象,并设置模型参数。这里设置类LabelCol和FeaturesCol列,默认为“label”和“features”,而我们的数据是"_1"和”_2“。

scala> val lir = new LinearRegression
lir: org.apache.spark.ml.regression.LinearRegression = linReg_c4e70a01bcd3
scala> lir.setFeaturesCol("_2")
res0: org.apache.spark.ml.regression.LinearRegression = linReg_c4e70a01bcd3
scala> lir.setLabelCol("_1")
res1: org.apache.spark.ml.regression.LinearRegression = linReg_c4e70a01bcd3

5.训练模型

val model = lir.fit(train_set)
16/08/26 09:45:16 WARN Executor: 1 block locks were not released by TID = 0:
[rdd_9_0]
16/08/26 09:45:16 WARN WeightedLeastSquares: regParam is zero, which might cause numerical instability and overfitting.
model: org.apache.spark.ml.regression.LinearRegressionModel = linReg_c4e70a01bcd3

6.模型评估

scala> val res = model.transform(test_set)
res: org.apache.spark.sql.DataFrame = [_1: double, _2: vector ... 1 more field]
scala> import org.apache.spark.ml.evaluation._
import org.apache.spark.ml.evaluation._
scala> val eva = new RegressionEvaluator
eva: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_8fc6cce63aa9
scala> eva.setLabelCol("_1")
res6: eva.type = regEval_8fc6cce63aa9
scala> eva.setMetricName("mse")
res7: eva.type = regEval_8fc6cce63aa9
scala> eva.evaluate(res)
res8: Double = 0.027933653533088666

Spark机器学习笔记一的更多相关文章

  1. spark机器学习笔记01

     1)外部数据源 val distFile1 = sc.textFile("data.txt") //本地当前目录下文件 val distFile2 =sc.textFile(& ...

  2. spark学习笔记总结-spark入门资料精化

    Spark学习笔记 Spark简介 spark 可以很容易和yarn结合,直接调用HDFS.Hbase上面的数据,和hadoop结合.配置很容易. spark发展迅猛,框架比hadoop更加灵活实用. ...

  3. Spark学习笔记0——简单了解和技术架构

    目录 Spark学习笔记0--简单了解和技术架构 什么是Spark 技术架构和软件栈 Spark Core Spark SQL Spark Streaming MLlib GraphX 集群管理器 受 ...

  4. Spark机器学习· 实时机器学习

    Spark机器学习 1 在线学习 模型随着接收的新消息,不断更新自己:而不是像离线训练一次次重新训练. 2 Spark Streaming 离散化流(DStream) 输入源:Akka actors. ...

  5. Spark学习笔记之SparkRDD

    Spark学习笔记之SparkRDD 一.   基本概念 RDD(resilient distributed datasets)弹性分布式数据集. 来自于两方面 ①   内存集合和外部存储系统 ②   ...

  6. 机器学习笔记:Gradient Descent

    机器学习笔记:Gradient Descent http://www.cnblogs.com/uchihaitachi/archive/2012/08/16/2642720.html

  7. Spark机器学习 Day2 快速理解机器学习

    Spark机器学习 Day2 快速理解机器学习 有两个问题: 机器学习到底是什么. 大数据机器学习到底是什么. 机器学习到底是什么 人正常思维的过程是根据历史经验得出一定的规律,然后在当前情况下根据这 ...

  8. Spark机器学习 Day1 机器学习概述

    Spark机器学习 Day1 机器学习概述 今天主要讨论个问题:Spark机器学习的本质是什么,其内部构成到底是什么. 简单来说,机器学习是数据+算法. 数据 在Spark中做机器学习,肯定有数据来源 ...

  9. Spark机器学习之协同过滤算法

    Spark机器学习之协同过滤算法 一).协同过滤 1.1 概念 协同过滤是一种借助"集体计算"的途径.它利用大量已有的用户偏好来估计用户对其未接触过的物品的喜好程度.其内在思想是相 ...

随机推荐

  1. Java 8 新特性终极版

    声明:本文翻译自Java 8 Features Tutorial – The ULTIMATE Guide,翻译过程中发现并发编程网已经有同学翻译过了:Java 8 特性 – 终极手册,我还是坚持自己 ...

  2. 如何让程序(如java Hello)只启动一次?

    如何让程序(如java Hello)只启动一次? 摘自http://bbs.csdn.net/topics/50488704 总结一下,关于让Java程序只运行一个实例的问题,其实质是JVM之间通信的 ...

  3. CSS3 动画触发事件

    @keyframes mymove { 0% {top:0px; left:0px; background:red;} 25% {top:0px; left:100px; background:blu ...

  4. 【转】WCF、WebAPI、WCFREST、WebService之间的区别

    在.net平台下,有大量的技术让你创建一个HTTP服务,像Web Service,WCF,现在又出了Web API.在.net平台下,你有很多的选择来构建一个HTTP Services.我分享一下我对 ...

  5. Linux中图形界面和文本模式相互切换

    1.默认开机进入文本模式 如果想让开机自动进纯文本模式, 修改/etc/inittab 找到其中的 id:5:initdefault: 这行指示启动时的运行级是5,也就是图形模式 改成3就是文本模式了 ...

  6. iOS 8 自适应 Cell

    在使用 table view 的时侯经常会遇到这样的需求:table view 的 cell 中的内容是动态的,导致在开发的时候不知道一个 cell 的高度具体是多少,所以需要提供一个计算 cell ...

  7. idmap_ad — Samba's idmap_ad Backend for Winbind《转载》

    Name idmap_ad — Samba's idmap_ad Backend for Winbind DESCRIPTION The idmap_ad plugin provides a way ...

  8. python - 类的特殊成员

    class Foo: #构造方法 def __init__(self,name,age): pass self.name = name self.age = age def __str__(self) ...

  9. EF执行存储过程(带输出参数)

    1.不含动态sql.带输出参数存储过程调用实例 1.存储过程代码:   2.EF自动生成代码(包括对应ObjectResult的实体模型): 3.调用存储过程代码实例: 总结: ObjectParam ...

  10. FileShare文件读写锁解决“文件XXX正由另一进程使用,因此该进程无法访问此文件”(转)

    开发过程中,我们往往需要大量与文件交互,读文件,写文件已成家常便饭,本地运行完美,但一上到投产环境,往往会出现很多令人措手不及的意外,或开发中的烦恼,因此,我对普通的C#文件操作做了一次总结,问题大部 ...