首页 文章

pyspark在udf中使用数据框

提问于
浏览
2

我有两个数据帧 df1

+---+---+----------+
|  n|val| distances|
+---+---+----------+
|  1|  1|0.27308652|
|  2|  1|0.24969208|
|  3|  1|0.21314497|
+---+---+----------+

df2

+---+---+----------+
| x1| x2|         w|
+---+---+----------+
|  1|  2|0.03103427|
|  1|  4|0.19012526|
|  1| 10|0.26805446|
|  1|  8|0.26825935|
+---+---+----------+

我想在 df1 中添加一个名为 gamma 的新列,当 df1.n == df2.x1 OR df1.n == df2.x2 时,它将包含来自 df2w 值的总和

我尝试使用udf,但显然从不同的数据框中选择是行不通的,因为值应该在计算之前确定

gamma_udf = udf(lambda n: float(df2.filter("x1 = %d OR x2 = %d"%(n,n)).groupBy().sum('w').rdd.map(lambda x: x).collect()[0]), FloatType())
df1.withColumn('gamma1', gamma_udf('n'))

有没有办法用 joingroupby 而不使用循环?

1 回答

  • 1

    您无法在 udf 中引用DataFrame . 正如您所提到的,这个问题最好用 join 来解决 .

    IIUC,您正在寻找类似的东西:

    from pyspark.sql import Window
    import pyspark.sql.functions as F
    
    df1.alias("L").join(df2.alias("R"), (df1.n == df2.x1) | (df1.n == df2.x2), how="left")\
        .select("L.*", F.sum("w").over(Window.partitionBy("n")).alias("gamma"))\
        .distinct()\
        .show()
    #+---+---+----------+----------+
    #|  n|val| distances|     gamma|
    #+---+---+----------+----------+
    #|  1|  1|0.27308652|0.75747334|
    #|  3|  1|0.21314497|      null|
    #|  2|  1|0.24969208|0.03103427|
    #+---+---+----------+----------+
    

    或者如果您对 pyspark-sql 语法更熟悉,可以注册临时表并执行:

    df1.registerTempTable("df1")
    df2.registerTempTable("df2")
    
    sqlCtx.sql(
        "SELECT DISTINCT L.*, SUM(R.w) OVER (PARTITION BY L.n) AS gamma "
        "FROM df1 L LEFT JOIN df2 R ON L.n = R.x1 OR L.n = R.x2"
    ).show()
    #+---+---+----------+----------+
    #|  n|val| distances|     gamma|
    #+---+---+----------+----------+
    #|  1|  1|0.27308652|0.75747334|
    #|  3|  1|0.21314497|      null|
    #|  2|  1|0.24969208|0.03103427|
    #+---+---+----------+----------+
    

    Explanation

    在这两种情况下,我们正在做 df1df2 . 这将保留 df1 中的所有行,无论是否匹配 .

    join子句是您在问题中指定的条件 . 因此, df2x1x2 等于 n 的所有行都将被连接 .

    接下来选择左表中的所有行加上我们分组(分区依据) n 并将 w 的值相加 . 对于 n 的每个值,这将获得与连接条件匹配的所有行的总和 .

    最后,我们只返回不同的行以消除重复 .

相关问题