我有一个如下的数据框,我正在尝试获取用户 groupby 名称的 max(sum)。
+-----+-----------------------------+
|name |nt_set |
+-----+-----------------------------+
|Bob |[av:27.0, bcd:29.0, abc:25.0]|
|Alice|[abc:95.0, bcd:55.0] |
|Bob |[abc:95.0, bcd:70.0] |
|Alice|[abc:125.0, bcd:90.0] |
+-----+-----------------------------+
下面是我用来获取用户的 max(sum) 的 udf
val maxfunc = udf((arr: Array[String]) => {
val step1 = arr.map(x => (x.split(":", -1)(0), x.split(":", -1)(1))).groupBy(_._1).mapValues(arr => arr.map(_._2.toInt).sum).maxBy(_._2)
val result = step1._1 + ":" + step1._2
result})
当我运行 udf 时,它抛出以下错误
val c6 = c5.withColumn("max_nt", maxfunc(col("nt_set"))).show(false)
错误:无法执行用户定义的函数($anonfun$1: (array) =>string)
我如何以更好的方式实现这一点,因为我需要在更大的数据集中做到这一点
预期的结果是
expected result:
+-----+-----------------------------+
|name |max_nt |
+-----+-----------------------------+
|Bob |abc:120.0 |
|Alice|abc:220.0 |
+-----+-----------------------------+
您的核心逻辑maxfunc
正常工作,只是它应该处理 post-groupBy 数组列,它是一个嵌套Seq
集合:
val df = Seq(
("Bob", Seq("av:27.0", "bcd:29.0", "abc:25.0")),
("Alice", Seq("abc:95.0", "bcd:55.0")),
("Zack", Seq()),
("Bob", Seq("abc:50.0", null)),
("Bob", Seq("abc:95.0", "bcd:70.0")),
("Alice", Seq("abc:125.0", "bcd:90.0"))
).toDF("name", "nt_set")
import org.apache.spark.sql.functions._
val maxfunc = udf( (ss: Seq[Seq[String]]) => {
val groupedSeq: Map[String, Double] = ss.flatMap(identity).
collect{ case x if x != null => (x.split(":")(0), x.split(":")(1)) }.
groupBy(_._1).mapValues(_.map(_._2.toDouble).sum)
groupedSeq match {
case x if x == Map.empty[String, Double] => ("", -999.0)
case _ => groupedSeq.maxBy(_._2)
}
} )
df.groupBy("name").agg(collect_list("nt_set").as("arr_nt")).
withColumn("max_nt", maxfunc($"arr_nt")).
select($"name", $"max_nt._1".as("max_key"), $"max_nt._2".as("max_val")).
show
// +-----+-------+-------+
// | name|max_key|max_val|
// +-----+-------+-------+
// | Zack| | -999.0|
// | Bob| abc| 170.0|
// |Alice| abc| 220.0|
// +-----+-------+-------+
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句