Skip to content

Commit a52607d

Browse files
committed
Remove np and array_api_strict as required imports in _array_api_utils
1 parent e8c8ee9 commit a52607d

File tree

1 file changed

+50
-30
lines changed

1 file changed

+50
-30
lines changed

glass/_array_api_utils.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818

1919
from typing import TYPE_CHECKING, Any, TypeAlias
2020

21-
import array_api_strict
22-
import numpy as np
23-
import numpy.random
24-
2521
if TYPE_CHECKING:
2622
from collections.abc import Callable
2723
from types import ModuleType
2824

25+
import numpy as np
2926
from array_api_strict._array_object import Array as AArray
3027
from jaxtyping import Array as JAXArray
3128
from numpy.typing import DTypeLike, NDArray
@@ -34,11 +31,11 @@
3431

3532
Size: TypeAlias = int | tuple[int, ...] | None
3633

37-
AnyArray: TypeAlias = NDArray[Any] | JAXArray
38-
ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray
39-
DoubleArray: TypeAlias = NDArray[np.double] | JAXArray
40-
FloatArray: TypeAlias = NDArray[np.float64] | JAXArray
41-
IntArray: TypeAlias = NDArray[np.int_] | JAXArray
34+
AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
35+
ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
36+
DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
37+
FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
38+
IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray
4239

4340

4441
def import_numpy(backend: str, function_name: str) -> ModuleType:
@@ -68,7 +65,7 @@ def import_numpy(backend: str, function_name: str) -> ModuleType:
6865
backend does not implement a needed function.
6966
"""
7067
try:
71-
import numpy as np # noqa: PLC0415
68+
import numpy # noqa: ICN001, PLC0415
7269

7370
except ModuleNotFoundError as err:
7471
msg = (
@@ -79,7 +76,7 @@ def import_numpy(backend: str, function_name: str) -> ModuleType:
7976
)
8077
raise ModuleNotFoundError(msg) from err
8178
else:
82-
return np
79+
return numpy
8380

8481

8582
def get_namespace(*arrays: AnyArray) -> ModuleType:
@@ -135,15 +132,20 @@ def rng_dispatcher(
135132
NotImplementedError
136133
If the array backend is not supported.
137134
"""
138-
backend = array.__array_namespace__().__name__
135+
xp = get_namespace(array)
136+
backend = xp.__name__
137+
139138
if backend == "jax.numpy":
140139
import glass.jax # noqa: PLC0415
141140

142141
return glass.jax.Generator(seed=42)
142+
143143
if backend == "numpy":
144-
return np.random.default_rng()
144+
return xp.random.default_rng() # type: ignore[no-any-return]
145+
145146
if backend == "array_api_strict":
146147
return Generator(seed=42)
148+
147149
msg = "the array backend in not supported"
148150
raise NotImplementedError(msg)
149151

