Skip to content

Commit 43c3300

Browse files
authored
Fix bucket refilling in _score methode of Inference class (#2557)
* fixed score results overriding * fixed bucket refilling in translator._score
1 parent 1c27987 commit 43c3300

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

onmt/translate/translator.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -576,29 +576,35 @@ def _process_bucket(bucket_translations):
576576

577577
def _score(self, infer_iter):
578578
self.with_scores = True
579-
scored_bucket = {}
579+
score_res = []
580+
processed_bucket = {}
581+
prev_bucket_idx = 0
580582
for batch, bucket_idx in infer_iter:
583+
if bucket_idx != prev_bucket_idx:
584+
prev_bucket_idx += 1
585+
score_res += [item for _, item in sorted(processed_bucket.items())]
586+
processed_bucket = {}
581587
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
582588
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
589+
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
590+
batch_inds_in_bucket = batch["ind_in_bucket"]
583591
if self.return_gold_log_probs:
584592
batch_gold_log_probs = (
585593
batch_data["gold_log_probs"].cpu().numpy().tolist()
586594
)
587595
else:
588-
batch_gold_log_probs = None
589-
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
590-
batch_inds_in_bucket = batch["ind_in_bucket"]
591-
for i, _score in enumerate(batch_gold_scores):
592-
log_probs = (
593-
batch_gold_log_probs[i] if self.return_gold_log_probs else None
594-
)
595-
scored_bucket[batch_inds_in_bucket[i]] = (
596-
_score,
597-
log_probs,
596+
batch_gold_log_probs = [
597+
None for i, _ in enumerate(batch_inds_in_bucket)
598+
]
599+
for i, ind in enumerate(batch_inds_in_bucket):
600+
processed_bucket[ind] = [
601+
batch_gold_scores[i],
602+
batch_gold_log_probs[i],
598603
batch_tgt_lengths[i],
599-
)
600-
score_results = [scored_bucket[i] for i in range(len(scored_bucket))]
601-
return score_results
604+
]
605+
if processed_bucket:
606+
score_res += [item for _, item in sorted(processed_bucket.items())]
607+
return score_res
602608

603609
def _align_pad_prediction(self, predictions, bos, pad):
604610
"""

0 commit comments

Comments
 (0)