diff --git a/Project.toml b/Project.toml index 648882fa..c898d6b4 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,6 @@ Pkg = "1" Preferences = "1" Random = "1" StatsBase = "0.34" -cunumeric_jl_wrapper_jll = "25.10.4" +cunumeric_jl_wrapper_jll = "25.10.3" cupynumeric_jll = "25.10.3" julia = "1.10" diff --git a/lib/cunumeric_jl_wrapper/include/types.h b/lib/cunumeric_jl_wrapper/include/types.h index c75754db..88a18dc9 100644 --- a/lib/cunumeric_jl_wrapper/include/types.h +++ b/lib/cunumeric_jl_wrapper/include/types.h @@ -65,3 +65,6 @@ void wrap_unary_reds(jlcxx::Module&); // Binary op codes void wrap_binary_ops(jlcxx::Module&); + +// Linear algebra op codes +void wrap_linalg_ops(jlcxx::Module& mod); diff --git a/lib/cunumeric_jl_wrapper/include/ufi.h b/lib/cunumeric_jl_wrapper/include/ufi.h index 91fa0151..b132dd66 100644 --- a/lib/cunumeric_jl_wrapper/include/ufi.h +++ b/lib/cunumeric_jl_wrapper/include/ufi.h @@ -1,6 +1,6 @@ /* Copyright 2026 Northwestern University, * Carnegie Mellon University University - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/lib/cunumeric_jl_wrapper/src/types.cpp b/lib/cunumeric_jl_wrapper/src/types.cpp index 2e73ebd5..f181dc2e 100644 --- a/lib/cunumeric_jl_wrapper/src/types.cpp +++ b/lib/cunumeric_jl_wrapper/src/types.cpp @@ -162,3 +162,8 @@ void wrap_binary_ops(jlcxx::Module& mod) { mod.set_const("SUBTRACT", CuPyNumericBinaryOpCode::CUPYNUMERIC_BINOP_SUBTRACT); } + +void wrap_linalg_ops(jlcxx::Module& mod) { + mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE}); + mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE}); +} \ No newline at end of file diff --git a/lib/cunumeric_jl_wrapper/src/wrapper.cpp b/lib/cunumeric_jl_wrapper/src/wrapper.cpp index 912732fd..c5c792fa 100644 --- a/lib/cunumeric_jl_wrapper/src/wrapper.cpp +++ b/lib/cunumeric_jl_wrapper/src/wrapper.cpp @@ -65,6 +65,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) { wrap_unary_ops(mod); wrap_binary_ops(mod); wrap_unary_reds(mod); + wrap_linalg_ops(mod); using jlcxx::ParameterList; using jlcxx::Parametric; diff --git a/src/cuNumeric.jl b/src/cuNumeric.jl index 3a9280e6..b4f18637 100644 --- a/src/cuNumeric.jl +++ b/src/cuNumeric.jl @@ -55,12 +55,26 @@ const DEFAULT_FLOAT = Float32 const DEFAULT_INT = Int32 const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} -const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} # Float16 not supported yet +# Float16 is only supported by the backend when built with CUDA +@static if HAS_CUDA + const SUPPORTED_FLOAT_TYPES = Union{Float16,Float32,Float64} +else + const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} +end + const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64} const SUPPORTED_NUMERIC_TYPES = Union{ SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES } + +const SUPPORTED_LINALG_TYPES = Union{ + SUPPORTED_INT_TYPES,Float32,Float64,SUPPORTED_COMPLEX_TYPES +} + +# solve has no integer/Float16 backend kernel — float/complex only. +const SUPPORTED_SOLVE_TYPES = Union{Float32,Float64,SUPPORTED_COMPLEX_TYPES} + const SUPPORTED_ARRAY_TYPES = Union{Bool,SUPPORTED_NUMERIC_TYPES} const SUPPORTED_TYPES = Union{SUPPORTED_ARRAY_TYPES,String} @@ -144,6 +158,7 @@ include("ndarray/broadcast.jl") include("ndarray/ndarray.jl") include("ndarray/unary.jl") include("ndarray/binary.jl") +include("ndarray/linalg.jl") # scoping macro include("scoping.jl") @@ -230,7 +245,7 @@ function __init__() _is_precompiling() && return nothing # Cannot set LEGATE_CONFIG on CI machines used - # to register packages. So we will just skip starting + # to register packages. So we will just skip starting # legate/cunumeric when using registry CI machines. get(ENV, "JULIA_REGISTRYCI_AUTOMERGE", false) == "true" && return nothing diff --git a/src/ndarray/detail/ndarray.jl b/src/ndarray/detail/ndarray.jl index a13073e0..9d630fcc 100644 --- a/src/ndarray/detail/ndarray.jl +++ b/src/ndarray/detail/ndarray.jl @@ -236,18 +236,21 @@ function nda_array_equal(rhs1::NDArray{T,N}, rhs2::NDArray{T,N}) where {T,N} return NDArray(ptr, Bool, Val(1)) end -function nda_diag(arr::NDArray, k::Int32) +# 2D -> 1D: extract the k-th diagonal. Backend only supports the 2D case +# (1D-construct and >2D both abort), so non-2D input is a MethodError. +function nda_diag(arr::NDArray{T,2}, k::Int32) where {T} ptr = ccall((:nda_diag, libnda), NDArray_t, (NDArray_t, Int32), arr.ptr, k) - return NDArray(ptr) + return NDArray(ptr, T, Val(1)) end -function nda_unique(arr::NDArray) +# unique always returns a flat 1D array of the input's element type +function nda_unique(arr::NDArray{T}) where {T} ptr = ccall((:nda_unique, libnda), NDArray_t, (NDArray_t,), arr.ptr) - return NDArray(ptr) + return NDArray(ptr, T, Val(1)) end function nda_ravel(arr::NDArray) @@ -315,11 +318,12 @@ function nda_trace( return NDArray(ptr, T, Val(1)) end -function nda_transpose(arr::NDArray) +# transpose reverses the axes: element type and rank are preserved +function nda_transpose(arr::NDArray{T,N}) where {T,N} ptr = ccall((:nda_transpose, libnda), NDArray_t, (NDArray_t,), arr.ptr) - return NDArray(ptr) + return NDArray(ptr, T, Val(N)) end function nda_attach_external(arr::AbstractArray{T,N}) where {T,N} @@ -501,3 +505,9 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real) # successful completion return true end + +function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} + la_handle = cuNumeric.get_store(arr) # LogicalArrayImplAllocated (returned by value) + st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle, size(arr))) + return Legate.LogicalStore{T,N}(st_handle, size(arr)) +end diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl new file mode 100644 index 00000000..1c01b44a --- /dev/null +++ b/src/ndarray/linalg.jl @@ -0,0 +1,119 @@ +function choose_nd_color_shape(shape::NTuple{N,Int}) where {N} + color_shape = Base.ones(Int, N) + if N > 2 + color_shape[1] = Legate.num_procs() + done = false + while !done && color_shape[1] % 2 == 0 + weight_per_dim = [shape[i] / color_shape[i] for i in 1:(N - 2)] + max_weight, idx = findmax(weight_per_dim) + if weight_per_dim[idx] > 2 * weight_per_dim[1] + color_shape[1] ÷= 2 + color_shape[idx] *= 2 + else + done = true + end + end + end + return Tuple(color_shape) +end + +function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where {N} + initial_color_shape = choose_nd_color_shape(full_shape) + tilesize = Tuple( + (full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N + ) + color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) + return tilesize, color_shape +end + +function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} + nrhs = size(b)[end] + full_shape = size(a) + tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) + tilesize_b = (tilesize_a[1:(end - 1)]..., nrhs) + + store_a = nda_to_logical_store(a) + store_b = nda_to_logical_store(b) + store_x = nda_to_logical_store(x) + + tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a)) + tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b)) + tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b)) + + rt = Legate.get_runtime() + domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape))) + lib = cuNumeric.get_lib() + task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain) + + Legate.add_input(task, tiled_a) + Legate.add_input(task, tiled_b) + Legate.add_output(task, tiled_x) + + Legate.submit_manual_task(rt, task) +end + +# solve runs in floating point: +# int/bool inputs promote to Float64 (matching cupynumeric) +const _SOLVE_PROMOTABLE = Union{SUPPORTED_INT_TYPES,Bool} +const _SOLVE_ACCEPTED = Union{SUPPORTED_SOLVE_TYPES,_SOLVE_PROMOTABLE} +_solve_eltype(::Type{T}) where {T<:_SOLVE_PROMOTABLE} = Float64 +_solve_eltype(::Type{T}) where {T<:SUPPORTED_SOLVE_TYPES} = T + +# Type/dim guards dispatch on one argument at a time, then forward to `_solve`. +function solve(a::NDArray{<:_SOLVE_ACCEPTED}, b::NDArray{<:_SOLVE_ACCEPTED}) + A, B = eltype(a), eltype(b) + O = promote_type(_solve_eltype(A), _solve_eltype(B)) + # int/bool -> float is an implicit promotion, disallowed unless `allowpromotion` + A <: _SOLVE_PROMOTABLE && assertpromotion(solve, A, O) + B <: _SOLVE_PROMOTABLE && assertpromotion(solve, B, O) + return _solve_check_a_dims(unchecked_promote_arr(a, O), unchecked_promote_arr(b, O)) +end + +function solve(a::NDArray, b::NDArray) + bad = eltype(a) <: _SOLVE_ACCEPTED ? eltype(b) : eltype(a) + throw(ArgumentError("array type $bad is unsupported in solve")) +end + +# `a` must be at least 2D, `b` at least 1D. +function _solve_check_a_dims(a::NDArray{<:Any,0}, b::NDArray) + throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional")) +end +function _solve_check_a_dims(a::NDArray{<:Any,1}, b::NDArray) + throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional")) +end +_solve_check_a_dims(a::NDArray, b::NDArray) = _solve_check_b_dims(a, b) + +function _solve_check_b_dims(a::NDArray, b::NDArray{<:Any,0}) + throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional")) +end +_solve_check_b_dims(a::NDArray, b::NDArray) = _solve(a, b) + +# 2D case: (m,m),(m)->(m). +# Backend needs rhs "b" to be 2D. We reshape b from (n,) to (n,1) +function _solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} + m = size(b)[1] + return reshape(_solve(a, reshape(b, (m, 1))), (m,)) +end + +# 2D (m,m),(m,n)->(m,n) and batched (...,m,m),(...,m,n)->(...,m,n) +function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} + size(a)[end - 1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[end] != size(b)[end - 1] && + throw( + ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))", + ), + ) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# Mismatched batch dimensions +function _solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} + throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)")) +end diff --git a/src/ndarray/ndarray.jl b/src/ndarray/ndarray.jl index d58f6572..086a60ca 100644 --- a/src/ndarray/ndarray.jl +++ b/src/ndarray/ndarray.jl @@ -30,21 +30,27 @@ function transpose(arr::NDArray) end @doc""" - cuNumeric.eye(rows::Int; T=Float32) + cuNumeric.eye([T,] rows::Int) -Create a 2D identity `NDArray` of size `rows x rows` with element type `T`. +Create a 2D identity `NDArray` of size `rows x rows` with element type `T` +(defaults to `DEFAULT_FLOAT`). """ -function eye(rows::Int; T::Type{S}=Float64) where {S} - return nda_eye(Int32(rows), S) +function eye(::Type{T}, rows::Int) where {T} + return nda_eye(Int32(rows), T) +end +function eye(rows::Int) + return eye(DEFAULT_FLOAT, rows) end @doc""" - cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1, T=Float32) + cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1) -Compute the trace of the `NDArray` along the specified axes. +Compute the trace (sum of a diagonal) of the `NDArray`. +The accumulator type follows promotions of other reductions like 'sum'. """ -function trace(arr::NDArray; offset::Int=0, a1::Int=0, a2::Int=1, T::Type{S}=Float32) where {S} - return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), S) +function trace(arr::NDArray{T}; offset::Int=0, a1::Int=0, a2::Int=1) where {T} + T_OUT = Base.promote_op(Base.sum, Vector{T}) + return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), T_OUT) end @doc""" diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index 32a18100..8f9c060d 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -19,22 +19,24 @@ =# @testset "transpose" begin - A = rand(Float64, 4, 3) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 4, 3) + nda = cuNumeric.NDArray(A) - ref = transpose(A) - out = cuNumeric.transpose(nda) + ref = transpose(A) + out = cuNumeric.transpose(nda) - allowscalar() do - @test cuNumeric.compare(ref, out, atol(Float64), rtol(Float64)) + allowscalar() do + @test cuNumeric.compare(ref, out, atol(T), rtol(T)) + end end end @testset "eye" begin - for T in (Float32, Float64, Int32) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) n = 5 ref = Matrix{T}(I, n, n) - out = cuNumeric.eye(n; T=T) + out = cuNumeric.eye(T, n) allowscalar() do @test cuNumeric.compare(ref, out, atol(T), rtol(T)) end @@ -42,41 +44,45 @@ end end @testset "trace" begin - A = rand(Float64, 6, 6) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 6, 6) + nda = cuNumeric.NDArray(A) - ref = tr(A) - out = cuNumeric.trace(nda) - - allowscalar() do - @test ref ≈ out[1] atol=atol(Float32) rtol=rtol(Float32) + ref = sum(diag(A)) # widens ints like trace's accumulator + out = cuNumeric.trace(nda) + allowscalar() do + @test ref ≈ out[1] atol=atol(eltype(ref)) rtol=rtol(eltype(ref)) + end end end @testset "trace with offset" begin - A = rand(Float32, 5, 5) - nda = cuNumeric.NDArray(A) - - for k in (-2, -1, 0, 1, 2) - ref = sum(diag(A, k)) - out = cuNumeric.trace(nda; offset=k) - - allowscalar() do - @test ref ≈ out[1] atol=atol(Float32) rtol=rtol(Float32) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 5, 5) + nda = cuNumeric.NDArray(A) + + @testset "offset=$(k)" for k in (-2, -1, 0, 1, 2) + ref = sum(diag(A, k)) + out = cuNumeric.trace(nda; offset=k) + allowscalar() do + @test ref ≈ out[1] atol=atol(eltype(ref)) rtol=rtol(eltype(ref)) + end end end end @testset "diag" begin - A = rand(Int, 6, 6) - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = my_rand(T, 6, 6) + nda = cuNumeric.NDArray(A) - for k in (-2, 0, 3) - ref = diag(A, k) - out = cuNumeric.diag(nda; k=k) + @testset "k=$(k)" for k in (-2, 0, 3) + ref = diag(A, k) + out = cuNumeric.diag(nda; k=k) - allowscalar() do - @test cuNumeric.compare(ref, out, atol(Int32), rtol(Int32)) + allowscalar() do + @test cuNumeric.compare(ref, out, atol(T), rtol(T)) + end end end end @@ -94,11 +100,88 @@ end # end @testset "unique" begin - A = [1, 2, 2, 3, 4, 4, 4, 5] - nda = cuNumeric.NDArray(A) + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + A = T[1, 2, 2, 3, 4, 4, 4, 5] + nda = cuNumeric.NDArray(A) + + ref = unique(A) + out = cuNumeric.unique(nda) + + @test Set(Array(out)) == Set(ref) + end +end + +@testset "solve diagonal" begin + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + n = 4 + A = cuNumeric.zeros(T, n, n) + b = cuNumeric.zeros(T, n, 1) + cuNumeric.@allowscalar for i in 1:n + A[i, i] = T(4) + b[i, 1] = T(1) + end + x = cuNumeric.solve(A, b) + allowscalar() do + @test cuNumeric.compare(fill(T(0.25), n, 1), x, atol(T), rtol(T)) + end + end +end + +@testset "solve identity" begin + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + n = 4 + A = cuNumeric.NDArray(Matrix{T}(I, n, n)) + b = cuNumeric.NDArray(reshape(T.(collect(1:n)), n, 1)) + x = cuNumeric.solve(A, b) + ref = reshape(T.(collect(1:n)), n, 1) + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) + end + end +end + +@testset "solve general" begin + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + A_ref = T[2 1; 5 7] + b_ref = T[11; 13;;] # creates a 2d matrix instead of vector + A = cuNumeric.NDArray(A_ref) + b = cuNumeric.NDArray(b_ref) + x = cuNumeric.solve(A, b) + ref = A_ref \ b_ref + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) + end + end +end - ref = unique(A) - out = cuNumeric.unique(nda) +@testset "solve vector rhs" begin + @testset verbose=true for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + A_ref = T[2 1; 5 7] + b_ref = T[11, 13] + x = cuNumeric.solve(cuNumeric.NDArray(A_ref), cuNumeric.NDArray(b_ref)) + @test ndims(x) == 1 + ref = A_ref \ b_ref + allowscalar() do + @test cuNumeric.compare(ref, x, atol(T), rtol(T)) + end + end +end - @test sort(Array(out)) == sort(ref) +@testset "solve promotion" begin + @testset verbose=true for T in (Int32, Int64, Bool) + A = cuNumeric.NDArray(T[1 0; 0 1]) + b = cuNumeric.NDArray(reshape(T[1, 1], 2, 1)) + + # int/bool requires promotion to float. Will throw without allowpromtion() + @test_throws "Implicit promotion" cuNumeric.solve(A, b) + + # ...allowed under @allowpromotion, result is Float64 + allowpromotion() do + x = cuNumeric.solve(A, b) + ref = Float64[1 0; 0 1] \ Float64[1; 1;;] + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end + end + end end diff --git a/test/tests/stability.jl b/test/tests/stability.jl index 648f4bc4..7e488d27 100644 --- a/test/tests/stability.jl +++ b/test/tests/stability.jl @@ -1,4 +1,4 @@ -#= Copyright 2026 Northwestern University, +#= Copyright 2026 Northwestern University, * Carnegie Mellon University University * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -98,3 +98,36 @@ end @test @inferred(a ./ b) !== nothing @test @inferred(((a .* b) .+ a) .* 2.0f0) !== nothing end + +@testset verbose = true "solve" begin + # native float/complex, 2D and 1D rhs + @testset "$(T)" for T in Base.uniontypes(cuNumeric.SUPPORTED_SOLVE_TYPES) + A = cuNumeric.NDArray(T[2 1; 5 7]) + b2 = cuNumeric.NDArray(T[11; 13;;]) # creates a 2d matrix instead of vector + b1 = cuNumeric.NDArray(T[11, 13]) + @test @inferred(cuNumeric.solve(A, b2)) !== nothing + @test @inferred(cuNumeric.solve(A, b1)) !== nothing + end + + # int/bool promote to Float64 (under allowpromotion) and stay inferrable + @testset "promote $(T)" for T in (Int32, Int64, Bool) + A = cuNumeric.NDArray(T[1 0; 0 1]) + b = cuNumeric.NDArray(reshape(T[1, 1], 2, 1)) + allowpromotion() do + @test @inferred(cuNumeric.solve(A, b)) !== nothing + end + end +end + +@testset verbose = true "linalg ops" begin + @testset "$(T)" for T in Base.uniontypes(cuNumeric.SUPPORTED_LINALG_TYPES) + M = cuNumeric.zeros(T, 4, 3) + sq = cuNumeric.zeros(T, 5, 5) + v = cuNumeric.zeros(T, 8) + @test @inferred(cuNumeric.eye(T, 5)) !== nothing + @test @inferred(cuNumeric.transpose(M)) !== nothing + @test @inferred(cuNumeric.trace(sq)) !== nothing + @test @inferred(cuNumeric.diag(sq)) !== nothing + @test @inferred(cuNumeric.unique(v)) !== nothing + end +end