首页 文章

PySpark - 在UDF中调用子设置函数

提问于
浏览
1

我必须在pyspark数据帧中找到特定数据点的邻居 .

a= spark.createDataFrame([("A", [0,1]), ("B", [5,9]), ("D", [13,5])],["Letter", "distances"])

我创建了这个函数,它将接收数据帧(DB),然后使用欧几里德距离检查最接近固定点(Q)的数据点 . 它将根据某些epsilon值(eps)过滤掉相关数据点并返回子集 .

def rangequery(DB, Q, eps):
    distance_udf = F.udf(lambda x: float(distance.euclidean(x, Q)), FloatType())
    df_neigh =DB.withColumn('euclid_distances', distance_udf(F.col('distances')))
    return df_neigh.filter(df_neigh['euclid_distances'] <= eps)

但现在我需要为数据框中的每个点运行此函数

所以我做了以下几点 .

def check_neighbours(distance):
    df = rangequery(a,distances, 9)
    if df.count()>=1:
        return "Has Neighbours"
    else:
        return "No Neighbours"       
udf_neigh=udf(check_neighbours, StringType())
a.withColumn("label", udf_neigh( a["distances"])).show()

我尝试运行此代码时收到以下错误 .

PicklingError: Could not serialize object: Py4JError: An error occurred while calling o380.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
    at py4j.Gateway.invoke(Gateway.java:272)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:214)
    at java.lang.Thread.run(Thread.java:745)

1 回答

  • 0

    借用this answer,这是一种方法 . 请考虑以下示例:

    from pyspark.sql.functions import col, udf
    # create dummy dataset
    DB = sqlCtx.createDataFrame(
        [("A", [0,1]), ("B", [5,9]), ("D", [13,5])],
        ["Letter", "distances"]
    )
    
    # Define your distance metric as a udf 
    from scipy.spatial import distance
    distance_udf = udf(lambda x, y: float(distance.euclidean(x, y)), FloatType())
    
    # Use crossJoin() to compute distances.
    eps = 9  # minimum distance 
    DB.alias("l")\
        .crossJoin(DB.alias("r"))\
        .where(distance_udf(col("l.distances"), col("r.distances")) < eps)\
        .groupBy("l.letter", "l.distances")\
        .count()\
        .withColumn("count", col("count") - 1)\
        .withColumn("label", udf(lambda x: "Has Neighbours" if x >= 1 else "No Neighbours")(col("count")))\
        .sort('letter')\
        .show()
    

    输出:

    +------+---------+-----+--------------+
    |letter|distances|count|         label|
    +------+---------+-----+--------------+
    |     A|   [0, 1]|    0| No Neighbours|
    |     B|   [5, 9]|    1|Has Neighbours|
    |     D|  [13, 5]|    1|Has Neighbours|
    +------+---------+-----+--------------+
    

    完成 .withColumn("count", col("count") - 1) 的地方,因为我们知道每列都将自己作为一个普通的邻居 . (您可以根据需要删除此行 . )

    您写的代码不起作用,因为_265452中的@ user8371915提到:

    您无法在udf中引用分布式DataFrame

相关问题