1515
1616"""UL2-style dataset."""
1717
18+ import math
19+
1820import numpy as np
1921
2022from megatron import get_tokenizer
2325 get_samples_mapping ,
2426 SamplingStyle
2527)
26- from megatron .data .t5_dataset import pad_and_convert_to_numpy , T5Dataset
28+ from megatron .data .t5_dataset import (
29+ make_history_mask ,
30+ merge_subsequent_masks ,
31+ pad_and_convert_to_numpy ,
32+ T5Dataset ,
33+ )
34+ from megatron .enums import UL2ModelType
35+
36+
37+ def is_decoder_only (ul2_model_type ):
38+ """Return whether we use a decoder-only model."""
39+ assert isinstance (ul2_model_type , UL2ModelType )
40+ return ul2_model_type is not UL2ModelType .ENCODER_DECODER
41+
42+
43+ def is_prefix_lm (ul2_model_type ):
44+ """Return whether we use a non-causal decoder-only model."""
45+ assert isinstance (ul2_model_type , UL2ModelType )
46+ return ul2_model_type is UL2ModelType .NON_CAUSAL_DECODER
2747
2848
2949class UL2Dataset (T5Dataset ):
3050
3151 def __init__ (self , name , indexed_dataset , data_prefix ,
32- num_epochs , max_num_samples , denoiser_ratios ,
33- denoisers , mean_span_lengths , mask_ratios ,
34- denoiser_tokens , max_seq_length , max_seq_length_dec ,
35- short_seq_prob , seed ):
52+ num_epochs , max_num_samples , model_type ,
53+ denoiser_ratios , denoisers , mean_span_lengths ,
54+ mask_ratios , denoiser_tokens , max_seq_length ,
55+ max_seq_length_dec , short_seq_prob , seed ):
3656
3757 if denoiser_ratios is None :
3858 # Uniform
@@ -49,6 +69,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
4969 # Params to store.
5070 self .name = name
5171 self .seed = seed
72+ self .model_type = model_type
5273 self .denoiser_ratios = [
5374 denoiser_ratio / sum (denoiser_ratios )
5475 for denoiser_ratio in denoiser_ratios
@@ -116,21 +137,21 @@ def __getitem__(self, idx):
116137 self .vocab_id_to_token_dict ,
117138 self .cls_ids , self .sep_id ,
118139 self .mask_id , self .pad_id ,
119- self .denoiser_ratios , self .denoisers ,
120- self .mean_span_lengths , self .mask_ratios ,
121- np_rng ,
122- self .bos_id , self .eos_id ,
123- self .sentinel_tokens )
140+ self .model_type , self .denoiser_ratios ,
141+ self .denoisers , self .mean_span_lengths ,
142+ self .mask_ratios , np_rng , self .bos_id ,
143+ self .eos_id , self .sentinel_tokens )
124144
125145
126146def build_training_sample (sample , target_seq_length ,
127147 max_seq_length , max_seq_length_dec ,
128148 vocab_id_list , vocab_id_to_token_dict ,
129149 cls_ids , sep_id , mask_id , pad_id ,
130- denoiser_ratios , denoisers ,
131- mean_span_lengths , mask_ratios ,
132- np_rng , bos_id = None ,
133- eos_id = None , sentinel_tokens = None ):
150+ model_type , denoiser_ratios ,
151+ denoisers , mean_span_lengths ,
152+ mask_ratios , np_rng ,
153+ bos_id = None , eos_id = None ,
154+ sentinel_tokens = None ):
134155 """Build training sample.
135156
136157 Arguments:
@@ -144,6 +165,7 @@ def build_training_sample(sample, target_seq_length,
144165 sep_id: Separator id.
145166 mask_id: Mask token id.
146167 pad_id: Padding token id.
168+ model_type: What type of model is used.
147169 denoiser_ratios: Probability of each denoising objective to be selected.
148170 denoisers: What type of UL2 denoising objective the other UL2
149171 configurations refer to.
@@ -158,24 +180,28 @@ def build_training_sample(sample, target_seq_length,
158180 sentinel_tokens: unique value to be substituted for every replaced span
159181 """
160182
183+ # Denoiser selection
184+ denoiser_index = np_rng .choice (np .arange (len (denoisers )), p = denoiser_ratios )
185+ denoiser = denoisers [denoiser_index ]
186+ masked_lm_prob = mask_ratios [denoiser_index ]
187+
161188 assert target_seq_length <= max_seq_length
162189
163190 # flatten sentences into one list
164191 tokens = [token for sentence in sample for token in sentence ]
165192
166- # Truncate to `target_sequence_length`.
167193 max_num_tokens = target_seq_length
168- truncated = len ( tokens ) > max_num_tokens
169- tokens = tokens [: max_num_tokens ]
170-
171- # Denoiser selection
172- denoiser_index = np_rng . choice ( np . arange ( len ( denoisers )), p = denoiser_ratios )
173- denoiser = denoisers [ denoiser_index ]
174- masked_lm_prob = mask_ratios [ denoiser_index ]
175- mean_ngrams = mean_span_lengths [ denoiser_index ]
176- if mean_ngrams < 1 :
177- mean_ngrams = round ( len (tokens ) * mean_ngrams )
178- max_ngrams = mean_ngrams * 2 - 1
194+ if is_decoder_only ( model_type ):
195+ # Keep space for repeated `extra_id` tokens; not the most data
196+ # efficient since we calculate this based on the maximum number
197+ # of possible `extra_id` tokens.
198+ safe_max_seq_len = math . floor ( max_num_tokens / ( 1 + masked_lm_prob ) )
199+ truncated = len ( tokens ) > safe_max_seq_len
200+ tokens = tokens [: safe_max_seq_len ]
201+ else :
202+ # Truncate to `target_sequence_length`.
203+ truncated = len (tokens ) > max_num_tokens
204+ tokens = tokens [: max_num_tokens ]
179205
180206 # Prepend objective token.
181207 cls_id = cls_ids .get (denoiser )
@@ -185,6 +211,11 @@ def build_training_sample(sample, target_seq_length,
185211
186212 # Masking.
187213 max_predictions_per_seq = masked_lm_prob * len (tokens )
214+ mean_ngrams = mean_span_lengths [denoiser_index ]
215+ if mean_ngrams < 1 :
216+ mean_ngrams = round (len (tokens ) * mean_ngrams )
217+ max_ngrams = mean_ngrams * 2 - 1
218+
188219 if denoiser == 'R' or denoiser == 'X' :
189220 sampling_style = SamplingStyle .NORMAL
190221 prefix_lm = False
@@ -202,22 +233,64 @@ def build_training_sample(sample, target_seq_length,
202233 sampling_style = sampling_style , prefix_lm = prefix_lm ,
203234 )
204235
205- # Padding.
206- tokens_enc , tokens_dec_in , labels , enc_mask , \
207- dec_mask , enc_dec_mask , loss_mask \
208- = pad_and_convert_to_numpy (tokens , masked_positions ,
209- masked_labels , pad_id , max_seq_length ,
210- max_seq_length_dec , masked_spans ,
211- bos_id , eos_id , sentinel_tokens )
212-
213- train_sample = {
214- 'text_enc' : tokens_enc ,
215- 'text_dec' : tokens_dec_in ,
216- 'labels' : labels ,
217- 'loss_mask' : loss_mask ,
218- 'truncated' : int (truncated ),
219- 'enc_mask' : enc_mask ,
220- 'dec_mask' : dec_mask ,
221- 'enc_dec_mask' : enc_dec_mask ,
222- }
236+ if is_decoder_only (model_type ):
237+ # Concatenate to one sequence.
238+ tokens_enc , tokens_dec_in , labels = merge_subsequent_masks (
239+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
240+
241+ # Move EOS tokens to end of sequence.
242+ while tokens_enc [- 1 ] == eos_id :
243+ del tokens_enc [- 1 ]
244+ tokens_dec_in .append (eos_id )
245+ labels .append (eos_id )
246+
247+ num_labels = len (labels )
248+
249+ # Move BOS token to start of sequence.
250+ tokens_dec_in = tokens_dec_in [1 :]
251+ tokens = np .concatenate ([
252+ np .array ([bos_id ], dtype = np .int64 ),
253+ tokens_enc ,
254+ np .array ([sep_id ], dtype = np .int64 ),
255+ tokens_dec_in ,
256+ ])
257+ labels = np .concatenate ([
258+ tokens_enc ,
259+ np .array ([sep_id ], dtype = np .int64 ),
260+ labels ,
261+ ])
262+
263+ loss_mask = np .zeros (len (tokens ), dtype = np .int64 )
264+ loss_mask [- num_labels :] = 1
265+
266+ dec_mask = make_history_mask (tokens )
267+ if is_prefix_lm (model_type ):
268+ dec_mask [:- num_labels , :- num_labels ] = 1
269+
270+ train_sample = {
271+ 'text' : tokens ,
272+ 'labels' : labels ,
273+ 'loss_mask' : loss_mask ,
274+ 'truncated' : int (truncated ),
275+ 'dec_mask' : dec_mask ,
276+ }
277+ else :
278+ # Padding.
279+ tokens_enc , tokens_dec_in , labels , enc_mask , \
280+ dec_mask , enc_dec_mask , loss_mask \
281+ = pad_and_convert_to_numpy (tokens , masked_positions ,
282+ masked_labels , pad_id , max_seq_length ,
283+ max_seq_length_dec , masked_spans ,
284+ bos_id , eos_id , sentinel_tokens )
285+
286+ train_sample = {
287+ 'text_enc' : tokens_enc ,
288+ 'text_dec' : tokens_dec_in ,
289+ 'labels' : labels ,
290+ 'loss_mask' : loss_mask ,
291+ 'truncated' : int (truncated ),
292+ 'enc_mask' : enc_mask ,
293+ 'dec_mask' : dec_mask ,
294+ 'enc_dec_mask' : enc_dec_mask ,
295+ }
223296 return train_sample
0 commit comments