diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 69c273c865..a4f0543268 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -352,7 +352,9 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...) # figure out the optimal workgroupsize automatically if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing - if !Reactant.Compiler.PartitionKA[] || raising() + if !Reactant.Compiler.PartitionKA[] || + raising() || + Reactant.Compiler.backend() in ("cpu", "tpu") threads = prod(ndrange) else config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange)) diff --git a/src/Compiler.jl b/src/Compiler.jl index 73595dec7b..c3924c899e 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -700,6 +700,38 @@ function raising!(f, is_raising::Bool) end end +function activate_backend!(backend::String) + stack = get!(task_local_storage(), :reactant_backend) do + String[] + end + push!(stack, backend) + return nothing +end + +function deactivate_backend!(backend::String) + key = :reactant_backend + backend === last(task_local_storage(key)) || + error("Deactivating wrong Reactant backend context") + return pop!(task_local_storage(key)) +end + +function backend(; throw_error::Bool=true) + key = :reactant_backend + if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key))) + throw_error && error("No Reactant backend context") + end + return last(task_local_storage(key)::Vector{String}) +end + +function backend!(f, backend::String) + activate_backend!(backend) + try + return f() + finally + deactivate_backend!(backend) + end +end + function compile_mlir!( mod, f, @@ -747,12 +779,14 @@ function compile_mlir!( end is_raising = raise isa String || raise activate_raising!(is_raising) + activate_backend!(backend) mlir_fn_res = try Reactant.TracedUtils.make_mlir_fn( f, args, fn_kwargs, "main", true; input_shardings, runtime ) finally + deactivate_backend!(backend) deactivate_raising!(is_raising) deactivate_sdycache!(sdycache) deactivate_callcache!(callcache)