-
Notifications
You must be signed in to change notification settings - Fork 12
Fix bias gradient computations #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c74db9
ac89086
d581873
b88717d
f433265
d41264e
db58e89
46bdd34
e488e29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,14 +62,24 @@ def state_dict(self) -> dict[str, str | Tensor]: | |
| class AdafactorNormalizer(Normalizer): | ||
| """ | ||
| Row and column sums of second moments of gradients for a matrix-valued parameter. | ||
|
|
||
| Args: | ||
| row: Row statistics [O] | ||
| col: Column statistics [I] | ||
| bias_avg_sq: Optional second moments for bias [O] | ||
| """ | ||
|
|
||
| row: Tensor # shape [O] | ||
| col: Tensor # shape [I] | ||
| bias_avg_sq: Tensor | None = None # shape [O] | ||
|
|
||
| def __post_init__(self): | ||
| assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" | ||
| assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" | ||
| if self.bias_avg_sq is not None: | ||
| assert ( | ||
| self.bias_avg_sq.ndim == 1 | ||
| ), f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D" | ||
|
|
||
| @torch.compile | ||
| def normalize_( | ||
|
|
@@ -114,22 +124,40 @@ def to_adam(self) -> "AdamNormalizer": | |
| """ | ||
| Convert this Adafactor normalizer to an Adam normalizer by materializing the | ||
| rank-one second moment matrix. | ||
|
|
||
| Preserves bias_avg_sq if present. | ||
| """ | ||
| # Compute the second moment matrix as a square matrix of shape [O, I] | ||
| # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to | ||
| # add it outside the square root. This could cause infs though if there are | ||
| # any exactly zero rows or columns, so we should be careful. | ||
| avg_sq = torch.outer(self.row, self.col) / self.row.mean() | ||
| return AdamNormalizer(avg_sq=avg_sq) | ||
| return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq) | ||
|
|
||
| def scale_by_lr(self, lr: float | Tensor) -> None: | ||
| """Scale normalizer by learning rate. | ||
|
|
||
| Factorized dimensions (row, col) are scaled by lr. | ||
| Bias is scaled by lr**2. | ||
| """ | ||
| lr_sqrt = lr**0.5 | ||
| self.row.mul_(lr_sqrt) | ||
| self.col.mul_(lr_sqrt) | ||
| self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdamNormalizer(Normalizer): | ||
| """ | ||
| Contains the second moments of the gradients. | ||
|
|
||
| Args: | ||
| avg_sq: Second moments for weights [O, I] | ||
| bias_avg_sq: Optional second moments for bias [O] | ||
| """ | ||
|
|
||
| avg_sq: Tensor | ||
| bias_avg_sq: Tensor | None = None | ||
|
|
||
| @torch.compile | ||
| def normalize_( | ||
|
|
@@ -147,6 +175,8 @@ def to_adafactor(self) -> AdafactorNormalizer: | |
| Convert this Adam normalizer to an Adafactor normalizer, minimizing the | ||
| I-divergence (generalized Kullback-Leibler divergence) between the original | ||
| and the factored second moments. | ||
|
|
||
| Preserves bias_avg_sq if present. | ||
| """ | ||
| # We assume avg_sq is a square matrix of shape [O, I] | ||
| assert ( | ||
|
|
@@ -157,8 +187,17 @@ def to_adafactor(self) -> AdafactorNormalizer: | |
| return AdafactorNormalizer( | ||
| row=self.avg_sq.mean(dim=1), # shape [O] | ||
| col=self.avg_sq.mean(dim=0), # shape [I] | ||
| bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments | ||
| ) | ||
|
|
||
| def scale_by_lr(self, lr: float | Tensor) -> None: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this method, but let me know if not necessary. Also it does in-place ops, mirroring
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the advantage of new tensors. If they are in-place operations I would consider just calling the function, i.e. instead of doing Otherwise one may think these are not in-place?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was waffling between the two, hence this cursed syntax. New tensors, since these are references initially, and easy to forget to replace before calling.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we maybe choose only one then? |
||
| """Scale normalizer to incorporate learning rate. | ||
|
|
||
| Both avg_sq and bias_avg_sq are divided by lr². | ||
| """ | ||
| self.avg_sq.mul_(lr) | ||
| self.bias_avg_sq.mul_(lr) if self.bias_avg_sq is not None else None | ||
|
|
||
|
|
||
| @dataclass | ||
| class GradientProcessor: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are quite repetitive, but I thought it would be more readable and less error-prone if the logic of each case is separate.
Also what are your thoughts on upstreaming this to
HookCollectorBase. All the other classes seem to be using the same logic, and people can always overload it if they need to do something customThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it is quite repetitive. What exactly do you want to upstream? the forward and backward methods? I will be adding instances of HookBaseCollector in near future that will not be of this form. But could have a generic GradientCollector thing maybe that is then further inherited?