Skip to content

Conversation

@baberabb
Copy link

@baberabb baberabb commented Nov 14, 2025

Couple of bug fixes to do with bias:

  • The gradients were contracted over both the batch and sequence dimension (dim=(0,1)), rather than just the sequence (dim=1).
  • Normalize weights with Adam before concatenating bias to avoid shape mismatch ([N, O, I+1] / [O, I] division error). The biases are currently concatenated raw, as I wasn't sure the best way to handle them. More in comment.

update:

  • Added bias_avg_sq field to AdafactorNormalizer and AdamNormalizer to keep track of the bias second moments so we can handle bias normalization separately from weight gradients in AdafactorNormalizer.normalize_():
    • Normalize bias from raw gradient G before weight processing
    • Sum bias gradients over sequence dimension
    • Append normalized bias as extra column when include_bias=True

Modified GradientCollectorCallback (with help from claude):

  • Extract bias second moments from both adam and adafactor optimziers
  • added scale_by_lr(lr) method to AdafactorNormalizer (also fixes bug where optimizer state tensors were being modified in-place)
  • added test_optimizer_state_extraction

Also added some unit tests. #75 should probably be merged before this.

Someone better at linear algebra than me should probably have a look at this as well.

@luciaquirke
Copy link
Collaborator

This is fabulous, thank you!! 🙏 Interested to hear what Nora thinks but I reckon exposing second moments for bias through the normalizer would be great

@luciaquirke
Copy link
Collaborator

Running

pip install -e ".[dev]"
pre-commit install

Should add formatting on commit, let me know if that doesn't work for some reason

@baberabb
Copy link
Author

Running

pip install -e ".[dev]"
pre-commit install

Should add formatting on commit, let me know if that doesn't work for some reason

oh yeah, it was a problem with the ruff linter. it doesn't fix line length errors (leaves that to the formatter). Will add black back.

@baberabb baberabb force-pushed the bias branch 2 times, most recently from 0a1cdb2 to 31e5008 Compare November 18, 2025 00:30
@luciaquirke
Copy link
Collaborator

@LouisYRYJ if we merge this in the next few days will it interfere with your big PR?

@baberabb we currently can't merge this because it breaks the build

@baberabb
Copy link
Author

baberabb commented Dec 16, 2025

removed the workflow files! do you want me to rebase this on the other PR branch?

@luciaquirke
Copy link
Collaborator

luciaquirke commented Dec 16, 2025

Louis' PR just merged!! If you can rebase this on main we should be able to merge it too 🙏 🙏 🚀 TODO me do another once over to remember where we're at with the normalizers too

bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
)

def scale_by_lr(self, lr: float | Tensor) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this method, but let me know if not necessary. Also it does in-place ops, mirroring normalize, but maybe new tensors would be better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the advantage of new tensors. If they are in-place operations I would consider just calling the function, i.e.
self.row.mul_(lr_sqrt)

instead of doing
self.row = self.row.mul_(lr_sqrt)

Otherwise one may think these are not in-place?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was waffling between the two, hence this cursed syntax. New tensors, since these are references initially, and easy to forget to replace before calling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe choose only one then?

i = i + 1
setattr(module, LayerAdapter.in_attr(module), i)
if p is not None:
# Only project if no bias (bias requires full gradient to be materialized)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this. You can project with bias, if you have no normalizer right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup. added it back in

if isinstance(normalizer, AdafactorNormalizer):
bias_grad = None

match normalizer:
Copy link
Author

@baberabb baberabb Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are quite repetitive, but I thought it would be more readable and less error-prone if the logic of each case is separate.

Also what are your thoughts on upstreaming this to HookCollectorBase. All the other classes seem to be using the same logic, and people can always overload it if they need to do something custom

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is quite repetitive. What exactly do you want to upstream? the forward and backward methods? I will be adding instances of HookBaseCollector in near future that will not be of this form. But could have a generic GradientCollector thing maybe that is then further inherited?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants