Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,37 @@

import gc
import time
import contextlib
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import: contextlib is imported but never used anywhere in the file. It should be removed to keep the imports clean.

Suggested change
import contextlib

Copilot uses AI. Check for mistakes.
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

# ---------------------------------------------------------------------------
# DDP Setup
# ---------------------------------------------------------------------------
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
dist.init_process_group("nccl")
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = torch.device(f'cuda:{ddp_local_rank}')
torch.cuda.set_device(device)
master_process = (ddp_rank == 0)
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
device = torch.device("cuda")

if not master_process:
import builtins
builtins.print = lambda *args, **kwargs: None

from kernels import get_kernel
cap = torch.cuda.get_device_capability()
Expand Down Expand Up @@ -454,10 +480,9 @@ def step(self):
# ---------------------------------------------------------------------------

t_start = time.time()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.manual_seed(42 + ddp_rank)
torch.cuda.manual_seed(42 + ddp_rank)
Comment on lines +483 to +484
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: Setting per-rank seeds (42 + ddp_rank) before model initialization (line 509: model.init_weights()) means each DDP rank will initialize the model with different random weights. DDP requires all ranks to start with identical parameters — it only synchronizes gradients, not initial weights. This will cause training to diverge immediately.

The fix is to use the same seed across all ranks for model initialization, and only diverge seeds afterward (before the dataloader) to ensure each rank sees different data. For example, set torch.manual_seed(42) before init_weights(), then set torch.manual_seed(42 + ddp_rank) after model initialization and before the dataloader setup.

Suggested change
torch.manual_seed(42 + ddp_rank)
torch.cuda.manual_seed(42 + ddp_rank)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Copilot uses AI. Check for mistakes.
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
H100_BF16_PEAK_FLOPS = 989.5e12

Expand Down Expand Up @@ -491,9 +516,9 @@ def build_model_config(depth):
num_flops_per_token = model.estimate_flops()
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")

tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN * ddp_world_size
grad_accum_steps = max(1, TOTAL_BATCH_SIZE // tokens_per_fwdbwd)
TOTAL_BATCH_SIZE = grad_accum_steps * tokens_per_fwdbwd

optimizer = model.setup_optimizer(
unembedding_lr=UNEMBEDDING_LR,
Expand All @@ -504,9 +529,22 @@ def build_model_config(depth):
weight_decay=WEIGHT_DECAY,
)

if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
model = torch.compile(model, dynamic=False)

train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")

def ddp_dataloader_wrapper(loader, rank, world_size):
for _ in range(rank):
next(loader)
while True:
yield next(loader)
for _ in range(world_size - 1):
next(loader)
Comment on lines +538 to +544
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataloader wrapper makes every rank produce all batches from the underlying generator (including the expensive best-fit packing work), then discard (world_size - 1) / world_size of them. For example, with 8 GPUs, each rank does 8× the CPU data-loading work but only uses 1/8 of it. This effectively scales CPU cost by world_size and can become a bottleneck on large clusters.

Consider passing the rank/world_size into make_dataloader (or its underlying _document_batches) to shard the parquet files or row groups at the source, so each rank only loads its own slice of the data.

Copilot uses AI. Check for mistakes.

train_loader = ddp_dataloader_wrapper(train_loader, ddp_rank, ddp_world_size)

x, y, epoch = next(train_loader) # prefetch first batch

print(f"Time budget: {TIME_BUDGET}s")
Expand Down Expand Up @@ -543,6 +581,8 @@ def get_weight_decay(progress):
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
if ddp:
model._orig_mod.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
Expand Down Expand Up @@ -627,3 +667,6 @@ def get_weight_decay(progress):
print(f"num_steps: {step}")
print(f"num_params_M: {num_params / 1e6:.1f}")
print(f"depth: {DEPTH}")

if ddp:
dist.destroy_process_group()
Loading