Skip to content

Commit 006c4e9

Browse files
committed
Allow Prefix-LM style masked LM
1 parent d8db189 commit 006c4e9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

megatron/data/dataset_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def create_masked_lm_predictions(tokens,
203203
do_permutation=False,
204204
geometric_dist=False,
205205
masking_style="bert",
206-
sampling_style=SamplingStyle.POISSON):
206+
sampling_style=SamplingStyle.POISSON,
207+
prefix_lm=False):
207208
"""Creates the predictions for the masked LM objective.
208209
Note: Tokens here are vocab ids and not text tokens."""
209210
if not isinstance(sampling_style, SamplingStyle):
@@ -263,6 +264,10 @@ def create_masked_lm_predictions(tokens,
263264
for idx in range(len(cand_indexes)):
264265
ngram_index = []
265266
for n in ngrams:
267+
if prefix_lm:
268+
last_cand_index_index = min(idx + n - 1, len(cand_indexes) - 1)
269+
if cand_indexes[last_cand_index_index][-1] < len(tokens) - 1:
270+
continue
266271
ngram_index.append(cand_indexes[idx:idx + n])
267272
ngram_indexes.append(ngram_index)
268273

0 commit comments

Comments
 (0)