1818
1919from typing import TYPE_CHECKING , Any , TypeAlias
2020
21- import array_api_strict
22- import numpy as np
23- import numpy .random
24-
2521if TYPE_CHECKING :
2622 from collections .abc import Callable
2723 from types import ModuleType
2824
25+ import numpy as np
2926 from array_api_strict ._array_object import Array as AArray
3027 from jaxtyping import Array as JAXArray
3128 from numpy .typing import DTypeLike , NDArray
3431
3532 Size : TypeAlias = int | tuple [int , ...] | None
3633
37- AnyArray : TypeAlias = NDArray [Any ] | JAXArray
38- ComplexArray : TypeAlias = NDArray [np .complex128 ] | JAXArray
39- DoubleArray : TypeAlias = NDArray [np .double ] | JAXArray
40- FloatArray : TypeAlias = NDArray [np .float64 ] | JAXArray
41- IntArray : TypeAlias = NDArray [np .int_ ] | JAXArray
34+ AnyArray : TypeAlias = NDArray [Any ] | JAXArray | AArray
35+ ComplexArray : TypeAlias = NDArray [np .complex128 ] | JAXArray | AArray
36+ DoubleArray : TypeAlias = NDArray [np .double ] | JAXArray | AArray
37+ FloatArray : TypeAlias = NDArray [np .float64 ] | JAXArray | AArray
38+ IntArray : TypeAlias = NDArray [np .int_ ] | JAXArray | AArray
4239
4340
4441def import_numpy (backend : str , function_name : str ) -> ModuleType :
@@ -68,7 +65,7 @@ def import_numpy(backend: str, function_name: str) -> ModuleType:
6865 backend does not implement a needed function.
6966 """
7067 try :
71- import numpy as np # noqa: PLC0415
68+ import numpy # noqa: ICN001, PLC0415
7269
7370 except ModuleNotFoundError as err :
7471 msg = (
@@ -79,7 +76,7 @@ def import_numpy(backend: str, function_name: str) -> ModuleType:
7976 )
8077 raise ModuleNotFoundError (msg ) from err
8178 else :
82- return np
79+ return numpy
8380
8481
8582def get_namespace (* arrays : AnyArray ) -> ModuleType :
@@ -135,15 +132,20 @@ def rng_dispatcher(
135132 NotImplementedError
136133 If the array backend is not supported.
137134 """
138- backend = array .__array_namespace__ ().__name__
135+ xp = get_namespace (array )
136+ backend = xp .__name__
137+
139138 if backend == "jax.numpy" :
140139 import glass .jax # noqa: PLC0415
141140
142141 return glass .jax .Generator (seed = 42 )
142+
143143 if backend == "numpy" :
144- return np .random .default_rng ()
144+ return xp .random .default_rng () # type: ignore[no-any-return]
145+
145146 if backend == "array_api_strict" :
146147 return Generator (seed = 42 )
148+
147149 msg = "the array backend in not supported"
148150 raise NotImplementedError (msg )
149151
@@ -156,11 +158,11 @@ class Generator:
156158 with array_api_strict.
157159 """
158160
159- __slots__ = ("rng" , )
161+ __slots__ = ("axp" , "nxp" , "rng" )
160162
161163 def __init__ (
162164 self ,
163- seed : int | bool | NDArray [ np . int_ | np . bool ] | None = None , # noqa: FBT001
165+ seed : int | bool | AArray | None = None , # noqa: FBT001
164166 ) -> None :
165167 """
166168 Initialize the Generator.
@@ -170,13 +172,18 @@ def __init__(
170172 seed : int | bool | NDArray[np.int_ | np.bool] | None, optional
171173 Seed for the random number generator.
172174 """
173- self .rng = numpy .random .default_rng (seed = seed ) # type: ignore[arg-type]
175+ import array_api_strict # noqa: PLC0415
176+ import numpy as np # noqa: PLC0415
177+
178+ self .axp = array_api_strict
179+ self .nxp = np
180+ self .rng = self .nxp .random .default_rng (seed = seed )
174181
175182 def random (
176183 self ,
177184 size : Size = None ,
178- dtype : DTypeLike | None = np . float64 ,
179- out : NDArray [ Any ] | None = None ,
185+ dtype : DTypeLike | None = None ,
186+ out : AArray | None = None ,
180187 ) -> AArray :
181188 """
182189 Return random floats in the half-open interval [0.0, 1.0).
@@ -195,12 +202,13 @@ def random(
195202 AArray
196203 Array of random floats.
197204 """
198- return array_api_strict .asarray (self .rng .random (size , dtype , out )) # type: ignore[arg-type]
205+ dtype = dtype if dtype is not None else self .nxp .float64
206+ return self .axp .asarray (self .rng .random (size , dtype , out )) # type: ignore[arg-type]
199207
200208 def normal (
201209 self ,
202- loc : float | NDArray [ np . floating ] = 0.0 ,
203- scale : float | NDArray [ np . floating ] = 1.0 ,
210+ loc : float | AArray = 0.0 ,
211+ scale : float | AArray = 1.0 ,
204212 size : Size = None ,
205213 ) -> AArray :
206214 """
@@ -220,9 +228,9 @@ def normal(
220228 AArray
221229 Array of samples from the normal distribution.
222230 """
223- return array_api_strict .asarray (self .rng .normal (loc , scale , size ))
231+ return self . axp .asarray (self .rng .normal (loc , scale , size ))
224232
225- def poisson (self , lam : float | NDArray [ np . floating ] , size : Size = None ) -> AArray :
233+ def poisson (self , lam : float | AArray , size : Size = None ) -> AArray :
226234 """
227235 Draw samples from a Poisson distribution.
228236
@@ -238,13 +246,13 @@ def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArra
238246 AArray
239247 Array of samples from the Poisson distribution.
240248 """
241- return array_api_strict .asarray (self .rng .poisson (lam , size ))
249+ return self . axp .asarray (self .rng .poisson (lam , size ))
242250
243251 def standard_normal (
244252 self ,
245253 size : Size = None ,
246- dtype : DTypeLike | None = np . float64 ,
247- out : NDArray [ Any ] | None = None ,
254+ dtype : DTypeLike | None = None ,
255+ out : AArray | None = None ,
248256 ) -> AArray :
249257 """
250258 Draw samples from a standard Normal distribution (mean=0, stdev=1).
@@ -263,12 +271,13 @@ def standard_normal(
263271 AArray
264272 Array of samples from the standard normal distribution.
265273 """
266- return array_api_strict .asarray (self .rng .standard_normal (size , dtype , out )) # type: ignore[arg-type]
274+ dtype = dtype if dtype is not None else self .nxp .float64
275+ return self .axp .asarray (self .rng .standard_normal (size , dtype , out )) # type: ignore[arg-type]
267276
268277 def uniform (
269278 self ,
270- low : float | NDArray [ np . floating ] = 0.0 ,
271- high : float | NDArray [ np . floating ] = 1.0 ,
279+ low : float | AArray = 0.0 ,
280+ high : float | AArray = 1.0 ,
272281 size : Size = None ,
273282 ) -> AArray :
274283 """
@@ -288,7 +297,7 @@ def uniform(
288297 AArray
289298 Array of samples from the uniform distribution.
290299 """
291- return array_api_strict .asarray (self .rng .uniform (low , high , size ))
300+ return self . axp .asarray (self .rng .uniform (low , high , size ))
292301
293302
294303class XPAdditions :
@@ -354,8 +363,10 @@ def trapezoid(
354363 import glass .jax # noqa: PLC0415
355364
356365 return glass .jax .trapezoid (y , x = x , dx = dx , axis = axis )
366+
357367 if self .backend == "numpy" :
358368 return self .xp .trapezoid (y , x = x , dx = dx , axis = axis )
369+
359370 if self .backend == "array_api_strict" :
360371 np = import_numpy (self .backend , "trapezoid" )
361372
@@ -395,6 +406,7 @@ def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
395406 """
396407 if self .backend in {"numpy" , "jax.numpy" }:
397408 return self .xp .union1d (ar1 , ar2 )
409+
398410 if self .backend == "array_api_strict" :
399411 np = import_numpy (self .backend , "union1d" )
400412
@@ -452,6 +464,7 @@ def interp( # noqa: PLR0913
452464 return self .xp .interp (
453465 x , x_points , y_points , left = left , right = right , period = period
454466 )
467+
455468 if self .backend == "array_api_strict" :
456469 np = import_numpy (self .backend , "interp" )
457470
@@ -492,6 +505,7 @@ def gradient(self, f: AnyArray) -> AnyArray:
492505 """
493506 if self .backend in {"numpy" , "jax.numpy" }:
494507 return self .xp .gradient (f )
508+
495509 if self .backend == "array_api_strict" :
496510 np = import_numpy (self .backend , "gradient" )
497511
@@ -546,6 +560,7 @@ def linalg_lstsq(
546560 """
547561 if self .backend in {"numpy" , "jax.numpy" }:
548562 return self .xp .linalg .lstsq (a , b , rcond = rcond ) # type: ignore[no-any-return]
563+
549564 if self .backend == "array_api_strict" :
550565 np = import_numpy (self .backend , "linalg.lstsq" )
551566
@@ -585,6 +600,7 @@ def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
585600 """
586601 if self .backend in {"numpy" , "jax.numpy" }:
587602 return self .xp .einsum (subscripts , * operands )
603+
588604 if self .backend == "array_api_strict" :
589605 np = import_numpy (self .backend , "einsum" )
590606
@@ -637,6 +653,7 @@ def apply_along_axis(
637653 """
638654 if self .backend in {"numpy" , "jax.numpy" }:
639655 return self .xp .apply_along_axis (func1d , axis , arr , * args , ** kwargs )
656+
640657 if self .backend == "array_api_strict" :
641658 # Import here to prevent users relying on numpy unless in this instance
642659 np = import_numpy (self .backend , "apply_along_axis" )
@@ -679,6 +696,7 @@ def vectorize(
679696 """
680697 if self .backend == "numpy" :
681698 return self .xp .vectorize (pyfunc , otypes = otypes ) # type: ignore[no-any-return]
699+
682700 if self .backend in {"array_api_strict" , "jax.numpy" }:
683701 # Import here to prevent users relying on numpy unless in this instance
684702 np = import_numpy (self .backend , "vectorize" )
@@ -709,6 +727,7 @@ def radians(self, deg_arr: AnyArray) -> AnyArray:
709727 """
710728 if self .backend in {"numpy" , "jax.numpy" }:
711729 return self .xp .radians (deg_arr )
730+
712731 if self .backend == "array_api_strict" :
713732 np = import_numpy (self .backend , "radians" )
714733
@@ -738,6 +757,7 @@ def degrees(self, deg_arr: AnyArray) -> AnyArray:
738757 """
739758 if self .backend in {"numpy" , "jax.numpy" }:
740759 return self .xp .degrees (deg_arr )
760+
741761 if self .backend == "array_api_strict" :
742762 np = import_numpy (self .backend , "degrees" )
743763
0 commit comments