Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
99b4260
feat: Use .jac fields
ValerianRey Jan 8, 2026
7dda66f
Delete jac field instead of setting to None
ValerianRey Jan 8, 2026
e08e45f
[WIP] Fix jac undefined errors
ValerianRey Jan 8, 2026
0bcf1c0
Merge branch 'main' into revamp-interface
ValerianRey Jan 8, 2026
57e5c6d
Add check that the number of rows of the jacobians is consistant
ValerianRey Jan 8, 2026
6ea5983
Add check of jac shape before assigning to .jac
ValerianRey Jan 8, 2026
b7b3a75
Add changelog entry
ValerianRey Jan 8, 2026
28cf701
Move check of jacobian shape outside of if/else
ValerianRey Jan 8, 2026
9a2a0ec
Improve docstring of AccumulateJac and AccumulateGrad
ValerianRey Jan 8, 2026
9e855e2
Merge branch 'main' into revamp-interface
ValerianRey Jan 8, 2026
382e9d3
Fix tests
ValerianRey Jan 9, 2026
7e95c37
Simplify a test
ValerianRey Jan 9, 2026
1876351
Add unit tests for AccumulateJac
ValerianRey Jan 12, 2026
4f24e39
Refactor accumulate tests to use loops and assert helpers
ValerianRey Jan 12, 2026
b1aaee9
Move newly added functions
ValerianRey Jan 12, 2026
ce9231b
Rename retain_jacs to retain_jac
ValerianRey Jan 12, 2026
c353713
Rename params to tensors in jac_to_grad
ValerianRey Jan 12, 2026
5cf8c1c
Ad jac_to_grad tests
ValerianRey Jan 12, 2026
e93f2d9
Remove duplicated optimizer.zero_grad() lines
ValerianRey Jan 12, 2026
2fb6856
Fix formulation about freeing jacs
ValerianRey Jan 12, 2026
57fe5b4
Add doc entry for jac_to_grad and usage example
ValerianRey Jan 12, 2026
0a1fc21
Add comments in jac_to_grad example
ValerianRey Jan 12, 2026
f1ee074
Fix docstring of test_backward.py
ValerianRey Jan 12, 2026
8b3d447
Fix formatting in backward docstring
ValerianRey Jan 14, 2026
87b66f8
Fix comment in accumulate_jacs that applied to accumulate_grads
ValerianRey Jan 14, 2026
1394395
Fix error message in _check_expects_grad
ValerianRey Jan 14, 2026
16349a0
Fix wrong import in basic_usage.rst
ValerianRey Jan 14, 2026
0a8cc62
Add explanation about how jac_to_grad works in jac_to_grad's docstring
ValerianRey Jan 14, 2026
c13a75b
Improve description of parameters in jac_to_grad
ValerianRey Jan 14, 2026
674f6ad
Improve error message and usage example of jac_to_grad
ValerianRey Jan 14, 2026
8a0fb0e
Make _disunite_gradient use less memory
ValerianRey Jan 14, 2026
0e8add2
Free .jacs earlier to divide by two peak memory
ValerianRey Jan 14, 2026
430a8a2
Use Tensor.split in _disunit_gradient
ValerianRey Jan 14, 2026
f0fe529
Add kwargs to assert_jac_close and assert_grad_close
ValerianRey Jan 14, 2026
cff6d8e
Rename expected_jacobian to J in some test
ValerianRey Jan 14, 2026
84bd552
Move asserts to tests/utils and use them in doc tests
ValerianRey Jan 14, 2026
4bb561d
Rename test and update docstring to match its changes
ValerianRey Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,31 @@ changelog does not include internal changes that do not affect the user.

### Changed

- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new
function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into
`.grad` fields.
This means that users now have more control on what they do with the Jacobians (they can easily
aggregate them group by group or even param by param if they want), but it now requires an extra
line of code to do the Jacobian descent step. To update, please change:
```python
backward(losses, aggregator)
```
to
```python
backward(losses)
jac_to_grad(model.parameters(), aggregator)
```
and
```python
mtl_backward(losses, features, aggregator)
```
to
```python
mtl_backward(losses, features)
jac_to_grad(shared_module.parameters(), aggregator)
```

- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.

Expand Down
1 change: 1 addition & 0 deletions docs/source/docs/autojac/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ autojac

