Skip to content
Draft
8 changes: 5 additions & 3 deletions docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler
following example shows the resulting code for a multi-task learning use-case.

.. code-block:: python
:emphasize-lines: 2, 17, 27, 34, 36-38
:emphasize-lines: 2, 18, 28, 35-36, 38-39

import torch
from torch.amp import GradScaler
Expand All @@ -21,6 +21,7 @@ following example shows the resulting code for a multi-task learning use-case.

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.utils import jac_to_grad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
Expand Down Expand Up @@ -48,10 +49,11 @@ following example shows the resulting code for a multi-task learning use-case.
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
mtl_backward(losses=scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

.. hint::
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For
Expand Down
21 changes: 12 additions & 9 deletions docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``:

from torchjd import autojac
from torchjd.aggregation import UPGrad
from torchjd.utils import jac_to_grad

Define the model and the optimizer, as usual:

Expand Down Expand Up @@ -59,20 +60,16 @@ We can now compute the losses associated to each element of the batch.

The last steps are similar to gradient descent-based optimization, but using the two losses.

Reset the ``.grad`` field of each model parameter:

.. code-block:: python

optimizer.zero_grad()

Perform the Jacobian descent backward pass:

.. code-block:: python

autojac.backward([loss1, loss2], aggregator)
autojac.backward([loss1, loss2])
jac_to_grad(model.parameters(), aggregator)

This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
Jacobian matrix.
The first function will populate the ``.jac`` field of each model parameter with the corresponding
Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad``
field of the parameters. It also resets the ``.jac`` fields to ``None`` to save some memory.

Update each parameter based on its ``.grad`` field, using the ``optimizer``:

Expand All @@ -81,3 +78,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
optimizer.step()

The model's parameters have been updated!

As usual, you should now reset the ``.grad`` field of each model parameter:

.. code-block:: python

optimizer.zero_grad()
4 changes: 2 additions & 2 deletions docs/source/examples/iwmtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
The following example shows how to do that.

.. code-block:: python
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand Down Expand Up @@ -51,10 +51,10 @@ The following example shows how to do that.
# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

optimizer.zero_grad()
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

.. note::
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
Expand Down
17 changes: 10 additions & 7 deletions docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac




X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16)

Expand All @@ -64,26 +65,27 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
loss = loss_fn(y_hat, y) # shape: [] (scalar)
optimizer.zero_grad()
loss.backward()


optimizer.step()
optimizer.zero_grad()

In this baseline example, the update may negatively affect the loss of some elements of the
batch.

.. tab-item:: autojac

.. code-block:: python
:emphasize-lines: 5-6, 12, 16, 21, 23
:emphasize-lines: 5-7, 13, 17, 22-24

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import backward
from torchjd.utils import jac_to_grad

X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16)
Expand All @@ -99,19 +101,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
backward(losses, aggregator)

backward(losses)
jac_to_grad(model.parameters(), aggregator)

optimizer.step()
optimizer.zero_grad()

Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
and use it to update the model such that no loss from the batch is (locally) increased.

.. tab-item:: autogram (recommended)

.. code-block:: python
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
:emphasize-lines: 5-6, 13, 17-18, 22-25

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand All @@ -120,6 +122,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
from torchjd.aggregation import UPGradWeighting
from torchjd.autogram import Engine


X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16)

Expand All @@ -134,11 +137,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses) # shape: [16, 16]
weights = weighting(gramian) # shape: [16]
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

Here, the per-sample gradients are never fully stored in memory, leading to large
improvements in memory usage and speed compared to autojac, in most practical cases. The
Expand Down
8 changes: 5 additions & 3 deletions docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
<../docs/autojac/mtl_backward>` at each training iteration.

.. code-block:: python
:emphasize-lines: 9-10, 18, 32
:emphasize-lines: 9-11, 19, 32-33

import torch
from lightning import LightningModule, Trainer
Expand All @@ -23,6 +23,7 @@ The following code example demonstrates a basic multi-task learning setup using

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.utils import jac_to_grad

class Model(LightningModule):
def __init__(self):
Expand All @@ -43,9 +44,10 @@ The following code example demonstrates a basic multi-task learning setup using
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()

def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
Expand Down
8 changes: 5 additions & 3 deletions docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Jacobian descent is doing something different than gradient descent. With
they have a negative inner product).

.. code-block:: python
:emphasize-lines: 9-11, 13-18, 33-34
:emphasize-lines: 10-12, 14-19, 34-35

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand All @@ -24,6 +24,7 @@ they have a negative inner product).

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.utils import jac_to_grad

def print_weights(_, __, weights: torch.Tensor) -> None:
"""Prints the extracted weights."""
Expand Down Expand Up @@ -63,6 +64,7 @@ they have a negative inner product).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
8 changes: 5 additions & 3 deletions docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.


.. code-block:: python
:emphasize-lines: 5-6, 19, 33
:emphasize-lines: 5-7, 20, 33-34

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import mtl_backward
from torchjd.utils import jac_to_grad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
Expand All @@ -52,9 +53,10 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
mtl_backward(losses=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()

.. note::
In this example, the Jacobian is only with respect to the shared parameters. The task-specific
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/partial_jd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses)
weights = weighting(gramian)
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
8 changes: 5 additions & 3 deletions docs/source/examples/rnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ element of the output sequences. If the gradients of these losses are likely to
descent can be leveraged to enhance optimization.

.. code-block:: python
:emphasize-lines: 5-6, 10, 17, 20
:emphasize-lines: 5-7, 11, 18, 20-21

import torch
from torch.nn import RNN
from torch.optim import SGD

from torchjd.aggregation import UPGrad
from torchjd.autojac import backward
from torchjd.utils import jac_to_grad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
Expand All @@ -26,9 +27,10 @@ descent can be leveraged to enhance optimization.
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
backward(losses, parallel_chunk_size=1)
jac_to_grad(rnn.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()

.. note::
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
Expand Down
Loading
Loading