首页 文章

Pyspark:在窗口内使用udf

提问于
浏览
0

我需要使用Pyspark检测时间序列上的阈值 . 在下面的示例图中,我想检测(通过存储相关的时间戳)参数ALT_STD的每次出现都大于5000然后低于5000 .

ALT_STD vs Time

对于这个简单的情况,我可以运行简单的查询,如

t_start = df.select('timestamp')\
                .filter(df.ALT_STD > 5000)\
                .sort('timestamp')\
                .first()
t_stop = df.select('timestamp')\
               .filter((df.ALT_STD < 5000)\                           
                       & (df.timestamp > t_start.timestamp))\
               .sort('timestamp')\
               .first()

但是,在某些情况下,事件可以是循环的,并且我可能有几条曲线(即ALT_STD的几次将高于或低于5000) . 当然,如果我使用上面的查询,我将只能检测到第一次出现 .

我想我应该使用udf的窗口函数,但我找不到一个有效的解决方案 . 我的猜测是算法应该是这样的:

windowSpec = Window.partitionBy('flight_hash')\
                   .orderBy('timestamp')\
                   .rowsBetween(Window.currentRow, 1)

def detect_thresholds(x):
    if (x['ALT_STD'][current_row]< 5000) and (x['ALT_STD'][next_row] > 5000):
        return x['timestamp'] #Or maybe simply 1
    if (x['ALT_STD'][current_row]> 5000) and (x['ALT_STD'][current_row] > 5000):
    return x['timestamp'] #Or maybe simply 2
    else:
        return 0

import pyspark.sql.functions as F
detect_udf = F.udf(detect_threshold, IntegerType())
df.withColumn('Result', detect_udf(F.Struct('ALT_STD')).over(windowSpec).show()

这样的算法在Pyspark中是否可行?怎么样 ?

post-scriptum:作为旁注,我已经了解了如何使用udf或udf以及内置的sql窗口函数,但不知道如何组合udf和窗口 . 例如:

# This will compute the mean (built-in function)
df.withColumn("Result", F.mean(df['ALT_STD']).over(windowSpec)).show()

# This will also work
divide_udf = F.udf(lambda x: x[0]/1000., DoubleType())
df.withColumn('result', divide_udf(F.struct('timestamp')))

2 回答

  • 0

    感谢user9569772的答案,我发现了 . 他的解决方案不起作用,因为.lag()或.lead()是窗口函数 .

    from pyspark.sql.functions import when
    from pyspark.sql import functions as F
    
    # Define conditions
    det_start = (F.lag(F.col('ALT_STD')).over(windowSpec) < 100)\
              & (F.lead(F.col('ALT_STD'), 0).over(windowSpec) >= 100)
    det_end = (F.lag(F.col('ALT_STD'), 0).over(windowSpec) > 100)\
            & (F.lead(F.col('ALT_STD')).over(windowSpec) < 100)
    
    # Combine conditions with .when() and .otherwise()
    result = (when(det_start, 1)\
           .when(det_end, 2)\
           .otherwise(0))
    
    df.withColumn("phases", result).show()
    
  • 0

    这里不需要udf(并且python udfs不能用作窗口函数) . 只需使用 lead / lagwhen

    from pyspark.sql.functions import col, lag, lead, when
    
    result = (when((col('ALT_STD') < 5000) & (lead(col('ALT_STD'), 1) > 5000), 1)
        .when(col('ALT_STD') > 5000) & (lead(col('ALT_STD'), 1) < 5000), 1)
        .otherwise(0))
    
    df.withColum("result", result)
    

相关问题