Skip to content
Merged
3 changes: 3 additions & 0 deletions lib/cunumeric_jl_wrapper/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ void wrap_unary_reds(jlcxx::Module&);

// Binary op codes
void wrap_binary_ops(jlcxx::Module&);

// Linear algebra op codes
void wrap_linalg_ops(jlcxx::Module& mod);
2 changes: 1 addition & 1 deletion lib/cunumeric_jl_wrapper/include/ufi.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* Copyright 2026 Northwestern University,
* Carnegie Mellon University University
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down
5 changes: 5 additions & 0 deletions lib/cunumeric_jl_wrapper/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,8 @@ void wrap_binary_ops(jlcxx::Module& mod) {
mod.set_const("SUBTRACT",
CuPyNumericBinaryOpCode::CUPYNUMERIC_BINOP_SUBTRACT);
}

void wrap_linalg_ops(jlcxx::Module& mod) {
mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE});
mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE});
}
1 change: 1 addition & 0 deletions lib/cunumeric_jl_wrapper/src/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) {
wrap_unary_ops(mod);
wrap_binary_ops(mod);
wrap_unary_reds(mod);
wrap_linalg_ops(mod);

using jlcxx::ParameterList;
using jlcxx::Parametric;
Expand Down
6 changes: 6 additions & 0 deletions src/ndarray/detail/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,9 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real)
# successful completion
return true
end

function nda_to_logical_store(arr::NDArray{T,N}) where {T,N}
la_handle = cuNumeric.get_store(arr)
st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr)))
return Legate.LogicalStore{T,N}(st_handle, size(arr))
end
101 changes: 101 additions & 0 deletions src/ndarray/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
function choose_nd_color_shape(shape::NTuple{N,Int}) where N
color_shape = Base.ones(Int, N)
if N > 2
color_shape[1] = Legate.num_procs()
done = false
while !done && color_shape[1] % 2 == 0
weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2]
max_weight, idx = findmax(weight_per_dim)
if weight_per_dim[idx] > 2 * weight_per_dim[1]
color_shape[1] ÷= 2
color_shape[idx] *= 2
else
done = true
end
end
end
return Tuple(color_shape)
end

function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N
initial_color_shape = choose_nd_color_shape(full_shape)
tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N)
color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N)
return tilesize, color_shape
end

function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N}
nrhs = size(b)[end]
full_shape = size(a)
tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape)
tilesize_b = (tilesize_a[1:end-1]..., nrhs)

store_a = nda_to_logical_store(a)
store_b = nda_to_logical_store(b)
store_x = nda_to_logical_store(x)

tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a))
tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b))
tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b))

rt = Legate.get_runtime()
domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape)))
lib = cuNumeric.get_lib()
task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain)

Legate.add_input(task, tiled_a)
Legate.add_input(task, tiled_b)
Legate.add_output(task, tiled_x)

Legate.submit_manual_task(rt, task)
end

function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M}
if N < 2
Comment thread
Nader-Rahhal marked this conversation as resolved.
Outdated
throw(ArgumentError("$(N)-dimensional array given. Array must be at least two-dimensional"))
end
if M < 1
throw(ArgumentError("$(M)-dimensional array given. Array must be at least one-dimensional"))
end
if T == Float16 || S == Float16
throw(ArgumentError("array type float16 is unsupported in linalg"))
end
if size(a)[end-1] != size(a)[end]
throw(ArgumentError("Last 2 dimensions of the array must be square"))
end
if N == 2 && size(a)[2] != size(b)[1]
if M == 1
throw(ArgumentError(
"Input operand 1 has a mismatch in its dimension 0, " *
"with signature (m,m),(m)->(m) (size $(size(b)[1]) " *
"is different from $(size(a)[2]))"
))
else
throw(ArgumentError(
"Input operand 1 has a mismatch in its dimension 0, " *
"with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " *
"is different from $(size(a)[2]))"
))
end
end
if N > 2
if N != M
throw(ArgumentError(
"Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)"
))
end
if size(a)[end] != size(b)[end-1]
throw(ArgumentError(
"Input operand 1 has a mismatch in its dimension " *
"$(M-2), with signature (...,m,m),(...,m,n)->(...,m,n)" *
" (size $(size(b)[end-1]) is different from $(size(a)[end]))"
))
end
end
if prod(size(a)) == 0 || prod(size(b)) == 0
return zeros(T, size(b)...)
end
x = zeros(T, size(b)...)
solve_batched(a, b, x)
return x
end
2 changes: 1 addition & 1 deletion src/ndarray/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -762,4 +762,4 @@ end

function Base.isapprox(arr::NDArray{T}, arr2::NDArray{T}; atol=0, rtol=0) where {T}
return compare(arr, arr2, atol, rtol)
end
end
37 changes: 37 additions & 0 deletions test/tests/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,40 @@ end

@test sort(Array(out)) == sort(ref)
end

@testset "solve diagonal" begin
n = 4
A = cuNumeric.zeros(Float64, n, n)
b = cuNumeric.zeros(Float64, n, 1)
cuNumeric.@allowscalar for i in 1:n
A[i, i] = 4.0
b[i, 1] = 1.0
end
x = cuNumeric.solve(A, b)
allowscalar() do
@test cuNumeric.compare(fill(0.25, n, 1), x, atol(Float64), rtol(Float64))
end
end

@testset "solve identity" begin
n = 4
A = cuNumeric.NDArray(Matrix{Float64}(I, n, n))
b = cuNumeric.NDArray(reshape(collect(1.0:n), n, 1))
x = cuNumeric.solve(A, b)
ref = reshape(collect(1.0:n), n, 1)
allowscalar() do
@test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64))
end
end

@testset "solve general" begin
A_ref = [2.0 1.0; 5.0 7.0]
b_ref = [11.0; 13.0;;]
A = cuNumeric.NDArray(A_ref)
b = cuNumeric.NDArray(b_ref)
x = cuNumeric.solve(A, b)
ref = A_ref \ b_ref
allowscalar() do
@test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64))
end
end
Loading