1.UDAF定义

spark中的UDF(UserDefinedFunction)大家都不会陌生, UDF其实就是将一个普通的函数, 包装为可以按 操作DataFrame中指定Columns的函数.

例如, 对某一列的所有元素进行+1操作, 它对应mapreduce操作中的map操作. 这种操作有的主要特点是:

  • 行与行之间的操作是独立的, 可以非常方便的并行计算
  • 每一行的操作完成后, map的任务就完成了, 直接将结果返回就行, 它是一种”无状态的“

但是UDAF(UserDefinedAggregateFunction)则不同, 由于存在聚合(Aggregate)操作, 它对应mapreduce操作中的reduce操作. SparkSQL中有很多现成的聚合函数, 常用的sum, count, avg等等都是. 这种操作的主要特点是:

  • 每一轮reduce之间可以是并行, 但是多轮reduce的执行是串行的, 下一轮依靠前一轮的结果, 它是一种“有状态的”, 需要记录中间的计算结果

分析上图, 96 => (96, 1)这一步是一个map操作, 给每个样本添加一个1, 表示它的数量. 它们之间的计算是独立的, 也不影响数据的行数. 然后(96, 1)和(54, 1)求和, 得到(150, 2), 它是一轮reduce的其中一个中间结果, 等三个中间结果都结束了, 才能继续后续的reduce, 得到最终的reduce结果(303, 6), 因此完整的reduce需要记录并不断更新中间结果.

2.向量平均(average pooling)

向量平均是个很常用的操作, 比如我们现在有1000个64维的向量, 想要求这1000个点的中心点. 通常来说我们不会用64列float column去存储一个向量, 因此无法使用原生的avg函数.

下面介绍如何自定义一个avgvector函数, 去处理array[float] column的平均值计算问题. 通过这个例子学会如何在spark下实现自定义的聚合函数

2.1 average的并行化

average算法非常简单, 求个和, 然后除以样本个数就好了. 它的并行化也很好理解

  • reduce的过程只进行sum的累积和样本数num的累积, 在最后一步将sum/num

因此我们的在reduce的过程中, 需要时刻记录当前task处理的样本的个数, 和它们的和.

由于这样的原因, 不像UDF只需要定义一个函数就可以, UDAF通常需要定义一个类, 用来保存中间结果

2.2 代码实现

// 从基类UserDefinedAggregateFunction继承
class VectorMean64 extends UserDefinedAggregateFunction {
// 定义输入的格式
// 这个函数将会处理的那一列的数据类型, 因为是64维的向量, 因此是Array[Float]
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("vector", ArrayType(FloatType)) :: Nil) // 这个就是上面提到的状态
// 在reduce过程中, 需要记录的中间结果. vector_count即为已经统计的向量个数, 而vector_sum即为已经统计的向量的和
override def bufferSchema: StructType =
StructType(
StructField("vector_count", IntegerType) ::
StructField("vector_sum", ArrayType(FloatType)) :: Nil) // 最终的输出格式
// 既然是求平均, 最后当然还是一个向量, 依然是Array[Float]
override def dataType: DataType = ArrayType(FloatType) override def deterministic: Boolean = true // 初始化
// buffer的格式即为bufferSchema, 因此buffer(0)就是向量个数, 初始化当然是0, buffer(1)为向量和, 初始化为零向量
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
buffer(1) = Array.fill[Float](64)(0).toSeq
} // 定义reduce的更新操作: 如何根据一行新数据, 更新一个聚合buffer的中间结果
// 一行数据是一个向量, 因此需要将count+1, 然后sum+新向量
// addTwoEmb为向量相加的基本实现
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getInt(0) + 1 val inputVector = input.getAs[Seq[Float]](0)
buffer(1) = addTwoEmb(buffer.getAs[Seq[Float]](1), inputVector)
} // 定义reduce的merge操作: 两个buffer结果合并到其中一个bufer上
// 两个buffer各自统计的样本个数相加; 两个buffer各自的sum也相加
// 注意: 为什么buffer1和buffer2的数据类型不一样?一个是MutableAggregationBuffer, 一个是Row
// 因为: 在将所有中间task的结果进行reduce的过程中, 两两合并时是将一个结果合到另外一个上面, 因此一个是mutable的, 它们两者的schema其实是一样的, 都对应bufferSchema
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
buffer1(1) = addTwoEmb(buffer1.getAs[Seq[Float]](1), buffer2.getAs[Seq[Float]](1))
} // 最终的结果, 依赖最终的buffer中的数据计算的到, 就是将sum/count
override def evaluate(buffer: Row): Any = {
val result = buffer.getAs[Seq[Float]](1).toArray
val count = buffer.getInt(0)
for (i <- result.indices) {
result(i) /= (count + 1)
}
result.toSeq
} // 向量相加
private def addTwoEmb(emb1: Seq[Float], emb2: Seq[Float]): Seq[Float] = {
val result = Array.fill[Float](emb1.length)(0)
for (i <- emb1.indices) {
result(i) = emb1(i) + emb2(i)
}
result.toSeq
}

解释可以参考上面的代码注释. 核心就是定义四个模块:

  • 中间结果的格式 - bufferSchema
  • 将一行数据更新到中间结果buffer中 - update
  • 将两个中间结果buffer合并 - merge
  • 从最后的buffer计算需要的结果 - evaluate

2.3 使用

