Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix creation of traced values #1007

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 50 additions & 37 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,39 @@ function make_mlir_fn(

num_partitions, num_replicas = 1, 1

ctx = MLIR.IR.context()
mod = MLIR.IR.mmodule()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name * "_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
)
end

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)
Ops.activate_constant_context!(fnbody)

N = length(args)
seen_args = OrderedIdDict()
traced_args = Vector{Any}(undef, N)
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args,
args[i],
(:args, i),
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
toscalar,
runtime,
)

try
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args,
args[i],
(:args, i),
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
toscalar,
runtime,
)
end
finally
MLIR.IR.deactivate!(fnbody)
Ops.deactivate_constant_context!(fnbody)
end

linear_args = Reactant.TracedType[]
Expand All @@ -247,9 +268,6 @@ function make_mlir_fn(
sym_visibility = MLIR.IR.Attribute("private")
end

ctx = MLIR.IR.context()
mod = MLIR.IR.mmodule()

# Insert meshes for the sharded arguments
traced_args_to_shardings = OrderedIdDict()
for (k, v) in seen_args
Expand All @@ -264,17 +282,6 @@ function make_mlir_fn(
end
end

func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name * "_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
)
end

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)
Ops.activate_constant_context!(fnbody)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@assert MLIR.IR._has_block()

Expand Down Expand Up @@ -312,23 +319,29 @@ function make_mlir_fn(

seen_results = OrderedIdDict()

traced_result = Reactant.make_tracer(
seen_results,
result,
(:result,),
concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath;
runtime,
)

# marks buffers to be donated
for i in 1:N
Reactant.make_tracer(
MLIR.IR.activate!(fnbody)
traced_result = try
traced_result = Reactant.make_tracer(
seen_results,
traced_args[i],
concretein ? (:resargs, i) : (),
Reactant.NoStopTracedTrack;
result,
(:result,),
concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath;
runtime,
)

# marks buffers to be donated
for i in 1:N
Reactant.make_tracer(
seen_results,
traced_args[i],
concretein ? (:resargs, i) : (),
Reactant.NoStopTracedTrack;
runtime,
)
end
traced_result
finally
MLIR.IR.deactivate!(fnbody)
end

linear_results = Reactant.TracedType[]
Expand Down
Loading