Skip to content

Commit 2cd5b7a

Browse files
committed
Improve metric calculation and logging
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 345b314 commit 2cd5b7a

File tree

2 files changed

+54
-41
lines changed

2 files changed

+54
-41
lines changed

src/speculators/train/eagle3/core.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def align_for_step(
2020
logits: torch.Tensor, # shape: [1, total_seq_len, draft_vocab_size]
2121
targets: torch.Tensor, # shape: [1, total_seq_len, draft_vocab_size]
2222
loss_mask: torch.Tensor | None, # shape: [1, total_seq_len]
23+
prev_correct: torch.Tensor | None, # shape: [1, total_seq_len]
2324
ttt_step: int,
2425
):
2526
# There are no target values for the last ttt_step tokens, so we mask them out
@@ -40,24 +41,38 @@ def align_for_step(
4041
if loss_mask is not None:
4142
loss_mask = loss_mask[:, ttt_step:]
4243
# shape: [1, total_seq_len - ttt_step]
43-
return logits, targets, loss_mask
44+
if prev_correct is not None:
45+
# Align with draft starts
46+
prev_correct = prev_correct[:, :-ttt_step] if ttt_step > 0 else prev_correct
47+
# shape: [1, total_seq_len - ttt_step]
48+
return logits, targets, loss_mask, prev_correct
4449

4550

4651
@torch.no_grad()
4752
def compute_accuracy(
4853
logits: torch.Tensor, # shape: [1, total_seq_len - ttt_step, draft_vocab_size]
4954
targets: torch.Tensor, # shape: [1, total_seq_len - ttt_step, draft_vocab_size]
5055
loss_mask: torch.Tensor | None, # shape: [1, total_seq_len - ttt_step]
56+
prev_correct: torch.Tensor | None, # shape: [1, total_seq_len - ttt_step]
5157
):
5258
# Note: logits, targets, and loss_mask are already aligned for the current ttt_step
5359
target_tokens = torch.argmax(targets, dim=-1)
5460
predicted_tokens = torch.argmax(logits, dim=-1)
5561
# shape: [1, total_seq_len - ttt_step]
5662

5763
correct = predicted_tokens == target_tokens
64+
cond_denom: torch.Tensor | int = correct.numel()
65+
if prev_correct is not None:
66+
cond_denom = prev_correct.sum()
67+
# Update prev_correct in place
68+
correct = torch.logical_and(prev_correct, correct, out=prev_correct)
5869
if loss_mask is not None:
5970
correct = torch.masked_select(correct, loss_mask.to(torch.bool))
60-
return correct.float().sum() / (correct.numel() + 1e-5)
71+
72+
correct_sum = correct.float().sum()
73+
full_denom = correct.numel()
74+
75+
return correct_sum / (full_denom + 1e-5), correct_sum / (cond_denom + 1e-5)
6176

6277

6378
def loss_function(
@@ -235,8 +250,13 @@ def forward(
235250
# shape: [1, total_seq_len, draft_vocab_size]
236251

237252
loss = torch.tensor(0.0, device=device)
253+
prev_correct = (
254+
loss_mask.clone()
255+
if loss_mask is not None
256+
else torch.ones(1, total_seq_len, device=device, dtype=torch.bool)
257+
)
238258
draft_tokens = []
239-
accuracy_list = []
259+
metrics = {}
240260
for ttt_step in range(ttt_steps):
241261
with torch.no_grad():
242262
input_embeds = self.embed_tokens(input_ids)
@@ -269,12 +289,19 @@ def forward(
269289
# shape: [1, total_seq_len, draft_vocab_size]
270290

271291
if return_loss:
272-
s_logits, s_targets, s_loss_mask = align_for_step(
273-
logits, target_logits, loss_mask, ttt_step
292+
s_logits, s_targets, s_loss_mask, s_prev_correct = align_for_step(
293+
logits, target_logits, loss_mask, prev_correct, ttt_step
274294
)
275295
loss_weight = self.ttt_step_loss_decay**ttt_step
276-
loss += loss_weight * loss_function(s_logits, s_targets, s_loss_mask)
277-
accuracy_list.append(compute_accuracy(s_logits, s_targets, s_loss_mask))
296+
s_loss = loss_weight * loss_function(s_logits, s_targets, s_loss_mask)
297+
loss += s_loss
298+
299+
s_full_acc, s_cond_acc = compute_accuracy(
300+
s_logits, s_targets, s_loss_mask, s_prev_correct
301+
)
302+
metrics[f"loss_{ttt_step}"] = s_loss.detach().clone()
303+
metrics[f"full_acc_{ttt_step}"] = s_full_acc
304+
metrics[f"cond_acc_{ttt_step}"] = s_cond_acc
278305

279306
input_ids = torch.argmax(logits, dim=-1)
280307
draft_tokens.append(input_ids.detach().clone())
@@ -303,11 +330,9 @@ def forward(
303330
position_ids = position_ids + 1
304331
# shape: [1, total_seq_len]
305332

333+
metrics["loss"] = loss.detach().clone()
334+
306335
if return_loss:
307-
return (
308-
draft_tokens,
309-
loss,
310-
torch.tensor(accuracy_list, device=device, dtype=torch.float),
311-
)
336+
return draft_tokens, loss, metrics
312337
else:
313338
return draft_tokens

src/speculators/train/trainer.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def train_epoch(self, epoch: int):
118118
for k, v in batch.items()
119119
}
120120

121-
_draft_tokens, loss, draft_accuracies = self.model(
121+
_draft_tokens, loss, metrics = self.model(
122122
**gpu_batch, **self.config.train_call_kwargs
123123
)
124124

@@ -127,18 +127,13 @@ def train_epoch(self, epoch: int):
127127
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
128128
self.opt.step()
129129

130-
loss = loss.detach().clone()
131130
if self.is_distributed:
132-
# Note: this is not needed for training, just for logging
133-
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
134-
dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG)
131+
for v in metrics.values():
132+
dist.reduce(v, dst=0, op=dist.ReduceOp.AVG)
135133

136-
acc_values = {
137-
f"acc_{i}": acc.item() for i, acc in enumerate(draft_accuracies)
138-
}
134+
metrics = {k: v.item() for k, v in metrics.items()}
139135
metric_logger.info(
140-
{"train": {"loss": loss.item(), **acc_values}, "epoch": epoch},
141-
extra={"step": self.global_step},
136+
{"train": metrics, "epoch": epoch}, extra={"step": self.global_step}
142137
)
143138
self.global_step += 1
144139

@@ -152,36 +147,29 @@ def val_epoch(self, epoch: int):
152147
val_loader = self.val_loader
153148
if self.local_rank == 0:
154149
val_loader = tqdm(val_loader, desc=f"Epoch {epoch}") # type: ignore[assignment]
155-
val_loss = torch.zeros(1, device=self.local_rank)
156-
val_accuracies = torch.zeros(
157-
(), device=self.local_rank
158-
) # initialize to tensor of shape ()
150+
151+
val_metrics: dict[str, float] = {}
152+
num_batches = len(val_loader)
159153
for batch in val_loader:
160154
gpu_batch = {
161155
k: v.to(self.local_rank) if isinstance(v, torch.Tensor) else v
162156
for k, v in batch.items()
163157
}
164158

165-
_draft_tokens, loss, draft_accuracies = self.model(
159+
_draft_tokens, _loss, metrics = self.model(
166160
**gpu_batch, **self.config.val_call_kwargs
167161
)
168162

169163
if self.is_distributed:
170-
dist.reduce(val_loss, dst=0, op=dist.ReduceOp.AVG)
171-
dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG)
172-
173-
val_loss += loss.detach().clone()
174-
# Can't use += here because val_accuracies has shape () on first iteration
175-
val_accuracies = val_accuracies + draft_accuracies.detach()
176-
177-
val_loss /= len(val_loader)
178-
val_accuracies /= len(val_loader)
179-
acc_values = {
180-
f"acc_{i}_epoch": acc.item() for i, acc in enumerate(val_accuracies)
181-
}
164+
for v in metrics.values():
165+
dist.reduce(v, dst=0, op=dist.ReduceOp.AVG)
166+
167+
for k, v in metrics.items():
168+
val_metrics[k] = val_metrics.get(k, 0.0) + v.item()
169+
170+
val_metrics = {f"{k}_epoch": v / num_batches for k, v in val_metrics.items()}
182171
metric_logger.info(
183-
{"val": {"loss_epoch": val_loss.item(), **acc_values}, "epoch": epoch},
184-
extra={"step": self.global_step},
172+
{"val": val_metrics, "epoch": epoch}, extra={"step": self.global_step}
185173
)
186174

187175
def save_checkpoint(self, epoch: int):

0 commit comments

Comments
 (0)