-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathCIFAR10_test.py
More file actions
24 lines (21 loc) · 763 Bytes
/
CIFAR10_test.py
File metadata and controls
24 lines (21 loc) · 763 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from PIL import Image
import torchvision
from CIFAR10_model import MyCIFAR10
import torch
image_path = "C:/Users/xw112/Pictures/Screenshots/Screenshot 2025-05-21 002600.png"
image = Image.open(image_path).convert("RGB")
# image.show()
transforms = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor(),])
image_tensor = transforms(image)
# print(image_tensor.shape)
model = MyCIFAR10()
model.load_state_dict(torch.load("CIFAR10_model.pth"))
model.eval()
image_tensor = image_tensor.reshape(1, 3, 32, 32)
with torch.no_grad():
output = model(image_tensor)
pred = torch.argmax(output, dim=1)
print(image_tensor)
print(output)
print(pred.item())