diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 570f6dcc3..e7057d6b8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,6 +19,13 @@ jobs: matrix: python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] os: [ubuntu-latest, macOS-latest, windows-latest] + device: [cpu, mps] + exclude: + # Only test on MPS when the OS is macOS-latest. + - os: ubuntu-latest + device: mps + - os: windows-latest + device: mps steps: - uses: actions/checkout@v4 @@ -29,6 +36,8 @@ jobs: - name: Install default (with full options) and test dependencies run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report + env: + PYTEST_TORCH_DEVICE: ${{ matrix.device }} run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 267b69b3d..f3623650b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -87,6 +87,12 @@ uv run pre-commit install CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit ``` + - If you work on a MacOS device with Metal programming framework (MPS), you can check that the + unit tests pass on it: + ```bash + PYTEST_TORCH_DEVICE=mps uv run pytest tests/unit + ``` + - To check that the usage examples from docstrings and `.rst` files are correct, we test their behavior in `tests/doc`. To run these tests, do: ```bash diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 3f0b71199..7fe291dca 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -82,7 +82,17 @@ def forward(self, gramian: Tensor) -> Tensor: @staticmethod def _compute_balance_transformation(M: Tensor) -> Tensor: - lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig + try: + lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig + except NotImplementedError: + # This will happen on MPS until they add support for aten::_linalg_eigh.eigenvalues + # See status in https://github.com/pytorch/pytorch/issues/141287 + # In this case, perform the qr on CPU and move back to the original device + original_device = M.device + lambda_, V = torch.linalg.eigh(M.cpu(), UPLO="U") + lambda_ = lambda_.to(device=original_device) + V = V.to(device=original_device) + tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) diff --git a/tests/device.py b/tests/device.py index 7be2c75c2..da4f20d90 100644 --- a/tests/device.py +++ b/tests/device.py @@ -2,15 +2,34 @@ import torch +_POSSIBLE_TEST_DEVICES = {"cpu", "cuda:0", "mps"} + try: _device_str = os.environ["PYTEST_TORCH_DEVICE"] except KeyError: _device_str = "cpu" # Default to cpu if environment variable not set -if _device_str != "cuda:0" and _device_str != "cpu": - raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}") +if _device_str not in _POSSIBLE_TEST_DEVICES: + raise ValueError( + f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n" + f"Possible devices: {_POSSIBLE_TEST_DEVICES}" + ) if _device_str == "cuda:0" and not torch.cuda.is_available(): raise ValueError('Requested device "cuda:0" but cuda is not available.') +if _device_str == "mps": + # Check that MPS is available (following https://docs.pytorch.org/docs/stable/notes/mps.html) + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + raise ValueError( + "MPS not available because the current PyTorch install was not built with MPS " + "enabled." + ) + else: + raise ValueError( + "MPS not available because the current MacOS version is not 12.3+ and/or you do not" + " have an MPS-enabled device on this machine." + ) + DEVICE = torch.device(_device_str) diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 137ca2e69..40db71a81 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -167,5 +167,14 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = # project A onto the orthogonal complement of Q A_proj = A - Q @ (Q.T @ A) - Q_prime, _ = torch.linalg.qr(A_proj) + try: + Q_prime, _ = torch.linalg.qr(A_proj) + except NotImplementedError: + # This will happen on MPS until they add support for aten::linalg_qr.out + # See status in https://github.com/pytorch/pytorch/issues/141287 + # In this case, perform the qr on CPU and move back to the original device + original_device = A_proj.device + Q_prime, _ = torch.linalg.qr(A_proj.to(device="cpu")) + Q_prime = Q_prime.to(device=original_device) + return Q_prime diff --git a/tests/utils/contexts.py b/tests/utils/contexts.py index ef4c0ecf1..8971ff587 100644 --- a/tests/utils/contexts.py +++ b/tests/utils/contexts.py @@ -10,7 +10,7 @@ @contextmanager def fork_rng(seed: int = 0) -> Generator[Any, None, None]: - devices = [DEVICE] if DEVICE.type == "cuda" else [] + devices = [] if DEVICE.type == "cpu" else [DEVICE] with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx: torch.manual_seed(seed) yield ctx