Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。

UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()

UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()

Spark编写UDF函数

下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF1 {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("spark udf test");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
@SuppressWarnings("deprecation")
SQLContext sqlContext=new SQLContext(javaSparkContext);
JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"));
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema);
ds.createOrReplaceTempView("user"); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sqlContext.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType); Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user");
rows.show(); javaSparkContext.stop();
}
}

输出效果:

+---+--------+------+
| id| name|length|
+---+--------+------+
| 1|zhangsan| 8|
| 2| lisi| 4|
| 3| wangwu| 6|
| 4| zhaoliu| 7|
+---+--------+------+

上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDF2 {
public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING()); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
sparkSession.udf().register("strLength", new UDF1<String, Integer>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType);
sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() {
private static final long serialVersionUID = -8172995965965931129L; @Override
public String call(String combChar, String t1, String t2) throws Exception {
return t1 + combChar + t2;
}
}, DataTypes.StringType); showByStruct(sparkSession, row);
System.out.println("==========================================");
showBySchema(sparkSession, row); sparkSession.stop();
} private static void showBySchema(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
return RowFactory.create(fields);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user");
rows.show();
} private static void showByStruct(SparkSession sparkSession, Dataset<String> row) {
JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson);
Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class);
persons.show(); persons.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user");
rows.show();
}
}

Person.java

package com.dx.streaming.producer;

import java.io.Serializable;

public class Person implements Serializable{
private String id;
private String name; public Person(String id, String name) {
this.id = id;
this.name = name;
} public String getId() {
return id;
} public void setId(String id) {
this.id = id;
} public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public static Person parsePerson(String line) {
String[] fields = line.split(",");
Person person = new Person(fields[0], fields[1]);
return person;
}
}

需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。

输出效果:

+---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1-zhangsan|
| 2| lisi| 4| 2-lisi|
| 3| wangwu| 6| 3-wangwu|
| 4| zhaoliu| 7| 4-zhaoliu|
+---+--------+------+----------+ ========================================== +---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+ +---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1+zhangsan|
| 2| lisi| 4| 2+lisi|
| 3| wangwu| 6| 3+wangwu|
| 4| zhaoliu| 7| 4+zhaoliu|
+---+--------+------+----------+

相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。

Spark编写UDAF函数

自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:

package org.apache.spark.sql.expressions

import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental /**
* :: Experimental ::
* The base class for implementing user-defined aggregate functions (UDAF).
*/
@Experimental
abstract class UserDefinedAggregateFunction extends Serializable { /**
* A [[StructType]] represents data types of input arguments of this aggregate function.
* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
* with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* input argument. Users can choose names to identify the input arguments.
*/
//输入参数的数据类型定义
def inputSchema: StructType /**
* A [[StructType]] represents data types of values in the aggregation buffer.
* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
* (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
* the returned [[StructType]] will look like
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* The name of a field of this [[StructType]] is only used to identify the corresponding
* buffer value. Users can choose names to identify the input arguments.
*/
//聚合的中间过程中产生的数据的数据类型定义
def bufferSchema: StructType /**
* The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
*/
//聚合结果的数据类型定义
def dataType: DataType /**
* Returns true if this function is deterministic, i.e. given the same input,
* always return the same output.
*/
//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
def deterministic: Boolean /**
* Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
*
* The contract should be that applying the merge function on two initial buffers should just
* return the initial buffer itself, i.e.
* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
*/
//设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* Updates the given aggregation buffer `buffer` with new input data from `input`.
*
* This is called once per input row.
*/
//用输入数据input更新buffer值,类似于combineByKey
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
*
* This is called when we merge two partially aggregated data together.
*/
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
* aggregation buffer.
*/
//计算并返回最终的聚合结果
def evaluate(buffer: Row): Any
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
//所有输入数据进行聚合
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
} /**
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
//所有输入数据去重后进行聚合
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression2(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
} /**
* :: Experimental ::
* A [[Row]] representing an mutable aggregation buffer.
*
* This is not meant to be extended outside of Spark.
*/
@Experimental
abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */
def update(i: Int, value: Any): Unit
}

实现单列求平均数的聚合函数:

package com.dx.streaming.producer;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class SimpleAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L; @Override
public StructType inputSchema() {
StructType structType= new StructType().add("myinput",DataTypes.DoubleType);
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1
buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}

