如何获得sklearn决策树的每个分支?

Xerac

我有这个决策树,我想从中提取每个分支。图像是树的一部分,因为原始树要大得多,但不适用于单个图像。

在此处输入图片说明

我不是想像打印树的规则

Rules used to predict sample 1400:

decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)

或类似:

The binary tree structure has 7 nodes and has the following tree structure:

node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
    node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
    node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.

我想要达到的目标是:

branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01

基本上,我正在尝试使用类和gini分数获取从顶部到叶节点(完整路径)的每个分支。我该如何实现?

米格尔·特雷霍(Miguel Trejo)

考虑sklearn文档中的鸢尾花数据集示例,我们遵循以下步骤。

1.生成示例决策树

来自文档的代码

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)

2.检索分支路径

首先,我们从树中检索以下值

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value

在内部,retrieve_branches我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们使用一条yield语句返回分支路径

def retrieve_branches(number_nodes, children_left_list, children_right_list):
    """Retrieve decision tree branches"""
    
    # Calculate if a node is a leaf
    is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
    
    # Store the branches paths
    paths = []
    
    for i in range(number_nodes):
        if is_leaves_list[i]:
            # Search leaf node in previous paths
            end_node = [path[-1] for path in paths]

            # If it is a leave node yield the path
            if i in end_node:
                output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
                yield output

        else:
            
            # Origin and end nodes
            origin, end_l, end_r = i, children_left_list[i], children_right_list[i]

            # Iterate over previous paths to add nodes
            for index, path in enumerate(paths):
                if origin == path[-1]:
                    paths[index] = path + [end_l]
                    paths.append(path + [end_r])

            # Initialize path in first iteration
            if i == 0:
                paths.append([i, children_left[i]])
                paths.append([i, children_right[i]])

要调用retrieve_branches刚刚通n_nodeschildren_leftchildren_right与将要存储和更新分支路径的空单。最终显示如下。

all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>> 
[[0, 1],
 [0, 2, 3, 5],
 [0, 2, 3, 6, 7],
 [0, 2, 3, 6, 8],
 [0, 2, 4, 9],
 [0, 2, 4, 10]]

3.分行的路径,价值和基尼

可以从的featurethresholdclf.tree_以及叶节点上的杂质clf.tree_.impurity和值获得规则clf.tree_.value

for index, branch in enumerate(all_branches):
    leaf_index = branch[-1]
    print(f'Branch: {index}, Path: {branch}')
    print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
    print(f'Value: {value[leaf_index]}')
    print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
    print(f"---------------------------------------------------------------------------------------\n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37.  0.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0.  0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

如何检索通向sklearn决策树的每个叶节点的完整分支路径?

来自分类Dev

如何检索通向sklearn决策树的每个叶节点的完整分支路径?

来自分类Dev

如何从每个节点到熊猫布尔条件中提取sklearn决策树规则?

来自分类Dev

遍历sklearn决策树

来自分类Dev

如何存储决策树

来自分类Dev

sklearn决策树图的Pydot错误

来自分类Dev

sklearn中的交叉验证+决策树

来自分类Dev

决策树 sklearn : PlayTennis 数据集

来自分类Dev

遍历决策树并捕获每个节点

来自分类Dev

如何在R中的决策树中指定分支数

来自分类Dev

重复的行如何影响决策树?

来自分类Dev

如何重塑决策树的数据?

来自分类Dev

如何使用sklearn从决策树模型提高预测的准确性?

来自分类Dev

如何在新数据上使用决策树回归器?(Python,Pandas,Sklearn)

来自分类Dev

决策树-边缘/分支不可见

来自分类Dev

如何从决策树中为每个节点获取百分比

来自分类Dev

如何使我的决策树模型在每个节点上提出问题

来自分类Dev

从sklearn Python的决策树在pydot中创建图形

来自分类Dev

查找sklearn决策树分类器的随机状态

来自分类Dev

Sklearn决策树分类器-动物猜谜游戏

来自分类Dev

决策树sklearn:预测准确性100%

来自分类Dev

决策树深度

来自分类Dev

如何显示此 scikit-learn 决策树脚本的图形决策树?

来自分类Dev

获得随机森林中决策树的价值

来自分类Dev

在 sklearn 中,one-hot encoding 如何在构建具有分类特征的决策树时提供帮助?

来自分类Dev

如何从scikit-learn解释决策树

来自分类Dev

如何计算决策树的泛化错误率

来自分类Dev

SPARK:如何为LabeledPoint中的决策树创建categoricalFeaturesInfo?

来自分类Dev

如何基于if-then决策树设置ruby变量

Related 相关文章

  1. 1

    如何检索通向sklearn决策树的每个叶节点的完整分支路径?

  2. 2

    如何检索通向sklearn决策树的每个叶节点的完整分支路径?

  3. 3

    如何从每个节点到熊猫布尔条件中提取sklearn决策树规则?

  4. 4

    遍历sklearn决策树

  5. 5

    如何存储决策树

  6. 6

    sklearn决策树图的Pydot错误

  7. 7

    sklearn中的交叉验证+决策树

  8. 8

    决策树 sklearn : PlayTennis 数据集

  9. 9

    遍历决策树并捕获每个节点

  10. 10

    如何在R中的决策树中指定分支数

  11. 11

    重复的行如何影响决策树?

  12. 12

    如何重塑决策树的数据?

  13. 13

    如何使用sklearn从决策树模型提高预测的准确性?

  14. 14

    如何在新数据上使用决策树回归器?(Python,Pandas,Sklearn)

  15. 15

    决策树-边缘/分支不可见

  16. 16

    如何从决策树中为每个节点获取百分比

  17. 17

    如何使我的决策树模型在每个节点上提出问题

  18. 18

    从sklearn Python的决策树在pydot中创建图形

  19. 19

    查找sklearn决策树分类器的随机状态

  20. 20

    Sklearn决策树分类器-动物猜谜游戏

  21. 21

    决策树sklearn:预测准确性100%

  22. 22

    决策树深度

  23. 23

    如何显示此 scikit-learn 决策树脚本的图形决策树?

  24. 24

    获得随机森林中决策树的价值

  25. 25

    在 sklearn 中,one-hot encoding 如何在构建具有分类特征的决策树时提供帮助?

  26. 26

    如何从scikit-learn解释决策树

  27. 27

    如何计算决策树的泛化错误率

  28. 28

    SPARK:如何为LabeledPoint中的决策树创建categoricalFeaturesInfo?

  29. 29

    如何基于if-then决策树设置ruby变量

热门标签

归档