首页 文章

在Scala中编写Spark UDAF以返回Array类型作为输出

提问于
浏览
1

我有一个数据框如下 -

val myDF = Seq(
(1,"A",100),
(1,"E",300),
(1,"B",200),
(2,"A",200),
(2,"C",300),
(2,"D",100)
).toDF("id","channel","time")

myDF.show()

+---+-------+----+
| id|channel|time|
+---+-------+----+
|  1|      A| 100|
|  1|      E| 300|
|  1|      B| 200|
|  2|      A| 200|
|  2|      C| 300|
|  2|      D| 100|
+---+-------+----+

对于每个 id ,我希望 Channels 按 time 按升序排序 . 我想为这个逻辑实现一个UDAF .

我想把这个UDAF称为 -

scala > spark.sql("""select customerid , myUDAF(customerid,channel,time) group by customerid """).show()

Ouptut数据框应该看起来像 -

+---+-------+
| id|channel|
+---+-------+
|  1|[A,B,E]|
|  2|[D,A,C]|
+---+-------+

我正在尝试编写UDAF但无法实现它 -

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._



class myUDAF extends UserDefinedAggregateFunction {

    // This is the input fields for your aggregate function 
    override def inputSchema : org.apache.spark.sql.types.Structype = 
        Structype(
            StructField("id" , IntegerType)
            StructField("channel", StringType)
            StructField("time", IntegerType) :: Nil
        )

    // This is the internal fields we would keep for computing the aggregate 
    // output 
    override def bufferSchema : Structype = 
        Structype(
            StructField("Sequence", ArrayType(StringType)) :: Nil
        )

    // This is the output type of my aggregate function 
    override def dataType : DataType = ArrayType(StringType)

    // no comments here
    override def deterministic : Booelan = true 

    // initialize 
    override def initialize(buffer: MutableAggregationBuffer) : Unit = {
        buffer(0) = Seq("")
    }





}

请帮忙 .

3 回答

  • 1

    这样做(不需要定义自己的UDF):

    df.groupBy("id")
      .agg(sort_array(collect_list(  // NOTE: sort based on the first element of the struct
             struct("time", "channel"))).as("stuff"))
      .select("id", "stuff.channel")
      .show(false)
    
    +---+---------+
    |id |channel  |
    +---+---------+
    |1  |[A, B, E]|
    |2  |[D, A, C]|
    +---+---------+
    
  • -2

    我不会为此写一个UDAF . 根据我的经验,UDAF相当慢,尤其是复杂类型 . 我会使用collect_list和UDF方法:

    val sortByTime = udf((rws:Seq[Row]) => rws.sortBy(_.getInt(0)).map(_.getString(1)))
    
    myDF
      .groupBy($"id")
      .agg(collect_list(struct($"time",$"channel")).as("channel"))
      .withColumn("channel", sortByTime($"channel"))
      .show()
    
    +---+---------+
    | id|  channel|
    +---+---------+
    |  1|[A, B, E]|
    |  2|[D, A, C]|
    +---+---------+
    
  • 3

    没有UDF的更简单的方法 .

    import org.apache.spark.sql.functions._
    myDF.orderBy($"time".asc).groupBy($"id").agg(collect_list($"channel") as "channel").show()
    

相关问题