Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ jobs:
matrix:
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
os: [ubuntu-latest, macOS-latest, windows-latest]
device: [cpu, mps]
exclude:
# Only test on MPS when the OS is macOS-latest.
- os: ubuntu-latest
device: mps
- os: windows-latest
device: mps

steps:
- uses: actions/checkout@v4
Expand All @@ -29,6 +36,8 @@ jobs:
- name: Install default (with full options) and test dependencies
run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test
- name: Run unit and doc tests with coverage report
env:
PYTEST_TORCH_DEVICE: ${{ matrix.device }}
run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
Expand Down
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ uv run pre-commit install
CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit
```

- If you work on a MacOS device with Metal programming framework (MPS), you can check that the
unit tests pass on it:
```bash
PYTEST_TORCH_DEVICE=mps uv run pytest tests/unit
```

- To check that the usage examples from docstrings and `.rst` files are correct, we test their
behavior in `tests/doc`. To run these tests, do:
```bash
Expand Down
12 changes: 11 additions & 1 deletion src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,17 @@ def forward(self, gramian: Tensor) -> Tensor:

@staticmethod
def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
try:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
except NotImplementedError:
# This will happen on MPS until they add support for aten::_linalg_eigh.eigenvalues
# See status in https://github.com/pytorch/pytorch/issues/141287
# In this case, perform the qr on CPU and move back to the original device
original_device = M.device
lambda_, V = torch.linalg.eigh(M.cpu(), UPLO="U")
lambda_ = lambda_.to(device=original_device)
V = V.to(device=original_device)

tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)

Expand Down
23 changes: 21 additions & 2 deletions tests/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,34 @@

import torch

_POSSIBLE_TEST_DEVICES = {"cpu", "cuda:0", "mps"}

try:
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
except KeyError:
_device_str = "cpu" # Default to cpu if environment variable not set

if _device_str != "cuda:0" and _device_str != "cpu":
raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}")
if _device_str not in _POSSIBLE_TEST_DEVICES:
raise ValueError(
f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n"
f"Possible devices: {_POSSIBLE_TEST_DEVICES}"
)

if _device_str == "cuda:0" and not torch.cuda.is_available():
raise ValueError('Requested device "cuda:0" but cuda is not available.')

if _device_str == "mps":
# Check that MPS is available (following https://docs.pytorch.org/docs/stable/notes/mps.html)
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise ValueError(
"MPS not available because the current PyTorch install was not built with MPS "
"enabled."
)
else:
raise ValueError(
"MPS not available because the current MacOS version is not 12.3+ and/or you do not"
" have an MPS-enabled device on this machine."
)

DEVICE = torch.device(_device_str)
11 changes: 10 additions & 1 deletion tests/unit/aggregation/_matrix_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,14 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
# project A onto the orthogonal complement of Q
A_proj = A - Q @ (Q.T @ A)

Q_prime, _ = torch.linalg.qr(A_proj)
try:
Q_prime, _ = torch.linalg.qr(A_proj)
except NotImplementedError:
# This will happen on MPS until they add support for aten::linalg_qr.out
# See status in https://github.com/pytorch/pytorch/issues/141287
# In this case, perform the qr on CPU and move back to the original device
original_device = A_proj.device
Q_prime, _ = torch.linalg.qr(A_proj.to(device="cpu"))
Q_prime = Q_prime.to(device=original_device)

return Q_prime
2 changes: 1 addition & 1 deletion tests/utils/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@contextmanager
def fork_rng(seed: int = 0) -> Generator[Any, None, None]:
devices = [DEVICE] if DEVICE.type == "cuda" else []
devices = [] if DEVICE.type == "cpu" else [DEVICE]
with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx:
torch.manual_seed(seed)
yield ctx
Loading