Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 99 additions & 31 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,28 @@ def setup(self) -> None:

Sets up a Builder for gradient storage if not using a Scorer.
"""
model_device = (
getattr(self.model, "device", None) or next(self.model.parameters()).device
)
model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

assert isinstance(
self.model.device, torch.device
model_device, torch.device
), "Model device is not set correctly"
if self.cfg.include_bias and self.processor.normalizers is not None:
raise NotImplementedError(
"Bias with normalizers not supported yet, "
"consider disabling bias inclusion for now."
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
self.hi = torch.finfo(self.save_dtype).max

self.per_doc_losses = torch.full(
(len(self.data),),
device=self.model.device,
device=model_device,
dtype=self.save_dtype,
fill_value=0.0,
)
Expand Down Expand Up @@ -143,16 +145,18 @@ def forward_hook(self, module: nn.Module, a: Float[Tensor, "N S I"]) -> None:
a_factor = a_factor.rsqrt()
a = a * a_factor.type_as(a) # [N, S, I] * [I] → [N, S, I]

if module._has_bias:
# Append ones to activation for bias term
# For normalizer cases, bias normalization differs from weight normalization,
# so we handle bias separately in backward hook
if module._has_bias and normalizer is None:
ones = torch.ones(a.size(0), a.size(1), 1, device=a.device, dtype=a.dtype)
a = torch.cat([a, ones], dim=-1)
a = torch.cat([a, ones], dim=-1) # [N, S, I+1]
i = i + 1
setattr(module, LayerAdapter.in_attr(module), i)
if p is not None:

if p is not None and (normalizer is None or not module._has_bias):
a_projection = self.projection(name, p, i, "right", a.device, a.dtype).T
a = a @ a_projection # type: ignore
# set module._inputs to a
a = a @ a_projection # [N, S, I(+1)] @ [I(+1), p] → [N, S, p]

module._inputs = a

@HookCollectorBase.split_attention_heads
Expand All @@ -171,25 +175,79 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
i = getattr(module, LayerAdapter.in_attr(module))
o = getattr(module, LayerAdapter.out_attr(module))
normalizer = self.processor.normalizers.get(name)

if isinstance(normalizer, AdamNormalizer):
full_gradient = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I]
P = normalizer.normalize_(full_gradient)
if p is not None:
g_projection = self.projection(name, p, o, "left", g.device, g.dtype)
a_projection = self.projection(name, p, i, "right", g.device, g.dtype).T
P = g_projection @ P @ a_projection
else:
if isinstance(normalizer, AdafactorNormalizer):
bias_grad = None

match normalizer:
Copy link
Author

@baberabb baberabb Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are quite repetitive, but I thought it would be more readable and less error-prone if the logic of each case is separate.

Also what are your thoughts on upstreaming this to HookCollectorBase. All the other classes seem to be using the same logic, and people can always overload it if they need to do something custom

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is quite repetitive. What exactly do you want to upstream? the forward and backward methods? I will be adding instances of HookBaseCollector in near future that will not be of this form. But could have a generic GradientCollector thing maybe that is then further inherited?

case AdamNormalizer():
if module._has_bias and normalizer.bias_avg_sq is not None:
# Normalize bias with bias second moments # [N, S, O] → [N, O]
bias_grad = g.sum(dim=1) / normalizer.bias_avg_sq.sqrt().add(1e-8)

P = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I]
P = normalizer.normalize_(P)

# Append pre-normalized bias gradient
if bias_grad is not None:
P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1]
i += 1

if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
a_projection = self.projection(
name, p, i, "right", g.device, g.dtype
).T
P = g_projection @ P @ a_projection

case AdafactorNormalizer():
if module._has_bias and normalizer.bias_avg_sq is not None:
# Compute bias from RAW g (before row normalization)
bias_grad = g.sum(dim=1) # [N, S, O] → [N, O]
# Normalize bias with bias second moments
bias_grad = bias_grad / normalizer.bias_avg_sq.add(1e-30).sqrt()

# Apply row normalization to g (for weights)
g_factor = normalizer.row.add(1e-30)
g_factor = g_factor.mean().sqrt() * g_factor.rsqrt()
g = g * g_factor.type_as(g) # [N, S, O] * [O] → [N, S, O]

if p is not None:
g_projection = self.projection(name, p, o, "left", g.device, g.dtype)
g = g @ g_projection.T # [N, S, p]

P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]
# If bias is present, materialize full gradient then project
if bias_grad is not None:
P = g.mT @ a # [N, O, I]

# Append pre-normalized bias gradient
P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1]
i += 1

# Project the entire normalized gradient
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
a_projection = self.projection(
name, p, i, "right", a.device, a.dtype
).T
P = g_projection @ P @ a_projection
else:
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
g = g @ g_projection.T # [N, S, p]

P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]

case None:
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
g = g @ g_projection.T # [N, S, p]

P = g.mT @ a # [N, O/p, I(+1)/p]
case _:
raise ValueError(f"Unknown normalizer type {type(normalizer)}")

P = P.flatten(1).clamp_(self.lo, self.hi)

Expand Down Expand Up @@ -293,9 +351,14 @@ class TraceCollector(HookCollectorBase):
"""Dtype for stored gradients."""

def setup(self) -> None:

model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
Expand Down Expand Up @@ -403,9 +466,14 @@ class StreamingGradientCollector(HookCollectorBase):
"""Dtype for stored gradients."""

def setup(self) -> None:

model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
Expand Down
41 changes: 40 additions & 1 deletion bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.

Args:
row: Row statistics [O]
col: Column statistics [I]
bias_avg_sq: Optional second moments for bias [O]
"""

