首页 文章

使用Spark数据集在Scala中执行类型化连接

提问于
浏览
24

我喜欢Spark数据集,因为它们在编译时给我分析错误和语法错误,并且允许我使用getter而不是硬编码的名称/数字 . 大多数计算都可以使用Dataset的高级API完成 . 例如,通过访问数据集类型对象而不是使用RDD行的数据字段来执行 agg, select, sum, avg, map, filter, or groupBy 操作要简单得多 .

但是,由于缺少连接操作,我读到我可以像这样进行连接

ds1.joinWith(ds2, ds1.toDF().col("key") === ds2.toDF().col("key"), "inner")

但这不是我想要的,因为我更喜欢通过case类接口来做,所以更像这样的东西

ds1.joinWith(ds2, ds1.key === ds2.key, "inner")

现在最好的选择似乎是在case类旁边创建一个对象,并给这个函数提供正确的列名作为String . 所以我会使用第一行代码但是放置一个函数而不是硬编码的列名 . 但那感觉不够优雅..

有人可以告诉我其他选项吗?目标是从实际的列名中抽象出来,最好通过case类的getter工作 .

我正在使用Spark 1.6.1和Scala 2.10

2 回答

  • 24

    观察

    仅当连接条件基于相等运算符时,Spark SQL才能优化连接 . 这意味着我们可以分别考虑等量连接和非等量连接 .

    Equijoin

    通过将 Datasets 映射到(键,值)元组,基于键执行连接以及重新整形结果,可以以类型安全的方式实现Equijoin:

    import org.apache.spark.sql.Encoder
    import org.apache.spark.sql.Dataset
    
    def safeEquiJoin[T, U, K](ds1: Dataset[T], ds2: Dataset[U])
        (f: T => K, g: U => K)
        (implicit e1: Encoder[(K, T)], e2: Encoder[(K, U)], e3: Encoder[(T, U)]) = {
      val ds1_ = ds1.map(x => (f(x), x))
      val ds2_ = ds2.map(x => (g(x), x))
      ds1_.joinWith(ds2_, ds1_("_1") === ds2_("_1")).map(x => (x._1._2, x._2._2))
    }
    

    非等值

    可以使用关系代数运算符表示为R⋈θS=σθ(R×S)并直接转换为代码 .

    Spark 2.0

    启用 crossJoin 并使用 joinWith 与简单相等的谓词:

    spark.conf.set("spark.sql.crossJoin.enabled", true)
    
    def safeNonEquiJoin[T, U](ds1: Dataset[T], ds2: Dataset[U])
                             (p: (T, U) => Boolean) = {
      ds1.joinWith(ds2, lit(true)).filter(p.tupled)
    }
    

    Spark 2.1

    使用 crossJoin 方法:

    def safeNonEquiJoin[T, U](ds1: Dataset[T], ds2: Dataset[U])
        (p: (T, U) => Boolean)
        (implicit e1: Encoder[Tuple1[T]], e2: Encoder[Tuple1[U]], e3: Encoder[(T, U)]) = {
      ds1.map(Tuple1(_)).crossJoin(ds2.map(Tuple1(_))).as[(T, U)].filter(p.tupled)
    }
    

    例子

    case class LabeledPoint(label: String, x: Double, y: Double)
    case class Category(id: Long, name: String)
    
    val points1 = Seq(LabeledPoint("foo", 1.0, 2.0)).toDS
    val points2 = Seq(
      LabeledPoint("bar", 3.0, 5.6), LabeledPoint("foo", -1.0, 3.0)
    ).toDS
    val categories = Seq(Category(1, "foo"), Category(2, "bar")).toDS
    
    safeEquiJoin(points1, categories)(_.label, _.name)
    safeNonEquiJoin(points1, points2)(_.x > _.x)
    

    注意事项

    • 应该注意的是,这些方法与直接 joinWith 应用程序的质量不同,并且需要进行昂贵的转换(与直接 joinWith 可以对数据使用逻辑运算相比) .

    这类似于Spark 2.0 Dataset vs DataFrame中描述的行为 .

    • 如果您不限于Spark SQL API framelessDatasets 提供了有趣的类型安全扩展(截至今天它仅支持Spark 2.0):
    import frameless.TypedDataset
    
    val typedPoints1 = TypedDataset.create(points1)
    val typedPoints2 = TypedDataset.create(points2)
    
    typedPoints1.join(typedPoints2, typedPoints1('x), typedPoints2('x))
    
    • Dataset API在1.6中不稳定所以我觉得在那里使用它没有意义 .

    • 当然,这种设计和描述性名称不是必需的 . 您可以轻松地使用类型类隐式地将此方法添加到 Dataset 并且与内置签名没有冲突,因此两者都可以被称为 joinWith .

  • -1

    另外,对于非类型安全的Spark API,另一个更大的问题是,当你加入两个 Datasets 时,它会给你一个 DataFrame . 然后你丢失原始两个数据集中的类型 .

    val a: Dataset[A]
    val b: Dataset[B]
    
    val joined: Dataframe = a.join(b)
    // what would be great is 
    val joined: Dataset[C] = a.join(b)(implicit func: (A, B) => C)
    

相关问题