backward.rst
mtl_backward.rst
jac_to_grad.rst
6 changes: 6 additions & 0 deletions docs/source/docs/autojac/jac_to_grad.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:hide-toc:

jac_to_grad
===========

.. autofunction:: torchjd.autojac.jac_to_grad
7 changes: 4 additions & 3 deletions docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ case, the losses) should preferably be scaled with a `GradScaler
following example shows the resulting code for a multi-task learning use-case.

.. code-block:: python
:emphasize-lines: 2, 17, 27, 34-37
:emphasize-lines: 2, 17, 27, 34-35, 37-38

import torch
from torch.amp import GradScaler
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.autojac import mtl_backward, jac_to_grad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
Expand Down Expand Up @@ -48,7 +48,8 @@ following example shows the resulting code for a multi-task learning use-case.
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
mtl_backward(losses=scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Expand Down
9 changes: 6 additions & 3 deletions docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:

from torchjd import autojac
from torchjd.aggregation import UPGrad
from torchjd.autojac import jac_to_grad

Define the model and the optimizer, as usual:

Expand Down Expand Up @@ -63,10 +64,12 @@ Perform the Jacobian descent backward pass:

.. code-block:: python

autojac.backward([loss1, loss2], aggregator)
autojac.backward([loss1, loss2])
jac_to_grad(model.parameters(), aggregator)

This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
Jacobian matrix.
The first function will populate the ``.jac`` field of each model parameter with the corresponding
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
field of the parameters. It also deletes the ``.jac`` fields save some memory.

Update each parameter based on its ``.grad`` field, using the ``optimizer``:

Expand Down
8 changes: 4 additions & 4 deletions docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
.. tab-item:: autojac

.. code-block:: python
:emphasize-lines: 5-6, 12, 16, 21-22
:emphasize-lines: 5-6, 12, 16, 21-23

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import backward
from torchjd.autojac import backward, jac_to_grad

X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16)
Expand All @@ -99,8 +99,8 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
backward(losses, aggregator)

backward(losses)
jac_to_grad(model.parameters(), aggregator)

optimizer.step()
optimizer.zero_grad()
Expand Down
7 changes: 4 additions & 3 deletions docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
<../docs/autojac/mtl_backward>` at each training iteration.

.. code-block:: python
:emphasize-lines: 9-10, 18, 31
:emphasize-lines: 9-10, 18, 31-32

import torch
from lightning import LightningModule, Trainer
Expand All @@ -22,7 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using
from torch.utils.data import DataLoader, TensorDataset

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.autojac import mtl_backward, jac_to_grad

class Model(LightningModule):
def __init__(self):
Expand All @@ -43,7 +43,8 @@ The following code example demonstrates a basic multi-task learning setup using
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()

Expand Down
5 changes: 3 additions & 2 deletions docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ they have a negative inner product).
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.autojac import mtl_backward, jac_to_grad

def print_weights(_, __, weights: torch.Tensor) -> None:
"""Prints the extracted weights."""
Expand Down Expand Up @@ -63,6 +63,7 @@ they have a negative inner product).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
7 changes: 4 additions & 3 deletions docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.


.. code-block:: python
:emphasize-lines: 5-6, 19, 32
:emphasize-lines: 5-6, 19, 32-33

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.autojac import mtl_backward, jac_to_grad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
Expand All @@ -52,7 +52,8 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()

Expand Down
7 changes: 4 additions & 3 deletions docs/source/examples/rnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ element of the output sequences. If the gradients of these losses are likely to
descent can be leveraged to enhance optimization.

.. code-block:: python
:emphasize-lines: 5-6, 10, 17, 19
:emphasize-lines: 5-6, 10, 17, 19-20

import torch
from torch.nn import RNN
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import backward
from torchjd.autojac import backward, jac_to_grad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
Expand All @@ -26,7 +26,8 @@ descent can be leveraged to enhance optimization.
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

