首页 文章

Pyspark:将UDF的结果迭代地写回数据帧并不会产生预期的结果

提问于
浏览
0

我仍然是pyspark的新手,我正在尝试评估函数并在UDF的帮助下迭代创建列 . 以下是功能:

def get_temp(df):
    l=['temp1','temp2','temp3']
    s=[0]
    pt = [0]
    start = [0]
    end = [0]
    cummulative_stat = [0]
    for p in xrange(1,4):
        def func(p):
            if p==1:
                pass
            elif p >1:
                start[0] = end[0]
                s[0]=2
                pt[0] =4
            end[0] = start[0] + pt[0] - s[0]
            return end[0]
        func_udf=udf(func,IntegerType())
        df=df.withColumn(l[p-1],func_udf(lit(p)))
    return df
df=get_temp(df)
df.show()

以上结果得出结果:

+---+---+---+-----+-----+-----+
|  a|  b|  c|temp1|temp2|temp3|
+---+---+---+-----+-----+-----+
|  2| 12|  5|    0|    2|    2|
|  8|  5|  7|    0|    4|    4|
|  9|  4|  3|    0|    2|    2|
|  3|  8|  2|    0|    4|    4|
+---+---+---+-----+-----+-----+

预期的结果是:

+---+---+---+-----+-----+-----+
|  a|  b|  c|temp1|temp2|temp3|
+---+---+---+-----+-----+-----+
|  2| 12|  5|    0|    2|    4|
|  8|  5|  7|    0|    2|    4|
|  9|  4|  3|    0|    2|    4|
|  3|  8|  2|    0|    2|    4|
+---+---+---+-----+-----+-----+

如果我单独查看内部函数的输出,结果如预期那样,即:

s=[0]
pt = [0]
start = [0]
end = [0]
cummulative_stat = [0]
for p in xrange(1,4):
    def func():
        if p==1:
            pass
        elif p >1:
            start[0] = end[0]
            s[0]=2
            pt[0] =4
        end[0] = start[0] + pt[0] - s[0]
        return end[0]
    e=func()
    print e

output:
0
2
4

不确定将这些结果从UDF写回到df的正确方法是什么 . 发布的数据帧只是一个示例数据帧,我需要使用for循环,因为在我的原始代码中,我在for循环中调用其他函数(谁的输出取决于迭代器的值) . 例如,请参阅下面的内容:

def get_temp(df):
    l=['temp1','temp2','temp3']
    s=[0]
    pt = [0]
    start = [0]
    end = [0]
    q=[]
    cummulative_stat = [0]
    for p in xrange(1,4):
        def func(p):
            if p < a:
                cummulative_stat[0]=cummulative_stat[0]+52
                pass
            elif p >=a:

                if p==1:
                    pass
                elif p >1:
                    start[0] = end[0]
                    s[0]=2
                    pt[0] =4
                if cummulative_stat and p >1:
                    var1=func2(p,3000)
                    var2=func3(var1)
                    cummulative_stat=np.nan
                else:
                    var1=func2(p,3000)
                    var2=func3(var1)         
                end[0] = start[0] + pt[0] - s[0]
            q.append(end[0],var1,var2)
            return q
        func_udf=udf(func,ArrayType(ArrayType(IntegerType())))
        df=df.withColumn(l[p-1],func_udf(lit(p)))
    return df
df=get_temp(df)
df.show()

我正在使用pyspark 2.2 . 任何帮助深表感谢 . 要创建此数据框:

rdd =  sc.parallelize([(2,12,5),(8,5,7),
                 (9,4,3),
                  (3,8,2)])
df = sqlContext.createDataFrame(rdd, ('a', 'b','c'))
df.show()

1 回答

  • 1

    根据我的理解,查看您的代码是 your next column value depends on the previous one . 如果我的理解是正确的,那么我可以告诉你的udf函数定义放在错误的地方 . 您需要对代码进行细微更改才能使其正常工作 .

    让我们一步一步走

    你已经有了

    +---+---+---+
    |  a|  b|  c|
    +---+---+---+
    |  2| 12|  5|
    |  8|  5|  7|
    |  9|  4|  3|
    |  3|  8|  2|
    +---+---+---+
    

    我们需要一个初始化列,我看到它是0

    from pyspark.sql import functions as F
    from pyspark.sql import types as T
    
    df=df.withColumn('temp0', F.lit(0))
    

    应该是

    +---+---+---+-----+
    |  a|  b|  c|temp0|
    +---+---+---+-----+
    |  2| 12|  5|    0|
    |  8|  5|  7|    0|
    |  9|  4|  3|    0|
    |  3|  8|  2|    0|
    +---+---+---+-----+
    

    我们应该将 udf 函数移动到循环之外

    def func(p, end):
        start = 0
        s = 0
        pt = 0
        if p==1:
            pass
        elif p >1:
            start = end
            s=2
            pt =4
        end = start + pt - s
        return end
    
    func_udf=F.udf(func, T.IntegerType())
    

    并在循环中调用 udf 函数

    def get_temp(df):
        l=['temp1','temp2','temp3']
        for p in xrange(1,4):
            df=df.withColumn(l[p-1],func_udf(F.lit(p), F.col('temp'+str(p-1))))
        return df
    
    df=get_temp(df)
    

    最后删除初始化列

    df=df.drop('temp0')
    

    哪个应该给你你想要的输出

    +---+---+---+-----+-----+-----+
    |  a|  b|  c|temp1|temp2|temp3|
    +---+---+---+-----+-----+-----+
    |  2| 12|  5|    0|    2|    4|
    |  8|  5|  7|    0|    2|    4|
    |  9|  4|  3|    0|    2|    4|
    |  3|  8|  2|    0|    2|    4|
    +---+---+---+-----+-----+-----+
    

    我希望答案是有帮助的

相关问题