Skip to content

Commit cb067e6

Browse files
Implement dpnp.linalg.lu_solve() 2D inputs (#2575)
This PR suggests adding `dpnp.linalg.lu_solve()` for 2D arrays similar to scipy.linalg.lu_solve()
1 parent 1982dac commit cb067e6

File tree

10 files changed

+654
-1
lines changed

10 files changed

+654
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
1818
* Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565)
1919
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
20+
* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575)
2021

2122
### Changed
2223

doc/reference/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ Solving linear equations
8686
dpnp.linalg.solve
8787
dpnp.linalg.tensorsolve
8888
dpnp.linalg.lstsq
89+
dpnp.linalg.lu_solve
8990
dpnp.linalg.inv
9091
dpnp.linalg.pinv
9192
dpnp.linalg.tensorinv

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
dpnp_inv,
5858
dpnp_lstsq,
5959
dpnp_lu_factor,
60+
dpnp_lu_solve,
6061
dpnp_matrix_power,
6162
dpnp_matrix_rank,
6263
dpnp_multi_dot,
@@ -81,6 +82,7 @@
8182
"inv",
8283
"lstsq",
8384
"lu_factor",
85+
"lu_solve",
8486
"matmul",
8587
"matrix_norm",
8688
"matrix_power",
@@ -905,7 +907,7 @@ def lstsq(a, b, rcond=None):
905907

906908
def lu_factor(a, overwrite_a=False, check_finite=True):
907909
"""
908-
Compute the pivoted LU decomposition of a matrix.
910+
Compute the pivoted LU decomposition of `a` matrix.
909911
910912
The decomposition is::
911913
@@ -947,6 +949,11 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
947949
This function synchronizes in order to validate array elements
948950
when ``check_finite=True``.
949951
952+
See Also
953+
--------
954+
:obj:`dpnp.linalg.lu_solve` : Solve an equation system using
955+
the LU factorization of `a` matrix.
956+
950957
Examples
951958
--------
952959
>>> import dpnp as np
@@ -966,6 +973,81 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
966973
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)
967974

968975

976+
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
977+
"""
978+
Solve a linear system, :math:`a x = b`, given the LU factorization of `a`.
979+
980+
For full documentation refer to :obj:`scipy.linalg.lu_solve`.
981+
982+
Parameters
983+
----------
984+
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays}
985+
LU factorization of matrix `a` (M, M) together with pivot indices.
986+
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
987+
Right-hand side
988+
trans : {0, 1, 2} , optional
989+
Type of system to solve:
990+
991+
===== =================
992+
trans system
993+
===== =================
994+
0 :math:`a x = b`
995+
1 :math:`a^T x = b`
996+
2 :math:`a^H x = b`
997+
===== =================
998+
999+
Default: ``0``.
1000+
overwrite_b : {None, bool}, optional
1001+
Whether to overwrite data in `b` (may increase performance).
1002+
1003+
Default: ``False``.
1004+
check_finite : {None, bool}, optional
1005+
Whether to check that the input matrix contains only finite numbers.
1006+
Disabling may give a performance gain, but may result in problems
1007+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
1008+
1009+
Default: ``True``.
1010+
1011+
Returns
1012+
-------
1013+
x : {(M,), (M, K)} dpnp.ndarray
1014+
Solution to the system
1015+
1016+
Warning
1017+
-------
1018+
This function synchronizes in order to validate array elements
1019+
when ``check_finite=True``.
1020+
1021+
See Also
1022+
--------
1023+
:obj:`dpnp.linalg.lu_factor` : LU factorize a matrix.
1024+
1025+
Examples
1026+
--------
1027+
>>> import dpnp as np
1028+
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
1029+
>>> b = np.array([1, 1, 1, 1])
1030+
>>> lu, piv = np.linalg.lu_factor(A)
1031+
>>> x = np.linalg.lu_solve((lu, piv), b)
1032+
>>> np.allclose(A @ x - b, np.zeros((4,)))
1033+
array(True)
1034+
1035+
"""
1036+
1037+
(lu, piv) = lu_and_piv
1038+
dpnp.check_supported_arrays_type(lu, piv, b)
1039+
assert_stacked_2d(lu)
1040+
1041+
return dpnp_lu_solve(
1042+
lu,
1043+
piv,
1044+
b,
1045+
trans=trans,
1046+
overwrite_b=overwrite_b,
1047+
check_finite=check_finite,
1048+
)
1049+
1050+
9691051
def matmul(x1, x2, /):
9701052
"""
9711053
Computes the matrix product.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,6 +2477,121 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24772477
return (a_h, ipiv_h)
24782478

24792479

