-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCART.py
87 lines (77 loc) · 2.88 KB
/
CART.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import pickle
from config import model_save_road
class Cart:
# 输入一个numpy类型的二维数据矩阵
def __init__(self, x):
self.data = x
self.row = np.size(x, 0) # 行数
self.col = np.size(x, 1) # 列数
self.choose_col = 0
self.choose_row = 0
self.model_road = "BTree.pickle"
self.model = []
# p为一列的数据
# p为输入的只有0和1的数据列
def gini(self, p):
yes = np.sum(p)
no = self.row - yes
res = 1 - (yes/self.row) ** 2 - (no/self.row) ** 2
return res
def buildTree(self, data_for_every_tree, k):
# 停止条件
n, m = np.shape(data_for_every_tree)
judge = data_for_every_tree[:, m - 1].copy()
g = np.sum(judge)
if k + 1 == self.col:
if g > n - g:
return ["YES"]
else:
return ["NO"]
if g == 0:
return ["NO"]
if g == n:
return ["YES"]
ans = []
data_for_every_tree = data_for_every_tree[np.argsort(data_for_every_tree[:, k])] # 按所需比较的列总体排序
gini_d = 1 - (g/n) * (g/n) - ((n - g)/n) * ((n - g)/n)
import_col = data_for_every_tree[:, k].copy() # 拿出我们需要比较的一列
import_col = import_col
decision_col = [] # 计算所有分裂点
for i in range(n - 1):
decision_col.append((import_col[i] + import_col[i + 1]) / 2)
# 计算最优分类点
gini_array = []
# 首先枚举所有分类点
for i in range(n - 1):
d1 = data_for_every_tree[0:i + 1, :]
d2 = data_for_every_tree[i + 1:n, :]
gini_array.append(((i+1)/n)*self.gini(d1[:, m-1])+((n-i-1)/n)*self.gini(d2[:, m - 1])-gini_d) # 算出gini系数并放入数组
pos = 0
minn = 100000
for i in range(n - 1):
if gini_array[i] < minn:
minn = gini_array[i]
pos = i # 得到最优分类位置
less_data = data_for_every_tree[0:pos, :].copy()
greater_data = data_for_every_tree[pos + 1: n - 1, :].copy()
# print("root = ",minn,"left_size = ",np.shape(less_data),"right_size = ",np.shape(greater_data))
ans.append(decision_col[pos])
ans.append(self.buildTree(less_data, k+1))
ans.append(self.buildTree(greater_data, k+1))
return ans
def train_model(self):
res = []
res = self.buildTree(self.data, 0)
self.model = res
def save_model(self, model):
with open(self.model_road, 'wb') as f:
pickle.dump(model, f)
def load_model(self):
with open(self.model_road, 'rb') as f:
model = pickle.load(f)
return model
def load_model():
with open(model_save_road, 'rb') as f:
model = pickle.load(f)
return model