1. Param

  Spark ML使用一个自定义的Map(ParmaMap类型),其实该类内部使用了mutable.Map容器来存储数据。

如下所示其定义:

Class ParamMap private[ml] (private val map.mutable.Map[Param[Any],Any])

  从上述定义可以看出,ParamMap是用一个Map来存储,key为Param[Any],value为Any。这里的value就是用户设置的参数值,而key是对String的封装,对用户来所其实就是字符串。

如上述的tokenizer类,对调用setInputCol方法来设置输入DataFrame的输入列,其内部实现如下所示:

Final val inputCol:Param[String] = new Param[String](this,"inputCol","input column name")

def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]

final def set[T](param:Param[T],value:T):this.type={

set(param->value)

}

2. Transformer

  Transformer类是一个抽象类,为了实现从一个DataFrame转换为另一个DataFrame,其子类只需要实现三个方法即可。如下所示的源码:

abstract class Transformer extends PipelineStage {

/**

* Transforms the dataset with optional parameters

* @param dataset input dataset

* @param firstParamPair the first param pair, overwrite embedded params

* @param otherParamPairs other param pairs, overwrite embedded params

* @return transformed dataset

*/

@Since("2.0.0")

@varargs

def transform(

dataset: Dataset[_],

firstParamPair: ParamPair[_],

otherParamPairs: ParamPair[_]*): DataFrame = {

val map = new ParamMap()

.put(firstParamPair)

.put(otherParamPairs: _*)

transform(dataset, map)

}

/**

* Transforms the dataset with provided parameter map as additional parameters.

* @param dataset input dataset

* @param paramMap additional parameters, overwrite embedded params

* @return transformed dataset

*/

@Since("2.0.0")

def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {

this.copy(paramMap).transform(dataset)

}

/**

* Transforms the input dataset.

*/

@Since("2.0.0")

def transform(dataset: Dataset[_]): DataFrame

override def copy(extra: ParamMap): Transformer

}

  • transform():该方法是用户的API方法,用户直接调用该方法来实现转换;
  • copy():该方法复制了一个Transformer对象;
  • transformSchema:由于Transformer类继承了PipelineStage接口,该接口有这个方法实现。
2.1 HasInputCol和HasOutputCol

  HasInputCol和HasOutputCol都是接口,它们定义了一种协议。若有输入或有输出参数的Transformer,那么就需要实现这个接口。

private[ml] trait HasInputCol extends Params {

final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")

/** @group getParam */

final def getInputCol: String = $(inputCol)

}

private[ml] trait HasOutputCol extends Params {

final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")

setDefault(outputCol, uid + "__output")

/** @group getParam */

final def getOutputCol: String = $(outputCol)

}

2.2 UnaryTransformer

  这个类是一元转换的抽象类,其以一个DataFrame列作为输入,然后经过处理后,产生一个新列增加到输入的DataFrame中。

该类的源码如下所示:

abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]

extends Transformer with HasInputCol with HasOutputCol with Logging {

/** API method*/

def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]

/** API method */

def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]

/**

* Creates the transform function using the given param map. The input param map already takes

* account of the embedded param map. So the param values should be determined solely by the input

* param map.

*/

protected def createTransformFunc: IN => OUT

/**

* Returns the data type of the output column.

*/

protected def outputDataType: DataType

/**

* Validates the input type. Throw an exception if it is invalid.

*/

protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType): StructType = {

val inputType = schema($(inputCol)).dataType

validateInputType(inputType)

if (schema.fieldNames.contains($(outputCol))) {

throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")

}

val outputFields = schema.fields :+

StructField($(outputCol), outputDataType, nullable = false)

StructType(outputFields)

}

/** API method */

override def transform(dataset: Dataset[_]): DataFrame = {

transformSchema(dataset.schema, logging = true)

val transformUDF = udf(this.createTransformFunc, outputDataType)

dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))

}

override def copy(extra: ParamMap): T = defaultCopy(extra)

}

1) API method

  该类提供三个API方法,用户通过使用这些方法来实现转换功能,如下所示:

Method

Description

setInputCol

指明输入DataFrame中的哪一列是被处理的,输入参数是Dataframe中存在的列名

setOutputCol

设置新增加列的名字,及对输入的列变换后悔产生一个新列,该方法设置增加新列的列名

transform

用户通过调用该方法实现DataFrame的转换,其实调用该方法是在原来的DataFrame中增加了一个新列,如何增加一个新列,则由createTransformFunc方法来实现。

需要特别说明的是transform方法的最后一条语句,其使用了Dataset的如下方法:

Dataset.withColumn(colName:String, col:Column):DataFrame

该方法的功能是通过在遍历dataset中的每一行,然后每行都增加一列,列名为colName,内容为col。

2) Implement method

  因为UnaryTransformer类是一个抽象类,其没有指明一个输入列如何产生一个新列,这些具体转换工作需要子类来实现。子类需要实现三个方法:

Method

Description

createTransformFunc

该函数实现了如何将一个输入参数变化后产生一个新数据,即可用将其理解为map操作,即inàout.

outputDataType

子类实现该方法的目的是返回一个输出列的数据类型;

validataInputType

验证输入列的类型的合法性。

UnaryTransformer抽象类有7个实现类,用户如果需要自定义转换操作也可以继承该类,然后实现相应的操作即可。如下以Tokenizer类进行介绍,如下所示:

class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)

extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {

//1. 在继承UnaryTransformer类时,指明了createTransformFunc函数的输入参数类型和返回参数类型

@Since("1.2.0")

def this() = this(Identifiable.randomUID("tok"))

//2. 实现了一个输入值如何进行处理,然后将其返回

override protected def createTransformFunc: String => Seq[String] = {

_.toLowerCase.split("\\s")

}

//3. 验证输入参数类型是否合法

override protected def validateInputType(inputType: DataType): Unit = {

require(inputType == StringType, s"Input type must be string type but got $inputType.")

}

//4.返回DataFrame中新增加列的类型

override protected def outputDataType: DataType = new ArrayType(StringType, true)

@Since("1.4.1")

override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)

}

2.3 Model

  通过上述前两节的分析,我们知道在模型训练后,Estimator会生成一个Transformer对象。这种Transformer对象就是Model类的子类,其也是Transformer抽象类的子类。

Model类簇都有特别的功能,其是机器学习模型在训练后的模型,即其能够对输入的DataFrame进行预测,所以都特别有针对性。

3. Estimator

  Estimator就是机器学习中的模型,其在Spark ML中有很多实现子类。不同的学习模型都有不同的实现方式。通过前两节分析,我们了解到Estimator在训练后悔产生一个Transformer,这个Transformer其实是Model类。每种Estimator都对应有一种Model。其类图如图 5所示。

图 5

因为Estimator继承PipelineStage,所以Estimator的实现类需要实现三个方法:

  • copy(extra:ParamMap):实现模型拷贝操作;
  • transformSchema(schema:StructType):实现DataFrame结构的转换;
  • fit(dataset:Dataset[_]):实现模型训练,这个非常重要,是用户的API方法,该方法会返回一个Model实现类。

4. 参考文献

Spark Mllib源码分析的更多相关文章

  1. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  2. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  3. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  4. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  5. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

  6. 【Spark SQL 源码分析系列文章】

    从决定写Spark SQL源码分析的文章,到现在一个月的时间里,陆陆续续差不多快完成了,这里也做一个整合和索引,方便大家阅读,这里给出阅读顺序 :) 第一篇 Spark SQL源码分析之核心流程 第二 ...

  7. spark(1.1) mllib 源码分析(二)-相关系数

    原创文章,转载请注明: 转载自http://www.cnblogs.com/tovin/p/4024733.html 在spark mllib 1.1版本中增加stat包,里面包含了一些统计相关的函数 ...

  8. spark(1.1) mllib 源码分析(一)-卡方检验

    原创文章,转载请注明: 转载自http://www.cnblogs.com/tovin/p/4019131.html 在spark mllib 1.1版本中增加stat包,里面包含了一些统计相关的函数 ...

  9. spark(1.1) mllib 源码分析(三)-朴素贝叶斯

    原创文章,转载请注明: 转载自http://www.cnblogs.com/tovin/p/4042467.html 本文主要以mllib 1.1版本为基础,分析朴素贝叶斯的基本原理与源码 一.基本原 ...

随机推荐

  1. 关于javaSocket中 Software caused connection abort: recv failed问题

    在学习Socket中今天突然遇到了以下这种问题 原来是网路连接出了问题,由于我測试的是远程连接所以是在学校的局域网下,结果非常不稳定,開始还以为怎么了一会连上了一会又出现故障然后把IP地址改为本机的1 ...

  2. Mac下配置mnmp环境

    虽然比较喜欢玩下新语言, 但是php还是常会用到的. lnmp很多人都听过, 但是不能用在Mac上面, 另外还有个mnpp但在osx 10.8.3下面跑不起来.所以自己手动一步步安装, 整理了方便安装 ...

  3. Java内存模型FAQ(一) 什么是内存模型

    原文:http://www.cs.umd.edu/~pugh/java/memoryModel/jsr-133-faq.html 第一章 译者:方腾飞 在多核系统中,处理器一般有一层或者多层的缓存,这 ...

  4. 解决UISlider滑块不灵敏

    由于UI给的thumbImage图片过小,默认UISlider开始拖动的手势范围只有thumbImage的大小之内. 为了解决这个问题需要创建一个子类继承于UISlider.重写其中的方法: - (C ...

  5. php判断页面是否是微信打开的示例(微信打开网页)

    代码如下: $user_agent = $_SERVER['HTTP_USER_AGENT'];if (strpos($user_agent, 'MicroMessenger') === false) ...

  6. Parencodings - poj 1068

      Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 22764   Accepted: 13344 Description L ...

  7. 排序算法 python

    1.先写个原始数组和测试算法是否正确,输出多次,方便计算算法运算的平均值 2.开始第一个最简单的冒泡排序 3.“”选择排序“”,跟冒泡很像,每次选最大/最小,放进新list中. 3.1发现测试test ...

  8. 嵌入式专题: S5PV210 - MPEG4编码

    我想说不台的平台,如tiny210和x210.它们的头文件是有稍微区别的. 我这个是x210下的代码.但都须要注意的是NV12T与NV12的问题,默认要求输入的图片是NV12T,经过调整之后,能够同意 ...

  9. linux授权某个用户对某个目录有读写的权限

    针对特定的某一个用户设置文件或目录权限,用setfacl. 首先打开文件系统的acl功能,在挂载参数添加cal,再保存退出,比如/home分区: vim /etc/fstab /dev/sda2 /h ...

  10. POJ 3253 Fence Repair(哈夫曼树)

    Fence Repair Time Limit: 2000MS   Memory Limit: 65536K Total Submissions: 26167   Accepted: 8459 Des ...