row: Tensor # shape [O]
col: Tensor # shape [I]
bias_avg_sq: Tensor | None = None # shape [O]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
if self.bias_avg_sq is not None:
assert (
self.bias_avg_sq.ndim == 1
), f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D"

@torch.compile
def normalize_(
Expand Down Expand Up @@ -114,22 +124,40 @@ def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.

Preserves bias_avg_sq if present.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)

def scale_by_lr(self, lr: float | Tensor) -> None:
"""Scale normalizer by learning rate.

Factorized dimensions (row, col) are scaled by lr.
Bias is scaled by lr**2.
"""
lr_sqrt = lr**0.5
self.row.mul_(lr_sqrt)
self.col.mul_(lr_sqrt)
self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.

Args:
avg_sq: Second moments for weights [O, I]
bias_avg_sq: Optional second moments for bias [O]
"""

avg_sq: Tensor
bias_avg_sq: Tensor | None = None

@torch.compile
def normalize_(
Expand All @@ -147,6 +175,8 @@ def to_adafactor(self) -> AdafactorNormalizer:
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.

Preserves bias_avg_sq if present.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
Expand All @@ -157,8 +187,17 @@ def to_adafactor(self) -> AdafactorNormalizer:
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
)

def scale_by_lr(self, lr: float | Tensor) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this method, but let me know if not necessary. Also it does in-place ops, mirroring normalize, but maybe new tensors would be better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the advantage of new tensors. If they are in-place operations I would consider just calling the function, i.e.
self.row.mul_(lr_sqrt)

instead of doing
self.row = self.row.mul_(lr_sqrt)

Otherwise one may think these are not in-place?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was waffling between the two, hence this cursed syntax. New tensors, since these are references initially, and easy to forget to replace before calling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe choose only one then?

"""Scale normalizer to incorporate learning rate.

Both avg_sq and bias_avg_sq are divided by lr².
"""
self.avg_sq.mul_(lr)
self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None


@dataclass
class GradientProcessor:
Expand Down
81 changes: 62 additions & 19 deletions bergson/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def on_step_end(
**kwargs,
):
self.on_substep_end(args, state, control)
print("Step end")

# Record training order if enabled
if self.order is not None:
Expand Down Expand Up @@ -271,40 +270,84 @@ def on_step_end(
for name, param in model.named_parameters()
if param.requires_grad
}
normalizers: dict[str, AdafactorNormalizer] = {}
normalizers: dict[str, AdafactorNormalizer | AdamNormalizer] = {}

assert self.collector is not None
proc = self.collector.processor
proc.normalizers = {}

# Read normalizers off of the optimizer state. We need to figure out
# what type of optimizer this is first.
# Collect references to both weight and bias second moments per layer
layer_second_moments: dict[str, dict[str, Tensor]] = {}

for group in optimizer.param_groups:
lr_sqrt = group["lr"] ** 0.5
group_lr = group["lr"]

for param in group["params"]:
name = param_to_name[param].removesuffix(".weight")
if name not in self.collector.target_info:
param_name = param_to_name[param]

# Extract layer name (remove .weight or .bias suffix)
if param_name.endswith(".weight"):
param_type = "weight"
layer_name = param_name.removesuffix(".weight")
elif param_name.endswith(".bias"):
param_type = "bias"
layer_name = param_name.removesuffix(".bias")
else:
continue

if layer_name not in self.collector.target_info:
continue

p_state = optimizer.state[param]

# Adam-like optimizer
if (eas := p_state.get("exp_avg_sq")) is not None:
norm = AdamNormalizer(eas).to_adafactor()
# Initialize layer dict if needed, storing this group's learning rate
if layer_name not in layer_second_moments:
layer_second_moments[layer_name] = {"lr": group_lr}

# Adafactor-like optimizer
elif (vr := p_state.get("exp_avg_sq_row")) is not None:
# Check for Adafactor FIRST (more specific than Adam)
# Adafactor-like optimizer: weights have factorized moments
if (vr := p_state.get("exp_avg_sq_row")) is not None:
vc = p_state.get("exp_avg_sq_col")
norm = AdafactorNormalizer(vr, vc)
else:
continue

# Scale the gradient by the current learning rate. It's factorized
# so we multiply each factor by the square root of the LR.
norm.row *= lr_sqrt
norm.col *= lr_sqrt
normalizers[name] = norm
if param_type == "weight":
# Factorized second moments for weights
layer_second_moments[layer_name]["row"] = vr
layer_second_moments[layer_name]["col"] = vc
elif param_type == "bias":
# Adafactor stores bias as regular exp_avg_sq
bias_eas = p_state.get("exp_avg_sq")
if bias_eas is not None:
layer_second_moments[layer_name]["bias"] = bias_eas
# Adam-like optimizer: has exp_avg_sq for both weight and bias
elif (eas := p_state.get("exp_avg_sq")) is not None:
layer_second_moments[layer_name][param_type] = eas

# Build normalizers from collected second moments
for layer_name, moments in layer_second_moments.items():
lr = moments["lr"]

# Adam-like: has weight exp_avg_sq
if "weight" in moments:
weight_eas = moments["weight"] * lr
bias_eas = moments.get("bias")
bias_eas = bias_eas * lr if bias_eas is not None else None

norm = AdamNormalizer(weight_eas, bias_eas)

# Adafactor-like: has row/col factorization
elif "row" in moments and "col" in moments:
row = moments["row"] * lr**0.5
col = moments["col"] * lr**0.5
bias_eas = moments.get("bias")
bias_eas = bias_eas * lr if bias_eas is not None else None

norm = AdafactorNormalizer(row, col, bias_eas)
else:
# No weight moments found - skip this layer
continue

normalizers[layer_name] = norm

proc.normalizers = normalizers

Expand Down
Loading
Loading