首页 文章

如何从scikit-learn决策树中提取决策规则?

提问于
浏览
99

我可以从决策树中的受过训练的树中提取基础决策规则(或“决策路径”)作为文本列表吗?

就像是:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

谢谢你的帮助 .

13 回答

  • 33

    我相信这个答案比其他答案更正确:

    from sklearn.tree import _tree
    
    def tree_to_code(tree, feature_names):
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        print "def tree({}):".format(", ".join(feature_names))
    
        def recurse(node, depth):
            indent = "  " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                print "{}if {} <= {}:".format(indent, name, threshold)
                recurse(tree_.children_left[node], depth + 1)
                print "{}else:  # if {} > {}".format(indent, name, threshold)
                recurse(tree_.children_right[node], depth + 1)
            else:
                print "{}return {}".format(indent, tree_.value[node])
    
        recurse(0, 1)
    

    这打印出一个有效的Python函数 . 以下是尝试返回其输入的树的示例输出,该数字介于0和10之间 .

    def tree(f0):
      if f0 <= 6.0:
        if f0 <= 1.5:
          return [[ 0.]]
        else:  # if f0 > 1.5
          if f0 <= 4.5:
            if f0 <= 3.5:
              return [[ 3.]]
            else:  # if f0 > 3.5
              return [[ 4.]]
          else:  # if f0 > 4.5
            return [[ 5.]]
      else:  # if f0 > 6.0
        if f0 <= 8.5:
          if f0 <= 7.5:
            return [[ 7.]]
          else:  # if f0 > 7.5
            return [[ 8.]]
        else:  # if f0 > 8.5
          return [[ 9.]]
    

    以下是我在其他答案中看到的一些绊脚石:

    • 使用 tree_.threshold == -2 来判断一个节点是否是一个叶子是一个真正的决策节点,其阈值为-2?相反,你应该看 tree.featuretree.children_* .

    • features = [feature_names[i] for i in tree_.feature] 与我的sklearn版本崩溃,因为 tree.tree_.feature 的某些值为-2(特别是对于叶节点) .

    • 递归函数中不需要多个if语句,只需一个就可以了 .

  • 84

    我创建了自己的函数来从sklearn创建的决策树中提取规则:

    import pandas as pd
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
    
    # dummy data:
    df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
    
    # create decision tree
    dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
    dt.fit(df.ix[:,:2], df.dv)
    

    此函数首先从节点(在子数组中由-1标识)开始,然后以递归方式查找父节点 . 我将此称为节点的“谱系” . 一路上,我 grab 了我需要创建的值,如果/ then / else SAS逻辑:

    def get_lineage(tree, feature_names):
         left      = tree.tree_.children_left
         right     = tree.tree_.children_right
         threshold = tree.tree_.threshold
         features  = [feature_names[i] for i in tree.tree_.feature]
    
         # get ids of child nodes
         idx = np.argwhere(left == -1)[:,0]     
    
         def recurse(left, right, child, lineage=None):          
              if lineage is None:
                   lineage = [child]
              if child in left:
                   parent = np.where(left == child)[0].item()
                   split = 'l'
              else:
                   parent = np.where(right == child)[0].item()
                   split = 'r'
    
              lineage.append((parent, split, threshold[parent], features[parent]))
    
              if parent == 0:
                   lineage.reverse()
                   return lineage
              else:
                   return recurse(left, right, parent, lineage)
    
         for child in idx:
              for node in recurse(left, right, child):
                   print node
    

    下面的元组集包含创建SAS if / then / else语句所需的一切 . 我不喜欢在SAS中使用 do 块,这就是我创建描述节点整个路径的逻辑的原因 . 元组之后的单个整数是路径中终端节点的ID . 所有前面的元组组合起来创建该节点 .

    In [1]: get_lineage(dt, df.columns)
    (0, 'l', 0.5, 'col1')
    1
    (0, 'r', 0.5, 'col1')
    (2, 'l', 4.5, 'col2')
    3
    (0, 'r', 0.5, 'col1')
    (2, 'r', 4.5, 'col2')
    (4, 'l', 2.5, 'col1')
    5
    (0, 'r', 0.5, 'col1')
    (2, 'r', 4.5, 'col2')
    (4, 'r', 2.5, 'col1')
    6
    

    GraphViz output of example tree

  • 1

    我修改了Zelazny7提交的代码来打印一些伪代码:

    def get_code(tree, feature_names):
            left      = tree.tree_.children_left
            right     = tree.tree_.children_right
            threshold = tree.tree_.threshold
            features  = [feature_names[i] for i in tree.tree_.feature]
            value = tree.tree_.value
    
            def recurse(left, right, threshold, features, node):
                    if (threshold[node] != -2):
                            print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                            if left[node] != -1:
                                    recurse (left, right, threshold, features,left[node])
                            print "} else {"
                            if right[node] != -1:
                                    recurse (left, right, threshold, features,right[node])
                            print "}"
                    else:
                            print "return " + str(value[node])
    
            recurse(left, right, threshold, features, 0)
    

    如果你在同一个例子上调用 get_code(dt, df.columns) ,你将获得:

    if ( col1 <= 0.5 ) {
    return [[ 1.  0.]]
    } else {
    if ( col2 <= 4.5 ) {
    return [[ 0.  1.]]
    } else {
    if ( col1 <= 2.5 ) {
    return [[ 1.  0.]]
    } else {
    return [[ 0.  1.]]
    }
    }
    }
    
  • 10

    0.18.0版本中有一个新的DecisionTreeClassifier方法 decision_path . 开发人员提供了广泛的(记录良好的)walkthrough .

    打印树结构的演练中的第一部分代码似乎没问题 . 但是,我修改了第二部分中的代码以询问一个样本 . 我的更改用 # <-- 表示

    Edit 在拉取请求#8653#10951中指出错误后,以下代码中由 # <-- 标记的更改已在演练链接中更新 . 现在跟进起来要容易得多 .

    sample_id = 0
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                        node_indicator.indptr[sample_id + 1]]
    
    print('Rules used to predict sample %s: ' % sample_id)
    for node_id in node_index:
    
        if leave_id[sample_id] == node_id:  # <-- changed != to ==
            #continue # <-- comment out
            print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
    
        else: # < -- added else to iterate through decision nodes
            if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
                threshold_sign = "<="
            else:
                threshold_sign = ">"
    
            print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
                  % (node_id,
                     sample_id,
                     feature[node_id],
                     X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                     threshold_sign,
                     threshold[node_id]))
    
    Rules used to predict sample 0: 
    decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
    decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
    leaf node 4 reached, no decision here
    

    更改 sample_id 以查看其他样本的决策路径 . 我没有向开发人员询问这些变化,在完成示例时似乎更直观 .

  • 0
    from StringIO import StringIO
    out = StringIO()
    out = tree.export_graphviz(clf, out_file=out)
    print out.getvalue()
    

    你可以看到一个有向图树 . 然后, clf.tree_.featureclf.tree_.value 分别是分裂特征和节点值数组的节点阵列 . 您可以参考github source中的更多详细信息 .

  • 0

    下面的代码是我在anaconda python 2.7下的方法,加上包名“pydot-ng”来制作带有决策规则的PDF文件 . 我希望它有所帮助 .

    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
    clf_ = clf.fit(X, data_y)
    
    feature_names = X.columns
    class_name = clf_.classes_.astype(int).astype(str)
    
    def output_pdf(clf_, name):
        from sklearn import tree
        from sklearn.externals.six import StringIO
        import pydot_ng as pydot
        dot_data = StringIO()
        tree.export_graphviz(clf_, out_file=dot_data,
                             feature_names=feature_names,
                             class_names=class_name,
                             filled=True, rounded=True,
                             special_characters=True,
                              node_ids=1,)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())
        graph.write_pdf("%s.pdf"%name)
    
    output_pdf(clf_, name='filename%s'%n)
    

    a tree graphy show here

  • 2

    只是因为每个人都非常乐于助人,我只想对Zelazny7和Daniele的漂亮解决方案进行修改 . 这个是用于python 2.7,带有选项卡使其更具可读性:

    def get_code(tree, feature_names, tabdepth=0):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value
    
        def recurse(left, right, threshold, features, node, tabdepth=0):
                if (threshold[node] != -2):
                        print '\t' * tabdepth,
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node], tabdepth+1)
                        print '\t' * tabdepth,
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node], tabdepth+1)
                        print '\t' * tabdepth,
                        print "}"
                else:
                        print '\t' * tabdepth,
                        print "return " + str(value[node])
    
        recurse(left, right, threshold, features, 0)
    
  • 11

    这 Build 在@paulkernfeld的答案之上 . 如果你有一个带有你的特征的数据框X和一个带有你的共鸣的目标数据框y你想知道哪个y值在哪个节点结束(以及相应地绘制它的 Ant )你可以做以下事情:

    def tree_to_code(tree, feature_names):
            codelines = []
            codelines.append('def get_cat(X_tmp):\n')
            codelines.append('   catout = []\n')
            codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
            codelines.append('      Xin = X_tmp.iloc[codelines]\n')
            tree_ = tree.tree_
            feature_name = [
                feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
                for i in tree_.feature
            ]
            #print "def tree({}):".format(", ".join(feature_names))
    
            def recurse(node, depth):
                indent = "      " * depth
                if tree_.feature[node] != _tree.TREE_UNDEFINED:
                    name = feature_name[node]
                    threshold = tree_.threshold[node]
                    codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                    recurse(tree_.children_left[node], depth + 1)
                    codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                    recurse(tree_.children_right[node], depth + 1)
                else:
                    codelines.append( '{}mycat = {}\n'.format(indent, node))
    
            recurse(0, 1)
            codelines.append('      catout.append(mycat)\n')
            codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
            codelines.append('node_ids = get_cat(X)\n')
            return codelines
        mycode = tree_to_code(clf,X.columns.values)
    
        # now execute the function and obtain the dataframe with all nodes
        exec(''.join(mycode))
        node_ids = [int(x[0]) for x in node_ids.values]
        node_ids2 = pd.DataFrame(node_ids)
    
        print('make plot')
        import matplotlib.cm as cm
        colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
        #plt.figure(figsize=cm2inch(24, 21))
        for i in list(set(node_ids)):
            plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
        mytitle = ['y colored by node']
        plt.title(mytitle ,fontsize=14)
        plt.xlabel('my xlabel')
        plt.ylabel(tagname)
        plt.xticks(rotation=70)       
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
        plt.tight_layout()
        plt.show()
        plt.close
    

    不是最优雅的版本,但它做的工作......

  • 0

    这是一个函数,在python 3下打印scikit-learn决策树的规则,并使用条件块的偏移量来使结构更具可读性:

    def print_decision_tree(tree, feature_names=None, offset_unit='    '):
        '''Plots textual representation of rules of a decision tree
        tree: scikit-learn representation of tree
        feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
        offset_unit: a string of offset of the conditional block'''
    
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        value = tree.tree_.value
        if feature_names is None:
            features  = ['f%d'%i for i in tree.tree_.feature]
        else:
            features  = [feature_names[i] for i in tree.tree_.feature]        
    
        def recurse(left, right, threshold, features, node, depth=0):
                offset = offset_unit*depth
                if (threshold[node] != -2):
                        print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node],depth+1)
                        print(offset+"} else {")
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node],depth+1)
                        print(offset+"}")
                else:
                        print(offset+"return " + str(value[node]))
    
        recurse(left, right, threshold, features, 0,0)
    
  • 2

    我一直在经历这个,但我需要以这种格式编写规则

    if A>0.4 then if B<0.2 then if C>0.8 then class='X'
    

    所以我调整了@paulkernfeld(谢谢)的答案,你可以根据自己的需要进行定制

    def tree_to_code(tree, feature_names, Y):
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        pathto=dict()
    
        global k
        k = 0
        def recurse(node, depth, parent):
            global k
            indent = "  " * depth
    
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                s= "{} <= {} ".format( name, threshold, node )
                if node == 0:
                    pathto[node]=s
                else:
                    pathto[node]=pathto[parent]+' & ' +s
    
                recurse(tree_.children_left[node], depth + 1, node)
                s="{} > {}".format( name, threshold)
                if node == 0:
                    pathto[node]=s
                else:
                    pathto[node]=pathto[parent]+' & ' +s
                recurse(tree_.children_right[node], depth + 1, node)
            else:
                k=k+1
                print(k,')',pathto[parent], tree_.value[node])
        recurse(0, 1, 0)
    
  • 2

    修改了Zelazny7的代码,用于从决策树中获取SQL .

    # SQL from decision tree
    
    def get_lineage(tree, feature_names):
         left      = tree.tree_.children_left
         right     = tree.tree_.children_right
         threshold = tree.tree_.threshold
         features  = [feature_names[i] for i in tree.tree_.feature]
         le='<='               
         g ='>'
         # get ids of child nodes
         idx = np.argwhere(left == -1)[:,0]     
    
         def recurse(left, right, child, lineage=None):          
              if lineage is None:
                   lineage = [child]
              if child in left:
                   parent = np.where(left == child)[0].item()
                   split = 'l'
              else:
                   parent = np.where(right == child)[0].item()
                   split = 'r'
              lineage.append((parent, split, threshold[parent], features[parent]))
              if parent == 0:
                   lineage.reverse()
                   return lineage
              else:
                   return recurse(left, right, parent, lineage)
         print 'case '
         for j,child in enumerate(idx):
            clause=' when '
            for node in recurse(left, right, child):
                if len(str(node))<3:
                    continue
                i=node
                if i[1]=='l':  sign=le 
                else: sign=g
                clause=clause+i[3]+sign+str(i[2])+' and '
            clause=clause[:-4]+' then '+str(j)
            print clause
         print 'else 99 end as clusters'
    
  • 1

    显然很久以前有人已经决定尝试将以下功能添加到官方scikit的树导出功能(基本上只支持export_graphviz)

    def export_dict(tree, feature_names=None, max_depth=None) :
        """Export a decision tree in dict format.
    

    这是他的全部承诺:

    https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

    不完全确定这个评论发生了什么 . 但您也可以尝试使用该功能 .

    我认为这需要向scikit的优秀人员提供严格的文档请求 - 学习如何正确地记录 sklearn.tree.Tree API,这是 DecisionTreeClassifier 作为其属性 tree_ 公开的基础树结构 .

  • 44

    这是一种使用SKompiler库将整个树转换为单个(不一定是人类可读的)python表达式的方法:

    from skompiler import skompile
    skompile(dtree.predict).to('python/code')
    

相关问题