diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a7f4d4997d..c005190cc7 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -211,18 +211,39 @@ function make_mlir_fn( num_partitions, num_replicas = 1, 1 + ctx = MLIR.IR.context() + mod = MLIR.IR.mmodule() + + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) + end + + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) + push!(MLIR.IR.region(func, 1), fnbody) + Ops.activate_constant_context!(fnbody) + N = length(args) seen_args = OrderedIdDict() traced_args = Vector{Any}(undef, N) - for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, - args[i], - (:args, i), - concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; - toscalar, - runtime, - ) + + try + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, + args[i], + (:args, i), + concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; + toscalar, + runtime, + ) + end + finally + MLIR.IR.deactivate!(fnbody) + Ops.deactivate_constant_context!(fnbody) end linear_args = Reactant.TracedType[] @@ -247,9 +268,6 @@ function make_mlir_fn( sym_visibility = MLIR.IR.Attribute("private") end - ctx = MLIR.IR.context() - mod = MLIR.IR.mmodule() - # Insert meshes for the sharded arguments traced_args_to_shardings = OrderedIdDict() for (k, v) in seen_args @@ -264,17 +282,6 @@ function make_mlir_fn( end end - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - ) - end - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) - push!(MLIR.IR.region(func, 1), fnbody) - Ops.activate_constant_context!(fnbody) @assert MLIR.IR._has_block() @@ -312,23 +319,29 @@ function make_mlir_fn( seen_results = OrderedIdDict() - traced_result = Reactant.make_tracer( - seen_results, - result, - (:result,), - concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath; - runtime, - ) - - # marks buffers to be donated - for i in 1:N - Reactant.make_tracer( + MLIR.IR.activate!(fnbody) + traced_result = try + traced_result = Reactant.make_tracer( seen_results, - traced_args[i], - concretein ? (:resargs, i) : (), - Reactant.NoStopTracedTrack; + result, + (:result,), + concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath; runtime, ) + + # marks buffers to be donated + for i in 1:N + Reactant.make_tracer( + seen_results, + traced_args[i], + concretein ? (:resargs, i) : (), + Reactant.NoStopTracedTrack; + runtime, + ) + end + traced_result + finally + MLIR.IR.deactivate!(fnbody) end linear_results = Reactant.TracedType[]