diff --git a/CHANGELOG.md b/CHANGELOG.md index cafe980a..be107908 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # CHANGELOG +## v0.5.0 (2026-01-08) + +### Features + +- Add optimizer-aware gradients + ([`497edab`](https://github.com/EleutherAI/bergson/commit/497edab8f2ca19d8fcb1d409fbd99452a929584e)) + + ## v0.4.6 (2026-01-06) ### Bug Fixes diff --git a/README.md b/README.md index 07934e76..0423fd2d 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,9 @@ We view attribution as a counterfactual question: **_If we "unlearned" this trai # Announcements +**January 2026** +- [Experimental] Support distributing preconditioners across nodes and devices for VRAM-efficient computation through the GradientCollectorWithDistributedPreconditioners. If you would like this functionality exposed via the CLI please get in touch! https://github.com/EleutherAI/bergson/pull/100 + **October 2025** - Support bias parameter gradients in linear modules: https://github.com/EleutherAI/bergson/pull/54 - Support convolution modules: https://github.com/EleutherAI/bergson/pull/50 diff --git a/bergson/__init__.py b/bergson/__init__.py index 64ff9aab..29d0fe49 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.6" +__version__ = "0.5.0" from .collection import collect_gradients from .collector.gradient_collectors import GradientCollector diff --git a/bergson/__main__.py b/bergson/__main__.py index a4fbd59e..37dcac38 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -14,6 +14,9 @@ def validate_run_path(index_cfg: IndexConfig): """Validate the run path.""" + if index_cfg.distributed.rank != 0: + return + for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]: if not path.exists(): continue diff --git a/bergson/build.py b/bergson/build.py index d1ece8c9..628a54bf 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -12,15 +12,18 @@ from bergson.collection import collect_gradients from bergson.config import IndexConfig from bergson.data import allocate_batches +from bergson.distributed import launch_distributed_run from bergson.utils.utils import assert_type, setup_reproducibility -from bergson.utils.worker_utils import setup_model_and_peft - -from .distributed import launch_distributed_run -from .utils.worker_utils import create_processor, setup_data_pipeline +from bergson.utils.worker_utils import ( + create_processor, + setup_data_pipeline, + setup_model_and_peft, +) def build_worker( rank: int, + local_rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableDataset, @@ -32,6 +35,8 @@ def build_worker( ---------- rank : int Distributed rank / GPU ID for this worker. + local_rank : int + Local rank / GPU ID for this worker on the node. world_size : int Total number of workers participating in the run. cfg : IndexConfig @@ -39,7 +44,7 @@ def build_worker( ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. """ - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) # These should be set by the main process if world_size > 1: @@ -49,14 +54,14 @@ def build_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, timeout=timedelta(hours=1), world_size=world_size, ) - model, target_modules = setup_model_and_peft(cfg, rank) - processor = create_processor(model, ds, cfg, rank, target_modules) + model, target_modules = setup_model_and_peft(cfg) + processor = create_processor(model, ds, cfg, target_modules) attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules} @@ -119,6 +124,10 @@ def build(index_cfg: IndexConfig): ds = setup_data_pipeline(index_cfg) - launch_distributed_run("build", build_worker, [index_cfg, ds]) + launch_distributed_run( + "build", build_worker, [index_cfg, ds], index_cfg.distributed + ) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + rank = index_cfg.distributed.rank + if rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/collection.py b/bergson/collection.py index dbbf01a1..9f9ef54c 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -24,8 +24,6 @@ def collect_gradients( """ Compute gradients using the hooks specified in the GradientCollector. """ - if attention_cfgs is None: - attention_cfgs = {} collector = GradientCollector( model=model.base_model, # type: ignore cfg=cfg, diff --git a/bergson/collector/dist_preconditioners_gradient_collector.py b/bergson/collector/dist_preconditioners_gradient_collector.py new file mode 100644 index 00000000..41c6dd6e --- /dev/null +++ b/bergson/collector/dist_preconditioners_gradient_collector.py @@ -0,0 +1,414 @@ +import math +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +import torch.nn as nn +from datasets import Dataset, Value +from jaxtyping import Float +from torch import Tensor + +from bergson.collector.collector import HookCollectorBase +from bergson.config import IndexConfig, ReduceConfig +from bergson.data import Builder +from bergson.gradients import ( + AdafactorNormalizer, + AdamNormalizer, + LayerAdapter, +) +from bergson.process_preconditioners import process_preconditioners +from bergson.score.scorer import Scorer +from bergson.utils.utils import assert_type + + +@dataclass(kw_only=True) +class GradientCollectorWithDistributedPreconditioners(HookCollectorBase): + """ + Collects per-sample gradients from model layers and writes them to disk. + Preconditioners are distributed across nodes, and data from each node is + distributed to each preconditioner at every step. This enables the computation + of preconditioners that are too large to fit on a single device. + + - For each forward/backward hook, we compute the the gradient or a low-rank + approximation via random projections, if cfg.projection_dim is set. + - Supports also normalization via Adam or Adafactor normalizers. + - Uses Builder for index construction and gradient saving. + - Also supports Scorer for on-the-fly scoring of gradients. + """ + + data: Dataset + """The dataset being processed.""" + + cfg: IndexConfig + """Configuration for gradient index.""" + + mod_grads: dict = field(default_factory=dict) + """Temporary storage for gradients during a batch, keyed by module name.""" + + reduce_cfg: ReduceConfig | None = None + """Configuration for in-run gradient reduction.""" + + builder: Builder | None = None + """Handles writing gradients to disk. Created in setup() if save_index is True.""" + + scorer: Scorer | None = None + """Optional scorer for computing scores instead of building an index.""" + + def __init__(self, *args, **kwargs): + self.data = assert_type(Dataset, kwargs["data"]) + self.cfg = assert_type(IndexConfig, kwargs["cfg"]) + + self.reduce_cfg = kwargs.get("reduce_cfg", None) + self.builder = kwargs.get("builder", None) + self.scorer = kwargs.get("scorer", None) + self.mod_grads = {} + + # Extract parent class arguments + parent_kwargs = { + k: v + for k, v in kwargs.items() + if k + in { + "model", + "filter_modules", + "target_modules", + "processor", + "attention_cfgs", + } + } + parent_kwargs["filter_modules"] = self.cfg.filter_modules + + super().__init__(*args, **parent_kwargs) + + def setup(self) -> None: + """ + Initialize collector state. + + Sets up a Builder for gradient storage if not using a Scorer. + """ + assert isinstance( + self.model.device, torch.device + ), "Model device is not set correctly" + if self.cfg.include_bias and self.processor.normalizers is not None: + raise NotImplementedError( + "Bias with normalizers not supported yet, " + "consider disabling bias inclusion for now." + ) + + self.owned_modules: set[str] = set() + self.module_to_rank: dict[str, int] = {} + + # TODO: handle more elegantly? + self.save_dtype = ( + torch.float32 if self.model.dtype == torch.float32 else torch.float16 + ) + + self.lo = torch.finfo(self.save_dtype).min + self.hi = torch.finfo(self.save_dtype).max + + self.per_doc_losses = torch.full( + (len(self.data),), + device=self.model.device, + dtype=self.save_dtype, + fill_value=0.0, + ) + + # Compute whether we need to save the index + self.save_index = self.scorer is None and not self.cfg.skip_index + + if self.save_index: + grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} + self.builder = Builder( + self.cfg.partial_run_path, + self.data, + grad_sizes, + self.save_dtype, + self.reduce_cfg, + ) + else: + self.builder = None + + if dist.is_initialized(): + rank = dist.get_rank() + num_devices = dist.get_world_size() + available_modules = list(self.shapes().keys()) + + num_modules = len(available_modules) + base, remainder = divmod(num_modules, num_devices) + + assert base > 0, "Each rank must own at least one module" + + start_idx = rank * base + min(rank, remainder) + end_idx = start_idx + base + (1 if rank < remainder else 0) + self.owned_modules = set(available_modules[start_idx:end_idx]) + + for i, module_name in enumerate(available_modules): + # Inverse of the start_idx formula + self.module_to_rank[module_name] = ( + min(i // (base + 1), remainder - 1) + if i < remainder * (base + 1) + else remainder + (i - remainder * (base + 1)) // base + ) + + print(f"Rank {rank} owns {len(self.owned_modules)} modules") + else: + self.owned_modules = set(self.shapes().keys()) + + def forward_hook(self, module: nn.Module, a: Float[Tensor, "N S I"]) -> None: + """ + Cache activations for gradient computation with normalizer preprocessing + and compress via random projection if configured. + Stores result in module._inputs for use in backward_hook. + """ + p = self.processor.projection_dim + name = assert_type(str, module._name) + i = getattr(module, LayerAdapter.in_attr(module)) + normalizer = self.processor.normalizers.get(name) + + if isinstance(normalizer, AdamNormalizer): + module._inputs = a + return + if isinstance(normalizer, AdafactorNormalizer): + a_factor = normalizer.col.add(1e-30) + a_factor = a_factor.rsqrt() + a = a * a_factor.type_as(a) # [N, S, I] * [I] → [N, S, I] + + if module._has_bias: + # Append ones to activation for bias term + ones = torch.ones(a.size(0), a.size(1), 1, device=a.device, dtype=a.dtype) + a = torch.cat([a, ones], dim=-1) + i = i + 1 + setattr(module, LayerAdapter.in_attr(module), i) + if p is not None: + a_projection = self.projection(name, p, i, "right", a.device, a.dtype).T + a = a @ a_projection # type: ignore + # set module._inputs to a + module._inputs = a + + @HookCollectorBase.split_attention_heads + def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): + """ + Compute per-sample gradient and store in mod_grads. + + Computes gradient as outer product g.T @ a (again with optional projection and + normalization). + """ + a = module._inputs # [N, S, I/q] + + assert isinstance(a, torch.Tensor), "Activation cache missing for module" + name = assert_type(str, module._name) + p = self.processor.projection_dim + i = getattr(module, LayerAdapter.in_attr(module)) + o = getattr(module, LayerAdapter.out_attr(module)) + normalizer = self.processor.normalizers.get(name) + + if isinstance(normalizer, AdamNormalizer): + full_gradient = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I] + P = normalizer.normalize_(full_gradient) + if p is not None: + g_projection = self.projection(name, p, o, "left", g.device, g.dtype) + a_projection = self.projection(name, p, i, "right", g.device, g.dtype).T + P = g_projection @ P @ a_projection + else: + if isinstance(normalizer, AdafactorNormalizer): + g_factor = normalizer.row.add(1e-30) + g_factor = g_factor.mean().sqrt() * g_factor.rsqrt() + g = g * g_factor.type_as(g) # [N, S, O] * [O] → [N, S, O] + + if p is not None: + g_projection = self.projection(name, p, o, "left", g.device, g.dtype) + g = g @ g_projection.T # [N, S, p] + + P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] + + P = P.flatten(1).clamp_(self.lo, self.hi) + + # Keep gradients in original dtype for preconditioner computation + self.mod_grads[name] = P + + if self.cfg.skip_preconditioners: + if self.save_index: + # Asynchronously move the gradient to CPU and convert to the final dtype + self.mod_grads[name] = P.to( + device="cpu", dtype=self.save_dtype, non_blocking=True + ) + else: + self.mod_grads[name] = P.to(dtype=self.save_dtype) + + del module._inputs + + def process_batch(self, indices: list[int], **kwargs): + """Process collected gradients for a batch and update losses.""" + losses = kwargs.get("losses") + assert losses is not None, "losses must be provided in kwargs" + + # Send gradients to owning ranks and compute outer products there + if not self.cfg.skip_preconditioners: + exchange_preconditioner_gradients( + self.mod_grads, + self.processor.preconditioners, + self.module_to_rank, + self.owned_modules, + self.rank, + ) + + # Convert mod_grads to the right dtype for save_index logic + if self.save_index: + for name in self.mod_grads: + self.mod_grads[name] = self.mod_grads[name].to( + device="cpu", dtype=self.save_dtype, non_blocking=True + ) + else: + for name in self.mod_grads: + self.mod_grads[name] = self.mod_grads[name].to( + dtype=self.save_dtype + ) + + if self.builder: + self.builder(indices, self.mod_grads) + if self.scorer: + self.scorer(indices, self.mod_grads) + self.mod_grads.clear() + self.per_doc_losses[indices] = losses.detach().type_as(self.per_doc_losses) + + def teardown(self): + """ + Finalize gradient collection, save results and flush/reduce the Builder. + """ + assert isinstance( + self.cfg, IndexConfig + ), "cfg is required for GradientCollector" # pleasing type checker + if dist.is_initialized(): + dist.reduce(self.per_doc_losses, dst=0) + + grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} + if self.processor.preconditioners: + process_preconditioners( + self.processor, + self.processor.preconditioners, + len(self.data), + grad_sizes, + self.rank, + ) + + if self.rank == 0: + if self.cfg.drop_columns: + self.data = self.data.remove_columns(["input_ids"]) + + self.data = self.data.add_column( + "loss", + self.per_doc_losses.cpu().numpy(), + feature=Value( + "float16" + if self.save_dtype == torch.float16 + else "float32" # TODO: This is not robust + ), + new_fingerprint="loss", + ) + + self.data.save_to_disk(self.cfg.partial_run_path / "data.hf") + + self.processor.save(self.cfg.partial_run_path) + + # Flush and reduce builder if it exists + if self.builder is not None: + self.builder.flush() + self.builder.dist_reduce() + + +def exchange_preconditioner_gradients( + mod_grads: dict[str, torch.Tensor], + preconditioners: dict[str, torch.Tensor], + module_to_rank: dict[str, int], + owned_modules: set[str], + rank: int, +): + """ + Send gradients to the ranks that own their preconditioners, and accumulate + outer products on the owning ranks. + Each rank sends gradients for modules it doesn't own to the owning ranks, + and receives gradients for modules it owns to compute outer products. + """ + # Process current rank data for owned modules + for name, g in mod_grads.items(): + if name not in owned_modules: + continue + + g = g.float() + if name in preconditioners: + preconditioners[name].addmm_(g.mT, g) + else: + preconditioners[name] = g.mT @ g + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + device = next(iter(mod_grads.values())).device + + module_names = list(mod_grads.keys()) + module_numel = {n: int(mod_grads[n].numel()) for n in module_names} + + current_rank_chunk = torch.empty(0, device=device, dtype=torch.float32) + + # Flatten batch dimension: all to all works on contiguous 1-D tensors + send_chunks = [ + ( + current_rank_chunk + if dest == rank + else torch.cat( + [ + mod_grads[name].flatten() + for name in module_names + if module_to_rank[name] == dest + ] + ) + ) + for dest in range(world_size) + ] + + # --- collective exchange of gradient sizes in order of mod_grads --- + send_sizes = torch.tensor( + [t.numel() for t in send_chunks], device=device, dtype=torch.int64 + ) + recv_sizes = torch.empty_like(send_sizes) + + dist.all_to_all_single(recv_sizes, send_sizes) + + # --- collective exchange of gradient in order of mod_grads --- + send_buf = torch.cat(send_chunks) + recv_buf = torch.empty( + int(recv_sizes.sum().item()), device=device, dtype=torch.float32 + ) + + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_sizes.tolist(), + input_split_sizes=send_sizes.tolist(), + ) + + # Unpack gradients in src-rank order + # Within each src partition, modules are in fixed order. + offset = 0 + for src_rank in range(world_size): + part_len = int(recv_sizes[src_rank].item()) + part = recv_buf[offset : offset + part_len] + offset += part_len + + if part_len == 0 or src_rank == rank: + continue + + p = 0 + for name in owned_modules: + n = module_numel[name] + flat = part[p : p + n] + p += n + + feature_dim = mod_grads[name].shape[-1] + g = flat.to(device, non_blocking=True).view(-1, feature_dim).float() + + if name in preconditioners: + preconditioners[name].addmm_(g.mT, g) + else: + preconditioners[name] = g.mT @ g diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 1a459e60..0e8f5f84 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -110,7 +110,6 @@ def setup(self) -> None: # Compute whether we need to save the index self.save_index = self.scorer is None and not self.cfg.skip_index - self.skip_preconditioners = self.cfg.skip_preconditioners if self.save_index: grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} @@ -193,7 +192,7 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): P = P.flatten(1).clamp_(self.lo, self.hi) - if not self.skip_preconditioners: + if not self.cfg.skip_preconditioners: P = P.float() if name in self.processor.preconditioners: self.processor.preconditioners[name].addmm_(P.mT, P) diff --git a/bergson/config.py b/bergson/config.py index 17c3392e..25bf8328 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -1,7 +1,9 @@ +import os from dataclasses import dataclass from pathlib import Path from typing import Literal +import torch from simple_parsing import field @@ -55,6 +57,56 @@ class AttentionConfig: """Axis index for `num_heads` in the weight matrix.""" +@dataclass +class DistributedConfig: + """Configuration for multi-node preconditioner computation.""" + + nnode: int = 1 + """The number of nodes to use for preconditioner computation.""" + + nproc_per_node: int = field(default_factory=lambda: torch.cuda.device_count()) + """The number of processes per node.""" + + node_rank: int | None = None + """The rank of the current node. If not set Bergson will attempt to infer + it from environment variables.""" + + @property + def _node_rank(self) -> int: + """Get the rank of the node from config or environment variables.""" + if self.node_rank is not None: + return self.node_rank + + if self.nnode == 1: + return 0 + + for var in ("SLURM_NODEID", "GROUP_RANK", "NODE_RANK"): + if var in os.environ: + return int(os.environ[var]) + + raise ValueError("Node rank not found. Set it with --node_rank.") + + @property + def world_size(self) -> int: + """Total number of processes across all nodes.""" + return self.nnode * self.nproc_per_node + + @property + def start_rank(self) -> int: + """Starting rank for processes on this node.""" + return self._node_rank * self.nproc_per_node + + @property + def local_rank(self) -> int: + """Local rank of the current process.""" + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def rank(self) -> int: + """Rank of the current process.""" + return self.start_rank + self.local_rank + + @dataclass class IndexConfig: """Config for building the index and running the model/dataset pipeline.""" @@ -145,6 +197,9 @@ class IndexConfig: overwrite: bool = False """Whether to overwrite any existing index in the run path.""" + distributed: DistributedConfig = field(default_factory=DistributedConfig) + """Configuration for multi-node distributed preconditioner computation.""" + @property def partial_run_path(self) -> Path: """Temporary path to use while writing build artifacts.""" @@ -157,6 +212,9 @@ def __post_init__(self): if isinstance(self.attention, dict): self.attention = AttentionConfig(**self.attention) + if isinstance(self.distributed, dict): + self.distributed = DistributedConfig(**self.distributed) + @dataclass class QueryConfig: @@ -253,15 +311,20 @@ class FaissConfig: - "PQ6720": nearest neighbors with vector product quantization to 6720 elements. Reduces memory usage at the cost of accuracy. """ + mmap_index: bool = False """Whether to query the gradients on-disk.""" + max_train_examples: int | None = None """The maximum number of examples to train the index on. If `None`, all examples will be used.""" + batch_size: int = 1024 """The batch size for pre-processing gradients.""" + num_shards: int = 1 """The number of shards to build for an index. Using more shards reduces peak RAM usage.""" + nprobe: int = 10 """The number of FAISS vector clusters to search if using ANN.""" diff --git a/bergson/data.py b/bergson/data.py index bc40582a..dc700bf9 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -28,7 +28,11 @@ def ceildiv(a: int, b: int) -> int: return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers -def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[list[int]]: +def allocate_batches( + doc_lengths: list[int], + N: int, + seed: int = 42, +) -> list[list[int]]: """ Allocate documents into batches that are then distributed evenly across a fixed number of workers. @@ -41,7 +45,8 @@ def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[lis N : int Hard memory budget per *batch*, expressed as ``max(length in batch) * (# docs in batch) ≤ N``. - + seed : int + Random seed for shuffling batches within each worker's allocation. Returns ------- list[list[int]] @@ -446,6 +451,8 @@ def dist_reduce(self): self.grad_buffer.dtype ) + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() + def pad_and_tensor( sequences: list[list[int]], diff --git a/bergson/distributed.py b/bergson/distributed.py index 222754d0..7abbf753 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -1,11 +1,13 @@ +import os import socket from typing import Any, Callable -import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes +from .config import DistributedConfig + def dist_worker( worker: Callable, @@ -24,34 +26,34 @@ def dist_worker( dist.destroy_process_group() -def launch_distributed_run(process_name: str, worker, const_worker_args: list[Any]): - """ - Launch a distributed multi-process job over all visible CUDA devices. +def launch_distributed_run( + process_name: str, + worker, + const_worker_args: list[Any], + dist_config: DistributedConfig | None = None, +): + if dist_config is None: + dist_config = DistributedConfig() - Parameters - ---------- - process_name : str - Label used by Torch Elastic to tag logs and processes. - worker : Callable - Function that will be executed on every spawned process. It must accept - ``(rank, world_size, *const_worker_args)`` in that order. - const_worker_args : list - Arguments passed verbatim to every worker invocation after ``rank`` and - ``world_size``. These are typically configuration or shared datasets. - """ - world_size = torch.cuda.device_count() - if world_size <= 1: - # Run the worker directly if no distributed training is needed. This is great - # for debugging purposes. - worker(0, 1, *const_worker_args) - else: - # Set up multiprocessing and distributed training - mp.set_sharing_strategy("file_system") + local_world_size = dist_config.nproc_per_node + world_size = dist_config.world_size + start_rank = dist_config.start_rank - # Find an available port for distributed training + # Multi-node environment + if dist_config.nnode > 1: + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "29500") + else: + master_addr = "localhost" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) - _, port = s.getsockname() + _, master_port = s.getsockname() + master_port = str(master_port) + + if world_size <= 1: + worker(0, 0, 1, *const_worker_args) + else: + mp.set_sharing_strategy("file_system") ctx = None try: @@ -59,16 +61,18 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An process_name, dist_worker, args={ - i: (worker, i, world_size, *const_worker_args) - for i in range(world_size) + i: (worker, start_rank + i, i, world_size, *const_worker_args) + for i in range(local_world_size) }, envs={ i: { "LOCAL_RANK": str(i), - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(port), + "RANK": str(start_rank + i), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, } - for i in range(world_size) + for i in range(local_world_size) }, logs_specs=DefaultLogsSpecs(), ) diff --git a/bergson/query/query_index.py b/bergson/query/query_index.py index 783644dc..8b4915b2 100644 --- a/bergson/query/query_index.py +++ b/bergson/query/query_index.py @@ -38,12 +38,12 @@ def query( ) tokenizer = AutoTokenizer.from_pretrained(query_cfg.model) model, target_modules = setup_model_and_peft( - query_index_cfg, 0, device_map_auto=query_cfg.device_map_auto + query_index_cfg, device_map_auto=query_cfg.device_map_auto ) else: tokenizer = AutoTokenizer.from_pretrained(index_cfg.model) model, target_modules = setup_model_and_peft( - index_cfg, 0, device_map_auto=query_cfg.device_map_auto + index_cfg, device_map_auto=query_cfg.device_map_auto ) ds = load_data_string( diff --git a/bergson/reduce.py b/bergson/reduce.py index 7074c6c9..2a9df92b 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -21,6 +21,7 @@ def reduce_worker( rank: int, + local_rank: int, world_size: int, index_cfg: IndexConfig, reduce_cfg: ReduceConfig, @@ -33,6 +34,8 @@ def reduce_worker( ---------- rank : int Distributed rank / GPU ID for this worker. + local_rank : int + Local rank / GPU ID for this worker on the node. world_size : int Total number of workers participating in the run. index_cfg : IndexConfig @@ -42,7 +45,7 @@ def reduce_worker( ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. """ - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) # These should be set by the main process if world_size > 1: @@ -52,14 +55,14 @@ def reduce_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, - timeout=timedelta(hours=1), + timeout=timedelta(minutes=30), world_size=world_size, ) - model, target_modules = setup_model_and_peft(index_cfg, rank) - processor = create_processor(model, ds, index_cfg, rank, target_modules) + model, target_modules = setup_model_and_peft(index_cfg) + processor = create_processor(model, ds, index_cfg, target_modules) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules @@ -126,6 +129,9 @@ def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): ds = setup_data_pipeline(index_cfg) - launch_distributed_run("reduce", reduce_worker, [index_cfg, reduce_cfg, ds]) + launch_distributed_run( + "reduce", reduce_worker, [index_cfg, reduce_cfg, ds], index_cfg.distributed + ) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + if index_cfg.distributed.rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/score/score.py b/bergson/score/score.py index de2506a1..64509335 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -217,6 +217,7 @@ def get_query_ds(score_cfg: ScoreConfig): def score_worker( rank: int, + local_rank: int, world_size: int, index_cfg: IndexConfig, score_cfg: ScoreConfig, @@ -230,6 +231,8 @@ def score_worker( ---------- rank : int Distributed rank / GPU ID for this worker. + local_rank : int + Local rank / GPU ID for this worker on the node. world_size : int Total number of workers participating in the run. index_cfg : IndexConfig @@ -242,7 +245,7 @@ def score_worker( query_grads : dict[str, torch.Tensor] Preprocessed query gradient tensors (often [1, grad_dim]) keyed by module name. """ - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) # These should be set by the main process if world_size > 1: @@ -252,14 +255,14 @@ def score_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, timeout=timedelta(hours=1), world_size=world_size, ) - model, target_modules = setup_model_and_peft(index_cfg, rank) - processor = create_processor(model, ds, index_cfg, rank, target_modules) + model, target_modules = setup_model_and_peft(index_cfg) + processor = create_processor(model, ds, index_cfg, target_modules) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules @@ -365,7 +368,11 @@ def score_dataset( ) launch_distributed_run( - "score", score_worker, [index_cfg, score_cfg, ds, query_grads] + "score", + score_worker, + [index_cfg, score_cfg, ds, query_grads], + index_cfg.distributed, ) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + if index_cfg.distributed.rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index 1c41822d..ae48404d 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -63,18 +63,20 @@ def create_processor( model: PreTrainedModel, ds: Dataset | IterableDataset, cfg: IndexConfig, - rank: int, target_modules: set[str] | None = None, ) -> GradientProcessor: """Handle processor creation and normalizer fitting""" + local_rank = cfg.distributed.local_rank + rank = cfg.distributed.rank + processor_path = Path(cfg.processor_path) if (processor_path / "processor_config.json").exists(): - if rank == 0: + if local_rank == 0: print(f"Loading processor from '{cfg.processor_path}'") processor = GradientProcessor.load( processor_path, - map_location=f"cuda:{rank}", + map_location=f"cuda:{local_rank}", ) else: normalizers = create_normalizers(model, ds, cfg, target_modules) @@ -94,10 +96,10 @@ def create_processor( def setup_model_and_peft( cfg: IndexConfig, - rank: int, device_map_auto: bool = False, ) -> tuple[PreTrainedModel, set | None]: """Handle model loading, quantization, FSDP, and PEFT detection""" + local_rank = cfg.distributed.local_rank match cfg.precision: case "bf16": @@ -119,7 +121,7 @@ def setup_model_and_peft( elif cfg.fsdp: device_map = "cpu" else: - device_map = {"": f"cuda:{rank}"} + device_map = {"": f"cuda:{local_rank}"} quantization_config = None if cfg.precision in ("int4", "int8"): diff --git a/pyproject.toml b/pyproject.toml index 013532c3..b0c1ad22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers<4.56.0" # 4.56.0 increases fp error from operation order ] -version = "0.4.6" +version = "0.5.0" [project.optional-dependencies] dev = [ "pre-commit",