1515
1616"""UL2-style dataset.""" 
1717
18+ import  math 
19+ 
1820import  numpy  as  np 
1921
2022from  megatron  import  get_tokenizer 
2325    get_samples_mapping ,
2426    SamplingStyle 
2527)
26- from  megatron .data .t5_dataset  import  pad_and_convert_to_numpy , T5Dataset 
28+ from  megatron .data .t5_dataset  import  (
29+     make_history_mask ,
30+     merge_subsequent_masks ,
31+     pad_and_convert_to_numpy ,
32+     T5Dataset ,
33+ )
34+ from  megatron .enums  import  UL2ModelType 
35+ 
36+ 
37+ def  is_decoder_only (ul2_model_type ):
38+     """Return whether we use a decoder-only model.""" 
39+     assert  isinstance (ul2_model_type , UL2ModelType )
40+     return  ul2_model_type  is  not UL2ModelType .ENCODER_DECODER 
41+ 
42+ 
43+ def  is_prefix_lm (ul2_model_type ):
44+     """Return whether we use a non-causal decoder-only model.""" 
45+     assert  isinstance (ul2_model_type , UL2ModelType )
46+     return  ul2_model_type  is  UL2ModelType .NON_CAUSAL_DECODER 
2747
2848
2949class  UL2Dataset (T5Dataset ):
3050
3151    def  __init__ (self , name , indexed_dataset , data_prefix ,
32-                  num_epochs , max_num_samples , denoiser_ratios ,
33-                  denoisers ,  mean_span_lengths ,  mask_ratios ,
34-                  denoiser_tokens ,  max_seq_length ,  max_seq_length_dec ,
35-                  short_seq_prob , seed ):
52+                  num_epochs , max_num_samples , model_type ,
53+                  denoiser_ratios ,  denoisers ,  mean_span_lengths ,
54+                  mask_ratios ,  denoiser_tokens ,  max_seq_length ,
55+                  max_seq_length_dec ,  short_seq_prob , seed ):
3656
3757        if  denoiser_ratios  is  None :
3858            # Uniform distribution by default. 
@@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
5272                         short_seq_prob , seed )
5373
5474        # Params to store. 
75+         self .model_type  =  model_type 
5576        self .denoiser_ratios  =  [
5677            denoiser_ratio  /  sum (denoiser_ratios )
5778            for  denoiser_ratio  in  denoiser_ratios 
@@ -97,21 +118,21 @@ def __getitem__(self, idx):
97118                                     self .vocab_id_to_token_dict ,
98119                                     self .cls_ids , self .sep_id ,
99120                                     self .mask_id , self .pad_id ,
100-                                      self .denoiser_ratios , self .denoisers ,
101-                                      self .mean_span_lengths , self .mask_ratios ,
102-                                      np_rng ,
103-                                      self .bos_id , self .eos_id ,
104-                                      self .sentinel_tokens )
121+                                      self .model_type , self .denoiser_ratios ,
122+                                      self .denoisers , self .mean_span_lengths ,
123+                                      self .mask_ratios , np_rng , self .bos_id ,
124+                                      self .eos_id , self .sentinel_tokens )
105125
106126
107127def  build_training_sample (sample , target_seq_length ,
108128                          max_seq_length , max_seq_length_dec ,
109129                          vocab_id_list , vocab_id_to_token_dict ,
110130                          cls_ids , sep_id , mask_id , pad_id ,
111-                           denoiser_ratios , denoisers ,
112-                           mean_span_lengths , mask_ratios ,
113-                           np_rng , bos_id = None ,
114-                           eos_id = None , sentinel_tokens = None ):
131+                           model_type , denoiser_ratios ,
132+                           denoisers , mean_span_lengths ,
133+                           mask_ratios , np_rng ,
134+                           bos_id = None , eos_id = None ,
135+                           sentinel_tokens = None ):
115136    """Build training sample. 
116137
117138    Arguments: 
@@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length,
125146        sep_id: Separator id. 
126147        mask_id: Mask token id. 
127148        pad_id: Padding token id. 
149+         model_type: What type of model is used. 
128150        denoiser_ratios: Probability of each denoising objective to be selected. 
129151        denoisers: What type of UL2 denoising objective the other UL2 
130152              configurations refer to. 
@@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length,
139161        sentinel_tokens: unique value to be substituted for every replaced span 
140162    """ 
141163
164+     # Denoiser selection 
165+     denoiser_index  =  np_rng .choice (np .arange (len (denoisers )), p = denoiser_ratios )
166+     denoiser  =  denoisers [denoiser_index ]
167+     masked_lm_prob  =  mask_ratios [denoiser_index ]
168+ 
142169    assert  target_seq_length  <=  max_seq_length 
143170
144171    # flatten sentences into one list 
145172    tokens  =  [token  for  sentence  in  sample  for  token  in  sentence ]
146173
147-     # Truncate to `target_sequence_length`. 
148174    max_num_tokens  =  target_seq_length 
149-     truncated   =   len ( tokens )  >   max_num_tokens 
150-     tokens   =   tokens [: max_num_tokens ] 
151- 
152-     # Denoiser selection 
153-     denoiser_index   =   np_rng . choice ( np . arange ( len ( denoisers )),  p = denoiser_ratios )
154-     denoiser   =   denoisers [ denoiser_index ] 
155-     masked_lm_prob   =   mask_ratios [ denoiser_index ]
156-     mean_ngrams   =   mean_span_lengths [ denoiser_index ] 
157-     if   mean_ngrams   <   1 : 
158-         mean_ngrams  =  round ( len (tokens ) *   mean_ngrams ) 
159-     max_ngrams   =   mean_ngrams   *   2   -   1 
175+     if   is_decoder_only ( model_type ): 
176+          # Keep space for repeated `extra_id`  tokens; not the most data 
177+          # efficient since we calculate this based on the maximum number 
178+          # of possible `extra_id` tokens. 
179+          safe_max_seq_len   =   math . floor ( max_num_tokens   /  ( 1   +   masked_lm_prob ) )
180+          truncated   =   len ( tokens )  >   safe_max_seq_len 
181+          tokens   =   tokens [: safe_max_seq_len ]
182+     else : 
183+          # Truncate to `target_sequence_length`. 
184+         truncated  =  len (tokens ) >   max_num_tokens 
185+          tokens   =   tokens [: max_num_tokens ] 
160186
161187    # Prepend objective token. 
162188    cls_id  =  cls_ids .get (denoiser )
@@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length,
166192
167193    # Masking. 
168194    max_predictions_per_seq  =  masked_lm_prob  *  len (tokens )
195+     mean_ngrams  =  mean_span_lengths [denoiser_index ]
196+     if  mean_ngrams  <  1 :
197+         mean_ngrams  =  round (len (tokens ) *  mean_ngrams )
198+     max_ngrams  =  mean_ngrams  *  2  -  1 
199+ 
169200    if  denoiser  ==  'R'  or  denoiser  ==  'X' :
170201        sampling_style  =  SamplingStyle .NORMAL 
171202        prefix_lm  =  False 
@@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length,
183214        sampling_style = sampling_style , prefix_lm = prefix_lm ,
184215    )
185216
186-     # Padding. 
187-     tokens_enc , tokens_dec_in , labels , enc_mask , \
188-     dec_mask , enc_dec_mask , loss_mask  \
189-         =  pad_and_convert_to_numpy (tokens , masked_positions ,
190-                                    masked_labels , pad_id , max_seq_length ,
191-                                    max_seq_length_dec , masked_spans ,
192-                                    bos_id , eos_id , sentinel_tokens )
193- 
194-     train_sample  =  {
195-         'text_enc' : tokens_enc ,
196-         'text_dec' : tokens_dec_in ,
197-         'labels' : labels ,
198-         'loss_mask' : loss_mask ,
199-         'truncated' : int (truncated ),
200-         'enc_mask' : enc_mask ,
201-         'dec_mask' : dec_mask ,
202-         'enc_dec_mask' : enc_dec_mask ,
203-     }
217+     if  is_decoder_only (model_type ):
218+         # Concatenate to one sequence. 
219+         tokens_enc , tokens_dec_in , labels  =  merge_subsequent_masks (
220+             tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
221+ 
222+         # Move EOS tokens to end of sequence. 
223+         while  tokens_enc [- 1 ] ==  eos_id :
224+             del  tokens_enc [- 1 ]
225+             tokens_dec_in .append (eos_id )
226+             labels .append (eos_id )
227+ 
228+         num_labels  =  len (labels )
229+ 
230+         # Move BOS token to start of sequence. 
231+         tokens_dec_in  =  tokens_dec_in [1 :]
232+         tokens  =  np .concatenate ([
233+             np .array ([bos_id ], dtype = np .int64 ),
234+             tokens_enc ,
235+             np .array ([sep_id ], dtype = np .int64 ),
236+             tokens_dec_in ,
237+         ])
238+         labels  =  np .concatenate ([
239+             tokens_enc ,
240+             np .array ([sep_id ], dtype = np .int64 ),
241+             labels ,
242+         ])
243+ 
244+         loss_mask  =  np .zeros (len (tokens ), dtype = np .int64 )
245+         loss_mask [- num_labels :] =  1 
246+ 
247+         dec_mask  =  make_history_mask (tokens )
248+         if  is_prefix_lm (model_type ):
249+             dec_mask [:- num_labels , :- num_labels ] =  1 
250+ 
251+         train_sample  =  {
252+             'text' : tokens ,
253+             'labels' : labels ,
254+             'loss_mask' : loss_mask ,
255+             'truncated' : int (truncated ),
256+             'dec_mask' : dec_mask ,
257+         }
258+     else :
259+         # Padding. 
260+         tokens_enc , tokens_dec_in , labels , enc_mask , \
261+         dec_mask , enc_dec_mask , loss_mask  \
262+             =  pad_and_convert_to_numpy (tokens , masked_positions ,
263+                                        masked_labels , pad_id , max_seq_length ,
264+                                        max_seq_length_dec , masked_spans ,
265+                                        bos_id , eos_id , sentinel_tokens )
266+ 
267+         train_sample  =  {
268+             'text_enc' : tokens_enc ,
269+             'text_dec' : tokens_dec_in ,
270+             'labels' : labels ,
271+             'loss_mask' : loss_mask ,
272+             'truncated' : int (truncated ),
273+             'enc_mask' : enc_mask ,
274+             'dec_mask' : dec_mask ,
275+             'enc_dec_mask' : enc_dec_mask ,
276+         }
204277    return  train_sample 
0 commit comments