@@ -5,18 +5,22 @@ using Reactant:
5
5
Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
6
6
using ReactantCore: @trace
7
7
using KernelAbstractions: KernelAbstractions
8
+ import KernelAbstractions as KA
8
9
using Libdl
10
+ const ReactantKernelAbstractionsExt = Base. get_extension (
11
+ Reactant, :ReactantKernelAbstractionsExt
12
+ )
13
+ const ReactantBackend = ReactantKernelAbstractionsExt. ReactantBackend
9
14
10
15
using Adapt
11
16
12
- KernelAbstractions. get_backend (:: AnyTracedRArray ) = CUDABackend ()
13
- KernelAbstractions. get_backend (:: AnyConcreteRArray ) = CUDABackend ()
14
-
15
17
struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
16
18
ptr:: Core.LLVMPtr{T,A}
17
19
18
20
function CuTracedArray {T,N,A,Size} (xs:: TracedRArray ) where {T,N,A,Size}
19
- push! (Reactant. Compiler. context_gc_vector[MLIR. IR. context ()], xs)
21
+ gc_vec = Reactant. Compiler. context_gc_vector[MLIR. IR. context ()]
22
+ push! (gc_vec, xs)
23
+ @assert gc_vec[end ] === xs
20
24
ptr = Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, Base. pointer_from_objref (xs))
21
25
return new (ptr)
22
26
end
@@ -261,6 +265,78 @@ function Adapt.adapt_structure(
261
265
)
262
266
end
263
267
268
+ function threads_to_workgroupsize (threads, ndrange)
269
+ total = 1
270
+ return map (ndrange) do n
271
+ x = min (div (threads, total), n)
272
+ total *= x
273
+ return x
274
+ end
275
+ end
276
+
277
+ function ka_with_reactant (ndrange, workgroupsize, obj, args... )
278
+ backend = KA. backend (obj)
279
+
280
+ ndrange, workgroupsize, iterspace, dynamic = KA. launch_config (
281
+ obj, ndrange, workgroupsize
282
+ )
283
+ # this might not be the final context, since we may tune the workgroupsize
284
+ ctx = KA. mkcontext (obj, ndrange, iterspace)
285
+
286
+ # If the kernel is statically sized we can tell the compiler about that
287
+ if KA. workgroupsize (obj) <: KA.StaticSize
288
+ maxthreads = prod (KA. get (KA. workgroupsize (obj)))
289
+ else
290
+ maxthreads = nothing
291
+ end
292
+
293
+ kernel = CUDA. @cuda launch = false always_inline = backend. always_inline maxthreads =
294
+ maxthreads obj. f (ctx, args... )
295
+
296
+ # figure out the optimal workgroupsize automatically
297
+ if KA. workgroupsize (obj) <: KA.DynamicSize && workgroupsize === nothing
298
+ if ! Reactant. Compiler. PartitionKA[]
299
+ threads = prod (ndrange)
300
+ else
301
+ config = CUDA. launch_configuration (kernel. fun; max_threads= prod (ndrange))
302
+ if backend. prefer_blocks
303
+ # Prefer blocks over threads
304
+ threads = min (prod (ndrange), config. threads)
305
+ # XXX : Some kernels performs much better with all blocks active
306
+ cu_blocks = max (cld (prod (ndrange), threads), config. blocks)
307
+ threads = cld (prod (ndrange), cu_blocks)
308
+ else
309
+ threads = config. threads
310
+ end
311
+ workgroupsize = threads_to_workgroupsize (threads, ndrange)
312
+ iterspace, dynamic = KA. partition (obj, ndrange, workgroupsize)
313
+ end
314
+ ctx = KA. mkcontext (obj, ndrange, iterspace)
315
+ end
316
+
317
+ blocks = length (KA. blocks (iterspace))
318
+ threads = length (KA. workitems (iterspace))
319
+
320
+ if blocks == 0
321
+ return nothing
322
+ end
323
+
324
+ # Launch kernel
325
+ kernel (ctx, args... ; threads, blocks)
326
+
327
+ return nothing
328
+ end
329
+
330
+ Reactant. @reactant_overlay @noinline function (obj:: KA.Kernel{ReactantBackend} )(
331
+ args... ; ndrange= nothing , workgroupsize= nothing
332
+ )
333
+ return Reactant. call_with_reactant (
334
+ ka_with_reactant, ndrange, workgroupsize, obj, args...
335
+ )
336
+ end
337
+
338
+ Adapt. adapt_storage (to:: KA.ConstAdaptor , a:: CuTracedArray ) = Base. Experimental. Const (a)
339
+
264
340
function recudaconvert (arg)
265
341
return adapt (ReactantKernelAdaptor (), arg)
266
342
end
@@ -618,8 +694,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
618
694
array_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. IR. Type (Int8), sz))
619
695
cdata = MLIR. IR. result (
620
696
MLIR. Dialects. llvm. mlir_constant (;
621
- res= array_ty,
622
- value= MLIR. IR. DenseElementsAttribute (to_bytes (a)), # TODO : mlir_constant cannot be processed by the julia generator atm.
697
+ res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))
623
698
),
624
699
1 ,
625
700
)
@@ -841,45 +916,21 @@ function Reactant.make_tracer(
841
916
end
842
917
843
918
function __init__ ()
844
- if isdefined (CUDA. CUDA_Driver_jll, :libcuda ) && CUDA. CUDA_Driver_jll. libcuda != = nothing
845
- handle = Reactant. XLA. Libdl. dlopen (CUDA. CUDA_Driver_jll. libcuda; throw_error= false )
846
- if handle === nothing
847
- handle = C_NULL
848
- end
849
- ptr1 = Reactant. XLA. Libdl. dlsym (handle, " cuLaunchKernel" ; throw_error= false )
850
- if ptr1 === nothing
851
- ptr1 = C_NULL
852
- end
853
- ptr2 = Reactant. XLA. Libdl. dlsym (handle, " cuModuleLoadData" ; throw_error= false )
854
- if ptr2 === nothing
855
- ptr2 = C_NULL
856
- end
857
- ptr3 = Reactant. XLA. Libdl. dlsym (handle, " cuModuleGetFunction" ; throw_error= false )
858
- if ptr3 === nothing
859
- ptr3 = C_NULL
860
- end
861
- Reactant. Compiler. cuLaunch[] = Base. reinterpret (UInt, ptr1)
862
- Reactant. Compiler. cuModule[] = Base. reinterpret (UInt, ptr2)
863
- Reactant. Compiler. cuFunc[] = Base. reinterpret (UInt, ptr3)
864
- ptr4 = Reactant. XLA. Libdl. dlsym (handle, " cuStreamSynchronize" ; throw_error= false )
865
- if ptr4 === nothing
866
- ptr4 = C_NULL
867
- end
868
- Reactant. Compiler. cuSync[] = Base. reinterpret (UInt, ptr4)
869
- end
870
919
if CUDA. functional ()
871
920
target = CUDA. _compiler_config (CUDA. device ()). target
872
921
Reactant. Compiler. cubinChip[] = " sm_$(target. cap. major)$(target. cap. minor) "
873
922
end
874
923
return nothing
875
924
end
876
925
877
- @static if ! Sys. isapple () && Sys. ARCH != :aarch64
926
+ # In Julia v1.11.3 precompiling this module caches bad code:
927
+ # <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
928
+ @static if ! Sys. isapple ()
878
929
Reactant. PrecompileTools. @setup_workload begin
879
930
Reactant. initialize_dialect ()
880
931
client = Reactant. XLA. CPUClient (; checkcount= false )
881
932
Reactant. PrecompileTools. @compile_workload begin
882
- @static if Reactant. precompilation_supported ()
933
+ @static if Reactant. precompilation_supported () && VERSION != v " 1.11.3 "
883
934
function square_kernel! (x)
884
935
i = CUDA. threadIdx (). x
885
936
x[i] *= x[i]
0 commit comments