Apache Spark源码走读之22 -- 浅谈mllib中线性回归的算法实现
欢迎转载,转载请注明出处,徽沪一郎。
概要
本文简要描述线性回归算法在Spark MLLib中的具体实现,涉及线性回归算法本身及线性回归并行处理的理论基础,然后对代码实现部分进行走读。
线性回归模型
机器学习算法是的主要目的是找到最能够对数据做出合理解释的模型,这个模型是假设函数,一步步的推导基本遵循这样的思路
- 假设函数
- 为了找到最好的假设函数,需要找到合理的评估标准,一般来说使用损失函数来做为评估标准
- 根据损失函数推出目标函数
- 现在问题转换成为如何找到目标函数的最优解,也就是目标函数的最优化
具体到线性回归来说,上述就转换为


梯度下降法
那么如何求得损失函数的最优解,针对最小二乘法来说可以使用梯度下降法。

算法实现


随机梯度下降


正则化

如何解决这些问题呢?可以采用收缩方法(shrinkage method),收缩方法又称为正则化(regularization)。
主要是岭回归(ridge regression)和lasso回归。通过对最小二乘估计加
入罚约束,使某些系数的估计为0。

线性回归的代码实现
上面讲述了一些数学基础,在将这些数学理论用代码来实现的时候,最主要的是把握住相应的假设函数和最优化算法是什么,有没有相应的正则化规则。
对于线性回归,这些都已经明确,分别为
- Y = A*X + B 假设函数
- 随机梯度下降法
- 岭回归或Lasso法,或什么都没有
那么Spark mllib针对线性回归的代码实现也是依据该步骤来组织的代码,其类图如下所示

函数调用路径

