1414 Transformer ,
1515)
1616from algoperf .workloads .lm .workload import BaseLmWorkload
17+ from algoperf .workloads .lm .input_pipeline import get_data_iter
1718
1819USE_PYTORCH_DDP , RANK , DEVICE , N_GPUS = pytorch_utils .pytorch_setup ()
1920
@@ -37,10 +38,11 @@ def init_model_fn(
3738 cfg = ModelConfig (
3839 vocab_size = self ._vocab_size ,
3940 seq_len = self ._seq_len ,
40- dim = 512 , # Model dimension
41- expand = 4 , # MLP expansion factor
42- n_layers = 6 , # Number of transformer layers
43- n_heads = 8 , # Number of attention heads
41+ dim = self ._emb_dim , # Model dimension
42+ expand = self ._mlp_dim // self ._emb_dim , # MLP expansion factor
43+ # FIXME(rka97): fix expansion factor
44+ n_layers = self ._n_layers , # Number of transformer layers
45+ n_heads = self ._n_heads , # Number of attention heads
4446 rmsnorm_eps = 1e-6 ,
4547 tie_embeddings = True
4648 )
@@ -65,7 +67,7 @@ def model_fn(
6567 mode : spec .ForwardPassMode ,
6668 rng : spec .RandomState ,
6769 update_batch_norm : bool ,
68- dropout_rate : None ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
70+ dropout_rate : float = 0.0 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
6971
7072 del model_state , rng , update_batch_norm , dropout_rate
7173 model = params
@@ -87,10 +89,8 @@ def _build_input_queue(
8789 num_batches : Optional [int ] = None ,
8890 repeat_final_dataset : bool = False ) -> Iterator [Dict [str , spec .Tensor ]]:
8991 """Build an input queue for the given split."""
90- from algoperf .workloads .lm .input_pipeline import get_lm_dataset
9192 local_batch_size = global_batch_size // N_GPUS
92-
93- loader = get_lm_dataset (
93+ loader = get_data_iter (
9494 data_rng = data_rng ,
9595 split = split ,
9696 data_dir = data_dir ,
@@ -99,100 +99,54 @@ def _build_input_queue(
9999 )
100100 if USE_PYTORCH_DDP :
101101 loader = islice (loader , RANK , None , N_GPUS )
102- seq_len = self ._seq_len
103- weights = None
104-
105102 dtype = torch .int32
106- is_train = split == 'train'
107-
108103 for batch in loader :
109- inputs = batch ['inputs' ]
110- targets = batch ['targets' ]
111-
112- if USE_PYTORCH_DDP :
113- if not is_train :
114- # During eval, the batch size of the remainder might be different
115- per_device_batch_size = torch .tensor (
116- targets .shape [0 ], dtype = dtype , device = DEVICE )
117- dist .broadcast (per_device_batch_size , src = 0 )
118- local_batch_size = per_device_batch_size .item ()
119- # Broadcast to all devices
120- #dist.broadcast(inputs, src=0)
121- #dist.broadcast(targets, src=0)
122-
123- if weights is None :
124- weights = torch .ones ((local_batch_size , seq_len ), device = DEVICE )
125104 batch = {
126- 'inputs' : torch .tensor (inputs , device = DEVICE , dtype = dtype ),
127- 'targets' : torch .tensor (targets , device = DEVICE , dtype = dtype ),
128- 'weights' : weights ,
105+ 'inputs' : torch .tensor (batch [ ' inputs' ] , device = DEVICE , dtype = dtype ),
106+ 'targets' : torch .tensor (batch [ ' targets' ] , device = DEVICE , dtype = torch . int64 ),
107+ 'weights' : None ,
129108 }
130109 yield batch
131110
132111 def is_output_params (self , param_name : str ) -> bool :
133112 """Return whether the given parameter is an output parameter."""
134113 return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name
135114
136- def _eval_batch (self ,
137- params : spec .ParameterContainer ,
138- batch : Dict [str , spec .Tensor ],
139- model_state : spec .ModelAuxiliaryState ,
140- rng : spec .RandomState ) -> spec .Tensor :
141- """Evaluate the model on a single batch."""
142- model = params
143- logits , _ = self .model_fn (
144- model , batch , model_state , spec .ForwardPassMode .EVAL , rng , False )
145-
146- # Handle both one-hot and token ID targets
147- targets = batch ['targets' ]
148- if targets .dim () == 3 : # one-hot
149- loss = - torch .sum (targets * torch .nn .functional .log_softmax (logits , dim = - 1 ))
150- else : # token IDs
151- # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch.
152- loss = torch .nn .functional .cross_entropy (
153- logits .view (- 1 , logits .size (- 1 )),
154- targets .view (- 1 ),
155- reduction = 'sum'
156- )
157- return loss
158-
159- def loss_fn (
160- self ,
161- label_batch : spec .Tensor ,
162- logits_batch : spec .Tensor ,
163- mask_batch : Optional [spec .Tensor ] = None ,
164- label_smoothing : float = 0.0 ) -> Dict [str , spec .Tensor ]:
115+ # FIXME(rka97): Implement label smoothing
116+ def compute_weighted_cross_entropy (self , logits : spec .Tensor , labels : spec .Tensor , weights : spec .Tensor , label_smoothing : float = 0.0 ) -> Dict [str , spec .Tensor ]:
165117 """Compute cross-entropy loss for language modeling in PyTorch."""
166- vocab_size = logits_batch . shape [ - 1 ]
118+ vocab_size = logits . size ( - 1 )
167119
168- if len (label_batch .shape ) == len (logits_batch .shape ):
120+ if len (labels .shape ) == len (logits .shape ):
169121 # One-hot labels
170- log_probs = torch .nn .functional .log_softmax (logits_batch , dim = - 1 )
171- loss = - torch .sum (label_batch * log_probs , dim = - 1 )
122+ log_probs = torch .nn .functional .log_softmax (logits , dim = - 1 )
123+ loss = - torch .sum (labels * log_probs , dim = - 1 )
172124 else :
173125 # Dense labels
174126 loss = torch .nn .functional .cross_entropy (
175- logits_batch ,
176- label_batch ,
127+ logits . view ( - 1 , vocab_size ) ,
128+ labels . view ( - 1 ) ,
177129 reduction = 'none' )
178- if mask_batch is not None :
179- loss = loss * mask_batch
130+ loss = loss .view_as (labels )
131+
132+ if weights is not None :
133+ loss = loss * weights
180134
181- n_valid = mask_batch .sum () if mask_batch is not None else label_batch . shape [ 0 ]
135+ n_valid = weights .sum () if weights is not None else torch . tensor ( labels . numel (), dtype = torch . float32 , device = labels . device )
182136 return {
183137 'summed' : loss .sum (),
184138 'n_valid_examples' : n_valid ,
185- 'per_example' : loss
139+ 'per_example' : loss ,
186140 }
187141
188- def _normalize_eval_metrics (
189- self , num_examples : int , total_metrics : Dict [str , Any ]
190- ) -> Dict [str , float ]:
191- """Normalize eval metrics."""
192- del num_examples
193- if USE_PYTORCH_DDP :
194- for metric in total_metrics .values ():
195- dist .all_reduce (metric )
196- total_metrics = {k : v .item () for k , v in total_metrics .items ()}
197- eval_denominator = total_metrics .pop ('denominator' )
198- return jax .tree .map (lambda x : float (x / eval_denominator ), total_metrics )
142+ def _normalize_eval_metrics (
143+ self , num_examples : int , total_metrics : Dict [str , Any ]
144+ ) -> Dict [str , float ]:
145+ """Normalize eval metrics."""
146+ del num_examples
147+ if USE_PYTORCH_DDP :
148+ for metric in total_metrics .values ():
149+ dist .all_reduce (metric )
150+ total_metrics = {k : v .item () for k , v in total_metrics .items ()}
151+ eval_denominator = total_metrics .pop ('denominator' )
152+ return jax .tree .map (lambda x : float (x / eval_denominator ), total_metrics )
0 commit comments