-
Notifications
You must be signed in to change notification settings - Fork 12
feat: Add autojac.jac #505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ValerianRey
wants to merge
2
commits into
main
Choose a base branch
from
add-autojac-jac
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,3 +10,4 @@ autojac | |
|
|
||
| backward.rst | ||
| mtl_backward.rst | ||
| jac.rst | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| :hide-toc: | ||
|
|
||
| jac | ||
| === | ||
|
|
||
| .. autofunction:: torchjd.autojac.jac |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| from collections.abc import Sequence | ||
| from typing import Iterable | ||
|
|
||
| from torch import Tensor | ||
|
|
||
| from torchjd.autojac._transform._base import Transform | ||
| from torchjd.autojac._transform._diagonalize import Diagonalize | ||
| from torchjd.autojac._transform._init import Init | ||
| from torchjd.autojac._transform._jac import Jac | ||
| from torchjd.autojac._transform._ordered_set import OrderedSet | ||
| from torchjd.autojac._utils import ( | ||
| as_checked_ordered_set, | ||
| check_optional_positive_chunk_size, | ||
| get_leaf_tensors, | ||
| ) | ||
|
|
||
|
|
||
| def jac( | ||
| outputs: Sequence[Tensor] | Tensor, | ||
| inputs: Iterable[Tensor] | None = None, | ||
| retain_graph: bool = False, | ||
| parallel_chunk_size: int | None = None, | ||
| ) -> tuple[Tensor, ...]: | ||
| r""" | ||
| Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the | ||
| result as a tuple, with one element per input tensor. | ||
|
|
||
| :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian | ||
| matrices will have one row for each value of each of these tensors. | ||
| :param inputs: The tensors with respect to which the Jacobian must be computed. These must have | ||
| their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors | ||
| that were used to compute the ``outputs`` parameter. | ||
| :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to | ||
| ``False``. | ||
| :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the | ||
| backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in | ||
| parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A | ||
| larger value results in faster differentiation, but also higher memory usage. Defaults to | ||
| ``None``. | ||
|
|
||
| .. admonition:: | ||
| Example | ||
|
|
||
| The following example shows how to use ``jac``. | ||
|
|
||
| >>> import torch | ||
| >>> | ||
| >>> from torchjd.autojac import jac | ||
| >>> | ||
| >>> param = torch.tensor([1., 2.], requires_grad=True) | ||
| >>> # Compute arbitrary quantities that are function of param | ||
| >>> y1 = torch.tensor([-1., 1.]) @ param | ||
| >>> y2 = (param ** 2).sum() | ||
| >>> | ||
| >>> jacobians = jac([y1, y2], [param]) | ||
| >>> | ||
| >>> jacobians | ||
| (tensor([-1., 1.], | ||
| [ 2., 4.]]),) | ||
|
|
||
| The returned tuple contains a single tensor (because there is a single param), that is the | ||
| Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. | ||
|
|
||
| .. warning:: | ||
| To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some | ||
| limitations: `it does not work on the output of compiled functions | ||
| <https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have | ||
| <https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an | ||
| RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you | ||
| experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on | ||
| ``torch.vmap``. | ||
| """ | ||
|
|
||
| check_optional_positive_chunk_size(parallel_chunk_size) | ||
| outputs_ = as_checked_ordered_set(outputs, "outputs") | ||
|
|
||
| if inputs is None: | ||
| inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set()) | ||
| else: | ||
| inputs_ = OrderedSet(inputs) | ||
|
|
||
| if len(outputs_) == 0: | ||
| raise ValueError("`outputs` cannot be empty") | ||
|
|
||
| if len(inputs_) == 0: | ||
| raise ValueError("`inputs` cannot be empty") | ||
|
|
||
| jac_transform = _create_transform( | ||
| outputs=outputs_, | ||
| inputs=inputs_, | ||
| retain_graph=retain_graph, | ||
| parallel_chunk_size=parallel_chunk_size, | ||
| ) | ||
|
|
||
| result = jac_transform({}) | ||
| return tuple(val for val in result.values()) | ||
|
|
||
|
|
||
| def _create_transform( | ||
| outputs: OrderedSet[Tensor], | ||
| inputs: OrderedSet[Tensor], | ||
| retain_graph: bool, | ||
| parallel_chunk_size: int | None, | ||
| ) -> Transform: | ||
| # Transform that creates gradient outputs containing only ones. | ||
| init = Init(outputs) | ||
|
|
||
| # Transform that turns the gradients into Jacobians. | ||
| diag = Diagonalize(outputs) | ||
|
|
||
| # Transform that computes the required Jacobians. | ||
| jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) | ||
|
|
||
| return jac << diag << init | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| """ | ||
| This file contains the test of the jac usage example, with a verification of the value of the obtained jacobians tuple. | ||
| """ | ||
|
|
||
| from torch.testing import assert_close | ||
|
|
||
|
|
||
| def test_jac(): | ||
| import torch | ||
|
|
||
| from torchjd.autojac import jac | ||
|
|
||
| param = torch.tensor([1.0, 2.0], requires_grad=True) | ||
| # Compute arbitrary quantities that are function of param | ||
| y1 = torch.tensor([-1.0, 1.0]) @ param | ||
| y2 = (param**2).sum() | ||
| jacobians = jac([y1, y2], [param]) | ||
|
|
||
| assert len(jacobians) == 1 | ||
| assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from utils.tensors import tensor_ | ||
|
|
||
| from torchjd.autojac import jac | ||
| from torchjd.autojac._jac import _create_transform | ||
| from torchjd.autojac._transform import OrderedSet | ||
|
|
||
|
|
||
| def test_check_create_transform(): | ||
| """Tests that _create_transform creates a valid Transform.""" | ||
|
|
||
| a1 = tensor_([1.0, 2.0], requires_grad=True) | ||
| a2 = tensor_([3.0, 4.0], requires_grad=True) | ||
|
|
||
| y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() | ||
| y2 = (a1**2).sum() + a2.norm() | ||
|
|
||
| transform = _create_transform( | ||
| outputs=OrderedSet([y1, y2]), | ||
| inputs=OrderedSet([a1, a2]), | ||
| retain_graph=False, | ||
| parallel_chunk_size=None, | ||
| ) | ||
|
|
||
| output_keys = transform.check_keys(set()) | ||
| assert output_keys == {a1, a2} | ||
|
|
||
|
|
||
| def test_jac(): | ||
| """Tests that jac works.""" | ||
|
|
||
| a1 = tensor_([1.0, 2.0], requires_grad=True) | ||
| a2 = tensor_([3.0, 4.0], requires_grad=True) | ||
| inputs = [a1, a2] | ||
|
|
||
| y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() | ||
| y2 = (a1**2).sum() + a2.norm() | ||
| outputs = [y1, y2] | ||
|
|
||
| jacobians = jac(outputs, inputs) | ||
|
|
||
| assert len(jacobians) == len([a1, a2]) | ||
| for jacobian, a in zip(jacobians, [a1, a2]): | ||
| assert jacobian.shape[0] == len([y1, y2]) | ||
| assert jacobian.shape[1:] == a.shape |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.