Skip to content

Commit

Permalink
Optimize calling a known function
Browse files Browse the repository at this point in the history
  • Loading branch information
vouillon committed Jan 7, 2025
1 parent 6b05a5c commit d41d4c0
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 97 deletions.
17 changes: 12 additions & 5 deletions compiler/lib-wasm/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,16 @@ module Generate (Target : Target_sig.S) = struct

let zero_divide_pc = -2

let exact_call kind =
match kind with
| Generic -> false
| Exact | Known _ -> true

let rec translate_expr ctx context x e =
match e with
| Apply { f; args; exact }
when exact || List.length args = if Var.Set.mem x ctx.in_cps then 2 else 1 ->
| Apply { f; args; kind }
when exact_call kind || List.length args = if Var.Set.mem x ctx.in_cps then 2 else 1
->
let rec loop acc l =
match l with
| [] -> (
Expand All @@ -204,13 +210,14 @@ module Generate (Target : Target_sig.S) = struct
if b
then return (W.Call (f, List.rev (closure :: acc)))
else
match funct with
| W.RefFunc g ->
match funct, kind with
| W.RefFunc g, _ ->
(* Functions with constant closures ignore their
environment. In case of partial application, we
still need the closure. *)
let* cl = if exact then Value.unit else return closure in
let* cl = if exact_call kind then Value.unit else return closure in
return (W.Call (g, List.rev (cl :: acc)))
| _, Known g -> return (W.Call (g, List.rev (closure :: acc)))
| _ -> return (W.Call_ref (ty, funct, List.rev (closure :: acc))))
| x :: r ->
let* x = load x in
Expand Down
17 changes: 12 additions & 5 deletions compiler/lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,16 @@ type field_type =
| Non_float
| Float

type apply_kind =
| Generic
| Exact
| Known of Var.t

type expr =
| Apply of
{ f : Var.t
; args : Var.t list
; exact : bool
; kind : apply_kind
}
| Block of int * Var.t array * array_or_not * mutability
| Field of Var.t * int * field_type
Expand Down Expand Up @@ -556,10 +561,12 @@ module Print = struct

let expr f e =
match e with
| Apply { f = g; args; exact } ->
if exact
then Format.fprintf f "%a!(%a)" Var.print g var_list args
else Format.fprintf f "%a(%a)" Var.print g var_list args
| Apply { f = g; args; kind } -> (
match kind with
| Generic -> Format.fprintf f "%a(%a)" Var.print g var_list args
| Exact -> Format.fprintf f "%a!(%a)" Var.print g var_list args
| Known h -> Format.fprintf f "%a{=%a}(%a)" Var.print g Var.print h var_list args
)
| Block (t, a, _, mut) ->
Format.fprintf
f
Expand Down
7 changes: 6 additions & 1 deletion compiler/lib/code.mli
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,16 @@ type field_type =
| Non_float
| Float

