首页 文章

pyspark:grouby然后获得每组的最大值

提问于
浏览
5

我想按值分组,然后使用PySpark在每个组中找到最大值 . 我有以下代码,但现在我有点不知道如何提取最大值 .

# some file contains tuples ('user', 'item', 'occurrences')
data_file = sc.textData('file:///some_file.txt')
# Create the triplet so I index stuff
data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))
# Group by the user i.e. r[0]
grouped = data_file.groupBy(lambda r: r[0])
# Here is where I am stuck 
group_list = grouped.map(lambda x: (list(x[1]))) #?

返回类似于:

[[(u'u1', u's1', 20), (u'u1', u's2', 5)], [(u'u2', u's3', 5), (u'u2', u's2', 10)]]

我想现在为每个用户找到最大'发生' . 执行max后的最终结果将导致RDD看起来像这样:

[[(u'u1', u's1', 20)], [(u'u2', u's2', 10)]]

其中只有最大数据集将保留给文件中的每个用户 . 换句话说,我想将RDD的值更改为仅包含每个用户最多出现的一个三元组 .

2 回答

  • 11

    这里没有 groupBy . 简单 reduceByKey 会很好,大部分时间会更有效:

    data_file = sc.parallelize([
       (u'u1', u's1', 20), (u'u1', u's2', 5),
       (u'u2', u's3', 5), (u'u2', u's2', 10)])
    
    max_by_group = (data_file
      .map(lambda x: (x[0], x))  # Convert to PairwiseRD
      # Take maximum of the passed arguments by the last element (key)
      # equivalent to:
      # lambda x, y: x if x[-1] > y[-1] else y
      .reduceByKey(lambda x1, x2: max(x1, x2, key=lambda x: x[-1])) 
      .values()) # Drop keys
    
    max_by_group.collect()
    ## [('u2', 's2', 10), ('u1', 's1', 20)]
    
  • 2

    我想我找到了解决方案:

    from pyspark import SparkContext, SparkConf
    
    def reduce_by_max(rdd):
        """
        Helper function to find the max value in a list of values i.e. triplets. 
        """
        max_val = rdd[0][2]
        the_index = 0
    
        for idx, val in enumerate(rdd):
            if val[2] > max_val:
                max_val = val[2]
                the_index = idx
    
        return rdd[the_index]
    
    conf = SparkConf() \
        .setAppName("Collaborative Filter") \
        .set("spark.executor.memory", "5g")
    sc = SparkContext(conf=conf)
    
    # some file contains tuples ('user', 'item', 'occurrences')
    data_file = sc.textData('file:///some_file.txt')
    
    # Create the triplet so I can index stuff
    data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))
    
    # Group by the user i.e. r[0]
    grouped = data_file.groupBy(lambda r: r[0])
    
    # Get the values as a list
    group_list = grouped.map(lambda x: (list(x[1]))) 
    
    # Get the max value for each user. 
    max_list = group_list.map(reduce_by_max).collect()
    

相关问题