@@ -20,6 +20,7 @@ def align_for_step(
2020 logits : torch .Tensor , # shape: [1, total_seq_len, draft_vocab_size]
2121 targets : torch .Tensor , # shape: [1, total_seq_len, draft_vocab_size]
2222 loss_mask : torch .Tensor | None , # shape: [1, total_seq_len]
23+ prev_correct : torch .Tensor | None , # shape: [1, total_seq_len]
2324 ttt_step : int ,
2425):
2526 # There are no target values for the last ttt_step tokens, so we mask them out
@@ -40,24 +41,38 @@ def align_for_step(
4041 if loss_mask is not None :
4142 loss_mask = loss_mask [:, ttt_step :]
4243 # shape: [1, total_seq_len - ttt_step]
43- return logits , targets , loss_mask
44+ if prev_correct is not None :
45+ # Align with draft starts
46+ prev_correct = prev_correct [:, :- ttt_step ] if ttt_step > 0 else prev_correct
47+ # shape: [1, total_seq_len - ttt_step]
48+ return logits , targets , loss_mask , prev_correct
4449
4550
4651@torch .no_grad ()
4752def compute_accuracy (
4853 logits : torch .Tensor , # shape: [1, total_seq_len - ttt_step, draft_vocab_size]
4954 targets : torch .Tensor , # shape: [1, total_seq_len - ttt_step, draft_vocab_size]
5055 loss_mask : torch .Tensor | None , # shape: [1, total_seq_len - ttt_step]
56+ prev_correct : torch .Tensor | None , # shape: [1, total_seq_len - ttt_step]
5157):
5258 # Note: logits, targets, and loss_mask are already aligned for the current ttt_step
5359 target_tokens = torch .argmax (targets , dim = - 1 )
5460 predicted_tokens = torch .argmax (logits , dim = - 1 )
5561 # shape: [1, total_seq_len - ttt_step]
5662
5763 correct = predicted_tokens == target_tokens
64+ cond_denom : torch .Tensor | int = correct .numel ()
65+ if prev_correct is not None :
66+ cond_denom = prev_correct .sum ()
67+ # Update prev_correct in place
68+ correct = torch .logical_and (prev_correct , correct , out = prev_correct )
5869 if loss_mask is not None :
5970 correct = torch .masked_select (correct , loss_mask .to (torch .bool ))
60- return correct .float ().sum () / (correct .numel () + 1e-5 )
71+
72+ correct_sum = correct .float ().sum ()
73+ full_denom = correct .numel ()
74+
75+ return correct_sum / (full_denom + 1e-5 ), correct_sum / (cond_denom + 1e-5 )
6176
6277
6378def loss_function (
@@ -235,8 +250,13 @@ def forward(
235250 # shape: [1, total_seq_len, draft_vocab_size]
236251
237252 loss = torch .tensor (0.0 , device = device )
253+ prev_correct = (
254+ loss_mask .clone ()
255+ if loss_mask is not None
256+ else torch .ones (1 , total_seq_len , device = device , dtype = torch .bool )
257+ )
238258 draft_tokens = []
239- accuracy_list = []
259+ metrics = {}
240260 for ttt_step in range (ttt_steps ):
241261 with torch .no_grad ():
242262 input_embeds = self .embed_tokens (input_ids )
@@ -269,12 +289,19 @@ def forward(
269289 # shape: [1, total_seq_len, draft_vocab_size]
270290
271291 if return_loss :
272- s_logits , s_targets , s_loss_mask = align_for_step (
273- logits , target_logits , loss_mask , ttt_step
292+ s_logits , s_targets , s_loss_mask , s_prev_correct = align_for_step (
293+ logits , target_logits , loss_mask , prev_correct , ttt_step
274294 )
275295 loss_weight = self .ttt_step_loss_decay ** ttt_step
276- loss += loss_weight * loss_function (s_logits , s_targets , s_loss_mask )
277- accuracy_list .append (compute_accuracy (s_logits , s_targets , s_loss_mask ))
296+ s_loss = loss_weight * loss_function (s_logits , s_targets , s_loss_mask )
297+ loss += s_loss
298+
299+ s_full_acc , s_cond_acc = compute_accuracy (
300+ s_logits , s_targets , s_loss_mask , s_prev_correct
301+ )
302+ metrics [f"loss_{ ttt_step } " ] = s_loss .detach ().clone ()
303+ metrics [f"full_acc_{ ttt_step } " ] = s_full_acc
304+ metrics [f"cond_acc_{ ttt_step } " ] = s_cond_acc
278305
279306 input_ids = torch .argmax (logits , dim = - 1 )
280307 draft_tokens .append (input_ids .detach ().clone ())
@@ -303,11 +330,9 @@ def forward(
303330 position_ids = position_ids + 1
304331 # shape: [1, total_seq_len]
305332
333+ metrics ["loss" ] = loss .detach ().clone ()
334+
306335 if return_loss :
307- return (
308- draft_tokens ,
309- loss ,
310- torch .tensor (accuracy_list , device = device , dtype = torch .float ),
311- )
336+ return draft_tokens , loss , metrics
312337 else :
313338 return draft_tokens
0 commit comments