-
Notifications
You must be signed in to change notification settings - Fork 3.7k
feat: Multi-GPU Data Parallel (DDP) for training #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,11 +10,37 @@ | |||||||||
|
|
||||||||||
| import gc | ||||||||||
| import time | ||||||||||
| import contextlib | ||||||||||
| from dataclasses import dataclass, asdict | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| import torch.nn as nn | ||||||||||
| import torch.nn.functional as F | ||||||||||
| from torch.nn.parallel import DistributedDataParallel as DDP | ||||||||||
| import torch.distributed as dist | ||||||||||
|
|
||||||||||
| # --------------------------------------------------------------------------- | ||||||||||
| # DDP Setup | ||||||||||
| # --------------------------------------------------------------------------- | ||||||||||
| ddp = int(os.environ.get('RANK', -1)) != -1 | ||||||||||
| if ddp: | ||||||||||
| dist.init_process_group("nccl") | ||||||||||
| ddp_rank = int(os.environ['RANK']) | ||||||||||
| ddp_local_rank = int(os.environ['LOCAL_RANK']) | ||||||||||
| ddp_world_size = int(os.environ['WORLD_SIZE']) | ||||||||||
| device = torch.device(f'cuda:{ddp_local_rank}') | ||||||||||
| torch.cuda.set_device(device) | ||||||||||
| master_process = (ddp_rank == 0) | ||||||||||
| else: | ||||||||||
| ddp_rank = 0 | ||||||||||
| ddp_local_rank = 0 | ||||||||||
| ddp_world_size = 1 | ||||||||||
| master_process = True | ||||||||||
| device = torch.device("cuda") | ||||||||||
|
|
||||||||||
| if not master_process: | ||||||||||
| import builtins | ||||||||||
| builtins.print = lambda *args, **kwargs: None | ||||||||||
|
|
||||||||||
| from kernels import get_kernel | ||||||||||
| cap = torch.cuda.get_device_capability() | ||||||||||
|
|
@@ -454,10 +480,9 @@ def step(self): | |||||||||
| # --------------------------------------------------------------------------- | ||||||||||
|
|
||||||||||
| t_start = time.time() | ||||||||||
| torch.manual_seed(42) | ||||||||||
| torch.cuda.manual_seed(42) | ||||||||||
| torch.manual_seed(42 + ddp_rank) | ||||||||||
| torch.cuda.manual_seed(42 + ddp_rank) | ||||||||||
|
Comment on lines
+483
to
+484
|
||||||||||
| torch.manual_seed(42 + ddp_rank) | |
| torch.cuda.manual_seed(42 + ddp_rank) | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) |
Copilot
AI
Mar 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dataloader wrapper makes every rank produce all batches from the underlying generator (including the expensive best-fit packing work), then discard (world_size - 1) / world_size of them. For example, with 8 GPUs, each rank does 8× the CPU data-loading work but only uses 1/8 of it. This effectively scales CPU cost by world_size and can become a bottleneck on large clusters.
Consider passing the rank/world_size into make_dataloader (or its underlying _document_batches) to shard the parquet files or row groups at the source, so each rank only loads its own slice of the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import:
contextlibis imported but never used anywhere in the file. It should be removed to keep the imports clean.