-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTask_3.py
More file actions
117 lines (98 loc) · 3.68 KB
/
Task_3.py
File metadata and controls
117 lines (98 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import joblib
import numpy as np
# 1) LOAD
# bank.csv from UCI is semicolon-delimited
df = pd.read_csv("bank.csv", sep=';')
# 2) TARGET & BASIC CLEANING
# Convert target y: 'yes'->1, 'no'->0
df['y'] = (df['y'] == 'yes').astype(int)
# IMPORTANT: 'duration' is known only after the call is finished -> data leakage.
# Drop it for modeling.
if 'duration' in df.columns:
df = df.drop(columns=['duration'])
# 3) FEATURES / TYPES
y = df['y']
X = df.drop(columns=['y'])
numeric_features = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
categorical_features = X.select_dtypes(include=['object']).columns.tolist()
# 4) PREPROCESSING
numeric_pipe = Pipeline([
("imputer", SimpleImputer(strategy="median"))
])
categorical_pipe = Pipeline([
("imputer", SimpleImputer(strategy="most_frequent")), # handles 'unknown' too
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
preprocess = ColumnTransformer([
("num", numeric_pipe, numeric_features),
("cat", categorical_pipe, categorical_features)
])
# 5) SPLIT
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
# 6) MODEL
clf = Pipeline([
("prep", preprocess),
("model", DecisionTreeClassifier(
criterion="gini",
max_depth=8, # sensible defaults; keeps tree readable
min_samples_split=20,
min_samples_leaf=10,
random_state=42
))
])
# 7) TRAIN
clf.fit(X_train, y_train)
# 8) EVALUATE
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred, zero_division=0)
rec = recall_score(y_test, y_pred, zero_division=0)
f1 = f1_score(y_test, y_pred, zero_division=0)
print("\n=== Metrics on Test Set ===")
print(f"Accuracy : {acc:.3f}")
print(f"Precision: {prec:.3f}")
print(f"Recall : {rec:.3f}")
print(f"F1-score : {f1:.3f}\n")
print(classification_report(y_test, y_pred, target_names=["No", "Yes"]))
# 9) CONFUSION MATRIX (saved)
cm = confusion_matrix(y_test, y_pred)
fig = plt.figure(figsize=(5,4))
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix")
plt.xticks([0,1], ["No","Yes"])
plt.yticks([0,1], ["No","Yes"])
for (i, j), v in np.ndenumerate(cm):
plt.text(j, i, str(v), ha='center', va='center', fontsize=12)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.savefig("confusion_matrix.png", dpi=200)
plt.close(fig)
# 10) OPTIONAL: TREE PLOT (feature names after one-hot)
# Get feature names from the preprocessor
oh = clf.named_steps["prep"].named_transformers_["cat"]["onehot"]
cat_names = oh.get_feature_names_out(categorical_features)
feat_names = np.r_[numeric_features, cat_names]
# Access the trained DecisionTree
tree_est = clf.named_steps["model"]
fig2 = plt.figure(figsize=(14, 8))
plot_tree(tree_est, feature_names=feat_names, class_names=["No","Yes"],
filled=True, max_depth=3, fontsize=8) # show top 3 levels for readability
plt.title("Decision Tree (Top Levels)")
plt.tight_layout()
plt.savefig("decision_tree.png", dpi=200)
plt.close(fig2)
# 11) SAVE MODEL
joblib.dump(clf, "dt_bank.pkl")
print("Saved: confusion_matrix.png, decision_tree.png, dt_bank.pkl")