首页 文章

Scala - Spark在Dataframe中检索,具有最大值的行,列名称

提问于
浏览
3

我有一个DataFrame:

name     column1  column2  column3  column4
first    2        1        2.1      5.4
test     1.5      0.5      0.9      3.7
choose   7        2.9      9.1      2.5

我想要一个带有包含列的新数据框,列名称包含row的最大值:

| name   | max_column |
|--------|------------|
| first  | column4    |
| test   | column4    |
| choose | column3    |

非常感谢您的支持 .

3 回答

  • 3

    可能有一些更好的编写UDF的方法 . 但这可能是有效的解决方案

    val spark: SparkSession = SparkSession.builder.master("local").getOrCreate
    
    //implicits for magic functions like .toDf
    import spark.implicits._
    
    import org.apache.spark.sql.functions.udf
    
    //We have hard code number of params as UDF don't support variable number of args
    val maxval = udf((c1: Double, c2: Double, c3: Double, c4: Double) =>
      if(c1 >= c2 && c1 >= c3 && c1 >= c4)
        "column1"
      else if(c2 >= c1 && c2 >= c3 && c2 >= c4)
        "column2"
      else if(c3 >= c1 && c3 >= c2 && c3 >= c4)
        "column3"
      else
        "column4"
    )
    
    //create schema class
    case class Record(name: String, 
                        column1: Double, 
                        column2: Double, 
                        column3: Double, 
                        column4: Double)
    
    val df = Seq(
      Record("first", 2.0, 1, 2.1, 5.4),
      Record("test", 1.5, 0.5, 0.9, 3.7),
      Record("choose", 7, 2.9, 9.1, 2.5)
    ).toDF();
    
    df.withColumn("max_column", maxval($"column1", $"column2", $"column3", $"column4"))
      .select("name", "max_column").show
    

    Output

    +------+----------+
    |  name|max_column|
    +------+----------+
    | first|   column4|
    |  test|   column4|
    |choose|   column3|
    +------+----------+
    
  • 4

    完成工作后,绕道RDD并使用'getValuesMap' .

    val dfIn = Seq(
      ("first", 2.0, 1., 2.1, 5.4),
      ("test", 1.5, 0.5, 0.9, 3.7),
      ("choose", 7., 2.9, 9.1, 2.5)
    ).toDF("name","column1","column2","column3","column4")
    

    简单的解决方案是

    val dfOut = dfIn.rdd
      .map(r => (
           r.getString(0),
           r.getValuesMap[Double](r.schema.fieldNames.filter(_!="name"))
         ))
      .map{case (n,m) => (n,m.maxBy(_._2)._1)}
      .toDF("name","max_column")
    

    但是如果你想从原始数据框中取回所有列(比如Scala/Spark dataframes: find the column name corresponding to the max),你必须在合并行和扩展模式时玩一点

    import org.apache.spark.sql.types.{StructType,StructField,StringType}
    import org.apache.spark.sql.Row
    val dfOut = sqlContext.createDataFrame(
      dfIn.rdd
        .map(r => (r, r.getValuesMap[Double](r.schema.fieldNames.drop(1))))
        .map{case (r,m) => Row.merge(r,(Row(m.maxBy(_._2)._1)))},
      dfIn.schema.add(StructField("max_column",StringType))
    )
    
  • 0

    我想发布我的最终解决方案:

    val finalDf = originalDf.withColumn("name", maxValAsMap(keys, values)).select("cookie_id", "max_column")
    
    val maxValAsMap = udf((keys: Seq[String], values: Seq[Any]) => {
    
        val valueMap:Map[String,Double] = (keys zip values).filter( _._2.isInstanceOf[Double] ).map{
          case (x,y) => (x, y.asInstanceOf[Double])
        }.toMap
    
        if (valueMap.isEmpty) "not computed" else valueMap.maxBy(_._2)._1
      })
    

    它工作得非常快 .

相关问题