diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2fdfe3eb..7a82df57 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,20 +19,12 @@ jobs: run: | python -m pip install --upgrade pip pip install -e ".[dev,faiss]" - # TODO: Proper test infrastructure for tests/ekfac_tests - # - name: Run tests - # run: pytest - # TODO: run pyright on whole codebase - - name: Type Checking bergson/hessians + - name: Run EKFAC tests that work without GPU + run: pytest -sv tests/ekfac_tests + - name: Type Checking uses: jakebailey/pyright-action@v1 with: version: 1.1.406 - working-directory: bergson/hessians - - name: Type Checking tests/ekfac_tests - uses: jakebailey/pyright-action@v1 - with: - version: 1.1.406 - working-directory: tests/ekfac_tests - name: build run: pip wheel --no-deps -w dist . env: diff --git a/bergson/collection.py b/bergson/collection.py index aaa4235c..f5619ee4 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -85,7 +85,7 @@ def callback(name: str, g: torch.Tensor): for indices in tqdm(batches, disable=rank != 0, desc="Building index"): batch = data[indices] - x, y = pad_and_tensor( + x, y, valid_masks = pad_and_tensor( batch["input_ids"], # type: ignore labels=batch.get("labels"), # type: ignore device=model.device, @@ -100,8 +100,7 @@ def callback(name: str, g: torch.Tensor): reduction="none", ).reshape_as(y[:, 1:]) - masks = y[:, 1:] != -100 - denoms = masks.sum(dim=1, dtype=logits.dtype) + denoms = valid_masks.sum(dim=1, dtype=logits.dtype) losses = losses.sum(1).div(denoms) losses.mean().backward() @@ -208,7 +207,7 @@ def adam_update(name: str, g: torch.Tensor): closure=callback, target_modules=target_modules, ): - x, y = pad_and_tensor( + x, y, _ = pad_and_tensor( batch["input_ids"], # type: ignore labels=batch.get("labels", None), # type: ignore device=model.device, diff --git a/bergson/data.py b/bergson/data.py index f41f9980..c03d4c0a 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -351,12 +351,14 @@ def pad_and_tensor( padding_value: int = 0, dtype: torch.dtype | None = torch.long, device: torch.device | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Pad a list of sequences to the same length and convert them to tensors. - Returns a tuple of padded sequences and labels. The labels are the same as the - sequences, but with -100 for the padding positions, which is useful for ignoring - padding in loss calculations. + + Returns: + padded_tokens: Padded input sequences [N, S] + padded_labels: Labels with -100 for padding positions [N, S] + valid_masks: Boolean mask [N, S] where True indicates valid positions """ if labels is None: labels = sequences @@ -370,7 +372,13 @@ def pad_and_tensor( # convert to tensor padded_tokens = torch.tensor(padded, dtype=dtype, device=device) padded_labels = torch.tensor(labels, dtype=dtype, device=device) - return padded_tokens, padded_labels + + # Compute valid_masks: position i is valid if labels[i+1] != -100 + N, S = padded_tokens.shape + valid_masks = torch.zeros(N, S, dtype=torch.bool, device=device) + valid_masks[:, :-1] = padded_labels[:, 1:] != -100 + + return padded_tokens, padded_labels, valid_masks def tokenize(batch: dict, *, args: DataConfig, tokenizer): diff --git a/bergson/distributed.py b/bergson/distributed.py index 60fb9783..9873211a 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -94,7 +94,10 @@ def setup_model_and_peft( torch.cuda.manual_seed_all(42) # Common configuration - device_map = {"": f"cuda:{rank}"} if not cfg.fsdp else "cpu" + if cfg.fsdp or not torch.cuda.is_available(): + device_map = "cpu" + else: + device_map = {"": f"cuda:{rank}"} quantization_config = None if cfg.precision in ("int4", "int8"): quantization_config = BitsAndBytesConfig( @@ -202,7 +205,8 @@ def worker_wrapper( setup_processor: bool = True, ): try: - torch.cuda.set_device(rank) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) if cfg.debug: setup_reproducibility() print("DEBUG MODE IS ENABLED: quasi-deterministic training") diff --git a/bergson/hessians/collector.py b/bergson/hessians/collector.py index da75f2f3..4c10e14d 100644 --- a/bergson/hessians/collector.py +++ b/bergson/hessians/collector.py @@ -11,7 +11,7 @@ from torch.utils.hooks import RemovableHandle from bergson.hessians.sharded_computation import ShardedMul -from bergson.utils import assert_type +from bergson.utils import assert_type, get_device @dataclass @@ -37,6 +37,11 @@ class HookCollectorBase(ContextDecorator, ABC): Set of module names to attach hooks to. Should consist only of nn.Linear modules. If None, hooks are attached to all Linear layers in the model. """ + valid_masks: Tensor = None # type: ignore[assignment] + """ + Mask of shape [N, S] indicating which positions are valid. + Must be set via set_valid_masks() before each batch. + """ @staticmethod def discover_targets( @@ -82,6 +87,12 @@ def __post_init__(self): # Allow subclasses to perform custom initialization self.setup() + def set_valid_masks(self, masks: Tensor) -> None: + """ + Set the valid_masks for the current batch. + """ + self.valid_masks = masks + def __enter__(self): """Register forward and backward hooks on all target modules.""" for name in self.target_info: @@ -214,8 +225,8 @@ def forward_hook(self, name: str, a: Tensor) -> None: """Compute activation covariance: A^T @ A.""" A_cov_ki = self.A_cov_dict[name] - # Reshape to [N*S, I] - a_bi = a.reshape(-1, a.shape[-1]) + # a: [N, S, I], valid_masks: [N, S] -> select valid positions + a_bi = a[self.valid_masks] # [num_valid, I] # Compute local covariance local_update_ii = a_bi.mT @ a_bi @@ -290,18 +301,20 @@ class LambdaCollector(HookCollectorBase): def setup(self) -> None: """Load eigenvectors and initialize storage.""" + device = get_device(self.rank) + # Load precomputed eigenvectors self.eigen_a = load_file( os.path.join( self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors" ), - device=f"cuda:{self.rank}", + device=device, ) self.eigen_g = load_file( os.path.join( self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors" ), - device=f"cuda:{self.rank}", + device=device, ) # Initialize accumulators diff --git a/bergson/hessians/data_filtering_ekfac.ipynb b/bergson/hessians/data_filtering_ekfac.ipynb index 59a6ae88..c3f38453 100644 --- a/bergson/hessians/data_filtering_ekfac.ipynb +++ b/bergson/hessians/data_filtering_ekfac.ipynb @@ -13,24 +13,17 @@ "metadata": {}, "outputs": [], "source": [ + "import json\n", "import os\n", - "from typing import Literal\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", - "from datasets import load_dataset\n", - "from tqdm.notebook import tqdm\n", - "\n", - "from bergson.data import load_gradients\n", "from safetensors.torch import load_file\n", - "\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", "from scipy.stats import spearmanr\n", "\n", - "import json" + "from bergson.data import load_gradients" ] }, { @@ -373,8 +366,6 @@ ], "source": [ "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from scipy.stats import spearmanr\n", "\n", "# Calculate Spearman correlation\n", "mask = ~(np.isnan(attributions_scores) | np.isnan(attributions_ekfac_scores))\n", @@ -461,9 +452,9 @@ ")\n", "\n", "\n", - "plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n", - "plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n", - "plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n", "plt.xlabel(\"Number of elements removed\")\n", "plt.ylabel('Number of elements in the \"correct\" half')\n", "plt.title(\"EK-FAC, no attn, on train set\")\n", @@ -638,7 +629,6 @@ "# load the saved attributions\n", "import json\n", "\n", - "\n", "all_attributions = {}\n", "\n", "for path in all_query_paths:\n", @@ -849,7 +839,7 @@ ], "source": [ "# plot intersection\n", - "plt.plot(np.array(top_percentages) * len(index), intersection_12, label=f\"Intersection\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersection_12, label=\"Intersection\")\n", "plt.plot(\n", " [0, len(index) // 2, len(index)],\n", " [0, len(index) // 2, len(index)],\n", @@ -1218,9 +1208,9 @@ ")\n", "\n", "\n", - "plt.plot(np.array(top_percentages) * len(index), intersections, label=f\"Query without ekfac\")\n", - "plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=f\"Query with ekfac\")\n", - "plt.plot(np.array(top_percentages) * len(index), intersections_random, label=f\"Random baseline\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections, label=\"Query without ekfac\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections_ekfac, label=\"Query with ekfac\")\n", + "plt.plot(np.array(top_percentages) * len(index), intersections_random, label=\"Random baseline\")\n", "plt.xlabel(\"Number of elements removed\")\n", "plt.ylabel('Number of elements in the \"correct\" half')\n", "plt.legend()\n", @@ -1374,9 +1364,10 @@ } ], "source": [ - "import torch\n", "import os\n", "\n", + "import torch\n", + "\n", "# Set the debug flag - this is the correct way\n", "os.environ[\"TORCH_COMPILE_DEBUG\"] = \"1\"\n", "\n", diff --git a/bergson/hessians/ekfac_compute.py b/bergson/hessians/ekfac_compute.py index 3a82afaa..22056939 100644 --- a/bergson/hessians/ekfac_compute.py +++ b/bergson/hessians/ekfac_compute.py @@ -32,6 +32,7 @@ ) from bergson.hessians.logger import get_logger from bergson.hessians.sharded_computation import ShardedMul +from bergson.utils import get_device class EkfacComputer: @@ -92,6 +93,7 @@ def __init__( def compute_covariance(self): cov_collector = CovarianceCollector( self.model.base_model, + target_modules=self.target_modules, dtype=self.dtype, shard_computer=self.shard_computer, rank=self.rank, @@ -106,7 +108,7 @@ def compute_eigendecomposition( """This is Eq. 18 from above reference.""" total_processed = torch.load( os.path.join(self.path, "total_processed_covariances.pt"), - map_location=f"cuda:{self.rank}", + map_location=torch.device(get_device(self.rank)), ) random.seed(0) @@ -224,13 +226,14 @@ def _collector(self, collector, desc: Optional[str] = None): self.batches, disable=self.rank != 0, desc=f"Computing {desc}" ): batch = self.data[sl] - x, y = pad_and_tensor( + x, y, valid_masks = pad_and_tensor( batch["input_ids"], # type: ignore labels=batch.get("labels"), # type: ignore device=self.model.device, ) - total_processed += x.numel() + total_processed += valid_masks.sum() + collector.set_valid_masks(valid_masks) with ( collector, @@ -256,19 +259,19 @@ def _collector(self, collector, desc: Optional[str] = None): probs, num_samples=1, ).flatten() - del probs + flat_mask = valid_masks[:, :-1].flatten() losses = F.cross_entropy( - logits, - sampled_labels, + logits[flat_mask], + sampled_labels[flat_mask], reduction="none", - ).reshape_as(y[:, 1:]) + ) - losses = losses.sum(1) - losses.mean().backward() + losses.sum().backward() self.model.zero_grad() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() if self.cfg.profile: assert isinstance(prof, profile), "Profiler is not set up correctly" diff --git a/bergson/hessians/misaligned_datasets.ipynb b/bergson/hessians/misaligned_datasets.ipynb index 0ac02a97..20960170 100644 --- a/bergson/hessians/misaligned_datasets.ipynb +++ b/bergson/hessians/misaligned_datasets.ipynb @@ -13,23 +13,16 @@ "metadata": {}, "outputs": [], "source": [ + "import json\n", "import os\n", - "from typing import Literal\n", - "import joblib\n", + "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", - "from datasets import load_dataset\n", - "from tqdm.notebook import tqdm\n", - "import json\n", - "from datasets import Dataset\n", - "from bergson.data import load_gradients\n", - "from safetensors.torch import load_file\n", - "from sklearn.metrics import roc_auc_score\n", - "import numpy as np\n", - "\n", - "from sklearn.metrics import roc_auc_score, precision_recall_curve, auc\n" + "from datasets import Dataset, load_dataset\n", + "from sklearn.metrics import auc, precision_recall_curve, roc_auc_score\n", + "\n" ] }, { @@ -439,14 +432,14 @@ "print(f\"PR AUC: {pr_auc:.4f}\")\n", "\n", "# Additional metrics for analysis\n", - "print(f\"\\nDataset composition:\")\n", + "print(\"\\nDataset composition:\")\n", "print(f\"Correct examples: {len(sorted_correct_scores)}\")\n", "print(f\"Incorrect examples: {len(sorted_incorrect_scores)}\")\n", "print(f\"Subtle incorrect examples: {len(sorted_subtle_scores)}\")\n", "print(f\"Total examples: {len(all_scores)}\")\n", "print(f\"Problematic ratio: {(len(sorted_incorrect_scores) + len(sorted_subtle_scores)) / len(all_scores):.3f}\")\n", "\n", - "print(f\"\\nScore statistics:\")\n", + "print(\"\\nScore statistics:\")\n", "print(f\"Correct scores - Mean: {sorted_correct_scores.mean():.4f}, Std: {sorted_correct_scores.std():.4f}\")\n", "print(f\"Incorrect scores - Mean: {sorted_incorrect_scores.mean():.4f}, Std: {sorted_incorrect_scores.std():.4f}\")\n", "print(f\"Subtle scores - Mean: {sorted_subtle_scores.mean():.4f}, Std: {sorted_subtle_scores.std():.4f}\")\n" diff --git a/bergson/hessians/sharded_computation.py b/bergson/hessians/sharded_computation.py index ec9e5d1d..d4c90230 100644 --- a/bergson/hessians/sharded_computation.py +++ b/bergson/hessians/sharded_computation.py @@ -8,6 +8,8 @@ from safetensors import safe_open from torch import Tensor +from bergson.utils import get_device + class ShardedMul: def __init__(self, target_info, lambda_damp_factor=0.1): @@ -92,15 +94,14 @@ def _compute_full_matrix( len(files) == self.world_size ), f"Expected {self.world_size} shards, found {len(files)} in {shard_path}" + device = get_device(self.rank) full_matrix = None if not self.dist: full_path_rank = os.path.join( shard_path, "shard_0.safetensors" ) # TODO: Does this work with different CUDA visible devices? - with safe_open( - full_path_rank, framework="pt", device=f"cuda:{self.rank}" - ) as f: + with safe_open(full_path_rank, framework="pt", device=device) as f: full_matrix = f.get_tensor(name) else: @@ -109,9 +110,7 @@ def _compute_full_matrix( shard_path_rank = os.path.join( shard_path, f"shard_{shard_id}.safetensors" ) - with safe_open( - shard_path_rank, framework="pt", device=f"cuda:{self.rank}" - ) as f: + with safe_open(shard_path_rank, framework="pt", device=device) as f: local_matrix = f.get_tensor(name) full_matrix_list.append(local_matrix) diff --git a/bergson/utils.py b/bergson/utils.py index a5974b81..c34aa3b1 100644 --- a/bergson/utils.py +++ b/bergson/utils.py @@ -1,5 +1,6 @@ from typing import Any, Type, TypeVar, cast +import torch from peft import PeftModel from torch import nn from transformers import PreTrainedModel @@ -26,3 +27,11 @@ def get_layer_list(model: PreTrainedModel | PeftModel) -> nn.ModuleList: assert len(candidates) == 1, "Could not find the list of layers." return candidates[0] + + +def get_device(rank: int = 0) -> str: + """Get device string for the given rank. + + Returns "cpu" if CUDA is not available. + """ + return f"cuda:{rank}" if torch.cuda.is_available() else "cpu" diff --git a/pyproject.toml b/pyproject.toml index 2c02cd32..5fbabfb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "peft>=0.16.0", "simple-parsing", "torch", - "transformers", + "transformers>=4.57.1", ] version = "0.0.1" [project.optional-dependencies] @@ -40,7 +40,14 @@ faiss = [ bergson = "bergson.__main__:main" [tool.pyright] -include = ["bergson*"] +include = ["bergson*", "tests/ekfac_tests"] +## Temporary excludes before we merge ekfac into the main branch +exclude = [ + "bergson/attributor.py", + "bergson/collection.py", + "bergson/gradients.py", + "bergson/utils.py", +] reportPrivateImportUsage = false [tool.setuptools.packages.find] @@ -53,3 +60,11 @@ lint.ignore = ["E741", "F722"] # Ambiguous variable name and jaxtyping shape a lint.select = ["E", "F", "I"] # Same as Black. line-length = 120 + +[dependency-groups] +dev = [ + "pre-commit>=4.2.0", + "pre-commit-uv>=4.1.5", + "pyright>=1.1.406", + "pytest>=8.4.2", +] diff --git a/tests/ekfac_tests/apply_ekfac_ground_truth.ipynb b/tests/ekfac_tests/apply_ekfac_ground_truth.ipynb index 4aba39fe..bf75b9b7 100644 --- a/tests/ekfac_tests/apply_ekfac_ground_truth.ipynb +++ b/tests/ekfac_tests/apply_ekfac_ground_truth.ipynb @@ -23,35 +23,19 @@ "metadata": {}, "outputs": [], "source": [ - "import gc\n", "import hashlib\n", "import json\n", "import os\n", - "import random\n", - "from contextlib import nullcontext\n", - "from typing import Literal, Optional\n", + "from typing import Literal\n", "\n", - "import numpy as np\n", "import torch\n", - "import torch.distributed as dist\n", - "import torch.nn.functional as F\n", "from datasets import Dataset\n", - "from jaxtyping import Float\n", - "from safetensors import safe_open\n", "from safetensors.torch import load_file, save_file\n", "from torch import Tensor\n", "\n", - "from tqdm.auto import tqdm\n", - "from transformers import PreTrainedModel\n", - "\n", "from bergson.collection import collect_gradients\n", - "from bergson.data import DataConfig, IndexConfig, create_index, load_gradients, pad_and_tensor\n", - "from bergson.distributed import distributed_computing, setup_data_pipeline\n", - "from bergson.gradients import (\n", - " GradientProcessor,\n", - ")\n", - "from bergson.hessians.collector import EkfacCollector\n", - "from bergson.hessians.logger import get_logger" + "from bergson.data import DataConfig, IndexConfig, load_gradients\n", + "from bergson.distributed import distributed_computing, setup_data_pipeline" ] }, { diff --git a/tests/ekfac_tests/compute_ekfac_ground_truth.py b/tests/ekfac_tests/compute_ekfac_ground_truth.py index c4981dd4..71ecda85 100644 --- a/tests/ekfac_tests/compute_ekfac_ground_truth.py +++ b/tests/ekfac_tests/compute_ekfac_ground_truth.py @@ -40,7 +40,7 @@ from bergson.data import DataConfig, IndexConfig, Precision, pad_and_tensor, tokenize from bergson.hessians.utils import TensorDict -from bergson.utils import assert_type +from bergson.utils import assert_type, get_device Batches = list[list[list[int]]] @@ -148,67 +148,100 @@ def allocate_batches_test( # %% -def parse_config() -> tuple[Precision, Optional[str]]: +def parse_config() -> tuple[Precision, str, str, int, bool]: """Parse command-line arguments or return defaults.""" - precision: Precision - output_dir: Optional[str] + parser = argparse.ArgumentParser( + description="Compute EKFAC ground truth for testing" + ) + parser.add_argument( + "--precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "int4", "int8"], + help="Model precision (default: fp32)", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default=os.path.join( + os.getcwd(), "test_files", "pile_100_examples", "ground_truth" + ), + help="Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)", + ) + parser.add_argument( + "--model-name", + type=str, + default="EleutherAI/Pythia-14m", + help="Model name to use (default: EleutherAI/Pythia-14m)", + ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of workers for simulated distributed computation (default: 1)", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing ground truth data and config", + ) + # For interactive mode (Jupyter/IPython) or no args, use defaults if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"): - parser = argparse.ArgumentParser( - description="Compute EKFAC ground truth for testing" - ) - parser.add_argument( - "--precision", - type=str, - default="fp32", - choices=["fp32", "fp16", "bf16", "int4", "int8"], - help="Model precision (default: fp32)", - ) - parser.add_argument( - "-o", - "--output-dir", - type=str, - default=None, - help="Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)", - ) args = parser.parse_args() - precision = args.precision - output_dir = args.output_dir else: - # Defaults for interactive execution or running without arguments - precision = "fp32" - output_dir = None + args = parser.parse_args([]) # Set random seeds for reproducibility set_all_seeds(42) - return precision, output_dir + return ( + args.precision, + args.output_dir, + args.model_name, + args.world_size, + args.overwrite, + ) if __name__ == "__main__" or TYPE_CHECKING: - precision, output_dir = parse_config() + precision, test_path, model_name, world_size_arg, overwrite_arg = parse_config() # %% def setup_paths_and_config( - precision: Precision, output_dir: Optional[str] = None -) -> tuple[IndexConfig, str, int, torch.device, Any, torch.dtype]: + precision: Precision, + test_path: str, + model_name: str, + world_size: int, + overwrite: bool = False, +) -> tuple[IndexConfig, int, torch.device, Any, torch.dtype]: """Setup paths and configuration object.""" + os.makedirs(test_path, exist_ok=True) + current_path = os.getcwd() parent_path = os.path.join(current_path, "test_files", "pile_100_examples") - if output_dir is not None: - test_path = output_dir - else: - test_path = os.path.join(parent_path, "ground_truth") - os.makedirs(test_path, exist_ok=True) # Configuration cfg = IndexConfig(run_path="") - cfg.model = "EleutherAI/Pythia-14m" + cfg.model = model_name cfg.precision = precision cfg.fsdp = False cfg.data = DataConfig(dataset=os.path.join(parent_path, "data")) + # model_max_length is limited in some models like `roneneldan/TinyStories-1M` + tokenizer = AutoTokenizer.from_pretrained(cfg.model) + if ( + hasattr(tokenizer, "model_max_length") + and tokenizer.model_max_length < cfg.token_batch_size + ): + print( + f"Warning: Got --token-batch-size {cfg.token_batch_size} but {model_name} only supports up to {tokenizer.model_max_length}" + ) + cfg.token_batch_size = tokenizer.model_max_length + data_str = cfg.data.dataset # Create pile-100 dataset if it doesn't exist @@ -220,13 +253,42 @@ def setup_paths_and_config( subset.save_to_disk(data_str) print(f"Generated pile-100 in {data_str}") - # Save config - with open(os.path.join(test_path, "index_config.json"), "w") as f: - json.dump(asdict(cfg), f, indent=4) + config_path = os.path.join(test_path, "index_config.json") + if os.path.exists(config_path): + if not overwrite: + # Load existing config and compare + with open(config_path, "r") as f: + existing_cfg_dict = json.load(f) + + new_cfg_dict = asdict(cfg) + + if existing_cfg_dict != new_cfg_dict: + # Show differences for debugging + diffs = [ + f" {k}: {existing_cfg_dict[k]} != {new_cfg_dict[k]}" + for k in new_cfg_dict + if k in existing_cfg_dict + and existing_cfg_dict[k] != new_cfg_dict[k] + ] + raise RuntimeError( + f"Existing config at {config_path} differs from requested config:\n" + + "\n".join(diffs) + + "\n\nUse --overwrite to replace the existing config." + ) + + print(f"Using existing config from {config_path}") + else: + print(f"Overwriting existing config at {config_path}") + with open(config_path, "w") as f: + json.dump(asdict(cfg), f, indent=4) + else: + # Save new config + with open(config_path, "w") as f: + json.dump(asdict(cfg), f, indent=4) # Setup - workers = 8 - device = torch.device("cuda:0") + workers = world_size + device = torch.device(get_device(0)) target_modules = None # Determine dtype @@ -238,16 +300,20 @@ def setup_paths_and_config( case "fp32": dtype = torch.float32 case "int4" | "int8": - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = ( + torch.bfloat16 + if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) + else torch.float16 + ) case other: raise ValueError(f"Unsupported precision: {other}") - return cfg, test_path, workers, device, target_modules, dtype + return cfg, workers, device, target_modules, dtype if __name__ == "__main__" or TYPE_CHECKING: - cfg, test_path, workers, device, target_modules, dtype = setup_paths_and_config( - precision, output_dir + cfg, workers, device, target_modules, dtype = setup_paths_and_config( + precision, test_path, model_name, world_size_arg, overwrite_arg ) @@ -261,7 +327,7 @@ def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel: print(f"Loading model {cfg.model}...") model = AutoModelForCausalLM.from_pretrained( cfg.model, - device_map="cuda", + device_map="cuda" if torch.cuda.is_available() else "cpu", quantization_config=( BitsAndBytesConfig( load_in_4bit=cfg.precision == "int4", @@ -369,13 +435,14 @@ def compute_covariance( for sl in tqdm(batches, desc=f"Rank {rank} covariances"): batch = data[sl] - x, y = pad_and_tensor( + x, y, valid_masks = pad_and_tensor( batch["input_ids"], labels=batch.get("labels"), device=device, ) - total_processed += x.numel() + total_processed += valid_masks.sum() + collector.set_valid_masks(valid_masks) with collector: logits = model(x).logits @@ -385,12 +452,11 @@ def compute_covariance( reduction="none", ).reshape_as(y[:, 1:]) - losses = losses.sum(1) - losses.mean().backward() + losses.sum().backward() loss_list.append(losses.detach().cpu()) model.zero_grad() - return {"losses": loss_list, "total_processed_rank": total_processed} + return {"losses": loss_list, "total_processed_rank": total_processed.item()} # %% @@ -504,7 +570,8 @@ def combine_covariances_step( print(f"Global processed {total_processed_global} tokens.") gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return total_processed_global @@ -570,7 +637,8 @@ def compute_eigenvectors_step( ) gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return eigenvectors_test_path @@ -611,13 +679,14 @@ def compute_eigenvalue_correction_amortized( for sl in tqdm(batches, desc=f"Rank {rank} eigenvalue corrections"): batch = data[sl] - x, y = pad_and_tensor( + x, y, valid_masks = pad_and_tensor( batch["input_ids"], labels=batch.get("labels"), device=device, ) - total_processed += x.numel() + total_processed += valid_masks.sum() + collector.set_valid_masks(valid_masks) with collector: logits = model(x).logits @@ -627,11 +696,10 @@ def compute_eigenvalue_correction_amortized( reduction="none", ).reshape_as(y[:, 1:]) - losses = losses.sum(1) - losses.mean().backward() + losses.sum().backward() model.zero_grad() - return {"total_processed_rank": total_processed} + return {"total_processed_rank": total_processed.item()} # %% diff --git a/tests/ekfac_tests/conftest.py b/tests/ekfac_tests/conftest.py new file mode 100644 index 00000000..2bd6e4ee --- /dev/null +++ b/tests/ekfac_tests/conftest.py @@ -0,0 +1,325 @@ +"""Pytest configuration and fixtures for EKFAC tests.""" + +import json +import os +from typing import Any, Optional + +import pytest +from compute_ekfac_ground_truth import ( + combine_covariances_step, + combine_eigenvalue_corrections_step, + compute_covariances_step, + compute_eigenvalue_corrections_step, + compute_eigenvectors_step, + load_dataset_step, + load_model_step, + setup_paths_and_config, + tokenize_and_allocate_step, +) +from test_utils import set_all_seeds + +from bergson.data import DataConfig, IndexConfig, Precision +from bergson.distributed import distributed_computing +from bergson.hessians.compute_all import compute_all_factors + + +def pytest_addoption(parser) -> None: + """Add custom command-line options for EKFAC tests.""" + parser.addoption( + "--gradient_batch_size", + action="store", + type=int, + default=1, + help="Batch size for gradient computation (default: 1)", + ) + parser.addoption( + "--gradient_path", + action="store", + default=None, + help="Path to the gradient file", + ) + parser.addoption( + "--model_name", + action="store", + type=str, + default="EleutherAI/Pythia-14m", + help="Model name to use for ground truth generation (default: EleutherAI/Pythia-14m)", + ) + parser.addoption( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing run directory", + ) + parser.addoption( + "--precision", + action="store", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "int4", "int8"], + help="Model precision for ground truth generation (default: fp32)", + ) + parser.addoption( + "--sample", + action="store_true", + default=False, + help="Enable sampling mode", + ) + parser.addoption( + "--test_dir", + action="store", + default=None, + help="Directory containing test data. If not provided, generates test data using compute_ekfac_ground_truth.py.", + ) + parser.addoption( + "--use_fsdp", + action="store_true", + default=False, + help="Use Fully Sharded Data Parallel (FSDP)", + ) + parser.addoption( + "--world_size", + action="store", + type=int, + default=1, + help="World size for distributed training (default: 1)", + ) + + +@pytest.fixture(autouse=True) +def setup_test() -> None: + """Setup logic run before each test.""" + set_all_seeds(seed=42) + + +@pytest.fixture(scope="session") +def gradient_batch_size(request) -> int: + return request.config.getoption("--gradient_batch_size") + + +@pytest.fixture(scope="session") +def gradient_path(request) -> Optional[str]: + return request.config.getoption("--gradient_path") + + +@pytest.fixture(scope="session") +def model_name(request) -> str: + return request.config.getoption("--model_name") + + +@pytest.fixture(scope="session") +def overwrite(request) -> bool: + return request.config.getoption("--overwrite") + + +@pytest.fixture(scope="session") +def precision(request) -> Precision: + return request.config.getoption("--precision") + + +@pytest.fixture(scope="session") +def sample(request) -> bool: + return request.config.getoption("--sample") + + +@pytest.fixture(scope="session") +def use_fsdp(request) -> bool: + return request.config.getoption("--use_fsdp") + + +@pytest.fixture(scope="session") +def world_size(request) -> int: + return request.config.getoption("--world_size") + + +@pytest.fixture(scope="session") +def test_dir(request, tmp_path_factory) -> str: + """Get or create test directory (does not generate ground truth data).""" + # Check if test directory was provided + test_dir = request.config.getoption("--test_dir") + if test_dir is not None: + return test_dir + + # Create temporary directory for auto-generated test data + tmp_dir = tmp_path_factory.mktemp("ekfac_test_data") + return str(tmp_dir) + + +def ground_truth_base_path(test_dir: str) -> str: + return os.path.join(test_dir, "ground_truth") + + +@pytest.fixture(scope="session") +def ground_truth_setup( + request, test_dir: str, precision: Precision, overwrite: bool +) -> dict[str, Any]: + # Setup for generation + model_name = request.config.getoption("--model_name") + world_size = request.config.getoption("--world_size") + + print(f"\n{'='*60}") + print("Generating ground truth test data") + print(f"Model: {model_name}") + print(f"Precision: {precision}") + print(f"World size: {world_size}") + print(f"{'='*60}\n") + + cfg, workers, device, target_modules, dtype = setup_paths_and_config( + precision=precision, + test_path=ground_truth_base_path(test_dir), + model_name=model_name, + world_size=world_size, + overwrite=overwrite, + ) + + model = load_model_step(cfg, dtype) + ds = load_dataset_step(cfg) + data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers) + + return { + "cfg": cfg, + "workers": workers, + "device": device, + "target_modules": target_modules, + "dtype": dtype, + "model": model, + "data": data, + "batches_world": batches_world, + } + + +@pytest.fixture(scope="session") +def ground_truth_covariances_path( + ground_truth_setup: dict[str, Any], test_dir: str +) -> str: + """Ensure ground truth covariances exist and return path.""" + base_path = ground_truth_base_path(test_dir) + covariances_path = os.path.join(base_path, "covariances") + + if os.path.exists(covariances_path) and not overwrite: + print("Using existing covariances") + return covariances_path + + setup = ground_truth_setup + covariance_test_path = compute_covariances_step( + setup["model"], + setup["data"], + setup["batches_world"], + setup["device"], + setup["target_modules"], + setup["workers"], + base_path, + ) + combine_covariances_step(covariance_test_path, setup["workers"], setup["device"]) + print("Covariances computed") + return covariances_path + + +@pytest.fixture(scope="session") +def ground_truth_eigenvectors_path( + ground_truth_covariances_path: str, + ground_truth_setup: dict[str, Any], + test_dir: str, +) -> str: + """Ensure ground truth eigenvectors exist and return path.""" + base_path = ground_truth_base_path(test_dir) + eigenvectors_path = os.path.join(base_path, "eigenvectors") + + if os.path.exists(eigenvectors_path) and not overwrite: + print("Using existing eigenvectors") + return eigenvectors_path + + setup = ground_truth_setup + compute_eigenvectors_step(base_path, setup["device"], setup["dtype"]) + print("Eigenvectors computed") + return eigenvectors_path + + +@pytest.fixture(scope="session") +def ground_truth_eigenvalue_corrections_path( + ground_truth_eigenvectors_path: str, + ground_truth_setup: dict[str, Any], + test_dir: str, +) -> str: + """Ensure ground truth eigenvalue corrections exist and return path.""" + base_path = ground_truth_base_path(test_dir) + eigenvalue_corrections_path = os.path.join(base_path, "eigenvalue_corrections") + + if os.path.exists(eigenvalue_corrections_path) and not overwrite: + print("Using existing eigenvalue corrections") + return eigenvalue_corrections_path + + setup = ground_truth_setup + eigenvalue_correction_test_path, total_processed_global_lambda = ( + compute_eigenvalue_corrections_step( + setup["model"], + setup["data"], + setup["batches_world"], + setup["device"], + setup["target_modules"], + setup["workers"], + base_path, + ) + ) + combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path, + setup["workers"], + setup["device"], + total_processed_global_lambda, + ) + print("Eigenvalue corrections computed") + print("\n=== Ground Truth Computation Complete ===") + print(f"Results saved to: {base_path}") + return eigenvalue_corrections_path + + +@pytest.fixture(scope="session") +def ground_truth_path( + ground_truth_eigenvalue_corrections_path: str, test_dir: str +) -> str: + """Get ground truth base path with all data guaranteed to exist. + + Depends on ground_truth_eigenvalue_corrections_path to ensure all ground truth data exists. + """ + return ground_truth_base_path(test_dir) + + +@pytest.fixture(scope="session") +def ekfac_results_path( + test_dir: str, + ground_truth_path: str, + world_size: int, + overwrite: bool, + use_fsdp: bool, + sample: bool, +) -> str: + """Setup EKFAC configuration, run computation if needed, and return results path. + + ground_truth_path fixture ensures all ground truth data exists. + """ + results_path = os.path.join(test_dir, "run/influence_results") + + # Load configuration + with open(os.path.join(ground_truth_path, "index_config.json"), "r") as f: + cfg_json = json.load(f) + + cfg = IndexConfig(**cfg_json) + cfg.data = DataConfig(**(cfg_json["data"])) + assert isinstance(cfg.fsdp, bool) # for type checker + cfg.run_path = test_dir + "/run" + cfg.debug = True + cfg.fsdp = use_fsdp + cfg.world_size = world_size + cfg.sample = sample + + if os.path.exists(results_path) and not overwrite: + print(f"Using existing EKFAC results in {results_path}.") + else: + print(f"\nRunning EKFAC computation in {results_path}...") + distributed_computing( + cfg=cfg, + worker_fn=compute_all_factors, + ) + print(f"EKFAC computation completed in {results_path}.") + + return results_path diff --git a/tests/ekfac_tests/ground_truth/collector.py b/tests/ekfac_tests/ground_truth/collector.py index 74adf745..037cecb2 100644 --- a/tests/ekfac_tests/ground_truth/collector.py +++ b/tests/ekfac_tests/ground_truth/collector.py @@ -25,7 +25,8 @@ def teardown(self) -> None: pass def forward_hook(self, name: str, a: Tensor) -> None: - a = a.reshape(-1, a.shape[-1]) # [N*S, O] + # a: [N, S, I], valid_masks: [N, S] -> select valid positions + a = a[self.valid_masks] # [num_valid, I] update = a.mT @ a diff --git a/tests/ekfac_tests/run_apply_compute_ekfac.sh b/tests/ekfac_tests/run_apply_compute_ekfac.sh index 3d5117c3..f2572b04 100755 --- a/tests/ekfac_tests/run_apply_compute_ekfac.sh +++ b/tests/ekfac_tests/run_apply_compute_ekfac.sh @@ -1,10 +1,14 @@ #!/bin/bash -# Run all tests -python test_apply_ekfac.py \ +# Run EKFAC application tests + +cd "$(dirname "$0")" + +pytest -s -v \ --test_dir "./test_files/pile_10k_examples" \ --gradient_path "./test_files/pile_10k_examples/test_gradients/proj_dim_0" \ --overwrite \ --use_fsdp \ --world_size 8 \ --gradient_batch_size 10 \ + test_apply_ekfac.py diff --git a/tests/ekfac_tests/run_test_compute_ekfac.sh b/tests/ekfac_tests/run_test_compute_ekfac.sh index 739a0b55..392e46e7 100755 --- a/tests/ekfac_tests/run_test_compute_ekfac.sh +++ b/tests/ekfac_tests/run_test_compute_ekfac.sh @@ -1,8 +1,15 @@ #!/bin/bash -# Run all tests -python test_compute_ekfac.py \ - --test_dir "/root/bergson/test_files/pile_100_examples" \ +# Run EKFAC computation tests + +cd "$(dirname "$0")" + +pytest -s -v \ + --test_dir "./test_files/pile_100_examples" \ --world_size 8 \ --use_fsdp \ - --overwrite + --overwrite \ + test_compute_ekfac.py \ + test_covariance.py \ + test_eigenvectors.py \ + test_eigenvalue_correction.py diff --git a/tests/ekfac_tests/test_apply_ekfac.py b/tests/ekfac_tests/test_apply_ekfac.py index 414cf2e4..8b97fc48 100644 --- a/tests/ekfac_tests/test_apply_ekfac.py +++ b/tests/ekfac_tests/test_apply_ekfac.py @@ -1,7 +1,10 @@ -import argparse +"""Test EKFAC application against ground truth.""" + import json import os +from typing import Optional +import pytest import torch from safetensors.torch import load_file @@ -9,62 +12,70 @@ from bergson.distributed import distributed_computing from bergson.hessians.ekfac_apply import ekfac_apply_worker -parser = argparse.ArgumentParser(description="Run apply EKFAC tests.") - -parser.add_argument( - "--test_dir", - type=str, - help="Directory containing test files.", -) -parser.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing run directory.", -) -parser.add_argument( - "--use_fsdp", - action="store_true", - help="Use Fully Sharded Data Parallel (FSDP).", -) -parser.add_argument( - "--world_size", - type=int, - default=8, - help="World size for distributed training.", -) - -parser.add_argument( - "--gradient_path", - type=str, - help="Path to the gradient.", -) - -parser.add_argument( - "--gradient_batch_size", - type=int, - default=1, - help="Batch size for gradient computation.", -) - - -args = parser.parse_args() - - -test_dir = args.test_dir -overwrite = args.overwrite -use_fsdp = args.use_fsdp -world_size = args.world_size - - -ground_truth_path = os.path.join(test_dir, "ground_truth") -run_path = os.path.join(test_dir, "run/influence_results") - - -def test_gradients(run_path, ground_truth_path): + +@pytest.fixture(scope="module") +def ekfac_apply_gradient_path( + test_dir: str, + ground_truth_path: str, + world_size: int, + overwrite: bool, + ekfac_results_path: str, + use_fsdp: bool, + gradient_path: Optional[str], + gradient_batch_size: int, +) -> str: + """Setup EKFAC application configuration and run if needed. + + ground_truth_path fixture ensures all required files exist (covariances, eigenvectors, etc). + ekfac_results_path ensures EKFAC computation has run. + """ + # Load configuration + with open(os.path.join(ground_truth_path, "index_config.json"), "r") as f: + cfg_json = json.load(f) + + if gradient_path is None: + pytest.skip( + "No --gradient-path argument provided, skipping EKFAC application tests." + ) + return "" + + cfg = IndexConfig(**cfg_json) + cfg.data = DataConfig(**(cfg_json["data"])) + cfg.run_path = test_dir + "/run" + cfg.debug = True + cfg.fsdp = use_fsdp + cfg.world_size = world_size + cfg.ekfac = True + cfg.gradient_path = gradient_path + cfg.gradient_batch_size = gradient_batch_size + + results_path = gradient_path + "_ekfac" + + if os.path.exists(results_path) and not overwrite: + print(f"Using existing {results_path}.") + else: + print(f"\nRunning EKFAC application in {results_path}...") + distributed_computing( + cfg=cfg, + worker_fn=ekfac_apply_worker, + setup_data=False, + setup_model=False, + setup_processor=False, + ) + print("EKFAC application completed successfully in {results_path}.") + + return results_path + + +def test_gradients_after_ekfac(test_dir: str, ekfac_apply_gradient_path: str) -> None: + """Test gradients after EKFAC application against ground truth.""" + + ground_truth_path = test_dir + "/test_gradients/gradients_after_ekfac" + ground_truth = load_file( os.path.join(ground_truth_path, "gradients.safetensors"), device="cuda" ) - computed_mmap = load_gradients(run_path) + computed_mmap = load_gradients(ekfac_apply_gradient_path) for k in ground_truth.keys(): ground_truth_tensor = ground_truth[k].to(dtype=torch.float32) @@ -97,58 +108,4 @@ def test_gradients(run_path, ground_truth_path): ) print(f" At {tuple(coords)}: gt={gt_val:.2e}, comp={comp_val:.2e}") - -def main(): - # assert covariances, eigenvalue_corrections, eigenvectors and index_config.json exist - - required_files = [ - "covariances", - "eigenvalue_corrections", - "eigenvectors", - "index_config.json", - ] - - for file_name in required_files: - assert os.path.exists( - os.path.join(ground_truth_path, file_name) - ), f"Missing required file: {file_name}" - - cfg_json = json.load( - open(os.path.join(ground_truth_path, "index_config.json"), "r") - ) - print(cfg_json) - cfg = IndexConfig(**cfg_json) - - cfg.data = DataConfig(**(cfg_json["data"])) - - cfg.run_path = test_dir + "/run" - cfg.debug = True - cfg.fsdp = use_fsdp - cfg.world_size = world_size - cfg.ekfac = True - cfg.gradient_path = args.gradient_path - cfg.gradient_batch_size = args.gradient_batch_size - - if not os.path.exists(run_path) or overwrite: - distributed_computing( - cfg=cfg, - worker_fn=ekfac_apply_worker, - setup_data=False, - setup_model=False, - setup_processor=False, - ) - - print("EKFAC application completed successfully.") - else: - print("Using existing run directory.") - - test_gradients( - run_path=cfg.gradient_path + "_ekfac", - ground_truth_path=test_dir + "/test_gradients" + "/gradients_after_ekfac", - ) - - print("\n \n All tests done \n \n") - - -if __name__ == "__main__": - main() + print("\nāœ“ All gradient tests passed\n") diff --git a/tests/ekfac_tests/test_batch_size_invariance.py b/tests/ekfac_tests/test_batch_size_invariance.py new file mode 100644 index 00000000..9d3a696a --- /dev/null +++ b/tests/ekfac_tests/test_batch_size_invariance.py @@ -0,0 +1,50 @@ +"""Test that covariance traces are batch-size invariant after normalization.""" + +import tempfile +from typing import Any + +import torch + +from bergson.data import IndexConfig +from bergson.hessians.ekfac_compute import EkfacComputer +from tests.ekfac_tests.test_utils import load_covariances + + +def test_trace_batch_invariant(ground_truth_setup: dict[str, Any]): + """Normalized covariance traces should be the same regardless of batch size.""" + setup = ground_truth_setup + indices = [ + idx + for worker_batches in setup["batches_world"] + for batch in worker_batches + for idx in batch + ] + + # B=1 vs B=2 batches + batches_b1 = [[i] for i in indices] + batches_b2 = [indices[i : i + 2] for i in range(0, len(indices), 2)] + + def compute_traces(batches): + with tempfile.TemporaryDirectory() as tmpdir: + ekfac = EkfacComputer( + model=setup["model"], + data=setup["data"], + batches=batches, + target_modules=setup["target_modules"], + cfg=IndexConfig(run_path=tmpdir, data=None), + ) + ekfac.compute_covariance() + + A, G, n = load_covariances(tmpdir) + + return ( + sum(v.trace().item() / n for v in A.values()), + sum(v.trace().item() / n for v in G.values()), + ) + + setup["model"].eval() + A1, G1 = compute_traces(batches_b1) + A2, G2 = compute_traces(batches_b2) + + torch.testing.assert_close(A1, A2, rtol=0.01, atol=0) + torch.testing.assert_close(G1, G2, rtol=0.2, atol=0) diff --git a/tests/ekfac_tests/test_compute_ekfac.py b/tests/ekfac_tests/test_compute_ekfac.py index e9788320..c6c814af 100644 --- a/tests/ekfac_tests/test_compute_ekfac.py +++ b/tests/ekfac_tests/test_compute_ekfac.py @@ -1,152 +1,30 @@ -import argparse +"""Test EKFAC computation against ground truth.""" + import json import os import torch -from test_covariance import test_covariances -from test_eigenvalue_correction import test_eigenvalue_correction -from test_eigenvectors import test_eigenvectors -from test_utils import set_all_seeds - -from bergson.data import DataConfig, IndexConfig -from bergson.distributed import distributed_computing -from bergson.hessians.compute_all import compute_all_factors - -parser = argparse.ArgumentParser(description="Run compute EKFAC tests.") - -parser.add_argument( - "--test_dir", - type=str, - help="Directory containing test files.", -) -parser.add_argument( - "--overwrite", - action="store_true", - help="Overwrite existing run directory.", -) -parser.add_argument( - "--use_fsdp", - action="store_true", - help="Use Fully Sharded Data Parallel (FSDP).", -) -parser.add_argument( - "--world_size", - type=int, - default=8, - help="World size for distributed training.", -) -parser.add_argument( - "--sample", - type=bool, - default=False, - help="Batch size for gradient computation.", -) - - -args = parser.parse_args() - - -test_dir = args.test_dir -overwrite = args.overwrite -use_fsdp = args.use_fsdp -world_size = args.world_size -ground_truth_path = os.path.join(test_dir, "ground_truth") -run_path = os.path.join(test_dir, "run/influence_results") - - -def test_total_processed_examples(): +def test_total_processed_examples( + ground_truth_covariances_path: str, ekfac_results_path: str +) -> None: + """Test that total processed examples match between ground truth and computed values.""" total_processed_ground_truth_path = os.path.join( - ground_truth_path, "covariances/stats.json" + ground_truth_covariances_path, "stats.json" + ) + total_processed_run_path = os.path.join( + ekfac_results_path, "total_processed_covariances.pt" ) - total_processed_run_path = os.path.join(run_path, "total_processed_covariances.pt") with open(total_processed_ground_truth_path, "r") as f: ground_truth_data = json.load(f) total_processed_ground_truth = ground_truth_data["total_processed_global"] - total_processed_run = torch.load(total_processed_run_path).item() - - equal = total_processed_ground_truth == total_processed_run - - if equal: - print(f"Total processed examples match: {total_processed_ground_truth}") - else: - print( - f"Total processed examples do not match! Ground truth: {total_processed_ground_truth}," - f" Run: {total_processed_run}" - ) - print("-*" * 50) - - -def main(): - # assert covariances, eigenvalue_corrections, eigenvectors and index_config.json exist + total_processed_run = torch.load(total_processed_run_path, weights_only=True).item() - set_all_seeds(seed=42) - required_files = [ - "covariances", - "eigenvalue_corrections", - "eigenvectors", - "index_config.json", - ] - - for file_name in required_files: - assert os.path.exists( - os.path.join(ground_truth_path, file_name) - ), f"Missing required file: {file_name}" - - cfg_json = json.load( - open(os.path.join(ground_truth_path, "index_config.json"), "r") - ) - - cfg = IndexConfig(**cfg_json) - - cfg.data = DataConfig(**(cfg_json["data"])) - assert isinstance(cfg.fsdp, bool) # for the type checker - cfg.run_path = test_dir + "/run" - cfg.debug = True - cfg.fsdp = use_fsdp - cfg.world_size = world_size - cfg.sample = args.sample - - if not os.path.exists(run_path) or overwrite: - distributed_computing( - cfg=cfg, - worker_fn=compute_all_factors, - ) - print("EKFAC computation completed successfully.") - else: - print("Using existing run directory.") - - test_total_processed_examples() - - test_covariances( - run_path=run_path, - ground_truth_path=ground_truth_path, - covariance_type="activation", + assert total_processed_ground_truth == total_processed_run, ( + f"Total processed examples do not match! " + f"Ground truth: {total_processed_ground_truth}, Run: {total_processed_run}" ) - test_covariances( - run_path=run_path, - ground_truth_path=ground_truth_path, - covariance_type="gradient", - ) - - # Currently this tests for close equality, but does not account for sign differences in eigenvectors. TODO: fix. - test_eigenvectors( - run_path=run_path, - ground_truth_path=ground_truth_path, - eigenvector_type="activation", - ) - test_eigenvectors( - run_path=run_path, - ground_truth_path=ground_truth_path, - eigenvector_type="gradient", - ) - - test_eigenvalue_correction(ground_truth_path=ground_truth_path, run_path=run_path) - print("\n \n All tests done \n \n") - - -if __name__ == "__main__": - main() + print(f"āœ“ Total processed examples match: {total_processed_ground_truth}") diff --git a/tests/ekfac_tests/test_covariance.py b/tests/ekfac_tests/test_covariance.py index 9824d75b..b0eb6232 100644 --- a/tests/ekfac_tests/test_covariance.py +++ b/tests/ekfac_tests/test_covariance.py @@ -1,42 +1,30 @@ import os -from typing import Literal -import torch +import pytest from safetensors.torch import load_file from bergson.hessians.utils import TensorDict +from tests.ekfac_tests.test_utils import load_sharded_covariances +@pytest.mark.parametrize("covariance_type", ["activation", "gradient"]) def test_covariances( - ground_truth_path, - run_path, - covariance_type: Literal["activation", "gradient"] = "activation", -): + ekfac_results_path: str, + ground_truth_covariances_path: str, + covariance_type: str, +) -> None: + """Test covariances against ground truth.""" + print(f"\nTesting {covariance_type} covariances...") + covariances_ground_truth_path = os.path.join( - ground_truth_path, f"covariances/{covariance_type}_covariance.safetensors" + ground_truth_covariances_path, f"{covariance_type}_covariance.safetensors" ) covariances_run_path = os.path.join( - run_path, f"{covariance_type}_covariance_sharded" + ekfac_results_path, f"{covariance_type}_covariance_sharded" ) - # load ground_truth ground_truth_covariances = TensorDict(load_file(covariances_ground_truth_path)) - - world_size = len(os.listdir(covariances_run_path)) # number of shards - # load run covariances shards and concatenate them - - run_covariances_shards = [ - os.path.join(covariances_run_path, f"shard_{rank}.safetensors") - for rank in range(world_size) - ] - run_covariances_list = [(load_file(shard)) for shard in run_covariances_shards] - run_covariances = {} - for k, v in run_covariances_list[0].items(): - run_covariances[k] = torch.cat( - [shard[k] for shard in run_covariances_list], dim=0 - ) - - run_covariances = TensorDict(run_covariances) + run_covariances = TensorDict(load_sharded_covariances(covariances_run_path)) diff = ( ground_truth_covariances.sub(run_covariances) @@ -52,15 +40,20 @@ def test_covariances( if all(equal_dict.values()): print(f"{covariance_type} covariances match") - else: max_diff = diff.max() - # print keys for which the covariances do not match - print(f"{covariance_type} covariances do not match!") + # Collect error details for assertion message + error_details = [] for k, v in equal_dict.items(): if not v: - print( - f"Covariance {k} does not match with max relative difference " - f"{(100 * max_diff[k]):.3f} % and mean {100 * diff[k].mean()} % !", + error_details.append( + f" {k}: max_rel_diff={(100 * max_diff[k]):.3f}%, " + f"mean={(100 * diff[k].mean()):.3f}%" ) + + error_msg = f"{covariance_type} covariances do not match!\n" + "\n".join( + error_details + ) + assert False, error_msg + print("-*" * 50) diff --git a/tests/ekfac_tests/test_eigenvalue_correction.py b/tests/ekfac_tests/test_eigenvalue_correction.py index 5c3fe8b1..341b075c 100644 --- a/tests/ekfac_tests/test_eigenvalue_correction.py +++ b/tests/ekfac_tests/test_eigenvalue_correction.py @@ -1,25 +1,39 @@ import os +import pytest import torch from safetensors.torch import load_file from bergson.hessians.utils import TensorDict -def test_eigenvalue_correction(ground_truth_path, run_path): - """Test eigenvalue corrections by comparing to ground truth.""" +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Numerical precision differences on CPU vs GPU", +) +def test_eigenvalue_corrections( + ground_truth_eigenvalue_corrections_path: str, + ekfac_results_path: str, +) -> None: + """Test eigenvalue corrections against ground truth.""" + print("\nTesting eigenvalue corrections...") lambda_ground_truth_path = os.path.join( - ground_truth_path, "eigenvalue_corrections/eigenvalue_corrections.safetensors" + ground_truth_eigenvalue_corrections_path, "eigenvalue_corrections.safetensors" ) - lambda_run_path = os.path.join(run_path, "eigenvalue_correction_sharded") + lambda_run_path = os.path.join(ekfac_results_path, "eigenvalue_correction_sharded") # load ground_truth lambda_ground_truth = TensorDict(load_file(lambda_ground_truth_path)) world_size = len(os.listdir(lambda_run_path)) # number of shards - lambda_run_shards_path = [os.path.join(lambda_run_path, f"shard_{rank}.safetensors") for rank in range(world_size)] - lambda_list_shards = [(load_file(shard_path)) for shard_path in lambda_run_shards_path] + lambda_run_shards_path = [ + os.path.join(lambda_run_path, f"shard_{rank}.safetensors") + for rank in range(world_size) + ] + lambda_list_shards = [ + (load_file(shard_path)) for shard_path in lambda_run_shards_path + ] lambda_run = {} for k, v in lambda_list_shards[0].items(): if len(v.shape) == 0: @@ -29,8 +43,11 @@ def test_eigenvalue_correction(ground_truth_path, run_path): lambda_run = TensorDict(lambda_run) - total_processed_run_path = os.path.join(run_path, "total_processed_lambda_correction.pt") - total = torch.load(total_processed_run_path).to(device=lambda_run[list(lambda_run.keys())[0]].device) + total_processed_run_path = os.path.join( + ekfac_results_path, "total_processed_lambda_correction.pt" + ) + lambda_device = lambda_run[list(lambda_run.keys())[0]].device + total = torch.load(total_processed_run_path, map_location=lambda_device) lambda_run.div_(total) rtol = 1e-10 equal_dict = lambda_ground_truth.allclose(lambda_run, rtol=rtol) @@ -38,26 +55,40 @@ def test_eigenvalue_correction(ground_truth_path, run_path): if all(equal_dict.values()): print("Eigenvalue corrections match!") else: - print("Eigenvalue corrections do not match!") - diff = lambda_ground_truth.sub(lambda_run).div(lambda_ground_truth).abs() max_diff = diff.max() - # print keys for which the covariances do not match + # Collect error details for assertion message + error_details = [] + has_significant_errors = False for k, v in equal_dict.items(): if not v: # Find location of max difference - coord = diff[k].argmax() - a, b = coord // lambda_ground_truth[k].shape[1], coord % lambda_ground_truth[k].shape[1] - - print( - f"Eigenvalue correction {k} does not match with max relative difference " - f"{(100 * max_diff[k]):.3f} % and mean {100 * diff[k].mean()} % !", - ) - print( - f"Max relative difference at index {a},{b} with ground truth value " - f"{lambda_ground_truth[k][a, b]:.3e} and run value {lambda_run[k][a, b]:.3e}" + a, b = ( + coord // lambda_ground_truth[k].shape[1], + coord % lambda_ground_truth[k].shape[1], ) - print("\n") + if max_diff[k] < 1e-3: + error_details.append( + f" {k}: small differences within tolerance (max_rel_diff={(100 * max_diff[k]):.3f}%)" + ) + else: + has_significant_errors = True + error_details.append( + f" {k}: max_rel_diff={(100 * max_diff[k]):.3f}%, " + f"mean={(100 * diff[k].mean()):.3f}%" + ) + error_details.append( + f" at [{a},{b}]: gt={lambda_ground_truth[k][a, b]:.3e}, " + f"run={lambda_run[k][a, b]:.3e}" + ) + + if has_significant_errors: + error_msg = "Eigenvalue corrections do not match!\n" + "\n".join( + error_details + ) + assert False, error_msg + else: + print("āœ“ Eigenvalue corrections: all differences within tolerance") diff --git a/tests/ekfac_tests/test_eigenvectors.py b/tests/ekfac_tests/test_eigenvectors.py index 79eb5f67..09a2a188 100644 --- a/tests/ekfac_tests/test_eigenvectors.py +++ b/tests/ekfac_tests/test_eigenvectors.py @@ -1,21 +1,31 @@ import os -from typing import Literal +import pytest import torch from safetensors.torch import load_file from bergson.hessians.utils import TensorDict +@pytest.mark.parametrize("eigenvector_type", ["activation", "gradient"]) def test_eigenvectors( - ground_truth_path, - run_path, - eigenvector_type: Literal["activation", "gradient"] = "activation", -): + ekfac_results_path: str, + ground_truth_eigenvectors_path: str, + eigenvector_type: str, +) -> None: + """Test eigenvectors against ground truth. + + Note: Currently tests for close equality but does not account for + sign differences in eigenvectors. TODO: fix. + """ + print(f"\nTesting {eigenvector_type} eigenvectors...") + eigenvectors_ground_truth_path = os.path.join( - ground_truth_path, f"eigenvectors/eigenvectors_{eigenvector_type}s.safetensors" + ground_truth_eigenvectors_path, f"eigenvectors_{eigenvector_type}s.safetensors" + ) + eigenvectors_run_path = os.path.join( + ekfac_results_path, f"{eigenvector_type}_eigen_sharded" ) - eigenvectors_run_path = os.path.join(run_path, f"{eigenvector_type}_eigen_sharded") # load ground_truth ground_truth_eigenvectors = TensorDict(load_file(eigenvectors_ground_truth_path)) @@ -41,12 +51,13 @@ def test_eigenvectors( if all(equal_dict.values()): print(f"{eigenvector_type} eigenvectors match!") - else: diff = run_eigenvectors.sub(ground_truth_eigenvectors).abs() max_diff = diff.max() - # print keys for which the covariances do not match - print(f"{eigenvector_type} eigenvectors do not match!") + # Collect error details for assertion message + error_details = [] + has_significant_errors = False + for k, v in equal_dict.items(): if not v: # Find location of max difference @@ -55,13 +66,22 @@ def test_eigenvectors( relative_diff = ( 100 * max_diff[k] / ground_truth_eigenvectors[k][max_diff_idx].abs() ) + if max_diff[k] < 1e-6 and relative_diff < 1e-3: - print(f"Eigenvector {k} small differences within tolerance.") + error_details.append(f" {k}: small differences within tolerance") else: - print( - f"Eigenvalue corrections {k} does not match with absolute difference {max_diff[k]:.3f} and max " - f"rel. difference {relative_diff:.3f} %!" + has_significant_errors = True + error_details.append( + f" {k}: abs_diff={max_diff[k]:.3f}, " + f"rel_diff={relative_diff:.3f}%" ) - print("\n") + if has_significant_errors: + error_msg = f"{eigenvector_type} eigenvectors do not match!\n" + "\n".join( + error_details + ) + assert False, error_msg + else: + print(f"{eigenvector_type} eigenvectors: all differences within tolerance") + print("-*" * 50) diff --git a/tests/ekfac_tests/test_fim_accuracy.py b/tests/ekfac_tests/test_fim_accuracy.py new file mode 100644 index 00000000..77d8f271 --- /dev/null +++ b/tests/ekfac_tests/test_fim_accuracy.py @@ -0,0 +1,172 @@ +""" +Test EKFAC accuracy for computing the Fisher Information Matrix. + +Compares the K-FAC approximation F_kfac = G āŠ— A against the exact FIM +computed from per-position gradients on a toy language model. +""" + +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +from bergson.data import IndexConfig +from bergson.hessians.ekfac_compute import EkfacComputer +from bergson.utils import get_device +from tests.ekfac_tests.test_utils import load_covariances +from tests.ekfac_tests.toy_model import ( + ToyDataConfig, + ToyLM, + ToyLMConfig, + generate_batches, + generate_dataset, +) + + +def compute_exact_fim( + model: ToyLM, + dataset, + batches: list[list[int]], + device: torch.device, + sample: bool, +) -> tuple[Tensor, Tensor, Tensor, int]: + """ + Compute exact FIM from per-position gradients for ToyLM. + + Args: + sample: If True, sample labels from model distribution (true FIM). + If False, use dataset labels (empirical FIM). + + Returns: + F_exact: Exact FIM from per-position gradients + A: Activation covariance (normalized) + G: Gradient covariance (normalized) + n_positions: Total number of valid positions + """ + hidden_size = model.config.hidden_size + vocab_size = model.config.vocab_size + + position_grads = [] + A_sum = torch.zeros(hidden_size, hidden_size, device=device) + G_sum = torch.zeros(vocab_size, vocab_size, device=device) + + for batch_indices in batches: + for idx in batch_indices: + input_ids = torch.tensor( + dataset[idx]["input_ids"], device=device + ).unsqueeze(0) + labels = torch.tensor(dataset[idx]["labels"], device=device) + + hidden = model.model.embed(input_ids) + hidden.requires_grad_(True) + logits = model.model.linear(hidden) + + for s in range(input_ids.shape[1] - 1): + if sample: + # Sample from model distribution (true FIM) + with torch.no_grad(): + probs = torch.softmax(logits[0, s].detach(), dim=-1) + target = torch.multinomial(probs, num_samples=1).squeeze() + else: + # Use dataset labels (empirical FIM) + target = labels[s + 1] + + loss = F.cross_entropy(logits[0, s], target) + + (g,) = torch.autograd.grad(loss, logits, retain_graph=True) + g = g[0, s] + a = hidden[0, s].detach() + + position_grads.append(torch.outer(g, a).flatten().detach()) + A_sum += torch.outer(a, a) + G_sum += torch.outer(g.detach(), g.detach()) + + n_positions = len(position_grads) + grads_tensor = torch.stack(position_grads) + F_exact = grads_tensor.T @ grads_tensor / n_positions + + A = A_sum / n_positions + A = (A + A.T) / 2 + G = G_sum / n_positions + G = (G + G.T) / 2 + + return F_exact, A, G, n_positions + + +@pytest.mark.parametrize( + "seq_lengths, num_batches, sample, max_rel_error", + [ + ((512,), 100, False, 0.05), + ((512,), 100, True, 0.05), + ((4,), 10000, False, 0.05), # rel_error = ~0.25 without valid_masks logic + ((4,), 10000, True, 0.10), # rel_error = ~0.25 without valid_masks logic + ((512, 2), 100, False, 0.05), # rel_error = ~0.6 without valid_masks logic + ((512, 2), 100, True, 0.20), # rel_error = ~1.2 without valid_masks logic + ], +) +def test_kfac_fim_accuracy(seq_lengths, num_batches, max_rel_error, sample, tmp_path): + """ + Test that KFAC approximates the FIM within tolerance. + + Args: + sample: If True, test true FIM (sampled labels). + If False, test empirical FIM (dataset labels). + """ + config = ToyDataConfig( + vocab_size=8, + hidden_size=4, + seq_lengths=seq_lengths, + num_batches=num_batches, + ) + device = get_device() + + dataset = generate_dataset(config) + batches = generate_batches(config) + + model_config = ToyLMConfig( + vocab_size=config.vocab_size, hidden_size=config.hidden_size + ) + model = ToyLM( + model_config, + training_data=dataset, + training_batches=batches, + device=device, + ) + + F_exact, A_exact, G_exact, total_processed_exact = compute_exact_fim( + model, dataset, batches, device, sample=sample + ) + + run_path = tmp_path / "run" + idx_config = IndexConfig(run_path=run_path, data=None, sample=sample) + + kfac = EkfacComputer( + model=model, + data=dataset, + batches=batches, + target_modules={"linear"}, + cfg=idx_config, + ) + + kfac.compute_covariance() + + A_dict_kfac, G_dict_kfac, total_processed_kfac = load_covariances(run_path) + + assert total_processed_kfac == total_processed_exact + + A_kfac = list(A_dict_kfac.values())[0].float().to(device) / total_processed_kfac + A_kfac = (A_kfac + A_kfac.T) / 2 + G_kfac = list(G_dict_kfac.values())[0].float().to(device) / total_processed_kfac + G_kfac = (G_kfac + G_kfac.T) / 2 + + # A and G should be the same when we're not sampling + if not sample: + torch.testing.assert_close(A_kfac, A_exact, rtol=1e-3, atol=1e-6) + torch.testing.assert_close(G_kfac, G_exact, rtol=1e-3, atol=1e-6) + + F_kfac = torch.kron(G_kfac, A_kfac) + rel_error = (torch.norm(F_kfac - F_exact) / torch.norm(F_exact)).item() + + assert ( + rel_error <= max_rel_error + ), f"KFAC rel_error {rel_error:.4f} greater than tolerated max_rel_error {max_rel_error} for seq_lengths={seq_lengths}, num_batches={num_batches}, sample={sample}" diff --git a/tests/ekfac_tests/test_utils.py b/tests/ekfac_tests/test_utils.py index d3d6f10b..cd5a8b85 100644 --- a/tests/ekfac_tests/test_utils.py +++ b/tests/ekfac_tests/test_utils.py @@ -2,9 +2,57 @@ import os import random +from pathlib import Path import numpy as np import torch +from safetensors.torch import load_file + + +def load_sharded_covariances(sharded_dir: str | Path) -> dict[str, torch.Tensor]: + """Load and concatenate sharded covariance files. + + Args: + sharded_dir: Directory containing shard_0.safetensors, shard_1.safetensors, etc. + + Returns: + Dictionary mapping layer names to concatenated covariance tensors. + """ + sharded_dir = Path(sharded_dir) + shard_files = sorted(sharded_dir.glob("shard_*.safetensors")) + + if not shard_files: + raise FileNotFoundError(f"No shard files found in {sharded_dir}") + + shards = [load_file(str(f)) for f in shard_files] + + # Concatenate shards along first dimension + result = {} + for key in shards[0]: + result[key] = torch.cat([shard[key] for shard in shards], dim=0) + + return result + + +def load_covariances( + run_path: str | Path, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], int]: + """Load activation and gradient covariances from an EKFAC run. + + Args: + run_path: Path to the run directory containing influence_results/. + + Returns: + Tuple of (activation_covariances, gradient_covariances, total_processed). + """ + run_path = Path(run_path) + results_path = run_path / "influence_results" + + A_cov = load_sharded_covariances(results_path / "activation_covariance_sharded") + G_cov = load_sharded_covariances(results_path / "gradient_covariance_sharded") + total_processed = torch.load(results_path / "total_processed_covariances.pt").item() + + return A_cov, G_cov, total_processed def set_all_seeds(seed: int = 42) -> None: diff --git a/tests/ekfac_tests/toy_model.py b/tests/ekfac_tests/toy_model.py new file mode 100644 index 00000000..4d6b6aa3 --- /dev/null +++ b/tests/ekfac_tests/toy_model.py @@ -0,0 +1,155 @@ +""" +Toy language model for EKFAC testing. + +Provides a minimal transformers-compatible model and dataset generation +utilities for testing EkfacComputer without loading real models. +""" + +from dataclasses import dataclass + +import torch +import torch.nn as nn +from datasets import Dataset +from torch import Tensor +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutput + + +class ToyLMConfig(PretrainedConfig): + """Configuration for ToyLM - a minimal language model for testing.""" + + model_type = "toy_lm" + + def __init__( + self, + vocab_size: int = 8, + hidden_size: int = 4, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + super().__init__(**kwargs) + + +class ToyLMModule(nn.Module): + """The base model (what hooks attach to).""" + + def __init__(self, config: ToyLMConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size, config.hidden_size) + self.linear = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, input_ids: Tensor) -> Tensor: + hidden = self.embed(input_ids) # [B, S] -> [B, S, H] + return self.linear(hidden) # [B, S, H] -> [B, S, V] + + +class ToyLM(PreTrainedModel): + """Toy language model compatible with EkfacComputer.""" + + config_class = ToyLMConfig + base_model_prefix = "model" + + def __init__( + self, + config: ToyLMConfig, + *, + training_data=None, + training_batches: list[list[int]] | None = None, + device: torch.device | None = None, + num_steps: int = 5000, + ): + super().__init__(config) + self.model = ToyLMModule(config) + + if training_data is not None and training_batches is not None: + self._train(training_data, training_batches, device, num_steps) + + def _train( + self, + dataset, + batches: list[list[int]], + device: torch.device | None, + num_steps: int, + lr: float = 0.1, + ) -> None: + """Train the model to make logits more peaked (like a real LLM).""" + import torch.nn.functional as F + + if device is not None: + self.to(device) + + optimizer = torch.optim.SGD(self.parameters(), lr=lr) + + step = 0 + while step < num_steps: + for batch_indices in batches: + for idx in batch_indices: + input_ids = torch.tensor( + dataset[idx]["input_ids"], device=device + ).unsqueeze(0) + labels = torch.tensor( + dataset[idx]["labels"], device=device + ).unsqueeze(0) + + logits = self(input_ids).logits + loss = F.cross_entropy( + logits[:, :-1].reshape(-1, logits.size(-1)), + labels[:, 1:].reshape(-1), + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + step += 1 + if step >= num_steps: + return + + @property + def base_model(self) -> nn.Module: + return self.model + + def forward(self, input_ids: Tensor, **kwargs) -> CausalLMOutput: + logits = self.model(input_ids) + return CausalLMOutput(logits=logits) + + +@dataclass +class ToyDataConfig: + """Configuration for toy data generation.""" + + vocab_size: int = 8 + hidden_size: int = 4 + seq_lengths: tuple[int, ...] = (2,) + num_batches: int = 2000 + + @property + def max_seq_len(self) -> int: + return max(self.seq_lengths) + + @property + def batch_size(self) -> int: + return len(self.seq_lengths) + + +def generate_dataset(config: ToyDataConfig) -> Dataset: + """Generate a HuggingFace Dataset for use with EkfacComputer.""" + data = {"input_ids": [], "labels": []} + + for _ in range(config.num_batches): + for seq_len in config.seq_lengths: + input_ids = torch.randint(0, config.vocab_size, (seq_len,)).tolist() + data["input_ids"].append(input_ids) + data["labels"].append(input_ids) + + return Dataset.from_dict(data) + + +def generate_batches(config: ToyDataConfig) -> list[list[int]]: + """Generate batch indices for EkfacComputer.""" + batch_size = len(config.seq_lengths) + return [ + list(range(i * batch_size, (i + 1) * batch_size)) + for i in range(config.num_batches) + ] diff --git a/uv.lock b/uv.lock index 4b816c5b..8afd70ad 100644 --- a/uv.lock +++ b/uv.lock @@ -200,6 +200,14 @@ faiss = [ { name = "faiss-gpu-cu12" }, ] +[package.dev-dependencies] +dev = [ + { name = "pre-commit" }, + { name = "pre-commit-uv" }, + { name = "pyright" }, + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "accelerate" }, @@ -216,11 +224,19 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'" }, { name = "simple-parsing" }, { name = "torch" }, - { name = "transformers" }, + { name = "transformers", specifier = ">=4.57.1" }, { name = "trl", marker = "extra == 'example'" }, ] provides-extras = ["dev", "example", "faiss"] +[package.metadata.requires-dev] +dev = [ + { name = "pre-commit", specifier = ">=4.2.0" }, + { name = "pre-commit-uv", specifier = ">=4.1.5" }, + { name = "pyright", specifier = ">=1.1.406" }, + { name = "pytest", specifier = ">=8.4.2" }, +] + [[package]] name = "bitsandbytes" version = "0.47.0" @@ -1448,6 +1464,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/74/a88bf1b1efeae488a0c0b7bdf71429c313722d1fc0f377537fbe554e6180/pre_commit-4.2.0-py2.py3-none-any.whl", hash = "sha256:a009ca7205f1eb497d10b845e52c838a98b6cdd2102a6c8e4540e94ee75c58bd", size = 220707, upload-time = "2025-03-18T21:35:19.343Z" }, ] +[[package]] +name = "pre-commit-uv" +version = "4.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pre-commit" }, + { name = "uv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/0c/e6ab71e93d8e78ffa36a1f8b6ce12014679e2b83b401404c12bb2840078f/pre_commit_uv-4.1.5.tar.gz", hash = "sha256:3f40714152b4f4aa484703b8dbfeb9baa0aaedb17207e0012b3561da756d577d", size = 6920, upload-time = "2025-08-27T14:44:40.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/c6/747bc58da9f0665c607890c73b349b3934381e312272f584808182655898/pre_commit_uv-4.1.5-py3-none-any.whl", hash = "sha256:f4805e45615b898c4ca6ea37bdb60a05bb7830f986c303a06a378d6b50c3aa9e", size = 5653, upload-time = "2025-08-27T14:44:39.187Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -2009,27 +2038,27 @@ wheels = [ [[package]] name = "tokenizers" -version = "0.21.4" +version = "0.22.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c2/2f/402986d0823f8d7ca139d969af2917fefaa9b947d1fb32f6168c509f2492/tokenizers-0.21.4.tar.gz", hash = "sha256:fa23f85fbc9a02ec5c6978da172cdcbac23498c3ca9f3645c5c68740ac007880", size = 351253, upload-time = "2025-07-28T15:48:54.325Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/c6/fdb6f72bf6454f52eb4a2510be7fb0f614e541a2554d6210e370d85efff4/tokenizers-0.21.4-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:2ccc10a7c3bcefe0f242867dc914fc1226ee44321eb618cfe3019b5df3400133", size = 2863987, upload-time = "2025-07-28T15:48:44.877Z" }, - { url = "https://files.pythonhosted.org/packages/8d/a6/28975479e35ddc751dc1ddc97b9b69bf7fcf074db31548aab37f8116674c/tokenizers-0.21.4-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:5e2f601a8e0cd5be5cc7506b20a79112370b9b3e9cb5f13f68ab11acd6ca7d60", size = 2732457, upload-time = "2025-07-28T15:48:43.265Z" }, - { url = "https://files.pythonhosted.org/packages/aa/8f/24f39d7b5c726b7b0be95dca04f344df278a3fe3a4deb15a975d194cbb32/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b376f5a1aee67b4d29032ee85511bbd1b99007ec735f7f35c8a2eb104eade5", size = 3012624, upload-time = "2025-07-28T13:22:43.895Z" }, - { url = "https://files.pythonhosted.org/packages/58/47/26358925717687a58cb74d7a508de96649544fad5778f0cd9827398dc499/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2107ad649e2cda4488d41dfd031469e9da3fcbfd6183e74e4958fa729ffbf9c6", size = 2939681, upload-time = "2025-07-28T13:22:47.499Z" }, - { url = "https://files.pythonhosted.org/packages/99/6f/cc300fea5db2ab5ddc2c8aea5757a27b89c84469899710c3aeddc1d39801/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c73012da95afafdf235ba80047699df4384fdc481527448a078ffd00e45a7d9", size = 3247445, upload-time = "2025-07-28T15:48:39.711Z" }, - { url = "https://files.pythonhosted.org/packages/be/bf/98cb4b9c3c4afd8be89cfa6423704337dc20b73eb4180397a6e0d456c334/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f23186c40395fc390d27f519679a58023f368a0aad234af145e0f39ad1212732", size = 3428014, upload-time = "2025-07-28T13:22:49.569Z" }, - { url = "https://files.pythonhosted.org/packages/75/c7/96c1cc780e6ca7f01a57c13235dd05b7bc1c0f3588512ebe9d1331b5f5ae/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc88bb34e23a54cc42713d6d98af5f1bf79c07653d24fe984d2d695ba2c922a2", size = 3193197, upload-time = "2025-07-28T13:22:51.471Z" }, - { url = "https://files.pythonhosted.org/packages/f2/90/273b6c7ec78af547694eddeea9e05de771278bd20476525ab930cecaf7d8/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51b7eabb104f46c1c50b486520555715457ae833d5aee9ff6ae853d1130506ff", size = 3115426, upload-time = "2025-07-28T15:48:41.439Z" }, - { url = "https://files.pythonhosted.org/packages/91/43/c640d5a07e95f1cf9d2c92501f20a25f179ac53a4f71e1489a3dcfcc67ee/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:714b05b2e1af1288bd1bc56ce496c4cebb64a20d158ee802887757791191e6e2", size = 9089127, upload-time = "2025-07-28T15:48:46.472Z" }, - { url = "https://files.pythonhosted.org/packages/44/a1/dd23edd6271d4dca788e5200a807b49ec3e6987815cd9d0a07ad9c96c7c2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:1340ff877ceedfa937544b7d79f5b7becf33a4cfb58f89b3b49927004ef66f78", size = 9055243, upload-time = "2025-07-28T15:48:48.539Z" }, - { url = "https://files.pythonhosted.org/packages/21/2b/b410d6e9021c4b7ddb57248304dc817c4d4970b73b6ee343674914701197/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:3c1f4317576e465ac9ef0d165b247825a2a4078bcd01cba6b54b867bdf9fdd8b", size = 9298237, upload-time = "2025-07-28T15:48:50.443Z" }, - { url = "https://files.pythonhosted.org/packages/b7/0a/42348c995c67e2e6e5c89ffb9cfd68507cbaeb84ff39c49ee6e0a6dd0fd2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c212aa4e45ec0bb5274b16b6f31dd3f1c41944025c2358faaa5782c754e84c24", size = 9461980, upload-time = "2025-07-28T15:48:52.325Z" }, - { url = "https://files.pythonhosted.org/packages/3d/d3/dacccd834404cd71b5c334882f3ba40331ad2120e69ded32cf5fda9a7436/tokenizers-0.21.4-cp39-abi3-win32.whl", hash = "sha256:6c42a930bc5f4c47f4ea775c91de47d27910881902b0f20e4990ebe045a415d0", size = 2329871, upload-time = "2025-07-28T15:48:56.841Z" }, - { url = "https://files.pythonhosted.org/packages/41/f2/fd673d979185f5dcbac4be7d09461cbb99751554ffb6718d0013af8604cb/tokenizers-0.21.4-cp39-abi3-win_amd64.whl", hash = "sha256:475d807a5c3eb72c59ad9b5fcdb254f6e17f53dfcbb9903233b0dfa9c943b597", size = 2507568, upload-time = "2025-07-28T15:48:55.456Z" }, + { url = "https://files.pythonhosted.org/packages/bf/33/f4b2d94ada7ab297328fc671fed209368ddb82f965ec2224eb1892674c3a/tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73", size = 3069318, upload-time = "2025-09-19T09:49:11.848Z" }, + { url = "https://files.pythonhosted.org/packages/1c/58/2aa8c874d02b974990e89ff95826a4852a8b2a273c7d1b4411cdd45a4565/tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc", size = 2926478, upload-time = "2025-09-19T09:49:09.759Z" }, + { url = "https://files.pythonhosted.org/packages/1e/3b/55e64befa1e7bfea963cf4b787b2cea1011362c4193f5477047532ce127e/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a", size = 3256994, upload-time = "2025-09-19T09:48:56.701Z" }, + { url = "https://files.pythonhosted.org/packages/71/0b/fbfecf42f67d9b7b80fde4aabb2b3110a97fac6585c9470b5bff103a80cb/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7", size = 3153141, upload-time = "2025-09-19T09:48:59.749Z" }, + { url = "https://files.pythonhosted.org/packages/17/a9/b38f4e74e0817af8f8ef925507c63c6ae8171e3c4cb2d5d4624bf58fca69/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21", size = 3508049, upload-time = "2025-09-19T09:49:05.868Z" }, + { url = "https://files.pythonhosted.org/packages/d2/48/dd2b3dac46bb9134a88e35d72e1aa4869579eacc1a27238f1577270773ff/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214", size = 3710730, upload-time = "2025-09-19T09:49:01.832Z" }, + { url = "https://files.pythonhosted.org/packages/93/0e/ccabc8d16ae4ba84a55d41345207c1e2ea88784651a5a487547d80851398/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f", size = 3412560, upload-time = "2025-09-19T09:49:03.867Z" }, + { url = "https://files.pythonhosted.org/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4", size = 3250221, upload-time = "2025-09-19T09:49:07.664Z" }, + { url = "https://files.pythonhosted.org/packages/d7/a6/2c8486eef79671601ff57b093889a345dd3d576713ef047776015dc66de7/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879", size = 9345569, upload-time = "2025-09-19T09:49:14.214Z" }, + { url = "https://files.pythonhosted.org/packages/6b/16/32ce667f14c35537f5f605fe9bea3e415ea1b0a646389d2295ec348d5657/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446", size = 9271599, upload-time = "2025-09-19T09:49:16.639Z" }, + { url = "https://files.pythonhosted.org/packages/51/7c/a5f7898a3f6baa3fc2685c705e04c98c1094c523051c805cdd9306b8f87e/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a", size = 9533862, upload-time = "2025-09-19T09:49:19.146Z" }, + { url = "https://files.pythonhosted.org/packages/36/65/7e75caea90bc73c1dd8d40438adf1a7bc26af3b8d0a6705ea190462506e1/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390", size = 9681250, upload-time = "2025-09-19T09:49:21.501Z" }, + { url = "https://files.pythonhosted.org/packages/30/2c/959dddef581b46e6209da82df3b78471e96260e2bc463f89d23b1bf0e52a/tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82", size = 2472003, upload-time = "2025-09-19T09:49:27.089Z" }, + { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, ] [[package]] @@ -2175,7 +2204,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.54.1" +version = "4.57.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2189,9 +2218,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/21/6c/4caeb57926f91d943f309b062e22ad1eb24a9f530421c5a65c1d89378a7a/transformers-4.54.1.tar.gz", hash = "sha256:b2551bb97903f13bd90c9467d0a144d41ca4d142defc044a99502bb77c5c1052", size = 9514288, upload-time = "2025-07-29T15:57:22.826Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/68/a39307bcc4116a30b2106f2e689130a48de8bd8a1e635b5e1030e46fcd9e/transformers-4.57.1.tar.gz", hash = "sha256:f06c837959196c75039809636cd964b959f6604b75b8eeec6fdfc0440b89cc55", size = 10142511, upload-time = "2025-10-14T15:39:26.18Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cf/18/eb7578f84ef5a080d4e5ca9bc4f7c68e7aa9c1e464f1b3d3001e4c642fce/transformers-4.54.1-py3-none-any.whl", hash = "sha256:c89965a4f62a0d07009d45927a9c6372848a02ab9ead9c318c3d082708bab529", size = 11176397, upload-time = "2025-07-29T15:57:19.692Z" }, + { url = "https://files.pythonhosted.org/packages/71/d3/c16c3b3cf7655a67db1144da94b021c200ac1303f82428f2beef6c2e72bb/transformers-4.57.1-py3-none-any.whl", hash = "sha256:b10d05da8fa67dc41644dbbf9bc45a44cb86ae33da6f9295f5fbf5b7890bd267", size = 11990925, upload-time = "2025-10-14T15:39:23.085Z" }, ] [[package]] @@ -2250,6 +2279,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "uv" +version = "0.9.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/78/291b32fdcc774b8ba4a0f4570af44af6cd34ef7385537d6521c7e3280030/uv-0.9.8.tar.gz", hash = "sha256:99b18bfe92c33c3862b65d74677697e799763e669e0064685f405e7e27517f25", size = 3709979, upload-time = "2025-11-07T20:41:33.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/5d/4db5a4e72f70e15491ca33289092cd127d1220861bc647ebf743ea844cd7/uv-0.9.8-py3-none-linux_armv6l.whl", hash = "sha256:d93a2227d23e81ab3a16c30363559afc483e8aca40ea9343b3f326a9a41718c9", size = 20566439, upload-time = "2025-11-07T20:40:26.268Z" }, + { url = "https://files.pythonhosted.org/packages/e6/76/3ffedb2ba3adf71719996cb4c2660a333d2267503823a02e184a839e1d4e/uv-0.9.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7038a552159f2291dd0d1f4f66a36261b5f3ed5fcd92e2869186f8e910b2c935", size = 19705224, upload-time = "2025-11-07T20:40:31.384Z" }, + { url = "https://files.pythonhosted.org/packages/da/37/7716dd87189a6b062502ea41650eccd2473b6ee54b37cdf6e90a3b1aaa17/uv-0.9.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9f2f3576c4518ff4f15e48dbca70585a513523c4738bc8cc2e48b20fd1190ce3", size = 18213823, upload-time = "2025-11-07T20:40:34.962Z" }, + { url = "https://files.pythonhosted.org/packages/8d/ed/7aa302fac3d6c880df6bdbba3fb6b4d8cded023b1398f99576dcb103051a/uv-0.9.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:50d130c46d97d7f10675ebea8608b7b4722c84b5745cd1bb0c8ae6d7984c05d5", size = 20090145, upload-time = "2025-11-07T20:40:38.842Z" }, + { url = "https://files.pythonhosted.org/packages/72/d2/2539fe7ecf03f5fa3dfcc4c39f59ade412bd1b8e89c9ae026b5a2d7da3dd/uv-0.9.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6df2e16f6df32018047c60bab2c0284868ad5c309addba9183ea2eeb71746bf0", size = 20218906, upload-time = "2025-11-07T20:40:42.189Z" }, + { url = "https://files.pythonhosted.org/packages/f7/29/2923cd822b9a1dc9b99513a00d2102c7ef979ac3001e9541e72a1e7fca07/uv-0.9.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:543693def38fa41b9706aba391111fe8d9dd6be86899d76f9581faf045ac1cb6", size = 21061669, upload-time = "2025-11-07T20:40:47.663Z" }, + { url = "https://files.pythonhosted.org/packages/72/c6/46b9fe190e6fafb6bf04d870ccfd547e69aa79d0448a5c2c5799f1c0850e/uv-0.9.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1b8b5bdcda3e10ea70b618d0609acddc5c725cb58d4caf933030ddedd7c2e98f", size = 22668783, upload-time = "2025-11-07T20:40:51.172Z" }, + { url = "https://files.pythonhosted.org/packages/94/80/ec48165c76f863bbfcb0721aa1543cd3e7209c0cb8fdf89fe3d4e16694e2/uv-0.9.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a4010b3fdabbb3c4f2cf2f7aa3bf6002d00049dcbc54ce0ee5ada32a933b2290", size = 22319178, upload-time = "2025-11-07T20:40:54.719Z" }, + { url = "https://files.pythonhosted.org/packages/33/6c/2dbda528a2cd7a87a7363e8a9aad3033bff12c8b071a5e462eb852e704fd/uv-0.9.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75671150d6eb9d5ee829e1fdb8cf86b8e44a66d27cbb996fe807e986c4107b5d", size = 21398576, upload-time = "2025-11-07T20:40:58.509Z" }, + { url = "https://files.pythonhosted.org/packages/90/66/07e7067ace0886212217380b6e809f7dd1fed3d35c34be8d02124a656b17/uv-0.9.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14670bf55ecb5cfd0f3654fbf51c58a21dec3ad8ab531079b3ed8599271dc77b", size = 21346696, upload-time = "2025-11-07T20:41:01.931Z" }, + { url = "https://files.pythonhosted.org/packages/35/98/5b8fad804d17e76a2861c932009b0d34c7d5e3517923a808b168c2d92f2b/uv-0.9.8-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:40253d00c1e900a0a61b132b1e0dd4aa83575cfd5302d3671899b6de29b1ef67", size = 20159753, upload-time = "2025-11-07T20:41:05.51Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e4/32b74e9246e71f27b8710ba44be6bfd8bdcf552dce211cecd4d1061705cc/uv-0.9.8-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:f52c6a99197028a314d4c1825f7ccb696eb9a88b822d2e2f17046266c75e543e", size = 21299928, upload-time = "2025-11-07T20:41:09.285Z" }, + { url = "https://files.pythonhosted.org/packages/b2/35/003035bc2da31cc9925a62b1510a821d701c117cf0327ab0a1df5c83db34/uv-0.9.8-py3-none-musllinux_1_1_armv7l.whl", hash = "sha256:5af28f1645eb3c50fd34a78508792db2d0799816f4eb5f55e1e6e2c724dfb125", size = 20170593, upload-time = "2025-11-07T20:41:12.745Z" }, + { url = "https://files.pythonhosted.org/packages/d7/b4/8c3d7afdc87ef07b51b87646a4c75ee5209b7f9f99a33d54746b7ee0f157/uv-0.9.8-py3-none-musllinux_1_1_i686.whl", hash = "sha256:cdbfadca9522422ab9820f5ada071c9c5c869bcd6fee719d20d91d5ec85b2a7d", size = 20560556, upload-time = "2025-11-07T20:41:16.85Z" }, + { url = "https://files.pythonhosted.org/packages/64/43/6045bb0b69c788620df4750de57319f56a9b5bd02eef56f28af0de25c117/uv-0.9.8-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:87c3b65b6d5fcbdeab199d54c74fbf75de19cb534a690c936c5616478a038576", size = 21530469, upload-time = "2025-11-07T20:41:20.336Z" }, + { url = "https://files.pythonhosted.org/packages/96/a4/8bb8dca265df52abc405161f918225fbf156fc3a16f380a382a5cd52f992/uv-0.9.8-py3-none-win32.whl", hash = "sha256:0f03bc413c933dbf850ad0dc2dba3df6b80c860a5c65cd767add49da19dadef0", size = 19440191, upload-time = "2025-11-07T20:41:23.612Z" }, + { url = "https://files.pythonhosted.org/packages/6c/b6/9a2ed2c1cc86b967de82c20aeee2860f8771adbcf010061359f5406a6bed/uv-0.9.8-py3-none-win_amd64.whl", hash = "sha256:6a01d7cd41953ffac583139b10ad1df004a67c0246a6b694eb5bcdbc8c99deaf", size = 21491715, upload-time = "2025-11-07T20:41:27.181Z" }, + { url = "https://files.pythonhosted.org/packages/95/77/4a8f429c8d89a17a5327e7be8a7f3b72f7422b0acccfc378d424ca6dc0c9/uv-0.9.8-py3-none-win_arm64.whl", hash = "sha256:bb0f8e83c2a2fc5a802e930cc8a7b71ab068180300a3f27ba38037f9fcb3d430", size = 19865491, upload-time = "2025-11-07T20:41:30.62Z" }, +] + [[package]] name = "virtualenv" version = "20.32.0"