diff --git a/bergson/collector/collector.py b/bergson/collector/collector.py index c94818b6..1601e4ca 100644 --- a/bergson/collector/collector.py +++ b/bergson/collector/collector.py @@ -621,7 +621,7 @@ def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict): if index_cfg.loss_reduction == "mean" else 1.0 ) - if not hessian_cfg.use_dataset_labels: + if hessian_cfg.use_dataset_labels: losses = F.cross_entropy( logits.reshape(-1, logits.size(-1)), y[:, 1:].flatten(), diff --git a/bergson/hessians/kfac.py b/bergson/hessians/kfac.py index ac4aeeee..1ac5187a 100644 --- a/bergson/hessians/kfac.py +++ b/bergson/hessians/kfac.py @@ -69,9 +69,10 @@ def backward_hook(self, module: nn.Module, g: Tensor) -> None: """Compute gradient covariance: G^T @ G.""" name = assert_type(str, module._name) S_cov_po = self.S_cov_dict[name] + mask = self._current_valid_mask - # Reshape to [N*S, O] - g_bo = g.reshape(-1, g.shape[-1]) + # g: [N, S, O], mask: [N, S] -> select valid positions + g_bo = g[mask] # [num_valid, O] # Compute local covariance local_update_oo = g_bo.mT @ g_bo diff --git a/tests/ekfac_tests/compute_ekfac_ground_truth.py b/tests/ekfac_tests/compute_ekfac_ground_truth.py new file mode 100644 index 00000000..6f663002 --- /dev/null +++ b/tests/ekfac_tests/compute_ekfac_ground_truth.py @@ -0,0 +1,787 @@ +# %% +# %load_ext autoreload +# %autoreload 2 + +# %% +"""Compute EKFAC ground truth for testing. + +This script computes ground truth covariance matrices, eigenvectors, and eigenvalue +corrections for EKFAC on a single GPU without sharding. By specifying the number of +workers we can simulate distributed computation. +""" + +import argparse +import builtins +import gc +import json +import os +import sys +from dataclasses import asdict +from typing import TYPE_CHECKING, Any, Optional + +import torch +import torch.nn.functional as F +from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset +from ground_truth.collector import ( + GroundTruthAmortizedLambdaCollector, + GroundTruthCovarianceCollector, +) +from safetensors.torch import load_file, save_file +from test_utils import add_tensor_dicts, set_all_seeds, tensor_dict_to_device +from torch import Tensor +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + PreTrainedModel, +) + +from bergson.config import DataConfig, IndexConfig +from bergson.data import allocate_batches, pad_and_tensor, tokenize +from bergson.hessians.kfac import CovarianceCollector +from bergson.utils.utils import assert_type, get_device + +Precision = str # Type alias for precision strings + +Batches = list[list[list[int]]] + +# %% [markdown] +# ## 0. Hyperparameters + + +# %% +def parse_config() -> tuple[Precision, str, str, int, bool]: + """Parse command-line arguments or return defaults.""" + parser = argparse.ArgumentParser( + description="Compute EKFAC ground truth for testing" + ) + parser.add_argument( + "--precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "int4", "int8"], + help="Model precision (default: fp32)", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default=os.path.join( + os.getcwd(), "test_files", "pile_100_examples", "ground_truth" + ), + help="Output directory for ground truth results " + "(default: test_files/pile_100_examples/ground_truth)", + ) + parser.add_argument( + "--model-name", + type=str, + default="EleutherAI/Pythia-14m", + help="Model name to use (default: EleutherAI/Pythia-14m)", + ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of workers for simulated distributed computation (default: 1)", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing ground truth data and config", + ) + + # For interactive mode (Jupyter/IPython) or no args, use defaults + if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"): + args = parser.parse_args() + else: + args = parser.parse_args([]) + + # Set random seeds for reproducibility + set_all_seeds(42) + + return ( + args.precision, + args.output_dir, + args.model_name, + args.world_size, + args.overwrite, + ) + + +if __name__ == "__main__" or TYPE_CHECKING: + precision, test_path, model_name, world_size_arg, overwrite_arg = parse_config() + + +# %% +def setup_paths_and_config( + precision: Precision, + test_path: str, + model_name: str, + world_size: int, + overwrite: bool = False, +) -> tuple[IndexConfig, int, torch.device, Any, torch.dtype]: + """Setup paths and configuration object.""" + os.makedirs(test_path, exist_ok=True) + + current_path = os.getcwd() + parent_path = os.path.join(current_path, "test_files", "pile_100_examples") + + # Configuration + cfg = IndexConfig(run_path="", loss_reduction="sum") + cfg.model = model_name + cfg.precision = precision # type: ignore[assignment] + cfg.fsdp = False + cfg.data = DataConfig(dataset=os.path.join(parent_path, "data"), truncation=True) + # Set token_batch_size to a value that fits in GPU memory (8GB GPU) + cfg.token_batch_size = 2048 + + # model_max_length is limited in some models like `roneneldan/TinyStories-1M` + tokenizer = AutoTokenizer.from_pretrained(cfg.model) + if ( + hasattr(tokenizer, "model_max_length") + and tokenizer.model_max_length < cfg.token_batch_size + ): + print( + f"Warning: Got --token-batch-size {cfg.token_batch_size} but " + f"{model_name} only supports up to {tokenizer.model_max_length}" + ) + cfg.token_batch_size = tokenizer.model_max_length + + data_str = cfg.data.dataset + + # Create pile-100 dataset if it doesn't exist + if not os.path.exists(data_str): + full_dataset = load_dataset("NeelNanda/pile-10k", split="train") + assert isinstance(full_dataset, Dataset), "Expected Dataset, got something else" + subset = full_dataset.select(range(100)) + os.makedirs(os.path.dirname(data_str), exist_ok=True) + subset.save_to_disk(data_str) + print(f"Generated pile-100 in {data_str}") + + config_path = os.path.join(test_path, "index_config.json") + if os.path.exists(config_path): + if not overwrite: + # Load existing config and compare + with open(config_path, "r") as f: + existing_cfg_dict = json.load(f) + + new_cfg_dict = asdict(cfg) + + if existing_cfg_dict != new_cfg_dict: + # Show differences for debugging + diffs = [ + f" {k}: {existing_cfg_dict[k]} != {new_cfg_dict[k]}" + for k in new_cfg_dict + if k in existing_cfg_dict + and existing_cfg_dict[k] != new_cfg_dict[k] + ] + raise RuntimeError( + f"Existing config at {config_path} differs from requested config:\n" + + "\n".join(diffs) + + "\n\nUse --overwrite to replace the existing config." + ) + + print(f"Using existing config from {config_path}") + else: + print(f"Overwriting existing config at {config_path}") + with open(config_path, "w") as f: + json.dump(asdict(cfg), f, indent=4) + else: + # Save new config + with open(config_path, "w") as f: + json.dump(asdict(cfg), f, indent=4) + + # Setup + workers = world_size + device = torch.device(get_device(0)) + target_modules = None + + # Determine dtype + match cfg.precision: + case "bf16": + dtype = torch.bfloat16 + case "fp16": + dtype = torch.float16 + case "fp32": + dtype = torch.float32 + case "int4" | "int8": + dtype = ( + torch.bfloat16 + if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) + else torch.float16 + ) + case other: + raise ValueError(f"Unsupported precision: {other}") + + return cfg, workers, device, target_modules, dtype + + +if __name__ == "__main__" or TYPE_CHECKING: + cfg, workers, device, target_modules, dtype = setup_paths_and_config( + precision, test_path, model_name, world_size_arg, overwrite_arg + ) + + +# %% [markdown] +# ## 1. Loading model and data + + +# %% +def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel: + """Load the model.""" + print(f"Loading model {cfg.model}...") + model = AutoModelForCausalLM.from_pretrained( + cfg.model, + device_map="cuda" if torch.cuda.is_available() else "cpu", + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=cfg.precision == "int4", + load_in_8bit=cfg.precision == "int8", + bnb_4bit_compute_dtype=dtype, + bnb_4bit_quant_storage=dtype, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + if cfg.precision in ("int4", "int8") + else None + ), + torch_dtype=dtype, + ) + return model + + +if __name__ == "__main__" or TYPE_CHECKING: + model = load_model_step(cfg, dtype) + + +# %% +def load_dataset_step(cfg: IndexConfig) -> Dataset: + """Load and return the dataset.""" + data_str = cfg.data.dataset + print(f"Loading dataset from {data_str}...") + + if data_str.endswith(".csv"): + ds = assert_type(Dataset, Dataset.from_csv(data_str)) + elif data_str.endswith(".json") or data_str.endswith(".jsonl"): + ds = assert_type(Dataset, Dataset.from_json(data_str)) + else: + try: + ds = load_dataset(data_str, split="train") + if isinstance(ds, (DatasetDict, IterableDatasetDict)): + raise NotImplementedError( + "DatasetDicts and IterableDatasetDicts are not supported." + ) + except ValueError as e: + if "load_from_disk" in str(e): + ds = Dataset.load_from_disk(data_str, keep_in_memory=False) + else: + raise e + + assert isinstance(ds, Dataset) + return ds + + +if __name__ == "__main__" or TYPE_CHECKING: + ds = load_dataset_step(cfg) + + +# %% +def tokenize_and_allocate_step( + ds: Dataset, cfg: IndexConfig, workers: int +) -> tuple[Dataset, Batches, Any]: + """Tokenize dataset and allocate batches.""" + tokenizer = AutoTokenizer.from_pretrained( + cfg.model, model_max_length=cfg.token_batch_size + ) + ds = ds.map( + tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer) + ) + data = ds + + # Allocate batches - use same allocate_batches as EKFAC for consistent order + # This ensures floating-point accumulation happens in the same order + batches = allocate_batches(doc_lengths=ds["length"], N=cfg.token_batch_size) + batches_world = [batches] # Wrap for single worker case + assert len(batches_world) == workers + + return data, batches_world, tokenizer + + +if __name__ == "__main__" or TYPE_CHECKING: + data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers) + + +# %% [markdown] +# ## 2. Compute activation and gradient covariance + + +# %% +def compute_covariance( + rank: int, + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + activation_covariances: dict[str, Tensor], + gradient_covariances: dict[str, Tensor], + ekfac_collector: Optional[CovarianceCollector] = None, +) -> dict[str, Any]: + """Compute activation and gradient covariances for a single worker. + + If ekfac_collector is provided, it will be run simultaneously with the ground + truth collector during the same forward/backward passes. This ensures both + collectors see exactly the same gradients. + """ + total_processed = torch.tensor(0, device=device) + batches = batches_world[rank] + loss_list = [] + + collector = GroundTruthCovarianceCollector( + model=model.base_model, + activation_covariances=activation_covariances, + gradient_covariances=gradient_covariances, + target_modules=target_modules, + ) + + for sl in tqdm(batches, desc=f"Rank {rank} covariances"): + batch = data[sl] + x, y, valid_masks = pad_and_tensor( + batch["input_ids"], + labels=batch.get("labels"), + device=device, + ) + + total_processed += valid_masks.sum() + + # Run both collectors simultaneously during the same forward/backward pass + # This ensures they see exactly the same gradients + with collector.with_batch(valid_masks): + if ekfac_collector is not None: + ekfac_ctx = ekfac_collector.with_batch(valid_masks) + else: + ekfac_ctx = None + + if ekfac_ctx is not None: + ekfac_ctx.__enter__() + try: + # Use same loss computation as EKFAC (fwd_bwd_hessian_factory) + logits = model(x).logits[:, :-1] + losses = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), + y[:, 1:].flatten(), + reduction="none", + ).reshape_as(y[:, 1:]) + # Sum over sequence first, then over batch + # (like EKFAC with loss_reduction="sum") + losses = losses.sum(1) + losses.sum().backward() + loss_list.append(losses.detach().cpu()) + model.zero_grad() + finally: + if ekfac_ctx is not None: + ekfac_ctx.__exit__(None, None, None) + + return {"losses": loss_list, "total_processed_rank": total_processed.item()} + + +# %% +def compute_covariances_step( + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + workers: int, + test_path: str, + ekfac_path: Optional[str] = None, + dtype: torch.dtype = torch.float32, +) -> str: + """Compute covariances for all ranks and save to disk. + + If ekfac_path is provided, also runs the EKFAC CovarianceCollector simultaneously + during the same forward/backward passes. This ensures both collectors see exactly + the same gradients, enabling precise numerical comparison. + """ + covariance_test_path = os.path.join(test_path, "covariances") + + # Create EKFAC collector if path is provided + ekfac_collector = None + if ekfac_path is not None: + os.makedirs(ekfac_path, exist_ok=True) + ekfac_collector = CovarianceCollector( + model=model.base_model, + dtype=dtype, + path=ekfac_path, + target_modules=target_modules, + ) + + total_processed_global = 0 + for rank in range(workers): + covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}") + os.makedirs(covariance_test_path_rank, exist_ok=True) + + activation_covariances = {} + gradient_covariances = {} + d = compute_covariance( + rank=rank, + model=model, + data=data, + batches_world=batches_world, + device=device, + target_modules=target_modules, + activation_covariances=activation_covariances, + gradient_covariances=gradient_covariances, + ekfac_collector=ekfac_collector, + ) + + total_processed_global += d["total_processed_rank"] + + save_file( + activation_covariances, + os.path.join( + covariance_test_path_rank, "activation_covariance.safetensors" + ), + ) + save_file( + gradient_covariances, + os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"), + ) + with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f: + json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4) + print(f"Rank {rank} processed {d['total_processed_rank']} tokens.") + + # Finalize EKFAC collector and save total processed + if ekfac_collector is not None and ekfac_path is not None: + ekfac_collector.teardown() + torch.save( + torch.tensor(total_processed_global, device=device), + os.path.join(ekfac_path, "total_processed.pt"), + ) + + return covariance_test_path + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Covariances ===") + covariance_test_path = compute_covariances_step( + model, data, batches_world, device, target_modules, workers, test_path + ) + + +# %% +def combine_covariances_step( + covariance_test_path: str, workers: int, device: torch.device +) -> int: + """Combine covariance results from all ranks.""" + activation_covariances: dict[str, Tensor] = {} + gradient_covariances: dict[str, Tensor] = {} + total_processed_global = 0 + + for rank in range(workers): + covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}") + + with open(os.path.join(covariance_test_path_rank, "stats.json"), "r") as f: + d = json.load(f) + total_processed_global += d["total_processed_rank"] + + activation_covariances_rank = tensor_dict_to_device( + load_file( + os.path.join( + covariance_test_path_rank, "activation_covariance.safetensors" + ) + ), + device, + ) + + gradient_covariances_rank = tensor_dict_to_device( + load_file( + os.path.join( + covariance_test_path_rank, "gradient_covariance.safetensors" + ) + ), + device, + ) + + if not activation_covariances: + activation_covariances = activation_covariances_rank + else: + activation_covariances = add_tensor_dicts( + activation_covariances, activation_covariances_rank + ) + + if not gradient_covariances: + gradient_covariances = gradient_covariances_rank + else: + gradient_covariances = add_tensor_dicts( + gradient_covariances, gradient_covariances_rank + ) + + save_file( + activation_covariances, + os.path.join(covariance_test_path, "activation_covariance.safetensors"), + ) + save_file( + gradient_covariances, + os.path.join(covariance_test_path, "gradient_covariance.safetensors"), + ) + with open(os.path.join(covariance_test_path, "stats.json"), "w") as f: + json.dump({"total_processed_global": total_processed_global}, f, indent=4) + print(f"Global processed {total_processed_global} tokens.") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return total_processed_global + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Combining Covariances ===") + total_processed_global = combine_covariances_step( + covariance_test_path, workers, device + ) + + +# %% [markdown] +# ## 3. Compute eigenvalues and eigenvectors + + +# %% +def compute_eigenvectors_step( + test_path: str, device: torch.device, dtype: torch.dtype +) -> str: + """Compute eigenvectors from covariances.""" + covariance_test_path = os.path.join(test_path, "covariances") + eigenvectors_test_path = os.path.join(test_path, "eigenvectors") + os.makedirs(eigenvectors_test_path, exist_ok=True) + + # Load covariances + with open(os.path.join(covariance_test_path, "stats.json"), "r") as f: + d = json.load(f) + total_processed_global = d["total_processed_global"] + + activation_covariances = load_file( + os.path.join(covariance_test_path, "activation_covariance.safetensors") + ) + gradient_covariances = load_file( + os.path.join(covariance_test_path, "gradient_covariance.safetensors") + ) + + eigenvectors_activations = {} + eigenvectors_gradients = {} + + for name in activation_covariances.keys(): + a = activation_covariances[name].to(dtype=torch.float64, device=device) + g = gradient_covariances[name].to(dtype=torch.float64, device=device) + a = (a + a.T).div(2) + g = (g + g.T).div(2) + a.div_(total_processed_global) + g.div_(total_processed_global) + + eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a) + eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g) + print( + f"{name}: eigenvectors_a.sum()={eigenvectors_a.sum()}, " + f"eigenvectors_g.sum()={eigenvectors_g.sum()}" + ) + eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous() + eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous() + + save_file( + eigenvectors_activations, + os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"), + ) + save_file( + eigenvectors_gradients, + os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"), + ) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return eigenvectors_test_path + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Eigenvectors ===") + eigenvectors_test_path = compute_eigenvectors_step(test_path, device, dtype) + + +# %% [markdown] +# ## 4. Compute eigenvalue correction + + +# %% +def compute_eigenvalue_correction_amortized( + rank: int, + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + eigenvalue_corrections: dict[str, Tensor], + eigenvectors_activations: dict[str, Tensor], + eigenvectors_gradients: dict[str, Tensor], +) -> dict[str, Any]: + """Compute eigenvalue corrections using the amortized method.""" + total_processed = torch.tensor(0, device=device) + batches = batches_world[rank] + + collector = GroundTruthAmortizedLambdaCollector( + model=model.base_model, + eigenvalue_corrections=eigenvalue_corrections, + eigenvectors_activations=eigenvectors_activations, + eigenvectors_gradients=eigenvectors_gradients, + device=device, + target_modules=target_modules, + ) + + for sl in tqdm(batches, desc=f"Rank {rank} eigenvalue corrections"): + batch = data[sl] + x, y, valid_masks = pad_and_tensor( + batch["input_ids"], + labels=batch.get("labels"), + device=device, + ) + + total_processed += valid_masks.sum() + + with collector.with_batch(valid_masks): + logits = model(x).logits + losses = F.cross_entropy( + logits[:, :-1].reshape(-1, logits.size(-1)), + y[:, 1:].flatten(), + reduction="none", + ).reshape_as(y[:, 1:]) + + losses.sum().backward() + model.zero_grad() + + return {"total_processed_rank": total_processed.item()} + + +# %% +def compute_eigenvalue_corrections_step( + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + workers: int, + test_path: str, +) -> tuple[str, int]: + """Compute eigenvalue corrections for all ranks.""" + eigenvectors_test_path = os.path.join(test_path, "eigenvectors") + eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections") + os.makedirs(eigenvalue_correction_test_path, exist_ok=True) + + # Load eigenvectors + eigenvectors_activations = load_file( + os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors") + ) + eigenvectors_gradients = load_file( + os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors") + ) + + total_processed_global = 0 + for rank in range(workers): + eigenvalue_correction_test_path_rank = os.path.join( + eigenvalue_correction_test_path, f"rank_{rank}" + ) + os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True) + + eigenvalue_corrections = {} + d = compute_eigenvalue_correction_amortized( + rank=rank, + model=model, + data=data, + batches_world=batches_world, + device=device, + target_modules=target_modules, + eigenvalue_corrections=eigenvalue_corrections, + eigenvectors_activations=eigenvectors_activations, + eigenvectors_gradients=eigenvectors_gradients, + ) + + save_file( + eigenvalue_corrections, + os.path.join( + eigenvalue_correction_test_path_rank, + "eigenvalue_corrections.safetensors", + ), + ) + with open( + os.path.join(eigenvalue_correction_test_path_rank, "stats.json"), "w" + ) as f: + json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4) + print(f"Rank {rank} processed {d['total_processed_rank']} tokens.") + total_processed_global += d["total_processed_rank"] + + return eigenvalue_correction_test_path, total_processed_global + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Eigenvalue Corrections ===") + eigenvalue_correction_test_path, total_processed_global_lambda = ( + compute_eigenvalue_corrections_step( + model, data, batches_world, device, target_modules, workers, test_path + ) + ) + + +# %% +def combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path: str, + workers: int, + device: torch.device, + total_processed_global: int, +) -> dict[str, Tensor]: + """Combine eigenvalue correction results from all ranks.""" + eigenvalue_corrections: dict[str, Tensor] = {} + + for rank in range(workers): + eigenvalue_correction_test_path_rank = os.path.join( + eigenvalue_correction_test_path, f"rank_{rank}" + ) + + eigenvalue_corrections_rank = tensor_dict_to_device( + load_file( + os.path.join( + eigenvalue_correction_test_path_rank, + "eigenvalue_corrections.safetensors", + ) + ), + device, + ) + + if not eigenvalue_corrections: + eigenvalue_corrections = eigenvalue_corrections_rank + else: + eigenvalue_corrections = add_tensor_dicts( + eigenvalue_corrections, eigenvalue_corrections_rank + ) + + # Divide by total_processed_global + eigenvalue_corrections = { + k: v / total_processed_global for k, v in eigenvalue_corrections.items() + } + save_file( + eigenvalue_corrections, + os.path.join( + eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors" + ), + ) + + return eigenvalue_corrections + + +if __name__ == "__main__" or TYPE_CHECKING: + eigenvalue_corrections = combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path, workers, device, total_processed_global_lambda + ) + print("\n=== Ground Truth Computation Complete ===") + print(f"Results saved to: {test_path}") diff --git a/tests/ekfac_tests/conftest.py b/tests/ekfac_tests/conftest.py new file mode 100644 index 00000000..f374ea15 --- /dev/null +++ b/tests/ekfac_tests/conftest.py @@ -0,0 +1,348 @@ +"""Pytest configuration and fixtures for EKFAC tests.""" + +import os +from typing import Any, Optional + +import pytest +from compute_ekfac_ground_truth import ( + combine_covariances_step, + combine_eigenvalue_corrections_step, + compute_covariances_step, + compute_eigenvalue_corrections_step, + compute_eigenvectors_step, + load_dataset_step, + load_model_step, + setup_paths_and_config, + tokenize_and_allocate_step, +) +from test_utils import set_all_seeds + +Precision = str # Type alias for precision strings + + +def pytest_addoption(parser) -> None: + """Add custom command-line options for EKFAC tests.""" + parser.addoption( + "--model_name", + action="store", + type=str, + default="EleutherAI/Pythia-14m", + help="Model name for ground truth generation (default: EleutherAI/Pythia-14m)", + ) + parser.addoption( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing run directory", + ) + parser.addoption( + "--precision", + action="store", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "int4", "int8"], + help="Model precision for ground truth generation (default: fp32)", + ) + parser.addoption( + "--test_dir", + action="store", + default=None, + help="Directory containing test data. If not provided, generates data.", + ) + parser.addoption( + "--world_size", + action="store", + type=int, + default=1, + help="World size for distributed training (default: 1)", + ) + + +@pytest.fixture(autouse=True) +def setup_test() -> None: + """Setup logic run before each test.""" + set_all_seeds(seed=42) + + +@pytest.fixture(scope="session") +def gradient_batch_size(request) -> int: + return request.config.getoption("--gradient_batch_size") + + +@pytest.fixture(scope="session") +def gradient_path(request) -> Optional[str]: + return request.config.getoption("--gradient_path") + + +@pytest.fixture(scope="session") +def model_name(request) -> str: + return request.config.getoption("--model_name") + + +@pytest.fixture(scope="session") +def overwrite(request) -> bool: + return request.config.getoption("--overwrite") + + +@pytest.fixture(scope="session") +def precision(request) -> Precision: + return request.config.getoption("--precision") + + +@pytest.fixture(scope="session") +def use_fsdp(request) -> bool: + return request.config.getoption("--use_fsdp") + + +@pytest.fixture(scope="session") +def world_size(request) -> int: + return request.config.getoption("--world_size") + + +@pytest.fixture(scope="session") +def test_dir(request, tmp_path_factory) -> str: + """Get or create test directory (does not generate ground truth data).""" + # Check if test directory was provided + test_dir = request.config.getoption("--test_dir") + if test_dir is not None: + return test_dir + + # Create temporary directory for auto-generated test data + tmp_dir = tmp_path_factory.mktemp("ekfac_test_data") + return str(tmp_dir) + + +def ground_truth_base_path(test_dir: str) -> str: + return os.path.join(test_dir, "ground_truth") + + +@pytest.fixture(scope="session") +def ground_truth_setup( + request, test_dir: str, precision: Precision, overwrite: bool +) -> dict[str, Any]: + # Setup for generation + model_name = request.config.getoption("--model_name") + world_size = request.config.getoption("--world_size") + + print(f"\n{'='*60}") + print("Generating ground truth test data") + print(f"Model: {model_name}") + print(f"Precision: {precision}") + print(f"World size: {world_size}") + print(f"{'='*60}\n") + + cfg, workers, device, target_modules, dtype = setup_paths_and_config( + precision=precision, + test_path=ground_truth_base_path(test_dir), + model_name=model_name, + world_size=world_size, + overwrite=overwrite, + ) + + model = load_model_step(cfg, dtype) + model.eval() # Disable dropout for deterministic forward passes + ds = load_dataset_step(cfg) + data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers) + + return { + "cfg": cfg, + "workers": workers, + "device": device, + "target_modules": target_modules, + "dtype": dtype, + "model": model, + "data": data, + "batches_world": batches_world, + } + + +@pytest.fixture(scope="session") +def ground_truth_covariances_path( + ground_truth_setup: dict[str, Any], test_dir: str, overwrite: bool +) -> str: + """Ensure ground truth covariances exist and return path.""" + base_path = ground_truth_base_path(test_dir) + covariances_path = os.path.join(base_path, "covariances") + + if os.path.exists(covariances_path) and not overwrite: + print("Using existing covariances") + return covariances_path + + setup = ground_truth_setup + # Reset seeds for deterministic computation (same seed as EKFAC will use) + set_all_seeds(42) + covariance_test_path = compute_covariances_step( + setup["model"], + setup["data"], + setup["batches_world"], + setup["device"], + setup["target_modules"], + setup["workers"], + base_path, + ) + combine_covariances_step(covariance_test_path, setup["workers"], setup["device"]) + print("Covariances computed") + return covariances_path + + +@pytest.fixture(scope="session") +def ground_truth_eigenvectors_path( + ground_truth_covariances_path: str, + ground_truth_setup: dict[str, Any], + test_dir: str, + overwrite: bool, +) -> str: + """Ensure ground truth eigenvectors exist and return path.""" + base_path = ground_truth_base_path(test_dir) + eigenvectors_path = os.path.join(base_path, "eigenvectors") + + if os.path.exists(eigenvectors_path) and not overwrite: + print("Using existing eigenvectors") + return eigenvectors_path + + setup = ground_truth_setup + compute_eigenvectors_step(base_path, setup["device"], setup["dtype"]) + print("Eigenvectors computed") + return eigenvectors_path + + +@pytest.fixture(scope="session") +def ground_truth_eigenvalue_corrections_path( + ground_truth_eigenvectors_path: str, + ground_truth_setup: dict[str, Any], + test_dir: str, + overwrite: bool, +) -> str: + """Ensure ground truth eigenvalue corrections exist and return path.""" + base_path = ground_truth_base_path(test_dir) + eigenvalue_corrections_path = os.path.join(base_path, "eigenvalue_corrections") + + if os.path.exists(eigenvalue_corrections_path) and not overwrite: + print("Using existing eigenvalue corrections") + return eigenvalue_corrections_path + + setup = ground_truth_setup + eigenvalue_correction_test_path, total_processed_global_lambda = ( + compute_eigenvalue_corrections_step( + setup["model"], + setup["data"], + setup["batches_world"], + setup["device"], + setup["target_modules"], + setup["workers"], + base_path, + ) + ) + combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path, + setup["workers"], + setup["device"], + total_processed_global_lambda, + ) + print("Eigenvalue corrections computed") + print("\n=== Ground Truth Computation Complete ===") + print(f"Results saved to: {base_path}") + return eigenvalue_corrections_path + + +@pytest.fixture(scope="session") +def ground_truth_path( + ground_truth_eigenvalue_corrections_path: str, test_dir: str +) -> str: + """Get ground truth base path with all data guaranteed to exist. + + Depends on ground_truth_eigenvalue_corrections_path to ensure all + ground truth data exists. + """ + return ground_truth_base_path(test_dir) + + +@pytest.fixture(scope="session") +def ekfac_results_path( + test_dir: str, + ground_truth_path: str, + ground_truth_setup: dict[str, Any], + overwrite: bool, +) -> str: + """Run EKFAC computation and return results path. + + Uses the same data and batches as ground truth via collect_hessians to ensure + identical batch composition and floating-point accumulation order. + """ + import torch + + from bergson.config import HessianConfig + from bergson.hessians.eigenvectors import compute_eigendecomposition + from bergson.hessians.hessian_approximations import collect_hessians + + # collect_hessians writes to partial_run_path (run_path + ".part") + # We set run_path so partial_run_path points to our desired output location + base_run_path = os.path.join(test_dir, "run/kfac") + results_path = base_run_path + ".part" # Where collect_hessians will write + + if os.path.exists(results_path) and not overwrite: + print(f"Using existing EKFAC results in {results_path}") + return results_path + + setup = ground_truth_setup + cfg = setup["cfg"] + data = setup["data"] + batches = setup["batches_world"][0] # Single worker + target_modules = setup["target_modules"] + dtype = setup["dtype"] + + print(f"\nRunning EKFAC computation in {results_path}...") + + # Reset seeds for determinism (same as used before GT computation) + set_all_seeds(42) + + # Reload model to get fresh state (same as GT does) + model = load_model_step(cfg, dtype) + model.eval() + + cfg.run_path = base_run_path + cfg.partial_run_path.mkdir(parents=True, exist_ok=True) + + hessian_cfg = HessianConfig( + method="kfac", ev_correction=True, use_dataset_labels=True + ) + + # Phase 1: Covariance collection using collect_hessians + collect_hessians( + model=model, + data=data, + index_cfg=cfg, + batches=batches, + target_modules=target_modules, + hessian_cfg=hessian_cfg, + ) + + total_processed = torch.load( + os.path.join(results_path, "total_processed.pt"), + map_location="cpu", + weights_only=False, + ) + + # Phase 2: Eigendecomposition + compute_eigendecomposition( + os.path.join(results_path, "activation_sharded"), + total_processed=total_processed, + ) + compute_eigendecomposition( + os.path.join(results_path, "gradient_sharded"), + total_processed=total_processed, + ) + + # Phase 3: Eigenvalue correction + collect_hessians( + model=model, + data=data, + index_cfg=cfg, + batches=batches, + target_modules=target_modules, + hessian_cfg=hessian_cfg, + ev_correction=True, + ) + + print(f"EKFAC computation completed in {results_path}") + return results_path diff --git a/tests/ekfac_tests/ground_truth/__init__.py b/tests/ekfac_tests/ground_truth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ekfac_tests/ground_truth/collector.py b/tests/ekfac_tests/ground_truth/collector.py new file mode 100644 index 00000000..8cd11be9 --- /dev/null +++ b/tests/ekfac_tests/ground_truth/collector.py @@ -0,0 +1,146 @@ +"""Ground truth collector for EKFAC testing.""" + +from collections.abc import Mapping, MutableMapping +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch import Tensor + +from bergson.collector.collector import HookCollectorBase +from bergson.utils.utils import assert_type + + +@dataclass(kw_only=True) +class GroundTruthCovarianceCollector(HookCollectorBase): + activation_covariances: MutableMapping[str, Tensor] + gradient_covariances: MutableMapping[str, Tensor] + + def setup(self) -> None: + pass + + def teardown(self) -> None: + pass + + def forward_hook(self, module: nn.Module, a: Tensor) -> None: + name = assert_type(str, module._name) + mask = self._current_valid_mask + + # a: [N, S, I], valid_masks: [N, S] -> select valid positions + if mask is not None: + a = a[mask] # [num_valid, I] + else: + a = a.reshape(-1, a.shape[-1]) # [N*S, I] + + update = a.mT @ a + + if name not in self.activation_covariances: + self.activation_covariances[name] = update + else: + self.activation_covariances[name].add_(update) + + def backward_hook(self, module: nn.Module, g: Tensor) -> None: + name = assert_type(str, module._name) + mask = self._current_valid_mask + + # g: [N, S, O], valid_masks: [N, S] -> select valid positions + if mask is not None: + g = g[mask] # [num_valid, O] + else: + g = g.reshape(-1, g.shape[-1]) # [N*S, O] + + update = g.mT @ g + + if name not in self.gradient_covariances: + self.gradient_covariances[name] = update + else: + self.gradient_covariances[name].add_(update) + + def process_batch(self, indices: list[int], **kwargs) -> None: + pass + + +@dataclass(kw_only=True) +class GroundTruthNonAmortizedLambdaCollector(HookCollectorBase): + eigenvalue_corrections: MutableMapping[str, Tensor] + eigenvectors_activations: Mapping[str, Tensor] + eigenvectors_gradients: Mapping[str, Tensor] + device: torch.device + + def setup(self) -> None: + self.activation_cache: dict[str, Tensor] = {} + + def teardown(self) -> None: + self.activation_cache.clear() + + def forward_hook(self, module: nn.Module, a: Tensor) -> None: + name = assert_type(str, module._name) + self.activation_cache[name] = a + + def backward_hook(self, module: nn.Module, g: Tensor) -> None: + name = assert_type(str, module._name) + eigenvector_a = self.eigenvectors_activations[name].to(device=self.device) + eigenvector_g = self.eigenvectors_gradients[name].to(device=self.device) + + activation = self.activation_cache[name] # [N, S, I] + gradient = g # [N, S, O] + + gradient = torch.einsum("N S O, N S I -> N S O I", gradient, activation) + + gradient = torch.einsum("N S O I, I J -> N S O J", gradient, eigenvector_a) + gradient = torch.einsum("O P, N S O J -> N S P J", eigenvector_g, gradient) + + gradient = gradient.sum(dim=1) # sum over sequence length + + gradient = gradient**2 + correction = gradient.sum(dim=0) + + if name not in self.eigenvalue_corrections: + self.eigenvalue_corrections[name] = correction + else: + self.eigenvalue_corrections[name].add_(correction) + + def process_batch(self, indices: list[int], **kwargs) -> None: + pass + + +@dataclass(kw_only=True) +class GroundTruthAmortizedLambdaCollector(HookCollectorBase): + eigenvalue_corrections: MutableMapping[str, Tensor] + eigenvectors_activations: Mapping[str, Tensor] + eigenvectors_gradients: Mapping[str, Tensor] + device: torch.device + + def setup(self) -> None: + self.activation_cache: dict[str, Tensor] = {} + + def teardown(self) -> None: + self.activation_cache.clear() + + def forward_hook(self, module: nn.Module, a: Tensor) -> None: + name = assert_type(str, module._name) + self.activation_cache[name] = a + + def backward_hook(self, module: nn.Module, g: Tensor) -> None: + name = assert_type(str, module._name) + eigenvector_a = self.eigenvectors_activations[name].to(device=self.device) + eigenvector_g = self.eigenvectors_gradients[name].to(device=self.device) + + activation = self.activation_cache[name] # [N, S, I] + + transformed_a = torch.einsum("N S I, I J -> N S J", activation, eigenvector_a) + transformed_g = torch.einsum("O P, N S O -> N S P", eigenvector_g, g) + + correction = ( + (torch.einsum("N S O, N S I -> N O I", transformed_g, transformed_a) ** 2) + .sum(dim=0) + .contiguous() + ) + + if name not in self.eigenvalue_corrections: + self.eigenvalue_corrections[name] = correction + else: + self.eigenvalue_corrections[name].add_(correction) + + def process_batch(self, indices: list[int], **kwargs) -> None: + pass diff --git a/tests/ekfac_tests/run_test_compute_ekfac.sh b/tests/ekfac_tests/run_test_compute_ekfac.sh new file mode 100755 index 00000000..358f571c --- /dev/null +++ b/tests/ekfac_tests/run_test_compute_ekfac.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Run EKFAC computation tests + +cd "$(dirname "$0")" + +pytest -s -v \ + --test_dir "./test_files/pile_100_examples" \ + --world_size 8 \ + --overwrite \ + test_compute_ekfac.py \ + test_covariance.py \ + test_eigenvectors.py \ + test_eigenvalue_correction.py diff --git a/tests/ekfac_tests/test_batch_size_invariance.py b/tests/ekfac_tests/test_batch_size_invariance.py new file mode 100644 index 00000000..394051fd --- /dev/null +++ b/tests/ekfac_tests/test_batch_size_invariance.py @@ -0,0 +1,105 @@ +"""Test that covariance traces are batch-size invariant after normalization.""" + +import tempfile +from pathlib import Path + +import pytest +import torch + +from bergson.collector.collector import CollectorComputer, fwd_bwd_hessian_factory +from bergson.config import HessianConfig, IndexConfig +from bergson.hessians.kfac import CovarianceCollector +from bergson.utils.utils import get_device +from tests.ekfac_tests.test_utils import load_sharded_covariances, set_all_seeds +from tests.ekfac_tests.toy_model import ( + ToyDataConfig, + ToyLM, + ToyLMConfig, + generate_batches, + generate_dataset, +) + + +@pytest.mark.parametrize( + "seq_lengths, num_batches", + [ + ((16,), 20), # Single sequence length + ((16, 8), 20), # Mixed sequence lengths + ((64,), 10), # Longer sequences + ], +) +def test_trace_batch_invariant(seq_lengths, num_batches, tmp_path): + """Normalized covariance traces should be the same regardless of batch size.""" + set_all_seeds(42) + + config = ToyDataConfig( + vocab_size=16, + hidden_size=8, + seq_lengths=seq_lengths, + num_batches=num_batches, + ) + device = torch.device(get_device()) + + dataset = generate_dataset(config) + batches = generate_batches(config) + + model_config = ToyLMConfig( + vocab_size=config.vocab_size, hidden_size=config.hidden_size + ) + model = ToyLM( + model_config, + training_data=dataset, + training_batches=batches, + device=device, + ) + + # Flatten all indices from batches + indices = [idx for batch in batches for idx in batch] + + # B=1 vs B=2 batches + batches_b1 = [[i] for i in indices] + batches_b2 = [indices[i : i + 2] for i in range(0, len(indices), 2)] + + def compute_traces(batches: list[list[int]]) -> tuple[float, float]: + with tempfile.TemporaryDirectory() as tmpdir: + run_path = Path(tmpdir) / "run" + index_cfg = IndexConfig(run_path=str(run_path), loss_reduction="sum") + + collector = CovarianceCollector( + model=model.base_model, + target_modules={"linear"}, + dtype=torch.float32, + path=str(index_cfg.partial_run_path), + ) + + hessian_cfg = HessianConfig() + computer = CollectorComputer( + model=model, + data=dataset, + batches=batches, + collector=collector, + cfg=index_cfg, + ) + computer.forward_backward = fwd_bwd_hessian_factory(index_cfg, hessian_cfg) + computer.run_with_collector_hooks() + + # Load covariances + A = load_sharded_covariances( + index_cfg.partial_run_path / "activation_sharded" + ) + G = load_sharded_covariances( + index_cfg.partial_run_path / "gradient_sharded" + ) + n = torch.load(index_cfg.partial_run_path / "total_processed.pt").item() + + return ( + sum(v.trace().item() / n for v in A.values()), + sum(v.trace().item() / n for v in G.values()), + ) + + model.eval() + A1, G1 = compute_traces(batches_b1) + A2, G2 = compute_traces(batches_b2) + + torch.testing.assert_close(A1, A2, rtol=1e-2, atol=1e-4) + torch.testing.assert_close(G1, G2, rtol=1e-2, atol=1e-4) diff --git a/tests/ekfac_tests/test_compute_ekfac.py b/tests/ekfac_tests/test_compute_ekfac.py new file mode 100644 index 00000000..7e6af684 --- /dev/null +++ b/tests/ekfac_tests/test_compute_ekfac.py @@ -0,0 +1,28 @@ +"""Test EKFAC computation against ground truth.""" + +import json +import os + +import torch + + +def test_total_processed_examples( + ground_truth_covariances_path: str, ekfac_results_path: str +) -> None: + """Test that total processed examples match between ground truth and computed.""" + total_processed_ground_truth_path = os.path.join( + ground_truth_covariances_path, "stats.json" + ) + total_processed_run_path = os.path.join(ekfac_results_path, "total_processed.pt") + + with open(total_processed_ground_truth_path, "r") as f: + ground_truth_data = json.load(f) + total_processed_ground_truth = ground_truth_data["total_processed_global"] + + total_processed_run = torch.load(total_processed_run_path, weights_only=True).item() + + assert total_processed_ground_truth == total_processed_run, ( + f"Total processed examples do not match! " + f"Ground truth: {total_processed_ground_truth}, Run: {total_processed_run}" + ) + print(f"✓ Total processed examples match: {total_processed_ground_truth}") diff --git a/tests/ekfac_tests/test_covariance.py b/tests/ekfac_tests/test_covariance.py new file mode 100644 index 00000000..8aa0bed7 --- /dev/null +++ b/tests/ekfac_tests/test_covariance.py @@ -0,0 +1,56 @@ +import os + +import pytest +import torch +from safetensors.torch import load_file + +from tests.ekfac_tests.test_utils import load_sharded_covariances + + +@pytest.mark.parametrize("covariance_type", ["activation", "gradient"]) +def test_covariances( + ekfac_results_path: str, + ground_truth_covariances_path: str, + covariance_type: str, +) -> None: + """Test covariances against ground truth.""" + print(f"\nTesting {covariance_type} covariances...") + + covariances_ground_truth_path = os.path.join( + ground_truth_covariances_path, f"{covariance_type}_covariance.safetensors" + ) + covariances_run_path = os.path.join( + ekfac_results_path, f"{covariance_type}_sharded" + ) + + ground_truth_covariances = load_file(covariances_ground_truth_path) + run_covariances = load_sharded_covariances(covariances_run_path) + + rtol = 1e-10 + atol = 0 + all_match = True + error_details = [] + + for k in ground_truth_covariances: + gt = ground_truth_covariances[k] + run = run_covariances[k] + + if not torch.allclose(gt, run, rtol=rtol, atol=atol): + all_match = False + diff = (gt - run).abs() + rel_diff = diff / (gt.abs() + 1e-10) + error_details.append( + f" {k}: max_rel_diff={100 * rel_diff.max():.3f}%, " + f"mean={100 * rel_diff.mean():.3f}%" + ) + + if all_match: + print(f"{covariance_type} covariances match within tolerance (rtol={rtol})") + else: + error_msg = ( + f"{covariance_type} covariances do not match (rtol={rtol})!\n" + + "\n".join(error_details) + ) + assert False, error_msg + + print("-*" * 50) diff --git a/tests/ekfac_tests/test_eigenvalue_correction.py b/tests/ekfac_tests/test_eigenvalue_correction.py new file mode 100644 index 00000000..42e08693 --- /dev/null +++ b/tests/ekfac_tests/test_eigenvalue_correction.py @@ -0,0 +1,85 @@ +import os + +import pytest +import torch +from safetensors.torch import load_file + +from tests.ekfac_tests.test_utils import load_sharded_covariances + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Numerical precision differences on CPU vs GPU", +) +def test_eigenvalue_corrections( + ground_truth_eigenvalue_corrections_path: str, + ekfac_results_path: str, +) -> None: + """Test eigenvalue corrections against ground truth.""" + print("\nTesting eigenvalue corrections...") + + lambda_ground_truth_path = os.path.join( + ground_truth_eigenvalue_corrections_path, "eigenvalue_corrections.safetensors" + ) + lambda_run_path = os.path.join(ekfac_results_path, "eigenvalue_correction_sharded") + + # load ground_truth + lambda_ground_truth = load_file(lambda_ground_truth_path) + + # load run eigenvalue corrections (sharded) + lambda_run = load_sharded_covariances(lambda_run_path) + + total_processed_run_path = os.path.join(ekfac_results_path, "total_processed.pt") + lambda_device = lambda_run[list(lambda_run.keys())[0]].device + total = torch.load(total_processed_run_path, map_location=lambda_device) + + # Normalize by total + lambda_run = {k: v / total for k, v in lambda_run.items()} + + # Use reasonable tolerance for numerical differences between implementations + # due to float precision, accumulation order, and eigenvector differences + # query_key_value layers can have up to ~10% differences due to eigenvector issues + rtol = 0.12 # 12% relative tolerance + atol = 1e-4 + all_match = True + error_details = [] + has_significant_errors = False + + for k in lambda_ground_truth: + gt = lambda_ground_truth[k] + run = lambda_run[k] + + if not torch.allclose(gt, run, rtol=rtol, atol=atol): + all_match = False + diff = (gt - run).abs() + rel_diff = diff / (gt.abs() + 1e-10) + max_rel_diff = rel_diff.max() + + # Find location of max difference + coord = diff.argmax() + a, b = coord // gt.shape[1], coord % gt.shape[1] + + if max_rel_diff < 0.05: # 5% threshold for reporting + error_details.append( + f" {k}: small differences within tolerance " + f"(max_rel_diff={(100 * max_rel_diff):.3f}%)" + ) + else: + has_significant_errors = True + error_details.append( + f" {k}: max_rel_diff={(100 * max_rel_diff):.3f}%, " + f"mean={(100 * rel_diff.mean()):.3f}%" + ) + error_details.append( + f" at [{a},{b}]: gt={gt[a, b]:.3e}, run={run[a, b]:.3e}" + ) + + if all_match: + print(f"Eigenvalue corrections match within tolerance (rtol={rtol})") + elif has_significant_errors: + error_msg = f"Eigenvalue corrections do not match (rtol={rtol})!\n" + "\n".join( + error_details + ) + assert False, error_msg + else: + print("Eigenvalue corrections: all differences within tolerance") diff --git a/tests/ekfac_tests/test_eigenvectors.py b/tests/ekfac_tests/test_eigenvectors.py new file mode 100644 index 00000000..6587abeb --- /dev/null +++ b/tests/ekfac_tests/test_eigenvectors.py @@ -0,0 +1,63 @@ +import os + +import pytest +import torch +from safetensors.torch import load_file + +from tests.ekfac_tests.test_utils import load_sharded_covariances + + +@pytest.mark.parametrize("eigenvector_type", ["activation", "gradient"]) +def test_eigenvectors( + ekfac_results_path: str, + ground_truth_eigenvectors_path: str, + eigenvector_type: str, +) -> None: + """Test eigenvectors against ground truth.""" + print(f"\nTesting {eigenvector_type} eigenvectors...") + + eigenvectors_ground_truth_path = os.path.join( + ground_truth_eigenvectors_path, f"eigenvectors_{eigenvector_type}s.safetensors" + ) + eigenvectors_run_path = os.path.join( + ekfac_results_path, f"eigen_{eigenvector_type}_sharded" + ) + + # load ground_truth + ground_truth_eigenvectors = load_file(eigenvectors_ground_truth_path) + + # load run eigenvectors (sharded) and concatenate + run_eigenvectors = load_sharded_covariances(eigenvectors_run_path) + + rtol = 1e-5 + atol = 1e-7 + all_match = True + error_details = [] + + for k in ground_truth_eigenvectors: + gt = ground_truth_eigenvectors[k] + run = run_eigenvectors[k] + + if not torch.allclose(gt, run, rtol=rtol, atol=atol): + all_match = False + diff = (gt - run).abs() + max_diff_val = diff.max() + + # Find location of max difference + max_diff_flat_idx = torch.argmax(diff) + max_diff_idx = torch.unravel_index(max_diff_flat_idx, diff.shape) + relative_diff = 100 * max_diff_val / (gt[max_diff_idx].abs() + 1e-10) + + error_details.append( + f" {k}: abs_diff={max_diff_val:.2e}, rel_diff={relative_diff:.2e}%" + ) + + if all_match: + print(f"{eigenvector_type} eigenvectors match (rtol={rtol}, atol={atol})") + else: + error_msg = f"{eigenvector_type} eigenvectors do not match!\n" + "\n".join( + error_details + ) + assert False, error_msg + + print("-*" * 50) diff --git a/tests/ekfac_tests/test_fim_accuracy.py b/tests/ekfac_tests/test_fim_accuracy.py new file mode 100644 index 00000000..9be78aeb --- /dev/null +++ b/tests/ekfac_tests/test_fim_accuracy.py @@ -0,0 +1,196 @@ +""" +Test EKFAC accuracy for computing the Fisher Information Matrix. + +Compares the K-FAC approximation F_kfac = G ⊗ A against the exact FIM +computed from per-position gradients on a toy language model. +""" + +from pathlib import Path + +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +from bergson.collector.collector import CollectorComputer, fwd_bwd_hessian_factory +from bergson.config import HessianConfig, IndexConfig +from bergson.hessians.kfac import CovarianceCollector +from bergson.utils.utils import get_device +from tests.ekfac_tests.test_utils import load_sharded_covariances, set_all_seeds +from tests.ekfac_tests.toy_model import ( + ToyDataConfig, + ToyLM, + ToyLMConfig, + generate_batches, + generate_dataset, +) + + +def compute_exact_fim( + model: ToyLM, + dataset, + batches: list[list[int]], + device: torch.device, + sample: bool, +) -> tuple[Tensor, Tensor, Tensor, int]: + """ + Compute exact FIM from per-position gradients for ToyLM. + + Args: + sample: If True, sample labels from model distribution (true FIM). + If False, use dataset labels (empirical FIM). + + Returns: + F_exact: Exact FIM from per-position gradients + A: Activation covariance (normalized) + G: Gradient covariance (normalized) + n_positions: Total number of valid positions + """ + hidden_size = model.config.hidden_size + vocab_size = model.config.vocab_size + + position_grads = [] + A_sum = torch.zeros(hidden_size, hidden_size, device=device) + G_sum = torch.zeros(vocab_size, vocab_size, device=device) + + for batch_indices in batches: + for idx in batch_indices: + input_ids = torch.tensor( + dataset[idx]["input_ids"], device=device + ).unsqueeze(0) + labels = torch.tensor(dataset[idx]["labels"], device=device) + + hidden = model.model.embed(input_ids) + hidden.requires_grad_(True) + logits = model.model.linear(hidden) + + for s in range(input_ids.shape[1] - 1): + if sample: + # Sample from model distribution (true FIM) + with torch.no_grad(): + probs = torch.softmax(logits[0, s].detach(), dim=-1) + target = torch.multinomial(probs, num_samples=1).squeeze() + else: + # Use dataset labels (empirical FIM) + target = labels[s + 1] + + loss = F.cross_entropy(logits[0, s], target) + + (g,) = torch.autograd.grad(loss, logits, retain_graph=True) + g = g[0, s] + a = hidden[0, s].detach() + + position_grads.append(torch.outer(g, a).flatten().detach()) + A_sum += torch.outer(a, a) + G_sum += torch.outer(g.detach(), g.detach()) + + n_positions = len(position_grads) + grads_tensor = torch.stack(position_grads) + F_exact = grads_tensor.T @ grads_tensor / n_positions + + A = A_sum / n_positions + A = (A + A.T) / 2 + G = G_sum / n_positions + G = (G + G.T) / 2 + + return F_exact, A, G, n_positions + + +@pytest.mark.parametrize( + "seq_lengths, num_batches, sample, max_rel_error", + [ + ((512,), 100, False, 0.05), + ((512,), 100, True, 0.05), + ((4,), 10000, False, 0.05), # rel_error = ~0.25 without valid_masks logic + ((4,), 10000, True, 0.10), # rel_error = ~0.25 without valid_masks logic + ((512, 2), 100, False, 0.05), # rel_error = ~0.6 without valid_masks logic + ((512, 2), 100, True, 0.20), # rel_error = ~1.2 without valid_masks logic + ], +) +def test_kfac_fim_accuracy(seq_lengths, num_batches, max_rel_error, sample, tmp_path): + """ + Test that KFAC approximates the FIM within tolerance. + + Args: + sample: If True, test true FIM (sampled labels). + If False, test empirical FIM (dataset labels). + """ + set_all_seeds(42) + + config = ToyDataConfig( + vocab_size=8, + hidden_size=4, + seq_lengths=seq_lengths, + num_batches=num_batches, + ) + device = torch.device(get_device()) + + dataset = generate_dataset(config) + batches = generate_batches(config) + + model_config = ToyLMConfig( + vocab_size=config.vocab_size, hidden_size=config.hidden_size + ) + model = ToyLM( + model_config, + training_data=dataset, + training_batches=batches, + device=device, + ) + + F_exact, A_exact, G_exact, total_processed_exact = compute_exact_fim( + model, dataset, batches, device, sample=sample + ) + + run_path = Path(tmp_path) / "run" + index_cfg = IndexConfig(run_path=str(run_path), loss_reduction="sum") + + collector = CovarianceCollector( + model=model.base_model, + target_modules={"linear"}, + dtype=torch.float32, + path=str(index_cfg.partial_run_path), + ) + + hessian_cfg = HessianConfig(use_dataset_labels=not sample) + + computer = CollectorComputer( + model=model, + data=dataset, + batches=batches, + collector=collector, + cfg=index_cfg, + ) + computer.forward_backward = fwd_bwd_hessian_factory(index_cfg, hessian_cfg) + computer.run_with_collector_hooks() + + A_dict_kfac = load_sharded_covariances( + index_cfg.partial_run_path / "activation_sharded" + ) + G_dict_kfac = load_sharded_covariances( + index_cfg.partial_run_path / "gradient_sharded" + ) + total_processed_kfac = torch.load( + index_cfg.partial_run_path / "total_processed.pt" + ).item() + + assert total_processed_kfac == total_processed_exact + + A_kfac = list(A_dict_kfac.values())[0].float().to(device) / total_processed_kfac + A_kfac = (A_kfac + A_kfac.T) / 2 + G_kfac = list(G_dict_kfac.values())[0].float().to(device) / total_processed_kfac + G_kfac = (G_kfac + G_kfac.T) / 2 + + # A and G should be the same when we're not sampling + if not sample: + torch.testing.assert_close(A_kfac, A_exact, rtol=1e-3, atol=1e-6) + torch.testing.assert_close(G_kfac, G_exact, rtol=1e-3, atol=1e-6) + + F_kfac = torch.kron(G_kfac, A_kfac) + rel_error = (torch.norm(F_kfac - F_exact) / torch.norm(F_exact)).item() + + assert rel_error <= max_rel_error, ( + f"KFAC rel_error {rel_error:.4f} greater than tolerated max_rel_error " + f"{max_rel_error} for seq_lengths={seq_lengths}, num_batches={num_batches}, " + f"sample={sample}" + ) diff --git a/tests/ekfac_tests/test_utils.py b/tests/ekfac_tests/test_utils.py new file mode 100644 index 00000000..b649c01f --- /dev/null +++ b/tests/ekfac_tests/test_utils.py @@ -0,0 +1,89 @@ +"""Common utilities for EKFAC tests.""" + +import os +import random +from pathlib import Path + +import numpy as np +import torch +from safetensors.torch import load_file +from torch import Tensor + + +def add_tensor_dicts(a: dict[str, Tensor], b: dict[str, Tensor]) -> dict[str, Tensor]: + """Add two dictionaries of tensors element-wise.""" + assert set(a.keys()) == set(b.keys()), "Keys must match" + return {k: a[k] + b[k] for k in a} + + +def tensor_dict_to_device( + d: dict[str, Tensor], device: str | torch.device +) -> dict[str, Tensor]: + """Move all tensors in a dictionary to the specified device.""" + return {k: v.to(device) for k, v in d.items()} + + +def load_sharded_covariances(sharded_dir: str | Path) -> dict[str, torch.Tensor]: + """Load and concatenate sharded covariance files. + + Args: + sharded_dir: Directory containing shard_0.safetensors, shard_1.safetensors, etc. + + Returns: + Dictionary mapping layer names to concatenated covariance tensors. + """ + sharded_dir = Path(sharded_dir) + shard_files = sorted(sharded_dir.glob("shard_*.safetensors")) + + if not shard_files: + raise FileNotFoundError(f"No shard files found in {sharded_dir}") + + shards = [load_file(str(f)) for f in shard_files] + + # Concatenate shards along first dimension + result = {} + for key in shards[0]: + result[key] = torch.cat([shard[key] for shard in shards], dim=0) + + return result + + +def load_covariances( + run_path: str | Path, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], int]: + """Load activation and gradient covariances from an EKFAC run. + + Args: + run_path: Path to the run directory containing influence_results/. + + Returns: + Tuple of (activation_covariances, gradient_covariances, total_processed). + """ + run_path = Path(run_path) + results_path = run_path / "influence_results" + + A_cov = load_sharded_covariances(results_path / "activation_sharded") + G_cov = load_sharded_covariances(results_path / "gradient_sharded") + total_processed = torch.load(results_path / "total_processed.pt").item() + + return A_cov, G_cov, total_processed + + +def set_all_seeds(seed: int = 42) -> None: + """Set all random seeds for reproducibility.""" + # Set all random seeds + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + # Force deterministic behavior + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + # Set environment variables for additional determinism + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" diff --git a/tests/ekfac_tests/toy_model.py b/tests/ekfac_tests/toy_model.py new file mode 100644 index 00000000..34b4b66a --- /dev/null +++ b/tests/ekfac_tests/toy_model.py @@ -0,0 +1,155 @@ +""" +Toy language model for EKFAC testing. + +Provides a minimal transformers-compatible model and dataset generation +utilities for testing EkfacComputer without loading real models. +""" + +from dataclasses import dataclass + +import torch +import torch.nn as nn +from datasets import Dataset +from torch import Tensor +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutput + + +class ToyLMConfig(PretrainedConfig): + """Configuration for ToyLM - a minimal language model for testing.""" + + model_type = "toy_lm" + + def __init__( + self, + vocab_size: int = 8, + hidden_size: int = 4, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + super().__init__(**kwargs) + + +class ToyLMModule(nn.Module): + """The base model (what hooks attach to).""" + + def __init__(self, config: ToyLMConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size, config.hidden_size) + self.linear = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, input_ids: Tensor) -> Tensor: + hidden = self.embed(input_ids) # [B, S] -> [B, S, H] + return self.linear(hidden) # [B, S, H] -> [B, S, V] + + +class ToyLM(PreTrainedModel): + """Toy language model compatible with EkfacComputer.""" + + config_class = ToyLMConfig + base_model_prefix = "model" + + def __init__( + self, + config: ToyLMConfig, + *, + training_data=None, + training_batches: list[list[int]] | None = None, + device: torch.device | None = None, + num_steps: int = 5000, + ): + super().__init__(config) + self.model = ToyLMModule(config) + + if training_data is not None and training_batches is not None: + self._train(training_data, training_batches, device, num_steps) + + def _train( + self, + dataset, + batches: list[list[int]], + device: torch.device | None, + num_steps: int, + lr: float = 0.1, + ) -> None: + """Train the model to make logits more peaked (like a real LLM).""" + import torch.nn.functional as F + + if device is not None: + nn.Module.to(self, device) + + optimizer = torch.optim.SGD(self.parameters(), lr=lr) + + step = 0 + while step < num_steps: + for batch_indices in batches: + for idx in batch_indices: + input_ids = torch.tensor( + dataset[idx]["input_ids"], device=device + ).unsqueeze(0) + labels = torch.tensor( + dataset[idx]["labels"], device=device + ).unsqueeze(0) + + logits = self(input_ids).logits + loss = F.cross_entropy( + logits[:, :-1].reshape(-1, logits.size(-1)), + labels[:, 1:].reshape(-1), + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + step += 1 + if step >= num_steps: + return + + @property + def base_model(self) -> nn.Module: + return self.model + + def forward(self, input_ids: Tensor, **kwargs) -> CausalLMOutput: + logits = self.model(input_ids) + return CausalLMOutput(logits=logits) + + +@dataclass +class ToyDataConfig: + """Configuration for toy data generation.""" + + vocab_size: int = 8 + hidden_size: int = 4 + seq_lengths: tuple[int, ...] = (2,) + num_batches: int = 2000 + + @property + def max_seq_len(self) -> int: + return max(self.seq_lengths) + + @property + def batch_size(self) -> int: + return len(self.seq_lengths) + + +def generate_dataset(config: ToyDataConfig) -> Dataset: + """Generate a HuggingFace Dataset for use with EkfacComputer.""" + data = {"input_ids": [], "labels": []} + + for _ in range(config.num_batches): + for seq_len in config.seq_lengths: + input_ids = torch.randint(0, config.vocab_size, (seq_len,)).tolist() + data["input_ids"].append(input_ids) + data["labels"].append(input_ids) + + return Dataset.from_dict(data) + + +def generate_batches(config: ToyDataConfig) -> list[list[int]]: + """Generate batch indices for EkfacComputer.""" + batch_size = len(config.seq_lengths) + return [ + list(range(i * batch_size, (i + 1) * batch_size)) + for i in range(config.num_batches) + ]