feat: Multi-GPU Data Parallel (DDP) for training#117
feat: Multi-GPU Data Parallel (DDP) for training#117aniruddhaadak80 wants to merge 1 commit intokarpathy:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds DistributedDataParallel (DDP) support to train.py, enabling multi-GPU training via torchrun. It sets up the distributed process group, adjusts random seeds per rank, scales gradient accumulation for the world size, wraps the model with DDP, implements a dataloader sharding wrapper, suppresses prints on non-master ranks, and disables gradient synchronization on non-final micro-steps for efficiency.
Changes:
- Add DDP initialization/teardown with NCCL backend, per-rank device assignment, and print suppression on non-master processes
- Scale gradient accumulation steps and total batch size to account for
ddp_world_size, and userequire_backward_grad_syncto skip allreduce on intermediate micro-steps - Implement a generator-based dataloader wrapper to shard training data across DDP ranks by interleaving
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch.manual_seed(42 + ddp_rank) | ||
| torch.cuda.manual_seed(42 + ddp_rank) |
There was a problem hiding this comment.
Critical bug: Setting per-rank seeds (42 + ddp_rank) before model initialization (line 509: model.init_weights()) means each DDP rank will initialize the model with different random weights. DDP requires all ranks to start with identical parameters — it only synchronizes gradients, not initial weights. This will cause training to diverge immediately.
The fix is to use the same seed across all ranks for model initialization, and only diverge seeds afterward (before the dataloader) to ensure each rank sees different data. For example, set torch.manual_seed(42) before init_weights(), then set torch.manual_seed(42 + ddp_rank) after model initialization and before the dataloader setup.
| torch.manual_seed(42 + ddp_rank) | |
| torch.cuda.manual_seed(42 + ddp_rank) | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) |
|
|
||
| import gc | ||
| import time | ||
| import contextlib |
There was a problem hiding this comment.
Unused import: contextlib is imported but never used anywhere in the file. It should be removed to keep the imports clean.
| import contextlib |
| def ddp_dataloader_wrapper(loader, rank, world_size): | ||
| for _ in range(rank): | ||
| next(loader) | ||
| while True: | ||
| yield next(loader) | ||
| for _ in range(world_size - 1): | ||
| next(loader) |
There was a problem hiding this comment.
The dataloader wrapper makes every rank produce all batches from the underlying generator (including the expensive best-fit packing work), then discard (world_size - 1) / world_size of them. For example, with 8 GPUs, each rank does 8× the CPU data-loading work but only uses 1/8 of it. This effectively scales CPU cost by world_size and can become a bottleneck on large clusters.
Consider passing the rank/world_size into make_dataloader (or its underlying _document_batches) to shard the parquet files or row groups at the source, so each rank only loads its own slice of the data.
|
@copilot open a new pull request to apply changes based on the comments in this thread |
Summary
Introduces
torch.distributed(DDP) setup to natively run the training script on multi-GPU setups usingtorchrun.Key Capabilities
train.pyarchitecture into a data-parallel approach.device=cuda:local_rankto map perfectly to underlying hardware topology.require_backward_grad_syncstrictly to the terminal micro-steps, massively increasing throughput when using gradient accumulation to not stall compute networking.builtins.printdisabled) to prevent terminal flooding.Runs perfectly out of the box with zero modifications required inside
prepare.py.Usage