首页 文章

如何在决策树中获取所有基尼指数?

提问于
浏览
2

我使用sklearn做了一个决策树,在这里,在SciKit学习DL包,即 . sklearn.tree.DecisionTreeClassifier().fit(x,y) .

如何获取每个步骤中所有可能节点的gini索引? graphviz 仅给出了具有最低gini索引的节点的gini索引,即用于拆分的节点 .

例如,下面的图片(来自 graphviz )告诉我Pclass_lowVMid权利索引的基尼评分为0.408,但不是该步骤中Pclass_lower或Sex_male的基尼指数 . 我只知道Pclass_lower的Gini指数和Sex_male必须大于(0.408 * 0.7 0),但就是这样 .

2 回答

  • 0

    使用export_graphviz显示所有节点的杂质,至少在版本 0.20.1 中 .

    from sklearn.datasets import load_iris
    from sklearn.tree import DecisionTreeClassifier, export_graphviz
    from graphviz import Source
    
    data = load_iris()
    X, y = data.data, data.target
    
    clf = DecisionTreeClassifier(max_depth=2, random_state=42)
    clf.fit(X, y)
    
    graph = Source(export_graphviz(clf, out_file=None, feature_names=data.feature_names))
    graph.format = 'png'
    graph.render('dt', view=True);
    

    所有节点的杂质值也可在 treeimpurity 属性中访问 .

    clf.tree_.impurity
    array([0.66666667, 0.        , 0.5       , 0.16803841, 0.04253308])
    
  • 1

    pclass node的gini索引=左节点的gini索引*(左节点的样本数/右节点的左节点的样本数)右节点的gini索引*(左节点的样本数/在左节点处的样品在右边节点处的样品数量)所以这里它将是

    Gini index of pclass = 0 + .408 *(7/10) = 0.2856
    

相关问题