首页 文章

有没有办法在决策树的每个叶子下面获取样本?

提问于
浏览
4

我使用数据集训练了决策树 . 现在我想看看哪些样本落在树的哪个叶子下面 .

从这里我想要红色圆圈样本 .

enter image description here

我正在使用Python的Sklearn的决策树实现 .

1 回答

  • 6

    如果您只想要每个样品的叶子,您可以使用

    clf.apply(iris.data)
    

    数组([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1] 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,14,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,14,5,5 5,5,5,5,10,5,5,5,5,10,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,10,5,5,5,5,5,5,5,5,5,5,5,5,5,5 5,5,16,16,16,16,16,16,6,16,16,16,16,16,16,16,16,16,16,16,16,8,16,16,16, 16,16,16,15,16,16,11,16,16,16,8,8,16,16,16,15,16,16,16,16,16,16,16,16,16, 16,16])

    如果要获取每个节点的所有样本,可以使用计算所有决策路径

    dec_paths = clf.decision_path(iris.data)
    

    然后遍历决策路径,将它们转换为带有 toarray() 的数组,并检查它们是否属于某个节点 . 所有内容都存储在 defaultdict 中,其中键是节点编号,值是样本编号 .

    for d, dec in enumerate(dec_paths):
        for i in range(clf.tree_.node_count):
            if dec.toarray()[0][i]  == 1:
                samples[i].append(d)
    

    Complete code

    import sklearn.datasets
    import sklearn.tree
    import collections
    
    clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
    iris = sklearn.datasets.load_iris()
    clf = clf.fit(iris.data, iris.target)
    
    samples = collections.defaultdict(list)
    dec_paths = clf.decision_path(iris.data)
    
    for d, dec in enumerate(dec_paths):
        for i in range(clf.tree_.node_count):
            if dec.toarray()[0][i]  == 1:
                samples[i].append(d)
    

    Output

    print(samples[13])
    

    [70,126,138]

相关问题