首页 文章

在pyspark列中查找列表中连续的长度

提问于
浏览
0

我试图在pyspark中解决一个问题,包括收集一个只包含1和0的列表 . 我想在列表中找到连续的数量(在使用collect_list之后) .

# Sample data

sc = SparkContext().getOrCreate()
sqlCtx = SQLContext(sc)

json = [
    {'a': 'tank', 'b': 1},
    {'a': 'tank', 'b': 1}, {'a': 'bank', 'b': 1},
    {'a': 'tank', 'b': 0}, {'a': 'bank', 'b': 0},
    {'a': 'tank', 'b': 1}, {'a': 'bank', 'b': 1},
    {'a': 'tank', 'b': 1}, {'a': 'bank', 'b': 1},
    {'a': 'tank', 'b': 1}, {'a': 'bank', 'b': 1},
    {'a': 'tank', 'b': 1}, {'a': 'bank', 'b': 1},
]

df = sqlCtx.read.json(sc.parallelize(json))
df.show()

# Data looks like 
+----+---+
|   a|  b|
+----+---+
|tank|  1|
|tank|  1|
|bank|  1|
|tank|  0|
|bank|  0|
|tank|  1|
|bank|  1|
|tank|  1|
|bank|  1|
|tank|  1|
|bank|  1|
|tank|  1|
|bank|  1|
+----+---+

df = df.groupBy('a').agg(F.collect_list('b').alias('b'))
# Output looks like
+----+---------------------+
|a   |b                    |
+----+---------------------+
|bank|[1, 0, 1, 1, 1, 1]   |
|tank|[1, 1, 0, 1, 1, 1, 1]|
+----+---------------------+

我想在 collect_list(b) 中计算连续的最大数量,如果可能的话,得到开始和结束的索引 . 我是正确的 .

1 回答

  • 1

    Spark版本2.1及以上版本

    如果你有Spark版本2.1或更高版本,这是一种方法:

    首先使用pyspark.sql.posexplode()将收集的列表与索引一起展开 .

    import pyspark.sql.functions as f
    df = df.select("a", f.posexplode("b").alias("pos", "b"))
    

    接下来使用Window函数创建一个列,该列将指示当前行的值是否与前一行不同 .

    from pyspark.sql import Window
    
    w = Window.partitionBy("a").orderBy("pos")
    df = df.select(
        "*", 
        (f.col("b") != f.lag(f.col("b"), default=0).over(w)).cast("int").alias("change")
    )
    df.show()
    #+----+---+---+------+
    #|   a|pos|  b|change|
    #+----+---+---+------+
    #|bank|  0|  1|     1|
    #|bank|  1|  0|     1|
    #|bank|  2|  1|     1|
    #|bank|  3|  1|     0|
    #|bank|  4|  1|     0|
    #|bank|  5|  1|     0|
    #|tank|  0|  1|     1|
    #|tank|  1|  1|     0|
    #|tank|  2|  0|     1|
    #|tank|  3|  1|     1|
    #|tank|  4|  1|     0|
    #|tank|  5|  1|     0|
    #|tank|  6|  1|     0|
    #+----+---+---+------+
    

    计算该列的累积总和,将 01 分成组 . 然后你可以 groupBy()(a, b, group) 并计算每个组的长度以及开始和结束索引 .

    df = df.select(
        "*",
        f.sum(f.col("change")).over(w.rangeBetween(Window.unboundedPreceding, 0)).alias("group")
    )\
    .groupBy("a", "b", "group")\
    .agg(f.min("pos").alias("start"), f.max("pos").alias("end"), f.count("*").alias("length"))\
    .where(f.col("b") == 1)\
    .drop("group")
    df.show()
    #+----+---+-----+---+------+
    #|   a|  b|start|end|length|
    #+----+---+-----+---+------+
    #|bank|  1|    0|  0|     1|
    #|bank|  1|    2|  5|     4|
    #|tank|  1|    0|  1|     2|
    #|tank|  1|    3|  6|     4|
    #+----+---+-----+---+------+
    

    最后,您可以过滤此DataFrame以查找与 a 列中每个项目的最长长度序列关联的行:

    df = df.withColumn(
        "isMax",
        f.col("length") == f.max(f.col("length")).over(Window.partitionBy("a"))
    )\
    .where(f.col("isMax"))\
    .drop("isMax")
    df.show()
    #+----+---+-----+---+------+
    #|   a|  b|start|end|length|
    #+----+---+-----+---+------+
    #|bank|  1|    2|  5|     4|
    #|tank|  1|    3|  6|     4|
    #+----+---+-----+---+------+
    

    Spark版本1.5及以上版本

    如果你没有 posexplode ,另一种选择是将你的整数数组转换成一个字符串数组,连接它,并在 "0" 上拆分 . 然后爆炸结果数组,并筛选具有最大长度的数组 .

    不幸的是,这种方法并没有给你起点和终点 .

    df.withColumn('b', f.split(f.concat_ws('', f.col('b').cast('array<string>')), '0'))\
        .select('a', f.explode('b').alias('b'))\
        .select('a', f.length('b').alias('length'))\
        .withColumn(
            "isMax",
            f.col('length') == f.max(f.col('length')).over(Window.partitionBy('a'))
        )\
        .where(f.col("isMax"))\
        .drop("isMax")\
        .show()
    #+----+------+
    #|   a|length|
    #+----+------+
    #|bank|     4|
    #|tank|     4|
    #+----+------+
    

相关问题