backward(losses, aggregator, parallel_chunk_size=1)
backward(losses, parallel_chunk_size=1)
jac_to_grad(rnn.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()

Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/autojac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from ._backward import backward
from ._jac_to_grad import jac_to_grad
from ._mtl_backward import mtl_backward

__all__ = ["backward", "mtl_backward"]
__all__ = ["backward", "jac_to_grad", "mtl_backward"]
67 changes: 67 additions & 0 deletions src/torchjd/autojac/_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections.abc import Iterable
from typing import cast

from torch import Tensor


class TensorWithJac(Tensor):
"""
Tensor known to have a populated jac field.

Should not be directly instantiated, but can be used as a type hint and can be casted to.
"""

jac: Tensor


def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
for param, jac in zip(params, jacobians, strict=True):
_check_expects_grad(param, field_name=".jac")
# We that the shape is correct to be consistent with torch, that checks that the grad
# shape is correct before assigning it.
if jac.shape[1:] != param.shape:
raise RuntimeError(
f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of "
f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the"
" jacobian are the same size"
)

if hasattr(param, "jac"): # No check for None because jac cannot be None
param_ = cast(TensorWithJac, param)
param_.jac += jac
else:
# We do not clone the value to save memory and time, so subsequent modifications of
# the value of key.jac (subsequent accumulations) will also affect the value of
# jacobians[key] and outside changes to the value of jacobians[key] will also affect
# the value of key.jac. So to be safe, the values of jacobians should not be used
# anymore after being passed to this function.
#
# We do not detach from the computation graph because the value can have grad_fn
# that we want to keep track of (in case it was obtained via create_graph=True).
param.__setattr__("jac", jac)


def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:
for param, grad in zip(params, gradients, strict=True):
_check_expects_grad(param, field_name=".grad")
if hasattr(param, "grad") and param.grad is not None:
param.grad += grad
else:
param.grad = grad


def _check_expects_grad(tensor: Tensor, field_name: str) -> None:
if not _expects_grad(tensor):
raise ValueError(
f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n"
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
)


def _expects_grad(tensor: Tensor) -> bool:
"""
Determines whether a Tensor expects its .grad attribute to be populated.
See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information.
"""

return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)
42 changes: 16 additions & 26 deletions src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,23 @@

from torch import Tensor

from torchjd.aggregation import Aggregator

from ._transform import Accumulate, Aggregate, Diagonalize, Init, Jac, OrderedSet, Transform
from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors


def backward(
tensors: Sequence[Tensor] | Tensor,
aggregator: Aggregator,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> None:
r"""
Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its
aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the
``inputs``.

:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
matrices will have one row for each value of each of these tensors.
:param aggregator: Aggregator used to reduce the Jacobian into a vector.
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and
accumulates them in the ``.jac`` fields of the ``inputs``.

:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
have one row for each value of each of these tensors.
:param inputs: The tensors with respect to which the Jacobians must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``tensors`` parameter.
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
Expand All @@ -41,20 +36,20 @@ def backward(

>>> import torch
>>>
>>> from torchjd.aggregation import UPGrad
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2], UPGrad())
>>> backward([y1, y2])
>>>
>>> param.grad
tensor([0.5000, 2.5000])
>>> param.jac
tensor([[-1., 1.],
[ 2., 4.]])

The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of
The ``.jac`` field of ``param`` now contains the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.

.. warning::
Expand All @@ -80,7 +75,6 @@ def backward(

backward_transform = _create_transform(
tensors=tensors_,
aggregator=aggregator,
inputs=inputs_,
retain_graph=retain_graph,
parallel_chunk_size=parallel_chunk_size,
Expand All @@ -91,12 +85,11 @@ def backward(

def _create_transform(
tensors: OrderedSet[Tensor],
aggregator: Aggregator,
inputs: OrderedSet[Tensor],
retain_graph: bool,
parallel_chunk_size: int | None,
) -> Transform:
"""Creates the Jacobian descent backward transform."""
"""Creates the backward transform."""

# Transform that creates gradient outputs containing only ones.
init = Init(tensors)
Expand All @@ -107,10 +100,7 @@ def _create_transform(
# Transform that computes the required Jacobians.
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)

# Transform that aggregates the Jacobians.
aggregate = Aggregate(aggregator, inputs)

# Transform that accumulates the result in the .grad field of the inputs.
accumulate = Accumulate()
# Transform that accumulates the result in the .jac field of the inputs.
accumulate = AccumulateJac()

return accumulate << aggregate << jac << diag << init
return accumulate << jac << diag << init
Loading
Loading