首页 文章

scikit-learn在哪里保存树结构中每个叶节点的决策标签?

提问于
浏览
6

我已经使用scikit-learn训练了一个随机的森林模型,现在我想将它的树结构保存在文本文件中,以便我可以在其他地方使用它 . 根据this link,树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如,左子,右子,它检查的特征,......) . 但是,似乎没有关于每个叶节点对应的类标签的信息!在上面的链接中提供的示例中甚至没有提到它 .

有谁知道scikit-learn决策树结构中存储的类标签在哪里?

1 回答

  • 3

    看一下sklearn.tree.DecisionTreeClassifier.tree_.value的文档:

    from sklearn.datasets import load_iris
    from sklearn.cross_validation import cross_val_score
    from sklearn.tree import DecisionTreeClassifier
    
    clf = DecisionTreeClassifier(random_state=0)
    iris = load_iris()
    
    clf.fit(iris.data, iris.target)
    
    print(clf.classes_)
    
    [0, 1, 2]
    
    print(clf.tree_.value)
    
    [[[ 50.  50.  50.]]
    
     [[ 50.   0.   0.]]
    
     [[  0.  50.  50.]]
    
     [[  0.  49.   5.]]
    
     [[  0.  47.   1.]]
    
     [[  0.  47.   0.]]
    
     [[  0.   0.   1.]]
    
     [[  0.   2.   4.]]
    
     [[  0.   0.   3.]]
    
     [[  0.   2.   1.]]
    
     [[  0.   2.   0.]]
    
     [[  0.   0.   1.]]
    
     [[  0.   1.  45.]]
    
     [[  0.   1.   2.]]
    
     [[  0.   0.   2.]]
    
     [[  0.   1.   0.]]
    
     [[  0.   0.  43.]]]
    

    clf.tree_.value "contains the constant prediction value of each node,"( help(clf.tree_) )中的每一行对应索引到索引 clf.classes_ .

    请参阅this answer了解(几乎没有)更多细节 .

相关问题