Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
10d4fd6
gh-410: port observations.gaussian_nz
connoraird Oct 15, 2025
02b16a4
gh-410: port observations.smail_nz
connoraird Oct 15, 2025
78fa466
gh-410: port observations.fixed_zbins
connoraird Oct 15, 2025
5159d45
gh-410: port observations.equal_dens_zbins
connoraird Oct 16, 2025
1999e59
gh-410: port observations.tomo_nz_gausserr
connoraird Oct 16, 2025
05d106e
Wrap numpy import into function and remove unnecessary wrapping of ja…
connoraird Oct 17, 2025
c563ae1
gh-410: port lensing.deflect
connoraird Oct 17, 2025
67ba07c
Add detailed docstrings to _array_api_utils
connoraird Oct 17, 2025
0f33225
Remove np and array_api_strict as required imports in _array_api_utils
connoraird Oct 17, 2025
395237d
Remove types from docstrings in new functions
connoraird Oct 20, 2025
43ebb9d
Use UnifiedGenerator from conftest
connoraird Oct 20, 2025
7a674e4
Merge branch 'main' into connor/issue-410
paddyroddy Oct 20, 2025
4baad48
Add xp.nextafter back in to include max bin
connoraird Oct 20, 2025
3cc480b
Merge remote-tracking branch 'origin/main' into connor/issue-410
connoraird Oct 22, 2025
0289412
Use np.testing.assert_allclose instead of pytest.approx
connoraird Oct 22, 2025
3046206
Fix docstring by making Raises description fit on one line
connoraird Oct 22, 2025
b0d8bce
Merge branch 'main' into connor/issue-410
connoraird Oct 22, 2025
92fd697
undo leftove changes to np.testing calls
connoraird Oct 22, 2025
c84090a
Rmove comment
connoraird Oct 22, 2025
b0c900d
Undo docstring change
connoraird Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions glass/_array_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray


