Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ def cls2cov(
for j in range(nf):
begin, end = end, end + j + 1
for i, cl in enumerate(cls[begin:end][: nc + 1]):
if i == 0 and np.any(xp.less(cl, 0)):
if i == 0 and xp.any(xp.less(cl, 0)):
msg = "negative values in cl"
raise ValueError(msg)
n = cl.size
cov[:n, i] = cl
cov[n:, i] = 0
n = cl.shape[0]
cov = xpx.at(cov)[:n, i].set(cl)
cov = xpx.at(cov)[n:, i].set(0.0)
cov /= 2
yield cov

Expand Down
17 changes: 3 additions & 14 deletions tests/benchmarks/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,11 @@ def test_cls2cov(
benchmark: BenchmarkFixture,
compare: Compare,
generator_consumer: GeneratorConsumer,
urngb: UnifiedGenerator,
xpb: ModuleType,
) -> None:
"""Benchmarks for glass.cls2cov."""
# check output values and shape

nl, nf, nc = 3, 2, 2
array_in = [urngb.random(3) for _ in range(1_000)]
array_in = [xpb.arange(i + 1.0, i + 4.0) for i in range(1_000)]

def function_to_benchmark() -> list[Any]:
generator = glass.cls2cov(
Expand All @@ -151,16 +148,8 @@ def function_to_benchmark() -> list[Any]:
assert cov.shape == (nl, nc + 1)
assert cov.dtype == xpb.float64

compare.assert_allclose(
cov[:, 0],
xpb.asarray([0.348684, 0.047089, 0.487811]),
atol=1e-6,
)
compare.assert_allclose(
cov[:, 1],
[0.38057, 0.393032, 0.064057],
atol=1e-6,
)
compare.assert_allclose(cov[:, 0], xpb.asarray([1.0, 1.5, 2.0]))
compare.assert_allclose(cov[:, 1], xpb.asarray([1.5, 2.0, 2.5]))
compare.assert_allclose(cov[:, 2], 0)


Expand Down
4 changes: 0 additions & 4 deletions tests/core/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def test_iternorm(xp: ModuleType) -> None:


def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
# Call jax version of iternorm once jax version is written
if xp.__name__ == "jax.numpy":
pytest.skip("Arrays in cls2cov are not immutable, so do not support jax")

# check output values and shape

nl, nf, nc = 3, 2, 2
Expand Down
Loading