0%

自定义函数

spark自定义函数详解:UDF 与 UDAF 实战指南

在 Spark SQL 中,内置函数(如 sumconcat)可满足多数基础需求,但复杂业务场景往往需要自定义逻辑。Spark 支持两种核心自定义函数:用户定义函数(UDF)用户定义聚合函数(UDAF)。本文将深入解析两者的实现原理、使用场景及实战案例,帮助你灵活扩展 Spark 的数据处理能力。

自定义函数概述

为什么需要自定义函数?

  • 业务个性化:内置函数无法覆盖复杂业务逻辑(如特殊格式转换、自定义指标计算);
  • 代码复用:将常用逻辑封装为函数,避免重复开发;
  • 简化 SQL:用自定义函数替代复杂 SQL 子查询,提升可读性。

自定义函数的类型

  • UDF(User-Defined Function)一行输入 → 一行输出,如字符串拼接、格式转换;
  • UDAF(User-Defined Aggregate Function)多行输入 → 一行输出,如自定义平均值、加权求和。

UDF:用户定义函数(单行处理)

UDF 是最常用的自定义函数类型,用于对单条记录的字段进行转换或计算。其核心是接收一个或多个输入值,返回单个输出值。

UDF 的定义与注册

Scala 实现 UDF

通过 spark.udf.register 方法注册 UDF,支持匿名函数或方法引用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import org.apache.spark.sql.SparkSession  

val spark = SparkSession.builder()
.appName("UDFExample")
.master("local[*]")
.getOrCreate()

// 1. 定义 UDF:给字符串添加前缀
val addPrefix = (name: String, prefix: String) => s"$prefix:$name"

// 2. 注册 UDF(函数名、函数逻辑、返回值类型可选)
spark.udf.register("add_prefix", addPrefix)

// 3. 在 SQL 中使用
spark.sql("CREATE TEMPORARY VIEW users AS SELECT 'Alice' AS name UNION SELECT 'Bob'")
spark.sql("SELECT add_prefix(name, 'user') AS username FROM users").show()
// +----------+
// | username|
// +----------+
// |user:Alice|
// | user:Bob|
// +----------+
UDF 的输入输出类型

UDF 支持多种输入类型(如 StringIntDouble)和复杂类型(如 ArrayMap),需注意类型匹配:

1
2
3
// 示例:计算数组元素之和  
spark.udf.register("array_sum", (arr: Array[Int]) => arr.sum)
spark.sql("SELECT array_sum(array(1,2,3)) AS sum").show() // 输出 6

UDF 的使用场景

  • 数据清洗:格式转换(如日期格式化、大小写转换);
  • 特征工程:提取字段特征(如从 URL 中提取域名);
  • 业务计算:自定义评分、规则判断等。

注意事项

  • 类型安全:UDF 输入输出类型需明确,避免 NullPointerException(可使用 Option 处理空值);
  • 性能优化:复杂 UDF 可能成为瓶颈,建议优先使用内置函数,必要时通过代码生成(Code Generation)优化;
  • 注册作用域spark.udf.register 注册的 UDF 仅在当前 SparkSession 有效。

UDAF:用户定义聚合函数(多行聚合)

UDAF 用于对多行数据进行聚合计算,返回单个结果,如自定义平均值、中位数、TopN 统计等。与 UDF 不同,UDAF 需要处理中间状态(缓冲区)的初始化、更新和合并。

UDAF 的实现原理

UDAF 的核心是通过 Aggregator 抽象类定义聚合逻辑,包含以下关键步骤:

  1. 初始化缓冲区:定义聚合的初始状态;
  2. 更新缓冲区:将每行数据合并到缓冲区;
  3. 合并缓冲区:在分布式场景下合并多个分区的缓冲区;
  4. 计算结果:从最终缓冲区中提取聚合结果。

基于 Aggregator 的 UDAF 实现

示例:自定义平均年龄计算器

需求:计算用户年龄的平均值,支持分布式场景下的分区聚合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import org.apache.spark.sql.{Encoder, Encoders}  
import org.apache.spark.sql.expressions.Aggregator

// 1. 定义缓冲区样例类(存储中间状态:总年龄、总人数)
case class AgeBuffer(totalAge: Long, count: Long)

// 2. 实现 Aggregator[输入类型, 缓冲区类型, 输出类型]
class AvgAgeAggregator extends Aggregator[Long, AgeBuffer, Double] {
// 初始化缓冲区
override def zero: AgeBuffer = AgeBuffer(0L, 0L)

// 单个分区内更新缓冲区(每行数据触发)
override def reduce(buffer: AgeBuffer, age: Long): AgeBuffer = {
AgeBuffer(buffer.totalAge + age, buffer.count + 1)
}

// 合并多个分区的缓冲区(分布式场景)
override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
AgeBuffer(b1.totalAge + b2.totalAge, b1.count + b2.count)
}

