@@ -87,35 +87,38 @@ def model_fn(
8787 logits = self ._model .apply ({'params' : params }, inputs )
8888 return logits , None
8989
90- def compute_weighted_cross_entropy (
90+ def loss_fn (
9191 self ,
92- logits : spec .Tensor ,
93- targets : spec .Tensor ,
94- weights : Optional [spec .Tensor ] = None ,
92+ label_batch : spec .Tensor ,
93+ logits_batch : spec .Tensor ,
94+ mask_batch : Optional [spec .Tensor ] = None ,
9595 label_smoothing : float = 0.0 ,
9696 ) -> Dict [str , spec .Tensor ]: # differentiable
97- """Compute weighted cross entropy and entropy for log probs and targets.
97+ """Compute weighted cross entropy.
98+
9899 Args:
99- logits: [batch, length, num_classes] float array.
100- targets: categorical targets [batch, length] int array.
101- weights: array of shape [batch, length].
102- label_smoothing: label smoothing constant, used to determine the on and off
103- values.
100+ label_batch: categorical targets [batch, length] int array.
101+ logits_batch: [batch, length, num_classes] float array.
102+ mask_batch: weights array of shape [batch, length].
103+ label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target
104+ distribution becomes (1 - label_smoothing) for the correct class and
105+ label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing).
106+
104107 Returns:
105108 {'summed': scalar summed loss, 'n_valid_examples': scalar number of
106- valid examples in batch, 'per_example': 1-d array of per-example losses}
109+ valid examples in batch, 'per_example': 2d array of per-example losses}
107110 """
108- if logits .ndim != targets .ndim + 1 :
111+ if logits_batch .ndim != label_batch .ndim + 1 :
109112 raise ValueError (
110- f'Incorrect shapes. Got shape { logits .shape } logits and '
111- f'{ targets .shape } targets.'
113+ f'Incorrect shapes. Got shape { logits_batch .shape } logits and '
114+ f'{ label_batch .shape } targets.'
112115 )
113116 # Compute log probabilities
114- log_probs = jax .nn .log_softmax (logits , axis = - 1 )
117+ log_probs = jax .nn .log_softmax (logits_batch , axis = - 1 )
115118 # Extract log probability of the target class
116119 # Shape: [batch, length]
117120 target_log_probs = jnp .take_along_axis (
118- log_probs , targets [..., None ], axis = - 1
121+ log_probs , label_batch [..., None ], axis = - 1
119122 ).squeeze (- 1 )
120123 # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p)
121124 # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss.
@@ -124,11 +127,11 @@ def compute_weighted_cross_entropy(
124127 per_example_losses = - 1.0 * (
125128 confidence * target_log_probs + smoothing_term * log_probs .sum (axis = - 1 )
126129 )
127- if weights is not None :
128- per_example_losses = jnp . where ( weights , per_example_losses , 0.0 )
129- n_valid_examples = weights .sum ()
130+ if mask_batch is not None :
131+ per_example_losses = mask_batch * per_example_losses
132+ n_valid_examples = mask_batch .sum ()
130133 else :
131- n_valid_examples = targets .shape [0 ] * targets .shape [1 ]
134+ n_valid_examples = label_batch .shape [0 ] * label_batch .shape [1 ]
132135 summed_loss = per_example_losses .sum ()
133136 return {
134137 'summed' : summed_loss ,
@@ -147,12 +150,9 @@ def _eval_batch(
147150 logits , _ = self .model_fn (
148151 params , batch , model_state , spec .ForwardPassMode .EVAL , rng , False
149152 )
150- # Calculate cross-entropy loss
151- metrics = self .compute_weighted_cross_entropy (
152- logits , batch ['targets' ], batch ['weights' ]
153+ metrics = self .loss_fn (
154+ label_batch = batch ['targets' ], logits_batch = logits , mask_batch = batch ['weights' ]
153155 )
154- # CRITICAL: Detach tensors to free computation graph and activations
155- # Without this, all intermediate activations are kept in memory!
156156 return {
157157 'loss' : metrics ['summed' ],
158158 'denominator' : metrics ['n_valid_examples' ],
0 commit comments