class CompatibleBackendNotFoundError(Exception):
Expand Down Expand Up @@ -652,3 +653,101 @@ def apply_along_axis(

msg = "the array backend in not supported"
raise NotImplementedError(msg)

def vectorize(
self,
pyfunc: Callable[..., Any],
otypes: tuple[type[float]],
) -> Callable[..., Any]:
"""
Returns an object that acts like pyfunc, but takes arrays as input.

Parameters
----------
pyfunc
Python function to vectorize.
otypes
Output types.

Returns
-------
Vectorized function.

Raises
------
NotImplementedError
If the array backend is not supported.

Notes
-----
See https://github.com/glass-dev/glass/issues/671
"""
if self.xp.__name__ == "numpy":
return self.xp.vectorize(pyfunc, otypes=otypes) # type: ignore[no-any-return]

if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
# Import here to prevent users relying on numpy unless in this instance
np = import_numpy(self.xp.__name__)

return np.vectorize(pyfunc, otypes=otypes) # type: ignore[no-any-return]

msg = "the array backend in not supported"
raise NotImplementedError(msg)

def radians(self, deg_arr: AnyArray) -> AnyArray:
"""
Convert angles from degrees to radians.

Parameters
----------
deg_arr
Array of angles in degrees.

Returns
-------
Array of angles in radians.

Raises
------
NotImplementedError
If the array backend is not supported.
"""
if self.xp.__name__ in {"numpy", "jax.numpy"}:
return self.xp.radians(deg_arr)

if self.xp.__name__ == "array_api_strict":
np = import_numpy(self.xp.__name__)

return self.xp.asarray(np.radians(deg_arr))

msg = "the array backend in not supported"
raise NotImplementedError(msg)

def degrees(self, deg_arr: AnyArray) -> AnyArray:
"""
Convert angles from radians to degrees.

Parameters
----------
deg_arr
Array of angles in radians.

Returns
-------
Array of angles in degrees.

Raises
------
NotImplementedError
If the array backend is not supported.
"""
if self.xp.__name__ in {"numpy", "jax.numpy"}:
return self.xp.degrees(deg_arr)

if self.xp.__name__ == "array_api_strict":
np = import_numpy(self.xp.__name__)

return self.xp.asarray(np.degrees(deg_arr))

msg = "the array backend in not supported"
raise NotImplementedError(msg)
29 changes: 13 additions & 16 deletions glass/arraytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
if TYPE_CHECKING:
from typing import Unpack

from numpy.typing import DTypeLike, NDArray
from numpy.typing import NDArray

from glass._array_api_utils import FloatArray
from glass._array_api_utils import AnyArray, FloatArray, IntArray


def broadcast_first(
Expand Down Expand Up @@ -183,11 +183,9 @@ def trapezoid_product(


def cumulative_trapezoid(
f: NDArray[np.int_] | NDArray[np.float64],
x: NDArray[np.int_] | NDArray[np.float64],
dtype: DTypeLike | None = None,
out: NDArray[np.float64] | None = None,
) -> NDArray[np.float64]:
f: IntArray | FloatArray,
x: IntArray | FloatArray,
) -> AnyArray:
"""
Cumulative trapezoidal rule along last axis.

Expand All @@ -197,19 +195,18 @@ def cumulative_trapezoid(
The function values.
x
The x-coordinates.
dtype
The output data type.
out
The output array.

Returns
-------
The cumulative integral of the function.

"""
if out is None:
out = np.empty_like(f, dtype=dtype)
xp = _utils.get_namespace(f, x)

np.cumsum((f[..., 1:] + f[..., :-1]) / 2 * np.diff(x), axis=-1, out=out[..., 1:])
out[..., 0] = 0
return out
f = xp.asarray(f, dtype=xp.float64)
x = xp.asarray(x, dtype=xp.float64)

# Compute the cumulative trapezoid without mutating any arrays
return xp.cumulative_sum(
(f[..., 1:] + f[..., :-1]) * 0.5 * xp.diff(x), axis=-1, include_initial=True
)
2 changes: 1 addition & 1 deletion glass/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def redshifts_from_nz(
# go through extra dimensions; also works if dims is empty
for k in np.ndindex(dims):
# compute the CDF of each galaxy population
cdf = glass.arraytools.cumulative_trapezoid(nz_out[k], z_out[k], dtype=float)
cdf = glass.arraytools.cumulative_trapezoid(nz_out[k], z_out[k])
cdf /= cdf[-1]

# sample redshifts and store result
Expand Down
55 changes: 37 additions & 18 deletions glass/lensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,22 @@

from __future__ import annotations

from numbers import Number
from typing import TYPE_CHECKING, Literal, overload

import healpy as hp
import numpy as np

import glass._array_api_utils as _utils

if TYPE_CHECKING:
from collections.abc import Sequence
from types import ModuleType

from numpy.typing import NDArray

import glass
from glass._array_api_utils import ComplexArray, FloatArray
from glass.cosmology import Cosmology


Expand Down Expand Up @@ -601,12 +606,13 @@ def multi_plane_weights(


def deflect(
lon: float | NDArray[np.float64],
lat: float | NDArray[np.float64],
alpha: complex | list[float] | NDArray[np.complex128] | NDArray[np.float64],
lon: float | FloatArray,
lat: float | FloatArray,
alpha: complex | ComplexArray | FloatArray,
xp: ModuleType | None = None,
) -> tuple[
NDArray[np.float64],
NDArray[np.float64],
FloatArray,
FloatArray,
]:
r"""
Apply deflections to positions.
Expand Down Expand Up @@ -639,28 +645,41 @@ def deflect(
exponential map.

"""
alpha = np.asanyarray(alpha)
if np.iscomplexobj(alpha):
alpha1, alpha2 = alpha.real, alpha.imag
arrays_to_check = tuple(
x
for x in (lon, lat, alpha)
if not isinstance(x, Number) and not isinstance(x, list)
)
if len(arrays_to_check) == 0:
if xp is None:
msg = "Either, one positional input must be an array or xp must be provided"
raise ValueError(msg)
else:
xp = _utils.get_namespace(*arrays_to_check)
uxpx = _utils.XPAdditions(xp)

alpha = xp.asarray(alpha)
if xp.isdtype(alpha.dtype, "complex floating"): # type: ignore[union-attr]
alpha1, alpha2 = xp.real(alpha), xp.imag(alpha)
else:
alpha1, alpha2 = alpha
alpha1, alpha2 = alpha # type: ignore[misc]

# we know great-circle navigation:
# θ' = arctan2(√[(cosθ sin|α| - sinθ cos|α| cosγ)² + (sinθ sinγ)²],
# cosθ cos|α| + sinθ sin|α| cosγ)
# δ = arctan2(sin|α| sinγ, sinθ cos|α| - cosθ sin|α| cosγ)

t = np.radians(lat)
ct, st = np.sin(t), np.cos(t) # sin and cos flipped: lat not co-lat
t = uxpx.radians(xp.asarray(lat))
ct, st = xp.sin(t), xp.cos(t) # sin and cos flipped: lat not co-lat

a = np.hypot(alpha1, alpha2) # abs(alpha)
g = np.arctan2(alpha2, alpha1) # arg(alpha)
ca, sa = np.cos(a), np.sin(a)
cg, sg = np.cos(g), np.sin(g)
a = xp.hypot(alpha1, alpha2) # abs(alpha)
g = xp.atan2(alpha2, alpha1) # arg(alpha)
ca, sa = xp.cos(a), xp.sin(a)
cg, sg = xp.cos(g), xp.sin(g)

# flipped atan2 arguments for lat instead of co-lat
tp = np.arctan2(ct * ca + st * sa * cg, np.hypot(ct * sa - st * ca * cg, st * sg))
tp = xp.atan2(ct * ca + st * sa * cg, xp.hypot(ct * sa - st * ca * cg, st * sg))

d = np.arctan2(sa * sg, st * ca - ct * sa * cg)
d = xp.atan2(sa * sg, st * ca - ct * sa * cg)

return lon - np.degrees(d), np.degrees(tp)
return lon - uxpx.degrees(d), uxpx.degrees(tp)
Loading
Loading