From 2dc75bb27ef61b0555becc77e4dedbc11c99b9e5 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 12 Mar 2025 08:45:41 -0500 Subject: [PATCH 1/3] Add backend tls --- ext/ReactantCUDAExt.jl | 2 +- src/Compiler.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 69c273c865..64e17dd257 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -352,7 +352,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...) # figure out the optimal workgroupsize automatically if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing - if !Reactant.Compiler.PartitionKA[] || raising() + if !Reactant.Compiler.PartitionKA[] || raising() || backend() in ("cpu", "tpu") threads = prod(ndrange) else config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange)) diff --git a/src/Compiler.jl b/src/Compiler.jl index 73595dec7b..da02ec3aae 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -700,6 +700,38 @@ function raising!(f, is_raising::Bool) end end +function activate_backend!(backend::String) + stack = get!(task_local_storage(), :reactant_backend) do + String[] + end + push!(stack, backend) + return nothing +end + +function deactivate_backend!(backend::String) + key = :reactant_backend + backend === last(task_local_storage(key)) || + error("Deactivating wrong Reactant backend context") + return pop!(task_local_storage(key)) +end + +function backend(; throw_error::Bool=true) + key = :reactant_backend + if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key))) + throw_error && error("No Reactant backend context") + end + return last(task_local_storage(key)::Vector{Bool}) +end + +function backend!(f, backend::String) + activate_backend!(backend) + try + return f() + finally + deactivate_backend!(backend) + end +end + function compile_mlir!( mod, f, @@ -747,12 +779,14 @@ function compile_mlir!( end is_raising = raise isa String || raise activate_raising!(is_raising) + activate_backend!(backend) mlir_fn_res = try Reactant.TracedUtils.make_mlir_fn( f, args, fn_kwargs, "main", true; input_shardings, runtime ) finally + deactivate_backend!(backend) deactivate_raising!(is_raising) deactivate_sdycache!(sdycache) deactivate_callcache!(callcache) From d24d09b1245e83360c7f92ef1ac6edb8e2639f95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Wed, 12 Mar 2025 14:14:10 +0000 Subject: [PATCH 2/3] Apply suggestions from code review --- ext/ReactantCUDAExt.jl | 2 +- src/Compiler.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 64e17dd257..c27e9be4e1 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -352,7 +352,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...) # figure out the optimal workgroupsize automatically if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing - if !Reactant.Compiler.PartitionKA[] || raising() || backend() in ("cpu", "tpu") + if !Reactant.Compiler.PartitionKA[] || raising() || Reactant.Compiler.backend() in ("cpu", "tpu") threads = prod(ndrange) else config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange)) diff --git a/src/Compiler.jl b/src/Compiler.jl index da02ec3aae..c3924c899e 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -720,7 +720,7 @@ function backend(; throw_error::Bool=true) if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key))) throw_error && error("No Reactant backend context") end - return last(task_local_storage(key)::Vector{Bool}) + return last(task_local_storage(key)::Vector{String}) end function backend!(f, backend::String) From c0ec6dcb3eca614f04dbdda6d4f504b3c90953e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Wed, 12 Mar 2025 14:16:22 +0000 Subject: [PATCH 3/3] Update ext/ReactantCUDAExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantCUDAExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index c27e9be4e1..a4f0543268 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -352,7 +352,9 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...) # figure out the optimal workgroupsize automatically if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing - if !Reactant.Compiler.PartitionKA[] || raising() || Reactant.Compiler.backend() in ("cpu", "tpu") + if !Reactant.Compiler.PartitionKA[] || + raising() || + Reactant.Compiler.backend() in ("cpu", "tpu") threads = prod(ndrange) else config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))