首页 文章

Scala中的Spark SQL(v2.0)UDAF返回空字符串

提问于
浏览
1

当我试图为我们的复杂问题创建一个UDAF时,我决定从一个基本的UDAF开始,它按原样返回列 . 由于我是Spark SQL / Scala的新手,有人可以帮助我并突出我的错误 .

以下是代码:

import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org . apache.spark.sql.types.DataTypes import scala.collection._ object MinhashUdaf extends UserDefinedAggregateFunction {override def inputSchema:org.apache.spark.sql.types.StructType = StructType(StructField(“value”,StringType):: Nil) override def bufferSchema:StructType = StructType(StructField(“shingles”,(StringType)):: Nil)override def dataType:DataType =(StringType)override def deterministic:Boolean = true override def initialize(buffer:MutableAggregationBuffer):Unit = { buffer(0)=(“”)}覆盖def update(缓冲区:MutableAggregationBuffer,输入:Row):Unit = {buffer.update(0,input.toString())}覆盖def合并(buffer1:MutableAggregationBuffer,buffer2:Row ):Unit = {}覆盖def evaluate(缓冲区:行):Any = {buffer(0)}}

要运行上面的UDAF,以下是代码:

def main(args:Array [String]){val spark:SparkSession = SparkSession.builder .master(“local [*]”) . appName(“test”) . getOrCreate(); import spark.implicits._;

val df = spark.read.json(“people.json”)
df.createOrReplaceTempView( “人”)
val sqlDF = spark.sql(“从人们中选择姓名”)
sqlDF.show()

val minhash = df.select(MinhashUdaf(col(“name”)) . as(“minhash”))
minhash.printSchema()
minhash.show(truncate = false)

因为在UDAF中我按原样返回输入,所以我应该按原样得到每行的“name”列的值 . 而在运行上面的字符串时,我返回一个空字符串 .

1 回答

  • 1

    您没有实现合并功能 .

    使用下面的代码,您可以根据需要打印列的值 .

    object MinhashUdaf extends UserDefinedAggregateFunction {
    
    override def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)
    
    override def bufferSchema: StructType = StructType( StructField("shingles", (StringType)) :: Nil)
    
    override def dataType: DataType = (StringType)
    
    override def deterministic: Boolean = true
    
    override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = ("") }
    
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer.update(0, input.get(0)) }
    
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {  buffer1.update(0, buffer2.get(0))}
    
    override def evaluate(buffer: Row): Any = { buffer(0) } }
    

相关问题