diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 14eedf88..c0d94933 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -84,18 +84,20 @@ def setup(self) -> None: Sets up a Builder for gradient storage if not using a Scorer. """ + model_device = ( + getattr(self.model, "device", None) or next(self.model.parameters()).device + ) + model_dtype = ( + getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype + ) + assert isinstance( - self.model.device, torch.device + model_device, torch.device ), "Model device is not set correctly" - if self.cfg.include_bias and self.processor.normalizers is not None: - raise NotImplementedError( - "Bias with normalizers not supported yet, " - "consider disabling bias inclusion for now." - ) # TODO: handle more elegantly? self.save_dtype = ( - torch.float32 if self.model.dtype == torch.float32 else torch.float16 + torch.float32 if model_dtype == torch.float32 else torch.float16 ) self.lo = torch.finfo(self.save_dtype).min @@ -103,7 +105,7 @@ def setup(self) -> None: self.per_doc_losses = torch.full( (len(self.data),), - device=self.model.device, + device=model_device, dtype=self.save_dtype, fill_value=0.0, ) @@ -143,16 +145,18 @@ def forward_hook(self, module: nn.Module, a: Float[Tensor, "N S I"]) -> None: a_factor = a_factor.rsqrt() a = a * a_factor.type_as(a) # [N, S, I] * [I] → [N, S, I] - if module._has_bias: - # Append ones to activation for bias term + # For normalizer cases, bias normalization differs from weight normalization, + # so we handle bias separately in backward hook + if module._has_bias and normalizer is None: ones = torch.ones(a.size(0), a.size(1), 1, device=a.device, dtype=a.dtype) - a = torch.cat([a, ones], dim=-1) + a = torch.cat([a, ones], dim=-1) # [N, S, I+1] i = i + 1 setattr(module, LayerAdapter.in_attr(module), i) - if p is not None: + + if p is not None and (normalizer is None or not module._has_bias): a_projection = self.projection(name, p, i, "right", a.device, a.dtype).T - a = a @ a_projection # type: ignore - # set module._inputs to a + a = a @ a_projection # [N, S, I(+1)] @ [I(+1), p] → [N, S, p] + module._inputs = a @HookCollectorBase.split_attention_heads @@ -171,25 +175,79 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): i = getattr(module, LayerAdapter.in_attr(module)) o = getattr(module, LayerAdapter.out_attr(module)) normalizer = self.processor.normalizers.get(name) - - if isinstance(normalizer, AdamNormalizer): - full_gradient = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I] - P = normalizer.normalize_(full_gradient) - if p is not None: - g_projection = self.projection(name, p, o, "left", g.device, g.dtype) - a_projection = self.projection(name, p, i, "right", g.device, g.dtype).T - P = g_projection @ P @ a_projection - else: - if isinstance(normalizer, AdafactorNormalizer): + bias_grad = None + + match normalizer: + case AdamNormalizer(): + if module._has_bias and normalizer.bias_avg_sq is not None: + # Normalize bias with bias second moments # [N, S, O] → [N, O] + bias_grad = g.sum(dim=1) / normalizer.bias_avg_sq.sqrt().add(1e-8) + + P = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I] + P = normalizer.normalize_(P) + + # Append pre-normalized bias gradient + if bias_grad is not None: + P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1] + i += 1 + + if p is not None: + g_projection = self.projection( + name, p, o, "left", g.device, g.dtype + ) + a_projection = self.projection( + name, p, i, "right", g.device, g.dtype + ).T + P = g_projection @ P @ a_projection + + case AdafactorNormalizer(): + if module._has_bias and normalizer.bias_avg_sq is not None: + # Compute bias from RAW g (before row normalization) + bias_grad = g.sum(dim=1) # [N, S, O] → [N, O] + # Normalize bias with bias second moments + bias_grad = bias_grad / normalizer.bias_avg_sq.add(1e-30).sqrt() + + # Apply row normalization to g (for weights) g_factor = normalizer.row.add(1e-30) g_factor = g_factor.mean().sqrt() * g_factor.rsqrt() g = g * g_factor.type_as(g) # [N, S, O] * [O] → [N, S, O] - if p is not None: - g_projection = self.projection(name, p, o, "left", g.device, g.dtype) - g = g @ g_projection.T # [N, S, p] - - P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] + # If bias is present, materialize full gradient then project + if bias_grad is not None: + P = g.mT @ a # [N, O, I] + + # Append pre-normalized bias gradient + P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1] + i += 1 + + # Project the entire normalized gradient + if p is not None: + g_projection = self.projection( + name, p, o, "left", g.device, g.dtype + ) + a_projection = self.projection( + name, p, i, "right", a.device, a.dtype + ).T + P = g_projection @ P @ a_projection + else: + if p is not None: + g_projection = self.projection( + name, p, o, "left", g.device, g.dtype + ) + g = g @ g_projection.T # [N, S, p] + + P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] + + case None: + if p is not None: + g_projection = self.projection( + name, p, o, "left", g.device, g.dtype + ) + g = g @ g_projection.T # [N, S, p] + + P = g.mT @ a # [N, O/p, I(+1)/p] + case _: + raise ValueError(f"Unknown normalizer type {type(normalizer)}") P = P.flatten(1).clamp_(self.lo, self.hi) @@ -293,9 +351,14 @@ class TraceCollector(HookCollectorBase): """Dtype for stored gradients.""" def setup(self) -> None: + + model_dtype = ( + getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype + ) + # TODO: handle more elegantly? self.save_dtype = ( - torch.float32 if self.model.dtype == torch.float32 else torch.float16 + torch.float32 if model_dtype == torch.float32 else torch.float16 ) self.lo = torch.finfo(self.save_dtype).min @@ -403,9 +466,14 @@ class StreamingGradientCollector(HookCollectorBase): """Dtype for stored gradients.""" def setup(self) -> None: + + model_dtype = ( + getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype + ) + # TODO: handle more elegantly? self.save_dtype = ( - torch.float32 if self.model.dtype == torch.float32 else torch.float16 + torch.float32 if model_dtype == torch.float32 else torch.float16 ) self.lo = torch.finfo(self.save_dtype).min diff --git a/bergson/gradients.py b/bergson/gradients.py index 3b4089de..b9f37643 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -62,14 +62,24 @@ def state_dict(self) -> dict[str, str | Tensor]: class AdafactorNormalizer(Normalizer): """ Row and column sums of second moments of gradients for a matrix-valued parameter. + + Args: + row: Row statistics [O] + col: Column statistics [I] + bias_avg_sq: Optional second moments for bias [O] """ row: Tensor # shape [O] col: Tensor # shape [I] + bias_avg_sq: Tensor | None = None # shape [O] def __post_init__(self): assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" + if self.bias_avg_sq is not None: + assert ( + self.bias_avg_sq.ndim == 1 + ), f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D" @torch.compile def normalize_( @@ -114,22 +124,40 @@ def to_adam(self) -> "AdamNormalizer": """ Convert this Adafactor normalizer to an Adam normalizer by materializing the rank-one second moment matrix. + + Preserves bias_avg_sq if present. """ # Compute the second moment matrix as a square matrix of shape [O, I] # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to # add it outside the square root. This could cause infs though if there are # any exactly zero rows or columns, so we should be careful. avg_sq = torch.outer(self.row, self.col) / self.row.mean() - return AdamNormalizer(avg_sq=avg_sq) + return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq) + + def scale_by_lr(self, lr: float | Tensor) -> None: + """Scale normalizer by learning rate. + + Factorized dimensions (row, col) are scaled by lr. + Bias is scaled by lr**2. + """ + lr_sqrt = lr**0.5 + self.row.mul_(lr_sqrt) + self.col.mul_(lr_sqrt) + self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None @dataclass class AdamNormalizer(Normalizer): """ Contains the second moments of the gradients. + + Args: + avg_sq: Second moments for weights [O, I] + bias_avg_sq: Optional second moments for bias [O] """ avg_sq: Tensor + bias_avg_sq: Tensor | None = None @torch.compile def normalize_( @@ -147,6 +175,8 @@ def to_adafactor(self) -> AdafactorNormalizer: Convert this Adam normalizer to an Adafactor normalizer, minimizing the I-divergence (generalized Kullback-Leibler divergence) between the original and the factored second moments. + + Preserves bias_avg_sq if present. """ # We assume avg_sq is a square matrix of shape [O, I] assert ( @@ -157,8 +187,17 @@ def to_adafactor(self) -> AdafactorNormalizer: return AdafactorNormalizer( row=self.avg_sq.mean(dim=1), # shape [O] col=self.avg_sq.mean(dim=0), # shape [I] + bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments ) + def scale_by_lr(self, lr: float | Tensor) -> None: + """Scale normalizer to incorporate learning rate. + + Both avg_sq and bias_avg_sq are divided by lr². + """ + self.avg_sq.mul_(lr) + self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None + @dataclass class GradientProcessor: diff --git a/bergson/huggingface.py b/bergson/huggingface.py index e3bd08b3..9b92bc35 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -239,7 +239,6 @@ def on_step_end( **kwargs, ): self.on_substep_end(args, state, control) - print("Step end") # Record training order if enabled if self.order is not None: @@ -271,7 +270,7 @@ def on_step_end( for name, param in model.named_parameters() if param.requires_grad } - normalizers: dict[str, AdafactorNormalizer] = {} + normalizers: dict[str, AdafactorNormalizer | AdamNormalizer] = {} assert self.collector is not None proc = self.collector.processor @@ -279,32 +278,76 @@ def on_step_end( # Read normalizers off of the optimizer state. We need to figure out # what type of optimizer this is first. + # Collect references to both weight and bias second moments per layer + layer_second_moments: dict[str, dict[str, Tensor]] = {} + for group in optimizer.param_groups: - lr_sqrt = group["lr"] ** 0.5 + group_lr = group["lr"] for param in group["params"]: - name = param_to_name[param].removesuffix(".weight") - if name not in self.collector.target_info: + param_name = param_to_name[param] + + # Extract layer name (remove .weight or .bias suffix) + if param_name.endswith(".weight"): + param_type = "weight" + layer_name = param_name.removesuffix(".weight") + elif param_name.endswith(".bias"): + param_type = "bias" + layer_name = param_name.removesuffix(".bias") + else: + continue + + if layer_name not in self.collector.target_info: continue p_state = optimizer.state[param] - # Adam-like optimizer - if (eas := p_state.get("exp_avg_sq")) is not None: - norm = AdamNormalizer(eas).to_adafactor() + # Initialize layer dict if needed, storing this group's learning rate + if layer_name not in layer_second_moments: + layer_second_moments[layer_name] = {"lr": group_lr} - # Adafactor-like optimizer - elif (vr := p_state.get("exp_avg_sq_row")) is not None: + # Check for Adafactor FIRST (more specific than Adam) + # Adafactor-like optimizer: weights have factorized moments + if (vr := p_state.get("exp_avg_sq_row")) is not None: vc = p_state.get("exp_avg_sq_col") - norm = AdafactorNormalizer(vr, vc) - else: - continue - - # Scale the gradient by the current learning rate. It's factorized - # so we multiply each factor by the square root of the LR. - norm.row *= lr_sqrt - norm.col *= lr_sqrt - normalizers[name] = norm + if param_type == "weight": + # Factorized second moments for weights + layer_second_moments[layer_name]["row"] = vr + layer_second_moments[layer_name]["col"] = vc + elif param_type == "bias": + # Adafactor stores bias as regular exp_avg_sq + bias_eas = p_state.get("exp_avg_sq") + if bias_eas is not None: + layer_second_moments[layer_name]["bias"] = bias_eas + # Adam-like optimizer: has exp_avg_sq for both weight and bias + elif (eas := p_state.get("exp_avg_sq")) is not None: + layer_second_moments[layer_name][param_type] = eas + + # Build normalizers from collected second moments + for layer_name, moments in layer_second_moments.items(): + lr = moments["lr"] + + # Adam-like: has weight exp_avg_sq + if "weight" in moments: + weight_eas = moments["weight"] * lr + bias_eas = moments.get("bias") + bias_eas = bias_eas * lr if bias_eas is not None else None + + norm = AdamNormalizer(weight_eas, bias_eas) + + # Adafactor-like: has row/col factorization + elif "row" in moments and "col" in moments: + row = moments["row"] * lr**0.5 + col = moments["col"] * lr**0.5 + bias_eas = moments.get("bias") + bias_eas = bias_eas * lr if bias_eas is not None else None + + norm = AdafactorNormalizer(row, col, bias_eas) + else: + # No weight moments found - skip this layer + continue + + normalizers[layer_name] = norm proc.normalizers = normalizers diff --git a/tests/test_gradients.py b/tests/test_gradients.py index ada5305d..b3acad46 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,7 +1,10 @@ import tempfile +from collections import defaultdict from pathlib import Path +import pytest import torch +import torch.nn as nn from datasets import Dataset from transformers import AutoConfig, AutoModelForCausalLM @@ -15,13 +18,136 @@ ) -def test_GPTNeoX(): +# Test fixtures +@pytest.fixture +def test_params(): + """Common test parameters used across gradient tests. + + Returns: + dict: Test dimensions with keys: + - N: Batch size (4) + - S: Sequence length (6) + - I: Input dimension (5) + - O: Output dimension (3) + """ + return {"N": 4, "S": 6, "I": 5, "O": 3} + + +@pytest.fixture +def simple_model_class(test_params): + """Factory for creating test model classes. + + Creates simple neural network models for testing gradient collection. + Supports both single-layer and two-layer architectures. + + Returns: + callable: Factory function that takes: + - include_bias (bool): Whether to include bias terms + - num_layers (int): Number of linear layers (1 or 2, default 2) + + Examples: + >>> ModelClass = simple_model_class(include_bias=True, num_layers=1) + >>> model = ModelClass() # Single layer: fc + >>> ModelClass = simple_model_class(include_bias=False, num_layers=2) + >>> model = ModelClass() # Two layers: fc1, relu, fc2 + """ + I, O = test_params["I"], test_params["O"] + + def _make_model(include_bias: bool, num_layers: int = 2): + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + if num_layers == 1: + self.fc = nn.Linear(I, O, bias=include_bias) + self.layers = nn.Sequential(self.fc) + else: # num_layers == 2 + self.fc1 = nn.Linear(I, O * 2, bias=include_bias) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(O * 2, O, bias=include_bias) + self.layers = nn.Sequential(self.fc1, self.relu, self.fc2) + + def forward(self, x): + return self.layers(x) + + return SimpleModel + + return _make_model + + +@pytest.fixture +def trained_model_with_normalizers(simple_model_class, test_params): + """Factory for creating trained models with Adam second moments. + + Creates a two-layer model, runs several training steps with Adam optimizer, + then extracts second moments (exp_avg_sq) to create AdamNormalizers for + both weights and biases. + + Returns: + callable: Factory function that takes: + - include_bias (bool): Whether to include bias normalizers + + Returns tuple of (model, normalizers) where: + - model: Trained SimpleModel instance + - normalizers: Dict mapping layer names to AdamNormalizer instances + with weight and optional bias second moments + """ + N, S, I = test_params["N"], test_params["S"], test_params["I"] + + def _create(include_bias: bool): + torch.manual_seed(42) + ModelClass = simple_model_class(include_bias) + model = ModelClass().to("cpu") + + optimizer = torch.optim.Adam(model.parameters()) + + # Run a few training steps to build up second moments + for _ in range(5): + optimizer.zero_grad() + out = model(torch.randn(N, S, I)) + loss = (out**2).sum() + loss.backward() + optimizer.step() + + # Extract normalizers from optimizer state + normalizers = {} + for name, param in model.named_parameters(): + if "weight" in name: + layer_name = name.replace(".weight", "") + exp_avg_sq = optimizer.state[param]["exp_avg_sq"] + + # Get bias second moments if bias is included + bias_avg_sq = None + if include_bias: + bias_param_name = layer_name + ".bias" + for p_name, p in model.named_parameters(): + if p_name == bias_param_name: + bias_avg_sq = optimizer.state[p]["exp_avg_sq"] + break + + normalizers[layer_name] = AdamNormalizer(exp_avg_sq, bias_avg_sq) + + return model, normalizers + + return _create + + +def test_gradient_collector_proj_norm(): + """Test gradient collection with projection and normalization. + + Verifies that GradientCollector correctly: + - Collects gradients with and without random projection + - Applies Adam and Adafactor normalization + - Saves and loads GradientProcessor state + - Produces consistent results across save/load cycles + """ temp_dir = Path(tempfile.mkdtemp()) print(temp_dir) config = AutoConfig.from_pretrained("trl-internal-testing/tiny-GPTNeoXForCausalLM") model = AutoModelForCausalLM.from_config(config) + # It's important that we use a batch size of one so that we can simply use the + # aggregate gradients from the backward itself and compare against those tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=model.device) inputs = dict(input_ids=tokens, labels=tokens) data = Dataset.from_dict({"input_ids": tokens.tolist()}) @@ -48,6 +174,7 @@ def test_GPTNeoX(): adafactors: dict[str, AdafactorNormalizer] = {} adams: dict[str, AdamNormalizer] = {} + # Go through the motions of what GradientCollector does, but after the fact for name, collected_grad in collected_grads.items(): layer = model.get_submodule(name) @@ -115,3 +242,252 @@ def test_GPTNeoX(): ) previous_collected_grads = collected_grads.copy() + + +@pytest.mark.parametrize("include_bias", [True, False]) +def test_gradient_collector_batched( + include_bias: bool, trained_model_with_normalizers, test_params +): + """Test per-sample gradient collection with Adam normalization. + + Tests gradient collection with and without bias terms by: + - Computing ground truth gradients via individual backward passes + - Comparing against GradientCollector's batched computation + - Verifying proper bias normalization using Adam second moments + + Args: + include_bias: Whether to include bias gradients in collection + """ + temp_dir = Path(tempfile.mkdtemp()) + N, S, I = test_params["N"], test_params["S"], test_params["I"] + + model, normalizers = trained_model_with_normalizers(include_bias) + + # Create dummy dataset for GradientCollector + dummy_data = Dataset.from_dict({"input_ids": [[1] * 10] * N}) + + # Create config for GradientCollector + cfg = IndexConfig( + run_path=str(temp_dir / "run"), + skip_index=True, + ) + + processor = GradientProcessor( + normalizers=normalizers, projection_dim=None, include_bias=include_bias + ) + collector = GradientCollector( + model=model, + cfg=cfg, + data=dummy_data, + processor=processor, + target_modules={"fc1", "fc2"}, + ) + + x = torch.randn(N, S, I) + with collector: + model.zero_grad() + out = model(x) + loss = (out**2).sum() + loss.backward() + + # Copy collected gradients from collector.mod_grads + collected_grads = collector.mod_grads.copy() + + def compute_ground_truth(): + """Compute gradients using individual backward passes, with normalization.""" + model.zero_grad() + output = model(x) # [N, S, O] + + # Per-sample losses + per_sample_losses = (output**2).sum(dim=(1, 2)) # [N] + + ground_truth_grads = defaultdict(list) + for n in range(N): + model.zero_grad() + per_sample_losses[n].backward(retain_graph=True) + + # manually normalize + for layer_name in ["fc1", "fc2"]: + layer = model.get_submodule(layer_name) + grad = layer.weight.grad.clone() + + grad = normalizers[layer_name].normalize_(grad) + + if include_bias: + bias_grad = layer.bias.grad.clone() + # Normalize bias with bias second moments + # (matching GradientCollector) + bias_grad = bias_grad / normalizers[ + layer_name + ].bias_avg_sq.sqrt().add(1e-8) + bias_grad = bias_grad.unsqueeze(1) + grad = torch.cat([grad, bias_grad], dim=1) + + # Flatten to match GradientCollector's output format + ground_truth_grads[layer_name].append(grad.flatten()) + + for layer_name in ["fc1", "fc2"]: + ground_truth_grads[layer_name] = torch.stack(ground_truth_grads[layer_name]) + + return ground_truth_grads + + ground_truth = compute_ground_truth() + for layer_name in ["fc1", "fc2"]: + torch.testing.assert_close( + collected_grads[layer_name], ground_truth[layer_name] + ) + + +def test_bias_gradients(test_params, simple_model_class): + """Test per-sample bias gradient computation without normalizers. + + Validates that GradientCollector correctly computes bias gradients when + no normalizers are provided by: + - Computing ground truth via individual backward passes + - Collecting bias gradients using GradientCollector + - Verifying bias gradients match (summed over sequence dimension) + + This tests the no-normalizer bias collection path added to support + bias gradients without Adam/Adafactor second moments. + """ + temp_dir = Path(tempfile.mkdtemp()) + torch.manual_seed(42) + N, S, I, O = test_params["N"], test_params["S"], test_params["I"], test_params["O"] + + ModelClass = simple_model_class(include_bias=True, num_layers=1) + model = ModelClass().to("cpu") + x = torch.randn(N, S, I) + + # bias gradient is a sum over sequence dimension for each n + def compute_ground_truth(model) -> torch.Tensor: + """Compute gradients using individual backward passes.""" + model.zero_grad() + output = model(x) # [N, S, O] + + per_sample_losses = (output**2).sum(dim=(1, 2)) # [N] + + bias_grads = [] + for n in range(N): + model.zero_grad() + per_sample_losses[n].backward(retain_graph=True) + bias_grads.append(model.fc.bias.grad.clone()) + + return torch.stack(bias_grads, dim=0) # [N, O] + + ground_truth = compute_ground_truth(model) + + # GradientCollector with include_bias=True + # Create dummy dataset for GradientCollector + dummy_data = Dataset.from_dict({"input_ids": [[1] * 10] * N}) + + # Create config for GradientCollector + cfg = IndexConfig( + run_path=str(temp_dir / "run"), + skip_index=True, + ) + + processor = GradientProcessor(include_bias=True, projection_dim=None) + collector = GradientCollector( + model=model, + cfg=cfg, + data=dummy_data, + processor=processor, + target_modules={"fc"}, + ) + + with collector: + model.zero_grad() + output = model(x) + loss = (output**2).sum() + loss.backward() + + # Reshape from [N, O*(I+1)] to [N, O, I+1] to extract bias from last column + collected = collector.mod_grads["fc"].reshape(N, O, I + 1) + bias_grads = collected[..., -1] + + assert bias_grads.shape == ( + N, + O, + ), f"Expected shape ({N}, {O}), got {bias_grads.shape}" + assert ground_truth.shape == ( + N, + 3, + ), f"Expected shape ({N}, {O}), got {ground_truth.shape}" + + # Compare to ground truth + torch.testing.assert_close(bias_grads, ground_truth) + + +@pytest.mark.parametrize("include_bias", [True, False]) +def test_gradient_collector_with_projection( + include_bias: bool, trained_model_with_normalizers, test_params +): + """Test gradient collection with random projection and bias terms. + + Validates that combining random projection with bias collection works correctly: + - Verifies output shape is [N, projection_dim²] regardless of bias inclusion + - Checks gradients are non-zero (projection doesn't zero them out) + - Confirms deterministic behavior (same input = same output) + + This tests the critical path where bias gradients are concatenated to weight + gradients BEFORE applying the random projection, ensuring the projection + accounts for the increased dimensionality. + + Args: + include_bias: Whether to include bias gradients in collection + """ + temp_dir = Path(tempfile.mkdtemp()) + N, S, I = test_params["N"], test_params["S"], test_params["I"] + P = 4 # projection dimension + + model, normalizers = trained_model_with_normalizers(include_bias) + + # Create dummy dataset for GradientCollector + dummy_data = Dataset.from_dict({"input_ids": [[1] * 10] * N}) + + # Create config for GradientCollector + cfg = IndexConfig( + run_path=str(temp_dir / "run"), + skip_index=True, + ) + + processor = GradientProcessor( + normalizers=normalizers, projection_dim=P, include_bias=include_bias + ) + collector = GradientCollector( + model=model, + cfg=cfg, + data=dummy_data, + processor=processor, + target_modules={"fc1", "fc2"}, + ) + + x = torch.randn(N, S, I) + with collector: + model.zero_grad() + out = model(x) + loss = (out**2).sum() + loss.backward() + + # Check shapes - with projection, output should be [N, P*P] + for layer_name in ["fc1", "fc2"]: + collected = collector.mod_grads[layer_name] + assert collected.shape == ( + N, + P * P, + ), f"Expected shape ({N}, {P*P}), got {collected.shape} for {layer_name}" + + # Check that gradients are not all zeros + assert collected.abs().sum() > 0, f"Gradients are all zeros for {layer_name}" + + # Check determinism - running twice should give same results + with collector: + model.zero_grad() + out = model(x) + loss = (out**2).sum() + loss.backward() + + collected2 = collector.mod_grads[layer_name] + torch.testing.assert_close( + collected, collected2, msg=f"Gradients not deterministic for {layer_name}" + ) diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 0e3c6a71..ff9fba13 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -1,4 +1,10 @@ import os +from pathlib import Path + +from torch import nn + +from bergson import GradientProcessor +from bergson.gradients import AdafactorNormalizer, AdamNormalizer os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["WANDB_MODE"] = "disabled" @@ -6,7 +12,13 @@ import pytest import torch from datasets import Dataset -from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments +from transformers import ( + Adafactor, + AutoConfig, + AutoModelForCausalLM, + Trainer, + TrainingArguments, +) from trl import SFTConfig, SFTTrainer from bergson.data import load_gradients @@ -245,3 +257,132 @@ def test_sft_trainer(self, tmp_path, model, dataset): saved_order = Dataset.load_from_disk(str(order_file)) assert len(saved_order) > 0 assert all(key in saved_order[0] for key in ["_idx", "global_step", "epoch"]) + + @pytest.mark.parametrize("optimizer_name", ["adam", "adafactor"]) + @pytest.mark.parametrize("include_bias", [True, False]) + def test_optimizer_state_extraction(self, optimizer_name: str, include_bias: bool): + """Test that normalizers are correctly extracted from optimizer state. + + This tests the huggingface.py callback by: + 1. Training a model with an optimizer + 2. Calling the callback's on_step_end method + 3. Verifying against raw optimizer state + """ + torch.manual_seed(42) + N = 4 + S = 6 + I = 5 + O = 3 + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(I, O * 2, bias=include_bias) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(O * 2, O, bias=include_bias) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + torch.manual_seed(42) + model = SimpleModel() + + # Create optimizer + if optimizer_name == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + else: + optimizer = Adafactor( + model.parameters(), scale_parameter=False, relative_step=False, lr=0.001 + ) + + # Train a few steps to build up second moments + for _ in range(5): + optimizer.zero_grad() + out = model(torch.randn(N, S, I)) + loss = (out**2).sum() + loss.backward() + optimizer.step() + + # Extract normalizers using the ACTUAL callback + from unittest.mock import Mock, patch + + from bergson.huggingface import GradientCollectorCallback + + # Create callback with minimal setup + callback = GradientCollectorCallback( + path=Path("/tmp/test"), + use_optimizer_state=True, + include_bias=include_bias, + ) + + # Mock the collector and processor + mock_collector = Mock() + mock_collector.processor = GradientProcessor( + normalizers={}, include_bias=include_bias + ) + mock_collector.target_info = {"fc1": None, "fc2": None} # Track these layers + callback.collector = mock_collector + + # Mock on_substep_end to avoid needing train_grad_buffer + with patch.object(callback, "on_substep_end"): + # Call the ACTUAL callback method + callback.on_step_end( + args=Mock(), + state=Mock(epoch=0, global_step=1), + control=Mock(), + model=model, + optimizer=optimizer, + ) + + # Get the normalizers the callback extracted + normalizers = callback.collector.processor.normalizers + + # Verify against raw optimizer state (independent ground truth) + for layer_name in ["fc1", "fc2"]: + layer = model.get_submodule(layer_name) + norm = normalizers[layer_name] + + # Get raw state from optimizer + weight_state = optimizer.state[layer.weight] + lr = optimizer.param_groups[0]["lr"] + + if optimizer_name == "adam": + # Check normalizer type + assert isinstance(norm, AdamNormalizer) + + # Ground truth: Adam stores full exp_avg_sq + raw_exp_avg_sq = weight_state["exp_avg_sq"] + expected_avg_sq = raw_exp_avg_sq * lr + + torch.testing.assert_close(norm.avg_sq, expected_avg_sq) + + elif optimizer_name == "adafactor": + # Check normalizer type + assert isinstance(norm, AdafactorNormalizer) + + # Ground truth: Adafactor stores row/col directly + lr_sqrt = lr**0.5 + raw_row = weight_state["exp_avg_sq_row"] + raw_col = weight_state["exp_avg_sq_col"] + + # Our normalizer should match (scaled by LR) + expected_row = raw_row * lr_sqrt + expected_col = raw_col * lr_sqrt + + torch.testing.assert_close(norm.row, expected_row) + torch.testing.assert_close(norm.col, expected_col) + + # Verify bias handling + if include_bias and layer.bias is not None: + bias_state = optimizer.state[layer.bias] + raw_bias_exp_avg_sq = bias_state["exp_avg_sq"] + expected_bias = raw_bias_exp_avg_sq * lr + + assert ( + norm.bias_avg_sq is not None + ), f"Expected bias_avg_sq for {layer_name}" + torch.testing.assert_close(norm.bias_avg_sq, expected_bias) + else: + assert ( + norm.bias_avg_sq is None + ), f"Unexpected bias_avg_sq for {layer_name}"