我有一些用PySpark编写的代码,我正忙着将它转换为Scala . 它一直很顺利,除了现在我在Scala中使用用户定义的函数 .
python
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import functions as F
spark = SparkSession.builder.master('local[*]').getOrCreate()
a = spark.sparkContext.parallelize([(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]).toDF(["index"]).withColumn("a1", F.lit(1)).withColumn("a2", F.lit(2)).withColumn("a3", F.lit(3))
a = a.select("index", F.struct(*('a' + str(c) for c in range(1, 4))).alias('a'))
a.show()
def a_to_b(a):
# 1. check if a technical cure exists
b = {}
for i in range(1, 4):
b.update({'b' + str(i): a[i - 1] ** 2})
return b
a_to_b_udf = F.udf(lambda x: a_to_b(x), StructType(list(StructField("b" + str(x), IntegerType()) for x in range(1, 4))))
b = a.select("index", "a", a_to_b_udf(a.a).alias("b"))
b.show()
这会产生:
+-----+-------+
|index| a|
+-----+-------+
| 1|[1,2,3]|
| 2|[1,2,3]|
| 3|[1,2,3]|
| 4|[1,2,3]|
| 5|[1,2,3]|
| 6|[1,2,3]|
| 7|[1,2,3]|
| 8|[1,2,3]|
| 9|[1,2,3]|
| 10|[1,2,3]|
+-----+-------+
和
+-----+-------+-------+
|index| a| b|
+-----+-------+-------+
| 1|[1,2,3]|[1,4,9]|
| 2|[1,2,3]|[1,4,9]|
| 3|[1,2,3]|[1,4,9]|
| 4|[1,2,3]|[1,4,9]|
| 5|[1,2,3]|[1,4,9]|
| 6|[1,2,3]|[1,4,9]|
| 7|[1,2,3]|[1,4,9]|
| 8|[1,2,3]|[1,4,9]|
| 9|[1,2,3]|[1,4,9]|
| 10|[1,2,3]|[1,4,9]|
+-----+-------+-------+
斯卡拉
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
// can ignore if running on spark-shell
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.getOrCreate()
import spark.implicits._
var a = spark.sparkContext.parallelize(1 to 10).toDF("index").withColumn("a1", lit(1)).withColumn("a2", lit(2)).withColumn("a3", lit(3))
// convert a{x} to struct column
a = a.select($"index", struct((1 to 3).map {x => col("a" + x)}.toList:_*).alias("a"))
a.show()
// this is where I am struggling, I have tried supplying a schema, but still get errors
val f = udf((a: Column) => {
Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})
val b = a.select($"index", $"a", f($"a").alias("b"))
// throws the below error
b.show()
我可以show()第一个DataFrame,但是在尝试显示 b
时出现了转换错误 .
错误是:
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to org.apache.spark.sql.Column
at $line23.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:31)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
at org.apache.spark.scheduler.Task.run(Task.scala:85)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
我已经尝试为我的UDF设置架构,就像我在Python中所做的那样,但我仍然得到相同的错误 .
有谁知道如何解决这个问题?我的例子很简单,但在返回结构之前,我需要对UDF做的是很多转换 .
1 回答
我觉得很傻,因为自从周五下午以来我一直在努力 .
从Spark Sql UDF with complex input parameter,
我的问题是我提供给我的函数的
Column
类型 .我应该使用
Row
代替 .