首页 文章

scikit学习决策树导出graphviz - 决策树中的错误类名

提问于
浏览
1

我在决策树中从“scikit learn / decision tree / export graphviz”得到了错误的类名 . 该计划如下:

import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree

digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']

digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)

with open("digital.dot", 'w') as f:
    f = tree.export_graphviz(digital_tree, 
                            feature_names=digital_name,
                            class_names=digital_label,
                            filled=True, rounded=True,
                            out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")

plt.imshow(img.imread('digital.png'))
plt.show()

输出如下:

the decision tree

问题是关于叶子中显示的类名 . 例如,如果idx-1为1且idx-2为1,则绿框应标记为“3” . 但是,图像将标签显示为“1” . 有谁能发表你的意见?

2 回答

  • 0

    当您使用DecisionTreeClassifier时,您应该将类标签更改为0,1,2之类的数字

    然后使用:

    classe_names = decision_tree_classifier.classes_
    

    它将按升序为您提供类的标签 . 然后以相同的顺序指定class_label . 它可以是字符串 .

  • 2

    尝试按类别排序类标签,然后再将它们传递给 export_graphviz

相关问题