1
+
2
+ import numpy as np
3
+ import matplotlib .pyplot as plt
4
+ from sklearn .tree import DecisionTreeClassifier
5
+ from ipywidgets import interact
6
+
7
+
8
+ def visualize_tree (estimator , X , y , boundaries = True ,
9
+ xlim = None , ylim = None , ax = None ):
10
+ ax = ax or plt .gca ()
11
+
12
+ # Plot the training points
13
+ ax .scatter (X [:, 0 ], X [:, 1 ], c = y , s = 30 , cmap = 'viridis' ,
14
+ clim = (y .min (), y .max ()), zorder = 3 )
15
+ ax .axis ('tight' )
16
+ ax .axis ('off' )
17
+ if xlim is None :
18
+ xlim = ax .get_xlim ()
19
+ if ylim is None :
20
+ ylim = ax .get_ylim ()
21
+
22
+ # fit the estimator
23
+ estimator .fit (X , y )
24
+ xx , yy = np .meshgrid (np .linspace (* xlim , num = 200 ),
25
+ np .linspace (* ylim , num = 200 ))
26
+ Z = estimator .predict (np .c_ [xx .ravel (), yy .ravel ()])
27
+
28
+ # Put the result into a color plot
29
+ n_classes = len (np .unique (y ))
30
+ Z = Z .reshape (xx .shape )
31
+ contours = ax .contourf (xx , yy , Z , alpha = 0.3 ,
32
+ levels = np .arange (n_classes + 1 ) - 0.5 ,
33
+ cmap = 'viridis' , clim = (y .min (), y .max ()),
34
+ zorder = 1 )
35
+
36
+ ax .set (xlim = xlim , ylim = ylim )
37
+
38
+ # Plot the decision boundaries
39
+ def plot_boundaries (i , xlim , ylim ):
40
+ if i >= 0 :
41
+ tree = estimator .tree_
42
+
43
+ if tree .feature [i ] == 0 :
44
+ ax .plot ([tree .threshold [i ], tree .threshold [i ]], ylim , '-k' , zorder = 2 )
45
+ plot_boundaries (tree .children_left [i ],
46
+ [xlim [0 ], tree .threshold [i ]], ylim )
47
+ plot_boundaries (tree .children_right [i ],
48
+ [tree .threshold [i ], xlim [1 ]], ylim )
49
+
50
+ elif tree .feature [i ] == 1 :
51
+ ax .plot (xlim , [tree .threshold [i ], tree .threshold [i ]], '-k' , zorder = 2 )
52
+ plot_boundaries (tree .children_left [i ], xlim ,
53
+ [ylim [0 ], tree .threshold [i ]])
54
+ plot_boundaries (tree .children_right [i ], xlim ,
55
+ [tree .threshold [i ], ylim [1 ]])
56
+
57
+ if boundaries :
58
+ plot_boundaries (0 , xlim , ylim )
59
+
60
+
61
+ def plot_tree_interactive (X , y ):
62
+ def interactive_tree (depth = 5 ):
63
+ clf = DecisionTreeClassifier (max_depth = depth , random_state = 0 )
64
+ visualize_tree (clf , X , y )
65
+
66
+ return interact (interactive_tree , depth = [1 , 5 ])
67
+
68
+
69
+ def randomized_tree_interactive (X , y ):
70
+ N = int (0.75 * X .shape [0 ])
71
+
72
+ xlim = (X [:, 0 ].min (), X [:, 0 ].max ())
73
+ ylim = (X [:, 1 ].min (), X [:, 1 ].max ())
74
+
75
+ def fit_randomized_tree (random_state = 0 ):
76
+ clf = DecisionTreeClassifier (max_depth = 15 )
77
+ i = np .arange (len (y ))
78
+ rng = np .random .RandomState (random_state )
79
+ rng .shuffle (i )
80
+ visualize_tree (clf , X [i [:N ]], y [i [:N ]], boundaries = False ,
81
+ xlim = xlim , ylim = ylim )
82
+
83
+ interact (fit_randomized_tree , random_state = [0 , 100 ]);
0 commit comments