Skip to content

Commit d399766

Browse files
committed
New dataloader for CIFAR10 and fix save_image function
1 parent 6e3f004 commit d399766

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

utils.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,6 @@
1616
import argparse
1717

1818

19-
# Normalize MNIST dataset.
20-
data_transform = transforms.Compose([
21-
transforms.ToTensor(),
22-
transforms.Normalize((0.1307,), (0.3081,))
23-
])
24-
25-
2619
def one_hot_encode(target, length):
2720
"""Converts batches of class indices to classes of one-hot vectors."""
2821
batch_s = target.size(0)
@@ -45,18 +38,24 @@ def load_mnist(args):
4538
"""Load MNIST dataset.
4639
The data is split and normalized between train and test sets.
4740
"""
41+
# Normalize MNIST dataset.
42+
data_transform = transforms.Compose([
43+
transforms.ToTensor(),
44+
transforms.Normalize((0.1307,), (0.3081,))
45+
])
46+
4847
kwargs = {'num_workers': args.threads,
4948
'pin_memory': True} if args.cuda else {}
5049

51-
print('===> Loading training datasets')
50+
print('===> Loading MNIST training datasets')
5251
# MNIST dataset
5352
training_set = datasets.MNIST(
5453
'./data', train=True, download=True, transform=data_transform)
5554
# Input pipeline
5655
training_data_loader = DataLoader(
5756
training_set, batch_size=args.batch_size, shuffle=True, **kwargs)
5857

59-
print('===> Loading testing datasets')
58+
print('===> Loading MNIST testing datasets')
6059
testing_set = datasets.MNIST(
6160
'./data', train=False, download=True, transform=data_transform)
6261
testing_data_loader = DataLoader(
@@ -65,6 +64,50 @@ def load_mnist(args):
6564
return training_data_loader, testing_data_loader
6665

6766

67+
def load_cifar10(args):
68+
"""Load CIFAR10 dataset.
69+
The data is split and normalized between train and test sets.
70+
"""
71+
# Normalize CIFAR10 dataset.
72+
data_transform = transforms.Compose([
73+
transforms.ToTensor(),
74+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
75+
])
76+
77+
kwargs = {'num_workers': args.threads,
78+
'pin_memory': True} if args.cuda else {}
79+
80+
print('===> Loading CIFAR10 training datasets')
81+
# CIFAR10 dataset
82+
training_set = datasets.CIFAR10(
83+
'./data', train=True, download=True, transform=data_transform)
84+
# Input pipeline
85+
training_data_loader = DataLoader(
86+
training_set, batch_size=args.batch_size, shuffle=True, **kwargs)
87+
88+
print('===> Loading CIFAR10 testing datasets')
89+
testing_set = datasets.CIFAR10(
90+
'./data', train=False, download=True, transform=data_transform)
91+
testing_data_loader = DataLoader(
92+
testing_set, batch_size=args.test_batch_size, shuffle=True, **kwargs)
93+
94+
return training_data_loader, testing_data_loader
95+
96+
97+
def load_data(args):
98+
"""
99+
Load dataset.
100+
"""
101+
dst = args.dataset
102+
103+
if dst == 'mnist':
104+
return load_mnist(args)
105+
elif dst == 'cifar10':
106+
return load_cifar10(args)
107+
else:
108+
raise Exception('Invalid dataset, please check the name of dataset:', dst)
109+
110+
68111
def squash(sj, dim=2):
69112
"""
70113
The non-linear activation used in Capsule.
@@ -132,8 +175,12 @@ def save_image(image, file_name):
132175
Save a given image into an image file
133176
"""
134177
# Check number of channels in an image.
135-
if image.size(1) == 1:
136-
# Grayscale
178+
if image.size(1) == 2:
179+
# 2-channel image
180+
zeros = torch.zeros(image.size(0), 1, image.size(2), image.size(3))
181+
image_tensor = torch.cat([zeros, image.data.cpu()], dim=1)
182+
else:
183+
# Grayscale or RGB image
137184
image_tensor = image.data.cpu() # get Tensor from Variable
138185

139186
vutils.save_image(image_tensor, file_name)

0 commit comments

Comments
 (0)