下边展示下如何使用自定义的UDAF函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80",
"2,zhangsan,History,87",
"3,zhangsan,Chinese,88",
"4,zhangsan,Chemistry,96",
"5,lisi,English,70",
"6,lisi,Chinese,74",
"7,lisi,History,75",
"8,lisi,Chemistry,77",
"9,lisi,Physics,79",
"10,lisi,Biology,82",
"11,wangwu,English,96",
"12,wangwu,Chinese,98",
"13,wangwu,History,91",
"14,zhaoliu,English,68",
"15,zhaoliu,Chinese,66"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve=Double.parseDouble(fields[3]);
return RowFactory.create(id,name,subject,achieve);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new SimpleAvg();
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name");
rows2.show();
} }

输出结果:

+---+--------+---------+-------+
| id| name| subject|achieve|
+---+--------+---------+-------+
| 1|zhangsan| English| 80.0|
| 2|zhangsan| History| 87.0|
| 3|zhangsan| Chinese| 88.0|
| 4|zhangsan|Chemistry| 96.0|
| 5| lisi| English| 70.0|
| 6| lisi| Chinese| 74.0|
| 7| lisi| History| 75.0|
| 8| lisi|Chemistry| 77.0|
| 9| lisi| Physics| 79.0|
| 10| lisi| Biology| 82.0|
| 11| wangwu| English| 96.0|
| 12| wangwu| Chinese| 98.0|
| 13| wangwu| History| 91.0|
| 14| zhaoliu| English| 68.0|
| 15| zhaoliu| Chinese| 66.0|
+---+--------+---------+-------+ +--------+-----------------+
| name| avg_achieve|
+--------+-----------------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi|76.16666666666667|
+--------+-----------------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi| 76.17|
+--------+-----------+

实现多列之和,再求平均数的UDAF聚合函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show();
ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilAvg(2);
sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name");
rows2.show();
}
}

上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。

MutilAvg.java(udaf函数):

package com.dx.streaming.producer;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType; public class MutilAvg extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize=1; public MutilAvg(int columnSize){
this.columnSize=columnSize;
} @Override
public StructType inputSchema() {
StructType structType= new StructType();
for(int i=0;i<columnSize;i++){
structType.add("myinput"+i,DataTypes.DoubleType);
}
return structType;
} @Override
public StructType bufferSchema() {
StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
return structType;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return true;
} //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
//即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
} /**
* partitions内部combine
* */
//用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0, buffer.getLong(0)+1); // 條目數+1 // 输入一行包含多列,因此需要把铜一行的多列合并。
Double currentLineSumValue= 0d;
for(int i=0;i<columnSize;i++){
currentLineSumValue+=input.getDouble(i);
} buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
* */
//合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
//这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併
buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
} //计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double avg = buffer.getDouble(1) / buffer.getLong(0);
Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat;
}
}

测试输出:

        +---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+ +--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+

实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class TestUDAF2 { public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
Dataset<String> row = sparkSession.createDataset(Arrays.asList(
"1,zhangsan,English,80,89",
"2,zhangsan,History,87,88",
"3,zhangsan,Chinese,88,87",
"4,zhangsan,Chemistry,96,95",
"5,lisi,English,70,75",
"6,lisi,Chinese,74,67",
"7,lisi,History,75,80",
"8,lisi,Chemistry,77,70",
"9,lisi,Physics,79,80",
"10,lisi,Biology,82,83",
"11,wangwu,English,96,84",
"12,wangwu,Chinese,98,64",
"13,wangwu,History,91,92",
"14,zhaoliu,English,68,80",
"15,zhaoliu,Chinese,66,69"), Encoders.STRING());
JavaRDD<String> javaRDD = row.javaRDD();
JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
private static final long serialVersionUID = -4769584490875182711L; @Override
public Row call(String line) throws Exception {
String[] fields = line.split(",");
Integer id=Integer.parseInt(fields[0]);
String name=fields[1];
String subject=fields[2];
Double achieve1=Double.parseDouble(fields[3]);
Double achieve2=Double.parseDouble(fields[4]);
return RowFactory.create(id,name,subject,achieve1,achieve2);
}
}); List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilMax(2,0);
sparkSession.udf().register("max_vals", udaf); Dataset<Row> rows1 = sparkSession.sql(""
+ "select name,max(achieve) as max_achieve "
+ "from "
+ "("
+ "select name,max(achieve1) achieve from user group by name "
+ "union all "
+ "select name,max(achieve2) achieve from user group by name "
+ ") t10 "
+ "group by name");
rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name");
rows2.show();
}
}

