diff --git a/bergson/__init__.py b/bergson/__init__.py index 7019af22..64ff9aab 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -1,6 +1,7 @@ __version__ = "0.4.6" from .collection import collect_gradients +from .collector.gradient_collectors import GradientCollector from .config import ( AttentionConfig, DataConfig, @@ -11,6 +12,7 @@ ) from .data import load_gradient_dataset, load_gradients from .gradients import GradientProcessor +from .normalizer.fit_normalizers import fit_normalizers from .query.attributor import Attributor from .query.faiss_index import FaissConfig from .score.scorer import Scorer @@ -20,10 +22,12 @@ "collect_gradients", "load_gradients", "load_gradient_dataset", + "fit_normalizers", "Attributor", "FaissConfig", "FiniteDiff", "GradientProcessor", + "GradientCollector", "IndexConfig", "DataConfig", "AttentionConfig", diff --git a/bergson/__main__.py b/bergson/__main__.py index 705873cd..a4fbd59e 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -107,16 +107,10 @@ def execute(self): self.command.execute() -def get_parser(): - """Get the argument parser. Used for documentation generation.""" - parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT) - parser.add_arguments(Main, dest="prog") - return parser - - def main(args: Optional[list[str]] = None): """Parse CLI arguments and dispatch to the selected subcommand.""" - parser = get_parser() + parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT) + parser.add_arguments(Main, dest="prog") prog: Main = parser.parse_args(args=args).prog prog.execute() diff --git a/bergson/build.py b/bergson/build.py index 6aae2274..d1ece8c9 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -56,7 +56,7 @@ def build_worker( ) model, target_modules = setup_model_and_peft(cfg, rank) - processor = create_processor(cfg, rank) + processor = create_processor(model, ds, cfg, rank, target_modules) attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules} diff --git a/bergson/collector/collector.py b/bergson/collector/collector.py index 3680e3fc..3393dbc4 100644 --- a/bergson/collector/collector.py +++ b/bergson/collector/collector.py @@ -477,10 +477,10 @@ def run_with_collector_hooks( total_processed = torch.tensor(0, device=self.model.device) prof = self._setup_profiler() step = 0 - with prof: for indices in tqdm( - self.batches, disable=self.rank != 0, desc=f"Computing {desc}" + self.batches, + desc=f"Computing {desc}", ): batch = self.data[indices] @@ -503,6 +503,7 @@ def run_with_collector_hooks( step += 1 self.collector.process_batch(indices, losses=losses) + total_processed += len(indices) self.collector.teardown() if dist.is_initialized(): diff --git a/bergson/config.py b/bergson/config.py index 403277bc..17c3392e 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -95,6 +95,9 @@ class IndexConfig: processor_path: str = "" """Path to a precomputed processor.""" + normalizer: Literal["adafactor", "adam", "none"] = "none" + """Type of normalizer to use for the gradients.""" + skip_preconditioners: bool = False """Whether to skip computing preconditioners for the gradients.""" diff --git a/bergson/gradients.py b/bergson/gradients.py index 3b4089de..6175ceba 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -58,108 +58,6 @@ def state_dict(self) -> dict[str, str | Tensor]: } -@dataclass -class AdafactorNormalizer(Normalizer): - """ - Row and column sums of second moments of gradients for a matrix-valued parameter. - """ - - row: Tensor # shape [O] - col: Tensor # shape [I] - - def __post_init__(self): - assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" - assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" - - @torch.compile - def normalize_( - self, - grad: Tensor, - eps: float = 1e-30, - ) -> Tensor: - """ - Normalize the row and column sums by adding a small epsilon. - - Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They - recommend 1e-30, but we use 1e-16 for extra numerical stability. - """ - # We follow the Adafactor implementation in the tensor2tensor repo, which is - # different from the paper and from the PyTorch implementation. First add eps - # to ensure these second moments are sufficiently far from zero. Then we don't - # need to worry about numerical stability anywhere else, and we don't need to - # materialize the outer product at any point. - r, c = self.row.add(eps), self.col.add(eps) - - # This is the denominator for V, the rank-one matrix of second moment estimates: - # V = torch.outer(r, c) / denom - # V_ij = r_i * c_j / denom - # But we want to (implicitly) take the Hadamard product with the elementwise - # reciprocal square root of V: - # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt() - denom = r.mean() - - # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying - # by diag(a) and right-multiplying by diag(b). In this case we can represent - # the elementwise reciprocal square root of V as ab^T where: - # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt() - a = denom.sqrt() * r.rsqrt_() # shape [O] - b = c.rsqrt_() - - # Implicitly do the Hadamard product - grad *= a[:, None] # [N, O] * [O] → [N, O] - grad *= b[None, :] - return grad - - def to_adam(self) -> "AdamNormalizer": - """ - Convert this Adafactor normalizer to an Adam normalizer by materializing the - rank-one second moment matrix. - """ - # Compute the second moment matrix as a square matrix of shape [O, I] - # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to - # add it outside the square root. This could cause infs though if there are - # any exactly zero rows or columns, so we should be careful. - avg_sq = torch.outer(self.row, self.col) / self.row.mean() - return AdamNormalizer(avg_sq=avg_sq) - - -@dataclass -class AdamNormalizer(Normalizer): - """ - Contains the second moments of the gradients. - """ - - avg_sq: Tensor - - @torch.compile - def normalize_( - self, - grad: Tensor, - eps: float = 1e-8, - ) -> Tensor: - """Normalize the gradients by the square root of the second moments.""" - # Adam-style epsilon is added outside the square root - denom = self.avg_sq.sqrt() - return grad.div_(denom.add_(eps)) - - def to_adafactor(self) -> AdafactorNormalizer: - """ - Convert this Adam normalizer to an Adafactor normalizer, minimizing the - I-divergence (generalized Kullback-Leibler divergence) between the original - and the factored second moments. - """ - # We assume avg_sq is a square matrix of shape [O, I] - assert ( - self.avg_sq.ndim == 2 - ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" - - # Compute row and column means - return AdafactorNormalizer( - row=self.avg_sq.mean(dim=1), # shape [O] - col=self.avg_sq.mean(dim=0), # shape [I] - ) - - @dataclass class GradientProcessor: """Configuration for processing and compressing gradients.""" @@ -317,3 +215,105 @@ def out_attr(layer: nn.Module) -> str: return "out_channels" case _: raise ValueError(f"Unsupported layer type: {type(layer)}") + + +@dataclass +class AdafactorNormalizer(Normalizer): + """ + Row and column sums of second moments of gradients for a matrix-valued parameter. + """ + + row: Tensor # shape [O] + col: Tensor # shape [I] + + def __post_init__(self): + assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" + assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" + + @torch.compile + def normalize_( + self, + grad: Tensor, + eps: float = 1e-30, + ) -> Tensor: + """ + Normalize the row and column sums by adding a small epsilon. + + Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They + recommend 1e-30, but we use 1e-16 for extra numerical stability. + """ + # We follow the Adafactor implementation in the tensor2tensor repo, which is + # different from the paper and from the PyTorch implementation. First add eps + # to ensure these second moments are sufficiently far from zero. Then we don't + # need to worry about numerical stability anywhere else, and we don't need to + # materialize the outer product at any point. + r, c = self.row.add(eps), self.col.add(eps) + + # This is the denominator for V, the rank-one matrix of second moment estimates: + # V = torch.outer(r, c) / denom + # V_ij = r_i * c_j / denom + # But we want to (implicitly) take the Hadamard product with the elementwise + # reciprocal square root of V: + # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt() + denom = r.mean() + + # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying + # by diag(a) and right-multiplying by diag(b). In this case we can represent + # the elementwise reciprocal square root of V as ab^T where: + # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt() + a = denom.sqrt() * r.rsqrt_() # shape [O] + b = c.rsqrt_() + + # Implicitly do the Hadamard product + grad *= a[:, None] # [N, O] * [O] → [N, O] + grad *= b[None, :] + return grad + + def to_adam(self) -> "AdamNormalizer": + """ + Convert this Adafactor normalizer to an Adam normalizer by materializing the + rank-one second moment matrix. + """ + # Compute the second moment matrix as a square matrix of shape [O, I] + # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to + # add it outside the square root. This could cause infs though if there are + # any exactly zero rows or columns, so we should be careful. + avg_sq = torch.outer(self.row, self.col) / self.row.mean() + return AdamNormalizer(avg_sq=avg_sq) + + +@dataclass +class AdamNormalizer(Normalizer): + """ + Contains the second moments of the gradients. + """ + + avg_sq: Tensor + + @torch.compile + def normalize_( + self, + grad: Tensor, + eps: float = 1e-8, + ) -> Tensor: + """Normalize the gradients by the square root of the second moments.""" + # Adam-style epsilon is added outside the square root + denom = self.avg_sq.sqrt() + return grad.div_(denom.add_(eps)) + + def to_adafactor(self) -> AdafactorNormalizer: + """ + Convert this Adam normalizer to an Adafactor normalizer, minimizing the + I-divergence (generalized Kullback-Leibler divergence) between the original + and the factored second moments. + """ + # We assume avg_sq is a square matrix of shape [O, I] + assert ( + self.avg_sq.ndim == 2 + ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" + + # Compute row and column means + return AdafactorNormalizer( + row=self.avg_sq.mean(dim=1), # shape [O] + col=self.avg_sq.mean(dim=0), # shape [I] + ) diff --git a/bergson/normalizer/__init__.py b/bergson/normalizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bergson/normalizer/fit_normalizers.py b/bergson/normalizer/fit_normalizers.py new file mode 100644 index 00000000..fe36c9fd --- /dev/null +++ b/bergson/normalizer/fit_normalizers.py @@ -0,0 +1,284 @@ +import math +import random +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +import torch.nn as nn +from datasets import Dataset +from jaxtyping import Float +from torch import Tensor +from transformers import PreTrainedModel + +from bergson.collector.collector import CollectorComputer, HookCollectorBase +from bergson.config import IndexConfig +from bergson.gradients import ( + AdafactorNormalizer, + AdamNormalizer, + LayerAdapter, + Normalizer, +) +from bergson.process_preconditioners import process_preconditioners +from bergson.utils.utils import assert_type + + +@dataclass(kw_only=True) +class NormalizerCollector(HookCollectorBase): + """ + Collects per-sample gradients from model layers and writes them to disk. + + - For each forward/backward hook, we compute the the gradient or a low-rank + approximation via random projections, if cfg.projection_dim is set. + - Supports also normalization via Adam or Adafactor normalizers. + - Uses Builder for index construction and gradient saving. + - Also supports Scorer for on-the-fly scoring of gradients. + """ + + data: Dataset + """The dataset being processed.""" + + cfg: IndexConfig + """Configuration for gradient index.""" + + normalizers: dict[str, Normalizer] = field(default_factory=dict) + + mod_grads: dict = field(default_factory=dict) + """Temporary storage for gradients during a batch, keyed by module name.""" + + def __init__(self, *args, **kwargs): + self.data = assert_type(Dataset, kwargs["data"]) + self.cfg = assert_type(IndexConfig, kwargs["cfg"]) + self.normalizers = {} + self.mod_grads = {} + + self.callback = ( + self.adafactor_update + if self.cfg.normalizer == "adafactor" + else self.adam_update + ) + + # Extract parent class arguments + parent_kwargs = { + k: v + for k, v in kwargs.items() + if k + in { + "model", + "filter_modules", + "target_modules", + "processor", + "attention_cfgs", + } + } + parent_kwargs["filter_modules"] = self.cfg.filter_modules + + super().__init__(*args, **parent_kwargs) + + def adafactor_update(self, name: str, g: torch.Tensor): + # We follow the tensor2tensor implementation of Adafactor, which + # takes the mean rather than summing over the rows and columns. + # row: mean over columns, shape [O] + sq = g.float().square_().sum(0) + row_acc = sq.mean(dim=1) + # col: mean over rows, shape [I] + col_acc = sq.mean(dim=0) + + if (normalizer := self.normalizers.get(name)) is None: + # initialize accumulators at zero + self.normalizers[name] = normalizer = AdafactorNormalizer( + torch.zeros_like(row_acc), + torch.zeros_like(col_acc), + ) + else: + assert isinstance(normalizer, AdafactorNormalizer) + + # in‐place accumulate + normalizer.row.add_(row_acc) + normalizer.col.add_(col_acc) + + def adam_update(self, name: str, g: torch.Tensor): + sq = g.square_().float().sum(0) + + # initialize accumulators at zero + if (normalizer := self.normalizers.get(name)) is None: + self.normalizers[name] = normalizer = AdamNormalizer(torch.zeros_like(sq)) + else: + assert isinstance(normalizer, AdamNormalizer) + + # in‐place accumulate + normalizer.avg_sq.add_(sq) + + def setup(self) -> None: + """ + Initialize collector state. + + Sets up a Builder for gradient storage if not using a Scorer. + """ + assert isinstance( + self.model.device, torch.device + ), "Model device is not set correctly" + if self.cfg.include_bias and self.processor.normalizers is not None: + raise NotImplementedError( + "Bias with normalizers not supported yet, " + "consider disabling bias inclusion for now." + ) + + # TODO: handle more elegantly? + self.save_dtype = ( + torch.float32 if self.model.dtype == torch.float32 else torch.float16 + ) + + self.lo = torch.finfo(self.save_dtype).min + self.hi = torch.finfo(self.save_dtype).max + + def forward_hook(self, module: nn.Module, a: Float[Tensor, "N S I"]) -> None: + """ + Cache activations for gradient computation with normalizer preprocessing + and compress via random projection if configured. + Stores result in module._inputs for use in backward_hook. + """ + p = self.processor.projection_dim + name = assert_type(str, module._name) + i = getattr(module, LayerAdapter.in_attr(module)) + normalizer = self.processor.normalizers.get(name) + + if isinstance(normalizer, AdamNormalizer): + module._inputs = a + return + if isinstance(normalizer, AdafactorNormalizer): + a_factor = normalizer.col.add(1e-30) + a_factor = a_factor.rsqrt() + a = a * a_factor.type_as(a) # [N, S, I] * [I] → [N, S, I] + + if module._has_bias: + # Append ones to activation for bias term + ones = torch.ones(a.size(0), a.size(1), 1, device=a.device, dtype=a.dtype) + a = torch.cat([a, ones], dim=-1) + i = i + 1 + setattr(module, LayerAdapter.in_attr(module), i) + if p is not None: + a_projection = self.projection(name, p, i, "right", a.device, a.dtype).T + a = a @ a_projection # type: ignore + # set module._inputs to a + module._inputs = a + + @HookCollectorBase.split_attention_heads + def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): + """ + Compute per-sample gradient and store in mod_grads. + + Computes gradient as outer product g.T @ a (again with optional projection and + normalization). + """ + a = module._inputs # [N, S, I/q] + + assert isinstance(a, torch.Tensor), "Activation cache missing for module" + name = assert_type(str, module._name) + p = self.processor.projection_dim + i = getattr(module, LayerAdapter.in_attr(module)) + o = getattr(module, LayerAdapter.out_attr(module)) + normalizer = self.processor.normalizers.get(name) + + if isinstance(normalizer, AdamNormalizer): + full_gradient = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I] + P = normalizer.normalize_(full_gradient) + if p is not None: + g_projection = self.projection(name, p, o, "left", g.device, g.dtype) + a_projection = self.projection(name, p, i, "right", g.device, g.dtype).T + P = g_projection @ P @ a_projection + else: + if isinstance(normalizer, AdafactorNormalizer): + g_factor = normalizer.row.add(1e-30) + g_factor = g_factor.mean().sqrt() * g_factor.rsqrt() + g = g * g_factor.type_as(g) # [N, S, O] * [O] → [N, S, O] + + if p is not None: + g_projection = self.projection(name, p, o, "left", g.device, g.dtype) + g = g @ g_projection.T # [N, S, p] + + P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] + + P = P.flatten(1).clamp_(self.lo, self.hi) + + if not self.cfg.skip_preconditioners: + P = P.float() + if name in self.processor.preconditioners: + self.processor.preconditioners[name].addmm_(P.mT, P) + else: + self.processor.preconditioners[name] = P.mT @ P + + # self.mod_grads[name] = P.to(dtype=self.save_dtype) + self.callback(name, P) + + del module._inputs + + def process_batch(self, indices: list[int], **kwargs): + """Process collected gradients for a batch.""" + self.mod_grads.clear() + + def teardown(self): + """ + Finalize normalizer collection. + """ + grad_sizes = {name: math.prod(s) for name, s in self.shapes().items()} + if self.processor.preconditioners: + process_preconditioners( + self.processor, + self.processor.preconditioners, + len(self.data), + grad_sizes, + self.rank, + ) + + if self.rank == 0: + self.processor.save(self.cfg.partial_run_path) + + +def fit_normalizers( + model: PreTrainedModel, + data: Dataset, + cfg: IndexConfig, + batches: list[list[int]], + *, + target_modules: set[str] | None = None, +) -> dict[str, Normalizer]: + """ + Estimate the second moments of the model's gradients using a subset of the dataset. + """ + # Just to make the pbar more accurate + rng = random.Random(0) + rng.shuffle(batches) + + collector = NormalizerCollector( + model=model.base_model, # type: ignore + data=data, + cfg=cfg, + target_modules=target_modules, + ) + computer = CollectorComputer( + model=model, + data=data, + collector=collector, + cfg=cfg, + ) + computer.run_with_collector_hooks(desc="Estimating normalizers") + + normalizers = collector.normalizers + + # Divide by the number of documents processed and average across all ranks + for normalizer in normalizers.values(): + if isinstance(normalizer, AdamNormalizer): + normalizer.avg_sq.div_(len(data)) + + if dist.is_initialized(): + dist.all_reduce(normalizer.avg_sq, op=dist.ReduceOp.AVG) + + elif isinstance(normalizer, AdafactorNormalizer): + normalizer.row.div_(len(data)) + normalizer.col.div_(len(data)) + + if dist.is_initialized(): + dist.all_reduce(normalizer.row, op=dist.ReduceOp.AVG) + dist.all_reduce(normalizer.col, op=dist.ReduceOp.AVG) + + return normalizers diff --git a/bergson/reduce.py b/bergson/reduce.py index a0567232..7074c6c9 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -59,7 +59,7 @@ def reduce_worker( ) model, target_modules = setup_model_and_peft(index_cfg, rank) - processor = create_processor(index_cfg, rank) + processor = create_processor(model, ds, index_cfg, rank, target_modules) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules diff --git a/bergson/score/score.py b/bergson/score/score.py index bc4a93c4..de2506a1 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -4,13 +4,12 @@ from dataclasses import asdict from datetime import timedelta from pathlib import Path -from typing import Literal, cast +from typing import Literal import torch import torch.distributed as dist from datasets import Dataset, IterableDataset from tqdm.auto import tqdm -from transformers import PreTrainedModel from bergson.collection import collect_gradients from bergson.config import IndexConfig, ScoreConfig @@ -260,8 +259,7 @@ def score_worker( ) model, target_modules = setup_model_and_peft(index_cfg, rank) - model = cast(PreTrainedModel, model) - processor = create_processor(index_cfg, rank) + processor = create_processor(model, ds, index_cfg, rank, target_modules) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index 8f12d039..1c41822d 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pandas as pd @@ -17,14 +18,53 @@ ) from bergson.config import DataConfig, IndexConfig -from bergson.data import load_data_string, tokenize -from bergson.gradients import GradientProcessor +from bergson.data import allocate_batches, load_data_string, tokenize +from bergson.gradients import GradientProcessor, Normalizer +from bergson.normalizer.fit_normalizers import fit_normalizers from bergson.utils.utils import assert_type, get_layer_list +def create_normalizers( + model: PreTrainedModel, + ds: Dataset | IterableDataset, + cfg: IndexConfig, + target_modules: set[str] | None = None, +) -> dict[str, Normalizer]: + """Create normalizers for the model""" + if cfg.normalizer != "none": + # Evenly sample `stats_sample_size` examples to compute statistics + if isinstance(ds, Dataset): + if cfg.stats_sample_size is not None and cfg.stats_sample_size < len(ds): + stats_ds = ds.shuffle(seed=0).select(range(cfg.stats_sample_size)) + else: + stats_ds = ds + else: + if cfg.stats_sample_size is None: + stats_iterable_ds = ds + else: + stats_iterable_ds = ds.shuffle(seed=0).take(cfg.stats_sample_size) + + stats_ds = assert_type( + Dataset, Dataset.from_generator(lambda: iter(stats_iterable_ds)) + ) + + return fit_normalizers( + model, + stats_ds, + cfg, + batches=allocate_batches(stats_ds["length"][:], cfg.token_batch_size), + target_modules=target_modules, + ) + + return {} + + def create_processor( + model: PreTrainedModel, + ds: Dataset | IterableDataset, cfg: IndexConfig, rank: int, + target_modules: set[str] | None = None, ) -> GradientProcessor: """Handle processor creation and normalizer fitting""" processor_path = Path(cfg.processor_path) @@ -37,8 +77,10 @@ def create_processor( map_location=f"cuda:{rank}", ) else: + normalizers = create_normalizers(model, ds, cfg, target_modules) + processor = GradientProcessor( - {}, + normalizers, projection_dim=cfg.projection_dim or None, reshape_to_square=cfg.reshape_to_square, projection_type=cfg.projection_type, @@ -149,6 +191,8 @@ def setup_model_and_peft( fully_shard(layer) fully_shard(model) + model = cast(PreTrainedModel, model) + return model, target_modules # type: ignore diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py new file mode 100644 index 00000000..00184a4a --- /dev/null +++ b/tests/test_normalizer.py @@ -0,0 +1,30 @@ +import torch.nn as nn + +from bergson import fit_normalizers +from bergson.config import IndexConfig + + +def test_fit_normalizers_runs(tmp_path, model, dataset): + target_modules = { + name + for name, module in model.base_model.named_modules() + if isinstance(module, nn.Linear) + } + print("len dataset", len(dataset)) + print("target_modules", target_modules) + + dataset = dataset.repeat(10) + + normalizers = fit_normalizers( + model, + dataset, + cfg=IndexConfig( + run_path=str(tmp_path), + skip_preconditioners=True, + normalizer="adam", + ), + batches=[[idx] for idx in range(len(dataset))], + target_modules=target_modules, + ) + + assert len(normalizers) == len(target_modules)