Skip to content

Commit f6a705d

Browse files
committed
Refactor loss function in LM workloads to unify label handling and improve clarity
1 parent 91988af commit f6a705d

File tree

3 files changed

+91
-69
lines changed

3 files changed

+91
-69
lines changed

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,35 +87,38 @@ def model_fn(
8787
logits = self._model.apply({'params': params}, inputs)
8888
return logits, None
8989

90-
def compute_weighted_cross_entropy(
90+
def loss_fn(
9191
self,
92-
logits: spec.Tensor,
93-
targets: spec.Tensor,
94-
weights: Optional[spec.Tensor] = None,
92+
label_batch: spec.Tensor,
93+
logits_batch: spec.Tensor,
94+
mask_batch: Optional[spec.Tensor] = None,
9595
label_smoothing: float = 0.0,
9696
) -> Dict[str, spec.Tensor]: # differentiable
97-
"""Compute weighted cross entropy and entropy for log probs and targets.
97+
"""Compute weighted cross entropy.
98+
9899
Args:
99-
logits: [batch, length, num_classes] float array.
100-
targets: categorical targets [batch, length] int array.
101-
weights: array of shape [batch, length].
102-
label_smoothing: label smoothing constant, used to determine the on and off
103-
values.
100+
label_batch: categorical targets [batch, length] int array.
101+
logits_batch: [batch, length, num_classes] float array.
102+
mask_batch: weights array of shape [batch, length].
103+
label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target
104+
distribution becomes (1 - label_smoothing) for the correct class and
105+
label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing).
106+
104107
Returns:
105108
{'summed': scalar summed loss, 'n_valid_examples': scalar number of
106-
valid examples in batch, 'per_example': 1-d array of per-example losses}
109+
valid examples in batch, 'per_example': 2d array of per-example losses}
107110
"""
108-
if logits.ndim != targets.ndim + 1:
111+
if logits_batch.ndim != label_batch.ndim + 1:
109112
raise ValueError(
110-
f'Incorrect shapes. Got shape {logits.shape} logits and '
111-
f'{targets.shape} targets.'
113+
f'Incorrect shapes. Got shape {logits_batch.shape} logits and '
114+
f'{label_batch.shape} targets.'
112115
)
113116
# Compute log probabilities
114-
log_probs = jax.nn.log_softmax(logits, axis=-1)
117+
log_probs = jax.nn.log_softmax(logits_batch, axis=-1)
115118
# Extract log probability of the target class
116119
# Shape: [batch, length]
117120
target_log_probs = jnp.take_along_axis(
118-
log_probs, targets[..., None], axis=-1
121+
log_probs, label_batch[..., None], axis=-1
119122
).squeeze(-1)
120123
# Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p)
121124
# The above formula is easy to derive from the definition of label smoothing and cross-entropy loss.
@@ -124,11 +127,11 @@ def compute_weighted_cross_entropy(
124127
per_example_losses = -1.0 * (
125128
confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)
126129
)
127-
if weights is not None:
128-
per_example_losses = jnp.where(weights, per_example_losses, 0.0)
129-
n_valid_examples = weights.sum()
130+
if mask_batch is not None:
131+
per_example_losses = mask_batch * per_example_losses
132+
n_valid_examples = mask_batch.sum()
130133
else:
131-
n_valid_examples = targets.shape[0] * targets.shape[1]
134+
n_valid_examples = label_batch.shape[0] * label_batch.shape[1]
132135
summed_loss = per_example_losses.sum()
133136
return {
134137
'summed': summed_loss,
@@ -147,12 +150,9 @@ def _eval_batch(
147150
logits, _ = self.model_fn(
148151
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False
149152
)
150-
# Calculate cross-entropy loss
151-
metrics = self.compute_weighted_cross_entropy(
152-
logits, batch['targets'], batch['weights']
153+
metrics = self.loss_fn(
154+
label_batch=batch['targets'], logits_batch=logits, mask_batch=batch['weights']
153155
)
154-
# CRITICAL: Detach tensors to free computation graph and activations
155-
# Without this, all intermediate activations are kept in memory!
156156
return {
157157
'loss': metrics['summed'],
158158
'denominator': metrics['n_valid_examples'],

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -133,43 +133,60 @@ def is_output_params(self, param_name: str) -> bool:
133133
"""Return whether the given parameter is an output parameter."""
134134
return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name
135135

136-
# FIXME(rka97): Implement label smoothing
137-
def compute_weighted_cross_entropy(
138-
self,
139-
logits: spec.Tensor,
140-
labels: spec.Tensor,
141-
weights: spec.Tensor,
142-
label_smoothing: float = 0.0,
143-
) -> Dict[str, spec.Tensor]:
144-
"""Compute cross-entropy loss for language modeling in PyTorch."""
145-
vocab_size = logits.size(-1)
146-
147-
if len(labels.shape) == len(logits.shape):
148-
# One-hot labels
149-
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
150-
loss = -torch.sum(labels * log_probs, dim=-1)
151-
else:
152-
# Dense labels
153-
loss = torch.nn.functional.cross_entropy(
154-
logits.view(-1, vocab_size), labels.view(-1), reduction='none'
136+
def loss_fn(
137+
self,
138+
label_batch: spec.Tensor,
139+
logits_batch: spec.Tensor,
140+
mask_batch: spec.Tensor,
141+
label_smoothing: float = 0.0,
142+
) -> Dict[str, spec.Tensor]:
143+
"""Compute weighted cross-entropy loss.
144+
145+
Args:
146+
label_batch: Target labels of shape [batch, length] (int).
147+
logits_batch: Predicted logits of shape [batch, length, vocab_size] (float).
148+
mask_batch: Optional weights of shape [batch, length] (float). Used to mask
149+
out padding tokens or weight examples differently. If None, all examples
150+
are weighted equally.
151+
label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target
152+
distribution becomes (1 - label_smoothing) for the correct class and
153+
label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing).
154+
155+
Returns:
156+
Dictionary containing:
157+
- 'summed': Scalar tensor with the sum of all weighted losses.
158+
- 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples.
159+
- 'per_example': Tensor of shape [batch, length] with individual losses per example.
160+
"""
161+
vocab_size = logits_batch.size(-1)
162+
163+
# Compute cross-entropy loss with label smoothing
164+
per_example_losses = torch.nn.functional.cross_entropy(
165+
logits_batch.view(-1, vocab_size),
166+
label_batch.view(-1),
167+
reduction='none',
168+
label_smoothing=label_smoothing
155169
)
156-
loss = loss.view_as(labels)
157-
158-
if weights is not None:
159-
loss = loss * weights
160-
161-
n_valid = (
162-
weights.sum()
163-
if weights is not None
164-
else torch.tensor(
165-
labels.numel(), dtype=torch.float32, device=labels.device
170+
per_example_losses = per_example_losses.view_as(label_batch)
171+
172+
# Apply weights if provided
173+
if mask_batch is not None:
174+
per_example_losses = per_example_losses * mask_batch
175+
176+
# Calculate number of valid examples
177+
n_valid_examples = (
178+
mask_batch.sum()
179+
if mask_batch is not None
180+
else torch.tensor(
181+
label_batch.numel(), dtype=torch.float32, device=label_batch.device
182+
)
166183
)
167-
)
168-
return {
169-
'summed': loss.sum(),
170-
'n_valid_examples': n_valid,
171-
'per_example': loss,
172-
}
184+
185+
return {
186+
'summed': per_example_losses.sum(),
187+
'n_valid_examples': n_valid_examples,
188+
'per_example': per_example_losses,
189+
}
173190

174191
def _eval_batch(
175192
self,
@@ -182,8 +199,8 @@ def _eval_batch(
182199
logits, _ = self.model_fn(
183200
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False
184201
)
185-
metrics = self.compute_weighted_cross_entropy(
186-
logits, batch['targets'], batch['weights']
202+
metrics = self.loss_fn(
203+
label_batch=batch['targets'], logits_batch=logits, mask_batch=batch['weights']
187204
)
188205
return {
189206
'loss': metrics['summed'].detach(),

algoperf/workloads/lm/workload.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ def _build_input_queue(
125125
) -> Iterator[Dict[str, Any]]:
126126
"""Build an input queue for the given split."""
127127

128+
@abc.abstractmethod
129+
def _eval_batch(
130+
self,
131+
params: spec.ParameterContainer,
132+
eval_batch: Dict[str, spec.Tensor],
133+
model_state: spec.ModelAuxiliaryState,
134+
rng: spec.RandomState,
135+
) -> Dict[str, float]:
136+
"""Evaluate the model on a single batch."""
137+
128138
def _eval_model_on_split(
129139
self,
130140
split: str,
@@ -168,20 +178,15 @@ def _normalize_eval_metrics(
168178
) -> Dict[str, float]:
169179
"""Normalize eval metrics."""
170180

181+
@abc.abstractmethod
171182
def loss_fn(
172183
self,
173184
label_batch: spec.Tensor,
174185
logits_batch: spec.Tensor,
175186
mask_batch: Optional[spec.Tensor] = None,
176187
label_smoothing: float = 0.0,
177188
) -> Dict[str, spec.Tensor]:
178-
"""Compute cross-entropy loss for language modeling in JAX."""
179-
return self.compute_weighted_cross_entropy(
180-
logits_batch,
181-
label_batch,
182-
weights=mask_batch,
183-
label_smoothing=label_smoothing,
184-
)
189+
"""Compute cross-entropy loss for language modeling."""
185190

186191
def is_output_params(self, param_name: str) -> bool:
187192
"""Return whether the given parameter is an output parameter."""

0 commit comments

Comments
 (0)