diff --git a/CHANGELOG.md b/CHANGELOG.md index 76bbad54..71439b51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ changes that do not affect the user. ### Added +- Added the function `torchjd.autojac.jac` to compute the Jacobian of some outputs with respect to + some inputs, without doing any aggregation. Its interface is very similar to + `torch.autograd.grad`. - Added `__all__` in the `__init__.py` of packages. This should prevent PyLance from triggering warnings when importing from `torchjd`. ## [0.8.0] - 2025-11-13 diff --git a/docs/source/docs/autojac/index.rst b/docs/source/docs/autojac/index.rst index 4ca478cf..bb3bedb6 100644 --- a/docs/source/docs/autojac/index.rst +++ b/docs/source/docs/autojac/index.rst @@ -10,3 +10,4 @@ autojac backward.rst mtl_backward.rst + jac.rst diff --git a/docs/source/docs/autojac/jac.rst b/docs/source/docs/autojac/jac.rst new file mode 100644 index 00000000..20db9f32 --- /dev/null +++ b/docs/source/docs/autojac/jac.rst @@ -0,0 +1,6 @@ +:hide-toc: + +jac +=== + +.. autofunction:: torchjd.autojac.jac diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/autojac/__init__.py index 846c062c..37f3e264 100644 --- a/src/torchjd/autojac/__init__.py +++ b/src/torchjd/autojac/__init__.py @@ -6,6 +6,7 @@ """ from ._backward import backward +from ._jac import jac from ._mtl_backward import mtl_backward -__all__ = ["backward", "mtl_backward"] +__all__ = ["backward", "jac", "mtl_backward"] diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py new file mode 100644 index 00000000..7fa7e416 --- /dev/null +++ b/src/torchjd/autojac/_jac.py @@ -0,0 +1,114 @@ +from collections.abc import Sequence +from typing import Iterable + +from torch import Tensor + +from torchjd.autojac._transform._base import Transform +from torchjd.autojac._transform._diagonalize import Diagonalize +from torchjd.autojac._transform._init import Init +from torchjd.autojac._transform._jac import Jac +from torchjd.autojac._transform._ordered_set import OrderedSet +from torchjd.autojac._utils import ( + as_checked_ordered_set, + check_optional_positive_chunk_size, + get_leaf_tensors, +) + + +def jac( + outputs: Sequence[Tensor] | Tensor, + inputs: Iterable[Tensor] | None = None, + retain_graph: bool = False, + parallel_chunk_size: int | None = None, +) -> tuple[Tensor, ...]: + r""" + Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the + result as a tuple, with one element per input tensor. + + :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian + matrices will have one row for each value of each of these tensors. + :param inputs: The tensors with respect to which the Jacobian must be computed. These must have + their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors + that were used to compute the ``outputs`` parameter. + :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to + ``False``. + :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the + backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in + parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A + larger value results in faster differentiation, but also higher memory usage. Defaults to + ``None``. + + .. admonition:: + Example + + The following example shows how to use ``jac``. + + >>> import torch + >>> + >>> from torchjd.autojac import jac + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> jacobians = jac([y1, y2], [param]) + >>> + >>> jacobians + (tensor([-1., 1.], + [ 2., 4.]]),) + + The returned tuple contains a single tensor (because there is a single param), that is the + Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + + .. warning:: + To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some + limitations: `it does not work on the output of compiled functions + `_, `when some tensors have + `_ ``retains_grad=True`` or `when using an + RNN on CUDA `_, for instance. If you + experience issues with ``jac`` try to use ``parallel_chunk_size=1`` to avoid relying on + ``torch.vmap``. + """ + + check_optional_positive_chunk_size(parallel_chunk_size) + outputs_ = as_checked_ordered_set(outputs, "outputs") + + if inputs is None: + inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set()) + else: + inputs_ = OrderedSet(inputs) + + if len(outputs_) == 0: + raise ValueError("`outputs` cannot be empty") + + if len(inputs_) == 0: + raise ValueError("`inputs` cannot be empty") + + jac_transform = _create_transform( + outputs=outputs_, + inputs=inputs_, + retain_graph=retain_graph, + parallel_chunk_size=parallel_chunk_size, + ) + + result = jac_transform({}) + return tuple(val for val in result.values()) + + +def _create_transform( + outputs: OrderedSet[Tensor], + inputs: OrderedSet[Tensor], + retain_graph: bool, + parallel_chunk_size: int | None, +) -> Transform: + # Transform that creates gradient outputs containing only ones. + init = Init(outputs) + + # Transform that turns the gradients into Jacobians. + diag = Diagonalize(outputs) + + # Transform that computes the required Jacobians. + jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) + + return jac << diag << init diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py new file mode 100644 index 00000000..c741117d --- /dev/null +++ b/tests/doc/test_jac.py @@ -0,0 +1,20 @@ +""" +This file contains the test of the jac usage example, with a verification of the value of the obtained jacobians tuple. +""" + +from torch.testing import assert_close + + +def test_jac(): + import torch + + from torchjd.autojac import jac + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + jacobians = jac([y1, y2], [param]) + + assert len(jacobians) == 1 + assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py new file mode 100644 index 00000000..92aadfcd --- /dev/null +++ b/tests/unit/autojac/test_jac.py @@ -0,0 +1,44 @@ +from utils.tensors import tensor_ + +from torchjd.autojac import jac +from torchjd.autojac._jac import _create_transform +from torchjd.autojac._transform import OrderedSet + + +def test_check_create_transform(): + """Tests that _create_transform creates a valid Transform.""" + + a1 = tensor_([1.0, 2.0], requires_grad=True) + a2 = tensor_([3.0, 4.0], requires_grad=True) + + y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + a2.norm() + + transform = _create_transform( + outputs=OrderedSet([y1, y2]), + inputs=OrderedSet([a1, a2]), + retain_graph=False, + parallel_chunk_size=None, + ) + + output_keys = transform.check_keys(set()) + assert output_keys == {a1, a2} + + +def test_jac(): + """Tests that jac works.""" + + a1 = tensor_([1.0, 2.0], requires_grad=True) + a2 = tensor_([3.0, 4.0], requires_grad=True) + inputs = [a1, a2] + + y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + a2.norm() + outputs = [y1, y2] + + jacobians = jac(outputs, inputs) + + assert len(jacobians) == len([a1, a2]) + for jacobian, a in zip(jacobians, [a1, a2]): + assert jacobian.shape[0] == len([y1, y2]) + assert jacobian.shape[1:] == a.shape