@@ -245,81 +245,84 @@ def validate(val_loader, encoder, decoder, criterion):
245
245
references = list () # references (true captions) for calculating BLEU-4 score
246
246
hypotheses = list () # hypotheses (predictions)
247
247
248
- # Batches
249
- for i , (imgs , caps , caplens , allcaps ) in enumerate (val_loader ):
250
-
251
- # Move to device, if available
252
- imgs = imgs .to (device )
253
- caps = caps .to (device )
254
- caplens = caplens .to (device )
255
-
256
- # Forward prop.
257
- if encoder is not None :
258
- imgs = encoder (imgs )
259
- scores , caps_sorted , decode_lengths , alphas , sort_ind = decoder (imgs , caps , caplens )
260
-
261
- # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
262
- targets = caps_sorted [:, 1 :]
263
-
264
- # Remove timesteps that we didn't decode at, or are pads
265
- # pack_padded_sequence is an easy trick to do this
266
- scores_copy = scores .clone ()
267
- scores , _ = pack_padded_sequence (scores , decode_lengths , batch_first = True )
268
- targets , _ = pack_padded_sequence (targets , decode_lengths , batch_first = True )
269
-
270
- # Calculate loss
271
- loss = criterion (scores , targets )
272
-
273
- # Add doubly stochastic attention regularization
274
- loss += alpha_c * ((1. - alphas .sum (dim = 1 )) ** 2 ).mean ()
275
-
276
- # Keep track of metrics
277
- losses .update (loss .item (), sum (decode_lengths ))
278
- top5 = accuracy (scores , targets , 5 )
279
- top5accs .update (top5 , sum (decode_lengths ))
280
- batch_time .update (time .time () - start )
281
-
282
- start = time .time ()
283
-
284
- if i % print_freq == 0 :
285
- print ('Validation: [{0}/{1}]\t '
286
- 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t '
287
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t '
288
- 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t ' .format (i , len (val_loader ), batch_time = batch_time ,
289
- loss = losses , top5 = top5accs ))
290
-
291
- # Store references (true captions), and hypothesis (prediction) for each image
292
- # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
293
- # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
294
-
295
- # References
296
- allcaps = allcaps [sort_ind ] # because images were sorted in the decoder
297
- for j in range (allcaps .shape [0 ]):
298
- img_caps = allcaps [j ].tolist ()
299
- img_captions = list (
300
- map (lambda c : [w for w in c if w not in {word_map ['<start>' ], word_map ['<pad>' ]}],
301
- img_caps )) # remove <start> and pads
302
- references .append (img_captions )
303
-
304
- # Hypotheses
305
- _ , preds = torch .max (scores_copy , dim = 2 )
306
- preds = preds .tolist ()
307
- temp_preds = list ()
308
- for j , p in enumerate (preds ):
309
- temp_preds .append (preds [j ][:decode_lengths [j ]]) # remove pads
310
- preds = temp_preds
311
- hypotheses .extend (preds )
312
-
313
- assert len (references ) == len (hypotheses )
314
-
315
- # Calculate BLEU-4 scores
316
- bleu4 = corpus_bleu (references , hypotheses )
317
-
318
- print (
319
- '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n ' .format (
320
- loss = losses ,
321
- top5 = top5accs ,
322
- bleu = bleu4 ))
248
+ # explicitly disable gradient calculation to avoid CUDA memory error
249
+ # solves the issue #57
250
+ with torch .no_grad ():
251
+ # Batches
252
+ for i , (imgs , caps , caplens , allcaps ) in enumerate (val_loader ):
253
+
254
+ # Move to device, if available
255
+ imgs = imgs .to (device )
256
+ caps = caps .to (device )
257
+ caplens = caplens .to (device )
258
+
259
+ # Forward prop.
260
+ if encoder is not None :
261
+ imgs = encoder (imgs )
262
+ scores , caps_sorted , decode_lengths , alphas , sort_ind = decoder (imgs , caps , caplens )
263
+
264
+ # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
265
+ targets = caps_sorted [:, 1 :]
266
+
267
+ # Remove timesteps that we didn't decode at, or are pads
268
+ # pack_padded_sequence is an easy trick to do this
269
+ scores_copy = scores .clone ()
270
+ scores , _ = pack_padded_sequence (scores , decode_lengths , batch_first = True )
271
+ targets , _ = pack_padded_sequence (targets , decode_lengths , batch_first = True )
272
+
273
+ # Calculate loss
274
+ loss = criterion (scores , targets )
275
+
276
+ # Add doubly stochastic attention regularization
277
+ loss += alpha_c * ((1. - alphas .sum (dim = 1 )) ** 2 ).mean ()
278
+
279
+ # Keep track of metrics
280
+ losses .update (loss .item (), sum (decode_lengths ))
281
+ top5 = accuracy (scores , targets , 5 )
282
+ top5accs .update (top5 , sum (decode_lengths ))
283
+ batch_time .update (time .time () - start )
284
+
285
+ start = time .time ()
286
+
287
+ if i % print_freq == 0 :
288
+ print ('Validation: [{0}/{1}]\t '
289
+ 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t '
290
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t '
291
+ 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t ' .format (i , len (val_loader ), batch_time = batch_time ,
292
+ loss = losses , top5 = top5accs ))
293
+
294
+ # Store references (true captions), and hypothesis (prediction) for each image
295
+ # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
296
+ # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
297
+
298
+ # References
299
+ allcaps = allcaps [sort_ind ] # because images were sorted in the decoder
300
+ for j in range (allcaps .shape [0 ]):
301
+ img_caps = allcaps [j ].tolist ()
302
+ img_captions = list (
303
+ map (lambda c : [w for w in c if w not in {word_map ['<start>' ], word_map ['<pad>' ]}],
304
+ img_caps )) # remove <start> and pads
305
+ references .append (img_captions )
306
+
307
+ # Hypotheses
308
+ _ , preds = torch .max (scores_copy , dim = 2 )
309
+ preds = preds .tolist ()
310
+ temp_preds = list ()
311
+ for j , p in enumerate (preds ):
312
+ temp_preds .append (preds [j ][:decode_lengths [j ]]) # remove pads
313
+ preds = temp_preds
314
+ hypotheses .extend (preds )
315
+
316
+ assert len (references ) == len (hypotheses )
317
+
318
+ # Calculate BLEU-4 scores
319
+ bleu4 = corpus_bleu (references , hypotheses )
320
+
321
+ print (
322
+ '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n ' .format (
323
+ loss = losses ,
324
+ top5 = top5accs ,
325
+ bleu = bleu4 ))
323
326
324
327
return bleu4
325
328
0 commit comments