// 从缓冲区计算最终结果
override def finish(buffer: AgeBuffer): Double = {
if (buffer.count == 0) 0.0 else buffer.totalAge.toDouble / buffer.count
}

// 缓冲区的编码器(用于序列化)
override def bufferEncoder: Encoder[AgeBuffer] = Encoders.product

// 输出结果的编码器
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

UDAF 的注册与使用

在 SQL 中使用 UDAF

需将 Aggregator 转换为 SQL 兼容的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 1. 注册 UDAF 为 SQL 函数  
spark.udf.register("avg_age", functions.udaf(new AvgAgeAggregator()))

// 2. 准备测试数据
spark.sql("CREATE TEMPORARY VIEW user_ages AS SELECT 18 AS age UNION SELECT 20 UNION SELECT 22")

// 3. 执行聚合查询
spark.sql("SELECT avg_age(age) AS average_age FROM user_ages").show()
// +-----------+
// |average_age|
// +-----------+
// | 20.0|
// +-----------+
在 DataSet 中使用 UDAF

直接通过 toColumn 方法将 UDAF 作为列函数使用:

1
2
3
4
5
6
7
8
import spark.implicits._  

// 1. 创建 DataSet[Long]
val ageDS = Seq(18L, 20L, 22L).toDS()

// 2. 使用 UDAF
val avgAgeColumn = new AvgAgeAggregator().toColumn.name("average_age")
ageDS.select(avgAgeColumn).show() // 输出同上

复杂 UDAF 示例:加权平均值

需求:计算带权重的平均值(如 sum(age * weight) / sum(weight))。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
// 输入类型:(年龄, 权重)  
case class WeightedAge(age: Long, weight: Double)
// 缓冲区:(总年龄×权重, 总权重)
case class WeightedBuffer(sumAgeWeight: Double, sumWeight: Double)

class WeightedAvgAggregator extends Aggregator[WeightedAge, WeightedBuffer, Double] {
override def zero: WeightedBuffer = WeightedBuffer(0.0, 0.0)

override def reduce(buffer: WeightedBuffer, input: WeightedAge): WeightedBuffer = {
WeightedBuffer(
buffer.sumAgeWeight + input.age * input.weight,
buffer.sumWeight + input.weight
)
}

override def merge(b1: WeightedBuffer, b2: WeightedBuffer): WeightedBuffer = {
WeightedBuffer(
b1.sumAgeWeight + b2.sumAgeWeight,
b1.sumWeight + b2.sumWeight
)
}

override def finish(buffer: WeightedBuffer): Double = {
if (buffer.sumWeight == 0) 0.0 else buffer.sumAgeWeight / buffer.sumWeight
}

override def bufferEncoder: Encoder[WeightedBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// 使用示例
val data = Seq(WeightedAge(18, 0.5), WeightedAge(20, 0.5)).toDS()
data.select(new WeightedAvgAggregator().toColumn.name("weighted_avg")).show() // 输出 19.0

UDAF 的使用场景

  • 自定义统计指标:如中位数、分位数、增长率等;
  • 业务聚合逻辑:如用户活跃度评分、订单金额加权求和;
  • 替代复杂 SQL 聚合:用 UDAF 简化多步骤聚合逻辑。

UDF 与 UDAF 的对比与选择

特性 UDF(用户定义函数) UDAF(用户定义聚合函数)
输入输出 单行输入 → 单行输出 多行输入 → 单行输出
核心接口 直接注册函数(spark.udf.register 继承 Aggregator 抽象类
适用场景 字段转换、格式处理 统计聚合、多记录计算
实现复杂度 简单(函数逻辑) 较复杂(需处理缓冲区和合并)
性能影响 低(单行处理) 中高(需维护中间状态)

选择建议

  • 若需对单条记录进行转换,用 UDF;
  • 若需对多条记录进行聚合统计,用 UDAF;
  • 优先使用内置函数,自定义函数仅在必要时使用(避免性能损耗)。

自定义函数的性能优化

  1. 避免复杂逻辑:UDF/UDAF 应简洁,复杂逻辑可拆分为多个步骤;
  2. 处理空值:用 Optioncoalesce 函数处理 null,避免空指针异常;
  3. 使用编码器优化:UDAF 中通过 Encoders.product 或自定义编码器提升序列化性能;
  4. 测试与 benchmark:对比自定义函数与内置函数的性能,确保优化效果。

欢迎关注我的其它发布渠道

表情 | 预览
快来做第一个评论的人吧~
Powered By Valine
v1.3.10