diff --git a/CHANGELOG.md b/CHANGELOG.md index 84d753344..e933c3ff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,31 @@ changelog does not include internal changes that do not affect the user. ### Changed +- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the + Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new + function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into + `.grad` fields. + This means that users now have more control on what they do with the Jacobians (they can easily + aggregate them group by group or even param by param if they want), but it now requires an extra + line of code to do the Jacobian descent step. To update, please change: + ```python + backward(losses, aggregator) + ``` + to + ```python + backward(losses) + jac_to_grad(model.parameters(), aggregator) + ``` + and + ```python + mtl_backward(losses, features, aggregator) + ``` + to + ```python + mtl_backward(losses, features) + jac_to_grad(shared_module.parameters(), aggregator) + ``` + - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory efficiency of `autojac`. diff --git a/docs/source/docs/autojac/index.rst b/docs/source/docs/autojac/index.rst index 4ca478cf2..5eeb22af6 100644 --- a/docs/source/docs/autojac/index.rst +++ b/docs/source/docs/autojac/index.rst @@ -10,3 +10,4 @@ autojac backward.rst mtl_backward.rst + jac_to_grad.rst diff --git a/docs/source/docs/autojac/jac_to_grad.rst b/docs/source/docs/autojac/jac_to_grad.rst new file mode 100644 index 000000000..0b61f00ed --- /dev/null +++ b/docs/source/docs/autojac/jac_to_grad.rst @@ -0,0 +1,6 @@ +:hide-toc: + +jac_to_grad +=========== + +.. autofunction:: torchjd.autojac.jac_to_grad diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index 0aad8da00..974316672 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler following example shows the resulting code for a multi-task learning use-case. .. code-block:: python - :emphasize-lines: 2, 17, 27, 34-37 + :emphasize-lines: 2, 17, 27, 34-35, 37-38 import torch from torch.amp import GradScaler @@ -20,7 +20,7 @@ following example shows the resulting code for a multi-task learning use-case. from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import mtl_backward, jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -48,7 +48,8 @@ following example shows the resulting code for a multi-task learning use-case. loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) + mtl_backward(losses=scaled_losses, features=features) + jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() optimizer.zero_grad() diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 1cca64b76..c3ee871cb 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -19,7 +19,7 @@ Import several classes from ``torch`` and ``torchjd``: from torch.optim import SGD from torchjd import autojac - from torchjd.aggregation import UPGrad + from torchjd.aggregation import UPGrad, jac_to_grad Define the model and the optimizer, as usual: @@ -63,10 +63,12 @@ Perform the Jacobian descent backward pass: .. code-block:: python - autojac.backward([loss1, loss2], aggregator) + autojac.backward([loss1, loss2]) + jac_to_grad(model.parameters(), aggregator) -This will populate the ``.grad`` field of each model parameter with the corresponding aggregated -Jacobian matrix. +The first function will populate the ``.jac`` field of each model parameter with the corresponding +Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad`` +field of the parameters. It also deletes the ``.jac`` fields save some memory. Update each parameter based on its ``.grad`` field, using the ``optimizer``: diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index d1b524260..ebc2bde55 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -76,14 +76,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autojac .. code-block:: python - :emphasize-lines: 5-6, 12, 16, 21-22 + :emphasize-lines: 5-6, 12, 16, 21-23 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward + from torchjd.autojac import backward, jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -99,8 +99,8 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - backward(losses, aggregator) - + backward(losses) + jac_to_grad(model.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index c1fbba3b7..61449b97f 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using <../docs/autojac/mtl_backward>` at each training iteration. .. code-block:: python - :emphasize-lines: 9-10, 18, 31 + :emphasize-lines: 9-10, 18, 31-32 import torch from lightning import LightningModule, Trainer @@ -22,7 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using from torch.utils.data import DataLoader, TensorDataset from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import mtl_backward, jac_to_grad class Model(LightningModule): def __init__(self): @@ -43,7 +43,8 @@ The following code example demonstrates a basic multi-task learning setup using loss2 = mse_loss(output2, target2) opt = self.optimizers() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index f12fd1da3..69cc0e1bc 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -23,7 +23,7 @@ they have a negative inner product). from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import mtl_backward, jac_to_grad def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" @@ -63,6 +63,7 @@ they have a negative inner product). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index dd7703403..ce74647ba 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -19,14 +19,14 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. .. code-block:: python - :emphasize-lines: 5-6, 19, 32 + :emphasize-lines: 5-6, 19, 32-33 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import mtl_backward, jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -52,7 +52,8 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/docs/source/examples/rnn.rst b/docs/source/examples/rnn.rst index a257c9383..42eb2f913 100644 --- a/docs/source/examples/rnn.rst +++ b/docs/source/examples/rnn.rst @@ -6,14 +6,14 @@ element of the output sequences. If the gradients of these losses are likely to descent can be leveraged to enhance optimization. .. code-block:: python - :emphasize-lines: 5-6, 10, 17, 19 + :emphasize-lines: 5-6, 10, 17, 19-20 import torch from torch.nn import RNN from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward + from torchjd.autojac import backward, jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) @@ -26,7 +26,8 @@ descent can be leveraged to enhance optimization. output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - backward(losses, aggregator, parallel_chunk_size=1) + backward(losses, parallel_chunk_size=1) + jac_to_grad(rnn.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/autojac/__init__.py index 846c062c1..ab99d98b8 100644 --- a/src/torchjd/autojac/__init__.py +++ b/src/torchjd/autojac/__init__.py @@ -6,6 +6,7 @@ """ from ._backward import backward +from ._jac_to_grad import jac_to_grad from ._mtl_backward import mtl_backward -__all__ = ["backward", "mtl_backward"] +__all__ = ["backward", "jac_to_grad", "mtl_backward"] diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py new file mode 100644 index 000000000..d561aaeae --- /dev/null +++ b/src/torchjd/autojac/_accumulation.py @@ -0,0 +1,68 @@ +from collections.abc import Iterable +from typing import cast + +from torch import Tensor + + +class TensorWithJac(Tensor): + """ + Tensor known to have a populated jac field. + + Should not be directly instantiated, but can be used as a type hint and can be casted to. + """ + + jac: Tensor + + +def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: + for param, jac in zip(params, jacobians, strict=True): + _check_expects_grad(param) + # We that the shape is correct to be consistent with torch, that checks that the grad + # shape is correct before assigning it. + if jac.shape[1:] != param.shape: + raise RuntimeError( + f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " + f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" + " jacobian are the same size" + ) + + if hasattr(param, "jac"): # No check for None because jac cannot be None + param_ = cast(TensorWithJac, param) + param_.jac += jac + else: + # We do not clone the value to save memory and time, so subsequent modifications of + # the value of key.grad (subsequent accumulations) will also affect the value of + # gradients[key] and outside changes to the value of gradients[key] will also affect + # the value of key.grad. So to be safe, the values of gradients should not be used + # anymore after being passed to this function. + # + # We do not detach from the computation graph because the value can have grad_fn + # that we want to keep track of (in case it was obtained via create_graph=True and a + # differentiable aggregator). + param.__setattr__("jac", jac) + + +def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: + for param, grad in zip(params, gradients, strict=True): + _check_expects_grad(param) + if hasattr(param, "grad") and param.grad is not None: + param.grad += grad + else: + param.grad = grad + + +def _check_expects_grad(tensor: Tensor) -> None: + if not _expects_grad(tensor): + raise ValueError( + "Cannot populate the .grad field of a Tensor that does not satisfy:" + "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." + ) + + +def _expects_grad(tensor: Tensor) -> bool: + """ + Determines whether a Tensor expects its .grad attribute to be populated. + See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. + """ + + return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index ca3009bc2..46ac2d484 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -2,28 +2,23 @@ from torch import Tensor -from torchjd.aggregation import Aggregator - -from ._transform import Accumulate, Aggregate, Diagonalize, Init, Jac, OrderedSet, Transform +from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors def backward( tensors: Sequence[Tensor] | Tensor, - aggregator: Aggregator, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: r""" - Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its - aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the - ``inputs``. - - :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian - matrices will have one row for each value of each of these tensors. - :param aggregator: Aggregator used to reduce the Jacobian into a vector. - :param inputs: The tensors with respect to which the Jacobian must be computed. These must have + Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and + accumulates them in the `.jac` fields of the `inputs`. + + :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will + have one row for each value of each of these tensors. + :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to @@ -41,7 +36,6 @@ def backward( >>> import torch >>> - >>> from torchjd.aggregation import UPGrad >>> from torchjd.autojac import backward >>> >>> param = torch.tensor([1., 2.], requires_grad=True) @@ -49,12 +43,13 @@ def backward( >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> - >>> backward([y1, y2], UPGrad()) + >>> backward([y1, y2]) >>> - >>> param.grad - tensor([0.5000, 2.5000]) + >>> param.jac + tensor([[-1., 1.], + [ 2., 4.]]) - The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of + The ``.jac`` field of ``param`` now contains the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. .. warning:: @@ -80,7 +75,6 @@ def backward( backward_transform = _create_transform( tensors=tensors_, - aggregator=aggregator, inputs=inputs_, retain_graph=retain_graph, parallel_chunk_size=parallel_chunk_size, @@ -91,12 +85,11 @@ def backward( def _create_transform( tensors: OrderedSet[Tensor], - aggregator: Aggregator, inputs: OrderedSet[Tensor], retain_graph: bool, parallel_chunk_size: int | None, ) -> Transform: - """Creates the Jacobian descent backward transform.""" + """Creates the backward transform.""" # Transform that creates gradient outputs containing only ones. init = Init(tensors) @@ -107,10 +100,7 @@ def _create_transform( # Transform that computes the required Jacobians. jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) - # Transform that aggregates the Jacobians. - aggregate = Aggregate(aggregator, inputs) - - # Transform that accumulates the result in the .grad field of the inputs. - accumulate = Accumulate() + # Transform that accumulates the result in the .jac field of the inputs. + accumulate = AccumulateJac() - return accumulate << aggregate << jac << diag << init + return accumulate << jac << diag << init diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py new file mode 100644 index 000000000..2924043dc --- /dev/null +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -0,0 +1,103 @@ +from collections.abc import Iterable +from typing import cast + +import torch +from torch import Tensor + +from torchjd.aggregation import Aggregator + +from ._accumulation import TensorWithJac, accumulate_grads + + +def jac_to_grad( + tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False +) -> None: + r""" + Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result + into their ``.grad`` fields. + + :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must + have the same first dimension (number of outputs). + :param aggregator: The aggregator used to reduce the Jacobians into gradients. + :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors. + + .. admonition:: + Example + + This example shows how to use ``jac_to_grad`` after a call to ``backward`` + + >>> import torch + >>> + >>> from torchjd.autojac import backward, jac_to_grad + >>> from torchjd.aggregation import UPGrad + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> backward([y1, y2]) # param now has a .jac field + >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + >>> param.grad + tensor([-1., 1.]) + + The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + """ + + tensors_ = list[TensorWithJac]() + for t in tensors: + if not hasattr(t, "jac"): + raise ValueError( + "Some `jac` fields were not populated. Did you use `autojac.backward` before" + "calling `jac_to_grad`?" + ) + t_ = cast(TensorWithJac, t) + tensors_.append(t_) + + if len(tensors_) == 0: + return + + jacobians = [t.jac for t in tensors_] + + if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): + raise ValueError("All Jacobians should have the same number of rows.") + + jacobian_matrix = _unite_jacobians(jacobians) + gradient_vector = aggregator(jacobian_matrix) + gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) + accumulate_grads(tensors_, gradients) + + if not retain_jac: + _free_jacs(tensors_) + + +def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: + jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] + jacobian_matrix = torch.concat(jacobian_matrices, dim=1) + return jacobian_matrix + + +def _disunite_gradient( + gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: + gradient_vectors = [] + start = 0 + for jacobian in jacobians: + end = start + jacobian[0].numel() + current_gradient_vector = gradient_vector[start:end] + gradient_vectors.append(current_gradient_vector) + start = end + gradients = [g.view(t.shape) for t, g in zip(tensors, gradient_vectors, strict=True)] + return gradients + + +def _free_jacs(tensors: Iterable[TensorWithJac]) -> None: + """ + Deletes the ``.jac`` field of the provided tensors. + + :param tensors: The tensors whose ``.jac`` fields should be cleared. + """ + + for t in tensors: + del t.jac diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 4bdac023c..0d05447d5 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -2,16 +2,23 @@ from torch import Tensor -from torchjd.aggregation import Aggregator - -from ._transform import Accumulate, Aggregate, Grad, Init, Jac, OrderedSet, Select, Stack, Transform +from ._transform import ( + AccumulateGrad, + AccumulateJac, + Grad, + Init, + Jac, + OrderedSet, + Select, + Stack, + Transform, +) from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors def mtl_backward( losses: Sequence[Tensor], features: Sequence[Tensor] | Tensor, - aggregator: Aggregator, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, retain_graph: bool = False, @@ -23,21 +30,18 @@ def mtl_backward( This function computes the gradient of each task-specific loss with respect to its task-specific parameters and accumulates it in their ``.grad`` fields. Then, it computes the Jacobian of all - losses with respect to the shared parameters, aggregates it and accumulates the result in their - ``.grad`` fields. + losses with respect to the shared parameters and accumulates it in their ``.jac`` fields. - :param losses: The task losses. The Jacobian matrix will have one row per loss. + :param losses: The task losses. The Jacobians will have one row per loss. :param features: The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty. - :param aggregator: Aggregator used to reduce the Jacobian into a vector. :param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags must be set to ``True``. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its loss, but that were not used to compute the ``features``. - :param shared_params: The parameters of the shared feature extractor. The Jacobian matrix will - have one column for each value in these tensors. Their ``requires_grad`` flags must be set - to ``True``. If not provided, defaults to the leaf tensors that are in the computation graph - of the ``features``. + :param shared_params: The parameters of the shared feature extractor. Their ``requires_grad`` + flags must be set to ``True``. If not provided, defaults to the leaf tensors that are in the + computation graph of the ``features``. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to ``False``. :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the @@ -95,7 +99,6 @@ def mtl_backward( backward_transform = _create_transform( losses=losses_, features=features_, - aggregator=aggregator, tasks_params=tasks_params_, shared_params=shared_params_, retain_graph=retain_graph, @@ -108,7 +111,6 @@ def mtl_backward( def _create_transform( losses: OrderedSet[Tensor], features: OrderedSet[Tensor], - aggregator: Aggregator, tasks_params: list[OrderedSet[Tensor]], shared_params: OrderedSet[Tensor], retain_graph: bool, @@ -140,13 +142,10 @@ def _create_transform( # Transform that computes the Jacobians of the losses w.r.t. the shared parameters. jac = Jac(features, shared_params, parallel_chunk_size, retain_graph) - # Transform that aggregates the Jacobians. - aggregate = Aggregate(aggregator, shared_params) - - # Transform that accumulates the result in the .grad field of the shared parameters. - accumulate = Accumulate() + # Transform that accumulates the result in the .jac field of the shared parameters. + accumulate = AccumulateJac() - return accumulate << aggregate << jac << stack + return accumulate << jac << stack def _create_task_transform( @@ -167,7 +166,7 @@ def _create_task_transform( # Transform that accumulates the gradients w.r.t. the task-specific parameters into their # .grad fields. - accumulate = Accumulate() << Select(task_params) + accumulate = AccumulateGrad() << Select(task_params) # Transform that backpropagates the gradients of the losses w.r.t. the features. backpropagate = Select(features) diff --git a/src/torchjd/autojac/_transform/__init__.py b/src/torchjd/autojac/_transform/__init__.py index 46be392d9..10d1c5125 100644 --- a/src/torchjd/autojac/_transform/__init__.py +++ b/src/torchjd/autojac/_transform/__init__.py @@ -1,5 +1,4 @@ -from ._accumulate import Accumulate -from ._aggregate import Aggregate +from ._accumulate import AccumulateGrad, AccumulateJac from ._base import Composition, Conjunction, RequirementError, Transform from ._diagonalize import Diagonalize from ._grad import Grad @@ -10,8 +9,8 @@ from ._stack import Stack __all__ = [ - "Accumulate", - "Aggregate", + "AccumulateGrad", + "AccumulateJac", "Composition", "Conjunction", "Diagonalize", diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 7bfce193c..082ef1df7 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -1,49 +1,38 @@ from torch import Tensor +from .._accumulation import accumulate_grads, accumulate_jacs from ._base import TensorDict, Transform -class Accumulate(Transform): +class AccumulateGrad(Transform): """ Transform from Gradients to {} that accumulates gradients with respect to keys into their ``grad`` field. + + The Gradients are not cloned and may be modified in-place by subsequent accumulations, so they + should not be used elsewhere. """ def __call__(self, gradients: TensorDict) -> TensorDict: - for key in gradients.keys(): - _check_expects_grad(key) - if hasattr(key, "grad") and key.grad is not None: - key.grad += gradients[key] - else: - # We do not clone the value to save memory and time, so subsequent modifications of - # the value of key.grad (subsequent accumulations) will also affect the value of - # gradients[key] and outside changes to the value of gradients[key] will also affect - # the value of key.grad. So to be safe, the values of gradients should not be used - # anymore after being passed to this function. - # - # We do not detach from the computation graph because the value can have grad_fn - # that we want to keep track of (in case it was obtained via create_graph=True and a - # differentiable aggregator). - key.grad = gradients[key] - + accumulate_grads(gradients.keys(), gradients.values()) return {} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: return set() -def _check_expects_grad(tensor: Tensor) -> None: - if not _expects_grad(tensor): - raise ValueError( - "Cannot populate the .grad field of a Tensor that does not satisfy:" - "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." - ) - - -def _expects_grad(tensor: Tensor) -> bool: +class AccumulateJac(Transform): """ - Determines whether a Tensor expects its .grad attribute to be populated. - See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. + Transform from Jacobians to {} that accumulates jacobians with respect to keys into their + ``jac`` field. + + The Jacobians are not cloned and may be modified in-place by subsequent accumulations, so they + should not be used elsewhere. """ - return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) + def __call__(self, jacobians: TensorDict) -> TensorDict: + accumulate_jacs(jacobians.keys(), jacobians.values()) + return {} + + def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + return set() diff --git a/src/torchjd/autojac/_transform/_aggregate.py b/src/torchjd/autojac/_transform/_aggregate.py deleted file mode 100644 index 6f1b2ccad..000000000 --- a/src/torchjd/autojac/_transform/_aggregate.py +++ /dev/null @@ -1,151 +0,0 @@ -from collections import OrderedDict -from collections.abc import Hashable -from typing import TypeVar - -import torch -from torch import Tensor - -from torchjd.aggregation import Aggregator - -from ._base import RequirementError, TensorDict, Transform -from ._ordered_set import OrderedSet - -_KeyType = TypeVar("_KeyType", bound=Hashable) -_ValueType = TypeVar("_ValueType") - - -class Aggregate(Transform): - """ - Transform aggregating Jacobians into Gradients. - - It does so by reshaping these Jacobians into matrices, concatenating them into a single matrix, - applying an aggregator to it, separating the result back into one gradient vector per key, and - finally reshaping those into gradients of the same shape as their corresponding keys. - - :param aggregator: The aggregator used to aggregate the concatenated jacobian matrix. - :param key_order: Order in which the different jacobian matrices must be concatenated. - """ - - def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]): - matrixify = _Matrixify() - aggregate_matrices = _AggregateMatrices(aggregator, key_order) - reshape = _Reshape() - - self._aggregator_str = str(aggregator) - self.transform = reshape << aggregate_matrices << matrixify - - def __call__(self, input: TensorDict) -> TensorDict: - return self.transform(input) - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return self.transform.check_keys(input_keys) - - -class _AggregateMatrices(Transform): - """ - Transform aggregating JacobiansMatrices into GradientsVectors. - - It does so by concatenating the matrices into a single matrix, applying an aggregator to it and - separating the result back into one gradient vector per key. - - :param aggregator: The aggregator used to aggregate the concatenated jacobian matrix. - :param key_order: Order in which the different jacobian matrices must be concatenated. - """ - - def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]): - self.key_order = key_order - self.aggregator = aggregator - - def __call__(self, jacobian_matrices: TensorDict) -> TensorDict: - """ - Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using - the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to - the part of the obtained gradient vector, that corresponds to the jacobian matrix given for - that key. - - :param jacobian_matrices: The dictionary of jacobian matrices to aggregate. The first - dimension of each jacobian matrix should be the same. - """ - ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order) - return self._aggregate_group(ordered_matrices, self.aggregator) - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - if not set(self.key_order) == input_keys: - raise RequirementError( - f"The input_keys must match the key_order. Found input_keys {input_keys} and" - f"key_order {self.key_order}." - ) - return input_keys - - @staticmethod - def _select_ordered_subdict( - dictionary: dict[_KeyType, _ValueType], ordered_keys: OrderedSet[_KeyType] - ) -> OrderedDict[_KeyType, _ValueType]: - """ - Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``. - Returns an OrderedDict in the same order as the provided ``ordered_keys``. - """ - - return OrderedDict([(key, dictionary[key]) for key in ordered_keys]) - - @staticmethod - def _aggregate_group( - jacobian_matrices: OrderedDict[Tensor, Tensor], aggregator: Aggregator - ) -> TensorDict: - """ - Unites the jacobian matrices and aggregates them using an - :class:`~torchjd.aggregation._aggregator_bases.Aggregator`. Returns the obtained gradient - vectors. - """ - - if len(jacobian_matrices) == 0: - return {} - - united_jacobian_matrix = _AggregateMatrices._unite(jacobian_matrices) - united_gradient_vector = aggregator(united_jacobian_matrix) - gradient_vectors = _AggregateMatrices._disunite(united_gradient_vector, jacobian_matrices) - return gradient_vectors - - @staticmethod - def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor: - return torch.cat(list(jacobian_matrices.values()), dim=1) - - @staticmethod - def _disunite( - united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor] - ) -> TensorDict: - gradient_vectors = {} - start = 0 - for key, jacobian_matrix in jacobian_matrices.items(): - end = start + jacobian_matrix.shape[1] - current_gradient_vector = united_gradient_vector[start:end] - gradient_vectors[key] = current_gradient_vector - start = end - return gradient_vectors - - -class _Matrixify(Transform): - """Transform reshaping Jacobians into JacobianMatrices.""" - - def __call__(self, jacobians: TensorDict) -> TensorDict: - jacobian_matrices = { - key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items() - } - return jacobian_matrices - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return input_keys - - -class _Reshape(Transform): - """Transform reshaping GradientVectors into Gradients.""" - - def __call__(self, gradient_vectors: TensorDict) -> TensorDict: - gradients = { - key: gradient_vector.view(key.shape) - for key, gradient_vector in gradient_vectors.items() - } - return gradients - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return input_keys diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index ca2e1a259..f989a1a77 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -1,6 +1,6 @@ """ This file contains the test of the backward usage example, with a verification of the value of the -obtained `.grad` field. +obtained `.jac` field. """ from torch.testing import assert_close @@ -9,7 +9,6 @@ def test_backward(): import torch - from torchjd.aggregation import UPGrad from torchjd.autojac import backward param = torch.tensor([1.0, 2.0], requires_grad=True) @@ -17,6 +16,6 @@ def test_backward(): y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() - backward([y1, y2], UPGrad()) + backward([y1, y2]) - assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + assert_close(param.jac, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py new file mode 100644 index 000000000..57bd42f08 --- /dev/null +++ b/tests/doc/test_jac_to_grad.py @@ -0,0 +1,22 @@ +""" +This file contains the test of the jac_to_grad usage example, with a verification of the value of +the obtained `.grad` field. +""" + +from torch.testing import assert_close + + +def test_jac_to_grad(): + import torch + + from torchjd.aggregation import UPGrad + from torchjd.autojac import backward, jac_to_grad + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + backward([y1, y2]) # param now has a .jac field + jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + + assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 867aad6b9..b89ac77be 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -14,7 +14,7 @@ def test_amp(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import jac_to_grad, mtl_backward shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -42,7 +42,8 @@ def test_amp(): loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) + mtl_backward(losses=scaled_losses, features=features) + jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -55,6 +56,7 @@ def test_basic_usage(): from torchjd import autojac from torchjd.aggregation import UPGrad + from torchjd.autojac import jac_to_grad model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2)) optimizer = SGD(model.parameters(), lr=0.1) @@ -69,7 +71,8 @@ def test_basic_usage(): loss1 = loss_fn(output[:, 0], target1) loss2 = loss_fn(output[:, 1], target2) - autojac.backward([loss1, loss2], aggregator) + autojac.backward([loss1, loss2]) + jac_to_grad(model.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -148,7 +151,7 @@ def test_autojac(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward + from torchjd.autojac import backward, jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -163,7 +166,8 @@ def test_autojac(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - backward(losses, aggregator) + backward(losses) + jac_to_grad(model.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -219,7 +223,7 @@ def test_lightning_integration(): from torch.utils.data import DataLoader, TensorDataset from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import jac_to_grad, mtl_backward class Model(LightningModule): def __init__(self): @@ -240,7 +244,9 @@ def training_step(self, batch, batch_idx) -> None: loss2 = mse_loss(output2, target2) opt = self.optimizers() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) + + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() @@ -274,7 +280,7 @@ def test_monitoring(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import jac_to_grad, mtl_backward def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" @@ -314,7 +320,8 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -325,7 +332,7 @@ def test_mtl(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward + from torchjd.autojac import jac_to_grad, mtl_backward shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -351,7 +358,8 @@ def test_mtl(): loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -395,7 +403,7 @@ def test_rnn(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward + from torchjd.autojac import backward, jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) @@ -408,6 +416,7 @@ def test_rnn(): output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - backward(losses, aggregator, parallel_chunk_size=1) + backward(losses, parallel_chunk_size=1) + jac_to_grad(rnn.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/tests/unit/autojac/_asserts.py b/tests/unit/autojac/_asserts.py new file mode 100644 index 000000000..742998d81 --- /dev/null +++ b/tests/unit/autojac/_asserts.py @@ -0,0 +1,35 @@ +from typing import cast + +import torch +from torch.testing import assert_close + +from torchjd.autojac._accumulation import TensorWithJac + + +def assert_has_jac(t: torch.Tensor) -> None: + assert hasattr(t, "jac") + t_ = cast(TensorWithJac, t) + assert t_.jac is not None and t_.jac.shape[1:] == t_.shape + + +def assert_has_no_jac(t: torch.Tensor) -> None: + assert not hasattr(t, "jac") + + +def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor) -> None: + assert hasattr(t, "jac") + t_ = cast(TensorWithJac, t) + assert_close(t_.jac, expected_jac) + + +def assert_has_grad(t: torch.Tensor) -> None: + assert (t.grad is not None) and (t.shape == t.grad.shape) + + +def assert_has_no_grad(t: torch.Tensor) -> None: + assert t.grad is None + + +def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor) -> None: + assert t.grad is not None + assert_close(t.grad, expected_grad) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 45db6d61b..6dadf1efc 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -1,105 +1,189 @@ from pytest import mark, raises +from unit.autojac._asserts import assert_grad_close, assert_jac_close from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import ones_, tensor_, zeros_ -from torchjd.autojac._transform import Accumulate +from torchjd.autojac._transform import AccumulateGrad, AccumulateJac -def test_single_accumulation(): +def test_single_grad_accumulation(): """ - Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run + Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run once. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([]) - value2 = ones_([1]) - value3 = ones_([2, 3]) - input = {key1: value1, key2: value2, key3: value3} + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_(shape) for shape in shapes] + input = dict(zip(keys, values)) - accumulate = Accumulate() + accumulate = AccumulateGrad() output = accumulate(input) - expected_output = {} + assert_tensor_dicts_are_close(output, {}) - assert_tensor_dicts_are_close(output, expected_output) + for key, value in zip(keys, values): + assert_grad_close(key, value) - grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} - expected_grads = {key1: value1, key2: value2, key3: value3} - assert_tensor_dicts_are_close(grads, expected_grads) +@mark.parametrize("iterations", [1, 2, 4, 10, 13]) +def test_multiple_grad_accumulations(iterations: int): + """ + Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run + `iterations` times. + """ + + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_(shape) for shape in shapes] + accumulate = AccumulateGrad() + + for i in range(iterations): + # Clone values to ensure that we accumulate values that are not ever used afterwards + input = {key: value.clone() for key, value in zip(keys, values)} + accumulate(input) + + for key, value in zip(keys, values): + assert_grad_close(key, iterations * value) + + +def test_accumulate_grad_fails_when_no_requires_grad(): + """ + Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a + tensor that does not require grad. + """ + + key = zeros_([1], requires_grad=False) + value = ones_([1]) + input = {key: value} + + accumulate = AccumulateGrad() + + with raises(ValueError): + accumulate(input) + + +def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): + """ + Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a + tensor that is not a leaf and that does not retain grad. + """ + + key = tensor_([1.0], requires_grad=True) * 2 + value = ones_([1]) + input = {key: value} + + accumulate = AccumulateGrad() + + with raises(ValueError): + accumulate(input) + + +def test_accumulate_grad_check_keys(): + """Tests that the `check_keys` method works correctly for AccumulateGrad.""" + + key = tensor_([1.0], requires_grad=True) + accumulate = AccumulateGrad() + + output_keys = accumulate.check_keys({key}) + assert output_keys == set() + + +def test_single_jac_accumulation(): + """ + Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run + once. + """ + + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_([4] + shape) for shape in shapes] + input = dict(zip(keys, values)) + + accumulate = AccumulateJac() + + output = accumulate(input) + assert_tensor_dicts_are_close(output, {}) + + for key, value in zip(keys, values): + assert_jac_close(key, value) @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_accumulation(iterations: int): +def test_multiple_jac_accumulations(iterations: int): """ - Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run + Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run `iterations` times. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([]) - value2 = ones_([1]) - value3 = ones_([2, 3]) + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_([4] + shape) for shape in shapes] - accumulate = Accumulate() + accumulate = AccumulateJac() for i in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards - input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()} + input = {key: value.clone() for key, value in zip(keys, values)} accumulate(input) - grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} - expected_grads = { - key1: iterations * value1, - key2: iterations * value2, - key3: iterations * value3, - } + for key, value in zip(keys, values): + assert_jac_close(key, iterations * value) - assert_tensor_dicts_are_close(grads, expected_grads) - -def test_no_requires_grad_fails(): +def test_accumulate_jac_fails_when_no_requires_grad(): """ - Tests that the Accumulate transform raises an error when it tries to populate a .grad of a + Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that does not require grad. """ key = zeros_([1], requires_grad=False) - value = ones_([1]) + value = ones_([4, 1]) input = {key: value} - accumulate = Accumulate() + accumulate = AccumulateJac() with raises(ValueError): accumulate(input) -def test_no_leaf_and_no_retains_grad_fails(): +def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): """ - Tests that the Accumulate transform raises an error when it tries to populate a .grad of a + Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that is not a leaf and that does not retain grad. """ key = tensor_([1.0], requires_grad=True) * 2 - value = ones_([1]) + value = ones_([4, 1]) input = {key: value} - accumulate = Accumulate() + accumulate = AccumulateJac() with raises(ValueError): accumulate(input) -def test_check_keys(): - """Tests that the `check_keys` method works correctly.""" +def test_accumulate_jac_fails_when_shape_mismatch(): + """ + Tests that the AccumulateJac transform raises an error when the jacobian shape does not match + the parameter shape (ignoring the first dimension). + """ + + key = zeros_([2, 3], requires_grad=True) + value = ones_([4, 3, 2]) # Wrong shape: should be [4, 2, 3], not [4, 3, 2] + input = {key: value} + + accumulate = AccumulateJac() + + with raises(RuntimeError): + accumulate(input) + + +def test_accumulate_jac_check_keys(): + """Tests that the `check_keys` method works correctly for AccumulateJac.""" key = tensor_([1.0], requires_grad=True) - accumulate = Accumulate() + accumulate = AccumulateJac() output_keys = accumulate.check_keys({key}) assert output_keys == set() diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py deleted file mode 100644 index 5beaed20f..000000000 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ /dev/null @@ -1,155 +0,0 @@ -import math - -import torch -from pytest import mark, raises -from settings import DEVICE -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import rand_, tensor_, zeros_ - -from torchjd.aggregation import Random -from torchjd.autojac._transform import OrderedSet, RequirementError -from torchjd.autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape -from torchjd.autojac._transform._base import TensorDict - - -def _make_jacobian_matrices(n_outputs: int, rng: torch.Generator) -> TensorDict: - jacobian_shapes = [[n_outputs, math.prod(shape)] for shape in _param_shapes] - jacobian_list = [rand_(shape, generator=rng) for shape in jacobian_shapes] - jacobian_matrices = {key: jac for key, jac in zip(_keys, jacobian_list)} - return jacobian_matrices - - -_param_shapes = [ - [], - [1], - [2], - [5], - [1, 1], - [2, 3], - [5, 5], - [1, 1, 1], - [2, 3, 4], - [5, 5, 5], - [1, 1, 1, 1], - [2, 3, 4, 5], - [5, 5, 5, 5], -] -_keys = [zeros_(shape) for shape in _param_shapes] - -_rng = torch.Generator(device=DEVICE) -_rng.manual_seed(0) -_jacobian_matrix_dicts = [_make_jacobian_matrices(n_outputs, _rng) for n_outputs in [1, 2, 5]] - - -@mark.parametrize("jacobian_matrices", _jacobian_matrix_dicts) -def test_aggregate_matrices_output_structure(jacobian_matrices: TensorDict): - """ - Tests that applying _AggregateMatrices to various dictionaries of jacobian matrices gives an - output of the desired structure. - """ - - aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet(_keys)) - gradient_vectors = aggregate_matrices(jacobian_matrices) - - assert set(jacobian_matrices.keys()) == set(gradient_vectors.keys()) - - for key in jacobian_matrices.keys(): - assert gradient_vectors[key].numel() == jacobian_matrices[key][0].numel() - - -def test_aggregate_matrices_empty_dict(): - """Tests that applying _AggregateMatrices to an empty input gives an empty output.""" - - aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet([])) - gradient_vectors = aggregate_matrices({}) - assert len(gradient_vectors) == 0 - - -def test_matrixify(): - """Tests that the Matrixify transform correctly creates matrices from the jacobians.""" - - n_outputs = 5 - key1 = zeros_([]) - key2 = zeros_([1]) - key3 = zeros_([2, 3]) - value1 = tensor_([1.0] * n_outputs) - value2 = tensor_([[2.0]] * n_outputs) - value3 = tensor_([[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]] * n_outputs) - input = {key1: value1, key2: value2, key3: value3} - - matrixify = _Matrixify() - - output = matrixify(input) - expected_output = { - key1: tensor_([[1.0]] * n_outputs), - key2: tensor_([[2.0]] * n_outputs), - key3: tensor_([[3.0, 4.0, 5.0, 6.0, 7.0, 8.0]] * n_outputs), - } - - assert_tensor_dicts_are_close(output, expected_output) - - -def test_reshape(): - """Tests that the Reshape transform correctly creates gradients from gradient vectors.""" - - key1 = zeros_([]) - key2 = zeros_([1]) - key3 = zeros_([2, 3]) - value1 = tensor_([1.0]) - value2 = tensor_([2.0]) - value3 = tensor_([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - input = {key1: value1, key2: value2, key3: value3} - - reshape = _Reshape() - - output = reshape(input) - expected_output = { - key1: tensor_(1.0), - key2: tensor_([2.0]), - key3: tensor_([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]), - } - - assert_tensor_dicts_are_close(output, expected_output) - - -def test_aggregate_matrices_check_keys(): - """ - Tests that the `check_keys` method works correctly: the input_keys must match the stored - key_order. - """ - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - key3 = tensor_([2.0]) - aggregate = _AggregateMatrices(Random(), OrderedSet([key2, key1])) - - output_keys = aggregate.check_keys({key1, key2}) - assert output_keys == {key1, key2} - - with raises(RequirementError): - aggregate.check_keys({key1}) - - with raises(RequirementError): - aggregate.check_keys({key1, key2, key3}) - - -def test_matrixify_check_keys(): - """Tests that the `check_keys` method works correctly.""" - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - matrixify = _Matrixify() - - output_keys = matrixify.check_keys({key1, key2}) - assert output_keys == {key1, key2} - - -def test_reshape_check_keys(): - """Tests that the `check_keys` method works correctly.""" - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - reshape = _Reshape() - - output_keys = reshape.check_keys({key1, key2}) - assert output_keys == {key1, key2} diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 8a943e834..a712dcefe 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -5,7 +5,7 @@ from utils.tensors import tensor_, zeros_ from torchjd.autojac._transform import ( - Accumulate, + AccumulateGrad, Conjunction, Diagonalize, Grad, @@ -186,10 +186,10 @@ def test_conjunction_is_associative(): def test_conjunction_accumulate_select(): """ - Tests that it is possible to conjunct an Accumulate and a Select in this order. - It is not trivial since the type of the TensorDict returned by the first transform (Accumulate) - is EmptyDict, which is not the type that the conjunction should return (Gradients), but a - subclass of it. + Tests that it is possible to conjunct an AccumulateGrad and a Select in this order. + It is not trivial since the type of the TensorDict returned by the first transform + (AccumulateGrad) is EmptyDict, which is not the type that the conjunction should return + (Gradients), but a subclass of it. """ key = tensor_([1.0, 2.0, 3.0], requires_grad=True) @@ -197,7 +197,7 @@ def test_conjunction_accumulate_select(): input = {key: value} select = Select(set()) - accumulate = Accumulate() + accumulate = AccumulateGrad() conjunction = accumulate | select output = conjunction(input) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 885a9c154..3bd50e810 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,14 +1,13 @@ import torch from pytest import mark, raises -from torch.autograd import grad -from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import backward from torchjd.autojac._backward import _create_transform from torchjd.autojac._transform import OrderedSet +from ._asserts import assert_has_jac, assert_has_no_jac, assert_jac_close + def test_check_create_transform(): """Tests that _create_transform creates a valid Transform.""" @@ -21,7 +20,6 @@ def test_check_create_transform(): transform = _create_transform( tensors=OrderedSet([y1, y2]), - aggregator=Mean(), inputs=OrderedSet([a1, a2]), retain_graph=False, parallel_chunk_size=None, @@ -31,8 +29,7 @@ def test_check_create_transform(): assert output_keys == set() -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) -def test_various_aggregators(aggregator: Aggregator): +def test_shape_is_correct(): """Tests that backward works for various aggregators.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -41,24 +38,22 @@ def test_various_aggregators(aggregator: Aggregator): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], aggregator) + backward([y1, y2]) for a in [a1, a2]: - assert (a.grad is not None) and (a.shape == a.grad.shape) + assert_has_jac(a) -@mark.parametrize("aggregator", [Mean(), UPGrad()]) @mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_inputs", [True, False]) @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( - aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool, chunk_size: int | None, ): """ - Tests that the .grad value filled by backward is correct in a simple example of matrix-vector + Tests that the .jac value filled by backward is correct in a simple example of matrix-vector product. """ @@ -73,16 +68,15 @@ def test_value_is_correct( backward( [output], - aggregator, inputs=inputs, parallel_chunk_size=chunk_size, ) - assert_close(input.grad, aggregator(J)) + assert_jac_close(input, J) def test_empty_inputs(): - """Tests that backward does not fill the .grad values if no input is specified.""" + """Tests that backward does not fill the .jac values if no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True) @@ -90,15 +84,15 @@ def test_empty_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], Mean(), inputs=[]) + backward([y1, y2], inputs=[]) for a in [a1, a2]: - assert a.grad is None + assert_has_no_jac(a) def test_partial_inputs(): """ - Tests that backward fills the right .grad values when only a subset of the actual inputs are + Tests that backward fills the right .jac values when only a subset of the actual inputs are specified as inputs. """ @@ -108,10 +102,10 @@ def test_partial_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], Mean(), inputs=[a1]) + backward([y1, y2], inputs=[a1]) - assert (a1.grad is not None) and (a1.shape == a1.grad.shape) - assert a2.grad is None + assert_has_jac(a1) + assert_has_no_jac(a2) def test_empty_tensors_fails(): @@ -121,34 +115,41 @@ def test_empty_tensors_fails(): a2 = tensor_([3.0, 4.0], requires_grad=True) with raises(ValueError): - backward([], UPGrad(), inputs=[a1, a2]) + backward([], inputs=[a1, a2]) def test_multiple_tensors(): """ Tests that giving multiple tensors to backward is equivalent to giving a single tensor - containing the all the values of the original tensors. + containing all the values of the original tensors. """ - aggregator = UPGrad() + J1 = tensor_([[-1.0, 1.0], [2.0, 4.0]]) + J2 = tensor_([[1.0, 1.0], [0.6, 0.8]]) + # First computation graph: multiple tensors a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True) - inputs = [a1, a2] y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], aggregator, retain_graph=True) + backward([y1, y2]) + + assert_jac_close(a1, J1) + assert_jac_close(a2, J2) - input_to_grad = {a: a.grad for a in inputs} - for a in inputs: - a.grad = None + # Second computation graph: single concatenated tensor + b1 = tensor_([1.0, 2.0], requires_grad=True) + b2 = tensor_([3.0, 4.0], requires_grad=True) - backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), aggregator) + z1 = tensor_([-1.0, 1.0]) @ b1 + b2.sum() + z2 = (b1**2).sum() + b2.norm() - for a in inputs: - assert (a.grad == input_to_grad[a]).all() + backward(torch.cat([z1.reshape(-1), z2.reshape(-1)])) + + assert_jac_close(b1, J1) + assert_jac_close(b2, J2) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -161,10 +162,10 @@ def test_various_valid_chunk_sizes(chunk_size): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], UPGrad(), parallel_chunk_size=chunk_size) + backward([y1, y2], parallel_chunk_size=chunk_size) for a in [a1, a2]: - assert (a.grad is not None) and (a.shape == a.grad.shape) + assert_has_jac(a) @mark.parametrize("chunk_size", [0, -1]) @@ -178,7 +179,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): y2 = (a1**2).sum() + a2.norm() with raises(ValueError): - backward([y1, y2], UPGrad(), parallel_chunk_size=chunk_size) + backward([y1, y2], parallel_chunk_size=chunk_size) def test_input_retaining_grad_fails(): @@ -192,8 +193,13 @@ def test_input_retaining_grad_fails(): b.retain_grad() y = 3 * b + # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor + # (and it also fills b.jac with the correct Jacobian) + backward(tensors=y, inputs=[b]) + with raises(RuntimeError): - backward(tensors=y, aggregator=UPGrad(), inputs=[b]) + # Using such a BatchedTensor should result in an error + _ = -b.grad def test_non_input_retaining_grad_fails(): @@ -208,7 +214,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor - backward(tensors=y, aggregator=UPGrad(), inputs=[a]) + backward(tensors=y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -227,18 +233,12 @@ def test_tensor_used_multiple_times(chunk_size: int | None): c = a * b d = a * c e = a * d - aggregator = UPGrad() - backward([d, e], aggregator=aggregator, parallel_chunk_size=chunk_size) + backward([d, e], parallel_chunk_size=chunk_size) - expected_jacobian = tensor_( - [ - [2.0 * 3.0 * (a**2).item()], - [2.0 * 4.0 * (a**3).item()], - ], - ) + expected_jacobian = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()]) - assert_close(a.grad, aggregator(expected_jacobian).squeeze()) + assert_jac_close(a, expected_jacobian) def test_repeated_tensors(): @@ -257,7 +257,7 @@ def test_repeated_tensors(): y2 = (a1**2).sum() + (a2**2).sum() with raises(ValueError): - backward([y1, y1, y2], Sum()) + backward([y1, y1, y2]) def test_repeated_inputs(): @@ -273,10 +273,10 @@ def test_repeated_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + (a2**2).sum() - expected_grad_wrt_a1 = grad([y1, y2], a1, retain_graph=True)[0] - expected_grad_wrt_a2 = grad([y1, y2], a2, retain_graph=True)[0] + J1 = tensor_([[-1.0, 1.0], [2.0, 4.0]]) + J2 = tensor_([[1.0, 1.0], [6.0, 8.0]]) - backward([y1, y2], Sum(), inputs=[a1, a1, a2]) + backward([y1, y2], inputs=[a1, a1, a2]) - assert_close(a1.grad, expected_grad_wrt_a1) - assert_close(a2.grad, expected_grad_wrt_a2) + assert_jac_close(a1, J1) + assert_jac_close(a2, J2) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py new file mode 100644 index 000000000..7d16247c9 --- /dev/null +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -0,0 +1,103 @@ +from pytest import mark, raises +from unit.autojac._asserts import assert_grad_close, assert_has_jac, assert_has_no_jac +from utils.tensors import tensor_ + +from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad +from torchjd.autojac._jac_to_grad import jac_to_grad + + +@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) +def test_various_aggregators(aggregator: Aggregator): + """Tests that jac_to_grad works for various aggregators.""" + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + expected_grad = aggregator(jac) + g1 = expected_grad[0] + g2 = expected_grad[1:] + + jac_to_grad([t1, t2], aggregator) + + assert_grad_close(t1, g1) + assert_grad_close(t2, g2) + + +def test_single_tensor(): + """Tests that jac_to_grad works when a single tensor is provided.""" + + aggregator = UPGrad() + t = tensor_([2.0, 3.0, 4.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t.__setattr__("jac", jac) + g = aggregator(jac) + + jac_to_grad([t], aggregator) + + assert_grad_close(t, g) + + +def test_no_jac_field(): + """Tests that jac_to_grad fails when a tensor does not have a jac field.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_no_requires_grad(): + """Tests that jac_to_grad fails when a tensor does not require grad.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=False) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_row_mismatch(): + """Tests that jac_to_grad fails when the number of rows of the .jac is not constant.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + t1.__setattr__("jac", tensor_([5.0, 6.0, 7.0])) # 3 rows + t2.__setattr__("jac", tensor_([[1.0, 2.0], [3.0, 4.0]])) # 2 rows + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_no_tensors(): + """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" + + jac_to_grad([], aggregator=UPGrad()) + + +@mark.parametrize("retain_jac", [True, False]) +def test_jacs_are_freed(retain_jac: bool): + """Tests that jac_to_grad frees the jac fields if an only if retain_jac is False.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + jac_to_grad([t1, t2], aggregator, retain_jac=retain_jac) + + check = assert_has_jac if retain_jac else assert_has_no_jac + check(t1) + check(t2) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 86595f927..1bd247f6d 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -2,14 +2,21 @@ from pytest import mark, raises from settings import DTYPE from torch.autograd import grad -from torch.testing import assert_close from utils.tensors import arange_, rand_, randn_, tensor_ -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import mtl_backward from torchjd.autojac._mtl_backward import _create_transform from torchjd.autojac._transform import OrderedSet +from ._asserts import ( + assert_grad_close, + assert_has_grad, + assert_has_jac, + assert_has_no_grad, + assert_has_no_jac, + assert_jac_close, +) + def test_check_create_transform(): """Tests that _create_transform creates a valid Transform.""" @@ -26,7 +33,6 @@ def test_check_create_transform(): transform = _create_transform( losses=OrderedSet([y1, y2]), features=OrderedSet([f1, f2]), - aggregator=Mean(), tasks_params=[OrderedSet([p1]), OrderedSet([p2])], shared_params=OrderedSet([p0]), retain_graph=False, @@ -37,9 +43,8 @@ def test_check_create_transform(): assert output_keys == set() -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) -def test_various_aggregators(aggregator: Aggregator): - """Tests that mtl_backward works for various aggregators.""" +def test_shape_is_correct(): + """Tests that mtl_backward works correctly.""" p0 = tensor_([1.0, 2.0], requires_grad=True) p1 = tensor_([1.0, 2.0], requires_grad=True) @@ -50,26 +55,25 @@ def test_various_aggregators(aggregator: Aggregator): y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f1, f2]) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) -@mark.parametrize("aggregator", [Mean(), UPGrad()]) @mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_shared_params", [True, False]) @mark.parametrize("manually_specify_tasks_params", [True, False]) @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( - aggregator: Aggregator, shape: tuple[int, int], manually_specify_shared_params: bool, manually_specify_tasks_params: bool, chunk_size: int | None, ): """ - Tests that the .grad value filled by mtl_backward is correct in a simple example of + Tests that the .jac value filled by mtl_backward is correct in a simple example of matrix-vector product for three tasks whose loss are given by a simple inner product of the shared features with the task parameter. @@ -100,20 +104,17 @@ def test_value_is_correct( mtl_backward( losses=[y1, y2, y3], features=f, - aggregator=aggregator, tasks_params=tasks_params, shared_params=shared_params, parallel_chunk_size=chunk_size, ) - assert_close(p1.grad, f) - assert_close(p2.grad, f) - assert_close(p3.grad, f) + assert_grad_close(p1, f) + assert_grad_close(p2, f) + assert_grad_close(p3, f) expected_jacobian = torch.stack((p1, p2, p3)) @ J - expected_aggregation = aggregator(expected_jacobian) - - assert_close(p0.grad, expected_aggregation) + assert_jac_close(p0, expected_jacobian) def test_empty_tasks_fails(): @@ -125,7 +126,7 @@ def test_empty_tasks_fails(): f2 = (p0**2).sum() + p0.norm() with raises(ValueError): - mtl_backward(losses=[], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[], features=[f1, f2]) def test_single_task(): @@ -138,10 +139,10 @@ def test_single_task(): f2 = (p0**2).sum() + p0.norm() y1 = f1 * p1[0] + f2 * p1[1] - mtl_backward(losses=[y1], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1], features=[f1, f2]) - for p in [p0, p1]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + assert_has_grad(p1) def test_incoherent_task_number_fails(): @@ -163,7 +164,6 @@ def test_incoherent_task_number_fails(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[p1]], # Wrong shared_params=[p0], ) @@ -171,14 +171,13 @@ def test_incoherent_task_number_fails(): mtl_backward( losses=[y1], # Wrong features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) def test_empty_params(): - """Tests that mtl_backward does not fill the .grad values if no parameter is specified.""" + """Tests that mtl_backward does not fill the .jac/.grad values if no parameter is specified.""" p0 = tensor_([1.0, 2.0], requires_grad=True) p1 = tensor_([1.0, 2.0], requires_grad=True) @@ -192,13 +191,13 @@ def test_empty_params(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[], []], shared_params=[], ) - for p in [p0, p1, p2]: - assert p.grad is None + assert_has_no_jac(p0) + for p in [p1, p2]: + assert_has_no_grad(p) def test_multiple_params_per_task(): @@ -216,10 +215,11 @@ def test_multiple_params_per_task(): y1 = f1 * p1_a + (f2 * p1_b).sum() + (f1 * p1_c).sum() y2 = f1 * p2_a * (f2 * p2_b).sum() - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[f1, f2]) - for p in [p0, p1_a, p1_b, p1_c, p2_a, p2_b]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1_a, p1_b, p1_c, p2_a, p2_b]: + assert_has_grad(p) @mark.parametrize( @@ -249,19 +249,20 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): mtl_backward( losses=[y1, y2], features=features, - aggregator=UPGrad(), tasks_params=[[p1], [p2]], # Enforce differentiation w.r.t. params that haven't been used shared_params=shared_params, ) - for p in [*shared_params, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + for p in shared_params: + assert_has_jac(p) + for p in [p1, p2]: + assert_has_grad(p) def test_partial_params(): """ - Tests that mtl_backward fills the right .grad values when only a subset of the parameters are - specified as inputs. + Tests that mtl_backward fills the right .jac/.grad values when only a subset of the parameters + are specified as inputs. """ p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -276,14 +277,13 @@ def test_partial_params(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=Mean(), tasks_params=[[p1], []], shared_params=[p0], ) - assert (p0.grad is not None) and (p0.shape == p0.grad.shape) - assert (p1.grad is not None) and (p1.shape == p1.grad.shape) - assert p2.grad is None + assert_has_jac(p0) + assert_has_grad(p1) + assert_has_no_grad(p2) def test_empty_features_fails(): @@ -299,7 +299,7 @@ def test_empty_features_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[]) @mark.parametrize( @@ -323,10 +323,11 @@ def test_various_single_features(shape: tuple[int, ...]): y1 = (f * p1[0]).sum() + (f * p1[1]).sum() y2 = (f * p2[0]).sum() * (f * p2[1]).sum() - mtl_backward(losses=[y1, y2], features=f, aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=f) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) @mark.parametrize( @@ -354,10 +355,11 @@ def test_various_feature_lists(shapes: list[tuple[int]]): y1 = sum([(f * p).sum() for f, p in zip(features, p1)]) y2 = (features[0] * p2).sum() - mtl_backward(losses=[y1, y2], features=features, aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=features) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) def test_non_scalar_loss_fails(): @@ -373,7 +375,7 @@ def test_non_scalar_loss_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[f1, f2]) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -392,12 +394,12 @@ def test_various_valid_chunk_sizes(chunk_size): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), parallel_chunk_size=chunk_size, ) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) @mark.parametrize("chunk_size", [0, -1]) @@ -417,15 +419,14 @@ def test_non_positive_chunk_size_fails(chunk_size: int): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), parallel_chunk_size=chunk_size, ) def test_shared_param_retaining_grad_fails(): """ - Tests that mtl_backward raises an error when some shared param in the computation graph of the - ``features`` parameter retains grad and vmap has to be used. + Tests that mtl_backward fails to fill a valid `.grad` when some shared param in the computation + graph of the ``features`` parameter retains grad and vmap has to be used. """ p0 = tensor_(1.0, requires_grad=True) @@ -438,14 +439,17 @@ def test_shared_param_retaining_grad_fails(): y1 = p1 * f y2 = p2 * f + # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor + mtl_backward( + losses=[y1, y2], + features=[f], + tasks_params=[[p1], [p2]], + shared_params=[a, p0], + ) + with raises(RuntimeError): - mtl_backward( - losses=[y1, y2], - features=[f], - aggregator=UPGrad(), - tasks_params=[[p1], [p2]], - shared_params=[a, p0], - ) + # Using such a BatchedTensor should result in an error + _ = -a.grad def test_shared_activation_retaining_grad_fails(): @@ -468,7 +472,6 @@ def test_shared_activation_retaining_grad_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) @@ -490,15 +493,14 @@ def test_tasks_params_overlap(): y1 = f * p1 * p12 y2 = f * p2 * p12 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f]) - assert_close(p2.grad, f * p12) - assert_close(p1.grad, f * p12) - assert_close(p12.grad, f * p1 + f * p2) + assert_grad_close(p2, f * p12) + assert_grad_close(p1, f * p12) + assert_grad_close(p12, f * p1 + f * p2) J = tensor_([[-8.0, 8.0], [-12.0, 12.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_tasks_params_are_the_same(): @@ -511,13 +513,12 @@ def test_tasks_params_are_the_same(): y1 = f * p1 y2 = f + p1 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f]) - assert_close(p1.grad, f + 1) + assert_grad_close(p1, f + 1) J = tensor_([[-2.0, 2.0], [-1.0, 1.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_task_params_is_subset_of_other_task_params(): @@ -534,14 +535,13 @@ def test_task_params_is_subset_of_other_task_params(): y1 = f * p1 y2 = y1 * p2 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator, retain_graph=True) + mtl_backward(losses=[y1, y2], features=[f], retain_graph=True) - assert_close(p2.grad, y1) - assert_close(p1.grad, p2 * f + f) + assert_grad_close(p2, y1) + assert_grad_close(p1, p2 * f + f) J = tensor_([[-2.0, 2.0], [-6.0, 6.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_shared_params_overlapping_with_tasks_params_fails(): @@ -562,7 +562,6 @@ def test_shared_params_overlapping_with_tasks_params_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), tasks_params=[[p1], [p0, p2]], # Problem: p0 is also shared shared_params=[p0], ) @@ -586,7 +585,6 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), ) @@ -610,7 +608,7 @@ def test_repeated_losses(): with raises(ValueError): losses = [y1, y1, y2] - mtl_backward(losses=losses, features=[f1, f2], aggregator=Sum(), retain_graph=True) + mtl_backward(losses=losses, features=[f1, f2], retain_graph=True) def test_repeated_features(): @@ -633,7 +631,7 @@ def test_repeated_features(): with raises(ValueError): features = [f1, f1, f2] - mtl_backward(losses=[y1, y2], features=features, aggregator=Sum()) + mtl_backward(losses=[y1, y2], features=features) def test_repeated_shared_params(): @@ -648,20 +646,20 @@ def test_repeated_shared_params(): p2 = tensor_([3.0, 4.0], requires_grad=True) f1 = tensor_([-1.0, 1.0]) @ p0 - f2 = (p0**2).sum() + p0.norm() + f2 = (p0**2).sum() y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] - expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] - expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + J0 = tensor_([[3.0, 9.0], [5.0, 19.0]]) + g1 = grad([y1], [p1], retain_graph=True)[0] + g2 = grad([y2], [p2], retain_graph=True)[0] shared_params = [p0, p0] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), shared_params=shared_params) + mtl_backward(losses=[y1, y2], features=[f1, f2], shared_params=shared_params) - assert_close(p0.grad, expected_grad_wrt_p0) - assert_close(p1.grad, expected_grad_wrt_p1) - assert_close(p2.grad, expected_grad_wrt_p2) + assert_jac_close(p0, J0) + assert_grad_close(p1, g1) + assert_grad_close(p2, g2) def test_repeated_task_params(): @@ -676,17 +674,17 @@ def test_repeated_task_params(): p2 = tensor_([3.0, 4.0], requires_grad=True) f1 = tensor_([-1.0, 1.0]) @ p0 - f2 = (p0**2).sum() + p0.norm() + f2 = (p0**2).sum() y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] - expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] - expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + J0 = tensor_([[3.0, 9.0], [5.0, 19.0]]) + g1 = grad([y1], [p1], retain_graph=True)[0] + g2 = grad([y2], [p2], retain_graph=True)[0] tasks_params = [[p1, p1], [p2]] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), tasks_params=tasks_params) + mtl_backward(losses=[y1, y2], features=[f1, f2], tasks_params=tasks_params) - assert_close(p0.grad, expected_grad_wrt_p0) - assert_close(p1.grad, expected_grad_wrt_p1) - assert_close(p2.grad, expected_grad_wrt_p2) + assert_jac_close(p0, J0) + assert_grad_close(p1, g1) + assert_grad_close(p2, g2)