diff --git a/glass/fields.py b/glass/fields.py index 2cd26174..9cb9fa81 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -213,12 +213,12 @@ def cls2cov( for j in range(nf): begin, end = end, end + j + 1 for i, cl in enumerate(cls[begin:end][: nc + 1]): - if i == 0 and np.any(xp.less(cl, 0)): + if i == 0 and xp.any(xp.less(cl, 0)): msg = "negative values in cl" raise ValueError(msg) - n = cl.size - cov[:n, i] = cl - cov[n:, i] = 0 + n = cl.shape[0] + cov = xpx.at(cov)[:n, i].set(cl) + cov = xpx.at(cov)[n:, i].set(0.0) cov /= 2 yield cov diff --git a/tests/benchmarks/test_fields.py b/tests/benchmarks/test_fields.py index 43b7513c..b75a2fad 100644 --- a/tests/benchmarks/test_fields.py +++ b/tests/benchmarks/test_fields.py @@ -127,14 +127,11 @@ def test_cls2cov( benchmark: BenchmarkFixture, compare: Compare, generator_consumer: GeneratorConsumer, - urngb: UnifiedGenerator, xpb: ModuleType, ) -> None: """Benchmarks for glass.cls2cov.""" - # check output values and shape - nl, nf, nc = 3, 2, 2 - array_in = [urngb.random(3) for _ in range(1_000)] + array_in = [xpb.arange(i + 1.0, i + 4.0) for i in range(1_000)] def function_to_benchmark() -> list[Any]: generator = glass.cls2cov( @@ -151,16 +148,8 @@ def function_to_benchmark() -> list[Any]: assert cov.shape == (nl, nc + 1) assert cov.dtype == xpb.float64 - compare.assert_allclose( - cov[:, 0], - xpb.asarray([0.348684, 0.047089, 0.487811]), - atol=1e-6, - ) - compare.assert_allclose( - cov[:, 1], - [0.38057, 0.393032, 0.064057], - atol=1e-6, - ) + compare.assert_allclose(cov[:, 0], xpb.asarray([1.0, 1.5, 2.0])) + compare.assert_allclose(cov[:, 1], xpb.asarray([1.5, 2.0, 2.5])) compare.assert_allclose(cov[:, 2], 0) diff --git a/tests/core/test_fields.py b/tests/core/test_fields.py index e031a8ca..0a155bd7 100644 --- a/tests/core/test_fields.py +++ b/tests/core/test_fields.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib.util from typing import TYPE_CHECKING import healpy as hp @@ -16,6 +17,8 @@ from tests.fixtures.helper_classes import Compare +HAVE_JAX = importlib.util.find_spec("jax") is not None + @pytest.fixture(scope="session") def not_triangle_numbers() -> list[int]: @@ -139,17 +142,57 @@ def test_iternorm(xp: ModuleType) -> None: assert s.shape == (3,) -def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None: - # Call jax version of iternorm once jax version is written - if xp.__name__ == "jax.numpy": - pytest.skip("Arrays in cls2cov are not immutable, so do not support jax") +@pytest.mark.skipif(not HAVE_JAX, reason="test requires jax") +def test_cls2cov_jax(compare: type[Compare], jnp: ModuleType) -> None: + nl, nf, nc = 3, 3, 2 + generator = glass.cls2cov( + [ + jnp.asarray(arr) + for arr in [ + [1.0, 0.5, 0.3], + [0.8, 0.4, 0.2], + [0.7, 0.6, 0.1], + [0.9, 0.5, 0.3], + [0.6, 0.3, 0.2], + [0.8, 0.7, 0.4], + ] + ], + nl, + nf, + nc, + ) + + cov1 = jnp.asarray(next(generator), copy=False) + cov2 = jnp.asarray(next(generator), copy=False) + cov3 = next(generator) + + assert cov1.shape == (nl, nc + 1) + assert cov2.shape == (nl, nc + 1) + assert cov3.shape == (nl, nc + 1) + + assert cov1.dtype == jnp.float64 + assert cov2.dtype == jnp.float64 + assert cov3.dtype == jnp.float64 + + # cov1 has the expected value for the first iteration (different to cov1_copy) + compare.assert_allclose(cov1[:, 0], jnp.asarray([0.5, 0.25, 0.15])) + + # The copies should not be equal + with pytest.raises(AssertionError, match="Not equal to tolerance"): + compare.assert_allclose(cov1, cov2) + + with pytest.raises(AssertionError, match="Not equal to tolerance"): + compare.assert_allclose(cov2, cov3) + + +def test_cls2cov_no_jax(compare: type[Compare], xpb: ModuleType) -> None: # check output values and shape nl, nf, nc = 3, 2, 2 generator = glass.cls2cov( - [xp.asarray([1.0, 0.5, 0.3]), None, xp.asarray([0.7, 0.6, 0.1])], + [xpb.asarray([1.0, 0.5, 0.3]), None, xpb.asarray([0.7, 0.6, 0.1])], nl, nf, nc, @@ -157,9 +200,9 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None: cov = next(generator) assert cov.shape == (nl, nc + 1) - assert cov.dtype == xp.float64 + assert cov.dtype == xpb.float64 - compare.assert_allclose(cov[:, 0], xp.asarray([0.5, 0.25, 0.15])) + compare.assert_allclose(cov[:, 0], xpb.asarray([0.5, 0.25, 0.15])) compare.assert_allclose(cov[:, 1], 0) compare.assert_allclose(cov[:, 2], 0) @@ -167,7 +210,7 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None: generator = glass.cls2cov( [ - xp.asarray(arr) + xpb.asarray(arr) for arr in [ [-1.0, 0.5, 0.3], [0.8, 0.4, 0.2], @@ -187,7 +230,7 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None: generator = glass.cls2cov( [ - xp.asarray(arr) + xpb.asarray(arr) for arr in [ [1.0, 0.5, 0.3], [0.8, 0.4, 0.2], @@ -202,23 +245,34 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None: nc, ) - cov1 = xp.asarray(next(generator), copy=True) - cov2 = xp.asarray(next(generator), copy=True) + cov1 = xpb.asarray(next(generator), copy=False) + cov1_copy = xpb.asarray(cov1, copy=True) + cov2 = xpb.asarray(next(generator), copy=False) + cov2_copy = xpb.asarray(cov2, copy=True) cov3 = next(generator) assert cov1.shape == (nl, nc + 1) assert cov2.shape == (nl, nc + 1) assert cov3.shape == (nl, nc + 1) - assert cov1.dtype == xp.float64 - assert cov2.dtype == xp.float64 - assert cov3.dtype == xp.float64 + assert cov1.dtype == xpb.float64 + assert cov2.dtype == xpb.float64 + assert cov3.dtype == xpb.float64 + + # cov1|2|3 reuse the same data, so should all equal the third result + compare.assert_allclose(cov1[:, 0], xpb.asarray([0.45, 0.25, 0.15])) + compare.assert_allclose(cov1, cov2) + compare.assert_allclose(cov2, cov3) + + # cov1 has the expected value for the first iteration (different to cov1_copy) + compare.assert_allclose(cov1_copy[:, 0], xpb.asarray([0.5, 0.25, 0.15])) + # The copies should not be equal with pytest.raises(AssertionError, match="Not equal to tolerance"): - compare.assert_allclose(cov1, cov2) + compare.assert_allclose(cov1_copy, cov2_copy) with pytest.raises(AssertionError, match="Not equal to tolerance"): - compare.assert_allclose(cov2, cov3) + compare.assert_allclose(cov2_copy, cov3) def test_lognormal_gls() -> None: diff --git a/tests/fixtures/array_backends.py b/tests/fixtures/array_backends.py index 5dedf423..f7ff7862 100644 --- a/tests/fixtures/array_backends.py +++ b/tests/fixtures/array_backends.py @@ -125,6 +125,12 @@ def xpb(request: pytest.FixtureRequest) -> ModuleType: return request.param # type: ignore[no-any-return] +@pytest.fixture(scope="session") +def jnp() -> ModuleType: + """Fixture for the jax.numpy array backend.""" + return xp_available_backends["jax.numpy"] + + @pytest.fixture(scope="session") def uxpx(xp: ModuleType) -> _utils.XPAdditions: """