File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -203,7 +203,8 @@ def create_masked_lm_predictions(tokens,
203203 do_permutation = False ,
204204 geometric_dist = False ,
205205 masking_style = "bert" ,
206- sampling_style = SamplingStyle .POISSON ):
206+ sampling_style = SamplingStyle .POISSON ,
207+ prefix_lm = False ):
207208 """Creates the predictions for the masked LM objective.
208209 Note: Tokens here are vocab ids and not text tokens."""
209210 if not isinstance (sampling_style , SamplingStyle ):
@@ -263,6 +264,10 @@ def create_masked_lm_predictions(tokens,
263264 for idx in range (len (cand_indexes )):
264265 ngram_index = []
265266 for n in ngrams :
267+ if prefix_lm :
268+ last_cand_index_index = min (idx + n - 1 , len (cand_indexes ) - 1 )
269+ if cand_indexes [last_cand_index_index ][- 1 ] < len (tokens ) - 1 :
270+ continue
266271 ngram_index .append (cand_indexes [idx :idx + n ])
267272 ngram_indexes .append (ngram_index )
268273
You can’t perform that action at this time.
0 commit comments