4
4
import zipfile
5
5
import os
6
6
import matplotlib .pyplot as plt
7
+ import torch
8
+ import numpy as np
7
9
8
10
def performance_metrics (y_true , y_pred ):
9
11
model_accuracy = accuracy_score (y_true , y_pred ) * 100
@@ -151,4 +153,14 @@ def plot_decision_boundary(model: torch.nn.Module, x: torch.Tensor, y: torch.Ten
151
153
plt .contourf (xx , yy , y_pred , cmap = plt .cm .RdYlBu , alpha = 0.7 )
152
154
plt .scatter (x [:, 0 ], x [:, 1 ], c = y , s = 40 , cmap = plt .cm .RdYlBu )
153
155
plt .xlim (xx .min (), xx .max ())
154
- plt .ylim (yy .min (), yy .max ())
156
+ plt .ylim (yy .min (), yy .max ())
157
+
158
+ def plot_prediction (train_data , train_label , test_data , test_label , prediction = None ):
159
+ plt .figure (figsize = (10 , 7 ))
160
+ plt .scatter (train_data , train_label , c = "b" , s = 4 , label = "Training data" )
161
+ plt .scatter (test_data , test_label , c = "g" , s = 4 , label = "Testing data" )
162
+
163
+ if predictions is not None :
164
+ plt .scatter (test_data , prediction , c = "r" , s = 4 , label = "Prediction" )
165
+
166
+ plt .legend (prop = {"size" : 14 })
0 commit comments