def print_decision_tree(tree, feature_names=None, offset_unit=' '):
'''Plots textual representation of rules of a decision tree
tree: scikit-learn representation of tree
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
offset_unit: a string of offset of the conditional block'''
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
print(offset+"return " + str(value[node]))
recurse(left, right, threshold, features, 0,0)
2
我一直在经历这个,但我需要以这种格式编写规则
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
所以我调整了@paulkernfeld(谢谢)的答案,你可以根据自己的需要进行定制
def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()
global k
k = 0
def recurse(node, depth, parent):
global k
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,')',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
2
修改了Zelazny7的代码,用于从决策树中获取SQL .
# SQL from decision tree
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
le='<='
g ='>'
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
print 'case '
for j,child in enumerate(idx):
clause=' when '
for node in recurse(left, right, child):
if len(str(node))<3:
continue
i=node
if i[1]=='l': sign=le
else: sign=g
clause=clause+i[3]+sign+str(i[2])+' and '
clause=clause[:-4]+' then '+str(j)
print clause
print 'else 99 end as clusters'
13 回答
我相信这个答案比其他答案更正确:
这打印出一个有效的Python函数 . 以下是尝试返回其输入的树的示例输出,该数字介于0和10之间 .
以下是我在其他答案中看到的一些绊脚石:
使用
tree_.threshold == -2
来判断一个节点是否是一个叶子是一个真正的决策节点,其阈值为-2?相反,你应该看tree.feature
或tree.children_*
.行
features = [feature_names[i] for i in tree_.feature]
与我的sklearn版本崩溃,因为tree.tree_.feature
的某些值为-2(特别是对于叶节点) .递归函数中不需要多个if语句,只需一个就可以了 .
我创建了自己的函数来从sklearn创建的决策树中提取规则:
此函数首先从节点(在子数组中由-1标识)开始,然后以递归方式查找父节点 . 我将此称为节点的“谱系” . 一路上,我 grab 了我需要创建的值,如果/ then / else SAS逻辑:
下面的元组集包含创建SAS if / then / else语句所需的一切 . 我不喜欢在SAS中使用
do
块,这就是我创建描述节点整个路径的逻辑的原因 . 元组之后的单个整数是路径中终端节点的ID . 所有前面的元组组合起来创建该节点 .我修改了Zelazny7提交的代码来打印一些伪代码:
如果你在同一个例子上调用
get_code(dt, df.columns)
,你将获得:在0.18.0版本中有一个新的DecisionTreeClassifier方法
decision_path
. 开发人员提供了广泛的(记录良好的)walkthrough .打印树结构的演练中的第一部分代码似乎没问题 . 但是,我修改了第二部分中的代码以询问一个样本 . 我的更改用
# <--
表示Edit 在拉取请求#8653和#10951中指出错误后,以下代码中由
# <--
标记的更改已在演练链接中更新 . 现在跟进起来要容易得多 .更改
sample_id
以查看其他样本的决策路径 . 我没有向开发人员询问这些变化,在完成示例时似乎更直观 .你可以看到一个有向图树 . 然后,
clf.tree_.feature
和clf.tree_.value
分别是分裂特征和节点值数组的节点阵列 . 您可以参考github source中的更多详细信息 .下面的代码是我在anaconda python 2.7下的方法,加上包名“pydot-ng”来制作带有决策规则的PDF文件 . 我希望它有所帮助 .
a tree graphy show here
只是因为每个人都非常乐于助人,我只想对Zelazny7和Daniele的漂亮解决方案进行修改 . 这个是用于python 2.7,带有选项卡使其更具可读性:
这 Build 在@paulkernfeld的答案之上 . 如果你有一个带有你的特征的数据框X和一个带有你的共鸣的目标数据框y你想知道哪个y值在哪个节点结束(以及相应地绘制它的 Ant )你可以做以下事情:
不是最优雅的版本,但它做的工作......
这是一个函数,在python 3下打印scikit-learn决策树的规则,并使用条件块的偏移量来使结构更具可读性:
我一直在经历这个,但我需要以这种格式编写规则
所以我调整了@paulkernfeld(谢谢)的答案,你可以根据自己的需要进行定制
修改了Zelazny7的代码,用于从决策树中获取SQL .
显然很久以前有人已经决定尝试将以下功能添加到官方scikit的树导出功能(基本上只支持export_graphviz)
这是他的全部承诺:
https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py
不完全确定这个评论发生了什么 . 但您也可以尝试使用该功能 .
我认为这需要向scikit的优秀人员提供严格的文档请求 - 学习如何正确地记录
sklearn.tree.Tree
API,这是DecisionTreeClassifier
作为其属性tree_
公开的基础树结构 .这是一种使用SKompiler库将整个树转换为单个(不一定是人类可读的)python表达式的方法: