Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# CHANGELOG


## v0.5.0 (2026-01-08)

### Features

- Add optimizer-aware gradients
([`497edab`](https://github.com/EleutherAI/bergson/commit/497edab8f2ca19d8fcb1d409fbd99452a929584e))


## v0.4.6 (2026-01-06)

### Bug Fixes
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ We view attribution as a counterfactual question: **_If we "unlearned" this trai

# Announcements

**January 2026**
- [Experimental] Support distributing preconditioners across nodes and devices for VRAM-efficient computation through the GradientCollectorWithDistributedPreconditioners. If you would like this functionality exposed via the CLI please get in touch! https://github.com/EleutherAI/bergson/pull/100

**October 2025**
- Support bias parameter gradients in linear modules: https://github.com/EleutherAI/bergson/pull/54
- Support convolution modules: https://github.com/EleutherAI/bergson/pull/50
Expand Down
2 changes: 1 addition & 1 deletion bergson/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.6"
__version__ = "0.5.0"

from .collection import collect_gradients
from .collector.gradient_collectors import GradientCollector
Expand Down
3 changes: 3 additions & 0 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

def validate_run_path(index_cfg: IndexConfig):
"""Validate the run path."""
if index_cfg.distributed.rank != 0:
return

for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]:
if not path.exists():
continue
Expand Down
29 changes: 19 additions & 10 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
from bergson.collection import collect_gradients
from bergson.config import IndexConfig
from bergson.data import allocate_batches
from bergson.distributed import launch_distributed_run
from bergson.utils.utils import assert_type, setup_reproducibility
from bergson.utils.worker_utils import setup_model_and_peft

from .distributed import launch_distributed_run
from .utils.worker_utils import create_processor, setup_data_pipeline
from bergson.utils.worker_utils import (
create_processor,
setup_data_pipeline,
setup_model_and_peft,
)


def build_worker(
rank: int,
local_rank: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to doc what this does

world_size: int,
cfg: IndexConfig,
ds: Dataset | IterableDataset,
Expand All @@ -32,14 +35,16 @@ def build_worker(
----------
rank : int
Distributed rank / GPU ID for this worker.
local_rank : int
Local rank / GPU ID for this worker on the node.
world_size : int
Total number of workers participating in the run.
cfg : IndexConfig
Specifies the model, tokenizer, PEFT adapters, and other settings.
ds : Dataset | IterableDataset
The entire dataset to be indexed. A subset is assigned to each worker.
"""
torch.cuda.set_device(rank)
torch.cuda.set_device(local_rank)

# These should be set by the main process
if world_size > 1:
Expand All @@ -49,14 +54,14 @@ def build_worker(
dist.init_process_group(
"nccl",
init_method=f"tcp://{addr}:{port}",
device_id=torch.device(f"cuda:{rank}"),
device_id=torch.device(f"cuda:{local_rank}"),
rank=rank,
timeout=timedelta(hours=1),
world_size=world_size,
)

model, target_modules = setup_model_and_peft(cfg, rank)
processor = create_processor(model, ds, cfg, rank, target_modules)
model, target_modules = setup_model_and_peft(cfg)
processor = create_processor(model, ds, cfg, target_modules)

attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}

Expand Down Expand Up @@ -119,6 +124,10 @@ def build(index_cfg: IndexConfig):

ds = setup_data_pipeline(index_cfg)

launch_distributed_run("build", build_worker, [index_cfg, ds])
launch_distributed_run(
"build", build_worker, [index_cfg, ds], index_cfg.distributed
)

shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
rank = index_cfg.distributed.rank
if rank == 0:
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
2 changes: 0 additions & 2 deletions bergson/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def collect_gradients(
"""
Compute gradients using the hooks specified in the GradientCollector.
"""
if attention_cfgs is None:
attention_cfgs = {}
collector = GradientCollector(
model=model.base_model, # type: ignore
cfg=cfg,
Expand Down
Loading