Skip to content

Commit 9b53570

Browse files
committed
Fix a numpy 1.26/2.0 compatibility issue
1 parent b5b6680 commit 9b53570

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

array_api_strict/linalg.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,16 @@ def slogdet(x: Array, /) -> SlogdetResult:
305305
# To workaround this, the below is the code from np.linalg.solve except
306306
# only calling solve1 in the exactly 1D case.
307307
def _solve(a, b):
308-
from numpy.linalg._linalg import (
308+
try:
309+
from numpy.linalg._linalg import (
309310
_makearray, _assert_stacked_2d, _assert_stacked_square,
310311
_commonType, isComplexType, _raise_linalgerror_singular
311-
)
312+
)
313+
except ImportError:
314+
from numpy.linalg.linalg import (
315+
_makearray, _assert_stacked_2d, _assert_stacked_square,
316+
_commonType, isComplexType, _raise_linalgerror_singular
317+
)
312318
from numpy.linalg import _umath_linalg
313319

314320
a, _ = _makearray(a)

0 commit comments

Comments
 (0)