Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,12 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e ".[dev,faiss]"
# TODO: Proper test infrastructure for tests/ekfac_tests
# - name: Run tests
# run: pytest
# TODO: run pyright on whole codebase
- name: Type Checking bergson/hessians
- name: Run EKFAC tests that work without GPU
run: pytest -sv tests/ekfac_tests
- name: Type Checking
uses: jakebailey/pyright-action@v1
with:
version: 1.1.406
working-directory: bergson/hessians
- name: Type Checking tests/ekfac_tests
uses: jakebailey/pyright-action@v1
with:
version: 1.1.406
working-directory: tests/ekfac_tests
- name: build
run: pip wheel --no-deps -w dist .
env:
Expand Down
7 changes: 3 additions & 4 deletions bergson/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def callback(name: str, g: torch.Tensor):

for indices in tqdm(batches, disable=rank != 0, desc="Building index"):
batch = data[indices]
x, y = pad_and_tensor(
x, y, valid_masks = pad_and_tensor(
batch["input_ids"], # type: ignore
labels=batch.get("labels"), # type: ignore
device=model.device,
Expand All @@ -100,8 +100,7 @@ def callback(name: str, g: torch.Tensor):
reduction="none",
).reshape_as(y[:, 1:])

masks = y[:, 1:] != -100
denoms = masks.sum(dim=1, dtype=logits.dtype)
denoms = valid_masks.sum(dim=1, dtype=logits.dtype)
losses = losses.sum(1).div(denoms)
losses.mean().backward()

Expand Down Expand Up @@ -208,7 +207,7 @@ def adam_update(name: str, g: torch.Tensor):
closure=callback,
target_modules=target_modules,
):
x, y = pad_and_tensor(
x, y, _ = pad_and_tensor(
batch["input_ids"], # type: ignore
labels=batch.get("labels", None), # type: ignore
device=model.device,
Expand Down
18 changes: 13 additions & 5 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,14 @@ def pad_and_tensor(
padding_value: int = 0,
dtype: torch.dtype | None = torch.long,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Pad a list of sequences to the same length and convert them to tensors.
Returns a tuple of padded sequences and labels. The labels are the same as the
sequences, but with -100 for the padding positions, which is useful for ignoring
padding in loss calculations.

Returns:
padded_tokens: Padded input sequences [N, S]
padded_labels: Labels with -100 for padding positions [N, S]
valid_masks: Boolean mask [N, S] where True indicates valid positions
"""
if labels is None:
labels = sequences
Expand All @@ -370,7 +372,13 @@ def pad_and_tensor(
# convert to tensor
padded_tokens = torch.tensor(padded, dtype=dtype, device=device)
padded_labels = torch.tensor(labels, dtype=dtype, device=device)
return padded_tokens, padded_labels

# Compute valid_masks: position i is valid if labels[i+1] != -100
N, S = padded_tokens.shape
valid_masks = torch.zeros(N, S, dtype=torch.bool, device=device)
valid_masks[:, :-1] = padded_labels[:, 1:] != -100

return padded_tokens, padded_labels, valid_masks


def tokenize(batch: dict, *, args: DataConfig, tokenizer):
Expand Down
8 changes: 6 additions & 2 deletions bergson/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def setup_model_and_peft(
torch.cuda.manual_seed_all(42)

# Common configuration
device_map = {"": f"cuda:{rank}"} if not cfg.fsdp else "cpu"
if cfg.fsdp or not torch.cuda.is_available():
device_map = "cpu"
else:
device_map = {"": f"cuda:{rank}"}
quantization_config = None
if cfg.precision in ("int4", "int8"):
quantization_config = BitsAndBytesConfig(
Expand Down Expand Up @@ -202,7 +205,8 @@ def worker_wrapper(
setup_processor: bool = True,
):
try:
torch.cuda.set_device(rank)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if cfg.debug:
setup_reproducibility()
print("DEBUG MODE IS ENABLED: quasi-deterministic training")
Expand Down
23 changes: 18 additions & 5 deletions bergson/hessians/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.hooks import RemovableHandle

from bergson.hessians.sharded_computation import ShardedMul
from bergson.utils import assert_type
from bergson.utils import assert_type, get_device


@dataclass
Expand All @@ -37,6 +37,11 @@ class HookCollectorBase(ContextDecorator, ABC):
Set of module names to attach hooks to. Should consist only of nn.Linear modules.
If None, hooks are attached to all Linear layers in the model.
"""
valid_masks: Tensor = None # type: ignore[assignment]
"""
Mask of shape [N, S] indicating which positions are valid.
Must be set via set_valid_masks() before each batch.
"""

@staticmethod
def discover_targets(
Expand Down Expand Up @@ -82,6 +87,12 @@ def __post_init__(self):
# Allow subclasses to perform custom initialization
self.setup()

def set_valid_masks(self, masks: Tensor) -> None:
"""
Set the valid_masks for the current batch.
"""
self.valid_masks = masks

def __enter__(self):
"""Register forward and backward hooks on all target modules."""
for name in self.target_info:
Expand Down Expand Up @@ -214,8 +225,8 @@ def forward_hook(self, name: str, a: Tensor) -> None:
"""Compute activation covariance: A^T @ A."""
A_cov_ki = self.A_cov_dict[name]

# Reshape to [N*S, I]
a_bi = a.reshape(-1, a.shape[-1])
# a: [N, S, I], valid_masks: [N, S] -> select valid positions
a_bi = a[self.valid_masks] # [num_valid, I]

# Compute local covariance
local_update_ii = a_bi.mT @ a_bi
Expand Down Expand Up @@ -290,18 +301,20 @@ class LambdaCollector(HookCollectorBase):

def setup(self) -> None:
"""Load eigenvectors and initialize storage."""
device = get_device(self.rank)

# Load precomputed eigenvectors
self.eigen_a = load_file(
os.path.join(
self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors"
),
device=f"cuda:{self.rank}",
device=device,
)
self.eigen_g = load_file(
os.path.join(
self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors"
),
device=f"cuda:{self.rank}",
device=device,
)

# Initialize accumulators
Expand Down
31 changes: 11 additions & 20 deletions bergson/hessians/data_filtering_ekfac.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,17 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Literal\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from datasets import load_dataset\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from bergson.data import load_gradients\n",
"from safetensors.torch import load_file\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.stats import spearmanr\n",
"\n",
"import json"
"from bergson.data import load_gradients"
]
},
{
Expand Down Expand Up @@ -373,8 +366,6 @@
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.stats import spearmanr\n",
"\n",
"# Calculate Spearman correlation\n",
"mask = ~(np.isnan(attributions_scores) | np.isnan(attributions_ekfac_scores))\n",
Expand Down Expand Up @@ -461,9 +452,9 @@
")\n",
"\n",
"\n",
"plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n",
"plt.xlabel(\"Number of elements removed\")\n",
"plt.ylabel('Number of elements in the \"correct\" half')\n",
"plt.title(\"EK-FAC, no attn, on train set\")\n",
Expand Down Expand Up @@ -638,7 +629,6 @@
"# load the saved attributions\n",
"import json\n",
"\n",
"\n",
"all_attributions = {}\n",
"\n",
"for path in all_query_paths:\n",
Expand Down Expand Up @@ -849,7 +839,7 @@
],
"source": [
"# plot intersection\n",
"plt.plot(np.array(top_percentages) * len(index), intersection_12, label=f\"Intersection\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersection_12, label=\"Intersection\")\n",
"plt.plot(\n",
" [0, len(index) // 2, len(index)],\n",
" [0, len(index) // 2, len(index)],\n",
Expand Down Expand Up @@ -1218,9 +1208,9 @@
")\n",
"\n",
"\n",
"plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n",
"plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n",
"plt.xlabel(\"Number of elements removed\")\n",
"plt.ylabel('Number of elements in the \"correct\" half')\n",
"plt.legend()\n",
Expand Down Expand Up @@ -1374,9 +1364,10 @@
}
],
"source": [
"import torch\n",
"import os\n",
"\n",
"import torch\n",
"\n",
"# Set the debug flag - this is the correct way\n",
"os.environ[\"TORCH_COMPILE_DEBUG\"] = \"1\"\n",
"\n",
Expand Down
23 changes: 13 additions & 10 deletions bergson/hessians/ekfac_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from bergson.hessians.logger import get_logger
from bergson.hessians.sharded_computation import ShardedMul
from bergson.utils import get_device


class EkfacComputer:
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
def compute_covariance(self):
cov_collector = CovarianceCollector(
self.model.base_model,
target_modules=self.target_modules,
dtype=self.dtype,
shard_computer=self.shard_computer,
rank=self.rank,
Expand All @@ -106,7 +108,7 @@ def compute_eigendecomposition(
"""This is Eq. 18 from above reference."""
total_processed = torch.load(
os.path.join(self.path, "total_processed_covariances.pt"),
map_location=f"cuda:{self.rank}",
map_location=torch.device(get_device(self.rank)),
)

random.seed(0)
Expand Down Expand Up @@ -224,13 +226,14 @@ def _collector(self, collector, desc: Optional[str] = None):
self.batches, disable=self.rank != 0, desc=f"Computing {desc}"
):
batch = self.data[sl]
x, y = pad_and_tensor(
x, y, valid_masks = pad_and_tensor(
batch["input_ids"], # type: ignore
labels=batch.get("labels"), # type: ignore
device=self.model.device,
)

total_processed += x.numel()
total_processed += valid_masks.sum()
collector.set_valid_masks(valid_masks)

with (
collector,
Expand All @@ -256,19 +259,19 @@ def _collector(self, collector, desc: Optional[str] = None):
probs,
num_samples=1,
).flatten()

del probs

flat_mask = valid_masks[:, :-1].flatten()
losses = F.cross_entropy(
logits,
sampled_labels,
logits[flat_mask],
sampled_labels[flat_mask],
reduction="none",
).reshape_as(y[:, 1:])
)

losses = losses.sum(1)
losses.mean().backward()
losses.sum().backward()
self.model.zero_grad()
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()

if self.cfg.profile:
assert isinstance(prof, profile), "Profiler is not set up correctly"
Expand Down
21 changes: 7 additions & 14 deletions bergson/hessians/misaligned_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,16 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Literal\n",
"import joblib\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from datasets import load_dataset\n",
"from tqdm.notebook import tqdm\n",
"import json\n",
"from datasets import Dataset\n",
"from bergson.data import load_gradients\n",
"from safetensors.torch import load_file\n",
"from sklearn.metrics import roc_auc_score\n",
"import numpy as np\n",
"\n",
"from sklearn.metrics import roc_auc_score, precision_recall_curve, auc\n"
"from datasets import Dataset, load_dataset\n",
"from sklearn.metrics import auc, precision_recall_curve, roc_auc_score\n",
"\n"
]
},
{
Expand Down Expand Up @@ -439,14 +432,14 @@
"print(f\"PR AUC: {pr_auc:.4f}\")\n",
"\n",
"# Additional metrics for analysis\n",
"print(f\"\\nDataset composition:\")\n",
"print(\"\\nDataset composition:\")\n",
"print(f\"Correct examples: {len(sorted_correct_scores)}\")\n",
"print(f\"Incorrect examples: {len(sorted_incorrect_scores)}\")\n",
"print(f\"Subtle incorrect examples: {len(sorted_subtle_scores)}\")\n",
"print(f\"Total examples: {len(all_scores)}\")\n",
"print(f\"Problematic ratio: {(len(sorted_incorrect_scores) + len(sorted_subtle_scores)) / len(all_scores):.3f}\")\n",
"\n",
"print(f\"\\nScore statistics:\")\n",
"print(\"\\nScore statistics:\")\n",
"print(f\"Correct scores - Mean: {sorted_correct_scores.mean():.4f}, Std: {sorted_correct_scores.std():.4f}\")\n",
"print(f\"Incorrect scores - Mean: {sorted_incorrect_scores.mean():.4f}, Std: {sorted_incorrect_scores.std():.4f}\")\n",
"print(f\"Subtle scores - Mean: {sorted_subtle_scores.mean():.4f}, Std: {sorted_subtle_scores.std():.4f}\")\n"
Expand Down
Loading
Loading