-
Notifications
You must be signed in to change notification settings - Fork 12
Fix bias gradient computations #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This is fabulous, thank you!! 🙏 Interested to hear what Nora thinks but I reckon exposing second moments for bias through the normalizer would be great |
|
Running Should add formatting on commit, let me know if that doesn't work for some reason |
oh yeah, it was a problem with the ruff linter. it doesn't fix line length errors (leaves that to the formatter). Will add black back. |
0a1cdb2 to
31e5008
Compare
|
@LouisYRYJ if we merge this in the next few days will it interfere with your big PR? @baberabb we currently can't merge this because it breaks the build |
|
removed the workflow files! do you want me to rebase this on the other PR branch? |
|
Louis' PR just merged!! If you can rebase this on main we should be able to merge it too 🙏 🙏 🚀 TODO me do another once over to remember where we're at with the normalizers too |
for more information, see https://pre-commit.ci
| bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments | ||
| ) | ||
|
|
||
| def scale_by_lr(self, lr: float | Tensor) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this method, but let me know if not necessary. Also it does in-place ops, mirroring normalize, but maybe new tensors would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the advantage of new tensors. If they are in-place operations I would consider just calling the function, i.e.
self.row.mul_(lr_sqrt)
instead of doing
self.row = self.row.mul_(lr_sqrt)
Otherwise one may think these are not in-place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was waffling between the two, hence this cursed syntax. New tensors, since these are references initially, and easy to forget to replace before calling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we maybe choose only one then?
| i = i + 1 | ||
| setattr(module, LayerAdapter.in_attr(module), i) | ||
| if p is not None: | ||
| # Only project if no bias (bias requires full gradient to be materialized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by this. You can project with bias, if you have no normalizer right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup. added it back in
| if isinstance(normalizer, AdafactorNormalizer): | ||
| bias_grad = None | ||
|
|
||
| match normalizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are quite repetitive, but I thought it would be more readable and less error-prone if the logic of each case is separate.
Also what are your thoughts on upstreaming this to HookCollectorBase. All the other classes seem to be using the same logic, and people can always overload it if they need to do something custom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it is quite repetitive. What exactly do you want to upstream? the forward and backward methods? I will be adding instances of HookBaseCollector in near future that will not be of this form. But could have a generic GradientCollector thing maybe that is then further inherited?
Couple of bug fixes to do with bias:
dim=(0,1)), rather than just the sequence (dim=1).[N, O, I+1] / [O, I]division error).The biases are currently concatenated raw, as I wasn't sure the best way to handle them. More in comment.update:
bias_avg_sqfield toAdafactorNormalizerandAdamNormalizerto keep track of the bias second moments so we can handle bias normalization separately from weight gradients inAdafactorNormalizer.normalize_():Gbefore weight processingModified
GradientCollectorCallback(with help from claude):scale_by_lr(lr)method toAdafactorNormalizer(also fixes bug where optimizer state tensors were being modified in-place)test_optimizer_state_extractionAlso added some unit tests. #75 should probably be merged before this.Someone better at linear algebra than me should probably have a look at this as well.