type apply_kind =
| Generic
| Exact (* # of arguments = # of parameters *)
| Known of Var.t (* Exact and we know which function is called *)

type expr =
| Apply of
{ f : Var.t
; args : Var.t list
; exact : bool (* if true, then # of arguments = # of parameters *)
; kind : apply_kind
}
| Block of int * Var.t array * array_or_not * mutability
| Field of Var.t * int * field_type
Expand Down
68 changes: 40 additions & 28 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,15 @@ let allocate_closure ~st ~params ~body ~branch =
let name = Var.fresh () in
[ Let (name, Closure (params, (pc, []))) ], name

let tail_call ~st ?(instrs = []) ~exact ~in_cps ~check ~f args =
assert (exact || check);
let tail_call ~st ?(instrs = []) ~kind ~in_cps ~check ~f args =
assert (
match kind with
| Generic -> check
| Exact | Known _ -> true);
let ret = Var.fresh () in
if check then st.trampolined_calls := Var.Set.add ret !(st.trampolined_calls);
if in_cps then st.in_cps := Var.Set.add ret !(st.in_cps);
instrs @ [ Let (ret, Apply { f; args; exact }) ], Return ret
instrs @ [ Let (ret, Apply { f; args; kind }) ], Return ret

let cps_branch ~st ~src (pc, args) =
match Addr.Set.mem pc st.blocks_to_transform with
Expand All @@ -359,14 +362,8 @@ let cps_branch ~st ~src (pc, args) =
(* We check the stack depth only for backward edges (so, at
least once per loop iteration) *)
let check = Hashtbl.find st.block_order src >= Hashtbl.find st.block_order pc in
tail_call
~st
~instrs
~exact:true
~in_cps:false
~check
~f:(closure_of_pc ~st pc)
args
let f = closure_of_pc ~st pc in
tail_call ~st ~instrs ~kind:(Known f) ~in_cps:false ~check ~f args

let cps_jump_cont ~st ~src ((pc, _) as cont) =
match Addr.Set.mem pc st.blocks_to_transform with
Expand Down Expand Up @@ -433,7 +430,7 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
(* If the number of successive 'returns' is unbounded in CPS, it
means that we have an unbounded of calls in direct style
(even with tail call optimization) *)
tail_call ~st ~exact:true ~in_cps:false ~check:false ~f:k [ x ]
tail_call ~st ~kind:Exact ~in_cps:false ~check:false ~f:k [ x ]
| Raise (x, rmode) -> (
assert (List.is_empty alloc_jump_closures);
match Hashtbl.find_opt st.matching_exn_handler pc with
Expand Down Expand Up @@ -468,7 +465,7 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
tail_call
~st
~instrs:(Let (exn_handler, Prim (Extern "caml_pop_trap", [])) :: instrs)
~exact:true
~kind:Exact
~in_cps:false
~check:false
~f:exn_handler
Expand Down Expand Up @@ -522,6 +519,14 @@ let cps_last ~st ~alloc_jump_closures pc (last : last) ~k : instr list * last =
@ (Let (exn_handler, Prim (Extern "caml_pop_trap", [])) :: body)
, branch ))

let refine_kind k k' =
match k, k' with
| Known _, _ -> k
| _, Known _ -> k'
| Exact, _ -> k
| _, Exact -> k'
| Generic, Generic -> k

let rewrite_instr ~st (instr : instr) : instr =
match instr with
| Let (x, Closure (_, (pc, _))) when Var.Set.mem x st.cps_needed ->
Expand All @@ -542,27 +547,34 @@ let rewrite_instr ~st (instr : instr) : instr =
(Extern "caml_alloc_dummy_function", [ size; Pc (Int (Targetint.succ a)) ])
)
| _ -> assert false)
| Let (x, Apply { f; args; _ }) when not (Var.Set.mem x st.cps_needed) ->
| Let (x, Apply { f; args; kind }) when not (Var.Set.mem x st.cps_needed) ->
(* At the moment, we turn into CPS any function not called with
the right number of parameter *)
assert (
let kind' =
(* If this function is unknown to the global flow analysis, then it was
introduced by the lambda lifting and we don't have exactness info any more. *)
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
|| Global_flow.exact_call st.flow_info f (List.length args));
Let (x, Apply { f; args; exact = true })
if Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
then Exact
else Global_flow.apply_kind st.flow_info f (List.length args)
in
assert (
match kind' with
| Generic -> false
| Exact | Known _ -> true);
Let (x, Apply { f; args; kind = refine_kind kind kind' })
| Let (_, e) when effect_primitive_or_application e ->
(* For the CPS target, applications of CPS functions and effect primitives require
more work (allocating a continuation and/or modifying end-of-block branches) and
are handled in a specialized function. *)
assert false
| _ -> instr

let call_exact flow_info (f : Var.t) nargs : bool =
let call_kind flow_info (f : Var.t) nargs =
(* If [f] is unknown to the global flow analysis, then it was introduced by
the lambda lifting and we don't have exactness about it. *)
Var.idx f < Var.Tbl.length flow_info.Global_flow.info_approximation
&& Global_flow.exact_call flow_info f nargs
if Var.idx f >= Var.Tbl.length flow_info.Global_flow.info_approximation
then Generic
else Global_flow.apply_kind flow_info f nargs

let cps_instr ~st (instr : instr) : instr list =
match instr with
Expand All @@ -571,7 +583,7 @@ let cps_instr ~st (instr : instr) : instr list =
Otherwise, the runtime primitive is used. *)
let unit = Var.fresh_n "unit" in
[ Let (unit, Constant (Int Targetint.zero))
; Let (x, Apply { exact = call_exact st.flow_info f 1; f; args = [ unit ] })
; Let (x, Apply { kind = call_kind st.flow_info f 1; f; args = [ unit ] })
]
| _ -> [ rewrite_instr ~st instr ]

Expand Down Expand Up @@ -646,11 +658,11 @@ let cps_block ~st ~k ~orig_pc block =
[ Let (x, e) ], Return x)
in
match e with
| Apply { f; args; exact } when Var.Set.mem x st.cps_needed ->
| Apply { f; args; kind } when Var.Set.mem x st.cps_needed ->
Some
(fun ~k ->
let exact = exact || call_exact st.flow_info f (List.length args) in
tail_call ~st ~exact ~in_cps:true ~check:true ~f (args @ [ k ]))
let kind = refine_kind kind (call_kind st.flow_info f (List.length args)) in
tail_call ~st ~kind ~in_cps:true ~check:true ~f (args @ [ k ]))
| Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg; tail ]) ->
Some
(fun ~k ->
Expand All @@ -659,7 +671,7 @@ let cps_block ~st ~k ~orig_pc block =
~st
~instrs:
[ Let (k', Prim (Extern "caml_resume_stack", [ Pv stack; tail; Pv k ])) ]
~exact:(call_exact st.flow_info f 1)
~kind:(call_kind st.flow_info f 1)
~in_cps:true
~check:true
~f
Expand Down Expand Up @@ -747,8 +759,8 @@ let rewrite_direct_block ~st ~cps_needed ~closure_info ~pc block =
(* We just need to call [f] in direct style. *)
let unit = Var.fresh_n "unit" in
let unit_val = Int Targetint.zero in
let exact = call_exact st.flow_info f 1 in
[ Let (unit, Constant unit_val); Let (x, Apply { exact; f; args = [ unit ] }) ]
let kind = call_kind st.flow_info f 1 in
[ Let (unit, Constant unit_val); Let (x, Apply { kind; f; args = [ unit ] }) ]
| (Let _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _ | Event _) as instr
-> [ instr ]
in
Expand Down
14 changes: 12 additions & 2 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,14 @@ module Share = struct
List.fold_left block.body ~init:share ~f:(fun share i ->
match i with
| Let (_, Constant c) -> get_constant c share
| Let (x, Apply { args; exact; _ }) ->
| Let (x, Apply { args; kind; _ }) ->
let trampolined = Var.Set.mem x trampolined_calls in
let in_cps = Var.Set.mem x in_cps in
let exact =
match kind with
| Generic -> false
| Exact | Known _ -> true
in
if (not exact) || trampolined
then
add_apply
Expand Down Expand Up @@ -1230,7 +1235,12 @@ let remove_unused_tail_args ctx exact trampolined args =
let rec translate_expr ctx loc x e level : (_ * J.statement_list) Expr_builder.t =
let open Expr_builder in
match e with
| Apply { f; args; exact } ->
| Apply { f; args; kind } ->
let exact =
match kind with
| Generic -> false
| Exact | Known _ -> true
in
let trampolined = Var.Set.mem x ctx.Ctx.trampolined_calls in
let args = remove_unused_tail_args ctx exact trampolined args in
let* () = info ~need_loc:true mutator_p in
Expand Down
14 changes: 7 additions & 7 deletions compiler/lib/generate_closure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ let rec collect_apply pc blocks visited tc =
match block.branch with
| Return x -> (
match List.last block.body with
| Some (Let (y, Apply { f; exact = true; _ })) when Code.Var.compare x y = 0 ->
Some (add_multi f pc tc)
| Some (Let (y, Apply { f; kind = Exact | Known _; _ }))
when Code.Var.compare x y = 0 -> Some (add_multi f pc tc)
| None -> None
| Some _ -> None)
| _ -> None
Expand Down Expand Up @@ -100,7 +100,7 @@ module Trampoline = struct
match counter with
| None ->
{ params = []
; body = [ Let (return, Apply { f; args; exact = true }) ]
; body = [ Let (return, Apply { f; args; kind = Known f }) ]
; branch = Return return
}
| Some counter ->
Expand All @@ -110,7 +110,7 @@ module Trampoline = struct
[ Let
( counter_plus_1
, Prim (Extern "%int_add", [ Pv counter; Pc (Int Targetint.one) ]) )
; Let (return, Apply { f; args = counter_plus_1 :: args; exact = true })
; Let (return, Apply { f; args = counter_plus_1 :: args; kind = Known f })
]
; branch = Return return
}
Expand Down Expand Up @@ -139,14 +139,14 @@ module Trampoline = struct
(match counter with
| None ->
[ Event loc
; Let (result1, Apply { f; args; exact = true })
; Let (result1, Apply { f; args; kind = Known f })
; Event Parse_info.zero
; Let (result2, Prim (Extern "caml_trampoline", [ Pv result1 ]))
]
| Some counter ->
[ Event loc
; Let (counter, Constant (Int Targetint.zero))
; Let (result1, Apply { f; args = counter :: args; exact = true })
; Let (result1, Apply { f; args = counter :: args; kind = Known f })
; Event Parse_info.zero
; Let (result2, Prim (Extern "caml_trampoline", [ Pv result1 ]))
])
Expand Down Expand Up @@ -222,7 +222,7 @@ module Trampoline = struct
let bounce_call_pc = free_pc + 1 in
let free_pc = free_pc + 2 in
match List.rev block.body with
| Let (x, Apply { f; args; exact = true }) :: rem_rev ->
| Let (x, Apply { f; args; kind = Exact | Known _ }) :: rem_rev ->
assert (Var.equal f ci.f_name);
let blocks =
Addr.Map.add
Expand Down
39 changes: 26 additions & 13 deletions compiler/lib/global_flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -704,17 +704,29 @@ let f ~fast p =
; info_return_vals = rets
}

let exact_call info f n =
let apply_kind info f n =
match Var.Tbl.get info.info_approximation f with
| Top | Values { others = true; _ } -> false
| Values { known; others = false } ->
Var.Set.for_all
(fun g ->
match info.info_defs.(Var.idx g) with
| Expr (Closure (params, _)) -> List.length params = n
| Expr (Block _) -> true
| Expr _ | Phi _ -> assert false)
known
| Top | Values { others = true; _ } -> Generic
| Values { known; others = false } -> (
match
Var.Set.fold
(fun g acc ->
match info.info_defs.(Var.idx g) with
| Expr (Closure (params, _)) ->
if List.length params = n
then
match acc with
| None -> Some (Known g)
| Some (Known _) -> Some Exact
| Some (Exact | Generic) -> acc
else Some Generic
| Expr (Block _) -> acc
| Expr _ | Phi _ -> assert false)
known
None
with
| None -> Exact
| Some kind -> kind)

let function_arity info f =
match Var.Tbl.get info.info_approximation f with
Expand All @@ -727,9 +739,10 @@ let function_arity info f =
| Expr (Closure (params, _)) -> (
let n = List.length params in
match acc with
| None -> Some (Some n)
| Some (Some n') when n <> n' -> Some None
| Some _ -> acc)
| None -> Some (Some (n, Known g))
| Some (Some (n', _)) when n <> n' -> Some None
| Some (Some (_, Known _)) -> Some (Some (n, Exact))
| Some (None | Some (_, (Exact | Generic))) -> acc)
| Expr (Block _) -> acc
| Expr _ | Phi _ -> assert false)
known
Expand Down
4 changes: 2 additions & 2 deletions compiler/lib/global_flow.mli
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ type info =

val f : fast:bool -> Code.program -> info

val exact_call : info -> Var.t -> int -> bool
val apply_kind : info -> Var.t -> int -> Code.apply_kind

val function_arity : info -> Var.t -> int option
val function_arity : info -> Var.t -> (int * Code.apply_kind) option
Loading

0 comments on commit d41d4c0

Please sign in to comment.