Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ changes that do not affect the user.

### Added

- Added the function `torchjd.autojac.jac` to compute the Jacobian of some outputs with respect to
some inputs, without doing any aggregation. Its interface is very similar to
`torch.autograd.grad`.
- Added `__all__` in the `__init__.py` of packages. This should prevent PyLance from triggering warnings when importing from `torchjd`.

## [0.8.0] - 2025-11-13
Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/autojac/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ autojac

backward.rst
mtl_backward.rst
jac.rst
6 changes: 6 additions & 0 deletions docs/source/docs/autojac/jac.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:hide-toc:

jac
===

.. autofunction:: torchjd.autojac.jac
3 changes: 2 additions & 1 deletion src/torchjd/autojac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from ._backward import backward
from ._jac import jac
from ._mtl_backward import mtl_backward

__all__ = ["backward", "mtl_backward"]
__all__ = ["backward", "jac", "mtl_backward"]
114 changes: 114 additions & 0 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from collections.abc import Sequence
from typing import Iterable

from torch import Tensor

from torchjd.autojac._transform._base import Transform
from torchjd.autojac._transform._diagonalize import Diagonalize
from torchjd.autojac._transform._init import Init
from torchjd.autojac._transform._jac import Jac
from torchjd.autojac._transform._ordered_set import OrderedSet
from torchjd.autojac._utils import (
as_checked_ordered_set,
check_optional_positive_chunk_size,
get_leaf_tensors,
)


def jac(
outputs: Sequence[Tensor] | Tensor,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> tuple[Tensor, ...]:
r"""
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
result as a tuple, with one element per input tensor.

:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
matrices will have one row for each value of each of these tensors.
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``outputs`` parameter.
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
``False``.
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
backward pass. If set to ``None``, all coordinates of ``outputs`` will be differentiated in
parallel at once. If set to ``1``, all coordinates will be differentiated sequentially. A
larger value results in faster differentiation, but also higher memory usage. Defaults to
``None``.

.. admonition::
Example

The following example shows how to use ``jac``.

>>> import torch
>>>
>>> from torchjd.autojac import jac
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> jacobians = jac([y1, y2], [param])
>>>
>>> jacobians
(tensor([-1., 1.],
[ 2., 4.]]),)

The returned tuple contains a single tensor (because there is a single param), that is the
Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.

.. warning::
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
limitations: `it does not work on the output of compiled functions
<https://github.com/pytorch/pytorch/issues/138422>`_, `when some tensors have
<https://github.com/TorchJD/torchjd/issues/184>`_ ``retains_grad=True`` or `when using an
RNN on CUDA <https://github.com/TorchJD/torchjd/issues/220>`_, for instance. If you
experience issues with ``backward`` try to use ``parallel_chunk_size=1`` to avoid relying on
``torch.vmap``.
"""

check_optional_positive_chunk_size(parallel_chunk_size)
outputs_ = as_checked_ordered_set(outputs, "outputs")

if inputs is None:
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
else:
inputs_ = OrderedSet(inputs)

if len(outputs_) == 0:
raise ValueError("`outputs` cannot be empty")

if len(inputs_) == 0:
raise ValueError("`inputs` cannot be empty")

jac_transform = _create_transform(
outputs=outputs_,
inputs=inputs_,
retain_graph=retain_graph,
parallel_chunk_size=parallel_chunk_size,
)

result = jac_transform({})
return tuple(val for val in result.values())


def _create_transform(
outputs: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
retain_graph: bool,
parallel_chunk_size: int | None,
) -> Transform:
# Transform that creates gradient outputs containing only ones.
init = Init(outputs)

# Transform that turns the gradients into Jacobians.
diag = Diagonalize(outputs)

# Transform that computes the required Jacobians.
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)

return jac << diag << init
20 changes: 20 additions & 0 deletions tests/doc/test_jac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
This file contains the test of the jac usage example, with a verification of the value of the obtained jacobians tuple.
"""

from torch.testing import assert_close


def test_jac():
import torch

from torchjd.autojac import jac

param = torch.tensor([1.0, 2.0], requires_grad=True)
# Compute arbitrary quantities that are function of param
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()
jacobians = jac([y1, y2], [param])

assert len(jacobians) == 1
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
44 changes: 44 additions & 0 deletions tests/unit/autojac/test_jac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from utils.tensors import tensor_

from torchjd.autojac import jac
from torchjd.autojac._jac import _create_transform
from torchjd.autojac._transform import OrderedSet


def test_check_create_transform():
"""Tests that _create_transform creates a valid Transform."""

a1 = tensor_([1.0, 2.0], requires_grad=True)
a2 = tensor_([3.0, 4.0], requires_grad=True)

y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

transform = _create_transform(
outputs=OrderedSet([y1, y2]),
inputs=OrderedSet([a1, a2]),
retain_graph=False,
parallel_chunk_size=None,
)

output_keys = transform.check_keys(set())
assert output_keys == {a1, a2}


def test_jac():
"""Tests that jac works."""

a1 = tensor_([1.0, 2.0], requires_grad=True)
a2 = tensor_([3.0, 4.0], requires_grad=True)
inputs = [a1, a2]

y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()
outputs = [y1, y2]

jacobians = jac(outputs, inputs)

assert len(jacobians) == len([a1, a2])
for jacobian, a in zip(jacobians, [a1, a2]):
assert jacobian.shape[0] == len([y1, y2])
assert jacobian.shape[1:] == a.shape
Loading