diff --git a/CHANGELOG.md b/CHANGELOG.md index c27d3fed4..39a62ffcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 changes that do not affect the user. ## [Unreleased] +## [0.7.1] - 2025-06-12 +### Added +- Seamless sparse-matrix support (SpMM and adjacency handling) for TorchJD, as SparseMatMul is currently not compatible with Jacobian Descent due to torch.vmap() dependencies. + ## [0.7.0] - 2025-06-04 diff --git a/docs/source/docs/sparse.rst b/docs/source/docs/sparse.rst new file mode 100644 index 000000000..64938e77e --- /dev/null +++ b/docs/source/docs/sparse.rst @@ -0,0 +1,6 @@ +:hide-toc: + +sparse.sparse_mm +================ + +.. autofunction:: torchjd.sparse.sparse_mm diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 84d42a462..580cdc529 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -28,6 +28,7 @@ This section contains some usage examples for TorchJD. basic_usage.rst iwrm.rst mtl.rst + sparse.rst rnn.rst monitoring.rst lightning_integration.rst diff --git a/docs/source/examples/sparse.rst b/docs/source/examples/sparse.rst new file mode 100644 index 000000000..53d47d4e7 --- /dev/null +++ b/docs/source/examples/sparse.rst @@ -0,0 +1,34 @@ +Quick example +============================== + +TorchJD now offers helpers that make working with sparse adjacency matrices +transparent. +The key entry-point is :pyfunc:`torchjd.sparse.sparse_mm`, +a vmap-aware autograd function that replaces the usual +``torch.sparse.mm`` inside Jacobian Descent pipelines. + +The snippet below shows how you can mix a sparse objective (involving +``A @ p``) with a dense one, then aggregate their Jacobians using +:pyclass:`torchjd.aggregation.UPGrad`. + +.. doctest:: + + >>> import torch + >>> from torchjd import backward + >>> from torchjd.sparse import sparse_mm # patches torch automatically + >>> from torchjd.aggregation import UPGrad + >>> + >>> # 2×2 off-diagonal adjacency matrix + >>> A = torch.sparse_coo_tensor( + ... indices=[[0, 1], [1, 0]], + ... values=[1.0, 1.0], + ... size=(2, 2) + ... ).coalesce() + >>> + >>> p = torch.tensor([1.0, 2.0], requires_grad=True) + >>> + >>> y1 = sparse_mm(A, p.unsqueeze(1)).sum() # sparse term + >>> y2 = (p ** 2).sum() # dense term + >>> backward([y1, y2], UPGrad()) # Jacobian Descent step + >>> p.grad # doctest:+ELLIPSIS + tensor([1.0000, 1.6667]) diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py deleted file mode 100644 index 0491e90a0..000000000 --- a/src/torchjd/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -This package enable Jacobian descent, through the `backward` and `mtl_backward` functions, which -are meant to replace the call to `torch.backward` or `loss.backward` in gradient descent. To combine -the information of the Jacobian, an aggregator from the `aggregation` package has to be used. -""" - -from ._autojac import backward, mtl_backward diff --git a/src/torchjd/_autojac/__init__.py b/src/torchjd/_autojac/__init__.py index e2175c165..be1b1d9d7 100644 --- a/src/torchjd/_autojac/__init__.py +++ b/src/torchjd/_autojac/__init__.py @@ -1,2 +1,4 @@ +from torchjd.sparse import sparse_mm + from ._backward import backward from ._mtl_backward import mtl_backward diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..f1648c928 --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1,19 @@ +"""Public interface for TorchJD sparse helpers. + +Importing ``torchjd`` automatically activates seamless sparse support, +unless the environment variable ``TORCHJD_DISABLE_SPARSE`` is set to +``"1"`` **before** the first TorchJD import. +""" + +from __future__ import annotations + +import os + +from ._autograd import sparse_mm # re-export +from ._patch import enable_seamless_sparse + +__all__ = ["sparse_mm"] + +# feature flag +if os.getenv("TORCHJD_DISABLE_SPARSE", "0") != "1": + enable_seamless_sparse() diff --git a/src/torchjd/sparse/_autograd.py b/src/torchjd/sparse/_autograd.py new file mode 100644 index 000000000..76d22108f --- /dev/null +++ b/src/torchjd/sparse/_autograd.py @@ -0,0 +1,62 @@ +"""Vmap-compatible sparse @ dense for TorchJD.""" + +from __future__ import annotations + +from typing import Tuple + +import torch + +from ._registry import to_coalesced_coo + +_orig_sparse_mm = getattr(torch.sparse, "_orig_mm", torch.sparse.mm) + + +class _SparseMatMul(torch.autograd.Function): + """y = A @ X where **A** is sparse and **X** is dense.""" + + @staticmethod + def forward(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: # noqa: D401 + A = to_coalesced_coo(A_like) + + if X.dim() == 3: # (B, N, d) + B, N, d = X.shape + X2d = X.reshape(B * N, d).view(N, B * d) + Y2d = _orig_sparse_mm(A, X2d) # pragma: no cover + return Y2d.view(N, B, d).permute(1, 0, 2) # pragma: no cover + + return _orig_sparse_mm(A, X) + + @staticmethod + def setup_context(ctx, inputs, output) -> None: # noqa: D401 + A_like, _ = inputs + ctx.save_for_backward(to_coalesced_coo(A_like)) + + @staticmethod + def backward(ctx, dY: torch.Tensor) -> Tuple[None, torch.Tensor]: + (A,) = ctx.saved_tensors + AT = A.transpose(0, 1) + + if dY.dim() == 3: # batched + B, N, d = dY.shape + dY2d = dY.permute(1, 0, 2).reshape(N, B * d) + dX2d = _orig_sparse_mm(AT, dY2d) + dX = dX2d.view(N, B, d).permute(1, 0, 2) + return None, dX + + return None, _orig_sparse_mm(AT, dY) # pragma: no cover + + @staticmethod + def vmap(info, in_dims, A_unbatched, X_batched): # noqa: D401 + A = A_unbatched # shared + X = X_batched # (B, N, d) + + B, N, d = X.shape + X2d = X.reshape(B * N, d).view(N, B * d) + Y2d = _orig_sparse_mm(A, X2d) + Y = Y2d.view(N, B, d).permute(1, 0, 2) + return Y, 0 # output & out-dims + + +def sparse_mm(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: + """Return ``A @ X`` through the vmap-safe sparse Function.""" + return _SparseMatMul.apply(A_like, X) diff --git a/src/torchjd/sparse/_patch.py b/src/torchjd/sparse/_patch.py new file mode 100644 index 000000000..6993dac11 --- /dev/null +++ b/src/torchjd/sparse/_patch.py @@ -0,0 +1,81 @@ +"""Monkey-patch hooks that route sparse ops through TorchJD wrappers. + +This module is imported from ``torchjd.sparse`` at import-time. +Patch execution is *idempotent* – calling :pyfunc:`enable_seamless_sparse` +multiple times is safe. +""" + +from __future__ import annotations + +import warnings +from importlib import import_module +from types import MethodType +from typing import Callable + +import torch + +from ._autograd import sparse_mm + +# The wheel might exist yet be ABI-incompatible with the current +# PyTorch, which raises *OSError* at import-time. + +try: # pragma: no cover + torch_sparse = import_module("torch_sparse") # type: ignore +except (ModuleNotFoundError, OSError): + torch_sparse = None + + +# Helpers +def _wrap_mm(orig_fn: Callable, wrapper: Callable) -> Callable: + """Return a patched ``torch.sparse.mm`` that defers to *wrapper*.""" + + def patched(A, X): # noqa: D401 + if isinstance(A, torch.Tensor) and A.is_sparse and X.dim() >= 2: + return wrapper(A, X) + return orig_fn(A, X) + + return patched + + +def _wrap_tensor_matmul(orig_fn: Callable) -> Callable: + def patched(self, other): # noqa: D401 + if self.is_sparse and isinstance(other, torch.Tensor) and other.dim() >= 2: + return sparse_mm(self, other) + return orig_fn(self, other) + + return patched + + +# Public API +def enable_seamless_sparse() -> None: + """Patch common call-sites so users need *no* explicit imports.""" + # torch.sparse.mm + if getattr(torch.sparse, "_orig_mm", None) is None: + torch.sparse._orig_mm = torch.sparse.mm # type: ignore[attr-defined] + torch.sparse.mm = _wrap_mm(torch.sparse._orig_mm, sparse_mm) # type: ignore[attr-defined] + + # tensor @ dense + if getattr(torch.Tensor, "_orig_matmul", None) is None: + torch.Tensor._orig_matmul = torch.Tensor.__matmul__ # type: ignore[attr-defined] # noqa: E501 + torch.Tensor.__matmul__ = _wrap_tensor_matmul( + torch.Tensor._orig_matmul # type: ignore[attr-defined] + ) # type: ignore[attr-defined] + + # torch_sparse (optional) + if torch_sparse is None: + warnings.warn( + "torch_sparse not found: SpSpMM will use slow fallback.", + RuntimeWarning, + stacklevel=2, + ) # pragma: no cover + return + + if not hasattr(torch_sparse.SparseTensor, "_orig_matmul"): + + def _sparse_tensor_matmul(self, dense): # noqa: D401 + return sparse_mm(self, dense) + + torch_sparse.SparseTensor._orig_matmul = torch_sparse.SparseTensor.matmul # type: ignore[attr-defined] # noqa: E501 + torch_sparse.SparseTensor.matmul = MethodType( # type: ignore[attr-defined] + _sparse_tensor_matmul, torch_sparse.SparseTensor + ) diff --git a/src/torchjd/sparse/_registry.py b/src/torchjd/sparse/_registry.py new file mode 100644 index 000000000..e498cc176 --- /dev/null +++ b/src/torchjd/sparse/_registry.py @@ -0,0 +1,11 @@ +"""Central registry of sparse conversions and helpers. + +For now this file simply re-exports :func:`to_coalesced_coo`, but keeps +the door open for future registration logic. +""" + +from __future__ import annotations + +from ._utils import to_coalesced_coo + +__all__ = ["to_coalesced_coo"] diff --git a/src/torchjd/sparse/_utils.py b/src/torchjd/sparse/_utils.py new file mode 100644 index 000000000..795bc7fc6 --- /dev/null +++ b/src/torchjd/sparse/_utils.py @@ -0,0 +1,37 @@ +"""Utility helpers shared by the sparse sub-package.""" + +from __future__ import annotations + +from typing import Any + +import torch + +try: + import importlib + + torch_sparse = importlib.import_module("torch_sparse") # type: ignore +except (ModuleNotFoundError, OSError): # pragma: no cover + torch_sparse = None + + +def to_coalesced_coo(x: Any) -> torch.Tensor: + """Convert *x* to a **coalesced** PyTorch sparse COO tensor.""" + + if isinstance(x, torch.Tensor) and x.is_sparse: + return x.coalesce() + + if torch_sparse and isinstance(x, torch_sparse.SparseTensor): # type: ignore + return x.to_torch_sparse_coo_tensor().coalesce() + + try: + import scipy.sparse as sp # pragma: no cover + + if isinstance(x, sp.spmatrix): + coo = x.tocoo() + indices = torch.as_tensor([coo.row, coo.col], dtype=torch.long) + values = torch.as_tensor(coo.data, dtype=torch.get_default_dtype()) + return torch.sparse_coo_tensor(indices, values, coo.shape).coalesce() + except ModuleNotFoundError: # pragma: no cover + pass + + raise TypeError(f"Unsupported sparse type: {type(x)}") # pragma: no cover diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 53b735c58..50ba4931d 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -9,7 +9,7 @@ def test_backward(): import torch - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad param = torch.tensor([1.0, 2.0], requires_grad=True) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 0f8ac3567..6753a638f 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -27,7 +27,7 @@ def test_basic_usage(): loss2 = loss_fn(output[:, 1], target2) optimizer.zero_grad() - torchjd.backward([loss1, loss2], aggregator) + torchjd._autojac.backward([loss1, loss2], aggregator) optimizer.step() @@ -58,7 +58,7 @@ def test_iwrm_with_ssjd(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad X = torch.randn(8, 16, 10) @@ -87,7 +87,7 @@ def test_mtl(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) @@ -136,7 +136,7 @@ def test_lightning_integration(): from torch.optim import Adam from torch.utils.data import DataLoader, TensorDataset - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad class Model(LightningModule): @@ -190,7 +190,7 @@ def test_rnn(): from torch.nn import RNN from torch.optim import SGD - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) @@ -215,7 +215,7 @@ def test_monitoring(): from torch.nn.functional import cosine_similarity from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad def print_weights(_, __, weights: torch.Tensor) -> None: @@ -267,7 +267,7 @@ def test_amp(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index b9f0cd6cc..2cd519683 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -3,7 +3,7 @@ from torch.autograd import grad from torch.testing import assert_close -from torchjd import backward +from torchjd._autojac import backward from torchjd._autojac._backward import _create_transform from torchjd._autojac._transform import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index e952d04bd..e924eb4f7 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -3,7 +3,7 @@ from torch.autograd import grad from torch.testing import assert_close -from torchjd import mtl_backward +from torchjd._autojac import mtl_backward from torchjd._autojac._mtl_backward import _create_transform from torchjd._autojac._transform import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad diff --git a/tests/unit/sparse/test_mm.py b/tests/unit/sparse/test_mm.py new file mode 100644 index 000000000..c6143e3d8 --- /dev/null +++ b/tests/unit/sparse/test_mm.py @@ -0,0 +1,64 @@ +import pytest +import torch + +from torchjd.sparse import sparse_mm +from torchjd.sparse._utils import to_coalesced_coo + +try: + import importlib + import types + + torch_sparse = importlib.import_module("torch_sparse") # noqa: E402 + HAVE_TORCH_SPARSE = isinstance(torch_sparse, types.ModuleType) +except (ModuleNotFoundError, OSError): + HAVE_TORCH_SPARSE = False + + +try: + import scipy.sparse as sp + + HAVE_SCIPY = True +except ModuleNotFoundError: + HAVE_SCIPY = False + + +def _dense_graph(): + idx = torch.tensor([[0, 1], [1, 0]]) + return torch.sparse_coo_tensor(idx, torch.ones(2)).coalesce() + + +def _batched_features(device): + # shape (B, N, d) with B=3, N=2, d=4 + return torch.randn(3, 2, 4, device=device, dtype=torch.float32) + + +@pytest.mark.parametrize("device", ["cpu"]) +def test_vmap_branch(device): + A = _dense_graph().to(device) + X = _batched_features(device) + Y = sparse_mm(A, X) # calls vmap-aware branch + assert Y.shape == X.shape + + +@pytest.mark.skipif(not HAVE_SCIPY, reason="SciPy not installed") +def test_scipy_path(): + import numpy as np + import scipy.sparse as sp + + coo = sp.coo_matrix(([1, 1], ([0, 1], [1, 0])), shape=(2, 2)) + A = to_coalesced_coo(coo) + assert A.is_sparse and A.is_coalesced() + + +@pytest.mark.skipif(not HAVE_TORCH_SPARSE, reason="torch_sparse not installed") +def test_torch_sparse_path(): + import torch_sparse as tsp + + row = torch.tensor([0, 1]) + col = torch.tensor([1, 0]) + val = torch.ones(2) + A_ts = tsp.SparseTensor(row=row, col=col, value=val, sparse_sizes=(2, 2)) + A = to_coalesced_coo(A_ts) + X = torch.randn(2, 3) + Y = sparse_mm(A, X) + assert Y.shape == (2, 3) diff --git a/tests/unit/sparse/test_mm_3d.py b/tests/unit/sparse/test_mm_3d.py new file mode 100644 index 000000000..c3c73a38e --- /dev/null +++ b/tests/unit/sparse/test_mm_3d.py @@ -0,0 +1,18 @@ +import torch + +from torchjd.sparse import sparse_mm + + +def test_forward_backward_3d(): + # sparse 2×2 matrix + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + + # 3-D dense tensor (B=3, N=2, d=4) + X = torch.randn(3, 2, 4, requires_grad=True) + + Y = sparse_mm(A, X) # exercises 3-D forward branch + loss = Y.sum() + loss.backward() # exercises 3-D backward branch + + # Gradient should be ones because A.T @ 1 = [1,1] → broadcast + assert torch.allclose(X.grad, torch.ones_like(X), atol=1e-6) diff --git a/tests/unit/sparse/test_mm_sequential.py b/tests/unit/sparse/test_mm_sequential.py new file mode 100644 index 000000000..e78c779d3 --- /dev/null +++ b/tests/unit/sparse/test_mm_sequential.py @@ -0,0 +1,21 @@ +import torch + +from torchjd._autojac import backward +from torchjd.aggregation import UPGrad +from torchjd.sparse import sparse_mm + + +def test_sequential_backward(): + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + p = torch.tensor([1.0, 2.0], requires_grad=True) + + # Make y1 require A@p, y2 a simple L2 term + y1 = sparse_mm(A, p.unsqueeze(1)).sum() # shape (2,1) → scalar + y2 = (p**2).sum() + + # Force sequential JD (no vmap) to touch the else-branch in backward() + backward([y1, y2], UPGrad(), parallel_chunk_size=1) + + # Gradient shape & basic sanity check + assert p.grad.shape == p.shape + assert torch.isfinite(p.grad).all() diff --git a/tests/unit/sparse/test_mm_single.py b/tests/unit/sparse/test_mm_single.py new file mode 100644 index 000000000..79bd9e577 --- /dev/null +++ b/tests/unit/sparse/test_mm_single.py @@ -0,0 +1,13 @@ +import torch + +from torchjd.sparse import sparse_mm + + +def test_single_forward_backward(): + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + X = torch.randn(2, 5, requires_grad=True) + Y = sparse_mm(A, X) # (2,5) + loss = Y.sum() + loss.backward() + # gradient should equal A.T @ 1 = [1,1] + assert torch.allclose(X.grad, torch.ones_like(X)) diff --git a/tests/unit/sparse/test_mm_vmap.py b/tests/unit/sparse/test_mm_vmap.py new file mode 100644 index 000000000..f57e52b91 --- /dev/null +++ b/tests/unit/sparse/test_mm_vmap.py @@ -0,0 +1,25 @@ +import torch +from torch.func import vmap + +from torchjd.sparse import sparse_mm + + +def test_batched_vmap_forward_backward(): + """ + Touch the custom vmap rule in _SparseMatMul to push per-file coverage + above the 90 % guideline. + """ + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + B, N, d = 4, 2, 3 + X = torch.randn(B, N, d, requires_grad=True) + + # vmap over the first dim (B) so SparseMatMul.vmap executes + def _single(inp): + return sparse_mm(A, inp).sum() + + loss = vmap(_single)(X).sum() + loss.backward() + + # Analytic gradient: A.T @ 1 = [1,1] broadcast to (B,N,d) + expected = torch.ones_like(X) + assert torch.allclose(X.grad, expected, atol=1e-6) diff --git a/tests/unit/sparse/test_patch.py b/tests/unit/sparse/test_patch.py new file mode 100644 index 000000000..0ee6c0269 --- /dev/null +++ b/tests/unit/sparse/test_patch.py @@ -0,0 +1,12 @@ +import torch + +from torchjd.sparse._patch import enable_seamless_sparse + + +def test_monkey_patch_matmul(): + enable_seamless_sparse() # idempotent + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + X = torch.randn(2, 3) + Y1 = A @ X # should hit sparse_mm via patched __matmul__ + Y2 = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) # placeholder + assert torch.allclose(Y1.sum(), (A.to_dense() @ X).sum()) diff --git a/tests/unit/sparse/test_patch_idempotent.py b/tests/unit/sparse/test_patch_idempotent.py new file mode 100644 index 000000000..32b0b36c8 --- /dev/null +++ b/tests/unit/sparse/test_patch_idempotent.py @@ -0,0 +1,6 @@ +from torchjd.sparse._patch import enable_seamless_sparse + + +def test_enable_patch_idempotent(): + enable_seamless_sparse() # first call patches + enable_seamless_sparse() # second call should be a no-op diff --git a/tests/unit/sparse/test_patch_import.py b/tests/unit/sparse/test_patch_import.py new file mode 100644 index 000000000..2b687d311 --- /dev/null +++ b/tests/unit/sparse/test_patch_import.py @@ -0,0 +1,46 @@ +import importlib +import sys +import types +from contextlib import contextmanager + + +@contextmanager +def fake_torch_sparse(): + """ + Context manager that injects a *minimal* torch_sparse stub. + The Dummy.SparseTensor *must* expose a ``matmul`` attribute because + enable_seamless_sparse() tries to save and patch it. + """ + mod = types.ModuleType("torch_sparse") + + class Dummy: # noqa: D401 + # placeholder matmul so _patch can grab the attribute + def matmul(self, dense): + raise NotImplementedError + + mod.SparseTensor = Dummy # type: ignore + sys.modules["torch_sparse"] = mod + try: + yield + finally: + sys.modules.pop("torch_sparse", None) + + +def test_patch_without_torch_sparse(monkeypatch): + monkeypatch.setitem(sys.modules, "torch_sparse", None) + from importlib import reload + + import torchjd.sparse._patch as p + + reload(p) # re-import to trigger patch + assert p.torch_sparse is None # slow fallback branch hit + + +def test_patch_with_dummy_torch_sparse(monkeypatch): + with fake_torch_sparse(): + from importlib import reload + + import torchjd.sparse._patch as p + + reload(p) + assert p.torch_sparse is not None # optional branch hit diff --git a/tests/unit/sparse/test_patch_torch_sparse_branch.py b/tests/unit/sparse/test_patch_torch_sparse_branch.py new file mode 100644 index 000000000..5687b0ba5 --- /dev/null +++ b/tests/unit/sparse/test_patch_torch_sparse_branch.py @@ -0,0 +1,48 @@ +import importlib +import sys +import types +from importlib import reload + + +def _make_dummy_torch_sparse(): + """ + Return a minimal torch_sparse stub: + + * SparseTensor.matmul – so _patch can save & wrap it. + * SparseTensor.to_torch_sparse_coo_tensor – so _utils branch works. + """ + dummy_mod = types.ModuleType("torch_sparse") + + class DummyTensor: # noqa: D401 + def matmul(self, dense): + raise NotImplementedError + + def to_torch_sparse_coo_tensor(self): + import torch + + return torch.sparse_coo_tensor([[0], [0]], [1.0], (1, 1)) + + dummy_mod.SparseTensor = DummyTensor # type: ignore[attr-defined] + return dummy_mod + + +def test_full_torch_sparse_branch(monkeypatch): + # Inject fresh stub + monkeypatch.setitem(sys.modules, "torch_sparse", _make_dummy_torch_sparse()) + + # Force the patch module to re-evaluate from scratch + # Remove earlier sentinel attributes so enable_seamless_sparse() re-patches + import torch + + import torchjd.sparse._patch as p # noqa: E402 + + for attr in ("_orig_mm",): + if hasattr(torch.sparse, attr): + delattr(torch.sparse, attr) # type: ignore[attr-defined] + + # Run patch + reload(p) + p.enable_seamless_sparse() + + # Optional branch should have set _orig_matmul + assert hasattr(p.torch_sparse.SparseTensor, "_orig_matmul") diff --git a/tests/unit/sparse/test_patch_warn_branch.py b/tests/unit/sparse/test_patch_warn_branch.py new file mode 100644 index 000000000..303f66abb --- /dev/null +++ b/tests/unit/sparse/test_patch_warn_branch.py @@ -0,0 +1,28 @@ +""" +Covers the branch in _patch.enable_seamless_sparse() that emits a warning +when *no* ``torch_sparse`` package is available. +""" + +import importlib +import sys +import types +import warnings + +import torch + + +def test_warn_branch(monkeypatch): + monkeypatch.setitem(sys.modules, "torch_sparse", None) + + if hasattr(torch.sparse, "_orig_mm"): + delattr(torch.sparse, "_orig_mm") # type: ignore[attr-defined] + + import torchjd.sparse._patch as p # noqa: E402 + + p = importlib.reload(p) + + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + p.enable_seamless_sparse() # <- emits RuntimeWarning branch + + assert any("SpSpMM will use slow fallback" in str(w.message) for w in rec) diff --git a/tests/unit/sparse/test_sparse_mm_wrapper.py b/tests/unit/sparse/test_sparse_mm_wrapper.py new file mode 100644 index 000000000..18a938735 --- /dev/null +++ b/tests/unit/sparse/test_sparse_mm_wrapper.py @@ -0,0 +1,13 @@ +import torch + +from torchjd.sparse._patch import enable_seamless_sparse + + +def test_torch_sparse_mm_wrapper(): + enable_seamless_sparse() # idempotent + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + X = torch.randn(2, 3) + + out = torch.sparse.mm(A, X) # routed through wrapper + ref = A.to_dense() @ X + assert torch.allclose(out, ref, atol=1e-6) diff --git a/tests/unit/sparse/test_utils_scipy.py b/tests/unit/sparse/test_utils_scipy.py new file mode 100644 index 000000000..d3bd96855 --- /dev/null +++ b/tests/unit/sparse/test_utils_scipy.py @@ -0,0 +1,16 @@ +import importlib + +import numpy as np +import pytest + +scipy = pytest.importorskip("scipy") # skip if SciPy not available +from torchjd.sparse._utils import to_coalesced_coo + + +def test_to_coalesced_coo_from_scipy(): + sp = importlib.import_module("scipy.sparse") + # 2×2 off-diagonal ones + coo = sp.coo_matrix((np.ones(2), ([0, 1], [1, 0])), shape=(2, 2)) + tsr = to_coalesced_coo(coo) # exercises SciPy branch + dense = tsr.to_dense() + assert dense[0, 1] == dense[1, 0] == 1 and dense.sum() == 2 diff --git a/tests/unit/sparse/test_utils_torch_sparse.py b/tests/unit/sparse/test_utils_torch_sparse.py new file mode 100644 index 000000000..3e4fa043a --- /dev/null +++ b/tests/unit/sparse/test_utils_torch_sparse.py @@ -0,0 +1,30 @@ +import importlib +import sys +import types + +import torch + + +def test_to_coalesced_coo_torch_sparse(monkeypatch): + dummy = types.ModuleType("torch_sparse") + + class DummyTensor: # noqa: D401 + def __init__(self): + self.row = torch.tensor([0]) + self.col = torch.tensor([0]) + self.value = torch.tensor([1.0]) + + def to_torch_sparse_coo_tensor(self): + return torch.sparse_coo_tensor(torch.stack([self.row, self.col]), self.value, (1, 1)) + + def matmul(self, other): + raise NotImplementedError + + dummy.SparseTensor = DummyTensor # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "torch_sparse", dummy) + + utils = importlib.reload(importlib.import_module("torchjd.sparse._utils")) + to_coalesced_coo = utils.to_coalesced_coo + + tsr = to_coalesced_coo(DummyTensor()) + assert tsr.is_sparse and tsr._nnz() == 1