@@ -177,11 +177,13 @@ def test(model, data_loader, num_train_batches, epoch, writer):
177177 # Get the reconstructed images of the last batch
178178 if args .use_reconstruction_loss :
179179 reconstruction = model .decoder (output , target )
180- image_width = 28 # MNIST digit image width
181- image_height = 28 # MNIST digit image height
182- image_channel = 1 # MNIST digit image channel
180+ # Input image size and number of channel.
181+ # By default, for MNIST, the image width and height is 28x28 and 1 channel for black/white.
182+ image_width = args .input_width
183+ image_height = args .input_height
184+ image_channel = args .num_conv_in_channel
183185 recon_img = reconstruction .view (- 1 , image_channel , image_width , image_height )
184- assert recon_img .size () == torch .Size ([batch_size , 1 , 28 , 28 ])
186+ assert recon_img .size () == torch .Size ([batch_size , image_channel , image_width , image_height ])
185187
186188 # Save the image into file system
187189 utils .save_image (recon_img , 'results/recons_image_test_{}_{}.png' .format (epoch , global_step ))
@@ -264,6 +266,11 @@ def main():
264266 help = 'use an additional reconstruction loss. default=True' )
265267 parser .add_argument ('--regularization-scale' , type = float , default = 0.0005 ,
266268 help = 'regularization coefficient for reconstruction loss. default=0.0005' )
269+ parser .add_argument ('--dataset' , help = 'the name of dataset (mnist, cifar10)' , default = 'mnist' )
270+ parser .add_argument ('--input-width' , type = int ,
271+ default = 28 , help = 'input image width to the convolution. default=28 for MNIST' )
272+ parser .add_argument ('--input-height' , type = int ,
273+ default = 28 , help = 'input image height to the convolution. default=28 for MNIST' )
267274
268275 args = parser .parse_args ()
269276
@@ -278,7 +285,7 @@ def main():
278285 torch .cuda .manual_seed (args .seed )
279286
280287 # Load data
281- train_loader , test_loader = utils .load_mnist (args )
288+ train_loader , test_loader = utils .load_data (args )
282289
283290 # Build Capsule Network
284291 print ('===> Building model' )
@@ -291,6 +298,8 @@ def main():
291298 num_routing = args .num_routing ,
292299 use_reconstruction_loss = args .use_reconstruction_loss ,
293300 regularization_scale = args .regularization_scale ,
301+ input_width = args .input_width ,
302+ input_height = args .input_height ,
294303 cuda_enabled = args .cuda )
295304
296305 if args .cuda :
@@ -307,12 +316,14 @@ def main():
307316 for name , param in model .named_parameters ():
308317 print ('{}: {}' .format (name , list (param .size ())))
309318
310- # CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnet.
319+ # CapsNet has:
320+ # - 8.2M parameters and 6.8M parameters without the reconstruction subnet on MNIST.
321+ # - 11.8M parameters and 8.0M parameters without the reconstruction subnet on CIFAR10.
311322 num_params = sum ([param .nelement () for param in model .parameters ()])
312323
313324 # The coupling coefficients c_ij are not included in the parameter list,
314- # we need to add them manually, which is 1152 * 10 = 11520.
315- print ('\n Total number of parameters: {}\n ' .format (num_params + 11520 ))
325+ # we need to add them manually, which is 1152 * 10 = 11520 (on MNIST) or 2048 * 10 (on CIFAR10)
326+ print ('\n Total number of parameters: {}\n ' .format (num_params + ( 11520 if args . dataset == 'mnist' else 20480 ) ))
316327
317328 # Optimizer
318329 optimizer = optim .Adam (model .parameters (), lr = args .lr )
0 commit comments