上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。

MutilMax.java(udaf函数):

package com.dx.streaming.producer;

import java.util.ArrayList;
import java.util.List; import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType; public class MutilMax extends UserDefinedAggregateFunction {
private static final long serialVersionUID = 3924913264741215131L;
private int columnSize = 1;
private Double defaultValue; public MutilMax(int columnSize, double defaultValue) {
this.columnSize = columnSize;
this.defaultValue = defaultValue;
} @Override
public StructType inputSchema() {
List<StructField> inputFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true));
}
StructType inputSchema = DataTypes.createStructType(inputFields);
return inputSchema;
} @Override
public StructType bufferSchema() {
List<StructField> bufferFields = new ArrayList<StructField>();
for (int i = 0; i < this.columnSize; i++) {
bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true));
}
StructType bufferSchema = DataTypes.createStructType(bufferFields);
return bufferSchema;
} @Override
public DataType dataType() {
return DataTypes.DoubleType;
} @Override
public boolean deterministic() {
return false;
} // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
// 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
@Override
public void initialize(MutableAggregationBuffer buffer) {
for (int i = 0; i < this.columnSize; i++) {
buffer.update(i, 0d);
}
} /**
* partitions内部combine
*/
// 用输入数据input更新buffer值,类似于combineByKey
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer.getDouble(i) >input.getDouble(i)){
buffer.update(i, buffer.getDouble(i));
}else{
buffer.update(i, input.getDouble(i));
}
}
} /**
* partitions间合并:MutableAggregationBuffer继承自Row。
*/
// 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
// 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
for (int i = 0; i < this.columnSize; i++) {
if( buffer1.getDouble(i) >buffer2.getDouble(i)){
buffer1.update(i, buffer1.getDouble(i));
}else{
buffer1.update(i, buffer2.getDouble(i));
}
}
} // 计算并返回最终的聚合结果
@Override
public Object evaluate(Row buffer) {
// 计算平均值
Double max = Double.MIN_VALUE;
for (int i = 0; i < this.columnSize; i++) {
if (buffer.getDouble(i) > max) {
max = buffer.getDouble(i);
}
} if (max == Double.MIN_VALUE) {
max = this.defaultValue;
} return max;
} }

打印结果:

        +---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+ +--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+

Spark编写Agg函数

实现一个avg函数:

第一步:定义一个Average,用来存储count,sum;

import java.io.Serializable;

public class Average implements Serializable {
private long sum;
private long count; // Constructors, getters, setters...
public long getSum() {
return sum;
} public void setSum(long sum) {
this.sum = sum;
} public long getCount() {
return count;
} public void setCount(long count) {
this.count = count;
} public Average() { } public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
}

第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;

import java.io.Serializable;

public class Employee implements Serializable {
private String name;
private long salary; // Constructors, getters, setters...
public String getName() {
return name;
} public void setName(String name) {
this.name = name;
} public long getSalary() {
return salary;
} public void setSalary(long salary) {
this.salary = salary;
} public Employee() {
} public Employee(String name, long salary) {
this.name = name;
this.salary = salary;
}
}

第三步:定义一个Agg,实现对员工的薪资avg功能;

import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator; public class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
} // Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
} // Merge two intermediate values
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
} // Transform the output of the reduction
@Override
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
} // Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
} // Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}

第四步:spark调用agg,验证。

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;; import java.util.ArrayList;
import java.util.List; public class SparkClient {
public static void main(String[] args) {
final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate();
final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext()); List<Employee> employeeList = new ArrayList<Employee>();
employeeList.add(new Employee("Michael", 3000L));
employeeList.add(new Employee("Andy", 4500L));
employeeList.add(new Employee("Justin", 3500L));
employeeList.add(new Employee("Berta", 4000L)); JavaRDD<Employee> rows = ctx.parallelize(employeeList);
Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() {
@Override
public Employee call(Row row) throws Exception {
return new Employee(row.getString(0), row.getLong(1));
}
}, Encoders.bean(Employee.class)); ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+ MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
}
}