@@ -156,11 +158,11 @@ class Generator:
156158
with array_api_strict.
157159
"""
158160

159-
__slots__ = ("rng",)
161+
__slots__ = ("axp", "nxp", "rng")
160162

161163
def __init__(
162164
self,
163-
seed: int | bool | NDArray[np.int_ | np.bool] | None = None, # noqa: FBT001
165+
seed: int | bool | AArray | None = None, # noqa: FBT001
164166
) -> None:
165167
"""
166168
Initialize the Generator.
@@ -170,13 +172,18 @@ def __init__(
170172
seed : int | bool | NDArray[np.int_ | np.bool] | None, optional
171173
Seed for the random number generator.
172174
"""
173-
self.rng = numpy.random.default_rng(seed=seed) # type: ignore[arg-type]
175+
import array_api_strict # noqa: PLC0415
176+
import numpy as np # noqa: PLC0415
177+
178+
self.axp = array_api_strict
179+
self.nxp = np
180+
self.rng = self.nxp.random.default_rng(seed=seed)
174181

175182
def random(
176183
self,
177184
size: Size = None,
178-
dtype: DTypeLike | None = np.float64,
179-
out: NDArray[Any] | None = None,
185+
dtype: DTypeLike | None = None,
186+
out: AArray | None = None,
180187
) -> AArray:
181188
"""
182189
Return random floats in the half-open interval [0.0, 1.0).
@@ -195,12 +202,13 @@ def random(
195202
AArray
196203
Array of random floats.
197204
"""
198-
return array_api_strict.asarray(self.rng.random(size, dtype, out)) # type: ignore[arg-type]
205+
dtype = dtype if dtype is not None else self.nxp.float64
206+
return self.axp.asarray(self.rng.random(size, dtype, out)) # type: ignore[arg-type]
199207

200208
def normal(
201209
self,
202-
loc: float | NDArray[np.floating] = 0.0,
203-
scale: float | NDArray[np.floating] = 1.0,
210+
loc: float | AArray = 0.0,
211+
scale: float | AArray = 1.0,
204212
size: Size = None,
205213
) -> AArray:
206214
"""
@@ -220,9 +228,9 @@ def normal(
220228
AArray
221229
Array of samples from the normal distribution.
222230
"""
223-
return array_api_strict.asarray(self.rng.normal(loc, scale, size))
231+
return self.axp.asarray(self.rng.normal(loc, scale, size))
224232

225-
def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArray:
233+
def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
226234
"""
227235
Draw samples from a Poisson distribution.
228236
@@ -238,13 +246,13 @@ def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArra
238246
AArray
239247
Array of samples from the Poisson distribution.
240248
"""
241-
return array_api_strict.asarray(self.rng.poisson(lam, size))
249+
return self.axp.asarray(self.rng.poisson(lam, size))
242250

243251
def standard_normal(
244252
self,
245253
size: Size = None,
246-
dtype: DTypeLike | None = np.float64,
247-
out: NDArray[Any] | None = None,
254+
dtype: DTypeLike | None = None,
255+
out: AArray | None = None,
248256
) -> AArray:
249257
"""
250258
Draw samples from a standard Normal distribution (mean=0, stdev=1).
@@ -263,12 +271,13 @@ def standard_normal(
263271
AArray
264272
Array of samples from the standard normal distribution.
265273
"""
266-
return array_api_strict.asarray(self.rng.standard_normal(size, dtype, out)) # type: ignore[arg-type]
274+
dtype = dtype if dtype is not None else self.nxp.float64
275+
return self.axp.asarray(self.rng.standard_normal(size, dtype, out)) # type: ignore[arg-type]
267276

268277
def uniform(
269278
self,
270-
low: float | NDArray[np.floating] = 0.0,
271-
high: float | NDArray[np.floating] = 1.0,
279+
low: float | AArray = 0.0,
280+
high: float | AArray = 1.0,
272281
size: Size = None,
273282
) -> AArray:
274283
"""
@@ -288,7 +297,7 @@ def uniform(
288297
AArray
289298
Array of samples from the uniform distribution.
290299
"""
291-
return array_api_strict.asarray(self.rng.uniform(low, high, size))
300+
return self.axp.asarray(self.rng.uniform(low, high, size))
292301

293302

294303
class XPAdditions:
@@ -354,8 +363,10 @@ def trapezoid(
354363
import glass.jax # noqa: PLC0415
355364

356365
return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
366+
357367
if self.backend == "numpy":
358368
return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
369+
359370
if self.backend == "array_api_strict":
360371
np = import_numpy(self.backend, "trapezoid")
361372

@@ -395,6 +406,7 @@ def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
395406
"""
396407
if self.backend in {"numpy", "jax.numpy"}:
397408
return self.xp.union1d(ar1, ar2)
409+
398410
if self.backend == "array_api_strict":
399411
np = import_numpy(self.backend, "union1d")
400412

@@ -452,6 +464,7 @@ def interp( # noqa: PLR0913
452464
return self.xp.interp(
453465
x, x_points, y_points, left=left, right=right, period=period
454466
)
467+
455468
if self.backend == "array_api_strict":
456469
np = import_numpy(self.backend, "interp")
457470

@@ -492,6 +505,7 @@ def gradient(self, f: AnyArray) -> AnyArray:
492505
"""
493506
if self.backend in {"numpy", "jax.numpy"}:
494507
return self.xp.gradient(f)
508+
495509
if self.backend == "array_api_strict":
496510
np = import_numpy(self.backend, "gradient")
497511

@@ -546,6 +560,7 @@ def linalg_lstsq(
546560
"""
547561
if self.backend in {"numpy", "jax.numpy"}:
548562
return self.xp.linalg.lstsq(a, b, rcond=rcond) # type: ignore[no-any-return]
563+
549564
if self.backend == "array_api_strict":
550565
np = import_numpy(self.backend, "linalg.lstsq")
551566

@@ -585,6 +600,7 @@ def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
585600
"""
586601
if self.backend in {"numpy", "jax.numpy"}:
587602
return self.xp.einsum(subscripts, *operands)
603+
588604
if self.backend == "array_api_strict":
589605
np = import_numpy(self.backend, "einsum")
590606

@@ -637,6 +653,7 @@ def apply_along_axis(
637653
"""
638654
if self.backend in {"numpy", "jax.numpy"}:
639655
return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
656+
640657
if self.backend == "array_api_strict":
641658
# Import here to prevent users relying on numpy unless in this instance
642659
np = import_numpy(self.backend, "apply_along_axis")
@@ -679,6 +696,7 @@ def vectorize(
679696
"""
680697
if self.backend == "numpy":
681698
return self.xp.vectorize(pyfunc, otypes=otypes) # type: ignore[no-any-return]
699+
682700
if self.backend in {"array_api_strict", "jax.numpy"}:
683701
# Import here to prevent users relying on numpy unless in this instance
684702
np = import_numpy(self.backend, "vectorize")
@@ -709,6 +727,7 @@ def radians(self, deg_arr: AnyArray) -> AnyArray:
709727
"""
710728
if self.backend in {"numpy", "jax.numpy"}:
711729
return self.xp.radians(deg_arr)
730+
712731
if self.backend == "array_api_strict":
713732
np = import_numpy(self.backend, "radians")
714733

@@ -738,6 +757,7 @@ def degrees(self, deg_arr: AnyArray) -> AnyArray:
738757
"""
739758
if self.backend in {"numpy", "jax.numpy"}:
740759
return self.xp.degrees(deg_arr)
760+
741761
if self.backend == "array_api_strict":
742762
np = import_numpy(self.backend, "degrees")
743763

0 commit comments

Comments
 (0)