Spark MLlib 之 aggregate和treeAggregate从原理到应用
在阅读spark mllib源码的时候,发现一个出镜率很高的函数——aggregate和treeAggregate,比如matrix.columnSimilarities()中。为了好好理解这两个方法的使用,于是整理了本篇内容。
由于treeAggregate是在aggregate基础上的优化版本,因此先来看看aggregate是什么.
更多内容参考我的大数据学习之路
aggregate
先直接看一下代码例子:
import org.apache.spark.sql.SparkSession
object AggregateTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
// 创建rdd,并分成6个分区
val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
// 输出每个分区的内容
rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
Array((s" $index : ${it.toList.mkString(",")}")).toIterator
}).foreach(println)
// 执行agg
val res1 = rdd.aggregate(0)(seqOp, combOp)
}
// 分区内执行的方法,直接加和
def seqOp(s1:Int, s2:Int):Int = {
println("seq: "+s1+":"+s2)
s1 + s2
}
// 在driver端汇总
def combOp(c1: Int, c2: Int): Int = {
println("comb: "+c1+":"+c2)
c1 + c2
}
}
这段代码的主要目的就是为了求和。考虑到spark分区并行计算的特性,在每个分区独立加和,最后再汇总加和。
过程可以参考下面的图片:

首先看一下map阶段,即在每个分区内计算加和。初始情况如蓝色方块所示,内容为:
分区号:里面的内容
如,0分区内的数据为6和8
当执行seqop时,会说先用初始值0开始遍历累加,原理类似如下:
rdd.mapPartitions((it:Iterator)=>{
var sum = init_value // 默认为0
it.foreach(sum + _)
sum
})
因此屏幕上会出现下面的内容,由于分区之间是并行的,所以最后的结果是乱序的:
seq: 0:6
seq: 0:1
seq: 0:3
seq: 1:9
seq: 3:10
seq: 0:2
seq: 0:5
seq: 5:7
seq: 12:12
seq: 0:4
seq: 4:11
seq: 6:8
计算完成后,依次遍历每个分区结果,进行累加:
comb: 0:10
comb: 10:13
comb: 23:2
comb: 25:24
comb: 49:15
comb: 64:14
aggregate的源码也比较简单:
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}
treeAggregate
treeAggregate在aggregate的基础上做了一些优化,因为aggregate是在每个分区计算完成后,把所有的数据拉倒driver端,进行统一的遍历合并,这样如果数据量很大,在driver端可能会OOM。
因此treeAggregate在中间多加了一层合并。
先来看看代码,没有任何的变化:
import org.apache.spark.sql.SparkSession
object TreeAggregateTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
Array(s" $index : ${it.toList.mkString(",")}").toIterator
}).foreach(println)
val res1 = rdd.treeAggregate(0)(seqOp, combOp)
println(res1)
}
def seqOp(s1:Int, s2:Int):Int = {
println("seq: "+s1+":"+s2)
s1 + s2
}
def combOp(c1: Int, c2: Int): Int = {
println("comb: "+c1+":"+c2)
c1 + c2
}
}
输出的结果则发生了变化,首先分区内的操作不变:
3 : 3,10
2 : 2
0 : 6,8
1 : 1,9
4 : 4,11
5 : 5,7,12
seq: 0:3
seq: 0:6
seq: 3:10
seq: 6:8
seq: 0:2
seq: 0:1
seq: 1:9
seq: 0:4
seq: 4:11
seq: 0:5
seq: 5:7
seq: 12:12
...
在合并的时候发生了 变化:
comb: 10:13
comb: 23:24
comb: 14:2
comb: 16:15
comb: 47:31
配合下面的流程图,可以更好的理解:

