Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
ceb1e04
init branch
LouisYRYJ May 15, 2025
1dc6a53
Merge branch 'main' into approx-unrolling
LouisYRYJ May 15, 2025
8b52f14
EK-FAC running
LouisYRYJ May 22, 2025
1f873c4
checkpoint EKFACs running
LouisYRYJ May 22, 2025
8581b14
Merge branch 'main' into approx-unrolling
LouisYRYJ May 22, 2025
2ac0540
Merge branch 'main' into approx-unrolling
LouisYRYJ May 25, 2025
bc2d8bf
WIP resetting
LouisYRYJ May 26, 2025
b7f1557
averages working
May 26, 2025
07626fc
unrolling pipeline works for small models (otherwise we get OOM)
May 30, 2025
de08537
pre kronfluence vendoring
LouisYRYJ Jun 2, 2025
0d2740f
pipeline with vendored library working
LouisYRYJ Jun 2, 2025
d00f620
removed score utilities from hessian
LouisYRYJ Jun 2, 2025
a692b89
debugging covariance randomness
LouisYRYJ Jun 3, 2025
97ea928
Merge branch 'main' into approx-unrolling
LouisYRYJ Jun 3, 2025
59d15f5
renaming quelle -> bergson
LouisYRYJ Jun 3, 2025
05325bf
Merge branch 'main' into approx-unrolling and testing
LouisYRYJ Jun 4, 2025
8489f44
fsdp testing
LouisYRYJ Jun 5, 2025
16d7022
covariance with hooks working
LouisYRYJ Jun 6, 2025
aa08964
using closure for covariance processing
LouisYRYJ Jun 11, 2025
c9384f6
using fsdp for covariance processing working
LouisYRYJ Jun 11, 2025
cb8466f
refactoring hessians
LouisYRYJ Jun 12, 2025
9b6045d
quick clean up
LouisYRYJ Jun 13, 2025
fce6d0d
merging
LouisYRYJ Jun 13, 2025
ae7bbc0
debugging memory leaks
LouisYRYJ Jun 17, 2025
6651e78
merge main
LouisYRYJ Jun 20, 2025
6e7eaae
ekfac refactoring WIP
LouisYRYJ Jun 21, 2025
3771dc7
EKFAC refactoring WIP
LouisYRYJ Jun 22, 2025
34a2bdf
KFAC done + slow Eigenvalue correction
LouisYRYJ Jun 22, 2025
7eed777
pipeline running
LouisYRYJ Jun 23, 2025
0ede5b6
merge main
LouisYRYJ Jun 24, 2025
24b5820
pipeline WIP
LouisYRYJ Jun 24, 2025
495cc85
pipeline with new set up running
LouisYRYJ Jun 24, 2025
27b8a71
merge main
LouisYRYJ Jun 25, 2025
144bb32
memory efficient pipeline for bigger models WIP
LouisYRYJ Jun 25, 2025
faa5d6b
scaling covariance
LouisYRYJ Jun 26, 2025
d47dd3c
detach grads
LouisYRYJ Jun 26, 2025
b667b90
sharding covariances WIP
LouisYRYJ Jun 27, 2025
1222850
proper saving WIP
LouisYRYJ Jun 29, 2025
7172a14
eigenvectors sharded
LouisYRYJ Jun 30, 2025
7c4e51a
writing and running tests (pipeline running for 7B)
LouisYRYJ Jul 1, 2025
4aaedf8
ekfac tests, 1 device passing
LouisYRYJ Jul 3, 2025
67fa643
clean up WIP
LouisYRYJ Jul 4, 2025
a43df83
refactor + bug fix: .contiguous must be called BEFORE dist.all_reduce
LouisYRYJ Jul 8, 2025
bf14d88
clean up and add README
LouisYRYJ Jul 8, 2025
d141a82
add specification
LouisYRYJ Jul 8, 2025
be41c0b
clean up WIP
LouisYRYJ Jul 8, 2025
086bcbd
merge main
LouisYRYJ Jul 8, 2025
f46c3b6
more clean up
LouisYRYJ Jul 8, 2025
d8600ac
reformatting
LouisYRYJ Jul 8, 2025
c214258
fix cpu nonblocking bug in compute_eigenvector
LouisYRYJ Jul 14, 2025
852855b
sharded matmul refactoring
LouisYRYJ Jul 15, 2025
566fd5e
small fixes
LouisYRYJ Jul 15, 2025
4714af8
rewriting dist WIP
LouisYRYJ Jul 16, 2025
9ac34e6
refactoring distributed done
LouisYRYJ Jul 16, 2025
7b0128f
fix label when prompt exceeds max_token_len
LouisYRYJ Jul 17, 2025
8a90724
attribution with ekfac
LouisYRYJ Jul 17, 2025
f6c6a30
merge main
LouisYRYJ Jul 17, 2025
33138ac
remove path dependency
LouisYRYJ Jul 18, 2025
3a305ee
fix path dependency
LouisYRYJ Jul 18, 2025
1524cca
refactor + added logger + ekfac transform running
LouisYRYJ Jul 18, 2025
7cea24b
big refactor
LouisYRYJ Jul 30, 2025
ad712f3
apply ekfac refactor
LouisYRYJ Jul 31, 2025
63c1343
merging main
LouisYRYJ Jul 31, 2025
a735ad1
reformatting
LouisYRYJ Jul 31, 2025
422157c
(re)move notebooks
LouisYRYJ Jul 31, 2025
a05d95a
minor changes
LouisYRYJ Jul 31, 2025
2528b0e
Remove attribute_results.ipynb from tracking
LouisYRYJ Jul 31, 2025
20e6b92
test ekfac apply passing
LouisYRYJ Aug 5, 2025
8393fe9
clean up
LouisYRYJ Aug 5, 2025
e252d62
apply ekfac + datafiltering
LouisYRYJ Aug 7, 2025
516ebc8
adding peft fsdp test
LouisYRYJ Aug 8, 2025
a203b75
switch fsdp and peft
LouisYRYJ Aug 8, 2025
9e0e87e
fixing peft fsdp interaction
LouisYRYJ Aug 11, 2025
30e45ab
smaller refactor for peft loading
LouisYRYJ Aug 11, 2025
aff9c64
ekfac fix to get same results as kronfluence
LouisYRYJ Aug 15, 2025
64d4eb0
update script + small fix
LouisYRYJ Aug 15, 2025
fa0a4a6
fix script path
LouisYRYJ Aug 15, 2025
8ca6324
running ekfac sweeps
LouisYRYJ Aug 17, 2025
2d13de2
refactoring ekfac computations WIP
LouisYRYJ Aug 27, 2025
c3dc77d
sharded computation moved into different class DONE
LouisYRYJ Aug 27, 2025
6d0f935
distributed friendly logging (only log rank 0)
LouisYRYJ Aug 27, 2025
28b5866
removing processor
LouisYRYJ Sep 3, 2025
644558f
cut normalizers, make ekfac_apply file handling more clear
LouisYRYJ Sep 16, 2025
bff15f2
remove break statement, now attn is included by default
LouisYRYJ Sep 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,46 @@ dmypy.json
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# JetBrains specific template is maintained in a separate JetBrains. that can
# be found at https://github.com/github//blob/main/Global/JetBrains.
# and can be added to the global or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# VS Code
.vscode/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

