首页 文章

u'DecisionTreeClassifier被赋予输入无效的标签列标签,没有指定的类数 . 请参见StringIndexer

提问于
浏览
0
#Load the CSV file into a RDD
    irisData = sc.textFile("/home/infademo/surya/iris.csv")
    irisData.cache()
    irisData.count()

    #Remove the first line (contains headers)
    dataLines = irisData.filter(lambda x: "Sepal" not in x)
    dataLines.count()

    from pyspark.sql import Row
    #Create a Data Frame from the data
    parts = dataLines.map(lambda l: l.split(","))
    irisMap = parts.map(lambda p: Row(SEPAL_LENGTH=float(p[0]),\
                                    SEPAL_WIDTH=float(p[1]), \
                                    PETAL_LENGTH=float(p[2]), \
                                    PETAL_WIDTH=float(p[3]), \
                                    SPECIES=p[4] ))

    # Infer the schema, and register the DataFrame as a table.
    irisDf = sqlContext.createDataFrame(irisMap)
    irisDf.cache()

    #Add a numeric indexer for the label/target column
    from pyspark.ml.feature import StringIndexer
    stringIndexer = StringIndexer(inputCol="SPECIES", outputCol="IND_SPECIES")
    si_model = stringIndexer.fit(irisDf)
    irisNormDf = si_model.transform(irisDf)

    irisNormDf.select("SPECIES","IND_SPECIES").distinct().collect()
    irisNormDf.cache()

    """--------------------------------------------------------------------------
    Perform Data Analytics
    -------------------------------------------------------------------------"""

    #See standard parameters
    irisNormDf.describe().show()

    #Find correlation between predictors and target
    for i in irisNormDf.columns:
        if not( isinstance(irisNormDf.select(i).take(1)[0][0], basestring)) :
            print( "Correlation to Species for ", i, \
                        irisNormDf.stat.corr('IND_SPECIES',i))



    #Transform to a Data Frame for input to Machine Learing
    #Drop columns that are not required (low correlation)

    from pyspark.mllib.linalg import Vectors
    from pyspark.mllib.linalg import SparseVector
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.util import MLUtils
    import org.apache.spark.mllib.linalg.{Matrix, Matrices}
    from pyspark.mllib.linalg.distributed import RowMatrix

    from pyspark.ml.linalg import Vectors
    pyspark.mllib.linalg.Vector
    def transformToLabeledPoint(row) :
        lp = ( row["SPECIES"], row["IND_SPECIES"], \
                    Vectors.dense([row["SEPAL_LENGTH"],\
                            row["SEPAL_WIDTH"], \
                            row["PETAL_LENGTH"], \
                            row["PETAL_WIDTH"]]))
        return lp




    irisLp = irisNormDf.rdd.map(transformToLabeledPoint)
    irisLpDf = sqlContext.createDataFrame(irisLp,["species","label", "features"])
    irisLpDf.select("species","label","features").show(10)
    irisLpDf.cache()

    """--------------------------------------------------------------------------
    Perform Machine Learning
    -------------------------------------------------------------------------"""
    #Split into training and testing data
    (trainingData, testData) = irisLpDf.randomSplit([0.9, 0.1])
    trainingData.count()
    testData.count()
    testData.collect()

    from pyspark.ml.classification import DecisionTreeClassifier
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator

    #Create the model
    dtClassifer = DecisionTreeClassifier(maxDepth=2, labelCol="label",\
                    featuresCol="features")

   dtModel = dtClassifer.fit(trainingData)

回溯(最近一次调用最后一次):文件“", line 1, in File " /opt/mapr/spark/spark-1.6.1-bin-hadoop2.6/python/pyspark/ml/pipeline.py ", line 69, in fit return self._fit(dataset) File "/opt/mapr/spark/spark-1.6 .1-bin-hadoop2.6 / python / pyspark / ml / wrapper.py ", line 133, in _fit java_model = self._fit_java(dataset) File " /opt/mapr/spark/spark-1.6.1-bin-hadoop2.6/python/pyspark/ml/wrapper.py ", line 130, in _fit_java return self._java_obj.fit(dataset._jdf) File " / opt /mapr/spark/spark-1.6.1-bin-hadoop2.6/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py“,第813行,在 call 文件"/opt/mapr/spark/spark-1.6.1-bin-hadoop2.6/python/pyspark/sql/utils.py",第53行,在deco中IllegalArgumentException(s.split(': ',1)[1],stackTrace)pyspark.sql.utils.IllegalArgumentException:u 'DecisionTreeClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.'

1 回答

  • 0

    根据Spark 1.6.1 document

    我们使用两个特征变换器来准备数据;这些帮助标记和分类特征的索引类别,向决策树算法可识别的DataFrame添加元数据 .

    根据Spark 1.6.1 source code

    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
      case Some(n: Int) => n
      case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
        s" with invalid label column ${$(labelCol)}, without the number of classes" +
        " specified. See StringIndexer.")
        // TODO: Automatically index labels: SPARK-7126
    }
    

    因此,在传递给 DecisionTreeClassifier 之前,您需要将 StringIndexer 用于 label 列,将 VectorIndexer 用于 features 列 . fit

相关问题