Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 114 additions & 32 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,28 @@ def setup(self) -> None:

Sets up a Builder for gradient storage if not using a Scorer.
"""
model_device = (
getattr(self.model, "device", None) or next(self.model.parameters()).device
)
model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

assert isinstance(
self.model.device, torch.device
model_device, torch.device
), "Model device is not set correctly"
if self.cfg.include_bias and self.processor.normalizers is not None:
raise NotImplementedError(
"Bias with normalizers not supported yet, "
"consider disabling bias inclusion for now."
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
self.hi = torch.finfo(self.save_dtype).max

self.per_doc_losses = torch.full(
(len(self.data),),
device=self.model.device,
device=model_device,
dtype=self.save_dtype,
fill_value=0.0,
)
Expand Down Expand Up @@ -143,13 +145,8 @@ def forward_hook(self, module: nn.Module, a: Float[Tensor, "N S I"]) -> None:
a_factor = a_factor.rsqrt()
a = a * a_factor.type_as(a) # [N, S, I] * [I] → [N, S, I]

if module._has_bias:
# Append ones to activation for bias term
ones = torch.ones(a.size(0), a.size(1), 1, device=a.device, dtype=a.dtype)
a = torch.cat([a, ones], dim=-1)
i = i + 1
setattr(module, LayerAdapter.in_attr(module), i)
if p is not None:
# Only project if no bias (bias requires full gradient to be materialized)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this. You can project with bias, if you have no normalizer right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup. added it back in

if p is not None and not module._has_bias:
a_projection = self.projection(name, p, i, "right", a.device, a.dtype).T
a = a @ a_projection # type: ignore
# set module._inputs to a
Expand All @@ -171,25 +168,100 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
i = getattr(module, LayerAdapter.in_attr(module))
o = getattr(module, LayerAdapter.out_attr(module))
normalizer = self.processor.normalizers.get(name)

if isinstance(normalizer, AdamNormalizer):
full_gradient = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I]
P = normalizer.normalize_(full_gradient)
if p is not None:
g_projection = self.projection(name, p, o, "left", g.device, g.dtype)
a_projection = self.projection(name, p, i, "right", g.device, g.dtype).T
P = g_projection @ P @ a_projection
else:
if isinstance(normalizer, AdafactorNormalizer):
bias_grad = None

match normalizer:
Copy link
Author

@baberabb baberabb Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are quite repetitive, but I thought it would be more readable and less error-prone if the logic of each case is separate.

Also what are your thoughts on upstreaming this to HookCollectorBase. All the other classes seem to be using the same logic, and people can always overload it if they need to do something custom

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is quite repetitive. What exactly do you want to upstream? the forward and backward methods? I will be adding instances of HookBaseCollector in near future that will not be of this form. But could have a generic GradientCollector thing maybe that is then further inherited?

case AdamNormalizer():
if module._has_bias and normalizer.bias_avg_sq is not None:
# Normalize bias with bias second moments # [N, S, O] → [N, O]
bias_grad = g.sum(dim=1) / normalizer.bias_avg_sq.sqrt().add(1e-8)

P = g.mT @ a # [N, O, S] @ [N, S, I] → [N, O, I]
P = normalizer.normalize_(P)

# Append pre-normalized bias gradient
if bias_grad is not None:
P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1]
i += 1

if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
a_projection = self.projection(
name, p, i, "right", g.device, g.dtype
).T
P = g_projection @ P @ a_projection

case AdafactorNormalizer():
if module._has_bias and normalizer.bias_avg_sq is not None:
# Compute bias from RAW g (before row normalization)
bias_grad = g.sum(dim=1) # [N, S, O] → [N, O]
# Normalize bias with bias second moments
bias_grad = bias_grad / normalizer.bias_avg_sq.add(1e-30).sqrt()

# Apply row normalization to g (for weights)
g_factor = normalizer.row.add(1e-30)
g_factor = g_factor.mean().sqrt() * g_factor.rsqrt()
g = g * g_factor.type_as(g) # [N, S, O] * [O] → [N, S, O]

if p is not None:
g_projection = self.projection(name, p, o, "left", g.device, g.dtype)
g = g @ g_projection.T # [N, S, p]

P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]
# If bias is present, materialize full gradient then project
if bias_grad is not None:
P = g.mT @ a # [N, O, I/q]

# Append pre-normalized bias gradient
P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1]
i += 1

# Project the entire normalized gradient
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
a_projection = self.projection(
name, p, i, "right", a.device, a.dtype
).T
P = g_projection @ P @ a_projection
else:
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
g = g @ g_projection.T # [N, S, p]

P = g.mT @ a # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]

case None:
if module._has_bias:
# Compute bias from RAW g (before any projection)
bias_grad = g.sum(dim=1) # [N, S, O] → [N, O]

# Materialize full gradient then project if needed
P = g.mT @ a # [N, O, I]

# Append bias gradient
P = torch.cat([P, bias_grad.unsqueeze(2)], dim=2) # [N, O, I+1]
i += 1

# Project the entire gradient if needed
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
a_projection = self.projection(
name, p, i, "right", a.device, a.dtype
).T
P = g_projection @ P @ a_projection
else:
if p is not None:
g_projection = self.projection(
name, p, o, "left", g.device, g.dtype
)
g = g @ g_projection.T

P = g.mT @ a # [N, O/p, I/q]
case _:
raise ValueError(f"Unknown normalizer type {type(normalizer)}")

P = P.flatten(1).clamp_(self.lo, self.hi)

Expand Down Expand Up @@ -293,9 +365,14 @@ class TraceCollector(HookCollectorBase):
"""Dtype for stored gradients."""

def setup(self) -> None:

model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
Expand Down Expand Up @@ -403,9 +480,14 @@ class StreamingGradientCollector(HookCollectorBase):
"""Dtype for stored gradients."""

def setup(self) -> None:

model_dtype = (
getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
)

# TODO: handle more elegantly?
self.save_dtype = (
torch.float32 if self.model.dtype == torch.float32 else torch.float16
torch.float32 if model_dtype == torch.float32 else torch.float16
)

self.lo = torch.finfo(self.save_dtype).min
Expand Down
45 changes: 44 additions & 1 deletion bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.

Args:
row: Row statistics [O]
col: Column statistics [I]
bias_avg_sq: Optional second moments for bias [O]
"""

