Skip to content

Commit b3e076e

Browse files
committed
More precise return types
1 parent 5793fed commit b3e076e

13 files changed

+342
-83
lines changed

compiler/lib-wasm/code_generation.ml

+89-6
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
199199
(* I31, struct, array and none have no other subtype *)
200200
| _, (I31 | Type _ | Struct | Array | None_) -> false, st
201201

202+
(*ZZZ*)
203+
let rec type_index_lub ty ty' st =
204+
if Var.equal ty ty'
205+
then Some ty
206+
else
207+
let type_field = Hashtbl.find st.context.types ty in
208+
match type_field.supertype with
209+
| None -> None
210+
| Some ty -> (
211+
match type_index_lub ty ty' st with
212+
| Some ty -> Some ty
213+
| None -> (
214+
let type_field = Hashtbl.find st.context.types ty' in
215+
match type_field.supertype with
216+
| None -> None
217+
| Some ty' -> type_index_lub ty ty' st))
218+
219+
let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
220+
match ty, ty' with
221+
| (Func | Extern), _ | _, (Func | Extern) -> assert false
222+
| None_, _ -> return ty'
223+
| _, None_ | Struct, Struct | Array, Array -> return ty
224+
| Any, _ | _, Any -> return W.Any
225+
| Eq, _
226+
| _, Eq
227+
| (Struct | Array | Type _), I31
228+
| I31, (Struct | Array | Type _)
229+
| Struct, Array
230+
| Array, Struct -> return (Eq : W.heap_type)
231+
| Struct, Type t | Type t, Struct -> (
232+
fun st ->
233+
let type_field = Hashtbl.find st.context.types t in
234+
match type_field.typ with
235+
| Struct _ -> W.Struct, st
236+
| Array _ | Func _ -> W.Eq, st)
237+
| Array, Type t | Type t, Array -> (
238+
fun st ->
239+
let type_field = Hashtbl.find st.context.types t in
240+
match type_field.typ with
241+
| Array _ -> W.Struct, st
242+
| Struct _ | Func _ -> W.Eq, st)
243+
| Type t, Type t' -> (
244+
let* r = fun st -> type_index_lub t t' st, st in
245+
match r with
246+
| Some t'' -> return (Type t'' : W.heap_type)
247+
| None -> (
248+
fun st ->
249+
let type_field = Hashtbl.find st.context.types t in
250+
let type_field' = Hashtbl.find st.context.types t' in
251+
match type_field.typ, type_field'.typ with
252+
| Struct _, Struct _ -> (Struct : W.heap_type), st
253+
| Array _, Array _ -> W.Array, st
254+
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
255+
| I31, I31 -> return W.I31
256+
257+
let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
258+
match ty, ty' with
259+
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
260+
let* typ = heap_type_lub typ typ' in
261+
return (W.Ref { nullable = nullable || nullable'; typ })
262+
| _ -> assert false
263+
202264
let register_global name ?exported_name ?(constant = false) typ init st =
203265
st.context.other_fields <-
204266
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
@@ -700,13 +762,28 @@ let push e =
700762
instr (Push e')
701763
| _ -> instr (Push e)
702764

765+
let blk' ty l st =
766+
let instrs = st.instrs in
767+
let (), st = l { st with instrs = [] } in
768+
let ty, st =
769+
match st.instrs with
770+
| Push e :: _ ->
771+
(let* ty' = expression_type e in
772+
match ty' with
773+
| None -> return ty
774+
| Some ty' -> return { ty with W.result = [ ty' ] })
775+
st
776+
| _ -> ty, st
777+
in
778+
(List.rev st.instrs, ty), { st with instrs }
779+
703780
let loop ty l =
704-
let* instrs = blk l in
705-
instr (Loop (ty, instrs))
781+
let* instrs, ty' = blk' ty l in
782+
instr (Loop (ty', instrs))
706783

707784
let block ty l =
708-
let* instrs = blk l in
709-
instr (Block (ty, instrs))
785+
let* instrs, ty' = blk' ty l in
786+
instr (Block (ty', instrs))
710787

711788
let block_expr ty l =
712789
let* instrs = blk l in
@@ -779,7 +856,7 @@ let init_code context = instrs context.init_code
779856

780857
let function_body ~context ~param_names ~body =
781858
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
782-
let (), st = body st in
859+
let res, st = body st in
783860
let local_count, body = st.var_count, List.rev st.instrs in
784861
let local_types = Array.make local_count (Var.fresh (), None) in
785862
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
@@ -797,4 +874,10 @@ let function_body ~context ~param_names ~body =
797874
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
798875
|> Array.to_list
799876
in
800-
locals, body
877+
locals, res, body
878+
879+
let eval ~context e =
880+
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
881+
let r, st = e st in
882+
assert (st.var_count = 0 && List.is_empty st.instrs);
883+
r

compiler/lib-wasm/code_generation.mli

+6-2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ val register_type : string -> (unit -> type_def t) -> Wasm_ast.var t
160160

161161
val heap_type_sub : Wasm_ast.heap_type -> Wasm_ast.heap_type -> bool t
162162

163+
val value_type_lub : Wasm_ast.value_type -> Wasm_ast.value_type -> Wasm_ast.value_type t
164+
163165
val register_import :
164166
?import_module:string -> name:string -> Wasm_ast.import_desc -> Wasm_ast.var t
165167

@@ -202,8 +204,8 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t
202204
val function_body :
203205
context:context
204206
-> param_names:Code.Var.t list
205-
-> body:unit t
206-
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
207+
-> body:'a t
208+
-> (Wasm_ast.var * Wasm_ast.value_type) list * 'a * Wasm_ast.instruction list
207209

208210
val variable_type : Code.Var.t -> Wasm_ast.value_type option t
209211

@@ -214,3 +216,5 @@ val array_placeholder : Code.Var.t -> expression
214216
val default_value :
215217
Wasm_ast.value_type
216218
-> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t
219+
220+
val eval : context:context -> 'a t -> 'a

compiler/lib-wasm/curry.ml

+42-11
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,16 @@ module Make (Target : Target_sig.S) = struct
9595
loop m [] f None
9696
in
9797
let param_names = args @ [ f ] in
98-
let locals, body = function_body ~context ~param_names ~body in
98+
let locals, _, body = function_body ~context ~param_names ~body in
9999
W.Function
100-
{ name; exported_name = None; typ = Type.func_type 1; param_names; locals; body }
100+
{ name
101+
; exported_name = None
102+
; typ = Type.func_type 1
103+
; type_name = Some (eval ~context (Type.function_type ~cps:false 1))
104+
; param_names
105+
; locals
106+
; body
107+
}
101108

102109
let curry_name n m = Printf.sprintf "curry_%d_%d" n m
103110

@@ -123,9 +130,16 @@ module Make (Target : Target_sig.S) = struct
123130
push (Closure.curry_allocate ~cps:false ~arity m ~f:name' ~closure:f ~arg:x)
124131
in
125132
let param_names = [ x; f ] in
126-
let locals, body = function_body ~context ~param_names ~body in
133+
let locals, _, body = function_body ~context ~param_names ~body in
127134
W.Function
128-
{ name; exported_name = None; typ = Type.func_type 1; param_names; locals; body }
135+
{ name
136+
; exported_name = None
137+
; typ = Type.func_type 1
138+
; type_name = Some (eval ~context (Type.function_type ~cps:false 1))
139+
; param_names
140+
; locals
141+
; body
142+
}
129143
:: functions
130144

131145
let curry ~arity ~name = curry ~arity arity ~name
@@ -167,9 +181,16 @@ module Make (Target : Target_sig.S) = struct
167181
loop m [] f None
168182
in
169183
let param_names = args @ [ f ] in
170-
let locals, body = function_body ~context ~param_names ~body in
184+
let locals, _, body = function_body ~context ~param_names ~body in
171185
W.Function
172-
{ name; exported_name = None; typ = Type.func_type 2; param_names; locals; body }
186+
{ name
187+
; exported_name = None
188+
; typ = Type.func_type 2
189+
; type_name = Some (eval ~context (Type.function_type ~cps:true 1))
190+
; param_names
191+
; locals
192+
; body
193+
}
173194

174195
let cps_curry_name n m = Printf.sprintf "cps_curry_%d_%d" n m
175196

@@ -199,9 +220,16 @@ module Make (Target : Target_sig.S) = struct
199220
instr (W.Return (Some c))
200221
in
201222
let param_names = [ x; cont; f ] in
202-
let locals, body = function_body ~context ~param_names ~body in
223+
let locals, _, body = function_body ~context ~param_names ~body in
203224
W.Function
204-
{ name; exported_name = None; typ = Type.func_type 2; param_names; locals; body }
225+
{ name
226+
; exported_name = None
227+
; typ = Type.func_type 2
228+
; type_name = Some (eval ~context (Type.function_type ~cps:true 1))
229+
; param_names
230+
; locals
231+
; body
232+
}
205233
:: functions
206234

207235
let cps_curry ~arity ~name = cps_curry ~arity arity ~name
@@ -236,11 +264,12 @@ module Make (Target : Target_sig.S) = struct
236264
build_applies (load f) l)
237265
in
238266
let param_names = l @ [ f ] in
239-
let locals, body = function_body ~context ~param_names ~body in
267+
let locals, _, body = function_body ~context ~param_names ~body in
240268
W.Function
241269
{ name
242270
; exported_name = None
243271
; typ = Type.primitive_type (arity + 1)
272+
; type_name = None
244273
; param_names
245274
; locals
246275
; body
@@ -282,11 +311,12 @@ module Make (Target : Target_sig.S) = struct
282311
push (call ~cps:true ~arity:2 (load f) [ x; iterate ]))
283312
in
284313
let param_names = l @ [ f ] in
285-
let locals, body = function_body ~context ~param_names ~body in
314+
let locals, _, body = function_body ~context ~param_names ~body in
286315
W.Function
287316
{ name
288317
; exported_name = None
289318
; typ = Type.primitive_type (arity + 1)
319+
; type_name = None
290320
; param_names
291321
; locals
292322
; body
@@ -316,11 +346,12 @@ module Make (Target : Target_sig.S) = struct
316346
instr (W.Return (Some e))
317347
in
318348
let param_names = l @ [ f ] in
319-
let locals, body = function_body ~context ~param_names ~body in
349+
let locals, _, body = function_body ~context ~param_names ~body in
320350
W.Function
321351
{ name
322352
; exported_name = None
323353
; typ = Type.func_type arity
354+
; type_name = Some (eval ~context (Type.function_type ~cps (arity - 1)))
324355
; param_names
325356
; locals
326357
; body

compiler/lib-wasm/gc_target.ml

+29-6
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,35 @@ module Type = struct
205205
let primitive_type n =
206206
{ W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] }
207207

208-
let func_type n = primitive_type (n + 1)
209-
210-
let function_type ~cps n =
211-
let n = if cps then n + 1 else n in
212-
register_type (Printf.sprintf "function_%d" n) (fun () ->
213-
return { supertype = None; final = true; typ = W.Func (func_type n) })
208+
let func_type ?(ret = value) n =
209+
{ W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ ret ] }
210+
211+
let rec function_type ~cps ?ret n =
212+
let n' = if cps then n + 1 else n in
213+
let ret_str =
214+
match ret with
215+
| None -> ""
216+
| Some (W.Ref { nullable = false; typ }) -> (
217+
match typ with
218+
| Eq -> "_eq" (*ZZZ remove ret in that case*)
219+
| I31 -> "_i31"
220+
| Struct -> "_struct"
221+
| Array -> "_array"
222+
| None_ -> "_none"
223+
| Type v -> (
224+
match Code.Var.get_name v with
225+
| None -> assert false
226+
| Some name -> "_" ^ name)
227+
| _ -> assert false)
228+
| _ -> assert false
229+
in
230+
register_type (Printf.sprintf "function_%d%s" n' ret_str) (fun () ->
231+
match ret with
232+
| None -> return { supertype = None; final = false; typ = W.Func (func_type n') }
233+
| Some ret ->
234+
let* super = function_type ~cps n in
235+
return
236+
{ supertype = Some super; final = false; typ = W.Func (func_type ~ret n') })
214237

215238
let closure_common_fields ~cps =
216239
let* fun_ty = function_type ~cps 1 in

0 commit comments

Comments
 (0)