输出:

+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+ +--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+

参考:

https://www.cnblogs.com/LHWorldBlog/p/8432210.html

https://blog.csdn.net/kwu_ganymede/article/details/50462020

https://my.oschina.net/cloudcoder/blog/640009

https://blog.csdn.net/xgjianstart/article/details/54956413

Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十五)Spark编写UDF、UDAF、Agg函数的更多相关文章

  1. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(五)针对hadoop2.9.0启动之后发现slave上正常启动了DataNode,DataManager,但是过了几秒后发现DataNode被关闭

    启动之后发现slave上正常启动了DataNode,DataManager,但是过了几秒后发现DataNode被关闭 以slave1上错误日期为例查看错误信息: /logs/hadoop-spark- ...

  2. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(二)安装hadoop2.9.0

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

  3. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(二十一)NIFI1.7.1安装

    一.nifi基本配置 1. 修改各节点主机名,修改/etc/hosts文件内容. 192.168.0.120 master 192.168.0.121 slave1 192.168.0.122 sla ...

  4. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十三)kafka+spark streaming打包好的程序提交时提示虚拟内存不足(Container is running beyond virtual memory limits. Current usage: 119.5 MB of 1 GB physical memory used; 2.2 GB of 2.1 G)

    异常问题:Container is running beyond virtual memory limits. Current usage: 119.5 MB of 1 GB physical mem ...

  5. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十二)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网。

    Centos7出现异常:Failed to start LSB: Bring up/down networking. 按照<Kafka:ZK+Kafka+Spark Streaming集群环境搭 ...

  6. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十一)定制一个arvo格式文件发送到kafka的topic,通过Structured Streaming读取kafka的数据

    将arvo格式数据发送到kafka的topic 第一步:定制avro schema: { "type": "record", "name": ...

  7. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十)安装hadoop2.9.0搭建HA

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

  8. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(九)安装kafka_2.11-1.1.0

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

  9. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(八)安装zookeeper-3.4.12

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

  10. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(三)安装spark2.2.1

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

随机推荐

  1. 百度地图api---实现新建地图

    调用这个函数 function refresh() { history.go(0);   } 实现了地图新建

  2. MTK65XX平台充电调试总结

    MTK平台充电调试总结 摘要:调试电池的充放电管理,首先须要深入了解锂电池的电池原理和特点.充放电特性以及主要的电池安全问题.然后须要对MTK的电池管理驱动程序有深入的了解.理解电池充放电算法的基本原 ...

  3. VS2015 Offline Help Content is now available in 10 more languages!

    https://blogs.msdn.microsoft.com/devcontentloc/2015/10/21/vs2015-offline-help-content-is-now-availab ...

  4. 基于设备树的controller学习(1)

    作者 彭东林pengdonglin137@163.com 平台 TQ2440Linux-4.10.17 概述 在设备树中我们经常见到诸如"#clock-cells"."# ...

  5. JavaScript进阶系列01,函数的声明,函数参数,函数闭包

    本篇主要体验JavaScript函数的声明.函数参数以及函数闭包. □ 函数的声明 ※ 声明全局函数 通常这样声明函数: function doSth() { alert("可以在任何时候调 ...

  6. 委托、Lambda表达式、事件系列02,什么时候该用委托

    假设要找出整型集合中小于5的数. static void Main(string[] args) { IEnumerable<int> source = new List<int&g ...

  7. Java命令学习系列(一)——Jps

    jps位于jdk的bin目录下,其作用是显示当前系统的java进程情况,及其id号. jps相当于Solaris进程工具ps.不象”pgrep java”或”ps -ef grep java”,jps ...

  8. Windows平台Mysql使表名区分大小写

    my.ini 里面的mysqld部分   加入 lower_case_table_names=2 [mysqld] lower_case_table_names=2 port= 3306   注: 1 ...

  9. No identifier specified for entity: springboot-jpa报错No identifier specified for entity

    说明:此次学习使用了springboot框架,jpa注解映射实体类 1,No identifier specified for entity: com.tf.model.User 看到此错误第一反应百 ...

  10. Swift: Swift中Selector的变化

    Swift中Selector变化 2.2 之前,使用字符串作为方法名称 //无参数 btn.addTarget(self, action: Selector("buttonPress&quo ...