Skip to content

Commit 8d57b1b

Browse files
committed
AbstractInterpreter: add a hook to customize bestguess calculation
Currently, the code that updates `bestguess` using `ReturnNode` information includes hardcodes that relate to `Conditional` and `LimitedAccuracy`. These behaviors are actually lattice-dependent and therefore should be overloadable by `AbstractInterpreter`. Additionally, particularly in Diffractor, a clever strategy is required to update return types in a way that it takes into account information from both the original method and its rule method (xref: JuliaDiff/Diffractor.jl#202). This also requires such an overload to exist. In response to these needs, this commit introduces an implementation of a hook named `update_bestguess!`.
1 parent 441fcb1 commit 8d57b1b

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

base/compiler/abstractinterpretation.jl

+37-30
Original file line numberDiff line numberDiff line change
@@ -2892,17 +2892,49 @@ function init_vartable!(vartable::VarTable, frame::InferenceState)
28922892
return vartable
28932893
end
28942894

2895+
function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
2896+
currstate::VarTable, @nospecialize(rt))
2897+
bestguess = frame.bestguess
2898+
nargs = narguments(frame, #=include_va=#false)
2899+
slottypes = frame.slottypes
2900+
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
2901+
# narrow representation of bestguess slightly to prepare for tmerge with rt
2902+
if rt isa InterConditional && bestguess isa Const
2903+
slot_id = rt.slot
2904+
old_id_type = slottypes[slot_id]
2905+
if bestguess.val === true && rt.elsetype !== Bottom
2906+
bestguess = InterConditional(slot_id, old_id_type, Bottom)
2907+
elseif bestguess.val === false && rt.thentype !== Bottom
2908+
bestguess = InterConditional(slot_id, Bottom, old_id_type)
2909+
end
2910+
end
2911+
# copy limitations to return value
2912+
if !isempty(frame.pclimitations)
2913+
union!(frame.limitations, frame.pclimitations)
2914+
empty!(frame.pclimitations)
2915+
end
2916+
if !isempty(frame.limitations)
2917+
rt = LimitedAccuracy(rt, copy(frame.limitations))
2918+
end
2919+
𝕃ₚ = ipo_lattice(interp)
2920+
if !(𝕃ₚ, rt, bestguess)
2921+
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
2922+
frame.bestguess = tmerge(𝕃ₚ, bestguess, rt) # new (wider) return type for frame
2923+
return true
2924+
else
2925+
return false
2926+
end
2927+
end
2928+
28952929
# make as much progress on `frame` as possible (without handling cycles)
28962930
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
28972931
@assert !is_inferred(frame)
28982932
frame.dont_work_on_me = true # mark that this function is currently on the stack
28992933
W = frame.ip
2900-
nargs = narguments(frame, #=include_va=#false)
2901-
slottypes = frame.slottypes
29022934
ssavaluetypes = frame.ssavaluetypes
29032935
bbs = frame.cfg.blocks
29042936
nbbs = length(bbs)
2905-
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
2937+
𝕃ᵢ = typeinf_lattice(interp)
29062938

29072939
currbb = frame.currbb
29082940
if currbb != 1
@@ -3003,35 +3035,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
30033035
end
30043036
end
30053037
elseif isa(stmt, ReturnNode)
3006-
bestguess = frame.bestguess
30073038
rt = abstract_eval_value(interp, stmt.val, currstate, frame)
3008-
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
3009-
# narrow representation of bestguess slightly to prepare for tmerge with rt
3010-
if rt isa InterConditional && bestguess isa Const
3011-
let slot_id = rt.slot
3012-
old_id_type = slottypes[slot_id]
3013-
if bestguess.val === true && rt.elsetype !== Bottom
3014-
bestguess = InterConditional(slot_id, old_id_type, Bottom)
3015-
elseif bestguess.val === false && rt.thentype !== Bottom
3016-
bestguess = InterConditional(slot_id, Bottom, old_id_type)
3017-
end
3018-
end
3019-
end
3020-
# copy limitations to return value
3021-
if !isempty(frame.pclimitations)
3022-
union!(frame.limitations, frame.pclimitations)
3023-
empty!(frame.pclimitations)
3024-
end
3025-
if !isempty(frame.limitations)
3026-
rt = LimitedAccuracy(rt, copy(frame.limitations))
3027-
end
3028-
if !(𝕃ₚ, rt, bestguess)
3029-
# new (wider) return type for frame
3030-
bestguess = tmerge(𝕃ₚ, bestguess, rt)
3031-
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
3032-
frame.bestguess = bestguess
3039+
if update_bestguess!(interp, frame, currstate, rt)
30333040
for (caller, caller_pc) in frame.cycle_backedges
3034-
if !(caller.ssavaluetypes[caller_pc] === Any)
3041+
if caller.ssavaluetypes[caller_pc] !== Any
30353042
# no reason to revisit if that call-site doesn't affect the final result
30363043
push!(caller.ip, block_for_inst(caller.cfg, caller_pc))
30373044
end

base/compiler/typeinfer.jl

+21-18
Original file line numberDiff line numberDiff line change
@@ -870,26 +870,10 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
870870
# since the inliner will request to use it later
871871
cache = :local
872872
else
873+
rt = cached_return_type(code)
873874
effects = ipo_effects(code)
874875
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
875-
rettype = code.rettype
876-
if isdefined(code, :rettype_const)
877-
rettype_const = code.rettype_const
878-
# the second subtyping/egal conditions are necessary to distinguish usual cases
879-
# from rare cases when `Const` wrapped those extended lattice type objects
880-
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
881-
rettype = PartialStruct(rettype, rettype_const)
882-
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
883-
rettype = rettype_const
884-
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
885-
rettype = rettype_const
886-
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
887-
rettype = rettype_const
888-
else
889-
rettype = Const(rettype_const)
890-
end
891-
end
892-
return EdgeCallResult(rettype, mi, effects)
876+
return EdgeCallResult(rt, mi, effects)
893877
end
894878
else
895879
cache = :global # cache edge targets by default
@@ -933,6 +917,25 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
933917
return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame))
934918
end
935919

920+
function cached_return_type(code::CodeInstance)
921+
rettype = code.rettype
922+
isdefined(code, :rettype_const) || return rettype
923+
rettype_const = code.rettype_const
924+
# the second subtyping/egal conditions are necessary to distinguish usual cases
925+
# from rare cases when `Const` wrapped those extended lattice type objects
926+
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
927+
return PartialStruct(rettype, rettype_const)
928+
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
929+
return rettype_const
930+
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
931+
return rettype_const
932+
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
933+
return rettype_const
934+
else
935+
return Const(rettype_const)
936+
end
937+
end
938+
936939
#### entry points for inferring a MethodInstance given a type signature ####
937940

938941
# compute an inferred AST and return type

0 commit comments

Comments
 (0)