diff --git a/glass/_array_api_utils.py b/glass/_array_api_utils.py index 52f54c32..633bc90b 100644 --- a/glass/_array_api_utils.py +++ b/glass/_array_api_utils.py @@ -36,6 +36,7 @@ ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray + IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray class CompatibleBackendNotFoundError(Exception): @@ -654,3 +655,101 @@ def apply_along_axis( msg = "the array backend in not supported" raise NotImplementedError(msg) + + def vectorize( + self, + pyfunc: Callable[..., Any], + otypes: tuple[type[float]], + ) -> Callable[..., Any]: + """ + Returns an object that acts like pyfunc, but takes arrays as input. + + Parameters + ---------- + pyfunc + Python function to vectorize. + otypes + Output types. + + Returns + ------- + Vectorized function. + + Raises + ------ + NotImplementedError + If the array backend is not supported. + + Notes + ----- + See https://github.com/glass-dev/glass/issues/671 + """ + if self.xp.__name__ == "numpy": + return self.xp.vectorize(pyfunc, otypes=otypes) # type: ignore[no-any-return] + + if self.xp.__name__ in {"array_api_strict", "jax.numpy"}: + # Import here to prevent users relying on numpy unless in this instance + np = import_numpy(self.xp.__name__) + + return np.vectorize(pyfunc, otypes=otypes) # type: ignore[no-any-return] + + msg = "the array backend in not supported" + raise NotImplementedError(msg) + + def radians(self, deg_arr: AnyArray) -> AnyArray: + """ + Convert angles from degrees to radians. + + Parameters + ---------- + deg_arr + Array of angles in degrees. + + Returns + ------- + Array of angles in radians. + + Raises + ------ + NotImplementedError + If the array backend is not supported. + """ + if self.xp.__name__ in {"numpy", "jax.numpy"}: + return self.xp.radians(deg_arr) + + if self.xp.__name__ == "array_api_strict": + np = import_numpy(self.xp.__name__) + + return self.xp.asarray(np.radians(deg_arr)) + + msg = "the array backend in not supported" + raise NotImplementedError(msg) + + def degrees(self, deg_arr: AnyArray) -> AnyArray: + """ + Convert angles from radians to degrees. + + Parameters + ---------- + deg_arr + Array of angles in radians. + + Returns + ------- + Array of angles in degrees. + + Raises + ------ + NotImplementedError + If the array backend is not supported. + """ + if self.xp.__name__ in {"numpy", "jax.numpy"}: + return self.xp.degrees(deg_arr) + + if self.xp.__name__ == "array_api_strict": + np = import_numpy(self.xp.__name__) + + return self.xp.asarray(np.degrees(deg_arr)) + + msg = "the array backend in not supported" + raise NotImplementedError(msg) diff --git a/glass/arraytools.py b/glass/arraytools.py index fdfe3410..1a945c17 100644 --- a/glass/arraytools.py +++ b/glass/arraytools.py @@ -13,9 +13,9 @@ if TYPE_CHECKING: from typing import Unpack - from numpy.typing import DTypeLike, NDArray + from numpy.typing import NDArray - from glass._array_api_utils import FloatArray + from glass._array_api_utils import AnyArray, FloatArray, IntArray def broadcast_first( @@ -183,11 +183,9 @@ def trapezoid_product( def cumulative_trapezoid( - f: NDArray[np.int_] | NDArray[np.float64], - x: NDArray[np.int_] | NDArray[np.float64], - dtype: DTypeLike | None = None, - out: NDArray[np.float64] | None = None, -) -> NDArray[np.float64]: + f: IntArray | FloatArray, + x: IntArray | FloatArray, +) -> AnyArray: """ Cumulative trapezoidal rule along last axis. @@ -197,19 +195,18 @@ def cumulative_trapezoid( The function values. x The x-coordinates. - dtype - The output data type. - out - The output array. Returns ------- The cumulative integral of the function. """ - if out is None: - out = np.empty_like(f, dtype=dtype) + xp = _utils.get_namespace(f, x) - np.cumsum((f[..., 1:] + f[..., :-1]) / 2 * np.diff(x), axis=-1, out=out[..., 1:]) - out[..., 0] = 0 - return out + f = xp.asarray(f, dtype=xp.float64) + x = xp.asarray(x, dtype=xp.float64) + + # Compute the cumulative trapezoid without mutating any arrays + return xp.cumulative_sum( + (f[..., 1:] + f[..., :-1]) * 0.5 * xp.diff(x), axis=-1, include_initial=True + ) diff --git a/glass/galaxies.py b/glass/galaxies.py index 87cc5be8..357b53c3 100644 --- a/glass/galaxies.py +++ b/glass/galaxies.py @@ -129,7 +129,7 @@ def redshifts_from_nz( # go through extra dimensions; also works if dims is empty for k in np.ndindex(dims): # compute the CDF of each galaxy population - cdf = glass.arraytools.cumulative_trapezoid(nz_out[k], z_out[k], dtype=float) + cdf = glass.arraytools.cumulative_trapezoid(nz_out[k], z_out[k]) cdf /= cdf[-1] # sample redshifts and store result diff --git a/glass/lensing.py b/glass/lensing.py index 89e37a1e..72220672 100644 --- a/glass/lensing.py +++ b/glass/lensing.py @@ -31,17 +31,22 @@ from __future__ import annotations +from numbers import Number from typing import TYPE_CHECKING, Literal, overload import healpy as hp import numpy as np +import glass._array_api_utils as _utils + if TYPE_CHECKING: from collections.abc import Sequence + from types import ModuleType from numpy.typing import NDArray import glass + from glass._array_api_utils import ComplexArray, FloatArray from glass.cosmology import Cosmology @@ -601,12 +606,13 @@ def multi_plane_weights( def deflect( - lon: float | NDArray[np.float64], - lat: float | NDArray[np.float64], - alpha: complex | list[float] | NDArray[np.complex128] | NDArray[np.float64], + lon: float | FloatArray, + lat: float | FloatArray, + alpha: complex | ComplexArray | FloatArray, + xp: ModuleType | None = None, ) -> tuple[ - NDArray[np.float64], - NDArray[np.float64], + FloatArray, + FloatArray, ]: r""" Apply deflections to positions. @@ -623,11 +629,20 @@ def deflect( alpha Deflection values. Must be complex-valued or have a leading axis of size 2 for the real and imaginary component. + xp + The array library backend to use for array operations. If this is not + specified, the array library will be determined from the other parameters, + if possible. Returns ------- The longitudes and latitudes after deflection. + Raises + ------ + ValueError + If neither an array nor the array backend ``xp`` are provided. + Notes ----- Deflections on the sphere are :term:`defined ` as @@ -639,28 +654,40 @@ def deflect( exponential map. """ - alpha = np.asanyarray(alpha) - if np.iscomplexobj(alpha): - alpha1, alpha2 = alpha.real, alpha.imag + arrays_to_check = tuple( + x + for x in (lon, lat, alpha) + if not isinstance(x, Number) and not isinstance(x, list) + ) + if xp is None: + if len(arrays_to_check) == 0: + msg = "Either, one positional input must be an array or xp must be provided" + raise ValueError(msg) + xp = _utils.get_namespace(*arrays_to_check) + uxpx = _utils.XPAdditions(xp) + + alpha = xp.asarray(alpha) + if xp.isdtype(alpha.dtype, "complex floating"): # type: ignore[union-attr] + alpha1, alpha2 = xp.real(alpha), xp.imag(alpha) else: - alpha1, alpha2 = alpha + alpha1, alpha2 = alpha # type: ignore[misc] # we know great-circle navigation: # θ' = arctan2(√[(cosθ sin|α| - sinθ cos|α| cosγ)² + (sinθ sinγ)²], # cosθ cos|α| + sinθ sin|α| cosγ) # δ = arctan2(sin|α| sinγ, sinθ cos|α| - cosθ sin|α| cosγ) - t = np.radians(lat) - ct, st = np.sin(t), np.cos(t) # sin and cos flipped: lat not co-lat + t = uxpx.radians(xp.asarray(lat)) + ct, st = xp.sin(t), xp.cos(t) # sin and cos flipped: lat not co-lat - a = np.hypot(alpha1, alpha2) # abs(alpha) - g = np.arctan2(alpha2, alpha1) # arg(alpha) - ca, sa = np.cos(a), np.sin(a) - cg, sg = np.cos(g), np.sin(g) + a = xp.hypot(alpha1, alpha2) # abs(alpha) + g = xp.atan2(alpha2, alpha1) # arg(alpha) + ca, sa = xp.cos(a), xp.sin(a) + cg, sg = xp.cos(g), xp.sin(g) # flipped atan2 arguments for lat instead of co-lat - tp = np.arctan2(ct * ca + st * sa * cg, np.hypot(ct * sa - st * ca * cg, st * sg)) + tp = xp.atan2(ct * ca + st * sa * cg, xp.hypot(ct * sa - st * ca * cg, st * sg)) - d = np.arctan2(sa * sg, st * ca - ct * sa * cg) + d = xp.atan2(sa * sg, st * ca - ct * sa * cg) - return lon - np.degrees(d), np.degrees(tp) + return lon - uxpx.degrees(d), uxpx.degrees(tp) diff --git a/glass/observations.py b/glass/observations.py index 63c5ba45..19540271 100644 --- a/glass/observations.py +++ b/glass/observations.py @@ -29,16 +29,22 @@ import itertools import math +from numbers import Number from typing import TYPE_CHECKING import healpy as hp import numpy as np +import glass._array_api_utils as _utils import glass.arraytools if TYPE_CHECKING: + from types import ModuleType + from numpy.typing import NDArray + from glass._array_api_utils import FloatArray + def vmap_galactic_ecliptic( nside: int, @@ -88,12 +94,12 @@ def vmap_galactic_ecliptic( def gaussian_nz( - z: NDArray[np.float64], - mean: float | NDArray[np.float64], - sigma: float | NDArray[np.float64], + z: FloatArray, + mean: float | FloatArray, + sigma: float | FloatArray, *, - norm: float | NDArray[np.float64] | None = None, -) -> NDArray[np.float64]: + norm: float | FloatArray | None = None, +) -> FloatArray: """ Gaussian redshift distribution. @@ -119,11 +125,20 @@ def gaussian_nz( The redshift distribution at the given ``z`` values. """ - mean = np.reshape(mean, np.shape(mean) + (1,) * np.ndim(z)) - sigma = np.reshape(sigma, np.shape(sigma) + (1,) * np.ndim(z)) + arrays_to_check = tuple( + x for x in (z, mean, sigma, norm) if not (isinstance(x, Number) or x is None) + ) + xp = _utils.get_namespace(*arrays_to_check) + uxpx = _utils.XPAdditions(xp) - nz = np.exp(-(((z - mean) / sigma) ** 2) / 2) - nz /= np.trapezoid(nz, z, axis=-1)[..., np.newaxis] + mean = xp.asarray(mean, dtype=xp.float64) + sigma = xp.asarray(sigma, dtype=xp.float64) + + mean = xp.reshape(mean, mean.shape + (1,) * z.ndim) # type: ignore[union-attr] + sigma = xp.reshape(sigma, sigma.shape + (1,) * z.ndim) # type: ignore[union-attr] + + nz = xp.exp(-(((z - mean) / sigma) ** 2) / 2) + nz /= uxpx.trapezoid(nz, z, axis=-1)[..., xp.newaxis] if norm is not None: nz *= norm @@ -132,13 +147,13 @@ def gaussian_nz( def smail_nz( - z: NDArray[np.float64], - z_mode: float | NDArray[np.float64], - alpha: float | NDArray[np.float64], - beta: float | NDArray[np.float64], + z: FloatArray, + z_mode: float | FloatArray, + alpha: float | FloatArray, + beta: float | FloatArray, *, - norm: float | NDArray[np.float64] | None = None, -) -> NDArray[np.float64]: + norm: float | FloatArray | None = None, +) -> FloatArray: r""" Redshift distribution following Smail et al. (1994). @@ -174,17 +189,25 @@ def smail_nz( where :math:`z_0` is matched to the given mode of the distribution. """ - z_mode = np.asanyarray(z_mode)[..., np.newaxis] - alpha = np.asanyarray(alpha)[..., np.newaxis] - beta = np.asanyarray(beta)[..., np.newaxis] + arrays_to_check = tuple( + x + for x in (z, z_mode, alpha, beta, norm) + if not (isinstance(x, Number) or x is None) + ) + xp = _utils.get_namespace(*arrays_to_check) + uxpx = _utils.XPAdditions(xp) - pz = z**alpha * np.exp(-alpha / beta * (z / z_mode) ** beta) - pz /= np.trapezoid(pz, z, axis=-1)[..., np.newaxis] + z_mode = xp.asarray(z_mode, dtype=xp.float64)[..., xp.newaxis] + alpha = xp.asarray(alpha, dtype=xp.float64)[..., xp.newaxis] + beta = xp.asarray(beta, dtype=xp.float64)[..., xp.newaxis] + + pz = z**alpha * xp.exp(-alpha / beta * (z / z_mode) ** beta) + pz /= uxpx.trapezoid(pz, z, axis=-1)[..., xp.newaxis] if norm is not None: pz *= norm - return pz # type: ignore[no-any-return] + return pz def fixed_zbins( @@ -193,6 +216,7 @@ def fixed_zbins( *, nbins: int | None = None, dz: float | None = None, + xp: ModuleType | None = None, ) -> list[tuple[float, float]]: """ Tomographic redshift bins of fixed size. @@ -210,6 +234,9 @@ def fixed_zbins( Number of redshift bins. Only one of ``nbins`` and ``dz`` can be given. dz Size of redshift bin. Only one of ``nbins`` and ``dz`` can be given. + xp + The array library backend to use for array operations. If this is not + specified, numpy with be used. Returns ------- @@ -221,10 +248,16 @@ def fixed_zbins( If both ``nbins`` and ``dz`` are given. """ + xp = np if xp is None else xp + if nbins is not None and dz is None: - zbinedges = np.linspace(zmin, zmax, nbins + 1) + zbinedges = xp.linspace(zmin, zmax, nbins + 1) elif nbins is None and dz is not None: - zbinedges = np.arange(zmin, np.nextafter(zmax + dz, zmax), dz) + zbinedges = xp.arange( + zmin, + xp.nextafter(xp.asarray(zmax + dz), xp.asarray(zmax)), + dz, + ) else: msg = "exactly one of nbins and dz must be given" raise ValueError(msg) @@ -233,8 +266,8 @@ def fixed_zbins( def equal_dens_zbins( - z: NDArray[np.float64], - nz: NDArray[np.float64], + z: FloatArray, + nz: FloatArray, nbins: int, ) -> list[tuple[float, float]]: """ @@ -257,23 +290,26 @@ def equal_dens_zbins( A list of redshift bin edges. """ + xp = _utils.get_namespace(z, nz) + uxpx = _utils.XPAdditions(xp) + # compute the normalised cumulative distribution function # first compute the cumulative integral (by trapezoidal rule) # then normalise: the first z is at CDF = 0, the last z at CDF = 1 # interpolate to find the z values at CDF = i/nbins for i = 0, ..., nbins cuml_nz = glass.arraytools.cumulative_trapezoid(nz, z) - cuml_nz /= cuml_nz[[-1]] - zbinedges = np.interp(np.linspace(0, 1, nbins + 1), cuml_nz, z) + cuml_nz /= cuml_nz[-1] + zbinedges = uxpx.interp(xp.linspace(0, 1, nbins + 1), cuml_nz, z) return list(itertools.pairwise(zbinedges)) def tomo_nz_gausserr( - z: NDArray[np.float64], - nz: NDArray[np.float64], + z: FloatArray, + nz: FloatArray, sigma_0: float, zbins: list[tuple[float, float]], -) -> NDArray[np.float64]: +) -> FloatArray: """ Tomographic redshift bins with a Gaussian redshift error. @@ -308,23 +344,28 @@ def tomo_nz_gausserr( produce redshift bins of fixed size """ + xp = _utils.get_namespace(z, nz) + uxpx = _utils.XPAdditions(xp) + # converting zbins into an array: - zbins_arr = np.asanyarray(zbins) + zbins_arr = xp.asarray(zbins) # bin edges and adds a new axis - z_lower = zbins_arr[:, 0, np.newaxis] - z_upper = zbins_arr[:, 1, np.newaxis] + z_lower = zbins_arr[:, 0, xp.newaxis] + z_upper = zbins_arr[:, 1, xp.newaxis] # we need a vectorised version of the error function: - erf = np.vectorize(math.erf, otypes=(float,)) + erf = uxpx.vectorize(math.erf, otypes=(float,)) # compute the probabilities that redshifts z end up in each bin # then apply probability as weights to given nz # leading axis corresponds to the different bins sz = 2**0.5 * sigma_0 * (1 + z) - binned_nz = erf((z - z_lower) / sz) - binned_nz -= erf((z - z_upper) / sz) - binned_nz /= 1 + erf(z / sz) + # we need to call xp.asarray here because erf will return a numpy + # array for array libs which do not implement vectorize. + binned_nz = xp.asarray(erf((z - z_lower) / sz)) + binned_nz -= xp.asarray(erf((z - z_upper) / sz)) + binned_nz /= 1 + xp.asarray(erf(z / sz)) binned_nz *= nz - return binned_nz # type: ignore[no-any-return] + return binned_nz diff --git a/tests/test_arraytools.py b/tests/test_arraytools.py index 5760259d..35d2f47b 100644 --- a/tests/test_arraytools.py +++ b/tests/test_arraytools.py @@ -1,3 +1,4 @@ +from types import ModuleType from typing import TYPE_CHECKING import numpy as np @@ -151,48 +152,21 @@ def test_trapezoid_product() -> None: np.testing.assert_allclose(s, 1.0) -def test_cumulative_trapezoid() -> None: +def test_cumulative_trapezoid(xp: ModuleType) -> None: # 1D f and x - f = np.array([1, 2, 3, 4]) - x = np.array([0, 1, 2, 3]) - - # default dtype (int) + f = xp.asarray([1, 2, 3, 4]) + x = xp.asarray([0, 1, 2, 3]) ct = glass.arraytools.cumulative_trapezoid(f, x) - np.testing.assert_allclose(ct, np.array([0, 1, 4, 7])) - - # explicit dtype (float) - - ct = glass.arraytools.cumulative_trapezoid(f, x, dtype=float) - np.testing.assert_allclose(ct, np.array([0.0, 1.5, 4.0, 7.5])) - - # explicit return array - - out = np.zeros((4,)) - ct = glass.arraytools.cumulative_trapezoid(f, x, dtype=float, out=out) - np.testing.assert_equal(ct, out) + np.testing.assert_allclose(ct, xp.asarray([0.0, 1.5, 4.0, 7.5])) # 2D f and 1D x - - f = np.array([[1, 4, 9, 16], [2, 3, 5, 7]]) - x = np.array([0, 1, 2.5, 4]) - - # default dtype (int) + f = xp.asarray([[1, 4, 9, 16], [2, 3, 5, 7]]) + x = xp.asarray([0, 1, 2.5, 4]) ct = glass.arraytools.cumulative_trapezoid(f, x) - np.testing.assert_allclose(ct, np.array([[0, 2, 12, 31], [0, 2, 8, 17]])) - - # explicit dtype (float) - - ct = glass.arraytools.cumulative_trapezoid(f, x, dtype=float) np.testing.assert_allclose( ct, np.array([[0.0, 2.5, 12.25, 31.0], [0.0, 2.5, 8.5, 17.5]]), ) - - # explicit return array - - out = np.zeros((2, 4)) - ct = glass.arraytools.cumulative_trapezoid(f, x, dtype=float, out=out) - np.testing.assert_equal(ct, out) diff --git a/tests/test_lensing.py b/tests/test_lensing.py index 670917e2..46999e29 100644 --- a/tests/test_lensing.py +++ b/tests/test_lensing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING import healpix @@ -9,6 +10,8 @@ import glass if TYPE_CHECKING: + from types import ModuleType + from glass.cosmology import Cosmology @@ -102,29 +105,50 @@ def test_multi_plane_weights( @pytest.mark.parametrize("usecomplex", [True, False]) -def test_deflect_nsew(usecomplex: bool) -> None: # noqa: FBT001 +def test_deflect_nsew(xp: ModuleType, usecomplex: bool) -> None: # noqa: FBT001 d = 5.0 - r = np.radians(d) + r = math.radians(d) def alpha(re: float, im: float, *, usecomplex: bool) -> complex | list[float]: return re + 1j * im if usecomplex else [re, im] # north - lon, lat = glass.deflect(0.0, 0.0, alpha(r, 0, usecomplex=usecomplex)) + lon, lat = glass.deflect(0.0, 0.0, alpha(r, 0, usecomplex=usecomplex), xp=xp) np.testing.assert_allclose([lon, lat], [0.0, d], atol=1e-15) # south - lon, lat = glass.deflect(0.0, 0.0, alpha(-r, 0, usecomplex=usecomplex)) + lon, lat = glass.deflect(0.0, 0.0, alpha(-r, 0, usecomplex=usecomplex), xp=xp) np.testing.assert_allclose([lon, lat], [0.0, -d], atol=1e-15) # east - lon, lat = glass.deflect(0.0, 0.0, alpha(0, r, usecomplex=usecomplex)) + lon, lat = glass.deflect(0.0, 0.0, alpha(0, r, usecomplex=usecomplex), xp=xp) np.testing.assert_allclose([lon, lat], [-d, 0.0], atol=1e-15) # west - lon, lat = glass.deflect(0.0, 0.0, alpha(0, -r, usecomplex=usecomplex)) + lon, lat = glass.deflect(0.0, 0.0, alpha(0, -r, usecomplex=usecomplex), xp=xp) + np.testing.assert_allclose([lon, lat], [d, 0.0], atol=1e-15) + + # At least one input is an array + lon, lat = glass.deflect( + xp.asarray(0.0), xp.asarray(0.0), alpha(0, -r, usecomplex=usecomplex) + ) np.testing.assert_allclose([lon, lat], [d, 0.0], atol=1e-15) + lon, lat = glass.deflect( + xp.asarray([0.0, 0.0]), + xp.asarray([0.0, 0.0]), + alpha(0, -r, usecomplex=usecomplex), + ) + np.testing.assert_allclose(lon, xp.asarray([d, d]), atol=1e-15) + np.testing.assert_allclose(lat, 0.0, atol=1e-15) + + # No inputs are arrays and xp not provided + with pytest.raises( + ValueError, + match="Either, one positional input must be an array or xp must be provided", + ): + glass.deflect(0.0, 0.0, alpha(0, -r, usecomplex=usecomplex)) + def test_deflect_many(rng: np.random.Generator) -> None: n = 1000 diff --git a/tests/test_observations.py b/tests/test_observations.py index a1066413..ae940b64 100644 --- a/tests/test_observations.py +++ b/tests/test_observations.py @@ -1,9 +1,19 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + import healpix import numpy as np import pytest import glass +if TYPE_CHECKING: + from types import ModuleType + + from conftest import UnifiedGenerator + def test_vmap_galactic_ecliptic() -> None: """Add unit tests for :func:`glass.vmap_galactic_ecliptic`.""" @@ -34,11 +44,11 @@ def test_vmap_galactic_ecliptic() -> None: glass.vmap_galactic_ecliptic(n_side, ecliptic=(1, 2, 3)) # type: ignore[arg-type] -def test_gaussian_nz(rng: np.random.Generator) -> None: +def test_gaussian_nz(xp: ModuleType, urng: UnifiedGenerator) -> None: """Add unit tests for :func:`glass.gaussian_nz`.""" mean = 0 sigma = 1 - z = np.linspace(0, 1, 11) + z = xp.linspace(0, 1, 11) # check passing in the norm @@ -49,25 +59,25 @@ def test_gaussian_nz(rng: np.random.Generator) -> None: norm = 1 nz = glass.gaussian_nz(z, mean, sigma, norm=norm) - np.testing.assert_allclose(nz.sum() / nz.shape, norm, rtol=1e-2) + np.testing.assert_allclose(xp.sum(nz) / nz.shape[0], norm, rtol=1e-2) # check multidimensionality size nz = glass.gaussian_nz( z, - np.tile(mean, z.shape), - np.tile(sigma, z.shape), - norm=rng.normal(size=z.shape), + xp.tile(xp.asarray(mean), z.shape), + xp.tile(xp.asarray(sigma), z.shape), + norm=urng.normal(size=z.shape), ) - np.testing.assert_array_equal(nz.shape, (len(z), len(z))) + assert nz.shape == (z.size, z.size) -def test_smail_nz() -> None: +def test_smail_nz(xp: ModuleType) -> None: """Add unit tests for :func:`glass.smail_nz`.""" alpha = 1 beta = 1 mode = 1 - z = np.linspace(0, 1, 11) + z = xp.linspace(0, 1, 11) # check passing in the norm @@ -75,64 +85,74 @@ def test_smail_nz() -> None: np.testing.assert_array_equal(pz, np.zeros_like(pz)) -def test_fixed_zbins() -> None: +def test_fixed_zbins(xp: ModuleType) -> None: """Add unit tests for :func:`glass.fixed_zbins`.""" - zmin = 0 - zmax = 1 + zmin = 0.0 + zmax = 1.0 # check nbins input nbins = 5 - expected_zbins = [(0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)] - zbins = glass.fixed_zbins(zmin, zmax, nbins=nbins) - np.testing.assert_array_equal(len(zbins), nbins) + expected_zbins = xp.asarray( + [ + tuple(xp.asarray(i) for i in pair) + for pair in [(0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)] + ] + ) + zbins = glass.fixed_zbins(zmin, zmax, nbins=nbins, xp=xp) + assert len(zbins) == nbins np.testing.assert_allclose(zbins, expected_zbins, rtol=1e-15) # check dz input dz = 0.2 - zbins = glass.fixed_zbins(zmin, zmax, dz=dz) - np.testing.assert_array_equal(len(zbins), np.ceil((zmax - zmin) / dz)) + zbins = glass.fixed_zbins(zmin, zmax, dz=dz, xp=xp) + assert len(zbins) == math.ceil((zmax - zmin) / dz) np.testing.assert_allclose(zbins, expected_zbins, rtol=1e-15) # check dz for spacing which results in a max value above zmax - zbins = glass.fixed_zbins(zmin, zmax, dz=0.3) - np.testing.assert_array_less(zmax, zbins[-1][1]) + zbins = glass.fixed_zbins(zmin, zmax, dz=0.3, xp=xp) + assert zmax < zbins[-1][1] # check error raised with pytest.raises(ValueError, match="exactly one of nbins and dz must be given"): - glass.fixed_zbins(zmin, zmax, nbins=nbins, dz=dz) + glass.fixed_zbins(zmin, zmax, nbins=nbins, dz=dz, xp=xp) -def test_equal_dens_zbins() -> None: +def test_equal_dens_zbins(xp: ModuleType) -> None: """Add unit tests for :func:`glass.equal_dens_zbins`.""" - z = np.linspace(0, 1, 11) + z = xp.linspace(0, 1, 11) nbins = 5 # check expected zbins returned - expected_zbins = [(0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)] - zbins = glass.equal_dens_zbins(z, np.ones_like(z), nbins) + expected_zbins = xp.asarray( + [ + tuple(xp.asarray(i) for i in pair) + for pair in [(0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)] + ] + ) + zbins = glass.equal_dens_zbins(z, xp.ones_like(z), nbins) np.testing.assert_allclose(zbins, expected_zbins, rtol=1e-15) # check output shape - np.testing.assert_array_equal(len(zbins), nbins) + assert len(zbins) == nbins -def test_tomo_nz_gausserr() -> None: +def test_tomo_nz_gausserr(xp: ModuleType) -> None: """Add unit tests for :func:`glass.tomo_nz_gausserr`.""" sigma_0 = 0.1 - z = np.linspace(0, 1, 11) + z = xp.linspace(0, 1, 11) zbins = [(0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)] # check zeros returned - binned_nz = glass.tomo_nz_gausserr(z, np.zeros_like(z), sigma_0, zbins) - np.testing.assert_array_equal(binned_nz, np.zeros_like(binned_nz)) + binned_nz = glass.tomo_nz_gausserr(z, xp.zeros_like(z), sigma_0, zbins) + np.testing.assert_array_equal(binned_nz, xp.zeros_like(binned_nz)) # check the shape of the output - np.testing.assert_array_equal(binned_nz.shape, (len(zbins), len(z))) + assert binned_nz.shape == (len(zbins), z.size)