diff --git a/bergson/build.py b/bergson/build.py index 36d8f023..c5992245 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -157,6 +157,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD else: # Convert each shard to a Dataset then collect its gradients buf, shard_id = [], 0 + total_tokens_processed = 0 def flush(): nonlocal buf, shard_id @@ -180,6 +181,19 @@ def flush(): shard_id += 1 for ex in tqdm(ds, desc="Collecting gradients"): + # Check if adding this example would exceed max_tokens + if cfg.max_tokens is not None: + example_tokens = ex.get("length", 0) + if total_tokens_processed + example_tokens > cfg.max_tokens: + # Flush current buffer and stop processing + flush() + print( + f"Reached max_tokens limit ({cfg.max_tokens}). " + f"Processed {total_tokens_processed} tokens." + ) + break + total_tokens_processed += example_tokens + buf.append(ex) if len(buf) == cfg.stream_shard_size: flush() @@ -229,6 +243,33 @@ def build_gradient_dataset(cfg: IndexConfig): new_fingerprint="advantage", # type: ignore ) + # Apply max_tokens filtering if specified + if cfg.max_tokens is not None: + if isinstance(ds, Dataset): + # For non-streaming datasets, filter based on cumulative token count + total_tokens = 0 + indices_to_keep = [] + + for i, length in enumerate(ds["length"]): + if total_tokens + length <= cfg.max_tokens: + total_tokens += length + indices_to_keep.append(i) + else: + break + + if indices_to_keep: + ds = ds.select(indices_to_keep) + print( + f"Filtered dataset to {len(indices_to_keep)} examples " + f"with {total_tokens} tokens (target: {cfg.max_tokens})" + ) + else: + raise ValueError(f"No examples fit within max_tokens={cfg.max_tokens}") + else: + # For streaming datasets, max_tokens filtering is handled in the worker + # function during the streaming processing loop + pass + world_size = torch.cuda.device_count() if world_size <= 1: # Run the worker directly if no distributed training is needed. This is great diff --git a/bergson/data.py b/bergson/data.py index 29d5cb31..70be5b26 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -113,6 +113,9 @@ class IndexConfig: head_cfgs: dict[str, HeadConfig] = field(default_factory=dict) """Configuration for each attention module to be split into head matrices.""" + max_tokens: int | None = None + """Maximum number of tokens to process. If None, process all available tokens.""" + def ceildiv(a: int, b: int) -> int: """Ceiling division of two integers.""" diff --git a/examples/benchmark_bergson.py b/examples/benchmark_bergson.py new file mode 100644 index 00000000..a6cca308 --- /dev/null +++ b/examples/benchmark_bergson.py @@ -0,0 +1,431 @@ +"""Utilities for benchmarking Bergson influence analysis scaling (in-memory reduce + score).""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import textwrap +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from datasets import Dataset, load_dataset +from torch.distributed.fsdp import fully_shard +from transformers import AutoModelForCausalLM, AutoTokenizer + +from bergson.collection import pad_and_tensor +from bergson.gradients import GradientCollector, GradientProcessor +from bergson.utils import assert_type, get_layer_list + +# Import from same directory + +from examples.benchmark_common import ( + MODEL_SPECS, ModelSpec, DEFAULT_DATASET, format_tokens, parse_tokens, timestamp +) + +SCHEMA_VERSION = 1 +DEFAULT_TRAIN_SPLIT = "train" +DEFAULT_EVAL_SPLIT = "validation" + + +@dataclass +class RunRecord: + schema_version: int + status: str + model_key: str + model_name: str + params: float + train_tokens: int + eval_tokens: int + dataset: str + train_split: str + eval_split: str + batch_size: int + max_length: int + reduce_seconds: float | None # Time to collect training gradients + score_seconds: float | None # Time to compute inner products + total_runtime_seconds: float | None + start_time: str + end_time: str + run_path: str + notes: str | None + error: str | None + + +def ensure_run_path( + base: Path, + spec: ModelSpec, + train_tokens: int, + eval_tokens: int, + tag: str | None, +) -> Path: + train_label = format_tokens(train_tokens) + eval_label = format_tokens(eval_tokens) + run_tag = tag or datetime.utcnow().strftime("%Y%m%d-%H%M%S") + path = base / spec.key / f"{train_label}-{eval_label}-{run_tag}" + path.mkdir(parents=True, exist_ok=True) + return path + + +def save_record(path: Path, record: RunRecord) -> None: + with open(path / "benchmark.json", "w", encoding="utf-8") as fh: + json.dump(asdict(record), fh, indent=2) + + +def cmd_run(args: argparse.Namespace) -> None: + if args.model not in MODEL_SPECS: + raise ValueError(f"Unknown model '{args.model}'") + spec = MODEL_SPECS[args.model] + train_tokens = parse_tokens(args.train_tokens) + eval_tokens = parse_tokens(args.eval_tokens) + + # Enable FSDP for larger models (>= 1B parameters) or if explicitly requested + use_fsdp = args.fsdp or (spec.params >= 1_000_000_000) + + print( + f"Running Bergson benchmark for {args.model} with {train_tokens} train " + f"and {eval_tokens} eval tokens" + ) + + run_root = Path(args.run_root).resolve() + run_root.mkdir(parents=True, exist_ok=True) + run_path = ( + Path(args.run_path).resolve() + if args.run_path + else ensure_run_path(run_root, spec, train_tokens, eval_tokens, args.tag) + ) + + start_wall = timestamp() + start = time.perf_counter() + status = "success" + error_message: str | None = None + reduce_time: float | None = None + score_time: float | None = None + train_grads_flat: dict[str, torch.Tensor] = {} + + try: + # Set up distributed training if FSDP is enabled + # FSDP requires process group initialization even for single GPU + rank = 0 + world_size = 1 + if use_fsdp: + # Initialize process group for FSDP (even for single GPU) + if not dist.is_initialized(): + # Set environment variables for single-process initialization + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="env://", + rank=0, + world_size=1, + ) + print("Initialized process group for FSDP (single GPU)") + else: + print("Process group already initialized") + + # Load model and tokenizer + # For FSDP, load to CPU first, then wrap with FSDP + device_map = "cpu" if use_fsdp else "auto" + model = AutoModelForCausalLM.from_pretrained( + spec.hf_id, torch_dtype=torch.bfloat16, device_map=device_map + ) + + if not use_fsdp: + model.cuda() + else: + # Move to GPU 0 and wrap with FSDP + model = model.cuda() + + # Wrap model with FSDP + embed = model.get_input_embeddings() + model.requires_grad_(False) # Freeze the model + embed.requires_grad_(True) # Make sure backward hooks are called though + + # Shard each individual transformer layer + for layer in get_layer_list(model): + fully_shard(layer) + + # Shard the entire model + fully_shard(model) + + print("Model wrapped with FSDP") + + tokenizer = AutoTokenizer.from_pretrained(spec.hf_id) + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(batch): + encoded = tokenizer.batch_encode_plus( + batch["text"], + return_tensors="pt", + padding=True, + truncation=True, + max_length=args.max_length, + ) + # Add labels for loss computation + encoded["labels"] = encoded["input_ids"].clone() + return encoded + + # Load datasets + train_dataset = assert_type( + Dataset, load_dataset(args.dataset, split=args.train_split) + ) + + # Estimate examples needed based on token count + max_length = args.max_length or 512 + train_examples_needed = max(1, train_tokens // max_length) + eval_examples_needed = max(1, eval_tokens // max_length) + + # Select enough examples + total_needed = train_examples_needed + eval_examples_needed + train_dataset = train_dataset.select(range(min(total_needed, len(train_dataset)))) + + eval_dataset = train_dataset.select( + range(train_examples_needed, train_examples_needed + eval_examples_needed) + ) + train_dataset = train_dataset.select(range(train_examples_needed)) + + train_dataset = train_dataset.map(tokenize, batched=True) + eval_dataset = eval_dataset.map(tokenize, batched=True) + + train_dataset.set_format( + type="torch", columns=["input_ids", "attention_mask", "labels"] + ) + eval_dataset.set_format( + type="torch", columns=["input_ids", "attention_mask", "labels"] + ) + + # Create processor (no normalization, no preconditioners, no projection) + processor = GradientProcessor( + normalizers={}, # No normalization + projection_dim=None, # No projection + reshape_to_square=False, + projection_type="rademacher", + ) + + # REDUCE PHASE: Collect training gradients in-memory + print("Collecting training gradients (reduce phase)...") + reduce_start = time.perf_counter() + + train_grads = defaultdict(list) + + def train_callback(name: str, g: torch.Tensor): + # Flatten and store gradients in-memory + # No normalization, no preconditioning as per requirements + train_grads[name].append(g.flatten(1).cpu()) + + train_collector = GradientCollector( + model.base_model, + train_callback, + processor, + ) + + # Process training data in batches + for i in range(0, len(train_dataset), args.batch_size): + batch_indices = list(range(i, min(i + args.batch_size, len(train_dataset)))) + batch_items = [train_dataset[j] for j in batch_indices] + + # Extract and convert to lists for pad_and_tensor + input_ids_list = [item["input_ids"].cpu().tolist() if isinstance(item["input_ids"], torch.Tensor) else item["input_ids"] for item in batch_items] + labels_list = [item["labels"].cpu().tolist() if isinstance(item.get("labels"), torch.Tensor) else item.get("labels", item["input_ids"]) for item in batch_items] + + # Get the device + device = next(model.parameters()).device + + x, y = pad_and_tensor( + input_ids_list, + labels=labels_list, + device=device, + ) + + with train_collector: + logits = model(x).logits[:, :-1] + losses = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), + y[:, 1:].flatten(), + reduction="none", + ).reshape_as(y[:, 1:]) + # Mean reduction per example + masks = y[:, 1:] != -100 + denoms = masks.sum(dim=1, dtype=logits.dtype) + losses = losses.sum(1) / denoms + losses.mean().backward() + + model.zero_grad() + torch.cuda.synchronize() + + # Concatenate all training gradients + train_grads_flat = { + name: torch.cat(grads, dim=0) for name, grads in train_grads.items() + } + del train_grads + + reduce_time = time.perf_counter() - reduce_start + print(f"Reduce phase completed in {reduce_time:.2f} seconds") + print(f"Training gradients shape: {[(k, v.shape) for k, v in train_grads_flat.items()]}") + + # SCORE PHASE: Compute inner products with test gradients + print("Computing influence scores (score phase)...") + score_start = time.perf_counter() + + all_scores = [] + + for i, example in enumerate(eval_dataset): + # Get the device + device = next(model.parameters()).device + + input_ids = example["input_ids"].unsqueeze(0).to(device) + labels = example["labels"].unsqueeze(0).to(device) + + # Collect test gradient + test_grads = {} + + def test_callback(name: str, g: torch.Tensor): + test_grads[name] = g.flatten(1).cpu() + + test_collector = GradientCollector( + model.base_model, + test_callback, + processor, + ) + + with test_collector: + logits = model(input_ids).logits[:, :-1] + loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), + labels[:, 1:].flatten(), + reduction="mean", + ) + loss.backward() + + model.zero_grad() + torch.cuda.synchronize() + + # Compute inner products (no normalization, no preconditioning) + # Sum across all modules + scores = torch.zeros(len(train_dataset), device="cpu") + for name in test_grads: + if name in train_grads_flat: + # Inner product: test_grad @ train_grads^T + scores += (test_grads[name] @ train_grads_flat[name].T).squeeze(0) + + all_scores.append(scores) + + if i >= args.max_eval_examples - 1: + break + + score_time = time.perf_counter() - score_start + print(f"Score phase completed in {score_time:.2f} seconds") + print(f"Computed scores for {len(all_scores)} test examples") + + except Exception as exc: # noqa: BLE001 + status = "error" + error_message = repr(exc) + import traceback + traceback.print_exc() + + runtime = time.perf_counter() - start + end_wall = timestamp() + + record = RunRecord( + schema_version=SCHEMA_VERSION, + status=status, + model_key=spec.key, + model_name=spec.hf_id, + params=spec.params, + train_tokens=train_tokens, + eval_tokens=eval_tokens, + dataset=args.dataset, + train_split=args.train_split, + eval_split=args.eval_split, + batch_size=args.batch_size, + max_length=args.max_length or 512, + reduce_seconds=reduce_time, + score_seconds=score_time, + total_runtime_seconds=runtime, + start_time=start_wall, + end_time=end_wall, + run_path=str(run_path), + notes=args.notes, + error=error_message, + ) + save_record(run_path, record) + + print(json.dumps(asdict(record), indent=2)) + + # Clean up process group if we initialized it + if use_fsdp and dist.is_initialized(): + dist.destroy_process_group() + + if status != "success": + sys.exit(1) + + +def load_records(root: Path) -> list[RunRecord]: + records: list[RunRecord] = [] + for meta in root.rglob("benchmark.json"): + try: + with open(meta, "r", encoding="utf-8") as fh: + payload = json.load(fh) + records.append(RunRecord(**payload)) + except Exception as exc: # noqa: BLE001 + print(f"Warning: failed to read {meta}: {exc}", file=sys.stderr) + return records + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Benchmark Bergson influence analysis scaling (in-memory)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent( + """Examples: + python -m examples.benchmark_bergson run pythia-14m 1M 100K + python -m examples.benchmark_bergson run pythia-70m 5M 500K""" + ), + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + run_parser = subparsers.add_parser( + "run", help="Execute a single Bergson benchmark run" + ) + run_parser.add_argument("model", help="Key for the model to benchmark") + run_parser.add_argument( + "train_tokens", help="Target training tokens (e.g. 1M, 10M)" + ) + run_parser.add_argument( + "eval_tokens", help="Target evaluation tokens (e.g. 100K, 1M)" + ) + run_parser.add_argument("--batch-size", type=int, default=4) + run_parser.add_argument("--max-length", type=int, default=512) + run_parser.add_argument("--max-eval-examples", type=int, default=10) + run_parser.add_argument("--dataset", default=DEFAULT_DATASET) + run_parser.add_argument("--train-split", default=DEFAULT_TRAIN_SPLIT) + run_parser.add_argument("--eval-split", default=DEFAULT_EVAL_SPLIT) + run_parser.add_argument("--run-root", default="runs/bergson-scaling") + run_parser.add_argument("--run-path") + run_parser.add_argument("--tag") + run_parser.add_argument("--notes") + run_parser.add_argument( + "--fsdp", + action="store_true", + help="Enable FSDP (automatically enabled for models >= 1B parameters)", + ) + run_parser.set_defaults(func=cmd_run) + + args = parser.parse_args(argv) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_common.py b/examples/benchmark_common.py new file mode 100644 index 00000000..e9a3e9e5 --- /dev/null +++ b/examples/benchmark_common.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from datetime import datetime + +DEFAULT_DATASET = "EleutherAI/SmolLM2-135M-10B" + +@dataclass(frozen=True) +class ModelSpec: + key: str + hf_id: str + params: float + + +MODEL_SPECS: dict[str, ModelSpec] = { + "pythia-14m": ModelSpec("pythia-14m", "EleutherAI/pythia-14m", 14_000_000), + "pythia-70m": ModelSpec("pythia-70m", "EleutherAI/pythia-70m", 70_000_000), + "pythia-160m": ModelSpec("pythia-160m", "EleutherAI/pythia-160m", 160_000_000), + "pythia-1b": ModelSpec("pythia-1b", "EleutherAI/pythia-1b", 1_000_000_000), + "pythia-6.9b": ModelSpec("pythia-6.9b", "EleutherAI/pythia-6.9b", 6_900_000_000), + "pythia-12b": ModelSpec("pythia-12b", "EleutherAI/pythia-12b", 12_000_000_000), +} + +def timestamp() -> str: + return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" + +def format_tokens(tokens: int) -> str: + if tokens >= 1_000_000_000: + value = tokens / 1_000_000_000 + suffix = "B" + elif tokens >= 1_000_000: + value = tokens / 1_000_000 + suffix = "M" + elif tokens >= 1_000: + value = tokens / 1_000 + suffix = "K" + else: + return str(tokens) + if value.is_integer(): + return f"{int(value)}{suffix}" + return f"{value:.2f}{suffix}" + + +def parse_tokens(value: str) -> int: + text = value.strip().lower().replace(",", "") + if text.endswith("tokens"): + text = text[:-6] + if not text: + raise ValueError("empty token spec") + + suffixes = {"k": 1_000, "m": 1_000_000, "b": 1_000_000_000} + unit = 1 + if text[-1] in suffixes: + unit = suffixes[text[-1]] + text = text[:-1] + number = float(text) + return int(number * unit) diff --git a/examples/benchmark_dattri.py b/examples/benchmark_dattri.py new file mode 100644 index 00000000..1ca646ff --- /dev/null +++ b/examples/benchmark_dattri.py @@ -0,0 +1,295 @@ +"""Utilities for benchmarking Dattri influence analysis scaling.""" + +from __future__ import annotations + +import argparse +import json +import sys +import textwrap +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +from datasets import Dataset, load_dataset +from dattri.algorithm.base import BaseInnerProductAttributor +from dattri.task import AttributionTask +from transformers import AutoModelForCausalLM, AutoTokenizer + +from bergson.utils import assert_type + +# Import from same directory +from examples.benchmark_common import ( + MODEL_SPECS, ModelSpec, DEFAULT_DATASET, format_tokens, parse_tokens, timestamp +) + +SCHEMA_VERSION = 1 +DEFAULT_TRAIN_SPLIT = "train" +DEFAULT_EVAL_SPLIT = "validation" + + +@dataclass +class RunRecord: + schema_version: int + status: str + model_key: str + model_name: str + params: float + train_tokens: int + eval_tokens: int + dataset: str + train_split: str + eval_split: str + batch_size: int + max_length: int + runtime_seconds: float | None + start_time: str + end_time: str + run_path: str + notes: str | None + error: str | None + + +def ensure_run_path( + base: Path, + spec: ModelSpec, + train_tokens: int, + eval_tokens: int, + tag: str | None, +) -> Path: + train_label = format_tokens(train_tokens) + eval_label = format_tokens(eval_tokens) + run_tag = tag or datetime.utcnow().strftime("%Y%m%d-%H%M%S") + path = base / spec.key / f"{train_label}-{eval_label}-{run_tag}" + path.mkdir(parents=True, exist_ok=True) + return path + + +def save_record(path: Path, record: RunRecord) -> None: + with open(path / "benchmark.json", "w", encoding="utf-8") as fh: + json.dump(asdict(record), fh, indent=2) + + +def cmd_run(args: argparse.Namespace) -> None: + if args.model not in MODEL_SPECS: + raise ValueError(f"Unknown model '{args.model}'") + spec = MODEL_SPECS[args.model] + train_tokens = parse_tokens(args.train_tokens) + eval_tokens = parse_tokens(args.eval_tokens) + print( + f"Running Dattri benchmark for {args.model} with {train_tokens} train " + f"and {eval_tokens} eval tokens" + ) + + run_root = Path(args.run_root).resolve() + run_root.mkdir(parents=True, exist_ok=True) + run_path = ( + Path(args.run_path).resolve() + if args.run_path + else ensure_run_path(run_root, spec, train_tokens, eval_tokens, args.tag) + ) + + start_wall = timestamp() + start = time.perf_counter() + status = "success" + error_message: str | None = None + + try: + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained( + spec.hf_id, torch_dtype=torch.bfloat16, device_map="auto" + ) + model.cuda() + + tokenizer = AutoTokenizer.from_pretrained(spec.hf_id) + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(batch): + return tokenizer.batch_encode_plus( + batch["text"], + return_tensors="pt", + padding=True, + truncation=True, + max_length=args.max_length, + ) + + # Load datasets + train_dataset = assert_type( + Dataset, load_dataset(args.dataset, split=args.train_split) + ) + + # Estimate examples needed based on token count + # We'll sample until we have enough tokens + max_length = args.max_length or 512 + train_examples_needed = max(1, train_tokens // max_length) + eval_examples_needed = max(1, eval_tokens // max_length) + + # Select enough examples + total_needed = train_examples_needed + eval_examples_needed + train_dataset = train_dataset.select(range(min(total_needed, len(train_dataset)))) + + eval_dataset = train_dataset.select( + range(train_examples_needed, train_examples_needed + eval_examples_needed) + ) + train_dataset = train_dataset.select(range(train_examples_needed)) + + train_dataset = train_dataset.map(tokenize, batched=True) + eval_dataset = eval_dataset.map(tokenize, batched=True) + + train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + + def collate_fn(batch): + # Dattri expects tuples of (input_ids, labels) where labels = input_ids for language modeling + # Keep on CPU - dattri will handle device placement + input_ids = torch.stack([item["input_ids"] for item in batch]) + labels = input_ids.clone() # For language modeling, labels are the same as input_ids + return (input_ids, labels) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + ) + test_loader = torch.utils.data.DataLoader( + eval_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + ) + + # Get model device + model_device = next(model.parameters()).device + + def loss_func(params, data_target_pair): + x, y = data_target_pair + # Ensure data is on the same device as model + if isinstance(x, torch.Tensor) and x.device != model_device: + x = x.to(model_device) + if isinstance(y, torch.Tensor) and y.device != model_device: + y = y.to(model_device) + # functional_call returns a tuple for transformers models, extract logits + output = torch.func.functional_call(model, params, (x,)) + if isinstance(output, tuple): + logits = output[0] # First element is logits + else: + logits = output.logits if hasattr(output, 'logits') else output + shift_logits = logits[:, :-1].contiguous() + shift_labels = y[:, 1:].contiguous() + loss = nn.CrossEntropyLoss()( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + return loss + + # Create task + task = AttributionTask( + loss_func=loss_func, + model=model, + checkpoints=model.state_dict(), + ) + + # Create attributor and cache + # Try to set device if BaseInnerProductAttributor supports it + try: + attributor = BaseInnerProductAttributor(task=task, device="cuda") + except TypeError: + # Device parameter not supported, use default + attributor = BaseInnerProductAttributor(task=task) + print("Caching training data...") + attributor.cache(train_loader) + + # Compute attributions + print("Computing attributions...") + with torch.no_grad(): + scores = attributor.attribute(train_loader, test_loader) + + except Exception as exc: # noqa: BLE001 + status = "error" + error_message = repr(exc) + import traceback + traceback.print_exc() + + runtime = time.perf_counter() - start + end_wall = timestamp() + + record = RunRecord( + schema_version=SCHEMA_VERSION, + status=status, + model_key=spec.key, + model_name=spec.hf_id, + params=spec.params, + train_tokens=train_tokens, + eval_tokens=eval_tokens, + dataset=args.dataset, + train_split=args.train_split, + eval_split=args.eval_split, + batch_size=args.batch_size, + max_length=args.max_length or 512, + runtime_seconds=runtime, + start_time=start_wall, + end_time=end_wall, + run_path=str(run_path), + notes=args.notes, + error=error_message, + ) + save_record(run_path, record) + + print(json.dumps(asdict(record), indent=2)) + + if status != "success": + sys.exit(1) + + +def load_records(root: Path) -> list[RunRecord]: + records: list[RunRecord] = [] + for meta in root.rglob("benchmark.json"): + try: + with open(meta, "r", encoding="utf-8") as fh: + payload = json.load(fh) + records.append(RunRecord(**payload)) + except Exception as exc: # noqa: BLE001 + print(f"Warning: failed to read {meta}: {exc}", file=sys.stderr) + return records + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Benchmark Dattri influence analysis scaling", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent( + """Examples: + python examples/benchmark_dattri.py run pythia-14m 1M 100K + python examples/benchmark_dattri.py run pythia-70m 5M 500K""" + ), + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + run_parser = subparsers.add_parser( + "run", help="Execute a single Dattri benchmark run" + ) + run_parser.add_argument("model", help="Key for the model to benchmark") + run_parser.add_argument( + "train_tokens", help="Target training tokens (e.g. 1M, 10M)" + ) + run_parser.add_argument( + "eval_tokens", help="Target evaluation tokens (e.g. 100K, 1M)" + ) + run_parser.add_argument("--batch-size", type=int, default=4) + run_parser.add_argument("--max-length", type=int, default=512) + run_parser.add_argument("--dataset", default=DEFAULT_DATASET) + run_parser.add_argument("--train-split", default=DEFAULT_TRAIN_SPLIT) + run_parser.add_argument("--eval-split", default=DEFAULT_EVAL_SPLIT) + run_parser.add_argument("--run-root", default="runs/dattri-scaling") + run_parser.add_argument("--run-path") + run_parser.add_argument("--tag") + run_parser.add_argument("--notes") + run_parser.set_defaults(func=cmd_run) + + args = parser.parse_args(argv) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/examples/kronfluence_benchmark.py b/examples/kronfluence_benchmark.py new file mode 100644 index 00000000..72620d27 --- /dev/null +++ b/examples/kronfluence_benchmark.py @@ -0,0 +1,619 @@ +"""Utilities for benchmarking Kronfluence influence analysis scaling.""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +import textwrap +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, load_dataset +from kronfluence.analyzer import Analyzer, prepare_model +from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.task import Task +from matplotlib import pyplot as plt +from transformers import AutoModelForCausalLM, AutoTokenizer + +from bergson.utils import assert_type +from benchmark_common import MODEL_SPECS, ModelSpec, DEFAULT_DATASET + +SCHEMA_VERSION = 1 +DEFAULT_TRAIN_SPLIT = "train" +DEFAULT_EVAL_SPLIT = "validation" + + +class LossTask(Task): + def compute_train_loss( + self, + batch: Any, + model: nn.Module, + sample: bool = False, + ) -> torch.Tensor: + input_ids = batch["input_ids"].cuda() + output = model(input_ids, batch["attention_mask"].cuda()) + loss = F.cross_entropy( + output.logits[:, :-1].flatten(0, 1), input_ids[:, 1:].flatten(0, 1) + ) + return loss + + def compute_measurement( + self, + batch: Any, + model: nn.Module, + ) -> torch.Tensor: + return self.compute_train_loss(batch, model) + + +@dataclass +class RunRecord: + schema_version: int + status: str + model_key: str + model_name: str + params: float + train_examples: int + eval_examples: int + dataset: str + train_split: str + eval_split: str + factors_name: str + scores_name: str + strategy: str + use_empirical_fisher: bool + covariance_max_examples: int + per_device_batch_size: int + per_device_query_batch_size: int + per_device_train_batch_size: int + amp_dtype: str + activation_covariance_dtype: str + gradient_covariance_dtype: str + per_sample_gradient_dtype: str + score_dtype: str + offload_activations_to_cpu: bool + runtime_seconds: float | None + start_time: str + end_time: str + run_path: str + notes: str | None + error: str | None + + +def parse_examples(value: str) -> int: + text = value.strip().lower().replace(",", "") + if text.endswith("examples"): + text = text[:-8] + if not text: + raise ValueError("empty example spec") + + suffixes = {"k": 1_000, "m": 1_000_000, "b": 1_000_000_000} + unit = 1 + if text[-1] in suffixes: + unit = suffixes[text[-1]] + text = text[:-1] + number = float(text) + return int(number * unit) + + +def format_examples(examples: int) -> str: + if examples >= 1_000_000_000: + value = examples / 1_000_000_000 + suffix = "B" + elif examples >= 1_000_000: + value = examples / 1_000_000 + suffix = "M" + elif examples >= 1_000: + value = examples / 1_000 + suffix = "K" + else: + return str(examples) + if value.is_integer(): + return f"{int(value)}{suffix}" + return f"{value:.2f}{suffix}" + + +def timestamp() -> str: + return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" + + +def ensure_run_path( + base: Path, + spec: ModelSpec, + train_examples: int, + eval_examples: int, + tag: str | None, +) -> Path: + train_label = format_examples(train_examples) + eval_label = format_examples(eval_examples) + run_tag = tag or datetime.utcnow().strftime("%Y%m%d-%H%M%S") + path = base / spec.key / f"{train_label}-{eval_label}-{run_tag}" + path.mkdir(parents=True, exist_ok=True) + return path + + +def save_record(path: Path, record: RunRecord) -> None: + with open(path / "benchmark.json", "w", encoding="utf-8") as fh: + json.dump(asdict(record), fh, indent=2) + + +def load_records(root: Path) -> list[RunRecord]: + records: list[RunRecord] = [] + for meta in root.rglob("benchmark.json"): + try: + with open(meta, "r", encoding="utf-8") as fh: + payload = json.load(fh) + records.append(RunRecord(**payload)) + except Exception as exc: # noqa: BLE001 + print(f"Warning: failed to read {meta}: {exc}", file=sys.stderr) + return records + + +def summarize_records(records: Iterable[RunRecord]) -> pd.DataFrame: + if not records: + return pd.DataFrame() + df = pd.DataFrame([asdict(r) for r in records]) + if "params" in df.columns: + df["params_b"] = df["params"] / 1_000_000_000 + return df + + +def estimate_scaling(df: pd.DataFrame) -> tuple[pd.DataFrame, dict[str, float]]: + subset = df.query("status == 'success' and runtime_seconds.notnull()") + if subset.empty: + raise ValueError("No successful runs with recorded runtime found.") + + X = np.column_stack( + [ + np.ones(len(subset)), + np.log(subset["train_examples"].astype(float)), + np.log(subset["eval_examples"].astype(float)), + np.log(subset["params"].astype(float)), + ] + ) + y = np.log(subset["runtime_seconds"].astype(float)) + coeffs, *_ = np.linalg.lstsq(X, y, rcond=None) + log_pred = X @ coeffs + subset = subset.copy() + subset["runtime_pred"] = np.exp(log_pred) + + resid = y - log_pred + ss_res = np.sum(resid**2) + ss_tot = np.sum((y - y.mean()) ** 2) + r2 = 1.0 - ss_res / ss_tot if ss_tot else float("nan") + + params = { + "log_scale": float(coeffs[0]), + "beta_train_examples": float(coeffs[1]), + "beta_eval_examples": float(coeffs[2]), + "beta_params": float(coeffs[3]), + "scale": float(math.exp(coeffs[0])), + "r2": float(r2), + "num_samples": int(len(subset)), + } + return subset, params + + +def plot_scaling(df: pd.DataFrame, out_path: Path) -> None: + if df.empty: + return + fig, ax = plt.subplots(figsize=(8, 6)) + for model_key, group in df.groupby("model_key"): + grp = group.sort_values("train_examples") + ax.plot( + grp["train_examples"], + grp["runtime_seconds"], + marker="o", + linewidth=1.5, + label=model_key, + ) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Training examples") + ax.set_ylabel("Wall clock time (s)") + ax.set_title("Kronfluence influence analysis scaling") + ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.6) + ax.legend(title="Model", fontsize="small") + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=200) + plt.close(fig) + + +def cmd_run(args: argparse.Namespace) -> None: + if args.model not in MODEL_SPECS: + raise ValueError(f"Unknown model '{args.model}'") + spec = MODEL_SPECS[args.model] + train_examples = parse_examples(args.train_examples) + eval_examples = parse_examples(args.eval_examples) + print( + f"Running Kronfluence benchmark for {args.model} with {train_examples} train " + f"and {eval_examples} eval examples" + ) + + run_root = Path(args.run_root).resolve() + run_root.mkdir(parents=True, exist_ok=True) + run_path = ( + Path(args.run_path).resolve() + if args.run_path + else ensure_run_path(run_root, spec, train_examples, eval_examples, args.tag) + ) + + start_wall = timestamp() + start = time.perf_counter() + status = "success" + error_message: str | None = None + + try: + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained( + spec.hf_id, torch_dtype=torch.bfloat16, device_map="auto" + ) + model.cuda() + + tokenizer = AutoTokenizer.from_pretrained(spec.hf_id) + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(batch): + return tokenizer.batch_encode_plus( + batch["text"], + return_tensors="pt", + padding=True, + truncation=True, + max_length=args.max_length, + ) + + # Load datasets + train_dataset = assert_type( + Dataset, load_dataset(args.dataset, split=args.train_split) + ) + train_dataset = train_dataset.select(range(train_examples + eval_examples)) + + eval_dataset = train_dataset.select( + range(train_examples, train_examples + eval_examples) + ) + train_dataset = train_dataset.select(range(train_examples)) + + train_dataset = train_dataset.map(tokenize, batched=True) + eval_dataset = eval_dataset.map(tokenize, batched=True) + + train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + + # Set up Kronfluence + task = LossTask() + model = prepare_model(model=model, task=task) + analyzer = Analyzer(analysis_name=args.analysis_name, model=model, task=task) + + # Fit factors + analyzer.fit_all_factors( + factors_name=args.factors_name, + dataset=train_dataset, + per_device_batch_size=args.per_device_batch_size, + overwrite_output_dir=True, + factor_args=FactorArguments( + strategy=args.strategy, + use_empirical_fisher=args.use_empirical_fisher, + covariance_max_examples=args.covariance_max_examples, + amp_dtype=getattr(torch, args.amp_dtype), + activation_covariance_dtype=getattr( + torch, args.activation_covariance_dtype + ), + gradient_covariance_dtype=getattr( + torch, args.gradient_covariance_dtype + ), + ), + ) + + if args.do_query: + # Compute pairwise scores + analyzer.compute_pairwise_scores( + scores_name=args.scores_name, + factors_name=args.factors_name, + query_dataset=eval_dataset, + train_dataset=train_dataset, + per_device_query_batch_size=args.per_device_query_batch_size, + per_device_train_batch_size=args.per_device_train_batch_size, + score_args=ScoreArguments( + amp_dtype=getattr(torch, args.amp_dtype), + per_sample_gradient_dtype=getattr( + torch, args.per_sample_gradient_dtype + ), + score_dtype=getattr(torch, args.score_dtype), + offload_activations_to_cpu=args.offload_activations_to_cpu, + ), + ) + + # Load scores to verify completion + # scores = analyzer.load_pairwise_scores(scores_name=args.scores_name) + + except Exception as exc: # noqa: BLE001 + status = "error" + error_message = repr(exc) + + runtime = time.perf_counter() - start + end_wall = timestamp() + + record = RunRecord( + schema_version=SCHEMA_VERSION, + status=status, + model_key=spec.key, + model_name=spec.hf_id, + params=spec.params, + train_examples=train_examples, + eval_examples=eval_examples, + dataset=args.dataset, + train_split=args.train_split, + eval_split=args.eval_split, + factors_name=args.factors_name, + scores_name=args.scores_name, + strategy=args.strategy, + use_empirical_fisher=args.use_empirical_fisher, + covariance_max_examples=args.covariance_max_examples, + per_device_batch_size=args.per_device_batch_size, + per_device_query_batch_size=args.per_device_query_batch_size, + per_device_train_batch_size=args.per_device_train_batch_size, + amp_dtype=args.amp_dtype, + activation_covariance_dtype=args.activation_covariance_dtype, + gradient_covariance_dtype=args.gradient_covariance_dtype, + per_sample_gradient_dtype=args.per_sample_gradient_dtype, + score_dtype=args.score_dtype, + offload_activations_to_cpu=args.offload_activations_to_cpu, + runtime_seconds=runtime, + start_time=start_wall, + end_time=end_wall, + run_path=str(run_path), + notes=args.notes, + error=error_message, + ) + save_record(run_path, record) + + print(json.dumps(asdict(record), indent=2)) + + if status != "success": + sys.exit(1) + + +def default_train_examples() -> list[str]: + return ["1K", "10K", "100K", "1M"] + + +def default_eval_examples() -> list[str]: + return ["100", "1K", "10K"] + + +def existing_success_lookup( + records: Iterable[RunRecord], +) -> set[tuple[str, int, int, str]]: + return { + (r.model_key, r.train_examples, r.eval_examples, r.strategy) + for r in records + if r.status == "success" + } + + +def cmd_commands(args: argparse.Namespace) -> None: + train_examples = [parse_examples(tok) for tok in args.train_examples] + eval_examples = [parse_examples(tok) for tok in args.eval_examples] + models = args.models or list(MODEL_SPECS.keys()) + + run_root = Path(args.run_root).resolve() + records = load_records(run_root) + seen = existing_success_lookup(records) + + for model_key in models: + if model_key not in MODEL_SPECS: + raise ValueError(f"Unknown model '{model_key}'") + for train_ex in train_examples: + for eval_ex in eval_examples: + key = (model_key, train_ex, eval_ex, args.strategy) + if key in seen and not args.include_completed: + continue + pieces = [ + "python", + "examples/kronfluence_benchmark.py", + "run", + model_key, + format_examples(train_ex), + format_examples(eval_ex), + ] + if args.tag_prefix: + pieces.extend( + [ + "--tag", + f"{args.tag_prefix}{format_examples(train_ex)}-{format_examples(eval_ex)}", + ] + ) + if args.do_query: + pieces.append("--do-query") + if args.use_empirical_fisher: + pieces.append("--use-empirical-fisher") + if args.offload_activations_to_cpu: + pieces.append("--offload-activations-to-cpu") + if args.max_length is not None: + pieces.extend(["--max-length", str(args.max_length)]) + if args.covariance_max_examples is not None: + pieces.extend( + ["--covariance-max-examples", str(args.covariance_max_examples)] + ) + print(" ".join(pieces)) + + +def cmd_fit(args: argparse.Namespace) -> None: + run_root = Path(args.run_root).resolve() + records = load_records(run_root) + df = summarize_records(records) + + if df.empty: + print("No benchmark records found.") + return + + # Select dfs where error is None + df = df.query("error.isna()") + + df_path = Path(args.output_table).resolve() + df_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(df_path, index=False) + print(f"Wrote combined table to {df_path}") + + try: + subset, params = estimate_scaling(df) + except ValueError as exc: + print(f"Skipping fit: {exc}") + return + + fit_path = Path(args.fit_output).resolve() + fit_path.parent.mkdir(parents=True, exist_ok=True) + with open(fit_path, "w", encoding="utf-8") as fh: + json.dump(params, fh, indent=2) + print(f"Saved scaling fit parameters to {fit_path}") + + plot_path = Path(args.plot_output).resolve() + plot_scaling(subset, plot_path) + print(f"Saved scaling plot to {plot_path}") + + +def add_common_run_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("model", help="Key for the model to benchmark") + parser.add_argument( + "train_examples", help="Target training examples (e.g. 10K, 1M)" + ) + parser.add_argument( + "eval_examples", help="Target evaluation examples (e.g. 100, 1K)" + ) + parser.add_argument( + "--strategy", default="diagonal", choices=["diagonal", "kfac", "ekfac"] + ) + parser.add_argument("--use-empirical-fisher", action="store_true") + parser.add_argument("--covariance-max-examples", type=int, default=100) + parser.add_argument("--per-device-batch-size", type=int, default=1) + parser.add_argument("--per-device-query-batch-size", type=int, default=1) + parser.add_argument("--per-device-train-batch-size", type=int, default=1) + parser.add_argument( + "--amp-dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"] + ) + parser.add_argument( + "--activation-covariance-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument( + "--gradient-covariance-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument( + "--per-sample-gradient-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument( + "--score-dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"] + ) + parser.add_argument("--offload-activations-to-cpu", action="store_true") + parser.add_argument("--dataset", default=DEFAULT_DATASET) + parser.add_argument("--train-split", default=DEFAULT_TRAIN_SPLIT) + parser.add_argument("--eval-split", default=DEFAULT_EVAL_SPLIT) + parser.add_argument("--analysis-name", default="kronfluence_benchmark") + parser.add_argument("--factors-name", default="my_factors") + parser.add_argument("--scores-name", default="my_scores") + parser.add_argument("--run-root", default="runs/kronfluence-scaling") + parser.add_argument("--run-path") + parser.add_argument("--tag") + parser.add_argument("--max-length", type=int) + parser.add_argument("--notes") + parser.add_argument("--do-query", action="store_true") + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Benchmark Kronfluence influence analysis scaling", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent( + """Examples:\n" + " python examples/kronfluence_benchmark.py run pythia-160m 10K 1K\n" + " python examples/kronfluence_benchmark.py commands --tag-prefix exp-\n" + " python examples/kronfluence_benchmark.py fit""" + ), + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + run_parser = subparsers.add_parser( + "run", help="Execute a single Kronfluence benchmark run" + ) + add_common_run_args(run_parser) + run_parser.set_defaults(func=cmd_run) + + cmd_parser = subparsers.add_parser("commands", help="List run commands") + cmd_parser.add_argument( + "--train-examples", nargs="*", default=default_train_examples() + ) + cmd_parser.add_argument( + "--eval-examples", nargs="*", default=default_eval_examples() + ) + cmd_parser.add_argument("--models", nargs="*") + cmd_parser.add_argument( + "--strategy", default="diagonal", choices=["diagonal", "kfac", "ekfac"] + ) + cmd_parser.add_argument("--use-empirical-fisher", action="store_true") + cmd_parser.add_argument("--covariance-max-examples", type=int) + cmd_parser.add_argument("--per-device-batch-size", type=int, default=1) + cmd_parser.add_argument("--per-device-query-batch-size", type=int, default=1) + cmd_parser.add_argument("--per-device-train-batch-size", type=int, default=1) + cmd_parser.add_argument( + "--amp-dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"] + ) + cmd_parser.add_argument( + "--activation-covariance-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + cmd_parser.add_argument( + "--gradient-covariance-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + cmd_parser.add_argument( + "--per-sample-gradient-dtype", + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + cmd_parser.add_argument( + "--score-dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"] + ) + cmd_parser.add_argument("--offload-activations-to-cpu", action="store_true") + cmd_parser.add_argument("--dataset", default=DEFAULT_DATASET) + cmd_parser.add_argument("--train-split", default=DEFAULT_TRAIN_SPLIT) + cmd_parser.add_argument("--eval-split", default=DEFAULT_EVAL_SPLIT) + cmd_parser.add_argument("--analysis-name", default="kronfluence_benchmark") + cmd_parser.add_argument("--factors-name", default="my_factors") + cmd_parser.add_argument("--scores-name", default="my_scores") + cmd_parser.add_argument("--run-root", default="runs/kronfluence-scaling") + cmd_parser.add_argument("--tag-prefix") + cmd_parser.add_argument("--include-completed", action="store_true") + cmd_parser.add_argument("--max-length", type=int) + cmd_parser.add_argument("--do-query", action="store_true") + cmd_parser.set_defaults(func=cmd_commands) + + fit_parser = subparsers.add_parser("fit", help="Aggregate results and fit scaling") + fit_parser.add_argument("--run-root", default="runs/kronfluence-scaling") + fit_parser.add_argument("--output-table", default="data/kronfluence_scaling.csv") + fit_parser.add_argument("--fit-output", default="data/kronfluence_scaling_fit.json") + fit_parser.add_argument("--plot-output", default="figures/kronfluence_scaling.png") + fit_parser.set_defaults(func=cmd_fit) + + args = parser.parse_args(argv) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/examples/run_full_benchmark.py b/examples/run_full_benchmark.py new file mode 100644 index 00000000..6b6a1953 --- /dev/null +++ b/examples/run_full_benchmark.py @@ -0,0 +1,396 @@ +"""Coordinate running dattri and bergson benchmarks and generate comparison plots.""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from pathlib import Path +from typing import Any + +import pandas as pd +from matplotlib import pyplot as plt + +# Import from same directory +from examples.benchmark_common import ( + MODEL_SPECS, ModelSpec, DEFAULT_DATASET, format_tokens, parse_tokens, timestamp +) +from examples.benchmark_dattri import load_records as load_dattri_records +from examples.benchmark_bergson import load_records as load_bergson_records +from examples.benchmark_dattri import RunRecord as DattriRecord +from examples.benchmark_bergson import RunRecord as BergsonRecord + +def run_benchmark( + method: str, + model: str, + train_tokens: int, + eval_tokens: int, + run_root: str, + **kwargs: Any, +) -> bool: + """Run a single benchmark.""" + if method == "dattri": + cmd = [ + sys.executable, + "-m", + "examples.benchmark_dattri", + "run", + model, + format_tokens(train_tokens), + format_tokens(eval_tokens), + "--run-root", + run_root, + ] + elif method == "bergson": + cmd = [ + sys.executable, + "-m", + "examples.benchmark_bergson", + "run", + model, + format_tokens(train_tokens), + format_tokens(eval_tokens), + "--run-root", + run_root, + ] + if "max_eval_examples" in kwargs: + cmd.extend(["--max-eval-examples", str(kwargs["max_eval_examples"])]) + # Enable FSDP for larger models (>= 1B parameters) + if model in MODEL_SPECS and MODEL_SPECS[model].params >= 1_000_000_000: + cmd.append("--fsdp") + else: + raise ValueError(f"Unknown method: {method}") + + if "batch_size" in kwargs: + cmd.extend(["--batch-size", str(kwargs["batch_size"])]) + if "max_length" in kwargs: + cmd.extend(["--max-length", str(kwargs["max_length"])]) + + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"Error running {method} benchmark:") + print(result.stderr) + return False + + print(f"Successfully ran {method} benchmark") + return True + + +def load_all_records( + dattri_root: Path, + bergson_root: Path, +) -> tuple[list[DattriRecord], list[BergsonRecord]]: + """Load all benchmark records.""" + dattri_records = load_dattri_records(dattri_root) if dattri_root.exists() else [] + bergson_records = load_bergson_records(bergson_root) if bergson_root.exists() else [] + return dattri_records, bergson_records + + +def create_comparison_dataframe( + dattri_records: list[DattriRecord], + bergson_records: list[BergsonRecord], +) -> pd.DataFrame: + """Create a combined dataframe for comparison.""" + rows = [] + + # Add dattri records + for r in dattri_records: + if r.status == "success" and r.runtime_seconds is not None: + rows.append({ + "method": "dattri", + "model_key": r.model_key, + "model_params": r.params, + "train_tokens": r.train_tokens, + "eval_tokens": r.eval_tokens, + "runtime_seconds": r.runtime_seconds, + "reduce_seconds": None, # Dattri doesn't separate reduce/score + "score_seconds": None, + }) + + # Add bergson records + for r in bergson_records: + if r.status == "success" and r.total_runtime_seconds is not None: + rows.append({ + "method": "bergson", + "model_key": r.model_key, + "model_params": r.params, + "train_tokens": r.train_tokens, + "eval_tokens": r.eval_tokens, + "runtime_seconds": r.total_runtime_seconds, + "reduce_seconds": r.reduce_seconds, + "score_seconds": r.score_seconds, + }) + + return pd.DataFrame(rows) + + +def plot_comparison(df: pd.DataFrame, output_path: Path) -> None: + """Create comparison plots.""" + if df.empty: + print("No data to plot") + return + + # Filter successful runs + df = df[df["runtime_seconds"].notna()] + + # Create figure with subplots + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + + # Plot 1: Runtime vs train tokens (by model) + ax1 = axes[0, 0] + for method in df["method"].unique(): + for model_key in df["model_key"].unique(): + subset = df[(df["method"] == method) & (df["model_key"] == model_key)] + if not subset.empty: + subset = subset.sort_values("train_tokens") + ax1.plot( + subset["train_tokens"], + subset["runtime_seconds"], + marker="o", + label=f"{method}-{model_key}", + linewidth=1.5, + ) + ax1.set_xscale("log") + ax1.set_yscale("log") + ax1.set_xlabel("Training Tokens") + ax1.set_ylabel("Total Runtime (seconds)") + ax1.set_title("Runtime Scaling: Total Time") + ax1.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.6) + ax1.legend(fontsize="small", ncol=2) + + # Plot 2: Runtime vs model params (by token scale) + ax2 = axes[0, 1] + for method in df["method"].unique(): + for train_tokens in sorted(df["train_tokens"].unique())[:5]: # Top 5 token scales + subset = df[(df["method"] == method) & (df["train_tokens"] == train_tokens)] + if not subset.empty: + subset = subset.sort_values("model_params") + ax2.plot( + subset["model_params"], + subset["runtime_seconds"], + marker="o", + label=f"{method}-{format_tokens(train_tokens)}", + linewidth=1.5, + ) + ax2.set_xscale("log") + ax2.set_yscale("log") + ax2.set_xlabel("Model Parameters") + ax2.set_ylabel("Total Runtime (seconds)") + ax2.set_title("Runtime Scaling: Model Size") + ax2.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.6) + ax2.legend(fontsize="small", ncol=2) + + # Plot 3: Bergson reduce vs score breakdown + ax3 = axes[1, 0] + bergson_df = df[df["method"] == "bergson"] + if not bergson_df.empty and bergson_df["reduce_seconds"].notna().any(): + for model_key in bergson_df["model_key"].unique(): + subset = bergson_df[bergson_df["model_key"] == model_key].sort_values("train_tokens") + if subset["reduce_seconds"].notna().any(): + ax3.plot( + subset["train_tokens"], + subset["reduce_seconds"], + marker="s", + label=f"{model_key} (reduce)", + linewidth=1.5, + linestyle="-", + ) + if subset["score_seconds"].notna().any(): + ax3.plot( + subset["train_tokens"], + subset["score_seconds"], + marker="^", + label=f"{model_key} (score)", + linewidth=1.5, + linestyle="--", + ) + ax3.set_xscale("log") + ax3.set_yscale("log") + ax3.set_xlabel("Training Tokens") + ax3.set_ylabel("Runtime (seconds)") + ax3.set_title("Bergson: Reduce vs Score Breakdown") + ax3.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.6) + ax3.legend(fontsize="small") + + # Plot 4: Speedup comparison (dattri / bergson) + ax4 = axes[1, 1] + speedup_data = [] + for model_key in df["model_key"].unique(): + for train_tokens in df["train_tokens"].unique(): + dattri_subset = df[(df["method"] == "dattri") & (df["model_key"] == model_key) & (df["train_tokens"] == train_tokens)] + bergson_subset = df[(df["method"] == "bergson") & (df["model_key"] == model_key) & (df["train_tokens"] == train_tokens)] + + if not dattri_subset.empty and not bergson_subset.empty: + dattri_time = dattri_subset["runtime_seconds"].iloc[0] + bergson_time = bergson_subset["runtime_seconds"].iloc[0] + speedup = dattri_time / bergson_time if bergson_time > 0 else None + if speedup is not None: + speedup_data.append({ + "model_key": model_key, + "train_tokens": train_tokens, + "speedup": speedup, + }) + + if speedup_data: + speedup_df = pd.DataFrame(speedup_data) + for model_key in speedup_df["model_key"].unique(): + subset = speedup_df[speedup_df["model_key"] == model_key].sort_values("train_tokens") + ax4.plot( + subset["train_tokens"], + subset["speedup"], + marker="o", + label=model_key, + linewidth=1.5, + ) + ax4.axhline(y=1.0, color="black", linestyle="--", linewidth=1, alpha=0.5) + ax4.set_xscale("log") + ax4.set_xlabel("Training Tokens") + ax4.set_ylabel("Speedup (dattri / bergson)") + ax4.set_title("Relative Performance: Dattri vs Bergson") + ax4.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.6) + ax4.legend(fontsize="small") + + plt.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=200) + plt.close() + print(f"Saved comparison plot to {output_path}") + + +def cmd_run(args: argparse.Namespace) -> None: + """Run benchmarks for specified models and token scales.""" + models = args.models or ["pythia-14m", "pythia-70m"] + token_scales = [parse_tokens(ts) for ts in args.token_scales] + eval_tokens = parse_tokens(args.eval_tokens) + + dattri_root = Path(args.run_root) / "dattri-scaling" + bergson_root = Path(args.run_root) / "bergson-scaling" + + # Check existing runs + dattri_records, bergson_records = load_all_records(dattri_root, bergson_root) + existing_dattri = { + (r.model_key, r.train_tokens, r.eval_tokens) + for r in dattri_records + if r.status == "success" + } + existing_bergson = { + (r.model_key, r.train_tokens, r.eval_tokens) + for r in bergson_records + if r.status == "success" + } + + # Run benchmarks + for model in models: + if model not in MODEL_SPECS: + print(f"Warning: Unknown model {model}, skipping") + continue + + for train_tokens in token_scales: + # Run dattri + if not args.skip_dattri: + key = (model, train_tokens, eval_tokens) + if key not in existing_dattri or args.force: + print(f"\n{'='*60}") + print(f"Running Dattri: {model}, {format_tokens(train_tokens)} train tokens") + print(f"{'='*60}") + success = run_benchmark( + "dattri", + model, + train_tokens, + eval_tokens, + str(dattri_root), + batch_size=args.batch_size, + max_length=args.max_length, + ) + if not success: + print(f"Failed to run dattri benchmark for {model} {format_tokens(train_tokens)}") + else: + print(f"Skipping dattri {model} {format_tokens(train_tokens)} (already exists)") + + # Run bergson + if not args.skip_bergson: + key = (model, train_tokens, eval_tokens) + if key not in existing_bergson or args.force: + print(f"\n{'='*60}") + print(f"Running Bergson: {model}, {format_tokens(train_tokens)} train tokens") + print(f"{'='*60}") + success = run_benchmark( + "bergson", + model, + train_tokens, + eval_tokens, + str(bergson_root), + batch_size=args.batch_size, + max_length=args.max_length, + max_eval_examples=args.num_test, + ) + if not success: + print(f"Failed to run bergson benchmark for {model} {format_tokens(train_tokens)}") + else: + print(f"Skipping bergson {model} {format_tokens(train_tokens)} (already exists)") + + +def cmd_plot(args: argparse.Namespace) -> None: + """Generate comparison plots from existing benchmark results.""" + dattri_root = Path(args.run_root) / "dattri-scaling" + bergson_root = Path(args.run_root) / "bergson-scaling" + + dattri_records, bergson_records = load_all_records(dattri_root, bergson_root) + + df = create_comparison_dataframe(dattri_records, bergson_records) + + # Save CSV + csv_path = Path(args.output_csv) + csv_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(csv_path, index=False) + print(f"Saved comparison data to {csv_path}") + + # Create plots + if not args.skip_plots: + plot_path = Path(args.plot_output) + plot_comparison(df, plot_path) + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Coordinate dattri and bergson benchmarks", + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + # Run command + run_parser = subparsers.add_parser("run", help="Run benchmarks") + run_parser.add_argument("--models", nargs="*", help="Models to benchmark") + run_parser.add_argument( + "--token-scales", + nargs="*", + default=["1M", "2M", "5M", "10M"], + help="Token scales to test (e.g. 1M 10M)", + ) + run_parser.add_argument("--eval-tokens", default="100K", help="Evaluation tokens") + run_parser.add_argument("--batch-size", type=int, default=4) + run_parser.add_argument("--max-length", type=int, default=512) + run_parser.add_argument("--num-test", type=int, default=10, help="Number of test examples for bergson") + run_parser.add_argument("--run-root", default="runs") + run_parser.add_argument("--skip-dattri", action="store_true") + run_parser.add_argument("--skip-bergson", action="store_true") + run_parser.add_argument("--force", action="store_true", help="Re-run existing benchmarks") + run_parser.set_defaults(func=cmd_run) + + # Plot command + plot_parser = subparsers.add_parser("plot", help="Generate comparison plots") + plot_parser.add_argument("--run-root", default="runs") + plot_parser.add_argument("--output-csv", default="data/benchmark_comparison.csv") + plot_parser.add_argument("--plot-output", default="figures/benchmark_comparison.png") + plot_parser.add_argument("--skip-plots", action="store_true") + plot_parser.set_defaults(func=cmd_plot) + + args = parser.parse_args(argv) + args.func(args) + + +if __name__ == "__main__": + main() +