首页 文章

CrossValidator调整火花ML在params上失败“在模型保存中发现了一个无关的Param”

提问于
浏览
0

作为paramGrid的一部分,我在logistic回归中使用regParam运行spark ml交叉验证 .

val paramGrid = new ParamGridBuilder()
    .addGrid(lr.regParam, Array(0.1, 0.01))
    .build()

 val validator = new CrossValidator()
  .setEstimator(estimator)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(3)

估算器在这里有regParam作为params的一部分 . 保存模型的示例代码:

class MyModelWriter(instance: MyModel[T])extends MLWriter {

  override protected def saveImpl(path: String): Unit = {
    new DefaultParamsWriter(instance).save(path)
    instance.model.save(new Path(path, s"nameOfMofel").toString)
  }
}

Mymodel确实在params中包含了regParam .

MyModel extends HasRegParam

当我调用model.save(path)时,这是我得到的异常:

java.lang.IllegalArgumentException:要求失败:ValidatorParams save要求estimatorParamMaps中的所有Params应用于此ValidatorParams,其Estimator或其Evaluator . 找到了一个无关的Param:logla_2fb5fdbe5012__regParam at scala.Predef $ .require(Predef.scala:224)at org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1 $$ anonfun $ apply $ 1.apply(ValidatorParams . scala:110)at org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1 $$ anonfun $ apply $ 1.apply(ValidatorParams.scala:109)at scala.collection.mutable.ResizableArray $ class.foreach( ResizableArray.scala:59)在org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1.apply(ValidatorParams.scala:109)的scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)在scala.collection.mutable的scala.collection.IndexedSeqOptimized $ class.foreach(IndexedSeqOptimized.scala:33)的org.apache.spark.ml.tuning.ValidatorParams $$ anonfun $ validateParams $ 1.apply(ValidatorParams.scala:108) .ArrayOps $ ofRef.foreach(ArrayOps.scala:186)atg.apache.spark.ml.tuning.ValidatorParams $ .validateParams(ValidatorParams.scala:108)at org.apache.spark.ml.tuning.CrossValidatorMo del $ CrossValidatorModelWriter . (CrossValidator.scala:257)org.apache.spark.ml.tuning.CrossValidatorModel.write(CrossValidator.scala:242)org.apache.spark.ml.util.MLWritable $ class.save(ReadWrite) .scala:157)atg.apache.spark.ml.tuning.CrossValidatorModel.save(CrossValidator.scala:210)at com.criteo.lookalike.sink.Sinks $$ anonfun $ SavePipelineParam1 $ 1.apply(Sinks.scala:111

L105上ValidatorParams.scala的代码说

//检查以确保所有Params都适用于此估算器 . 如果没有,则抛出错误 .

根据这一点,它确保estimatorMap中的参数,即在这种情况下的regParam存在于估计器/评估器中,在这种情况下,确实存在于上面的Mymodel中 .

任何人都可以告诉我,如果我的理解是正确的,如果是的话,可能是什么导致了这个?谢谢 .

1 回答

  • 0

    我刚刚解决了这个确切的错误 .

    添加网格时,请尝试传递Param实例;并且,在实例化Param时,将其与记录的参数类型相匹配,如您在https://spark.apache.org/docs/latest/api/scala/下找到的那样...

    例如,在 RandomForestRegressor 中有 numTrees: IntParam .

    因此我按如下方式构建了param网格......

    val rf = new RandomForestRegressor()
    .{set...()} // (pseudocode)
    
    val numTrees = new IntParam(rf, "numTrees", "Number of trees to train (>= 1) (default = 20)")
    
    // for fun/preference, i make numTrees[Int] increase as does the area of a circle
    val numTreesValues = (for (n <- 3 to 20 by 3) yield (math.Pi * math.pow(n, 2)).toInt)
    
    val paramGrid = new ParamGridBuilder()
    .addGrid(numTrees, numTreesValues)
    .build()
    

    尝试将估算器传递给参数,然后将参数和值传递给 .addGrid

    我的验证器看起来像这样......

    val cv = new CrossValidator()
    .setEstimator(rf)
    .setEstimatorParamMaps(paramGrid)
    .{set...()}
    

相关问题