# models
*.pt
*.pth
*.safetensors
*.json
*.jsonl
*.txt
*.arrow
*.bin
*.csv

# plots
*.png
*.jpg
*.jpeg
*.gif

# debugging results
*.svg
*.pickle
# Faiss index files
*.faiss
# Local directory for run artifacts
runs/
cache/

wandb/
.vscode/


9 changes: 7 additions & 2 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from simple_parsing import parse

from .build import build_gradient_dataset
from bergson.distributed import distributed_computing

from .collection import collect_gradients
from .data import IndexConfig


def main():
build_gradient_dataset(parse(IndexConfig))
distributed_computing(
parse(IndexConfig),
worker_fn=collect_gradients,
)


if __name__ == "__main__":
Expand Down
271 changes: 0 additions & 271 deletions bergson/build.py

This file was deleted.

15 changes: 6 additions & 9 deletions bergson/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm.auto import tqdm
from transformers import PreTrainedModel

from .data import create_index, pad_and_tensor
from .data import IndexConfig, create_index, pad_and_tensor
from .gradients import (
AdafactorNormalizer,
AdamNormalizer,
Expand All @@ -24,11 +24,10 @@ def collect_gradients(
model: PreTrainedModel,
data: Dataset,
processor: GradientProcessor,
path: str,
*,
batches: list[list[int]] | None = None,
skip_preconditioners: bool = False,
target_modules: set[str] | None = None,
cfg: IndexConfig,
):
"""
Compute projected gradients using a subset of the dataset.
Expand All @@ -54,7 +53,7 @@ def callback(name: str, g: torch.Tensor):
mod_grads[name] = g.to(device="cpu", dtype=torch.float16, non_blocking=True)

# Compute the outer product of the flattened gradient
if not skip_preconditioners:
if not cfg.skip_preconditioners:
g = g.float()
preconditioner = preconditioners.get(name, None)
if preconditioner is None:
Expand All @@ -73,9 +72,7 @@ def callback(name: str, g: torch.Tensor):
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}

# Allocate structured space ahead of time for the gradients
grad_buffer = create_index(
path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16
)
grad_buffer = create_index(cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16)

per_doc_losses = torch.full(
(len(data),),
Expand Down Expand Up @@ -139,9 +136,9 @@ def callback(name: str, g: torch.Tensor):
feature=Value("float16"),
new_fingerprint="loss",
)
data.save_to_disk(path + "/data.hf")
data.save_to_disk(cfg.run_path + "/data.hf")

processor.save(path)
processor.save(cfg.run_path)

# Make sure the gradients are written to disk
grad_buffer.flush()
Expand Down
Loading