Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] control flow if #735

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
CC = Core.Compiler
ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{
typeof(Reactant.set_reactant_abi)
}
EnzymeInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter

shift_off(s, _) = s
shift_off(s::Core.SSAValue, new_index::Vector) = Core.SSAValue(new_index[s.id])


apply(c::Expr, new_index) = begin
return Expr(c.head, (shift_off(apply(a, new_index), new_index) for a in c.args)...)
end

apply(c, _new_index) = c

#add a conversion to Bool before a lowered if
goto_if_not_protection(src::Core.CodeInfo) = begin
new_index = []
offset = 0
for (i, t) in enumerate(typeof.(src.code))
t == Core.GotoIfNot && (offset += 2)
push!(new_index, i + offset)
end

nc = []
ncl = []
for (i, c) in enumerate(src.code)
v = nothing
if c isa Core.GotoIfNot
push!(nc, GlobalRef(Main, :convert))
push!(nc, Expr(:call, (Core.SSAValue(new_index[i] - 2), GlobalRef(Main, :Bool), shift_off(c.cond, new_index))...))
append!(ncl, [src.codelocs[i] for _ in 1:2])
v = Core.GotoIfNot(Core.SSAValue(new_index[i] - 1), new_index[c.dest])
elseif c isa Core.GotoNode
v = Core.GotoNode(new_index[c.label])
elseif c isa Core.ReturnNode
v = Core.ReturnNode(shift_off(c.val, new_index))
elseif c isa Expr
v = apply(c, new_index)
else
v = c
end
push!(nc, v)
push!(ncl, src.codelocs[i])
end
new = copy(src)
new.code = nc
new.codelocs = ncl
for _ in 1:offset
push!(new.ssaflags, 0x00000000)
end
new.ssavaluetypes = src.ssavaluetypes + offset
return new
end

vec = []
vec2 = []
function CC.inlining_policy(
interp::ReactantInter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt32,
)
#typeof(src) in [CC.IRCode, Core.CodeInfo] || return;
#push!(vec, (CC.copy(src), info))
#push!(vec2, stacktrace())
#=info isa CC.ConstCallInfo && (info = info.call)
push!(vec, info)
if info isa MethodMatchInfo
mm::Core.MethodMatch = first(info.results.matches)
m::Method = mm.method
if m.name == :convert && m.sig isa DataType
if m.sig.types[3] == Reactant.TracedRNumber{Bool}
return true
end
end
end
#push!(vec2, info)
=#
return nothing
@invoke CC.inlining_policy(
interp::EnzymeInter, src, info::CC.CallInfo, stmt_flag::UInt32
)
end

#=vec2 = []
CC.finish!(ji::ReactantInter, caller::CC.InferenceState) = begin
res = @invoke CC.finish!(ji::EnzymeInter, caller::CC.InferenceState)
push!(vec2, res)
end
=#


22 changes: 21 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@ function rewrite_insts!(ir, interp, guaranteed_error)
return ir, any_changed
end

include("ext.jl")

global dico = Dict()
global dico2 = Dict()

# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
# In particular this entails two pieces:
# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance
Expand Down Expand Up @@ -562,7 +567,19 @@ function call_with_reactant_generator(
ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing)
rt = CC.widenconst(CC.ignorelimited(result.result))
else
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
result = CC.InferenceResult(mi, CC.typeinf_lattice(interp))
world = CC.get_inference_world(interp)
src = CC.retrieve_code_info(result.linfo, world)
#dico2[mi]=(CC.copy(src), goto_if_not_protection(src))
@error src
src = goto_if_not_protection(src)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Billy's concern was that since the transformation here happens before inlining, the code that will later get doesn't have the needed conversions inserted.
I'm wondering if you run into this issue?
If so, the idea was to do the transformation by specializing the abstract interpreter.
If that's necessary, the correct place to do it would be InferenceState like I did here:
https://github.com/JuliaLLVM/MLIR.jl/blob/cf22c54222033ed4f5f3d68a620448bcaf766ac4/src/Generate/absint.jl#L151C1-L160C4
based on an old pr from valentin:
JuliaGPU/GPUCompiler.jl#311

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an example of such a case in mind! I will try with this one firstly (seem easier to handle).

@error src
CC.maybe_validate_code(result.linfo, src, "lowered")
frame = CC.InferenceState(result, src, :no, interp)
CC.typeinf(interp, frame)
opt = CC.OptimizationState(frame, interp)
ir = CC.run_passes_ipo_safe(opt.src, opt, result)
rt = CC.widenconst(CC.ignorelimited(result.result))
end

if guaranteed_error
Expand Down Expand Up @@ -768,3 +785,6 @@ end
$(Expr(:meta, :generated_only))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end



Loading