2480+
def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2481+
"""
2482+
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)
2483+
2484+
Solve an equation system (SciPy-compatible behavior).
2485+
2486+
This function mimics the behavior of `scipy.linalg.lu_solve` including
2487+
support for `trans`, `overwrite_b`, `check_finite`,
2488+
and 0-based pivot indexing.
2489+
2490+
"""
2491+
2492+
res_usm_type, exec_q = get_usm_allocations([lu, piv, b])
2493+
2494+
res_type = _common_type(lu, b)
2495+
2496+
# TODO: add broadcasting
2497+
if lu.shape[0] != b.shape[0]:
2498+
raise ValueError(
2499+
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
2500+
)
2501+
2502+
if b.size == 0:
2503+
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
2504+
2505+
if lu.ndim > 2:
2506+
raise NotImplementedError("Batched matrices are not supported")
2507+
2508+
if check_finite:
2509+
if not dpnp.isfinite(lu).all():
2510+
raise ValueError(
2511+
"LU factorization array must not contain infs or NaNs.\n"
2512+
"Note that when a singular matrix is given, unlike "
2513+
"dpnp.linalg.lu_factor returns an array containing NaN."
2514+
)
2515+
if not dpnp.isfinite(b).all():
2516+
raise ValueError(
2517+
"Right-hand side array must not contain infs or NaNs"
2518+
)
2519+
2520+
lu_usm_arr = dpnp.get_usm_ndarray(lu)
2521+
b_usm_arr = dpnp.get_usm_ndarray(b)
2522+
2523+
# dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
2524+
# convert to 1-based for oneMKL getrs
2525+
piv_h = piv + 1
2526+
2527+
_manager = dpu.SequentialOrderManager[exec_q]
2528+
dep_evs = _manager.submitted_events
2529+
2530+
# oneMKL LAPACK getrs overwrites `lu`.
2531+
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)
2532+
2533+
# use DPCTL tensor function to fill the сopy of the input array
2534+
# from the input array
2535+
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2536+
src=lu_usm_arr,
2537+
dst=lu_h.get_array(),
2538+
sycl_queue=lu.sycl_queue,
2539+
depends=dep_evs,
2540+
)
2541+
_manager.add_event_pair(ht_ev, lu_copy_ev)
2542+
2543+
# SciPy-compatible behavior
2544+
# Copy is required if:
2545+
# - overwrite_b is False (always copy),
2546+
# - dtype mismatch,
2547+
# - not F-contiguous,
2548+
# - not writeable
2549+
if not overwrite_b or _is_copy_required(b, res_type):
2550+
b_h = dpnp.empty_like(
2551+
b, order="F", dtype=res_type, usm_type=res_usm_type
2552+
)
2553+
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2554+
src=b_usm_arr,
2555+
dst=b_h.get_array(),
2556+
sycl_queue=b.sycl_queue,
2557+
depends=dep_evs,
2558+
)
2559+
_manager.add_event_pair(ht_ev, b_copy_ev)
2560+
dep_evs = [lu_copy_ev, b_copy_ev]
2561+
else:
2562+
# input is suitable for in-place modification
2563+
b_h = b
2564+
dep_evs = [lu_copy_ev]
2565+
2566+
if not isinstance(trans, int):
2567+
raise TypeError("`trans` must be an integer")
2568+
2569+
# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570+
if trans == 0:
2571+
trans_mkl = li.Transpose.N
2572+
elif trans == 1:
2573+
trans_mkl = li.Transpose.T
2574+
elif trans == 2:
2575+
trans_mkl = li.Transpose.C
2576+
else:
2577+
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
2578+
2579+
# Call the LAPACK extension function _getrs
2580+
# to solve the system of linear equations with an LU-factored
2581+
# coefficient square matrix, with multiple right-hand sides.
2582+
ht_ev, getrs_ev = li._getrs(
2583+
exec_q,
2584+
lu_h.get_array(),
2585+
piv_h.get_array(),
2586+
b_h.get_array(),
2587+
trans_mkl,
2588+
depends=dep_evs,
2589+
)
2590+
_manager.add_event_pair(ht_ev, getrs_ev)
2591+
2592+
return b_h
2593+
2594+
24802595
def dpnp_matrix_power(a, n):
24812596
"""
24822597
dpnp_matrix_power(a, n)

dpnp/tests/helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
from sys import platform
23

34
import dpctl
@@ -488,6 +489,14 @@ def is_ptl(device=None):
488489
return _get_dev_mask(device) in (0xB000, 0xFD00)
489490

490491

492+
def is_scipy_available():
493+
"""
494+
Return True if SciPy is installed and can be found,
495+
False otherwise.
496+
"""
497+
return importlib.util.find_spec("scipy") is not None
498+
499+
491500
def is_tgllp_iris_xe(device=None):
492501
"""
493502
Return True if a test is running on Tiger Lake-LP with Iris Xe GPU device,

0 commit comments

Comments
 (0)