From dbfe29a847629ba8aca968abcb4c36889db47a92 Mon Sep 17 00:00:00 2001 From: "akilsurya.s" Date: Mon, 9 Mar 2026 12:57:13 -0700 Subject: [PATCH] Add optional early stopping to training loop --- train.py | 293 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 227 insertions(+), 66 deletions(-) diff --git a/train.py b/train.py index 6994fb9b..29d8a12c 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +# train.py """ Autoresearch pretraining script. Single-GPU, single-file. Cherry-picked and simplified from nanochat. @@ -208,8 +209,12 @@ def estimate_flops(self): """Estimated FLOPs per token (forward + backward).""" nparams = sum(p.numel() for p in self.parameters()) value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) - nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + - self.resid_lambdas.numel() + self.x0_lambdas.numel()) + nparams_exclude = ( + self.transformer.wte.weight.numel() + + value_embeds_numel + + self.resid_lambdas.numel() + + self.x0_lambdas.numel() + ) h = self.config.n_head q = self.config.n_embd // self.config.n_head t = self.config.sequence_len @@ -228,12 +233,23 @@ def num_scaling_params(self): scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() total = wte + value_embeds + lm_head + transformer_matrices + scalars return { - 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head, - 'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total, + "wte": wte, + "value_embeds": value_embeds, + "lm_head": lm_head, + "transformer_matrices": transformer_matrices, + "scalars": scalars, + "total": total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, - weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): + def setup_optimizer( + self, + unembedding_lr=0.004, + embedding_lr=0.2, + matrix_lr=0.02, + weight_decay=0.0, + adam_betas=(0.8, 0.95), + scalar_lr=0.5, + ): model_dim = self.config.n_embd matrix_params = list(self.transformer.h.parameters()) value_embeds_params = list(self.value_embeds.parameters()) @@ -241,30 +257,41 @@ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02 lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + - len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)) + assert len(list(self.parameters())) == ( + len(matrix_params) + + len(embedding_params) + + len(lm_head_params) + + len(value_embeds_params) + + len(resid_params) + + len(x0_params) + ) # Scale LR ∝ 1/√dmodel (tuned at 768 dim) dmodel_lr_scale = (model_dim / 768) ** -0.5 print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") param_groups = [ - dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), - dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), + dict(kind="adamw", params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind="adamw", params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind="adamw", params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind="adamw", params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0), + dict(kind="adamw", params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), ] for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] param_groups.append(dict( - kind='muon', params=group_params, lr=matrix_lr, - momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay, + kind="muon", + params=group_params, + lr=matrix_lr, + momentum=0.95, + ns_steps=5, + beta2=0.95, + weight_decay=weight_decay, )) optimizer = MuonAdamW(param_groups) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer - def forward(self, idx, targets=None, reduction='mean'): + def forward(self, idx, targets=None, reduction="mean"): B, T = idx.size() assert T <= self.cos.size(1) cos_sin = self.cos[:, :T], self.sin[:, :T] @@ -284,11 +311,16 @@ def forward(self, idx, targets=None, reduction='mean'): logits = softcap * torch.tanh(logits / softcap) if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), - ignore_index=-1, reduction=reduction) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-1, + reduction=reduction, + ) return loss return logits + # --------------------------------------------------------------------------- # Optimizer (MuonAdamW, single GPU only) # --------------------------------------------------------------------------- @@ -301,6 +333,7 @@ def forward(self, idx, targets=None, reduction='mean'): (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), ] + @torch.compile(dynamic=False, fullgraph=True) def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): p.mul_(1 - lr_t * wd_t) @@ -312,9 +345,20 @@ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_ step_size = lr_t / bias1 p.add_(exp_avg / denom, alpha=-step_size) + @torch.compile(dynamic=False, fullgraph=True) -def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer, - momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim): +def muon_step_fused( + stacked_grads, + stacked_params, + momentum_buffer, + second_momentum_buffer, + momentum_t, + lr_t, + wd_t, + beta2_t, + ns_steps, + red_dim, +): # Nesterov momentum momentum = momentum_t.to(stacked_grads.dtype) momentum_buffer.lerp_(stacked_grads, 1 - momentum) @@ -370,28 +414,37 @@ def __init__(self, param_groups): self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") def _step_adamw(self, group): - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad state = self.state[p] if not state: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p) - state['exp_avg_sq'] = torch.zeros_like(p) - state['step'] += 1 - self._adamw_step_t.fill_(state['step']) - self._adamw_lr_t.fill_(group['lr']) - self._adamw_beta1_t.fill_(group['betas'][0]) - self._adamw_beta2_t.fill_(group['betas'][1]) - self._adamw_eps_t.fill_(group['eps']) - self._adamw_wd_t.fill_(group['weight_decay']) - adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'], - self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t, - self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t) + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] += 1 + self._adamw_step_t.fill_(state["step"]) + self._adamw_lr_t.fill_(group["lr"]) + self._adamw_beta1_t.fill_(group["betas"][0]) + self._adamw_beta2_t.fill_(group["betas"][1]) + self._adamw_eps_t.fill_(group["eps"]) + self._adamw_wd_t.fill_(group["weight_decay"]) + adamw_step_fused( + p, + grad, + state["exp_avg"], + state["exp_avg_sq"], + self._adamw_step_t, + self._adamw_lr_t, + self._adamw_beta1_t, + self._adamw_beta2_t, + self._adamw_eps_t, + self._adamw_wd_t, + ) def _step_muon(self, group): - params = group['params'] + params = group["params"] if not params: return p = params[0] @@ -408,46 +461,64 @@ def _step_muon(self, group): stacked_params = torch.stack(params) self._muon_momentum_t.fill_(group["momentum"]) self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0) - self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5) + self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5) self._muon_wd_t.fill_(group["weight_decay"]) - muon_step_fused(stacked_grads, stacked_params, - state["momentum_buffer"], state["second_momentum_buffer"], - self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, - self._muon_beta2_t, group["ns_steps"], red_dim) + muon_step_fused( + stacked_grads, + stacked_params, + state["momentum_buffer"], + state["second_momentum_buffer"], + self._muon_momentum_t, + self._muon_lr_t, + self._muon_wd_t, + self._muon_beta2_t, + group["ns_steps"], + red_dim, + ) torch._foreach_copy_(params, list(stacked_params.unbind(0))) @torch.no_grad() def step(self): for group in self.param_groups: - if group['kind'] == 'adamw': + if group["kind"] == "adamw": self._step_adamw(group) - elif group['kind'] == 'muon': + elif group["kind"] == "muon": self._step_muon(group) + # --------------------------------------------------------------------------- # Hyperparameters (edit these directly, no CLI flags needed) # --------------------------------------------------------------------------- # Model architecture -ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO -HEAD_DIM = 128 # target head dimension for attention -WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context +ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO +HEAD_DIM = 128 # target head dimension for attention +WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context # Optimization -TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step -EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam) -UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam) -MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon) -SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam) -WEIGHT_DECAY = 0.2 # cautious weight decay for Muon -ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2 -WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup -WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown -FINAL_LR_FRAC = 0.0 # final LR as fraction of initial +TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step +EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam) +UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam) +MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon) +SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam) +WEIGHT_DECAY = 0.2 # cautious weight decay for Muon +ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2 +WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup +WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown +FINAL_LR_FRAC = 0.0 # final LR as fraction of initial # Model size -DEPTH = 8 # number of transformer layers -DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) +DEPTH = 8 # number of transformer layers +DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) + +# Optional early stopping +ENABLE_EARLY_STOP = False +EVAL_EVERY = 20 +EARLY_STOP_PATIENCE = 5 +MIN_DELTA = 1e-4 +TARGET_VAL_BPB = None +SAVE_BEST_CHECKPOINT = ENABLE_EARLY_STOP +BEST_CKPT_PATH = "best.pt" # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader @@ -465,16 +536,22 @@ def step(self): vocab_size = tokenizer.get_vocab_size() print(f"Vocab size: {vocab_size:,}") + def build_model_config(depth): base_dim = depth * ASPECT_RATIO model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM num_heads = model_dim // HEAD_DIM return GPTConfig( - sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, - n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, + sequence_len=MAX_SEQ_LEN, + vocab_size=vocab_size, + n_layer=depth, + n_head=num_heads, + n_kv_head=num_heads, + n_embd=model_dim, window_pattern=WINDOW_PATTERN, ) + config = build_model_config(DEPTH) print(f"Model config: {asdict(config)}") @@ -487,7 +564,7 @@ def build_model_config(depth): print("Parameter counts:") for key, value in param_counts.items(): print(f" {key:24s}: {value:,}") -num_params = param_counts['total'] +num_params = param_counts["total"] num_flops_per_token = model.estimate_flops() print(f"Estimated FLOPs per token: {num_flops_per_token:e}") @@ -511,6 +588,12 @@ def build_model_config(depth): print(f"Time budget: {TIME_BUDGET}s") print(f"Gradient accumulation steps: {grad_accum_steps}") +if ENABLE_EARLY_STOP: + print( + f"Early stopping enabled: eval_every={EVAL_EVERY}, " + f"patience={EARLY_STOP_PATIENCE}, min_delta={MIN_DELTA}, " + f"target_val_bpb={TARGET_VAL_BPB}" + ) # Schedules (all based on progress = training_time / TIME_BUDGET) @@ -523,13 +606,16 @@ def get_lr_multiplier(progress): cooldown = (1.0 - progress) / WARMDOWN_RATIO return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC + def get_muon_momentum(step): frac = min(step / 300, 1) return (1 - frac) * 0.85 + frac * 0.95 + def get_weight_decay(progress): return WEIGHT_DECAY * (1 - progress) + # --------------------------------------------------------------------------- # Training loop # --------------------------------------------------------------------------- @@ -539,9 +625,17 @@ def get_weight_decay(progress): total_training_time = 0 step = 0 +best_val_bpb = float("inf") +best_step = -1 +bad_evals = 0 +last_val_bpb = None +stopped_early = False +stop_reason = "time_budget" + while True: torch.cuda.synchronize() t0 = time.time() + for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) @@ -557,9 +651,10 @@ def get_weight_decay(progress): muon_weight_decay = get_weight_decay(progress) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': + if group["kind"] == "muon": group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay + optimizer.step() model.zero_grad(set_to_none=True) @@ -568,7 +663,7 @@ def get_weight_decay(progress): # Fast fail: abort if loss is exploding if train_loss_f > 100: print("FAIL") - exit(1) + raise SystemExit(1) torch.cuda.synchronize() t1 = time.time() @@ -580,13 +675,19 @@ def get_weight_decay(progress): # Logging ema_beta = 0.9 smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) pct_done = 100 * progress tok_per_sec = int(TOTAL_BATCH_SIZE / dt) mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS remaining = max(0, TIME_BUDGET - total_training_time) - print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True) + print( + f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | " + f"lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | " + f"mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", + end="", + flush=True, + ) # GC management (Python's GC causes ~500ms stalls) if step == 0: @@ -596,16 +697,66 @@ def get_weight_decay(progress): elif (step + 1) % 5000 == 0: gc.collect() + # Optional early stopping with periodic validation + if ENABLE_EARLY_STOP and step > 10 and step % EVAL_EVERY == 0: + print() + model.eval() + with torch.no_grad(), autocast_ctx: + val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) + model.train() + last_val_bpb = val_bpb + + improved = val_bpb < (best_val_bpb - MIN_DELTA) + if improved: + best_val_bpb = val_bpb + best_step = step + bad_evals = 0 + print(f"eval step {step:05d} | val_bpb: {val_bpb:.6f} | NEW BEST") + if SAVE_BEST_CHECKPOINT: + torch.save( + { + "model": model.state_dict(), + "config": asdict(config), + "step": step, + "val_bpb": val_bpb, + }, + BEST_CKPT_PATH, + ) + else: + bad_evals += 1 + print( + f"eval step {step:05d} | val_bpb: {val_bpb:.6f} | " + f"best: {best_val_bpb:.6f} | bad_evals: {bad_evals}/{EARLY_STOP_PATIENCE}" + ) + + if TARGET_VAL_BPB is not None and best_val_bpb <= TARGET_VAL_BPB: + stopped_early = True + stop_reason = "target_reached" + break + + if bad_evals >= EARLY_STOP_PATIENCE: + stopped_early = True + stop_reason = "patience_exhausted" + break + step += 1 # Time's up — but only stop after warmup steps so we don't count compilation if step > 10 and total_training_time >= TIME_BUDGET: + stop_reason = "time_budget" break print() # newline after \r training log total_tokens = step * TOTAL_BATCH_SIZE +# Restore best checkpoint if early stopping saved one +if ENABLE_EARLY_STOP and SAVE_BEST_CHECKPOINT and os.path.exists(BEST_CKPT_PATH): + ckpt = torch.load(BEST_CKPT_PATH, map_location=device) + model.load_state_dict(ckpt["model"]) + best_val_bpb = ckpt["val_bpb"] + best_step = ckpt["step"] + # Final eval model.eval() with autocast_ctx: @@ -614,16 +765,26 @@ def get_weight_decay(progress): # Final summary t_end = time.time() startup_time = t_start_training - t_start -steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0 +steady_state_mfu = ( + 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS + if total_training_time > 0 else 0 +) peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 print("---") print(f"val_bpb: {val_bpb:.6f}") +if ENABLE_EARLY_STOP: + print(f"best_val_bpb: {best_val_bpb:.6f}" if best_step >= 0 else "best_val_bpb: n/a") + print(f"best_step: {best_step}" if best_step >= 0 else "best_step: n/a") + print(f"last_val_bpb: {last_val_bpb:.6f}" if last_val_bpb is not None else "last_val_bpb: n/a") + print(f"stopped_early: {stopped_early}") + print(f"stop_reason: {stop_reason}") print(f"training_seconds: {total_training_time:.1f}") print(f"total_seconds: {t_end - t_start:.1f}") +print(f"startup_seconds: {startup_time:.1f}") print(f"peak_vram_mb: {peak_vram_mb:.1f}") print(f"mfu_percent: {steady_state_mfu:.2f}") print(f"total_tokens_M: {total_tokens / 1e6:.1f}") print(f"num_steps: {step}") print(f"num_params_M: {num_params / 1e6:.1f}") -print(f"depth: {DEPTH}") +print(f"depth: {DEPTH}") \ No newline at end of file