-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataloaders.py
151 lines (118 loc) · 4.34 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import glob
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
def mnist(batch_size=128, num_colors=256, size=28,
path_to_data='../mnist_data'):
"""MNIST dataloader with (28, 28) images.
Parameters
----------
batch_size : int
num_colors : int
Number of colors to quantize images into. Typically 256, but can be
lower for e.g. binary images.
size : int
Size (height and width) of each image. Default is 28 for no resizing.
path_to_data : string
Path to MNIST data files.
"""
quantize = get_quantize_func(num_colors)
all_transforms = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Lambda(lambda x: quantize(x))
])
train_data = datasets.MNIST(path_to_data, train=True, download=True,
transform=all_transforms)
test_data = datasets.MNIST(path_to_data, train=False,
transform=all_transforms)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
def celeba(batch_size=128, num_colors=256, size=178, crop=178, grayscale=False,
shuffle=True, path_to_data='../celeba_data'):
"""CelebA dataloader with square images. Note original CelebA images have
shape (218, 178), this dataloader center crops these images to be (178, 178)
by default.
Parameters
----------
batch_size : int
num_colors : int
Number of colors to quantize images into. Typically 256, but can be
lower for e.g. binary images.
size : int
Size (height and width) of each image.
crop : int
Size of center crop. This crop happens *before* the resizing.
grayscale : bool
If True converts images to grayscale.
shuffle : bool
If True shuffles images.
path_to_data : string
Path to CelebA image files.
"""
quantize = get_quantize_func(num_colors)
if grayscale:
transform = transforms.Compose([
transforms.CenterCrop(crop),
transforms.Resize(size),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Lambda(lambda x: quantize(x))
])
else:
transform = transforms.Compose([
transforms.CenterCrop(crop),
transforms.Resize(size),
transforms.ToTensor(),
transforms.Lambda(lambda x: quantize(x))
])
celeba_data = CelebADataset(path_to_data,
transform=transform)
celeba_loader = DataLoader(celeba_data, batch_size=batch_size,
shuffle=shuffle)
return celeba_loader
class CelebADataset(Dataset):
"""CelebA dataset.
Parameters
----------
path_to_data : string
Path to CelebA images.
subsample : int
Only load every |subsample| number of images.
transform : None or one of torchvision.transforms instances
"""
def __init__(self, path_to_data, subsample=1, transform=None):
self.img_paths = glob.glob(path_to_data + '/*')[::subsample]
self.transform = transform
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
sample_path = self.img_paths[idx]
sample = Image.open(sample_path)
if self.transform:
sample = self.transform(sample)
# Since there are no labels, return 0 for the "label"
return sample, 0
def get_quantize_func(num_colors):
"""Returns a quantization function which can be used to set the number of
colors in an image.
Parameters
----------
num_colors : int
Number of bins to quantize image into. Should be between 2 and 256.
"""
def quantize_func(batch):
"""Takes as input a float tensor with values in the 0 - 1 range and
outputs a long tensor with integer values corresponding to each
quantization bin.
Parameters
----------
batch : torch.Tensor
Values in 0 - 1 range.
"""
if num_colors == 2:
return (batch > 0.5).long()
else:
return (batch * (num_colors - 1)).long()
return quantize_func