首页 文章

PySpark使用UDF创建组合

提问于
浏览
2

这可能是一个基本问题,但我现在已经被困住了一段时间 .

我有几个列名称,我正在尝试创建一个组合列表,它组合了Spark中的两个元素 . 这是我尝试创建组合的列表

numeric_cols = ["age", "hours-per-week", "fnlwgt"]

我正在使用 itertools 模块中的 combinations

from itertools import combinations
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType

def combinations2(x): return combinations(x,2)
udf_combinations2 = udf(combinations2,ArrayType())

但是在跑线上

pairs  = udf_combinations2(numeric_cols)

我收到以下错误

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/sg/Downloads/spark/python/pyspark/sql/udf.py", line 179, in wrapper
    return self(*args)
  File "/Users/sg/Downloads/spark/python/pyspark/sql/udf.py", line 159, in __call__
    return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
  File "/Users/sg/Downloads/spark/python/pyspark/sql/column.py", line 66, in _to_seq
    cols = [converter(c) for c in cols]
  File "/Users/sg/Downloads/spark/python/pyspark/sql/column.py", line 66, in <listcomp>
    cols = [converter(c) for c in cols]
  File "/Users/sg/Downloads/spark/python/pyspark/sql/column.py", line 54, in _to_java_column
    "function.".format(col, type(col)))
TypeError: Invalid argument, not a string or column: ['age', 'hours-per-week', 'fnlwgt'] of type <class 'list'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

我不知道如何使用最后一行中提到的功能 . 任何方向和提示都会很棒 .

谢谢

1 回答

  • 1

    首先正确定义 udf

    df = spark.createDataFrame([(1, 2 ,3)], ("age", "hours-per-week", "fnlwgt"))
    

    您可以使用单个参数定义它

    @udf("array<struct<_1: double, _2: double>>")
    def combinations_list(x):
       return combinations(x, 2)
    

    或者varargs

    @udf("array<struct<_1: double, _2: double>>")
    def combinations_varargs(*x):
       return combinations(list(x), 2)
    

    在两种情况下 you have 都声明输出数组的类型 . 这里我们将使用 doublestructs .

    确保输入类型与声明的输出类型匹配:

    from pyspark.sql.functions import col
    
    numeric_cols = [
        col(c).cast("double") for c in ["age", "hours-per-week", "fnlwgt"]
    ]
    

    要调用单个参数版本,请使用 array

    from pyspark.sql.functions import array
    
    df.select(
         combinations_list(array(*numeric_cols)).alias("combinations")
    ).show(truncate=False)
    # +---------------------------------+
    # |combinations                     |
    # +---------------------------------+
    # |[[1.0,2.0], [1.0,3.0], [2.0,3.0]]|
    # +---------------------------------+
    

    调用varargs变量解压缩值

    df.select(
         combinations_varargs(*numeric_cols).alias("combinations")
    ).show(truncate=False)
    # +---------------------------------+
    # |combinations                     |
    # +---------------------------------+
    # |[[1.0,2.0], [1.0,3.0], [2.0,3.0]]|
    # +---------------------------------+
    

相关问题