Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ Pkg = "1"
Preferences = "1"
Random = "1"
StatsBase = "0.34"
cunumeric_jl_wrapper_jll = "25.10.4"
cunumeric_jl_wrapper_jll = "25.10.3"
cupynumeric_jll = "25.10.3"
julia = "1.10"
3 changes: 3 additions & 0 deletions lib/cunumeric_jl_wrapper/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ void wrap_unary_reds(jlcxx::Module&);

// Binary op codes
void wrap_binary_ops(jlcxx::Module&);

// Linear algebra op codes
void wrap_linalg_ops(jlcxx::Module& mod);
2 changes: 1 addition & 1 deletion lib/cunumeric_jl_wrapper/include/ufi.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* Copyright 2026 Northwestern University,
* Carnegie Mellon University University
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down
5 changes: 5 additions & 0 deletions lib/cunumeric_jl_wrapper/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,8 @@ void wrap_binary_ops(jlcxx::Module& mod) {
mod.set_const("SUBTRACT",
CuPyNumericBinaryOpCode::CUPYNUMERIC_BINOP_SUBTRACT);
}

void wrap_linalg_ops(jlcxx::Module& mod) {
mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE});
mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE});
}
1 change: 1 addition & 0 deletions lib/cunumeric_jl_wrapper/src/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) {
wrap_unary_ops(mod);
wrap_binary_ops(mod);
wrap_unary_reds(mod);
wrap_linalg_ops(mod);

using jlcxx::ParameterList;
using jlcxx::Parametric;
Expand Down
19 changes: 17 additions & 2 deletions src/cuNumeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,26 @@ const DEFAULT_FLOAT = Float32
const DEFAULT_INT = Int32

const SUPPORTED_INT_TYPES = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64}
const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64} # Float16 not supported yet
# Float16 is only supported by the backend when built with CUDA
@static if HAS_CUDA
const SUPPORTED_FLOAT_TYPES = Union{Float16,Float32,Float64}
else
const SUPPORTED_FLOAT_TYPES = Union{Float32,Float64}
end

const SUPPORTED_COMPLEX_TYPES = Union{ComplexF32,ComplexF64}

const SUPPORTED_NUMERIC_TYPES = Union{
SUPPORTED_INT_TYPES,SUPPORTED_FLOAT_TYPES,SUPPORTED_COMPLEX_TYPES
}

const SUPPORTED_LINALG_TYPES = Union{
SUPPORTED_INT_TYPES,Float32,Float64,SUPPORTED_COMPLEX_TYPES
}

# solve has no integer/Float16 backend kernel — float/complex only.
const SUPPORTED_SOLVE_TYPES = Union{Float32,Float64,SUPPORTED_COMPLEX_TYPES}

const SUPPORTED_ARRAY_TYPES = Union{Bool,SUPPORTED_NUMERIC_TYPES}
const SUPPORTED_TYPES = Union{SUPPORTED_ARRAY_TYPES,String}

Expand Down Expand Up @@ -144,6 +158,7 @@ include("ndarray/broadcast.jl")
include("ndarray/ndarray.jl")
include("ndarray/unary.jl")
include("ndarray/binary.jl")
include("ndarray/linalg.jl")

# scoping macro
include("scoping.jl")
Expand Down Expand Up @@ -230,7 +245,7 @@ function __init__()
_is_precompiling() && return nothing

# Cannot set LEGATE_CONFIG on CI machines used
# to register packages. So we will just skip starting
# to register packages. So we will just skip starting
# legate/cunumeric when using registry CI machines.
get(ENV, "JULIA_REGISTRYCI_AUTOMERGE", false) == "true" && return nothing

Expand Down
22 changes: 16 additions & 6 deletions src/ndarray/detail/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,21 @@ function nda_array_equal(rhs1::NDArray{T,N}, rhs2::NDArray{T,N}) where {T,N}
return NDArray(ptr, Bool, Val(1))
end

