-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
89 lines (69 loc) · 3.41 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
class PatchesDataset(Dataset):
"""
A standard PyTorch definition of Dataset which defines the functions __len__ and __getitem__.
"""
def __init__(self, data_dir, transform):
"""
Store the filenames of the jpgs to use. Specifies transforms to apply on images.
Args:
data_dir: (string) directory containing the dataset
transform: (torchvision.transforms) transformation to apply on image
"""
self.filenames = os.listdir(data_dir)
self.filenames = [os.path.join(data_dir, f) for f in self.filenames]# if f.endswith('.jpg')]
# self.labels = [int(os.path.split(filename)[-1][0]) for filename in self.filenames]
self.transform = transform
def __len__(self):
# return size of dataset
return len(self.filenames)
def __getitem__(self, idx):
"""
Fetch index idx image and labels from dataset. Perform transforms on image.
Args:
idx: (int) index in [0, 1, ..., size_of_dataset-1]
Returns:
image: (Tensor) transformed image
label: (int) corresponding label of image
"""
image = Image.open(self.filenames[idx]) # PIL image
image = self.transform(image)
return image#, self.labels[idx]
def fetch_dataloader(types, data_dir, params, batch_size, rotation_deg=0, translation=0, scaling=1, shearing_deg=0):
"""
Fetches the DataLoader object for each type in types from data_dir.
Args:
types: (list) has one or more of 'train', 'val', 'test' depending on which data is required
data_dir: (string) directory containing the dataset
params: (Params) hyperparameters
Returns:
data: (dict) contains the DataLoader object for each type in types
"""
# define a training image loader that specifies transforms on images. See documentation for more details.
train_transformer = transforms.Compose([
transforms.CenterCrop(64),
transforms.RandomAffine(rotation_deg, translate=(translation, translation), scale=(1.0, scaling), shear=shearing_deg),
# transforms.Resize(64), # resize the image to 64x64 (remove if images are already 64x64)
transforms.ToTensor()]) # transform it into a torch tensor
# loader for evaluation, no horizontal flip
eval_transformer = transforms.Compose([
transforms.CenterCrop(64),
transforms.ToTensor()]) # transform it into a torch tensor
dataloaders = {}
for split in ['train', 'validation', 'test']:
if split in types:
path = os.path.join(data_dir, "{}".format(split), "class0/")
# use the train_transformer if training data, else use eval_transformer without random flip
if split == 'train':
dl = DataLoader(PatchesDataset(path, train_transformer), batch_size=batch_size, shuffle=True,
num_workers=params.num_workers,
pin_memory=params.cuda)
else:
dl = DataLoader(PatchesDataset(path, eval_transformer), batch_size=batch_size, shuffle=False,
num_workers=params.num_workers,
pin_memory=params.cuda)
dataloaders[split] = dl
return dataloaders