diff --git a/train.py b/train.py index 6994fb9b..8d8dbfc8 100644 --- a/train.py +++ b/train.py @@ -10,11 +10,37 @@ import gc import time +import contextlib from dataclasses import dataclass, asdict import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist + +# --------------------------------------------------------------------------- +# DDP Setup +# --------------------------------------------------------------------------- +ddp = int(os.environ.get('RANK', -1)) != -1 +if ddp: + dist.init_process_group("nccl") + ddp_rank = int(os.environ['RANK']) + ddp_local_rank = int(os.environ['LOCAL_RANK']) + ddp_world_size = int(os.environ['WORLD_SIZE']) + device = torch.device(f'cuda:{ddp_local_rank}') + torch.cuda.set_device(device) + master_process = (ddp_rank == 0) +else: + ddp_rank = 0 + ddp_local_rank = 0 + ddp_world_size = 1 + master_process = True + device = torch.device("cuda") + +if not master_process: + import builtins + builtins.print = lambda *args, **kwargs: None from kernels import get_kernel cap = torch.cuda.get_device_capability() @@ -454,10 +480,9 @@ def step(self): # --------------------------------------------------------------------------- t_start = time.time() -torch.manual_seed(42) -torch.cuda.manual_seed(42) +torch.manual_seed(42 + ddp_rank) +torch.cuda.manual_seed(42 + ddp_rank) torch.set_float32_matmul_precision("high") -device = torch.device("cuda") autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) H100_BF16_PEAK_FLOPS = 989.5e12 @@ -491,9 +516,9 @@ def build_model_config(depth): num_flops_per_token = model.estimate_flops() print(f"Estimated FLOPs per token: {num_flops_per_token:e}") -tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN -assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0 -grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd +tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN * ddp_world_size +grad_accum_steps = max(1, TOTAL_BATCH_SIZE // tokens_per_fwdbwd) +TOTAL_BATCH_SIZE = grad_accum_steps * tokens_per_fwdbwd optimizer = model.setup_optimizer( unembedding_lr=UNEMBEDDING_LR, @@ -504,9 +529,22 @@ def build_model_config(depth): weight_decay=WEIGHT_DECAY, ) +if ddp: + model = DDP(model, device_ids=[ddp_local_rank]) model = torch.compile(model, dynamic=False) train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train") + +def ddp_dataloader_wrapper(loader, rank, world_size): + for _ in range(rank): + next(loader) + while True: + yield next(loader) + for _ in range(world_size - 1): + next(loader) + +train_loader = ddp_dataloader_wrapper(train_loader, ddp_rank, ddp_world_size) + x, y, epoch = next(train_loader) # prefetch first batch print(f"Time budget: {TIME_BUDGET}s") @@ -543,6 +581,8 @@ def get_weight_decay(progress): torch.cuda.synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): + if ddp: + model._orig_mod.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) with autocast_ctx: loss = model(x, y) train_loss = loss.detach() @@ -627,3 +667,6 @@ def get_weight_decay(progress): print(f"num_steps: {step}") print(f"num_params_M: {num_params / 1e6:.1f}") print(f"depth: {DEPTH}") + +if ddp: + dist.destroy_process_group()