1313
1414from __future__ import print_function
1515import argparse
16- import sys
17- import time
1816
1917import torch
2018import torch .optim as optim
@@ -57,8 +55,7 @@ def train(model, data_loader, optimizer, epoch):
5755 optimizer .step ()
5856
5957 if batch_idx % args .log_interval == 0 :
60- mesg = '{}\t Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
61- time .ctime (),
58+ mesg = 'Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
6259 epoch ,
6360 batch_idx * len (data ),
6461 len (data_loader .dataset ),
@@ -87,7 +84,7 @@ def test(model, data_loader):
8784 for data , target in data_loader :
8885 target_indices = target
8986 target_one_hot = utils .one_hot_encode (
90- target_indices , length = model .digits .num_units )
87+ target_indices , length = model .digits .num_unit )
9188
9289 data , target = Variable (data , volatile = True ), Variable (target_one_hot )
9390
@@ -133,12 +130,12 @@ def main():
133130 default = 128 , help = 'testing batch size. default=128' )
134131 parser .add_argument ('--loss-threshold' , type = float , default = 0.0001 ,
135132 help = 'stop training if loss goes below this threshold. default=0.0001' )
136- parser .add_argument (" --log-interval" , type = int , default = 1 ,
137- help = 'number of images after which the training loss is logged , default is 1 ' )
138- parser .add_argument ('--cuda' , action = 'store_true' ,
139- help = 'set it to 1 for running on GPU, 0 for CPU ' )
133+ parser .add_argument (' --log-interval' , type = int , default = 10 ,
134+ help = 'how many batches to wait before logging training status , default=10 ' )
135+ parser .add_argument ('--no- cuda' , action = 'store_true' , default = False ,
136+ help = 'disables CUDA training, default=false ' )
140137 parser .add_argument ('--threads' , type = int , default = 4 ,
141- help = 'number of threads for data loader to use' )
138+ help = 'number of threads for data loader to use, default=4 ' )
142139 parser .add_argument ('--seed' , type = int , default = 42 ,
143140 help = 'random seed for training. default=42' )
144141 parser .add_argument ('--num-conv-channel' , type = int , default = 256 ,
@@ -149,20 +146,18 @@ def main():
149146 default = 1152 , help = 'primary unit size. default=1152' )
150147 parser .add_argument ('--output-unit-size' , type = int ,
151148 default = 16 , help = 'output unit size. default=16' )
149+ parser .add_argument ('--num-routing' , type = int ,
150+ default = 3 , help = 'number of routing iteration. default=3' )
152151
153152 args = parser .parse_args ()
154153
155154 print (args )
156155
157156 # Check GPU or CUDA is available
158- cuda = args .cuda
159- if cuda and not torch .cuda .is_available ():
160- print (
161- "ERROR: No GPU/cuda is not available. Try running on CPU or run without --cuda" )
162- sys .exit (1 )
157+ args .cuda = not args .no_cuda and torch .cuda .is_available ()
163158
164159 torch .manual_seed (args .seed )
165- if cuda :
160+ if args . cuda :
166161 torch .cuda .manual_seed (args .seed )
167162
168163 # Load data
@@ -174,10 +169,11 @@ def main():
174169 num_primary_unit = args .num_primary_unit ,
175170 primary_unit_size = args .primary_unit_size ,
176171 output_unit_size = args .output_unit_size ,
177- cuda = args .cuda )
172+ num_routing = args .num_routing ,
173+ cuda_enabled = args .cuda )
178174
179- if cuda :
180- model = model .cuda ()
175+ if args . cuda :
176+ model .cuda ()
181177
182178 optimizer = optim .Adam (model .parameters (), lr = args .lr )
183179
0 commit comments