Skip to content

Commit ca674e1

Browse files
committed
update to solve issue sgrvinod#57
1 parent b3b8263 commit ca674e1

File tree

1 file changed

+78
-75
lines changed

1 file changed

+78
-75
lines changed

train.py

+78-75
Original file line numberDiff line numberDiff line change
@@ -245,81 +245,84 @@ def validate(val_loader, encoder, decoder, criterion):
245245
references = list() # references (true captions) for calculating BLEU-4 score
246246
hypotheses = list() # hypotheses (predictions)
247247

248-
# Batches
249-
for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
250-
251-
# Move to device, if available
252-
imgs = imgs.to(device)
253-
caps = caps.to(device)
254-
caplens = caplens.to(device)
255-
256-
# Forward prop.
257-
if encoder is not None:
258-
imgs = encoder(imgs)
259-
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
260-
261-
# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
262-
targets = caps_sorted[:, 1:]
263-
264-
# Remove timesteps that we didn't decode at, or are pads
265-
# pack_padded_sequence is an easy trick to do this
266-
scores_copy = scores.clone()
267-
scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
268-
targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
269-
270-
# Calculate loss
271-
loss = criterion(scores, targets)
272-
273-
# Add doubly stochastic attention regularization
274-
loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
275-
276-
# Keep track of metrics
277-
losses.update(loss.item(), sum(decode_lengths))
278-
top5 = accuracy(scores, targets, 5)
279-
top5accs.update(top5, sum(decode_lengths))
280-
batch_time.update(time.time() - start)
281-
282-
start = time.time()
283-
284-
if i % print_freq == 0:
285-
print('Validation: [{0}/{1}]\t'
286-
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
287-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
288-
'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
289-
loss=losses, top5=top5accs))
290-
291-
# Store references (true captions), and hypothesis (prediction) for each image
292-
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
293-
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
294-
295-
# References
296-
allcaps = allcaps[sort_ind] # because images were sorted in the decoder
297-
for j in range(allcaps.shape[0]):
298-
img_caps = allcaps[j].tolist()
299-
img_captions = list(
300-
map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
301-
img_caps)) # remove <start> and pads
302-
references.append(img_captions)
303-
304-
# Hypotheses
305-
_, preds = torch.max(scores_copy, dim=2)
306-
preds = preds.tolist()
307-
temp_preds = list()
308-
for j, p in enumerate(preds):
309-
temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
310-
preds = temp_preds
311-
hypotheses.extend(preds)
312-
313-
assert len(references) == len(hypotheses)
314-
315-
# Calculate BLEU-4 scores
316-
bleu4 = corpus_bleu(references, hypotheses)
317-
318-
print(
319-
'\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
320-
loss=losses,
321-
top5=top5accs,
322-
bleu=bleu4))
248+
# explicitly disable gradient calculation to avoid CUDA memory error
249+
# solves the issue #57
250+
with torch.no_grad():
251+
# Batches
252+
for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
253+
254+
# Move to device, if available
255+
imgs = imgs.to(device)
256+
caps = caps.to(device)
257+
caplens = caplens.to(device)
258+
259+
# Forward prop.
260+
if encoder is not None:
261+
imgs = encoder(imgs)
262+
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
263+
264+
# Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
265+
targets = caps_sorted[:, 1:]
266+
267+
# Remove timesteps that we didn't decode at, or are pads
268+
# pack_padded_sequence is an easy trick to do this
269+
scores_copy = scores.clone()
270+
scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
271+
targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
272+
273+
# Calculate loss
274+
loss = criterion(scores, targets)
275+
276+
# Add doubly stochastic attention regularization
277+
loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
278+
279+
# Keep track of metrics
280+
losses.update(loss.item(), sum(decode_lengths))
281+
top5 = accuracy(scores, targets, 5)
282+
top5accs.update(top5, sum(decode_lengths))
283+
batch_time.update(time.time() - start)
284+
285+
start = time.time()
286+
287+
if i % print_freq == 0:
288+
print('Validation: [{0}/{1}]\t'
289+
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
290+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
291+
'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
292+
loss=losses, top5=top5accs))
293+
294+
# Store references (true captions), and hypothesis (prediction) for each image
295+
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
296+
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
297+
298+
# References
299+
allcaps = allcaps[sort_ind] # because images were sorted in the decoder
300+
for j in range(allcaps.shape[0]):
301+
img_caps = allcaps[j].tolist()
302+
img_captions = list(
303+
map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
304+
img_caps)) # remove <start> and pads
305+
references.append(img_captions)
306+
307+
# Hypotheses
308+
_, preds = torch.max(scores_copy, dim=2)
309+
preds = preds.tolist()
310+
temp_preds = list()
311+
for j, p in enumerate(preds):
312+
temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
313+
preds = temp_preds
314+
hypotheses.extend(preds)
315+
316+
assert len(references) == len(hypotheses)
317+
318+
# Calculate BLEU-4 scores
319+
bleu4 = corpus_bleu(references, hypotheses)
320+
321+
print(
322+
'\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
323+
loss=losses,
324+
top5=top5accs,
325+
bleu=bleu4))
323326

324327
return bleu4
325328

0 commit comments

Comments
 (0)