Skip to content

Commit 9f5aec8

Browse files
update
1 parent d3dddad commit 9f5aec8

File tree

2 files changed

+618
-62
lines changed

2 files changed

+618
-62
lines changed

03-neural-network-classification-pytorch.ipynb

+605-61
Large diffs are not rendered by default.

helper_function.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import zipfile
55
import os
66
import matplotlib.pyplot as plt
7+
import torch
8+
import numpy as np
79

810
def performance_metrics(y_true, y_pred):
911
model_accuracy = accuracy_score(y_true, y_pred) * 100
@@ -151,4 +153,14 @@ def plot_decision_boundary(model: torch.nn.Module, x: torch.Tensor, y: torch.Ten
151153
plt.contourf(xx, yy, y_pred, cmap = plt.cm.RdYlBu, alpha = 0.7)
152154
plt.scatter(x[:, 0], x[:, 1], c = y, s = 40, cmap = plt.cm.RdYlBu)
153155
plt.xlim(xx.min(), xx.max())
154-
plt.ylim(yy.min(), yy.max())
156+
plt.ylim(yy.min(), yy.max())
157+
158+
def plot_prediction(train_data, train_label, test_data, test_label, prediction = None):
159+
plt.figure(figsize = (10, 7))
160+
plt.scatter(train_data, train_label, c = "b", s = 4, label = "Training data")
161+
plt.scatter(test_data, test_label, c = "g", s = 4, label = "Testing data")
162+
163+
if predictions is not None:
164+
plt.scatter(test_data, prediction, c = "r", s = 4, label = "Prediction")
165+
166+
plt.legend(prop = {"size" : 14})

0 commit comments

Comments
 (0)