Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ def cls2cov(
for j in range(nf):
begin, end = end, end + j + 1
for i, cl in enumerate(cls[begin:end][: nc + 1]):
if i == 0 and np.any(xp.less(cl, 0)):
if i == 0 and xp.any(xp.less(cl, 0)):
msg = "negative values in cl"
raise ValueError(msg)
n = cl.size
cov[:n, i] = cl
cov[n:, i] = 0
n = cl.shape[0]
cov = xpx.at(cov)[:n, i].set(cl)
cov = xpx.at(cov)[n:, i].set(0.0)
cov /= 2
yield cov

Expand Down
17 changes: 3 additions & 14 deletions tests/benchmarks/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,11 @@ def test_cls2cov(
benchmark: BenchmarkFixture,
compare: Compare,
generator_consumer: GeneratorConsumer,
urngb: UnifiedGenerator,
xpb: ModuleType,
) -> None:
"""Benchmarks for glass.cls2cov."""
# check output values and shape

nl, nf, nc = 3, 2, 2
array_in = [urngb.random(3) for _ in range(1_000)]
array_in = [xpb.arange(i + 1.0, i + 4.0) for i in range(1_000)]

def function_to_benchmark() -> list[Any]:
generator = glass.cls2cov(
Expand All @@ -151,16 +148,8 @@ def function_to_benchmark() -> list[Any]:
assert cov.shape == (nl, nc + 1)
assert cov.dtype == xpb.float64

compare.assert_allclose(
cov[:, 0],
xpb.asarray([0.348684, 0.047089, 0.487811]),
atol=1e-6,
)
compare.assert_allclose(
cov[:, 1],
[0.38057, 0.393032, 0.064057],
atol=1e-6,
)
compare.assert_allclose(cov[:, 0], xpb.asarray([1.0, 1.5, 2.0]))
compare.assert_allclose(cov[:, 1], xpb.asarray([1.5, 2.0, 2.5]))
compare.assert_allclose(cov[:, 2], 0)


Expand Down
28 changes: 20 additions & 8 deletions tests/core/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def test_iternorm(xp: ModuleType) -> None:


def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
# Call jax version of iternorm once jax version is written
if xp.__name__ == "jax.numpy":
pytest.skip("Arrays in cls2cov are not immutable, so do not support jax")

# check output values and shape

nl, nf, nc = 3, 2, 2
Expand Down Expand Up @@ -202,8 +198,10 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
nc,
)

cov1 = xp.asarray(next(generator), copy=True)
cov2 = xp.asarray(next(generator), copy=True)
cov1 = xp.asarray(next(generator), copy=False)
cov1_copy = xp.asarray(cov1, copy=True)
cov2 = xp.asarray(next(generator), copy=False)
cov2_copy = xp.asarray(cov2, copy=True)
cov3 = next(generator)

assert cov1.shape == (nl, nc + 1)
Expand All @@ -214,11 +212,25 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
assert cov2.dtype == xp.float64
assert cov3.dtype == xp.float64

with pytest.raises(AssertionError, match="Not equal to tolerance"):
# Jax enforces the creation of a copy rather than the reuse of memory.
if xp.__name__ != "jax.numpy":
# cov1|2|3 reuse the same data, so should all equal the third result
compare.assert_allclose(cov1[:, 0], xp.asarray([0.45, 0.25, 0.15]))
compare.assert_allclose(cov1, cov2)
compare.assert_allclose(cov2, cov3)
else:
compare.assert_allclose(cov1, cov1_copy)
compare.assert_allclose(cov2, cov2_copy)

# cov1 has the expected value for the first iteration (different to cov1_copy)
compare.assert_allclose(cov1_copy[:, 0], xp.asarray([0.5, 0.25, 0.15]))

# The copies should not be equal
with pytest.raises(AssertionError, match="Not equal to tolerance"):
compare.assert_allclose(cov2, cov3)
compare.assert_allclose(cov1_copy, cov2_copy)

with pytest.raises(AssertionError, match="Not equal to tolerance"):
compare.assert_allclose(cov2_copy, cov3)


def test_lognormal_gls() -> None:
Expand Down