-
Notifications
You must be signed in to change notification settings - Fork 12
feat(autojac): Use jac_to_grad to aggregate .jac fields #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
99b4260
feat: Use .jac fields
ValerianRey 7dda66f
Delete jac field instead of setting to None
ValerianRey e08e45f
[WIP] Fix jac undefined errors
ValerianRey 0bcf1c0
Merge branch 'main' into revamp-interface
ValerianRey 57e5c6d
Add check that the number of rows of the jacobians is consistant
ValerianRey 6ea5983
Add check of jac shape before assigning to .jac
ValerianRey b7b3a75
Add changelog entry
ValerianRey 28cf701
Move check of jacobian shape outside of if/else
ValerianRey 9a2a0ec
Improve docstring of AccumulateJac and AccumulateGrad
ValerianRey 9e855e2
Merge branch 'main' into revamp-interface
ValerianRey 382e9d3
Fix tests
ValerianRey 7e95c37
Simplify a test
ValerianRey 1876351
Add unit tests for AccumulateJac
ValerianRey 4f24e39
Refactor accumulate tests to use loops and assert helpers
ValerianRey b1aaee9
Move newly added functions
ValerianRey ce9231b
Rename retain_jacs to retain_jac
ValerianRey c353713
Rename params to tensors in jac_to_grad
ValerianRey 5cf8c1c
Ad jac_to_grad tests
ValerianRey e93f2d9
Remove duplicated optimizer.zero_grad() lines
ValerianRey 2fb6856
Fix formulation about freeing jacs
ValerianRey 57fe5b4
Add doc entry for jac_to_grad and usage example
ValerianRey 0a1fc21
Add comments in jac_to_grad example
ValerianRey f1ee074
Fix docstring of test_backward.py
ValerianRey 8b3d447
Fix formatting in backward docstring
ValerianRey 87b66f8
Fix comment in accumulate_jacs that applied to accumulate_grads
ValerianRey 1394395
Fix error message in _check_expects_grad
ValerianRey 16349a0
Fix wrong import in basic_usage.rst
ValerianRey 0a8cc62
Add explanation about how jac_to_grad works in jac_to_grad's docstring
ValerianRey c13a75b
Improve description of parameters in jac_to_grad
ValerianRey 674f6ad
Improve error message and usage example of jac_to_grad
ValerianRey 8a0fb0e
Make _disunite_gradient use less memory
ValerianRey 0e8add2
Free .jacs earlier to divide by two peak memory
ValerianRey 430a8a2
Use Tensor.split in _disunit_gradient
ValerianRey f0fe529
Add kwargs to assert_jac_close and assert_grad_close
ValerianRey cff6d8e
Rename expected_jacobian to J in some test
ValerianRey 84bd552
Move asserts to tests/utils and use them in doc tests
ValerianRey 4bb561d
Rename test and update docstring to match its changes
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,3 +10,4 @@ autojac | |
|
|
||
| backward.rst | ||
| mtl_backward.rst | ||
| jac_to_grad.rst | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| :hide-toc: | ||
|
|
||
| jac_to_grad | ||
| =========== | ||
|
|
||
| .. autofunction:: torchjd.autojac.jac_to_grad |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| from collections.abc import Iterable | ||
| from typing import cast | ||
|
|
||
| from torch import Tensor | ||
|
|
||
|
|
||
| class TensorWithJac(Tensor): | ||
| """ | ||
| Tensor known to have a populated jac field. | ||
|
|
||
| Should not be directly instantiated, but can be used as a type hint and can be casted to. | ||
| """ | ||
|
|
||
| jac: Tensor | ||
|
|
||
|
|
||
| def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: | ||
| for param, jac in zip(params, jacobians, strict=True): | ||
| _check_expects_grad(param, field_name=".jac") | ||
| # We that the shape is correct to be consistent with torch, that checks that the grad | ||
| # shape is correct before assigning it. | ||
| if jac.shape[1:] != param.shape: | ||
| raise RuntimeError( | ||
| f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " | ||
| f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" | ||
| " jacobian are the same size" | ||
| ) | ||
|
|
||
| if hasattr(param, "jac"): # No check for None because jac cannot be None | ||
| param_ = cast(TensorWithJac, param) | ||
| param_.jac += jac | ||
| else: | ||
| # We do not clone the value to save memory and time, so subsequent modifications of | ||
| # the value of key.jac (subsequent accumulations) will also affect the value of | ||
| # jacobians[key] and outside changes to the value of jacobians[key] will also affect | ||
| # the value of key.jac. So to be safe, the values of jacobians should not be used | ||
| # anymore after being passed to this function. | ||
| # | ||
| # We do not detach from the computation graph because the value can have grad_fn | ||
| # that we want to keep track of (in case it was obtained via create_graph=True). | ||
| param.__setattr__("jac", jac) | ||
|
|
||
|
|
||
| def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: | ||
| for param, grad in zip(params, gradients, strict=True): | ||
| _check_expects_grad(param, field_name=".grad") | ||
| if hasattr(param, "grad") and param.grad is not None: | ||
| param.grad += grad | ||
| else: | ||
| param.grad = grad | ||
|
|
||
|
|
||
| def _check_expects_grad(tensor: Tensor, field_name: str) -> None: | ||
| if not _expects_grad(tensor): | ||
| raise ValueError( | ||
| f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n" | ||
| "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." | ||
| ) | ||
|
|
||
|
|
||
| def _expects_grad(tensor: Tensor) -> bool: | ||
| """ | ||
| Determines whether a Tensor expects its .grad attribute to be populated. | ||
| See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. | ||
| """ | ||
|
|
||
| return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.