train->run,run函数的处理逻辑
- 利用最优化算法来求得最优解,optimizer.optimize
- 根据最优解创建相应的回归模型, createModel
runMiniBatchSGD是真正计算Gradient和Loss的地方
def runMiniBatchSGD(
data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
stepSize: Double,
numIterations: Int,
regParam: Double,
miniBatchFraction: Double,
initialWeights: Vector): (Vector, Array[Double]) = {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
val numExamples = data.count()
val miniBatchSize = numExamples * miniBatchFraction
// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {
logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
return (initialWeights, stochasticLossHistory.toArray)
}
// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)
val n = weights.size
/**
* For the first iteration, the regVal will be initialized as sum of weight squares
* if it's L2 updater; for L1 updater, the same logic is followed.
*/
var regVal = updater.compute(
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
for (i (c, v) match { case ((grad, loss), (label, features)) =>
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
(grad1 += grad2, loss1 + loss2)
})
/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
weights = update._1
regVal = update._2
}
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))
(weights, stochasticLossHistory.toArray)
}
上述代码中最需要引起重视的部分是aggregate函数的使用,先看下aggregate函数的定义
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}
aggregate函数有三个入参,一是初始值ZeroValue,二是seqOp,三为combOp.
- seqOp seqOp会被并行执行,具体由各个executor上的task来完成计算
- combOp combOp则是串行执行, 其中combOp操作在JobWaiter的taskSucceeded函数中被调用
为了进一步加深对aggregate函数的理解,现举一个小小例子。启动spark-shell后,运行如下代码
val z = sc. parallelize (List (1 ,2 ,3 ,4 ,5 ,6),2)
z.aggregate (0)(math.max(_, _), _ + _)
// 运 行 结 果 为 9
res0: Int = 9
仔细观察一下运行时的日志输出, aggregate提交的job由一个stage(stage0)组成,由于整个数据集被分成两个partition,所以为stage0创建了两个task并行处理。
LeastSquareGradient
讲完了aggregate函数的执行过程, 回过头来继续讲组成seqOp的gradient.compute函数。
LeastSquareGradient用来计算梯度和误差,注意cmopute中cumGraident会返回改变后的结果。这里计算公式依据的就是cost-function中的▽Q(w)
class LeastSquaresGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val brzData = data.toBreeze
val brzWeights = weights.toBreeze
val diff = brzWeights.dot(brzData) - label
val loss = diff * diff
val gradient = brzData * (2.0 * diff)
(Vectors.fromBreeze(gradient), loss)
}
override def compute(
data: Vector,
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
val brzData = data.toBreeze
val brzWeights = weights.toBreeze
//dot表示点积,是接受在实数R上的两个向量并返回一个实数标量的二元运算,它的结果是欧几里得空间的标准内积。
//两个向量的点积写作a·b。点乘的结果叫做点积,也称作数量积
val diff = brzWeights.dot(brzData) - label
//下面这句话完成y += a*x
brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze)
diff * diff
}
}
在上述代码中频繁出现breeze相关的函数,你一定会很好奇,这是个什么新鲜玩艺。
说 开 了 其 实 一 点 也 不 稀 奇, 由 于 计 算 中 有 大 量 的 矩 阵(Matrix)及 向量(Vector)计算,为了更好支持和封装这些计算引入了breeze库。
Breeze, Epic及Puck是scalanlp中三大支柱性项目, 具体可参数www.scalanlp.org
正则化过程
根据本次迭代出来的梯度和误差对权重系数进行更新,这个时候就需要用上正则化规则了。也就是下述语句会触发权重系数的更新
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
以岭回归为例,看其更新过程的代码实现。
class SquaredL2Updater extends Updater {
override def compute(
weightsOld: Vector,
gradient: Vector,
stepSize: Double,
iter: Int,
regParam: Double): (Vector, Double) = {
// add up both updates from the gradient of the loss (= step) as well as
// the gradient of the regularizer (= regParam * weightsOld)
// w' = w - thisIterStepSize * (gradient + regParam * w)
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
val thisIterStepSize = stepSize / math.sqrt(iter)
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
brzWeights :*= (1.0 - thisIterStepSize * regParam)
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
val norm = brzNorm(brzWeights, 2.0)
(Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
}
}
结果预测
计算出权重系数(weights)和截距intecept,就可以用来创建线性回归模型LinearRegressionModel,利用模型的predict函数来对观测值进行预测
class LinearRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
}
注意LinearRegression的构造函数需要权重(weights)和截距(intercept)作为入参,对新的变量做出预测需要调用predictPoint
一个完整的示例程序
在spark-shell中执行如下语句来亲自体验一下吧。
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}
// Building the model
val numIterations = 100
val model = LinearRegressionWithSGD.train(parsedData, numIterations)
// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean()
println("training Mean Squared Error = " + MSE)
小结
再次强调,找到对应的假设函数,用于评估的损失函数,最优化求解方法,正则化规则
Apache Spark源码走读之22 -- 浅谈mllib中线性回归的算法实现的更多相关文章
- Apache Spark源码走读之7 -- Standalone部署方式分析
欢迎转载,转载请注明出处,徽沪一郎. 楔子 在Spark源码走读系列之2中曾经提到Spark能以Standalone的方式来运行cluster,但没有对Application的提交与具体运行流程做详细 ...
- Apache Spark源码走读之16 -- spark repl实现详解
欢迎转载,转载请注明出处,徽沪一郎. 概要 之所以对spark shell的内部实现产生兴趣全部缘于好奇代码的编译加载过程,scala是需要编译才能执行的语言,但提供的scala repl可以实现代码 ...
- Apache Spark源码走读之13 -- hiveql on spark实现详解
欢迎转载,转载请注明出处,徽沪一郎 概要 在新近发布的spark 1.0中新加了sql的模块,更为引人注意的是对hive中的hiveql也提供了良好的支持,作为一个源码分析控,了解一下spark是如何 ...
- Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
欢迎转载,转载请注明出处,徽沪一郎. 概要 本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读. 拟牛顿法 数学原理 代码实现 L-BFGS算法中使 ...
- Apache Spark源码走读之18 -- 使用Intellij idea调试Spark源码
欢迎转载,转载请注明出处,徽沪一郎. 概要 上篇博文讲述了如何通过修改源码来查看调用堆栈,尽管也很实用,但每修改一次都需要编译,花费的时间不少,效率不高,而且属于侵入性的修改,不优雅.本篇讲述如何使用 ...
- Apache Spark源码走读之6 -- 存储子系统分析
欢迎转载,转载请注明出处,徽沪一郎. 楔子 Spark计算速度远胜于Hadoop的原因之一就在于中间结果是缓存在内存而不是直接写入到disk,本文尝试分析Spark中存储子系统的构成,并以数据写入和数 ...
- Apache Spark源码走读之17 -- 如何进行代码跟读
欢迎转载,转载请注明出处,徽沪一郎 概要 今天不谈Spark中什么复杂的技术实现,只稍为聊聊如何进行代码跟读.众所周知,Spark使用scala进行开发,由于scala有众多的语法糖,很多时候代码跟着 ...
- Apache Spark源码走读之5 -- DStream处理的容错性分析
欢迎转载,转载请注明出处,徽沪一郎,谢谢. 在流数据的处理过程中,为了保证处理结果的可信度(不能多算,也不能漏算),需要做到对所有的输入数据有且仅有一次处理.在Spark Streaming的处理机制 ...
- Apache Spark源码走读之11 -- sql的解析与执行
欢迎转载,转载请注明出处,徽沪一郎. 概要 在即将发布的spark 1.0中有一个新增的功能,即对sql的支持,也就是说可以用sql来对数据进行查询,这对于DBA来说无疑是一大福音,因为以前的知识继续 ...
随机推荐
- Spell checker(poj 1035)
题意: 此题是一个字符串的问题,首先要给出一个字典,里面存储了数个单词.而后,给出一个单词,如果字典中存在,那么就输出correct,如果字典中没有,那么就要判断是不是这个单词有错误,错误有3 ...
- commons-fileupload实现文件上传下载
commons-fileupload是Apache提供的一个实现文件上传下载的简单,有效途径,需要commons-io包的支持,本文是一个简单的示例 上传页面,注意设置响应头 <body> ...
- Struts2中配置默认Action
1.当访问的Action不存在时,页面会显示错误信息,可以通过配置默认Action处理用户异常的操作:2.配置方法: 在struts.xml文件中的<package>下添加如下内容: ...
- 一、HTML和CSS基础--网页布局--网页布局基础
W3C标准: 由万维网联盟制定的一系列标准,包括: 结构化标准语言(HTML和XML) 表现标准语言(CSS) 行为标准语言(DOM和ECMAScript) 倡导结构.样式.行为分离. CSS 规定的 ...
- 会员制实现C2B定制有机农产品,被中粮我买投资的良食网这样卖有机生鲜
前几天,中粮我买网战略投资了位于深圳的有机生鲜自营平台良食网,宣布双方将会在供应链上展开合作.然而良食网对大家来说还是比较陌生的,为此36氪专访了良食网的创始人唐忠. 良食网成立于2011年,是一家以 ...
- HTML5标准学习 – DOCTYPE
转自:http://www.cnblogs.com/GrayZhang/archive/2011/03/31/learning-html5-doctype.html 上一篇文章主要讲述了HTML文档的 ...
- IOS杂谈
1 IOS名称是iPhone Operating System 的缩写,原本这个系统名为iPhone OS,意思是iPhone 操作系统. 2 IOS的开发环境是Xcode.Xcode就成为了iPho ...
- 蒟蒻的树形dp记录
POJ2342: 题意:某公司要举办一次晚会,但是为了使得晚会的气氛更加活跃,每个参加晚会的人都不希望在晚会中见到他的直接上司,现在已知每个人的活跃指数和上司关系(当然不可能存在环),求邀请哪些人(多 ...
- Adapter适配器
1.概念 *连接后端数据和前端显示的适配器接口 *数据和UI之间的重要连接 2. ArrayAdapter ArrayAdapter构造器如下: ArrayAdapter(Context con ...
- hdu 1455 Sticks
Sticks Time Limit:1000MS Memory Limit:32768KB 64bit IO Format:%I64d & %I64u Submit Statu ...