diff --git a/README.md b/README.md index 8459259ab..6ccfbf04e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ By design, training runs for a **fixed 5-minute time budget** (wall clock, exclu ## Quick start -**Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/). +**Requirements:** A single NVIDIA GPU with CUDA access (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/). ```bash @@ -35,7 +35,7 @@ uv run prepare.py uv run train.py ``` -If the above commands all work ok, your setup is working and you can go into autonomous research mode. +If the above commands all work ok, your setup is working and you can go into autonomous research mode. If `train.py` says tokenizer or parquet data is missing, run `uv run prepare.py` first to populate the cache. On unsupported environments, `train.py` and the CUDA-only helpers now stop early with a clear error explaining that this repo expects one visible NVIDIA GPU with CUDA access. ## Running the agent @@ -64,7 +64,7 @@ pyproject.toml — dependencies ## Platform support -This code currently requires that you have a single NVIDIA GPU. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc. +This code currently requires that you have a single NVIDIA GPU with CUDA available. The repository now checks this explicitly and fails early with actionable guidance instead of crashing deep inside PyTorch CUDA calls. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc. Seeing as there seems to be a lot of interest in tinkering with autoresearch on much smaller compute platforms than an H100, a few extra words. If you're going to try running autoresearch on smaller computers (Macbooks etc.), I'd recommend one of the forks below. On top of this, here are some recommendations for how to tune the defaults for much smaller models for aspiring forks: diff --git a/prepare.py b/prepare.py index 62607b9af..cb1b5e5e6 100644 --- a/prepare.py +++ b/prepare.py @@ -18,6 +18,7 @@ from multiprocessing import Pool import requests +import pyarrow as pa import pyarrow.parquet as pq import rustbpe import tiktoken @@ -151,8 +152,9 @@ def train_tokenizer(): parquet_files = list_parquet_files() if len(parquet_files) < 2: - print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.") - sys.exit(1) + raise RuntimeError( + f"Tokenizer training requires at least 2 data shards (1 train + 1 val), but found {len(parquet_files)} in {DATA_DIR}. Run `uv run prepare.py --num-shards 1` or higher to download enough data." + ) # --- Train with rustbpe --- print("Tokenizer: training BPE tokenizer...") @@ -215,8 +217,18 @@ def __init__(self, enc): @classmethod def from_directory(cls, tokenizer_dir=TOKENIZER_DIR): - with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f: - enc = pickle.load(f) + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.pkl") + if not os.path.exists(tokenizer_path): + raise RuntimeError( + f"Tokenizer not found at {tokenizer_path}. Run `uv run prepare.py` first to download data and train the tokenizer." + ) + try: + with open(tokenizer_path, "rb") as f: + enc = pickle.load(f) + except (pickle.UnpicklingError, EOFError, AttributeError, ValueError, OSError) as err: + raise RuntimeError( + f"Tokenizer at {tokenizer_path} could not be loaded. Run `uv run prepare.py` again to rebuild the tokenizer cache." + ) from err return cls(enc) def get_vocab_size(self): @@ -247,23 +259,61 @@ def decode(self, ids): def get_token_bytes(device="cpu"): path = os.path.join(TOKENIZER_DIR, "token_bytes.pt") - with open(path, "rb") as f: - return torch.load(f, map_location=device) + if not os.path.exists(path): + raise RuntimeError( + f"token_bytes cache not found at {path}. Run `uv run prepare.py` first to build the tokenizer artifacts." + ) + try: + with open(path, "rb") as f: + return torch.load(f, map_location=device) + except (OSError, RuntimeError, pickle.UnpicklingError, EOFError, ValueError) as err: + raise RuntimeError( + f"token_bytes cache at {path} could not be loaded. Run `uv run prepare.py` again to rebuild the tokenizer artifacts." + ) from err + + +def require_cuda_environment(context): + if not torch.cuda.is_available(): + raise RuntimeError( + f"{context} requires a CUDA-enabled NVIDIA GPU, but PyTorch cannot access CUDA. " + "Run this repository on a machine with CUDA available, or use a platform-specific fork listed in README.md." + ) + if torch.cuda.device_count() < 1: + raise RuntimeError( + f"{context} requires one visible NVIDIA GPU, but torch.cuda.device_count() returned 0. " + "Check your NVIDIA driver, CUDA runtime, container GPU passthrough, and CUDA_VISIBLE_DEVICES." + ) def _document_batches(split, tokenizer_batch_size=128): """Infinite iterator over document batches from parquet files.""" parquet_paths = list_parquet_files() - assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first." + if len(parquet_paths) == 0: + raise RuntimeError( + f"No parquet data shards found in {DATA_DIR}. Run `uv run prepare.py` first to download the dataset and train the tokenizer." + ) val_path = os.path.join(DATA_DIR, VAL_FILENAME) if split == "train": parquet_paths = [p for p in parquet_paths if p != val_path] + if len(parquet_paths) == 0: + raise RuntimeError( + f"No training parquet shards found in {DATA_DIR}. Run `uv run prepare.py --num-shards 1` or higher to download training data." + ) else: + if not os.path.exists(val_path): + raise RuntimeError( + f"Validation shard not found at {val_path}. Run `uv run prepare.py` first to download the pinned validation shard." + ) parquet_paths = [val_path] epoch = 1 while True: for filepath in parquet_paths: - pf = pq.ParquetFile(filepath) + try: + pf = pq.ParquetFile(filepath) + except (OSError, ValueError, TypeError, pa.ArrowInvalid, pa.ArrowIOError) as err: + raise RuntimeError( + f"Parquet shard at {filepath} could not be opened. Run `uv run prepare.py` again to rebuild the cached dataset shards." + ) from err for rg_idx in range(pf.num_row_groups): rg = pf.read_row_group(rg_idx) batch = rg.column('text').to_pylist() @@ -279,6 +329,7 @@ def make_dataloader(tokenizer, B, T, split, buffer_size=1000): When no document fits remaining space, crops shortest doc to fill exactly. 100% utilization (no padding). """ + require_cuda_environment("make_dataloader") assert split in ["train", "val"] row_capacity = T + 1 batches = _document_batches(split) @@ -348,6 +399,7 @@ def evaluate_bpb(model, tokenizer, batch_size): are excluded from both sums. Uses fixed MAX_SEQ_LEN so results are comparable across configs. """ + require_cuda_environment("evaluate_bpb") token_bytes = get_token_bytes(device="cuda") val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val") steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN) @@ -368,21 +420,25 @@ def evaluate_bpb(model, tokenizer, batch_size): # --------------------------------------------------------------------------- if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") - parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") - parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") - args = parser.parse_args() - - num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards - - print(f"Cache directory: {CACHE_DIR}") - print() - - # Step 1: Download data - download_data(num_shards, download_workers=args.download_workers) - print() - - # Step 2: Train tokenizer - train_tokenizer() - print() - print("Done! Ready to train.") + try: + parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch") + parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.") + parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers") + args = parser.parse_args() + + num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards + + print(f"Cache directory: {CACHE_DIR}") + print() + + # Step 1: Download data + download_data(num_shards, download_workers=args.download_workers) + print() + + # Step 2: Train tokenizer + train_tokenizer() + print() + print("Done! Ready to train.") + except RuntimeError as err: + print(f"ERROR: {err}", file=sys.stderr) + raise SystemExit(1) diff --git a/train.py b/train.py index 6994fb9bb..76ffce107 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" import gc +import sys import time from dataclasses import dataclass, asdict @@ -16,11 +17,23 @@ import torch.nn as nn import torch.nn.functional as F -from kernels import get_kernel -cap = torch.cuda.get_device_capability() -# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs -repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" -fa3 = get_kernel(repo).flash_attn_interface + +def require_cuda_environment(context): + if not torch.cuda.is_available(): + raise RuntimeError( + f"{context} requires a CUDA-enabled NVIDIA GPU, but PyTorch cannot access CUDA. " + "Install a CUDA-enabled PyTorch build and run on a machine with an available NVIDIA GPU. " + "If you are on CPU, Apple Silicon, or another unsupported platform, use a platform-specific fork instead; see README.md for notable forks." + ) + if torch.cuda.device_count() < 1: + raise RuntimeError( + f"{context} requires one visible NVIDIA GPU, but torch.cuda.device_count() returned 0. " + "Check your NVIDIA driver, CUDA runtime, container GPU passthrough, and CUDA_VISIBLE_DEVICES." + ) + return torch.cuda.get_device_capability() + + +fa3 = None from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb @@ -424,6 +437,119 @@ def step(self): elif group['kind'] == 'muon': self._step_muon(group) +# --------------------------------------------------------------------------- +# Checkpoint utilities +# --------------------------------------------------------------------------- + +CHECKPOINT_DIR = "/tmp/autoresearch_checkpoints" +PRE_EVAL_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "pre_eval.pt") + +# Evaluation retry configuration +MAX_RETRY_ATTEMPTS = 3 # Maximum number of retry attempts on evaluation failure +MIN_BATCH_SIZE = 4 # Minimum batch size to try before giving up + + +def is_oom_error(error): + """Check if the error is an out-of-memory error that might succeed with smaller batch.""" + error_str = str(error).lower() + oom_keywords = [ + "out of memory", "oom", "cuda out of memory", "cudamalloc", + "memory error", "allocate", "allocation", "memory exhausted", + "memoryerror", "outofmemoryerror", "cuda error", + "illegal address", "assertion", "launch failed", + "cudart", "cuda_runtime", "cudaerror", "torch.cuda.OutOfMemoryError", + "runtimeerror: cuda", "allocation failed", "cudaallocate", + "device-side assertion", "view size", "buffer" + ] + return any(keyword in error_str for keyword in oom_keywords) + + +def get_gpu_memory_info(): + """Get current GPU memory usage information in MB.""" + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / 1024 / 1024 + reserved = torch.cuda.memory_reserved() / 1024 / 1024 + max_allocated = torch.cuda.max_memory_allocated() / 1024 / 1024 + return { + "allocated_mb": allocated, + "reserved_mb": reserved, + "max_allocated_mb": max_allocated + } + return None + + +def clear_gpu_memory(): + """Clear GPU memory cache and run garbage collection.""" + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + except Exception as e: + print(f"Warning: torch.cuda.empty_cache() failed: {e}", file=sys.stderr) + try: + gc.collect() + except Exception as e: + print(f"Warning: gc.collect() failed: {e}", file=sys.stderr) + + +def reset_gpu_memory_stats(): + """Reset GPU memory statistics for accurate memory diagnostics.""" + try: + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + except Exception as e: + print(f"Warning: torch.cuda.reset_peak_memory_stats() failed: {e}", file=sys.stderr) + + +def save_checkpoint(model, optimizer, step, total_training_time, smooth_train_loss): + """Save training checkpoint before evaluation.""" + try: + os.makedirs(CHECKPOINT_DIR, exist_ok=True) + torch.save({ + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "step": step, + "total_training_time": total_training_time, + "smooth_train_loss": smooth_train_loss, + }, PRE_EVAL_CHECKPOINT) + print(f"Checkpoint saved to {PRE_EVAL_CHECKPOINT}") + except (OSError, IOError, RuntimeError) as e: + print(f"Warning: Failed to save checkpoint: {e}", file=sys.stderr) + + +def load_checkpoint(model, optimizer): + """Load training checkpoint. Returns (step, total_training_time, smooth_train_loss) or None on failure.""" + if not os.path.exists(PRE_EVAL_CHECKPOINT): + return None + try: + # Clear GPU memory before loading checkpoint + clear_gpu_memory() + + checkpoint = torch.load(PRE_EVAL_CHECKPOINT, map_location="cuda") + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + # Ensure model is in eval mode after loading checkpoint + model.eval() + + print(f"Checkpoint loaded from {PRE_EVAL_CHECKPOINT}") + return checkpoint["step"], checkpoint["total_training_time"], checkpoint["smooth_train_loss"] + except (OSError, RuntimeError, KeyError) as e: + print(f"Warning: Failed to load checkpoint: {e}", file=sys.stderr) + return None + + +def delete_checkpoint(): + """Delete pre-evaluation checkpoint to free disk space.""" + try: + if os.path.exists(PRE_EVAL_CHECKPOINT): + os.remove(PRE_EVAL_CHECKPOINT) + print(f"Checkpoint deleted: {PRE_EVAL_CHECKPOINT}") + except OSError as e: + print(f"Warning: Failed to delete checkpoint: {e}", file=sys.stderr) + + # --------------------------------------------------------------------------- # Hyperparameters (edit these directly, no CLI flags needed) # --------------------------------------------------------------------------- @@ -450,180 +576,298 @@ def step(self): DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) # --------------------------------------------------------------------------- -# Setup: tokenizer, model, optimizer, dataloader +# Setup + training entrypoint # --------------------------------------------------------------------------- -t_start = time.time() -torch.manual_seed(42) -torch.cuda.manual_seed(42) -torch.set_float32_matmul_precision("high") -device = torch.device("cuda") -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) -H100_BF16_PEAK_FLOPS = 989.5e12 - -tokenizer = Tokenizer.from_directory() -vocab_size = tokenizer.get_vocab_size() -print(f"Vocab size: {vocab_size:,}") - -def build_model_config(depth): - base_dim = depth * ASPECT_RATIO - model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM - num_heads = model_dim // HEAD_DIM - return GPTConfig( - sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, - n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, - window_pattern=WINDOW_PATTERN, +def main(): + global fa3 + + cap = require_cuda_environment("train.py") + + from kernels import get_kernel + # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + + t_start = time.time() + torch.manual_seed(42) + torch.cuda.manual_seed(42) + torch.set_float32_matmul_precision("high") + device = torch.device("cuda") + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + H100_BF16_PEAK_FLOPS = 989.5e12 + + tokenizer = Tokenizer.from_directory() + vocab_size = tokenizer.get_vocab_size() + print(f"Vocab size: {vocab_size:,}") + + def build_model_config(depth): + base_dim = depth * ASPECT_RATIO + model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM + num_heads = model_dim // HEAD_DIM + return GPTConfig( + sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, + n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, + window_pattern=WINDOW_PATTERN, + ) + + config = build_model_config(DEPTH) + print(f"Model config: {asdict(config)}") + + with torch.device("meta"): + model = GPT(config) + model.to_empty(device=device) + model.init_weights() + + param_counts = model.num_scaling_params() + print("Parameter counts:") + for key, value in param_counts.items(): + print(f" {key:24s}: {value:,}") + num_params = param_counts['total'] + num_flops_per_token = model.estimate_flops() + print(f"Estimated FLOPs per token: {num_flops_per_token:e}") + + tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN + assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 + grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd + + optimizer = model.setup_optimizer( + unembedding_lr=UNEMBEDDING_LR, + embedding_lr=EMBEDDING_LR, + scalar_lr=SCALAR_LR, + adam_betas=ADAM_BETAS, + matrix_lr=MATRIX_LR, + weight_decay=WEIGHT_DECAY, ) -config = build_model_config(DEPTH) -print(f"Model config: {asdict(config)}") - -with torch.device("meta"): - model = GPT(config) -model.to_empty(device=device) -model.init_weights() - -param_counts = model.num_scaling_params() -print("Parameter counts:") -for key, value in param_counts.items(): - print(f" {key:24s}: {value:,}") -num_params = param_counts['total'] -num_flops_per_token = model.estimate_flops() -print(f"Estimated FLOPs per token: {num_flops_per_token:e}") - -tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 -grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd - -optimizer = model.setup_optimizer( - unembedding_lr=UNEMBEDDING_LR, - embedding_lr=EMBEDDING_LR, - scalar_lr=SCALAR_LR, - adam_betas=ADAM_BETAS, - matrix_lr=MATRIX_LR, - weight_decay=WEIGHT_DECAY, -) - -model = torch.compile(model, dynamic=False) - -train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") -x, y, epoch = next(train_loader) # prefetch first batch - -print(f"Time budget: {TIME_BUDGET}s") -print(f"Gradient accumulation steps: {grad_accum_steps}") - -# Schedules (all based on progress = training_time / TIME_BUDGET) - -def get_lr_multiplier(progress): - if progress < WARMUP_RATIO: - return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 - elif progress < 1.0 - WARMDOWN_RATIO: - return 1.0 - else: - cooldown = (1.0 - progress) / WARMDOWN_RATIO - return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC - -def get_muon_momentum(step): - frac = min(step / 300, 1) - return (1 - frac) * 0.85 + frac * 0.95 - -def get_weight_decay(progress): - return WEIGHT_DECAY * (1 - progress) - -# --------------------------------------------------------------------------- -# Training loop -# --------------------------------------------------------------------------- - -t_start_training = time.time() -smooth_train_loss = 0 -total_training_time = 0 -step = 0 - -while True: - torch.cuda.synchronize() - t0 = time.time() - for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) - train_loss = loss.detach() - loss = loss / grad_accum_steps - loss.backward() - x, y, epoch = next(train_loader) - - # Progress and schedules - progress = min(total_training_time / TIME_BUDGET, 1.0) - lrm = get_lr_multiplier(progress) - muon_momentum = get_muon_momentum(step) - muon_weight_decay = get_weight_decay(progress) - for group in optimizer.param_groups: - group["lr"] = group["initial_lr"] * lrm - if group['kind'] == 'muon': - group["momentum"] = muon_momentum - group["weight_decay"] = muon_weight_decay - optimizer.step() - model.zero_grad(set_to_none=True) - - train_loss_f = train_loss.item() - - # Fast fail: abort if loss is exploding - if train_loss_f > 100: - print("FAIL") - exit(1) - - torch.cuda.synchronize() - t1 = time.time() - dt = t1 - t0 - - if step > 10: - total_training_time += dt - - # Logging - ema_beta = 0.9 - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) - pct_done = 100 * progress - tok_per_sec = int(TOTAL_BATCH_SIZE / dt) - mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS - remaining = max(0, TIME_BUDGET - total_training_time) - - print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True) - - # GC management (Python's GC causes ~500ms stalls) - if step == 0: - gc.collect() - gc.freeze() - gc.disable() - elif (step + 1) % 5000 == 0: - gc.collect() - - step += 1 - - # Time's up — but only stop after warmup steps so we don't count compilation - if step > 10 and total_training_time >= TIME_BUDGET: - break - -print() # newline after \r training log - -total_tokens = step * TOTAL_BATCH_SIZE - -# Final eval -model.eval() -with autocast_ctx: - val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE) - -# Final summary -t_end = time.time() -startup_time = t_start_training - t_start -steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0 -peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 - -print("---") -print(f"val_bpb: {val_bpb:.6f}") -print(f"training_seconds: {total_training_time:.1f}") -print(f"total_seconds: {t_end - t_start:.1f}") -print(f"peak_vram_mb: {peak_vram_mb:.1f}") -print(f"mfu_percent: {steady_state_mfu:.2f}") -print(f"total_tokens_M: {total_tokens / 1e6:.1f}") -print(f"num_steps: {step}") -print(f"num_params_M: {num_params / 1e6:.1f}") -print(f"depth: {DEPTH}") + model = torch.compile(model, dynamic=False) + + train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") + x, y, epoch = next(train_loader) # prefetch first batch + + print(f"Time budget: {TIME_BUDGET}s") + print(f"Gradient accumulation steps: {grad_accum_steps}") + + # Schedules (all based on progress = training_time / TIME_BUDGET) + def get_lr_multiplier(progress): + if progress < WARMUP_RATIO: + return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 + elif progress < 1.0 - WARMDOWN_RATIO: + return 1.0 + else: + cooldown = (1.0 - progress) / WARMDOWN_RATIO + return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC + + def get_muon_momentum(step): + frac = min(step / 300, 1) + return (1 - frac) * 0.85 + frac * 0.95 + + def get_weight_decay(progress): + return WEIGHT_DECAY * (1 - progress) + + # --------------------------------------------------------------------------- + # Training loop + # --------------------------------------------------------------------------- + + t_start_training = time.time() + smooth_train_loss = 0 + total_training_time = 0 + step = 0 + + while True: + torch.cuda.synchronize() + t0 = time.time() + for micro_step in range(grad_accum_steps): + with autocast_ctx: + loss = model(x, y) + train_loss = loss.detach() + loss = loss / grad_accum_steps + loss.backward() + x, y, epoch = next(train_loader) + + # Progress and schedules + progress = min(total_training_time / TIME_BUDGET, 1.0) + lrm = get_lr_multiplier(progress) + muon_momentum = get_muon_momentum(step) + muon_weight_decay = get_weight_decay(progress) + for group in optimizer.param_groups: + group["lr"] = group["initial_lr"] * lrm + if group['kind'] == 'muon': + group["momentum"] = muon_momentum + group["weight_decay"] = muon_weight_decay + optimizer.step() + model.zero_grad(set_to_none=True) + + train_loss_f = train_loss.item() + + # Fast fail: abort if loss is exploding + if train_loss_f > 100: + print("FAIL") + raise SystemExit(1) + + torch.cuda.synchronize() + t1 = time.time() + dt = t1 - t0 + + if step > 10: + total_training_time += dt + + # Logging + ema_beta = 0.9 + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) + pct_done = 100 * progress + tok_per_sec = int(TOTAL_BATCH_SIZE / dt) + mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS + remaining = max(0, TIME_BUDGET - total_training_time) + + print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True) + + # GC management (Python's GC causes ~500ms stalls) + if step == 0: + gc.collect() + gc.freeze() + gc.disable() + elif (step + 1) % 5000 == 0: + gc.collect() + + step += 1 + + # Time's up — but only stop after warmup steps so we don't count compilation + if step > 10 and total_training_time >= TIME_BUDGET: + break + + print() # newline after \r training log + + total_tokens = step * TOTAL_BATCH_SIZE + + # Final eval - save checkpoint before evaluation to preserve training results + # in case evaluation crashes (OOM, CUDA error, etc.) + model.eval() + print("Saving pre-evaluation checkpoint...") + save_checkpoint(model, optimizer, step, total_training_time, smooth_train_loss) + + # Reset GPU memory stats before evaluation to get accurate diagnostics + reset_gpu_memory_stats() + + # Try evaluation with retry mechanism; if it fails (OOM, CUDA error, etc.), + # training results are preserved in the checkpoint and can be recovered + eval_failed = False + final_val_bpb = None + + # First attempt with original batch size + current_batch_size = DEVICE_BATCH_SIZE + attempt = 0 + + while attempt < MAX_RETRY_ATTEMPTS: + attempt += 1 + # Output GPU memory info before each evaluation attempt + mem_info = get_gpu_memory_info() + if mem_info: + print(f"GPU memory before evaluation (attempt {attempt}): allocated={mem_info['allocated_mb']:.1f}MB, reserved={mem_info['reserved_mb']:.1f}MB, max_allocated={mem_info['max_allocated_mb']:.1f}MB") + + try: + with autocast_ctx: + val_bpb = evaluate_bpb(model, tokenizer, current_batch_size) + # Evaluation succeeded + final_val_bpb = val_bpb + if attempt > 1: + print(f"Evaluation succeeded on attempt {attempt} with batch size {current_batch_size}: val_bpb={val_bpb:.6f}") + break + except RuntimeError as eval_error: + error_type = "OOM/Retryable" if is_oom_error(eval_error) else "RuntimeError" + print(f"Evaluation attempt {attempt}/{MAX_RETRY_ATTEMPTS} failed ({error_type}): {eval_error}", file=sys.stderr) + + # Output GPU memory info after failure + mem_info_after = get_gpu_memory_info() + if mem_info_after: + print(f"GPU memory after failure: allocated={mem_info_after['allocated_mb']:.1f}MB, reserved={mem_info_after['reserved_mb']:.1f}MB", file=sys.stderr) + + # Check if we should retry + should_retry = ( + is_oom_error(eval_error) and + current_batch_size > MIN_BATCH_SIZE and + attempt < MAX_RETRY_ATTEMPTS + ) + + if should_retry: + # Wait for GPU memory to be released before retrying + wait_time = 2 * attempt # Progressive wait: 2s, 4s, 6s + print(f"Waiting {wait_time}s for GPU memory to be released...") + time.sleep(wait_time) + + # Reset GPU memory stats for accurate diagnostics on next attempt + reset_gpu_memory_stats() + + # Clear GPU memory before loading checkpoint and retrying + print("Clearing GPU memory before retry...") + clear_gpu_memory() + + # Load checkpoint to recover training state before retry + print("Loading checkpoint to recover training results...") + recovered = load_checkpoint(model, optimizer) + if recovered is not None: + step, total_training_time, smooth_train_loss = recovered + print(f"Recovered: step={step}, training_time={total_training_time:.1f}s") + + # Ensure model is in eval mode for evaluation + model.eval() + + # Decrease batch size for next attempt + current_batch_size = max(MIN_BATCH_SIZE, current_batch_size // 2) + print(f"Retrying evaluation with smaller batch size ({current_batch_size})...") + + # Output GPU memory after checkpoint restore + mem_info_restore = get_gpu_memory_info() + if mem_info_restore: + print(f"GPU memory after checkpoint restore: allocated={mem_info_restore['allocated_mb']:.1f}MB, reserved={mem_info_restore['reserved_mb']:.1f}MB") + else: + # Failed to load checkpoint, cannot retry + print("Failed to load checkpoint, cannot retry evaluation", file=sys.stderr) + eval_failed = True + break + else: + # Cannot retry - either not an OOM error or we've exhausted retries + if not is_oom_error(eval_error): + print("Non-OOM error detected, not retrying", file=sys.stderr) + if current_batch_size <= MIN_BATCH_SIZE: + print(f"Batch size already at minimum ({MIN_BATCH_SIZE}), not retrying", file=sys.stderr) + eval_failed = True + break + + # Clean up checkpoint file after successful evaluation + if final_val_bpb is not None: + delete_checkpoint() + val_bpb = final_val_bpb + + if eval_failed or final_val_bpb is None: + # If evaluation failed, report failure but preserve checkpoint for recovery + print("FAIL: Evaluation crashed after all retry attempts - training results saved in checkpoint") + print(f"Checkpoint available at: {PRE_EVAL_CHECKPOINT}") + print(f"To recover, manually load from checkpoint and re-run evaluation") + raise SystemExit(1) + + # Final summary + t_end = time.time() + startup_time = t_start_training - t_start + steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0 + peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 + + print("---") + print(f"val_bpb: {val_bpb:.6f}") + print(f"training_seconds: {total_training_time:.1f}") + print(f"total_seconds: {t_end - t_start:.1f}") + print(f"peak_vram_mb: {peak_vram_mb:.1f}") + print(f"mfu_percent: {steady_state_mfu:.2f}") + print(f"total_tokens_M: {total_tokens / 1e6:.1f}") + print(f"num_steps: {step}") + print(f"num_params_M: {num_params / 1e6:.1f}") + print(f"depth: {DEPTH}") + + +if __name__ == "__main__": + try: + main() + except RuntimeError as err: + print(f"ERROR: {err}", file=sys.stderr) + raise SystemExit(1)