首页 文章

Spark ML转换预测标签到没有训练DataFrame的字符串

提问于
浏览
0

我在Apache Spark ML(版本2.1.0)中使用NaiveBayes多项分类器来预测一些文本类别 .

问题是如何将预测标签(0.0,1.0,2.0)转换为没有经过训练的DataFrame的字符串 .

我知道可以使用IndexToString,但只有在训练和预测同时进行时它才有用 . 但是,就我而言,它的独立工作 .

代码看起来像
1)TrainingModel.scala:训练模型并将模型保存在文件中 .
2)CategoryPrediction.scala:从文件加载训练的模型并对测试数据进行预测 .

请建议解决方案:

TrainingModel.scala

val trainData: Dataset[LabeledRecord] = spark.read.option("inferSchema", "false")
  .schema(schema).csv("trainingdata1.csv").as[LabeledRecord]

val labelIndexer = new StringIndexer().setInputCol("category").setOutputCol("label").fit(trainData).setHandleInvalid("skip")

val tokenizer = new RegexTokenizer().setInputCol("text").setOutputCol("words")

val hashingTF = new HashingTF()
  .setInputCol("words")
  .setOutputCol("features")
  .setNumFeatures(1000)

val rf = new NaiveBayes().setLabelCol("label").setFeaturesCol("features").setModelType("multinomial")

val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf))

val model = pipeline.fit(trainData)

model.write.overwrite().save("naivebayesmodel");

CategoryPrediction.scala

val testData: Dataset[PredictLabeledRecord] = spark.read.option("inferSchema", "false")
        .schema(predictSchema).csv("testingdata.csv").as[PredictLabeledRecord]

val model = PipelineModel.load("naivebayesmodel")

val predictions = model.transform(testData)

//      val labelConverter = new IndexToString()
//      .setInputCol("prediction")
//      .setOutputCol("predictedLabelString")
//      .setLabels(trainDataFrameIndexer.labels)    

predictions.select("prediction", "text").show(false)

trainingdata1.csv

category,text
Drama,"a b c d e spark"
Action,"b d"
Horror,"spark f g h"
Thriller,"hadoop mapreduce"

testingdata.csv

text
"a b c d e spark"
"spark f g h"

1 回答

  • 2

    添加一个转换器,将预测类别转换回管道中的标签,如下所示:

    val categoryConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("category")
      .setLabels(labelIndexer.labels)
    
    val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf, categoryConverter))
    

    这将采用预测并使用labelIndexer将其转换回标签 .

相关问题