function nda_diag(arr::NDArray, k::Int32)
# 2D -> 1D: extract the k-th diagonal. Backend only supports the 2D case
# (1D-construct and >2D both abort), so non-2D input is a MethodError.
function nda_diag(arr::NDArray{T,2}, k::Int32) where {T}
ptr = ccall((:nda_diag, libnda),
NDArray_t, (NDArray_t, Int32),
arr.ptr, k)
return NDArray(ptr)
return NDArray(ptr, T, Val(1))
end

function nda_unique(arr::NDArray)
# unique always returns a flat 1D array of the input's element type
function nda_unique(arr::NDArray{T}) where {T}
ptr = ccall((:nda_unique, libnda),
NDArray_t, (NDArray_t,),
arr.ptr)
return NDArray(ptr)
return NDArray(ptr, T, Val(1))
end

function nda_ravel(arr::NDArray)
Expand Down Expand Up @@ -315,11 +318,12 @@ function nda_trace(
return NDArray(ptr, T, Val(1))
end

function nda_transpose(arr::NDArray)
# transpose reverses the axes: element type and rank are preserved
function nda_transpose(arr::NDArray{T,N}) where {T,N}
ptr = ccall((:nda_transpose, libnda),
NDArray_t, (NDArray_t,),
arr.ptr)
return NDArray(ptr)
return NDArray(ptr, T, Val(N))
end

function nda_attach_external(arr::AbstractArray{T,N}) where {T,N}
Expand Down Expand Up @@ -501,3 +505,9 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real)
# successful completion
return true
end

function nda_to_logical_store(arr::NDArray{T,N}) where {T,N}
la_handle = cuNumeric.get_store(arr) # LogicalArrayImplAllocated (returned by value)
st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle, size(arr)))
return Legate.LogicalStore{T,N}(st_handle, size(arr))
end
119 changes: 119 additions & 0 deletions src/ndarray/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
function choose_nd_color_shape(shape::NTuple{N,Int}) where {N}
color_shape = Base.ones(Int, N)
if N > 2
color_shape[1] = Legate.num_procs()
done = false
while !done && color_shape[1] % 2 == 0
weight_per_dim = [shape[i] / color_shape[i] for i in 1:(N - 2)]
max_weight, idx = findmax(weight_per_dim)
if weight_per_dim[idx] > 2 * weight_per_dim[1]
color_shape[1] ÷= 2
color_shape[idx] *= 2
else
done = true
end
end
end
return Tuple(color_shape)
end

function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where {N}
initial_color_shape = choose_nd_color_shape(full_shape)
tilesize = Tuple(
(full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N
)
color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N)
return tilesize, color_shape
end

function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N}
nrhs = size(b)[end]
full_shape = size(a)
tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape)
tilesize_b = (tilesize_a[1:(end - 1)]..., nrhs)

store_a = nda_to_logical_store(a)
store_b = nda_to_logical_store(b)
store_x = nda_to_logical_store(x)

tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a))
tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b))
tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b))

rt = Legate.get_runtime()
domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape)))
lib = cuNumeric.get_lib()
task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain)

Legate.add_input(task, tiled_a)
Legate.add_input(task, tiled_b)
Legate.add_output(task, tiled_x)

Legate.submit_manual_task(rt, task)
end

# solve runs in floating point:
# int/bool inputs promote to Float64 (matching cupynumeric)
const _SOLVE_PROMOTABLE = Union{SUPPORTED_INT_TYPES,Bool}
const _SOLVE_ACCEPTED = Union{SUPPORTED_SOLVE_TYPES,_SOLVE_PROMOTABLE}
_solve_eltype(::Type{T}) where {T<:_SOLVE_PROMOTABLE} = Float64
_solve_eltype(::Type{T}) where {T<:SUPPORTED_SOLVE_TYPES} = T