搭配treeAggregate的源码来看一下:
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == 0) {
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
// 这里都没什么变化,在分区中遍历数据累加
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
// 关键是这下面的内容 !!!!
// 首先获得当前的分区数
var numPartitions = partiallyAggregated.partitions.length
// 计算合适的并行度,我这里相当于6^(1/2),也就是2.4左右,ceill向上取整后变成3.
// max(3,2)得到最后的结果为3。即每个树的分枝有3个叶子节点
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// 遍历分区,通过对scale取模进行合并计算
// 这里判断一下,当前的分区数是否还够分。如果少于条件值 scale+(p/scale),就停止分区
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
// 重新定义分区id,并按照分区id重新分区,执行合并计算
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
// 最后统计结果
partiallyAggregated.reduce(cleanCombOp)
}
}
spark中的应用
// matrix求相似度
def columnSimilarities(threshold: Double): CoordinateMatrix = {
... columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
}
// 统计每一个向量的相关数据,里面包含了min max 等等很多信息
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
updateNumRows(summary.count)
summary
}
了解了treeAggregate之后,后续就可以看matrix的并行求解相似度的源码了!敬请期待吧...
参考
Spark MLlib 之 aggregate和treeAggregate从原理到应用的更多相关文章
- Spark MLlib 之 大规模数据集的相似度计算原理探索
无论是ICF基于物品的协同过滤.UCF基于用户的协同过滤.基于内容的推荐,最基本的环节都是计算相似度.如果样本特征维度很高或者<user, item, score>的维度很大,都会导致无法 ...
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...
- Spark MLlib特征处理:OneHotEncoder OneHot编码 ---原理及实战
http://m.blog.csdn.net/wangpei1949/article/details/53140372 Spark MLlib特征处理:OneHotEncoder OneHot编码 - ...
- Spark MLlib LDA 基于GraphX实现原理及源代码分析
LDA背景 LDA(隐含狄利克雷分布)是一个主题聚类模型,是当前主题聚类领域最火.最有力的模型之中的一个,它能通过多轮迭代把特征向量集合按主题分类.眼下,广泛运用在文本主题聚类中. LDA的开源实现有 ...
- 《Spark MLlib机器学习实践》内容简介、目录
http://product.dangdang.com/23829918.html Spark作为新兴的.应用范围最为广泛的大数据处理开源框架引起了广泛的关注,它吸引了大量程序设计和开发人员进行相 ...
- Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
欢迎转载,转载请注明出处,徽沪一郎. 概要 本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读. 拟牛顿法 数学原理 代码实现 L-BFGS算法中使 ...
- Spark MLlib LDA 源代码解析
1.Spark MLlib LDA源代码解析 http://blog.csdn.net/sunbow0 Spark MLlib LDA 应该算是比較难理解的,当中涉及到大量的概率与统计的相关知识,并且 ...
- 【原】Learning Spark (Python版) 学习笔记(三)----工作原理、调优与Spark SQL
周末的任务是更新Learning Spark系列第三篇,以为自己写不完了,但为了改正拖延症,还是得完成给自己定的任务啊 = =.这三章主要讲Spark的运行过程(本地+集群),性能调优以及Spark ...
- Spark入门实战系列--8.Spark MLlib(上)--机器学习及SparkMLlib简介
[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .机器学习概念 1.1 机器学习的定义 在维基百科上对机器学习提出以下几种定义: l“机器学 ...
随机推荐
- SonarQube代码质量管理工具安装与使用(sonarqube5.1.2 + sonar-runner-dist-2.4 + MySQL5.x)
1. SonarQube安装(sonarqube5.1.2 + sonar-runner-dist-2.4) 1.1 前提条件 1) 已安装Java环境(version:1.7+) 2) 已安装MyS ...
- Android动画之逐帧动画(FrameAnimation)详解
今天我们就来学习逐帧动画,废话少说直接上效果图如下: 帧动画的实现方式有两种: 一.在res/drawable文件夹下新建animation-list的XML实现帧动画 1.首先在res/drawab ...
- OCM_第十六天课程:Section7 —》GI 及 ASM 安装配置 _安装 GRID 软件/创建和管理 ASM 磁盘组/创建和管理 ASM 实例
注:本文为原著(其内容来自 腾科教育培训课堂).阅读本文注意事项如下: 1:所有文章的转载请标注本文出处. 2:本文非本人不得用于商业用途.违者将承当相应法律责任. 3:该系列文章目录列表: 一:&l ...
- collectd+influxDB+Grafana搭建性能监控平台
网上查看了很多关于环境搭建的文章,都比较久远了很多安装包源都不可用了,今天收集了很多资料组合尝试使用新版本来搭建,故在此记录. 采集数据(collectd)-> 存储数据(influxdb) - ...
- python 全栈开发,Day140(RabbitMQ,基于scrapy-redis实现分布式爬虫)
一.RabbitMQ 队列 在生产者消费模型中,比如去餐馆吃饭的例子.生产者相当于厨师,队列相当于服务员,消费者就是你. 我们必须通过服务员,才能吃饭! 如果队列满了,队列会一直hold住.必须让消费 ...
- python 全栈开发,Day9(函数的初始,返回值,传参,三元运算)
一.函数的初始 比如python没有len()方法,如何求字符串的长度使用for循环 s = 'fdshfeigjoglfkldsja' count = 0 for i in s: count += ...
- 获取table行列
var table =document.getElementById("add_purchaseOrderDetailList_table"); var rows = table. ...
- HashTable、HashMap、ConcurrentHashMap的区别
HashTable是做了同步的,HashMap未考虑同步.所以HashMap在单线程情况下效率较高:HashTable在的多线程情况下,同步操作能保证程序执行的正确性. HashMap是非线程安全的, ...
- Zookeeper笔记(一)初识Zookeeper
为什么需要Zookeeper Zookeeper是一个典型的分布式数据一致性的解决方案,分布式应用程序可以基于它实现诸如数据发布/订阅.负载均衡.命名服务.分布式协调/通知.集群管理.Master选举 ...
- (APIO2014)序列分割
题解: 我也不知道为啥上午上课讲了我昨天看的3题 这题关键在于发现操作顺序无关的 可以发现最终答案是任意两段乘积的和 那这个东西显然是可以dp的 然后可以斜率优化一波 nklongn 另外上课讲的是当 ...