1616import argparse
1717
1818
19- # Normalize MNIST dataset.
20- data_transform = transforms .Compose ([
21- transforms .ToTensor (),
22- transforms .Normalize ((0.1307 ,), (0.3081 ,))
23- ])
24-
25-
2619def one_hot_encode (target , length ):
2720 """Converts batches of class indices to classes of one-hot vectors."""
2821 batch_s = target .size (0 )
@@ -45,18 +38,24 @@ def load_mnist(args):
4538 """Load MNIST dataset.
4639 The data is split and normalized between train and test sets.
4740 """
41+ # Normalize MNIST dataset.
42+ data_transform = transforms .Compose ([
43+ transforms .ToTensor (),
44+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
45+ ])
46+
4847 kwargs = {'num_workers' : args .threads ,
4948 'pin_memory' : True } if args .cuda else {}
5049
51- print ('===> Loading training datasets' )
50+ print ('===> Loading MNIST training datasets' )
5251 # MNIST dataset
5352 training_set = datasets .MNIST (
5453 './data' , train = True , download = True , transform = data_transform )
5554 # Input pipeline
5655 training_data_loader = DataLoader (
5756 training_set , batch_size = args .batch_size , shuffle = True , ** kwargs )
5857
59- print ('===> Loading testing datasets' )
58+ print ('===> Loading MNIST testing datasets' )
6059 testing_set = datasets .MNIST (
6160 './data' , train = False , download = True , transform = data_transform )
6261 testing_data_loader = DataLoader (
@@ -65,6 +64,50 @@ def load_mnist(args):
6564 return training_data_loader , testing_data_loader
6665
6766
67+ def load_cifar10 (args ):
68+ """Load CIFAR10 dataset.
69+ The data is split and normalized between train and test sets.
70+ """
71+ # Normalize CIFAR10 dataset.
72+ data_transform = transforms .Compose ([
73+ transforms .ToTensor (),
74+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
75+ ])
76+
77+ kwargs = {'num_workers' : args .threads ,
78+ 'pin_memory' : True } if args .cuda else {}
79+
80+ print ('===> Loading CIFAR10 training datasets' )
81+ # CIFAR10 dataset
82+ training_set = datasets .CIFAR10 (
83+ './data' , train = True , download = True , transform = data_transform )
84+ # Input pipeline
85+ training_data_loader = DataLoader (
86+ training_set , batch_size = args .batch_size , shuffle = True , ** kwargs )
87+
88+ print ('===> Loading CIFAR10 testing datasets' )
89+ testing_set = datasets .CIFAR10 (
90+ './data' , train = False , download = True , transform = data_transform )
91+ testing_data_loader = DataLoader (
92+ testing_set , batch_size = args .test_batch_size , shuffle = True , ** kwargs )
93+
94+ return training_data_loader , testing_data_loader
95+
96+
97+ def load_data (args ):
98+ """
99+ Load dataset.
100+ """
101+ dst = args .dataset
102+
103+ if dst == 'mnist' :
104+ return load_mnist (args )
105+ elif dst == 'cifar10' :
106+ return load_cifar10 (args )
107+ else :
108+ raise Exception ('Invalid dataset, please check the name of dataset:' , dst )
109+
110+
68111def squash (sj , dim = 2 ):
69112 """
70113 The non-linear activation used in Capsule.
@@ -132,8 +175,12 @@ def save_image(image, file_name):
132175 Save a given image into an image file
133176 """
134177 # Check number of channels in an image.
135- if image .size (1 ) == 1 :
136- # Grayscale
178+ if image .size (1 ) == 2 :
179+ # 2-channel image
180+ zeros = torch .zeros (image .size (0 ), 1 , image .size (2 ), image .size (3 ))
181+ image_tensor = torch .cat ([zeros , image .data .cpu ()], dim = 1 )
182+ else :
183+ # Grayscale or RGB image
137184 image_tensor = image .data .cpu () # get Tensor from Variable
138185
139186 vutils .save_image (image_tensor , file_name )
0 commit comments