# Type/dim guards dispatch on one argument at a time, then forward to `_solve`.
function solve(a::NDArray{<:_SOLVE_ACCEPTED}, b::NDArray{<:_SOLVE_ACCEPTED})
A, B = eltype(a), eltype(b)
O = promote_type(_solve_eltype(A), _solve_eltype(B))
# int/bool -> float is an implicit promotion, disallowed unless `allowpromotion`
A <: _SOLVE_PROMOTABLE && assertpromotion(solve, A, O)
B <: _SOLVE_PROMOTABLE && assertpromotion(solve, B, O)
return _solve_check_a_dims(unchecked_promote_arr(a, O), unchecked_promote_arr(b, O))
end

function solve(a::NDArray, b::NDArray)
bad = eltype(a) <: _SOLVE_ACCEPTED ? eltype(b) : eltype(a)
throw(ArgumentError("array type $bad is unsupported in solve"))
end

# `a` must be at least 2D, `b` at least 1D.
function _solve_check_a_dims(a::NDArray{<:Any,0}, b::NDArray)
throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional"))
end
function _solve_check_a_dims(a::NDArray{<:Any,1}, b::NDArray)
throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional"))
end
_solve_check_a_dims(a::NDArray, b::NDArray) = _solve_check_b_dims(a, b)

function _solve_check_b_dims(a::NDArray, b::NDArray{<:Any,0})
throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional"))
end
_solve_check_b_dims(a::NDArray, b::NDArray) = _solve(a, b)

# 2D case: (m,m),(m)->(m).
# Backend needs rhs "b" to be 2D. We reshape b from (n,) to (n,1)
function _solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S}
m = size(b)[1]
return reshape(_solve(a, reshape(b, (m, 1))), (m,))
end

# 2D (m,m),(m,n)->(m,n) and batched (...,m,m),(...,m,n)->(...,m,n)
function _solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N}
size(a)[end - 1] != size(a)[end] &&
throw(ArgumentError("Last 2 dimensions of the array must be square"))
size(a)[end] != size(b)[end - 1] &&
throw(
ArgumentError(
"Input operand 1 has a mismatch in its dimension " *
"$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" *
" (size $(size(b)[end-1]) is different from $(size(a)[end]))",
),
)
prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...)
x = zeros(T, size(b)...)
solve_batched(a, b, x)
return x
end

# Mismatched batch dimensions
function _solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M}
throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)"))
end
22 changes: 14 additions & 8 deletions src/ndarray/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,27 @@ function transpose(arr::NDArray)
end

@doc"""
cuNumeric.eye(rows::Int; T=Float32)
cuNumeric.eye([T,] rows::Int)

Create a 2D identity `NDArray` of size `rows x rows` with element type `T`.
Create a 2D identity `NDArray` of size `rows x rows` with element type `T`
(defaults to `DEFAULT_FLOAT`).
"""
function eye(rows::Int; T::Type{S}=Float64) where {S}
return nda_eye(Int32(rows), S)
function eye(::Type{T}, rows::Int) where {T}
return nda_eye(Int32(rows), T)
end
function eye(rows::Int)
return eye(DEFAULT_FLOAT, rows)
end

@doc"""
cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1, T=Float32)
cuNumeric.trace(arr::NDArray; offset=0, a1=0, a2=1)

Compute the trace of the `NDArray` along the specified axes.
Compute the trace (sum of a diagonal) of the `NDArray`.
The accumulator type follows promotions of other reductions like 'sum'.
"""
function trace(arr::NDArray; offset::Int=0, a1::Int=0, a2::Int=1, T::Type{S}=Float32) where {S}
return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), S)
function trace(arr::NDArray{T}; offset::Int=0, a1::Int=0, a2::Int=1) where {T}
T_OUT = Base.promote_op(Base.sum, Vector{T})
return nda_trace(arr, Int32(offset), Int32(a1), Int32(a2), T_OUT)
end

@doc"""
Expand Down
Loading
Loading