我已经使用scikit-learn训练了一个随机的森林模型,现在我想将它的树结构保存在文本文件中,以便我可以在其他地方使用它 . 根据this link,树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如,左子,右子,它检查的特征,......) . 但是,似乎没有关于每个叶节点对应的类标签的信息!在上面的链接中提供的示例中甚至没有提到它 .
有谁知道scikit-learn决策树结构中存储的类标签在哪里?
看一下sklearn.tree.DecisionTreeClassifier.tree_.value的文档:
from sklearn.datasets import load_iris from sklearn.cross_validation import cross_val_score from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(random_state=0) iris = load_iris() clf.fit(iris.data, iris.target) print(clf.classes_) [0, 1, 2] print(clf.tree_.value) [[[ 50. 50. 50.]] [[ 50. 0. 0.]] [[ 0. 50. 50.]] [[ 0. 49. 5.]] [[ 0. 47. 1.]] [[ 0. 47. 0.]] [[ 0. 0. 1.]] [[ 0. 2. 4.]] [[ 0. 0. 3.]] [[ 0. 2. 1.]] [[ 0. 2. 0.]] [[ 0. 0. 1.]] [[ 0. 1. 45.]] [[ 0. 1. 2.]] [[ 0. 0. 2.]] [[ 0. 1. 0.]] [[ 0. 0. 43.]]]
clf.tree_.value "contains the constant prediction value of each node,"( help(clf.tree_) )中的每一行对应索引到索引 clf.classes_ .
clf.tree_.value
help(clf.tree_)
clf.classes_
请参阅this answer了解(几乎没有)更多细节 .
1 回答
看一下sklearn.tree.DecisionTreeClassifier.tree_.value的文档:
clf.tree_.value
"contains the constant prediction value of each node,"(help(clf.tree_)
)中的每一行对应索引到索引clf.classes_
.请参阅this answer了解(几乎没有)更多细节 .