@@ -584,7 +584,7 @@ def _score(self, infer_iter):
584
584
self .with_scores = True
585
585
scored_bucket = {}
586
586
for batch , bucket_idx in infer_iter :
587
- batch_data = self .translate_batch (batch , attn_debug = False )
587
+ batch_data = self .translate_batch (batch , attn_debug = False , scoring = True )
588
588
batch_gold_scores = batch_data ["gold_score" ].cpu ().numpy ().tolist ()
589
589
batch_tgt_lengths = batch ["tgtlen" ].cpu ().numpy ().tolist ()
590
590
batch_inds_in_bucket = batch ["ind_in_bucket" ]
@@ -1001,8 +1001,9 @@ def _align_forward(self, batch, predictions):
1001
1001
"""
1002
1002
raise NotImplementedError
1003
1003
1004
- def translate_batch (self , batch , attn_debug ):
1004
+ def translate_batch (self , batch , attn_debug , scoring = False ):
1005
1005
"""Translate a batch of sentences."""
1006
+ max_length = 0 if scoring else self .max_length
1006
1007
with torch .no_grad ():
1007
1008
if self .sample_from_topk != 0 or self .sample_from_topp != 0 :
1008
1009
decode_strategy = GreedySearchLM (
@@ -1015,7 +1016,7 @@ def translate_batch(self, batch, attn_debug):
1015
1016
batch_size = len (batch ["srclen" ]),
1016
1017
global_scorer = self .global_scorer ,
1017
1018
min_length = self .min_length ,
1018
- max_length = self . max_length ,
1019
+ max_length = max_length ,
1019
1020
block_ngram_repeat = self .block_ngram_repeat ,
1020
1021
exclusion_tokens = self ._exclusion_idxs ,
1021
1022
return_attention = attn_debug or self .replace_unk ,
@@ -1039,7 +1040,7 @@ def translate_batch(self, batch, attn_debug):
1039
1040
n_best = self .n_best ,
1040
1041
global_scorer = self .global_scorer ,
1041
1042
min_length = self .min_length ,
1042
- max_length = self . max_length ,
1043
+ max_length = max_length ,
1043
1044
return_attention = attn_debug or self .replace_unk ,
1044
1045
block_ngram_repeat = self .block_ngram_repeat ,
1045
1046
exclusion_tokens = self ._exclusion_idxs ,
0 commit comments