Skip to content

Commit 17d67f8

Browse files
committed
fix rebase
1 parent efa88c9 commit 17d67f8

File tree

1 file changed

+85
-34
lines changed

1 file changed

+85
-34
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@ using Reactant:
55
Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
66
using ReactantCore: @trace
77
using KernelAbstractions: KernelAbstractions
8+
import KernelAbstractions as KA
89
using Libdl
10+
const ReactantKernelAbstractionsExt = Base.get_extension(
11+
Reactant, :ReactantKernelAbstractionsExt
12+
)
13+
const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend
914

1015
using Adapt
1116

12-
KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()
13-
KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend()
14-
1517
struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
1618
ptr::Core.LLVMPtr{T,A}
1719

1820
function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
19-
push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs)
21+
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
22+
push!(gc_vec, xs)
23+
@assert gc_vec[end] === xs
2024
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
2125
return new(ptr)
2226
end
@@ -261,6 +265,78 @@ function Adapt.adapt_structure(
261265
)
262266
end
263267

268+
function threads_to_workgroupsize(threads, ndrange)
269+
total = 1
270+
return map(ndrange) do n
271+
x = min(div(threads, total), n)
272+
total *= x
273+
return x
274+
end
275+
end
276+
277+
function ka_with_reactant(ndrange, workgroupsize, obj, args...)
278+
backend = KA.backend(obj)
279+
280+
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(
281+
obj, ndrange, workgroupsize
282+
)
283+
# this might not be the final context, since we may tune the workgroupsize
284+
ctx = KA.mkcontext(obj, ndrange, iterspace)
285+
286+
# If the kernel is statically sized we can tell the compiler about that
287+
if KA.workgroupsize(obj) <: KA.StaticSize
288+
maxthreads = prod(KA.get(KA.workgroupsize(obj)))
289+
else
290+
maxthreads = nothing
291+
end
292+
293+
kernel = CUDA.@cuda launch = false always_inline = backend.always_inline maxthreads =
294+
maxthreads obj.f(ctx, args...)
295+
296+
# figure out the optimal workgroupsize automatically
297+
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
298+
if !Reactant.Compiler.PartitionKA[]
299+
threads = prod(ndrange)
300+
else
301+
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))
302+
if backend.prefer_blocks
303+
# Prefer blocks over threads
304+
threads = min(prod(ndrange), config.threads)
305+
# XXX: Some kernels performs much better with all blocks active
306+
cu_blocks = max(cld(prod(ndrange), threads), config.blocks)
307+
threads = cld(prod(ndrange), cu_blocks)
308+
else
309+
threads = config.threads
310+
end
311+
workgroupsize = threads_to_workgroupsize(threads, ndrange)
312+
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
313+
end
314+
ctx = KA.mkcontext(obj, ndrange, iterspace)
315+
end
316+
317+
blocks = length(KA.blocks(iterspace))
318+
threads = length(KA.workitems(iterspace))
319+
320+
if blocks == 0
321+
return nothing
322+
end
323+
324+
# Launch kernel
325+
kernel(ctx, args...; threads, blocks)
326+
327+
return nothing
328+
end
329+
330+
Reactant.@reactant_overlay @noinline function (obj::KA.Kernel{ReactantBackend})(
331+
args...; ndrange=nothing, workgroupsize=nothing
332+
)
333+
return Reactant.call_with_reactant(
334+
ka_with_reactant, ndrange, workgroupsize, obj, args...
335+
)
336+
end
337+
338+
Adapt.adapt_storage(to::KA.ConstAdaptor, a::CuTracedArray) = Base.Experimental.Const(a)
339+
264340
function recudaconvert(arg)
265341
return adapt(ReactantKernelAdaptor(), arg)
266342
end
@@ -618,8 +694,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
618694
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
619695
cdata = MLIR.IR.result(
620696
MLIR.Dialects.llvm.mlir_constant(;
621-
res=array_ty,
622-
value=MLIR.IR.DenseElementsAttribute(to_bytes(a)), #TODO: mlir_constant cannot be processed by the julia generator atm.
697+
res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))
623698
),
624699
1,
625700
)
@@ -841,45 +916,21 @@ function Reactant.make_tracer(
841916
end
842917

843918
function __init__()
844-
if isdefined(CUDA.CUDA_Driver_jll, :libcuda) && CUDA.CUDA_Driver_jll.libcuda !== nothing
845-
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
846-
if handle === nothing
847-
handle = C_NULL
848-
end
849-
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
850-
if ptr1 === nothing
851-
ptr1 = C_NULL
852-
end
853-
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
854-
if ptr2 === nothing
855-
ptr2 = C_NULL
856-
end
857-
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
858-
if ptr3 === nothing
859-
ptr3 = C_NULL
860-
end
861-
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
862-
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
863-
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
864-
ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false)
865-
if ptr4 === nothing
866-
ptr4 = C_NULL
867-
end
868-
Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4)
869-
end
870919
if CUDA.functional()
871920
target = CUDA._compiler_config(CUDA.device()).target
872921
Reactant.Compiler.cubinChip[] = "sm_$(target.cap.major)$(target.cap.minor)"
873922
end
874923
return nothing
875924
end
876925

877-
@static if !Sys.isapple() && Sys.ARCH != :aarch64
926+
# In Julia v1.11.3 precompiling this module caches bad code:
927+
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
928+
@static if !Sys.isapple()
878929
Reactant.PrecompileTools.@setup_workload begin
879930
Reactant.initialize_dialect()
880931
client = Reactant.XLA.CPUClient(; checkcount=false)
881932
Reactant.PrecompileTools.@compile_workload begin
882-
@static if Reactant.precompilation_supported()
933+
@static if Reactant.precompilation_supported() && VERSION != v"1.11.3"
883934
function square_kernel!(x)
884935
i = CUDA.threadIdx().x
885936
x[i] *= x[i]

0 commit comments

Comments
 (0)