// 注册一下, 使其可以在Spark SQL中使用
spark.udf.register("avgVector64", new VectorMean64)
spark.sql("""
|select group_id, avgVector64(embedding) as avg_embedding
|from embedding_table_name
|group by group_id
""".stripMargin) // 当然不注册也可以用, 只是不能在SQL中用, 可以直接用来操作DataFrame
val avgVector64 = new VectorMean64
val df = spark.sql("select group_id, embedding from embedding_table_name")
df.groupBy("group_id").agg(avgVector64(col("embedding")))

参考

https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html

Spark UDAF实现举例 -- average pooling的更多相关文章

  1. 深度学习方法(十):卷积神经网络结构变化——Maxout Networks,Network In Network,Global Average Pooling

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 最近接下来几篇博文会回到神经网络结构 ...

  2. 深度拾遗(06) - 1X1卷积/global average pooling

    什么是1X1卷积 11的卷积就是对上一层的多个feature channels线性叠加,channel加权平均. 只不过这个组合系数恰好可以看成是一个11的卷积.这种表示的好处是,完全可以回到模型中其 ...

  3. Global Average Pooling Layers for Object Localization

    For image classification tasks, a common choice for convolutional neural network (CNN) architecture ...

  4. 深度学习基础系列(十)| Global Average Pooling是否可以替代全连接层?

    Global Average Pooling(简称GAP,全局池化层)技术最早提出是在这篇论文(第3.2节)中,被认为是可以替代全连接层的一种新技术.在keras发布的经典模型中,可以看到不少模型甚至 ...

  5. Network in Network(2013),1x1卷积与Global Average Pooling

    目录 写在前面 mlpconv layer实现 Global Average Pooling 网络结构 参考 博客:blog.shinelee.me | 博客园 | CSDN 写在前面 <Net ...

  6. spark UDAF

    感谢我的同事 李震给我讲解UDAF 网上找到的大部分都只有代码,但是缺少讲解,官网的的API有讲解,但是看不太明白.我还是自己记录一下吧,或许对其他人有帮助. 接下来以一个求几何平均数的例子来说明如何 ...

  7. 理解Spark SQL(三)—— Spark SQL程序举例

    上一篇说到,在Spark 2.x当中,实际上SQLContext和HiveContext是过时的,相反是采用SparkSession对象的sql函数来操作SQL语句的.使用这个函数执行SQL语句前需要 ...

  8. 自定义spark UDAF

    官网链接 样例代码: import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Dataset; i ...

  9. 转:Spark User Defined Aggregate Function (UDAF) using Java

    Sometimes the aggregate functions provided by Spark are not adequate, so Spark has a provision of ac ...

随机推荐

  1. Redis分布式锁—SETNX+Lua脚本实现篇

    前言 平时的工作中,由于生产环境中的项目是需要部署在多台服务器中的,所以经常会面临解决分布式场景下数据一致性的问题,那么就需要引入分布式锁来解决这一问题. 针对分布式锁的实现,目前比较常用的就如下几种 ...

  2. 冰河开源了全网首个完全开源的分布式全局有序序列号(分布式ID)框架!!

    写在前面 mykit-serial框架的设计参考了李艳鹏大佬开源的vesta框架,并彻底重构了vesta框架,借鉴了雪花算法(SnowFlake)的思想,并在此基础上进行了全面升级和优化.支持嵌入式( ...

  3. PyQt(Python+Qt)学习随笔:QLineEdit行编辑器功能详解

    专栏:Python基础教程目录 专栏:使用PyQt开发图形界面Python应用 专栏:PyQt入门学习 老猿Python博文目录 一.概述 QLineEdit部件是一个单行文本编辑器,支持撤消和重做. ...

  4. PHP代码审计分段讲解(9)

    22 弱类型整数大小比较绕过 <?php error_reporting(0); $flag = "flag{test}"; $temp = $_GET['password' ...

  5. PHP代码审计分段讲解(2)

    03 多重加密 源代码为: <?php include 'common.php'; $requset = array_merge($_GET, $_POST, $_SESSION, $_COOK ...

  6. SQL数据库优化的六种方法

    SQL命令因为语法简单.操作高效受到了很多用户的欢迎.但是,SQL命令的效率受到不同的数据库功能的限制,特别是在计算时间方面,再加上语言的高效率也不意味着优化会更容易,所以每个数据库都需要依据实际情况 ...

  7. python的数据缓存

    Python的数据缓存 python 的内置数据类型,数值型,字符串,列表,字典等都会有自己的对象缓存池, 这样做的好处是,避免了频繁的申请内存,释放内存,这样会极大的降低应用程序的运行速度,还会造成 ...

  8. 浏览器小程序(Browser Applet)闪亮登场

    2017 年 1 月 9 日,微信小程序横空出世.随后,支付宝小程序.今日头条小程序.百度智能小程序.360小程序等纷纷推出,自此国内软件功能扩展领域进入到了小程序时代,小程序为丰富其宿主软件的功能和 ...

  9. MySQL锁(一)全局锁:如何做全库的逻辑备份?

    数据库锁设计的初衷是处理并发问题,这也是数据库与文件系统的最大区别. 根据加锁的范围,MySQL里大致可以分为三种锁:全局锁.表锁和行锁.接下来我们会分三讲来介绍这三种锁,今天要讲的是全局锁. 全局锁 ...

  10. 【原创】WPF TreeView带连接线样式的优化(WinFrom风格)

    一.前言 之前查找WPF相关资料的时候,发现国外网站有一个TreeView控件的样式,是WinFrom风格的,样式如下,文章链接:https://www.codeproject.com/tips/67 ...