row: Tensor # shape [O]
col: Tensor # shape [I]
bias_avg_sq: Tensor | None = None # shape [O]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
if self.bias_avg_sq is not None:
assert (
self.bias_avg_sq.ndim == 1
), f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D"

@torch.compile
def normalize_(
Expand Down Expand Up @@ -114,22 +124,42 @@ def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.

Preserves bias_avg_sq if present.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)

def scale_by_lr(self, lr: float | Tensor) -> None:
"""Scale normalizer by learning rate.

Factorized dimensions (row, col) are scaled by lr.
Bias is scaled by lr**2.
"""
lr_sqrt = lr**0.5
self.row = self.row.mul_(lr_sqrt)
self.col = self.col.mul_(lr_sqrt)
self.bias_avg_sq = (
self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None
)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.

Args:
avg_sq: Second moments for weights [O, I]
bias_avg_sq: Optional second moments for bias [O]
"""

avg_sq: Tensor
bias_avg_sq: Tensor | None = None

@torch.compile
def normalize_(
Expand All @@ -147,6 +177,8 @@ def to_adafactor(self) -> AdafactorNormalizer:
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.

Preserves bias_avg_sq if present.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
Expand All @@ -157,6 +189,17 @@ def to_adafactor(self) -> AdafactorNormalizer:
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
)

def scale_by_lr(self, lr: float | Tensor) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this method, but let me know if not necessary. Also it does in-place ops, mirroring normalize, but maybe new tensors would be better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the advantage of new tensors. If they are in-place operations I would consider just calling the function, i.e.
self.row.mul_(lr_sqrt)

instead of doing
self.row = self.row.mul_(lr_sqrt)

Otherwise one may think these are not in-place?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was waffling between the two, hence this cursed syntax. New tensors, since these are references initially, and easy to forget to replace before calling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe choose only one then?

"""Scale normalizer to incorporate learning rate.

Both avg_sq and bias_avg_sq are divided by lr².
"""
self.avg_sq = self.avg_sq.mul_(lr)
self.bias_avg_sq = (
self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None
)


Expand Down
Loading
Loading