@@ -2477,6 +2477,121 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
2477
2477
return (a_h , ipiv_h )
2478
2478
2479
2479
2480
+ def dpnp_lu_solve (lu , piv , b , trans = 0 , overwrite_b = False , check_finite = True ):
2481
+ """
2482
+ dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)
2483
+
2484
+ Solve an equation system (SciPy-compatible behavior).
2485
+
2486
+ This function mimics the behavior of `scipy.linalg.lu_solve` including
2487
+ support for `trans`, `overwrite_b`, `check_finite`,
2488
+ and 0-based pivot indexing.
2489
+
2490
+ """
2491
+
2492
+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
2493
+
2494
+ res_type = _common_type (lu , b )
2495
+
2496
+ # TODO: add broadcasting
2497
+ if lu .shape [0 ] != b .shape [0 ]:
2498
+ raise ValueError (
2499
+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2500
+ )
2501
+
2502
+ if b .size == 0 :
2503
+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
2504
+
2505
+ if lu .ndim > 2 :
2506
+ raise NotImplementedError ("Batched matrices are not supported" )
2507
+
2508
+ if check_finite :
2509
+ if not dpnp .isfinite (lu ).all ():
2510
+ raise ValueError (
2511
+ "LU factorization array must not contain infs or NaNs.\n "
2512
+ "Note that when a singular matrix is given, unlike "
2513
+ "dpnp.linalg.lu_factor returns an array containing NaN."
2514
+ )
2515
+ if not dpnp .isfinite (b ).all ():
2516
+ raise ValueError (
2517
+ "Right-hand side array must not contain infs or NaNs"
2518
+ )
2519
+
2520
+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
2521
+ b_usm_arr = dpnp .get_usm_ndarray (b )
2522
+
2523
+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
2524
+ # convert to 1-based for oneMKL getrs
2525
+ piv_h = piv + 1
2526
+
2527
+ _manager = dpu .SequentialOrderManager [exec_q ]
2528
+ dep_evs = _manager .submitted_events
2529
+
2530
+ # oneMKL LAPACK getrs overwrites `lu`.
2531
+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
2532
+
2533
+ # use DPCTL tensor function to fill the сopy of the input array
2534
+ # from the input array
2535
+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2536
+ src = lu_usm_arr ,
2537
+ dst = lu_h .get_array (),
2538
+ sycl_queue = lu .sycl_queue ,
2539
+ depends = dep_evs ,
2540
+ )
2541
+ _manager .add_event_pair (ht_ev , lu_copy_ev )
2542
+
2543
+ # SciPy-compatible behavior
2544
+ # Copy is required if:
2545
+ # - overwrite_b is False (always copy),
2546
+ # - dtype mismatch,
2547
+ # - not F-contiguous,
2548
+ # - not writeable
2549
+ if not overwrite_b or _is_copy_required (b , res_type ):
2550
+ b_h = dpnp .empty_like (
2551
+ b , order = "F" , dtype = res_type , usm_type = res_usm_type
2552
+ )
2553
+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2554
+ src = b_usm_arr ,
2555
+ dst = b_h .get_array (),
2556
+ sycl_queue = b .sycl_queue ,
2557
+ depends = dep_evs ,
2558
+ )
2559
+ _manager .add_event_pair (ht_ev , b_copy_ev )
2560
+ dep_evs = [lu_copy_ev , b_copy_ev ]
2561
+ else :
2562
+ # input is suitable for in-place modification
2563
+ b_h = b
2564
+ dep_evs = [lu_copy_ev ]
2565
+
2566
+ if not isinstance (trans , int ):
2567
+ raise TypeError ("`trans` must be an integer" )
2568
+
2569
+ # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570
+ if trans == 0 :
2571
+ trans_mkl = li .Transpose .N
2572
+ elif trans == 1 :
2573
+ trans_mkl = li .Transpose .T
2574
+ elif trans == 2 :
2575
+ trans_mkl = li .Transpose .C
2576
+ else :
2577
+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
2578
+
2579
+ # Call the LAPACK extension function _getrs
2580
+ # to solve the system of linear equations with an LU-factored
2581
+ # coefficient square matrix, with multiple right-hand sides.
2582
+ ht_ev , getrs_ev = li ._getrs (
2583
+ exec_q ,
2584
+ lu_h .get_array (),
2585
+ piv_h .get_array (),
2586
+ b_h .get_array (),
2587
+ trans_mkl ,
2588
+ depends = dep_evs ,
2589
+ )
2590
+ _manager .add_event_pair (ht_ev , getrs_ev )
2591
+
2592
+ return b_h
2593
+
2594
+
2480
2595
def dpnp_matrix_power (a , n ):
2481
2596
"""
2482
2597
dpnp_matrix_power(a, n)
0 commit comments