From 13c43462f94da3f711afce32b544af5b429db500 Mon Sep 17 00:00:00 2001 From: Nader Rahhal <107228500+Nader-Rahhal@users.noreply.github.com> Date: Tue, 19 May 2026 12:05:03 -0500 Subject: [PATCH 01/14] init --- lib/cunumeric_jl_wrapper/include/ufi.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/cunumeric_jl_wrapper/include/ufi.h b/lib/cunumeric_jl_wrapper/include/ufi.h index f5558c16..d189c8fa 100644 --- a/lib/cunumeric_jl_wrapper/include/ufi.h +++ b/lib/cunumeric_jl_wrapper/include/ufi.h @@ -1,6 +1,6 @@ /* Copyright 2026 Northwestern University, * Carnegie Mellon University University - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at From a131d5ec6ecf3d356613eb24e651ee1a0b929dfa Mon Sep 17 00:00:00 2001 From: Nader Rahhal <107228500+Nader-Rahhal@users.noreply.github.com> Date: Wed, 20 May 2026 12:49:16 -0500 Subject: [PATCH 02/14] matrix solve --- lib/cunumeric_jl_wrapper/include/types.h | 3 + lib/cunumeric_jl_wrapper/src/types.cpp | 5 ++ lib/cunumeric_jl_wrapper/src/wrapper.cpp | 1 + src/ndarray/ndarray.jl | 108 +++++++++++++++++++++++ test/tests/linalg.jl | 37 ++++++++ 5 files changed, 154 insertions(+) diff --git a/lib/cunumeric_jl_wrapper/include/types.h b/lib/cunumeric_jl_wrapper/include/types.h index c75754db..88a18dc9 100644 --- a/lib/cunumeric_jl_wrapper/include/types.h +++ b/lib/cunumeric_jl_wrapper/include/types.h @@ -65,3 +65,6 @@ void wrap_unary_reds(jlcxx::Module&); // Binary op codes void wrap_binary_ops(jlcxx::Module&); + +// Linear algebra op codes +void wrap_linalg_ops(jlcxx::Module& mod); diff --git a/lib/cunumeric_jl_wrapper/src/types.cpp b/lib/cunumeric_jl_wrapper/src/types.cpp index 2e73ebd5..f181dc2e 100644 --- a/lib/cunumeric_jl_wrapper/src/types.cpp +++ b/lib/cunumeric_jl_wrapper/src/types.cpp @@ -162,3 +162,8 @@ void wrap_binary_ops(jlcxx::Module& mod) { mod.set_const("SUBTRACT", CuPyNumericBinaryOpCode::CUPYNUMERIC_BINOP_SUBTRACT); } + +void wrap_linalg_ops(jlcxx::Module& mod) { + mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE}); + mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE}); +} \ No newline at end of file diff --git a/lib/cunumeric_jl_wrapper/src/wrapper.cpp b/lib/cunumeric_jl_wrapper/src/wrapper.cpp index 29334333..e4b19331 100644 --- a/lib/cunumeric_jl_wrapper/src/wrapper.cpp +++ b/lib/cunumeric_jl_wrapper/src/wrapper.cpp @@ -68,6 +68,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) { wrap_unary_ops(mod); wrap_binary_ops(mod); wrap_unary_reds(mod); + wrap_linalg_ops(mod); using jlcxx::ParameterList; using jlcxx::Parametric; diff --git a/src/ndarray/ndarray.jl b/src/ndarray/ndarray.jl index d58f6572..31a51664 100644 --- a/src/ndarray/ndarray.jl +++ b/src/ndarray/ndarray.jl @@ -763,3 +763,111 @@ end function Base.isapprox(arr::NDArray{T}, arr2::NDArray{T}; atol=0, rtol=0) where {T} return compare(arr, arr2, atol, rtol) end + +function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} + la_handle = cuNumeric.get_store(arr) + st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr))) + return Legate.LogicalStore{T,N}(st_handle, size(arr)) +end + +function choose_nd_color_shape(shape::NTuple{N,Int}) where N + color_shape = Base.ones(Int, N) + if N > 2 + color_shape[1] = Legate.num_procs() + done = false + while !done && color_shape[1] % 2 == 0 + weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2] + max_weight, idx = findmax(weight_per_dim) + if weight_per_dim[idx] > 2 * weight_per_dim[1] + color_shape[1] ÷= 2 + color_shape[idx] *= 2 + else + done = true + end + end + end + return Tuple(color_shape) +end + +function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N + initial_color_shape = choose_nd_color_shape(full_shape) + tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N) + color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) + return tilesize, color_shape +end + +function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} + nrhs = size(b)[end] + full_shape = size(a) + tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) + tilesize_b = (tilesize_a[1:end-1]..., nrhs) + + store_a = nda_to_logical_store(a) + store_b = nda_to_logical_store(b) + store_x = nda_to_logical_store(x) + + tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a)) + tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b)) + tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b)) + + rt = Legate.get_runtime() + domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape))) + lib = cuNumeric.get_lib() + task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain) + + Legate.add_input(task, tiled_a) + Legate.add_input(task, tiled_b) + Legate.add_output(task, tiled_x) + + Legate.submit_manual_task(rt, task) +end + +function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} + if N < 2 + throw(ArgumentError("$(N)-dimensional array given. Array must be at least two-dimensional")) + end + if M < 1 + throw(ArgumentError("$(M)-dimensional array given. Array must be at least one-dimensional")) + end + if T == Float16 || S == Float16 + throw(ArgumentError("array type float16 is unsupported in linalg")) + end + if size(a)[end-1] != size(a)[end] + throw(ArgumentError("Last 2 dimensions of the array must be square")) + end + if N == 2 && size(a)[2] != size(b)[1] + if M == 1 + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + else + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + end + end + if N > 2 + if N != M + throw(ArgumentError( + "Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)" + )) + end + if size(a)[end] != size(b)[end-1] + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(M-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))" + )) + end + end + if prod(size(a)) == 0 || prod(size(b)) == 0 + return zeros(T, size(b)...) + end + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end \ No newline at end of file diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index 32a18100..e2e66d03 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -102,3 +102,40 @@ end @test sort(Array(out)) == sort(ref) end + +@testset "solve diagonal" begin + n = 4 + A = cuNumeric.zeros(Float64, n, n) + b = cuNumeric.zeros(Float64, n, 1) + cuNumeric.@allowscalar for i in 1:n + A[i, i] = 4.0 + b[i, 1] = 1.0 + end + x = cuNumeric.solve(A, b) + allowscalar() do + @test cuNumeric.compare(fill(0.25, n, 1), x, atol(Float64), rtol(Float64)) + end +end + +@testset "solve identity" begin + n = 4 + A = cuNumeric.NDArray(Matrix{Float64}(I, n, n)) + b = cuNumeric.NDArray(reshape(collect(1.0:n), n, 1)) + x = cuNumeric.solve(A, b) + ref = reshape(collect(1.0:n), n, 1) + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end +end + +@testset "solve general" begin + A_ref = [2.0 1.0; 5.0 7.0] + b_ref = [11.0; 13.0;;] + A = cuNumeric.NDArray(A_ref) + b = cuNumeric.NDArray(b_ref) + x = cuNumeric.solve(A, b) + ref = A_ref \ b_ref + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end +end From b5395b321a8d31f2d7061db1f8b831d38f5fca8a Mon Sep 17 00:00:00 2001 From: Nader Rahhal <107228500+Nader-Rahhal@users.noreply.github.com> Date: Wed, 20 May 2026 12:57:24 -0500 Subject: [PATCH 03/14] reorg files --- src/ndarray/detail/ndarray.jl | 6 ++ src/ndarray/linalg.jl | 101 +++++++++++++++++++++++++++++++ src/ndarray/ndarray.jl | 108 ---------------------------------- 3 files changed, 107 insertions(+), 108 deletions(-) create mode 100644 src/ndarray/linalg.jl diff --git a/src/ndarray/detail/ndarray.jl b/src/ndarray/detail/ndarray.jl index b11c4101..282dd568 100644 --- a/src/ndarray/detail/ndarray.jl +++ b/src/ndarray/detail/ndarray.jl @@ -494,3 +494,9 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real) # successful completion return true end + +function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} + la_handle = cuNumeric.get_store(arr) + st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr))) + return Legate.LogicalStore{T,N}(st_handle, size(arr)) +end diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl new file mode 100644 index 00000000..d430943c --- /dev/null +++ b/src/ndarray/linalg.jl @@ -0,0 +1,101 @@ +function choose_nd_color_shape(shape::NTuple{N,Int}) where N + color_shape = Base.ones(Int, N) + if N > 2 + color_shape[1] = Legate.num_procs() + done = false + while !done && color_shape[1] % 2 == 0 + weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2] + max_weight, idx = findmax(weight_per_dim) + if weight_per_dim[idx] > 2 * weight_per_dim[1] + color_shape[1] ÷= 2 + color_shape[idx] *= 2 + else + done = true + end + end + end + return Tuple(color_shape) +end + +function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N + initial_color_shape = choose_nd_color_shape(full_shape) + tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N) + color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) + return tilesize, color_shape +end + +function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} + nrhs = size(b)[end] + full_shape = size(a) + tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) + tilesize_b = (tilesize_a[1:end-1]..., nrhs) + + store_a = nda_to_logical_store(a) + store_b = nda_to_logical_store(b) + store_x = nda_to_logical_store(x) + + tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a)) + tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b)) + tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b)) + + rt = Legate.get_runtime() + domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape))) + lib = cuNumeric.get_lib() + task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain) + + Legate.add_input(task, tiled_a) + Legate.add_input(task, tiled_b) + Legate.add_output(task, tiled_x) + + Legate.submit_manual_task(rt, task) +end + +function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} + if N < 2 + throw(ArgumentError("$(N)-dimensional array given. Array must be at least two-dimensional")) + end + if M < 1 + throw(ArgumentError("$(M)-dimensional array given. Array must be at least one-dimensional")) + end + if T == Float16 || S == Float16 + throw(ArgumentError("array type float16 is unsupported in linalg")) + end + if size(a)[end-1] != size(a)[end] + throw(ArgumentError("Last 2 dimensions of the array must be square")) + end + if N == 2 && size(a)[2] != size(b)[1] + if M == 1 + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + else + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + end + end + if N > 2 + if N != M + throw(ArgumentError( + "Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)" + )) + end + if size(a)[end] != size(b)[end-1] + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(M-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))" + )) + end + end + if prod(size(a)) == 0 || prod(size(b)) == 0 + return zeros(T, size(b)...) + end + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end \ No newline at end of file diff --git a/src/ndarray/ndarray.jl b/src/ndarray/ndarray.jl index 31a51664..df20f2ce 100644 --- a/src/ndarray/ndarray.jl +++ b/src/ndarray/ndarray.jl @@ -762,112 +762,4 @@ end function Base.isapprox(arr::NDArray{T}, arr2::NDArray{T}; atol=0, rtol=0) where {T} return compare(arr, arr2, atol, rtol) -end - -function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} - la_handle = cuNumeric.get_store(arr) - st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr))) - return Legate.LogicalStore{T,N}(st_handle, size(arr)) -end - -function choose_nd_color_shape(shape::NTuple{N,Int}) where N - color_shape = Base.ones(Int, N) - if N > 2 - color_shape[1] = Legate.num_procs() - done = false - while !done && color_shape[1] % 2 == 0 - weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2] - max_weight, idx = findmax(weight_per_dim) - if weight_per_dim[idx] > 2 * weight_per_dim[1] - color_shape[1] ÷= 2 - color_shape[idx] *= 2 - else - done = true - end - end - end - return Tuple(color_shape) -end - -function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N - initial_color_shape = choose_nd_color_shape(full_shape) - tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N) - color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) - return tilesize, color_shape -end - -function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} - nrhs = size(b)[end] - full_shape = size(a) - tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) - tilesize_b = (tilesize_a[1:end-1]..., nrhs) - - store_a = nda_to_logical_store(a) - store_b = nda_to_logical_store(b) - store_x = nda_to_logical_store(x) - - tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a)) - tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b)) - tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b)) - - rt = Legate.get_runtime() - domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape))) - lib = cuNumeric.get_lib() - task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain) - - Legate.add_input(task, tiled_a) - Legate.add_input(task, tiled_b) - Legate.add_output(task, tiled_x) - - Legate.submit_manual_task(rt, task) -end - -function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} - if N < 2 - throw(ArgumentError("$(N)-dimensional array given. Array must be at least two-dimensional")) - end - if M < 1 - throw(ArgumentError("$(M)-dimensional array given. Array must be at least one-dimensional")) - end - if T == Float16 || S == Float16 - throw(ArgumentError("array type float16 is unsupported in linalg")) - end - if size(a)[end-1] != size(a)[end] - throw(ArgumentError("Last 2 dimensions of the array must be square")) - end - if N == 2 && size(a)[2] != size(b)[1] - if M == 1 - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) - else - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) - end - end - if N > 2 - if N != M - throw(ArgumentError( - "Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)" - )) - end - if size(a)[end] != size(b)[end-1] - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension " * - "$(M-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * - " (size $(size(b)[end-1]) is different from $(size(a)[end]))" - )) - end - end - if prod(size(a)) == 0 || prod(size(b)) == 0 - return zeros(T, size(b)...) - end - x = zeros(T, size(b)...) - solve_batched(a, b, x) - return x end \ No newline at end of file From ffcc4393cb8cc7b450dada18492ff0c3a3cdfc0e Mon Sep 17 00:00:00 2001 From: Nader Rahhal <107228500+Nader-Rahhal@users.noreply.github.com> Date: Sun, 24 May 2026 16:31:09 -0500 Subject: [PATCH 04/14] dynamic dispatch --- src/ndarray/linalg.jl | 113 ++++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 44 deletions(-) diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl index d430943c..29bab673 100644 --- a/src/ndarray/linalg.jl +++ b/src/ndarray/linalg.jl @@ -50,52 +50,77 @@ function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} Legate.submit_manual_task(rt, task) end -function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} - if N < 2 - throw(ArgumentError("$(N)-dimensional array given. Array must be at least two-dimensional")) - end - if M < 1 - throw(ArgumentError("$(M)-dimensional array given. Array must be at least one-dimensional")) - end - if T == Float16 || S == Float16 - throw(ArgumentError("array type float16 is unsupported in linalg")) - end - if size(a)[end-1] != size(a)[end] +# Dimension guards +function solve(a::NDArray{T,1}, b::NDArray{S,M}) where {T,S,M} + throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional")) +end + +function solve(a::NDArray{T,0}, b::NDArray{S,M}) where {T,S,M} + throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional")) +end + +function solve(a::NDArray{T,N}, b::NDArray{S,0}) where {T,N,S} + throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional")) +end + +# Float16 guards +function solve(a::NDArray{Float16,N}, b::NDArray{S,M}) where {N,S,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) +end + +function solve(a::NDArray{T,N}, b::NDArray{Float16,M}) where {T,N,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) +end + +# 2D case: (m,m),(m)->( m) +function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} + size(a)[end-1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) - end - if N == 2 && size(a)[2] != size(b)[1] - if M == 1 - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) - else - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) - end - end - if N > 2 - if N != M - throw(ArgumentError( - "Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)" - )) - end - if size(a)[end] != size(b)[end-1] - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension " * - "$(M-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * - " (size $(size(b)[end-1]) is different from $(size(a)[end]))" - )) - end - end - if prod(size(a)) == 0 || prod(size(b)) == 0 - return zeros(T, size(b)...) - end + size(a)[2] != size(b)[1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# 2D case: (m,m),(m,n)->(m,n) +function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} + size(a)[end-1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[2] != size(b)[1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) x = zeros(T, size(b)...) solve_batched(a, b, x) return x +end + +# Batched case: (...,m,m),(...,m,n)->(...,m,n) +function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} + size(a)[end-1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[end] != size(b)[end-1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# Mismatched batch dimensions +function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} + throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)")) end \ No newline at end of file From a98f2ec39276670c8859f5c397b2e643bc5c8a65 Mon Sep 17 00:00:00 2001 From: krasow Date: Sun, 31 May 2026 13:57:20 -0500 Subject: [PATCH 05/14] downgrade Project.toml wrapper version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 648882fa..c898d6b4 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,6 @@ Pkg = "1" Preferences = "1" Random = "1" StatsBase = "0.34" -cunumeric_jl_wrapper_jll = "25.10.4" +cunumeric_jl_wrapper_jll = "25.10.3" cupynumeric_jll = "25.10.3" julia = "1.10" From efdfa48bf08bf65e1e68048bcf14077a3aaef794 Mon Sep 17 00:00:00 2001 From: krasow Date: Sun, 31 May 2026 14:25:45 -0500 Subject: [PATCH 06/14] add include for linalg.jl --- src/cuNumeric.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index 3a9280e6..0380251c 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -144,6 +144,7 @@ include("ndarray/broadcast.jl") include("ndarray/ndarray.jl") include("ndarray/unary.jl") include("ndarray/binary.jl") +include("ndarray/linalg.jl") # scoping macro include("scoping.jl") @@ -230,7 +231,7 @@ function __init__() _is_precompiling() && return nothing # Cannot set LEGATE_CONFIG on CI machines used - # to register packages. So we will just skip starting + # to register packages. So we will just skip starting # legate/cunumeric when using registry CI machines. get(ENV, "JULIA_REGISTRYCI_AUTOMERGE", false) == "true" && return nothing From 5727aad2bbffb478d8b67460850210accce7fd62 Mon Sep 17 00:00:00 2001 From: krasow Date: Sun, 31 May 2026 14:58:35 -0500 Subject: [PATCH 07/14] SUPPORTED_FLOAT_TYPES --- src/cuNumeric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index 0380251c..8125a469 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -55,7 +55,7 @@ const DEFAULT_FLOAT = Float32 const DEFAULT_INT = Int32 const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} -const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} # Float16 not supported yet +const SUPPORTED_FLOAT_TYPES = Union{Float16,Float32,Float64} const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64} const SUPPORTED_NUMERIC_TYPES = Union{ From b12e3e0801ac97d5189ded7ff3e67d1e0b1e4e58 Mon Sep 17 00:00:00 2001 From: krasow Date: Sun, 31 May 2026 18:38:59 -0500 Subject: [PATCH 08/14] redo guards --- src/cuNumeric.jl | 3 +- src/ndarray/linalg.jl | 72 ++++++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index 8125a469..bc3fed48 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -55,7 +55,8 @@ const DEFAULT_FLOAT = Float32 const DEFAULT_INT = Int32 const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} -const SUPPORTED_FLOAT_TYPES = Union{Float16,Float32,Float64} +# Float16 is only supported by the backend when built with CUDA +const SUPPORTED_FLOAT_TYPES = HAS_CUDA ? Union{Float16,Float32,Float64} : Union{Float32,Float64} const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64} const SUPPORTED_NUMERIC_TYPES = Union{ diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl index 29bab673..6b25882d 100644 --- a/src/ndarray/linalg.jl +++ b/src/ndarray/linalg.jl @@ -1,10 +1,10 @@ -function choose_nd_color_shape(shape::NTuple{N,Int}) where N +function choose_nd_color_shape(shape::NTuple{N,Int}) where {N} color_shape = Base.ones(Int, N) if N > 2 - color_shape[1] = Legate.num_procs() + color_shape[1] = Legate.num_procs() done = false while !done && color_shape[1] % 2 == 0 - weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2] + weight_per_dim = [shape[i] / color_shape[i] for i in 1:(N - 2)] max_weight, idx = findmax(weight_per_dim) if weight_per_dim[idx] > 2 * weight_per_dim[1] color_shape[1] ÷= 2 @@ -17,9 +17,11 @@ function choose_nd_color_shape(shape::NTuple{N,Int}) where N return Tuple(color_shape) end -function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N +function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where {N} initial_color_shape = choose_nd_color_shape(full_shape) - tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N) + tilesize = Tuple( + (full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N + ) color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) return tilesize, color_shape end @@ -28,7 +30,7 @@ function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} nrhs = size(b)[end] full_shape = size(a) tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) - tilesize_b = (tilesize_a[1:end-1]..., nrhs) + tilesize_b = (tilesize_a[1:(end - 1)]..., nrhs) store_a = nda_to_logical_store(a) store_b = nda_to_logical_store(b) @@ -64,24 +66,28 @@ function solve(a::NDArray{T,N}, b::NDArray{S,0}) where {T,N,S} end # Float16 guards -function solve(a::NDArray{Float16,N}, b::NDArray{S,M}) where {N,S,M} - throw(ArgumentError("array type float16 is unsupported in linalg")) -end +@static if HAS_CUDA + function solve(a::NDArray{Float16,N}, b::NDArray{S,M}) where {N,S,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) + end -function solve(a::NDArray{T,N}, b::NDArray{Float16,M}) where {T,N,M} - throw(ArgumentError("array type float16 is unsupported in linalg")) + function solve(a::NDArray{T,N}, b::NDArray{Float16,M}) where {T,N,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) + end end # 2D case: (m,m),(m)->( m) function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} - size(a)[end-1] != size(a)[end] && + size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) size(a)[2] != size(b)[1] && - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) + throw( + ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))", + ), + ) prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) x = zeros(T, size(b)...) solve_batched(a, b, x) @@ -90,14 +96,16 @@ end # 2D case: (m,m),(m,n)->(m,n) function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} - size(a)[end-1] != size(a)[end] && + size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) size(a)[2] != size(b)[1] && - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))" - )) + throw( + ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))", + ), + ) prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) x = zeros(T, size(b)...) solve_batched(a, b, x) @@ -106,14 +114,16 @@ end # Batched case: (...,m,m),(...,m,n)->(...,m,n) function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} - size(a)[end-1] != size(a)[end] && + size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) - size(a)[end] != size(b)[end-1] && - throw(ArgumentError( - "Input operand 1 has a mismatch in its dimension " * - "$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * - " (size $(size(b)[end-1]) is different from $(size(a)[end]))" - )) + size(a)[end] != size(b)[end - 1] && + throw( + ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))", + ), + ) prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) x = zeros(T, size(b)...) solve_batched(a, b, x) @@ -123,4 +133,4 @@ end # Mismatched batch dimensions function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)")) -end \ No newline at end of file +end From 803f950c49f41f6930ed2e702b20aac3c431111b Mon Sep 17 00:00:00 2001 From: krasow Date: Sun, 31 May 2026 20:58:26 -0500 Subject: [PATCH 09/14] fix dereference issue --- src/ndarray/detail/ndarray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ndarray/detail/ndarray.jl b/src/ndarray/detail/ndarray.jl index 21b6d8e1..fff056ab 100644 --- a/src/ndarray/detail/ndarray.jl +++ b/src/ndarray/detail/ndarray.jl @@ -503,7 +503,7 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real) end function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} - la_handle = cuNumeric.get_store(arr) - st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr))) + la_handle = cuNumeric.get_store(arr) # LogicalArrayImplAllocated (returned by value) + st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle, size(arr))) return Legate.LogicalStore{T,N}(st_handle, size(arr)) end From abe8f1073b2679c1fa16dc2dc3e03cad6e2e741a Mon Sep 17 00:00:00 2001 From: krasow Date: Mon, 1 Jun 2026 10:15:41 -0500 Subject: [PATCH 10/14] static guard check --- src/cuNumeric.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index bc3fed48..926b0c3e 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -56,7 +56,12 @@ const DEFAULT_INT = Int32 const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} # Float16 is only supported by the backend when built with CUDA -const SUPPORTED_FLOAT_TYPES = HAS_CUDA ? Union{Float16,Float32,Float64} : Union{Float32,Float64} +@static if HAS_CUDA + const SUPPORTED_FLOAT_TYPES = Union{Float16,Float32,Float64} +else + const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} +end + const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64} const SUPPORTED_NUMERIC_TYPES = Union{ From fa07db738877c99370c4dfd1d956db58d5835493 Mon Sep 17 00:00:00 2001 From: Nader-Rahhal Date: Mon, 1 Jun 2026 16:59:45 -0400 Subject: [PATCH 11/14] add type coverage --- test/tests/linalg.jl | 61 +++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index e2e66d03..dc716f58 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -104,38 +104,47 @@ end end @testset "solve diagonal" begin - n = 4 - A = cuNumeric.zeros(Float64, n, n) - b = cuNumeric.zeros(Float64, n, 1) - cuNumeric.@allowscalar for i in 1:n - A[i, i] = 4.0 - b[i, 1] = 1.0 - end - x = cuNumeric.solve(A, b) - allowscalar() do - @test cuNumeric.compare(fill(0.25, n, 1), x, atol(Float64), rtol(Float64)) + @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) + n = 4 + T_comp = T <: Integer ? Float64 : T + A = cuNumeric.zeros(T, n, n) + b = cuNumeric.zeros(T, n, 1) + cuNumeric.@allowscalar for i in 1:n + A[i, i] = T(4) + b[i, 1] = T(1) + end + x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) + allowscalar() do + @test cuNumeric.compare(fill(T_comp(0.25), n, 1), x, atol(T_comp), rtol(T_comp)) + end end end @testset "solve identity" begin - n = 4 - A = cuNumeric.NDArray(Matrix{Float64}(I, n, n)) - b = cuNumeric.NDArray(reshape(collect(1.0:n), n, 1)) - x = cuNumeric.solve(A, b) - ref = reshape(collect(1.0:n), n, 1) - allowscalar() do - @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) + n = 4 + T_comp = T <: Integer ? Float64 : T + A = cuNumeric.NDArray(Matrix{T}(I, n, n)) + b = cuNumeric.NDArray(reshape(T.(collect(1:n)), n, 1)) + x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) + ref = reshape(Float64.(collect(1:n)), n, 1) + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T_comp), rtol(T_comp)) + end end end @testset "solve general" begin - A_ref = [2.0 1.0; 5.0 7.0] - b_ref = [11.0; 13.0;;] - A = cuNumeric.NDArray(A_ref) - b = cuNumeric.NDArray(b_ref) - x = cuNumeric.solve(A, b) - ref = A_ref \ b_ref - allowscalar() do - @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) + T_comp = T <: Integer ? Float64 : T + A_ref = T[2 1; 5 7] + b_ref = T[11; 13;;] + A = cuNumeric.NDArray(A_ref) + b = cuNumeric.NDArray(b_ref) + x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) + ref = Float64.(A_ref) \ Float64.(b_ref) + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T_comp), rtol(T_comp)) + end end -end +end \ No newline at end of file From 1c89944f3a467e47f0878ce0d3a809edbcf70a66 Mon Sep 17 00:00:00 2001 From: krasow Date: Tue, 2 Jun 2026 11:37:47 -0500 Subject: [PATCH 12/14] update solve and linalg test cases --- src/cuNumeric.jl | 8 ++++ src/ndarray/linalg.jl | 42 +++++++++-------- test/tests/linalg.jl | 107 +++++++++++++++++++++++------------------- 3 files changed, 90 insertions(+), 67 deletions(-) diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index 926b0c3e..b4f18637 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -67,6 +67,14 @@ const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64} const SUPPORTED_NUMERIC_TYPES = Union{ SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES } + +const SUPPORTED_LINALG_TYPES = Union{ + SUPPORTED_INT_TYPES,Float32,Float64,SUPPORTED_COMPLEX_TYPES +} + +# solve has no integer/Float16 backend kernel — float/complex only. +const SUPPORTED_SOLVE_TYPES = Union{Float32,Float64,SUPPORTED_COMPLEX_TYPES} + const SUPPORTED_ARRAY_TYPES = Union{Bool,SUPPORTED_NUMERIC_TYPES} const SUPPORTED_TYPES = Union{SUPPORTED_ARRAY_TYPES,String} diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl index 6b25882d..547b361b 100644 --- a/src/ndarray/linalg.jl +++ b/src/ndarray/linalg.jl @@ -52,32 +52,34 @@ function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} Legate.submit_manual_task(rt, task) end -# Dimension guards -function solve(a::NDArray{T,1}, b::NDArray{S,M}) where {T,S,M} - throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional")) +# Type/dim guards dispatch on one argument at a time, then forward to `_solve`. +function solve( + a::NDArray{<:SUPPORTED_SOLVE_TYPES}, b::NDArray{<:SUPPORTED_SOLVE_TYPES} +) + return _solve_check_a_dims(a, b) end -function solve(a::NDArray{T,0}, b::NDArray{S,M}) where {T,S,M} - throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional")) +function solve(a::NDArray, b::NDArray) + bad = eltype(a) <: SUPPORTED_SOLVE_TYPES ? eltype(b) : eltype(a) + throw(ArgumentError("array type $bad is unsupported in solve")) end -function solve(a::NDArray{T,N}, b::NDArray{S,0}) where {T,N,S} - throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional")) +# `a` must be at least 2D, `b` at least 1D. +function _solve_check_a_dims(a::NDArray{<:Any,0}, b::NDArray) + throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional")) end +function _solve_check_a_dims(a::NDArray{<:Any,1}, b::NDArray) + throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional")) +end +_solve_check_a_dims(a::NDArray, b::NDArray) = _solve_check_b_dims(a, b) -# Float16 guards -@static if HAS_CUDA - function solve(a::NDArray{Float16,N}, b::NDArray{S,M}) where {N,S,M} - throw(ArgumentError("array type float16 is unsupported in linalg")) - end - - function solve(a::NDArray{T,N}, b::NDArray{Float16,M}) where {T,N,M} - throw(ArgumentError("array type float16 is unsupported in linalg")) - end +function _solve_check_b_dims(a::NDArray, b::NDArray{<:Any,0}) + throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional")) end +_solve_check_b_dims(a::NDArray, b::NDArray) = _solve(a, b) # 2D case: (m,m),(m)->( m) -function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} +function _solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) size(a)[2] != size(b)[1] && @@ -95,7 +97,7 @@ function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} end # 2D case: (m,m),(m,n)->(m,n) -function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} +function _solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) size(a)[2] != size(b)[1] && @@ -113,7 +115,7 @@ function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} end # Batched case: (...,m,m),(...,m,n)->(...,m,n) -function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} +function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) size(a)[end] != size(b)[end - 1] && @@ -131,6 +133,6 @@ function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} end # Mismatched batch dimensions -function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} +function _solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)")) end diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index dc716f58..40dcf1e9 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -19,19 +19,21 @@ =# @testset "transpose" begin - A = rand(Float64, 4, 3) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 4, 3) + nda = cuNumeric.NDArray(A) - ref = transpose(A) - out = cuNumeric.transpose(nda) + ref = transpose(A) + out = cuNumeric.transpose(nda) - allowscalar() do - @test cuNumeric.compare(ref, out, atol(Float64), rtol(Float64)) + allowscalar() do + @test cuNumeric.compare(ref, out, atol(T), rtol(T)) + end end end @testset "eye" begin - for T in (Float32, Float64, Int32) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) n = 5 ref = Matrix{T}(I, n, n) out = cuNumeric.eye(n; T=T) @@ -42,41 +44,53 @@ end end @testset "trace" begin - A = rand(Float64, 6, 6) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + # float accumulation for ints for overflow prevention + T_acc = T <: Integer ? Float64 : T - ref = tr(A) - out = cuNumeric.trace(nda) + A = my_rand(T, 6, 6) + nda = cuNumeric.NDArray(A) - allowscalar() do - @test ref ≈ out[1] atol=atol(Float32) rtol=rtol(Float32) + ref = sum(T_acc.(diag(A))) + out = cuNumeric.trace(nda; T=T_acc) + + allowscalar() do + @test ref ≈ out[1] atol=atol(T_acc) rtol=rtol(T_acc) + end end end @testset "trace with offset" begin - A = rand(Float32, 5, 5) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + # float accumulation for ints for overflow prevention + T_acc = T <: Integer ? Float64 : T - for k in (-2, -1, 0, 1, 2) - ref = sum(diag(A, k)) - out = cuNumeric.trace(nda; offset=k) + A = my_rand(T, 5, 5) + nda = cuNumeric.NDArray(A) - allowscalar() do - @test ref ≈ out[1] atol=atol(Float32) rtol=rtol(Float32) + @testset "offset=$(k)" for k in (-2, -1, 0, 1, 2) + ref = sum(T_acc.(diag(A, k))) + out = cuNumeric.trace(nda; offset=k, T=T_acc) + + allowscalar() do + @test ref ≈ out[1] atol=atol(T_acc) rtol=rtol(T_acc) + end end end end @testset "diag" begin - A = rand(Int, 6, 6) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 6, 6) + nda = cuNumeric.NDArray(A) - for k in (-2, 0, 3) - ref = diag(A, k) - out = cuNumeric.diag(nda; k=k) + @testset "k=$(k)" for k in (-2, 0, 3) + ref = diag(A, k) + out = cuNumeric.diag(nda; k=k) - allowscalar() do - @test cuNumeric.compare(ref, out, atol(Int32), rtol(Int32)) + allowscalar() do + @test cuNumeric.compare(ref, out, atol(T), rtol(T)) + end end end end @@ -94,57 +108,56 @@ end # end @testset "unique" begin - A = [1, 2, 2, 3, 4, 4, 4, 5] - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = T[1, 2, 2, 3, 4, 4, 4, 5] + nda = cuNumeric.NDArray(A) - ref = unique(A) - out = cuNumeric.unique(nda) + ref = unique(A) + out = cuNumeric.unique(nda) - @test sort(Array(out)) == sort(ref) + @test Set(Array(out)) == Set(ref) + end end @testset "solve diagonal" begin - @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) n = 4 - T_comp = T <: Integer ? Float64 : T A = cuNumeric.zeros(T, n, n) b = cuNumeric.zeros(T, n, 1) cuNumeric.@allowscalar for i in 1:n A[i, i] = T(4) b[i, 1] = T(1) end - x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) + x = cuNumeric.solve(A, b) allowscalar() do - @test cuNumeric.compare(fill(T_comp(0.25), n, 1), x, atol(T_comp), rtol(T_comp)) + @test cuNumeric.compare(fill(T(0.25), n, 1), x, atol(T), rtol(T)) end end end @testset "solve identity" begin - @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) n = 4 - T_comp = T <: Integer ? Float64 : T A = cuNumeric.NDArray(Matrix{T}(I, n, n)) b = cuNumeric.NDArray(reshape(T.(collect(1:n)), n, 1)) - x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) - ref = reshape(Float64.(collect(1:n)), n, 1) + x = cuNumeric.solve(A, b) + ref = reshape(T.(collect(1:n)), n, 1) allowscalar() do - @test cuNumeric.compare(ref, x, atol(T_comp), rtol(T_comp)) + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) end end end @testset "solve general" begin - @testset for T in (Float32, Float64, ComplexF32, ComplexF64, Int8, Int16, Int32, Int64) - T_comp = T <: Integer ? Float64 : T + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) A_ref = T[2 1; 5 7] b_ref = T[11; 13;;] A = cuNumeric.NDArray(A_ref) b = cuNumeric.NDArray(b_ref) - x = cuNumeric.solve(cuNumeric.as_type(A, T_comp), cuNumeric.as_type(b, T_comp)) - ref = Float64.(A_ref) \ Float64.(b_ref) + x = cuNumeric.solve(A, b) + ref = A_ref \ b_ref allowscalar() do - @test cuNumeric.compare(ref, x, atol(T_comp), rtol(T_comp)) + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) end end -end \ No newline at end of file +end From ba9c535e4c1a4c02f7cec14ecd87aa82bacc8bbc Mon Sep 17 00:00:00 2001 From: krasow Date: Tue, 2 Jun 2026 14:07:44 -0500 Subject: [PATCH 13/14] add promition and stability checks to solve --- src/ndarray/linalg.jl | 59 ++++++++++++++--------------------------- test/tests/linalg.jl | 34 +++++++++++++++++++++++- test/tests/stability.jl | 22 ++++++++++++++- 3 files changed, 74 insertions(+), 41 deletions(-) diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl index 547b361b..1c01b44a 100644 --- a/src/ndarray/linalg.jl +++ b/src/ndarray/linalg.jl @@ -52,15 +52,25 @@ function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} Legate.submit_manual_task(rt, task) end +# solve runs in floating point: +# int/bool inputs promote to Float64 (matching cupynumeric) +const _SOLVE_PROMOTABLE = Union{SUPPORTED_INT_TYPES,Bool} +const _SOLVE_ACCEPTED = Union{SUPPORTED_SOLVE_TYPES,_SOLVE_PROMOTABLE} +_solve_eltype(::Type{T}) where {T<:_SOLVE_PROMOTABLE} = Float64 +_solve_eltype(::Type{T}) where {T<:SUPPORTED_SOLVE_TYPES} = T + # Type/dim guards dispatch on one argument at a time, then forward to `_solve`. -function solve( - a::NDArray{<:SUPPORTED_SOLVE_TYPES}, b::NDArray{<:SUPPORTED_SOLVE_TYPES} -) - return _solve_check_a_dims(a, b) +function solve(a::NDArray{<:_SOLVE_ACCEPTED}, b::NDArray{<:_SOLVE_ACCEPTED}) + A, B = eltype(a), eltype(b) + O = promote_type(_solve_eltype(A), _solve_eltype(B)) + # int/bool -> float is an implicit promotion, disallowed unless `allowpromotion` + A <: _SOLVE_PROMOTABLE && assertpromotion(solve, A, O) + B <: _SOLVE_PROMOTABLE && assertpromotion(solve, B, O) + return _solve_check_a_dims(unchecked_promote_arr(a, O), unchecked_promote_arr(b, O)) end function solve(a::NDArray, b::NDArray) - bad = eltype(a) <: SUPPORTED_SOLVE_TYPES ? eltype(b) : eltype(a) + bad = eltype(a) <: _SOLVE_ACCEPTED ? eltype(b) : eltype(a) throw(ArgumentError("array type $bad is unsupported in solve")) end @@ -78,43 +88,14 @@ function _solve_check_b_dims(a::NDArray, b::NDArray{<:Any,0}) end _solve_check_b_dims(a::NDArray, b::NDArray) = _solve(a, b) -# 2D case: (m,m),(m)->( m) +# 2D case: (m,m),(m)->(m). +# Backend needs rhs "b" to be 2D. We reshape b from (n,) to (n,1) function _solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} - size(a)[end - 1] != size(a)[end] && - throw(ArgumentError("Last 2 dimensions of the array must be square")) - size(a)[2] != size(b)[1] && - throw( - ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))", - ), - ) - prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) - x = zeros(T, size(b)...) - solve_batched(a, b, x) - return x -end - -# 2D case: (m,m),(m,n)->(m,n) -function _solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} - size(a)[end - 1] != size(a)[end] && - throw(ArgumentError("Last 2 dimensions of the array must be square")) - size(a)[2] != size(b)[1] && - throw( - ArgumentError( - "Input operand 1 has a mismatch in its dimension 0, " * - "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * - "is different from $(size(a)[2]))", - ), - ) - prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) - x = zeros(T, size(b)...) - solve_batched(a, b, x) - return x + m = size(b)[1] + return reshape(_solve(a, reshape(b, (m, 1))), (m,)) end -# Batched case: (...,m,m),(...,m,n)->(...,m,n) +# 2D (m,m),(m,n)->(m,n) and batched (...,m,m),(...,m,n)->(...,m,n) function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} size(a)[end - 1] != size(a)[end] && throw(ArgumentError("Last 2 dimensions of the array must be square")) diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index 40dcf1e9..9062135d 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -151,7 +151,7 @@ end @testset "solve general" begin @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) A_ref = T[2 1; 5 7] - b_ref = T[11; 13;;] + b_ref = T[11; 13;;] # creates a 2d matrix instead of vector A = cuNumeric.NDArray(A_ref) b = cuNumeric.NDArray(b_ref) x = cuNumeric.solve(A, b) @@ -161,3 +161,35 @@ end end end end + +@testset "solve vector rhs" begin + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + A_ref = T[2 1; 5 7] + b_ref = T[11, 13] + x = cuNumeric.solve(cuNumeric.NDArray(A_ref), cuNumeric.NDArray(b_ref)) + @test ndims(x) == 1 + ref = A_ref \ b_ref + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) + end + end +end + +@testset "solve promotion" begin + @testset verbose=true for T in (Int32, Int64, Bool) + A = cuNumeric.NDArray(T[1 0; 0 1]) + b = cuNumeric.NDArray(reshape(T[1, 1], 2, 1)) + + # int/bool requires promotion to float. Will throw without allowpromtion() + @test_throws "Implicit promotion" cuNumeric.solve(A, b) + + # ...allowed under @allowpromotion, result is Float64 + allowpromotion() do + x = cuNumeric.solve(A, b) + ref = Float64[1 0; 0 1] \ Float64[1; 1;;] + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end + end + end +end diff --git a/test/tests/stability.jl b/test/tests/stability.jl index 648f4bc4..13d6fad0 100644 --- a/test/tests/stability.jl +++ b/test/tests/stability.jl @@ -1,4 +1,4 @@ -#= Copyright 2026 Northwestern University, +#= Copyright 2026 Northwestern University, * Carnegie Mellon University University * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -98,3 +98,23 @@ end @test @inferred(a ./ b) !== nothing @test @inferred(((a .* b) .+ a) .* 2.0f0) !== nothing end + +@testset verbose = true "solve" begin + # native float/complex, 2D and 1D rhs + @testset "$(T)" for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + A = cuNumeric.NDArray(T[2 1; 5 7]) + b2 = cuNumeric.NDArray(T[11; 13;;]) # creates a 2d matrix instead of vector + b1 = cuNumeric.NDArray(T[11, 13]) + @test @inferred(cuNumeric.solve(A, b2)) !== nothing + @test @inferred(cuNumeric.solve(A, b1)) !== nothing + end + + # int/bool promote to Float64 (under allowpromotion) and stay inferrable + @testset "promote $(T)" for T in (Int32, Int64, Bool) + A = cuNumeric.NDArray(T[1 0; 0 1]) + b = cuNumeric.NDArray(reshape(T[1, 1], 2, 1)) + allowpromotion() do + @test @inferred(cuNumeric.solve(A, b)) !== nothing + end + end +end From b41cd2352fdf581b501089007ec23a805fea8597 Mon Sep 17 00:00:00 2001 From: krasow Date: Tue, 2 Jun 2026 15:21:19 -0500 Subject: [PATCH 14/14] fix stability in trace and eye. provide testing support. change public API --- src/ndarray/detail/ndarray.jl | 16 ++++++++++------ src/ndarray/ndarray.jl | 24 +++++++++++++++--------- test/tests/linalg.jl | 22 +++++++--------------- test/tests/stability.jl | 13 +++++++++++++ 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/src/ndarray/detail/ndarray.jl b/src/ndarray/detail/ndarray.jl index fff056ab..9d630fcc 100644 --- a/src/ndarray/detail/ndarray.jl +++ b/src/ndarray/detail/ndarray.jl @@ -236,18 +236,21 @@ function nda_array_equal(rhs1::NDArray{T,N}, rhs2::NDArray{T,N}) where {T,N} return NDArray(ptr, Bool, Val(1)) end -function nda_diag(arr::NDArray, k::Int32) +# 2D -> 1D: extract the k-th diagonal. Backend only supports the 2D case +# (1D-construct and >2D both abort), so non-2D input is a MethodError. +function nda_diag(arr::NDArray{T,2}, k::Int32) where {T} ptr = ccall((:nda_diag, libnda), NDArray_t, (NDArray_t, Int32), arr.ptr, k) - return NDArray(ptr) + return NDArray(ptr, T, Val(1)) end -function nda_unique(arr::NDArray) +# unique always returns a flat 1D array of the input's element type +function nda_unique(arr::NDArray{T}) where {T} ptr = ccall((:nda_unique, libnda), NDArray_t, (NDArray_t,), arr.ptr) - return NDArray(ptr) + return NDArray(ptr, T, Val(1)) end function nda_ravel(arr::NDArray) @@ -315,11 +318,12 @@ function nda_trace( return NDArray(ptr, T, Val(1)) end -function nda_transpose(arr::NDArray) +# transpose reverses the axes: element type and rank are preserved +function nda_transpose(arr::NDArray{T,N}) where {T,N} ptr = ccall((:nda_transpose, libnda), NDArray_t, (NDArray_t,), arr.ptr) - return NDArray(ptr) + return NDArray(ptr, T, Val(N)) end function nda_attach_external(arr::AbstractArray{T,N}) where {T,N} diff --git a/src/ndarray/ndarray.jl b/src/ndarray/ndarray.jl index df20f2ce..086a60ca 100644 --- a/src/ndarray/ndarray.jl +++ b/src/ndarray/ndarray.jl @@ -30,21 +30,27 @@ function transpose(arr::NDArray) end @doc""" - cuNumeric.eye(rows::Int; T=Float32) + cuNumeric.eye([T,] rows::Int) -Create a 2D identity `NDArray` of size `rows x rows` with element type `T`. +Create a 2D identity `NDArray` of size `rows x rows` with element type `T` +(defaults to `DEFAULT_FLOAT`). """ -function eye(rows::Int; T::Type{S}=Float64) where {S} - return nda_eye(Int32(rows), S) +function eye(::Type{T}, rows::Int) where {T} + return nda_eye(Int32(rows), T) +end +function eye(rows::Int) + return eye(DEFAULT_FLOAT, rows) end @doc""" - cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1, T=Float32) + cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1) -Compute the trace of the `NDArray` along the specified axes. +Compute the trace (sum of a diagonal) of the `NDArray`. +The accumulator type follows promotions of other reductions like 'sum'. """ -function trace(arr::NDArray; offset::Int=0, a1::Int=0, a2::Int=1, T::Type{S}=Float32) where {S} - return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), S) +function trace(arr::NDArray{T}; offset::Int=0, a1::Int=0, a2::Int=1) where {T} + T_OUT = Base.promote_op(Base.sum, Vector{T}) + return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), T_OUT) end @doc""" @@ -762,4 +768,4 @@ end function Base.isapprox(arr::NDArray{T}, arr2::NDArray{T}; atol=0, rtol=0) where {T} return compare(arr, arr2, atol, rtol) -end \ No newline at end of file +end diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index 9062135d..8f9c060d 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -36,7 +36,7 @@ end @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) n = 5 ref = Matrix{T}(I, n, n) - out = cuNumeric.eye(n; T=T) + out = cuNumeric.eye(T, n) allowscalar() do @test cuNumeric.compare(ref, out, atol(T), rtol(T)) end @@ -45,35 +45,27 @@ end @testset "trace" begin @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) - # float accumulation for ints for overflow prevention - T_acc = T <: Integer ? Float64 : T - A = my_rand(T, 6, 6) nda = cuNumeric.NDArray(A) - ref = sum(T_acc.(diag(A))) - out = cuNumeric.trace(nda; T=T_acc) - + ref = sum(diag(A)) # widens ints like trace's accumulator + out = cuNumeric.trace(nda) allowscalar() do - @test ref ≈ out[1] atol=atol(T_acc) rtol=rtol(T_acc) + @test ref ≈ out[1] atol=atol(eltype(ref)) rtol=rtol(eltype(ref)) end end end @testset "trace with offset" begin @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) - # float accumulation for ints for overflow prevention - T_acc = T <: Integer ? Float64 : T - A = my_rand(T, 5, 5) nda = cuNumeric.NDArray(A) @testset "offset=$(k)" for k in (-2, -1, 0, 1, 2) - ref = sum(T_acc.(diag(A, k))) - out = cuNumeric.trace(nda; offset=k, T=T_acc) - + ref = sum(diag(A, k)) + out = cuNumeric.trace(nda; offset=k) allowscalar() do - @test ref ≈ out[1] atol=atol(T_acc) rtol=rtol(T_acc) + @test ref ≈ out[1] atol=atol(eltype(ref)) rtol=rtol(eltype(ref)) end end end diff --git a/test/tests/stability.jl b/test/tests/stability.jl index 13d6fad0..7e488d27 100644 --- a/test/tests/stability.jl +++ b/test/tests/stability.jl @@ -118,3 +118,16 @@ end end end end + +@testset verbose = true "linalg ops" begin + @testset "$(T)" for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + M = cuNumeric.zeros(T, 4, 3) + sq = cuNumeric.zeros(T, 5, 5) + v = cuNumeric.zeros(T, 8) + @test @inferred(cuNumeric.eye(T, 5)) !== nothing + @test @inferred(cuNumeric.transpose(M)) !== nothing + @test @inferred(cuNumeric.trace(sq)) !== nothing + @test @inferred(cuNumeric.diag(sq)) !== nothing + @test @inferred(cuNumeric.unique(v)) !== nothing + end +end