首页 文章

如何实现ScalaTest FunSuite以避免样板Spark代码和导入含义

提问于
浏览
6

我尝试重构ScalaTest FunSuite测试,以避免样板代码初始化并销毁Spark会话 .

问题是我需要导入隐式函数,但使用前/后方法只能使用变量(var字段),并且导入它是必要的值(val字段) .

The idea is to have a new clean Spark Session every test execution.

我尝试做这样的事情:

import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.{BeforeAndAfter, FunSuite}

object SimpleWithBeforeTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _
  var sc: SparkContext = _
  implicit var sqlContext: SQLContext = _

  before {
    spark = SparkSession.builder
      .master("local")
      .appName("Spark session for testing")
      .getOrCreate()
    sc = spark.sparkContext
    sqlContext = spark.sqlContext
  }

  after {
    spark.sparkContext.stop()
  }

  test("Import implicits inside the test 1") {
    import sqlContext.implicits._

    // Here other stuff
  }

  test("Import implicits inside the test 2") {
    import sqlContext.implicits._

    // Here other stuff
  }

但在 import sqlContext.implicits._ 行我有一个错误

无法解析符号sqlContext

如何解决此问题或如何实现测试类?

2 回答

  • 1

    为spark上下文定义一个新的不可变变量,并在导入implicits之前将var赋值给它 .

    class MyCassTest extends FlatSpec with BeforeAndAfter {
    
      var spark: SparkSession = _
    
      before {
        val sparkConf: SparkConf = new SparkConf()    
        spark = SparkSession.
          builder().
          config(sparkConf).
          master("local[*]").
          getOrCreate()
      }
    
      after {
        spark.stop()
      }
    
      "myFunction()" should "return 1.0 blab bla bla" in {
        val sc = spark
        import sc.implicits._
    
        // assert ...
      }
    }
    
  • 1

    您也可以使用spark-testing-base,它几乎可以处理所有样板代码 .

    这是创作者的a blog post,解释了如何使用它 .

    这是一个来自wiki的简单示例:

    class test extends FunSuite with DatasetSuiteBase {
      test("simple test") {
        val sqlCtx = sqlContext
        import sqlCtx.implicits._
    
        val input1 = sc.parallelize(List(1, 2, 3)).toDS
        assertDatasetEquals(input1, input1) // equal
    
        val input2 = sc.parallelize(List(4, 5, 6)).toDS
        intercept[org.scalatest.exceptions.TestFailedException] {
            assertDatasetEquals(input1, input2) // not equal
        }
      }
    }
    

相关问题