@@ -576,29 +576,35 @@ def _process_bucket(bucket_translations):
576
576
577
577
def _score (self , infer_iter ):
578
578
self .with_scores = True
579
- scored_bucket = {}
579
+ score_res = []
580
+ processed_bucket = {}
581
+ prev_bucket_idx = 0
580
582
for batch , bucket_idx in infer_iter :
583
+ if bucket_idx != prev_bucket_idx :
584
+ prev_bucket_idx += 1
585
+ score_res += [item for _ , item in sorted (processed_bucket .items ())]
586
+ processed_bucket = {}
581
587
batch_data = self .translate_batch (batch , attn_debug = False , scoring = True )
582
588
batch_gold_scores = batch_data ["gold_score" ].cpu ().numpy ().tolist ()
589
+ batch_tgt_lengths = batch ["tgtlen" ].cpu ().numpy ().tolist ()
590
+ batch_inds_in_bucket = batch ["ind_in_bucket" ]
583
591
if self .return_gold_log_probs :
584
592
batch_gold_log_probs = (
585
593
batch_data ["gold_log_probs" ].cpu ().numpy ().tolist ()
586
594
)
587
595
else :
588
- batch_gold_log_probs = None
589
- batch_tgt_lengths = batch ["tgtlen" ].cpu ().numpy ().tolist ()
590
- batch_inds_in_bucket = batch ["ind_in_bucket" ]
591
- for i , _score in enumerate (batch_gold_scores ):
592
- log_probs = (
593
- batch_gold_log_probs [i ] if self .return_gold_log_probs else None
594
- )
595
- scored_bucket [batch_inds_in_bucket [i ]] = (
596
- _score ,
597
- log_probs ,
596
+ batch_gold_log_probs = [
597
+ None for i , _ in enumerate (batch_inds_in_bucket )
598
+ ]
599
+ for i , ind in enumerate (batch_inds_in_bucket ):
600
+ processed_bucket [ind ] = [
601
+ batch_gold_scores [i ],
602
+ batch_gold_log_probs [i ],
598
603
batch_tgt_lengths [i ],
599
- )
600
- score_results = [scored_bucket [i ] for i in range (len (scored_bucket ))]
601
- return score_results
604
+ ]
605
+ if processed_bucket :
606
+ score_res += [item for _ , item in sorted (processed_bucket .items ())]
607
+ return score_res
602
608
603
609
def _align_pad_prediction (self , predictions , bos , pad ):
604
610
"""
0 commit comments