-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_support_set.py
More file actions
44 lines (37 loc) · 1.44 KB
/
Copy pathbuild_support_set.py
File metadata and controls
44 lines (37 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import random
from PIL import Image
import torch
import torchvision.transforms as T
# Image preprocessing
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_images_from_folder(folder_path, label, num_samples):
images, labels = [], []
files = os.listdir(folder_path)
sampled = random.sample(files, min(len(files), num_samples))
for fname in sampled:
path = os.path.join(folder_path, fname)
try:
image = Image.open(path).convert("RGB")
tensor = transform(image)
images.append(tensor)
labels.append(label)
except:
continue # skip corrupted files
return images, labels
def build_support_set(data_dir, num_per_class=5):
real_images, real_labels = load_images_from_folder(os.path.join(data_dir, 'training_real'), label=0, num_samples=num_per_class)
fake_images, fake_labels = load_images_from_folder(os.path.join(data_dir, 'training_fake'), label=1, num_samples=num_per_class)
all_images = real_images + fake_images
all_labels = real_labels + fake_labels
support_x = torch.stack(all_images)
support_y = torch.tensor(all_labels)
return support_x, support_y
# Usage (optional test)
# support_x, support_y = build_support_set("data", num_per_class=5)
# print("Support Set Shape:", support_x.shape, support_y.shape)