Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 5128aa5

Browse files
authored
Merge pull request #118 from sjrd/better-receiver-type-for-static-dispatch
Better receiver types for static dispatch.
2 parents e36f30c + b256724 commit 5128aa5

8 files changed

+108
-89
lines changed

wasm/src/main/scala/ir2wasm/HelperFunctions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ object HelperFunctions {
550550

551551
val objectClassInfo = ctx.getClassInfo(IRNames.ObjectClass)
552552
instrs ++= objectClassInfo.tableEntries.map { methodName =>
553-
ctx.refFuncWithDeclaration(objectClassInfo.resolvedMethodInfos(methodName).wasmName)
553+
ctx.refFuncWithDeclaration(objectClassInfo.resolvedMethodInfos(methodName).tableEntryName)
554554
}
555555
instrs += STRUCT_NEW(WasmTypeName.WasmStructTypeName.ObjectVTable)
556556
instrs += LOCAL_TEE(arrayTypeDataLocal)

wasm/src/main/scala/ir2wasm/TypeTransformer.scala

+10-26
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,6 @@ object TypeTransformer {
3030
case _ => List(transformType(t))
3131
}
3232

33-
def transformTypeRef(t: IRTypes.TypeRef)(implicit ctx: ReadOnlyWasmContext): Types.WasmType =
34-
t match {
35-
case arrayTypeRef: IRTypes.ArrayTypeRef =>
36-
Types.WasmRefType.nullable(
37-
Names.WasmTypeName.WasmStructTypeName.forArrayClass(arrayTypeRef)
38-
)
39-
case IRTypes.ClassRef(className) =>
40-
transformClassByName(className)
41-
case IRTypes.PrimRef(tpe) => transformPrimType(tpe)
42-
}
43-
4433
/** Transforms a value type to a unique Wasm type.
4534
*
4635
* This method cannot be used for `void` and `nothing`, since they have no corresponding Wasm
@@ -54,28 +43,23 @@ object TypeTransformer {
5443
Types.WasmRefType.nullable(
5544
Names.WasmTypeName.WasmStructTypeName.forArrayClass(tpe.arrayTypeRef)
5645
)
57-
case IRTypes.ClassType(className) => transformClassByName(className)
46+
case IRTypes.ClassType(className) => transformClassType(className)
5847
case IRTypes.RecordType(fields) => ???
5948
case IRTypes.StringType | IRTypes.UndefType =>
6049
Types.WasmRefType.any
6150
case p: IRTypes.PrimTypeWithRef => transformPrimType(p)
6251
}
6352

64-
private def transformClassByName(
53+
def transformClassType(
6554
className: IRNames.ClassName
66-
)(implicit ctx: ReadOnlyWasmContext): Types.WasmType = {
67-
className match {
68-
case _ =>
69-
val info = ctx.getClassInfo(className)
70-
if (info.isAncestorOfHijackedClass)
71-
Types.WasmRefType.anyref
72-
else if (info.isInterface)
73-
Types.WasmRefType.nullable(Types.WasmHeapType.ObjectType)
74-
else
75-
Types.WasmRefType.nullable(
76-
Names.WasmTypeName.WasmStructTypeName.forClass(className)
77-
)
78-
}
55+
)(implicit ctx: ReadOnlyWasmContext): Types.WasmRefType = {
56+
val info = ctx.getClassInfo(className)
57+
if (info.isAncestorOfHijackedClass)
58+
Types.WasmRefType.anyref
59+
else if (info.isInterface)
60+
Types.WasmRefType.nullable(Types.WasmHeapType.ObjectType)
61+
else
62+
Types.WasmRefType.nullable(WasmTypeName.WasmStructTypeName.forClass(className))
7963
}
8064

8165
private def transformPrimType(

wasm/src/main/scala/ir2wasm/WasmBuilder.scala

+79-35
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
337337
val proxyId = ctx.getReflectiveProxyId(proxyInfo.methodName)
338338
List(
339339
I32_CONST(proxyId),
340-
REF_FUNC(proxyInfo.wasmName),
340+
REF_FUNC(proxyInfo.tableEntryName),
341341
STRUCT_NEW(WasmStructTypeName.reflectiveProxy)
342342
)
343343
} :+ ARRAY_NEW_FIXED(WasmArrayTypeName.reflectiveProxies, reflectiveProxies.size)
@@ -417,7 +417,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
417417
classInfo.resolvedMethodInfos.valuesIterator.filter(_.methodName.isReflectiveProxy).toList
418418
val typeDataFieldValues = genTypeDataFieldValues(clazz, reflectiveProxies)
419419
val vtableElems = classInfo.tableEntries.map { methodName =>
420-
REF_FUNC(classInfo.resolvedMethodInfos(methodName).wasmName)
420+
REF_FUNC(classInfo.resolvedMethodInfos(methodName).tableEntryName)
421421
}
422422
val globalVTable =
423423
genTypeDataGlobal(typeRef, vtableTypeName, typeDataFieldValues, vtableElems)
@@ -489,29 +489,35 @@ class WasmBuilder(coreSpec: CoreSpec) {
489489
ctor.name.name
490490
)
491491

492+
val resultTyp = WasmRefType(typeName)
493+
492494
implicit val fctx = WasmFunctionContext(
493495
WasmFunctionName.loadModule(clazz.className),
494496
Nil,
495-
List(WasmRefType.nullable(typeName))
497+
List(resultTyp)
496498
)
497499

500+
val instanceLocal = fctx.addLocal("instance", resultTyp)
501+
498502
import fctx.instrs
499503

500-
// global.get $module_name
501-
// ref.if_null
502-
// ref.null $module_type
503-
// call $module_init ;; should set to global
504-
// end
505-
// global.get $module_name
506-
instrs += GLOBAL_GET(globalInstanceName) // [rt]
507-
instrs += REF_IS_NULL // [rt] -> [i32] (bool)
508-
instrs += IF(BlockType.ValueType())
509-
instrs += CALL(WasmFunctionName.newDefault(clazz.name.name))
510-
instrs += GLOBAL_SET(globalInstanceName)
511-
instrs += GLOBAL_GET(globalInstanceName)
512-
instrs += CALL(ctorName)
513-
instrs += END
514-
instrs += GLOBAL_GET(globalInstanceName) // [rt]
504+
fctx.block(resultTyp) { nonNullLabel =>
505+
// load global, return if not null
506+
instrs += GLOBAL_GET(globalInstanceName)
507+
instrs += BR_ON_NON_NULL(nonNullLabel)
508+
509+
// create an instance and call its constructor
510+
instrs += CALL(WasmFunctionName.newDefault(clazz.name.name))
511+
instrs += LOCAL_TEE(instanceLocal)
512+
instrs += CALL(ctorName)
513+
514+
// store it in the global
515+
instrs += LOCAL_GET(instanceLocal)
516+
instrs += GLOBAL_SET(globalInstanceName)
517+
518+
// return it
519+
instrs += LOCAL_GET(instanceLocal)
520+
}
515521

516522
fctx.buildAndAddToContext()
517523
}
@@ -1076,31 +1082,28 @@ class WasmBuilder(coreSpec: CoreSpec) {
10761082
private def genFunction(
10771083
clazz: LinkedClass,
10781084
method: IRTrees.MethodDef
1079-
)(implicit ctx: WasmContext): WasmFunction = {
1085+
)(implicit ctx: WasmContext): Unit = {
10801086
implicit val pos = method.pos
10811087

1082-
val functionName = Names.WasmFunctionName(
1083-
method.flags.namespace,
1084-
clazz.name.name,
1085-
method.name.name
1086-
)
1088+
val namespace = method.flags.namespace
1089+
val className = clazz.className
1090+
val methodName = method.methodName
1091+
1092+
val functionName = Names.WasmFunctionName(namespace, className, methodName)
1093+
1094+
val isHijackedClass = ctx.getClassInfo(className).kind == ClassKind.HijackedClass
10871095

1088-
// Receiver type for non-constructor methods needs to be `(ref any)` because params are invariant
1089-
// Otherwise, vtable can't be a subtype of the supertype's subtype
1090-
// Constructor can use the exact type because it won't be registered to vtables.
10911096
val receiverTyp =
1092-
if (method.flags.namespace.isStatic)
1097+
if (namespace.isStatic)
10931098
None
1094-
else if (clazz.kind == ClassKind.HijackedClass)
1095-
Some(transformType(IRTypes.BoxedClassToPrimType(clazz.name.name)))
1096-
else if (method.flags.namespace.isConstructor)
1097-
Some(WasmRefType.nullable(WasmTypeName.WasmStructTypeName.forClass(clazz.name.name)))
1099+
else if (isHijackedClass)
1100+
Some(transformType(IRTypes.BoxedClassToPrimType(className)))
10981101
else
1099-
Some(WasmRefType.any)
1102+
Some(transformClassType(className).toNonNullable)
11001103

11011104
// Prepare for function context, set receiver and parameters
11021105
implicit val fctx = WasmFunctionContext(
1103-
Some(clazz.className),
1106+
Some(className),
11041107
functionName,
11051108
receiverTyp,
11061109
method.args,
@@ -1111,7 +1114,48 @@ class WasmBuilder(coreSpec: CoreSpec) {
11111114
val body = method.body.getOrElse(throw new Exception("abstract method cannot be transformed"))
11121115
WasmExpressionBuilder.generateIRBody(body, method.resultType)
11131116

1114-
fctx.buildAndAddToContext(useFunctionTypeInMainRecType = true)
1117+
fctx.buildAndAddToContext()
1118+
1119+
if (namespace == IRTrees.MemberNamespace.Public && !isHijackedClass) {
1120+
/* Also generate the bridge that is stored in the table entries. In table
1121+
* entries, the receiver type is always `(ref any)`.
1122+
*
1123+
* TODO: generate this only when the method is actually referred to from
1124+
* at least one table.
1125+
*/
1126+
1127+
implicit val fctx = WasmFunctionContext(
1128+
Some(className),
1129+
WasmFunctionName.forTableEntry(className, methodName),
1130+
Some(WasmRefType.any),
1131+
method.args,
1132+
method.resultType
1133+
)
1134+
1135+
import fctx.instrs
1136+
1137+
val receiverLocal :: paramLocals = fctx.paramIndices: @unchecked
1138+
1139+
// Load and cast down the receiver
1140+
instrs += LOCAL_GET(receiverLocal)
1141+
receiverTyp match {
1142+
case Some(Types.WasmRefType(_, WasmHeapType.Any)) =>
1143+
() // no cast necessary
1144+
case Some(receiverTyp: Types.WasmRefType) =>
1145+
instrs += REF_CAST(receiverTyp)
1146+
case _ =>
1147+
throw new AssertionError(s"Unexpected receiver type $receiverTyp")
1148+
}
1149+
1150+
// Load the other parameters
1151+
for (paramLocal <- paramLocals)
1152+
instrs += LOCAL_GET(paramLocal)
1153+
1154+
// Call the statically resolved method
1155+
instrs += RETURN_CALL(functionName)
1156+
1157+
fctx.buildAndAddToContext(useFunctionTypeInMainRecType = true)
1158+
}
11151159
}
11161160

11171161
private def transformField(

wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+7-23
Original file line numberDiff line numberDiff line change
@@ -1699,21 +1699,6 @@ private class WasmExpressionBuilder private (
16991699
case _ => t.tpe
17001700
}
17011701

1702-
/* If the receiver is a Class/ModuleClass, its wasm type will be declared
1703-
* as `(ref any)`, and therefore we must cast it down.
1704-
*/
1705-
fixedTpe match {
1706-
case IRTypes.ClassType(className) if className != IRNames.ObjectClass =>
1707-
val info = ctx.getClassInfo(className)
1708-
if (info.kind.isClass) {
1709-
instrs += REF_CAST(Types.WasmRefType(WasmStructTypeName.forClass(className)))
1710-
} else if (info.isInterface) {
1711-
instrs += REF_CAST(Types.WasmRefType(Types.WasmHeapType.ObjectType))
1712-
}
1713-
case _ =>
1714-
()
1715-
}
1716-
17171702
fixedTpe
17181703
}
17191704

@@ -1908,8 +1893,7 @@ private class WasmExpressionBuilder private (
19081893
* if the given class is an ancestor of hijacked classes (which in practice
19091894
* is only the case for j.l.Object).
19101895
*/
1911-
val instanceTyp =
1912-
Types.WasmRefType.nullable(WasmStructTypeName.forClass(n.className))
1896+
val instanceTyp = Types.WasmRefType(WasmStructTypeName.forClass(n.className))
19131897
val localInstance = fctx.addSyntheticLocal(instanceTyp)
19141898

19151899
fctx.markPosition(n)
@@ -1941,7 +1925,7 @@ private class WasmExpressionBuilder private (
19411925
val primLocal = fctx.addSyntheticLocal(primTyp)
19421926

19431927
val boxClassType = IRTypes.ClassType(boxClassName)
1944-
val boxTyp = TypeTransformer.transformType(boxClassType)(ctx)
1928+
val boxTyp = TypeTransformer.transformClassType(boxClassName)(ctx).toNonNullable
19451929
val instanceLocal = fctx.addSyntheticLocal(boxTyp)
19461930

19471931
/* The generated code is as follows. Before the codegen, the stack contains
@@ -1986,12 +1970,12 @@ private class WasmExpressionBuilder private (
19861970

19871971
private def genWrapAsThrowable(tree: IRTrees.WrapAsThrowable): IRTypes.Type = {
19881972
val throwableClassType = IRTypes.ClassType(IRNames.ThrowableClass)
1989-
val throwableTyp = TypeTransformer.transformType(throwableClassType)(ctx)
1973+
val nonNullThrowableTyp = Types.WasmRefType(Types.WasmHeapType.ThrowableType)
19901974

1991-
val jsExceptionClassType = IRTypes.ClassType(SpecialNames.JSExceptionClass)
1992-
val jsExceptionTyp = TypeTransformer.transformType(jsExceptionClassType)(ctx)
1975+
val jsExceptionTyp =
1976+
TypeTransformer.transformClassType(SpecialNames.JSExceptionClass)(ctx).toNonNullable
19931977

1994-
fctx.block(throwableTyp) { doneLabel =>
1978+
fctx.block(nonNullThrowableTyp) { doneLabel =>
19951979
genTree(tree.expr, IRTypes.AnyType)
19961980

19971981
fctx.markPosition(tree)
@@ -2000,7 +1984,7 @@ private class WasmExpressionBuilder private (
20001984
instrs += BR_ON_CAST(
20011985
doneLabel,
20021986
Types.WasmRefType.anyref,
2003-
Types.WasmRefType(Types.WasmHeapType.ThrowableType)
1987+
nonNullThrowableTyp
20041988
)
20051989

20061990
// otherwise, wrap in a new JavaScriptException

wasm/src/main/scala/wasm4s/Instructions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ object WasmInstr {
258258
with StackPolymorphicInstr
259259
case object RETURN extends WasmSimpleInstr("return", 0x0F) with StackPolymorphicInstr
260260
case class CALL(i: WasmFunctionName) extends WasmFuncInstr("call", 0x10, i)
261+
case class RETURN_CALL(i: WasmFunctionName) extends WasmFuncInstr("return_call", 0x12, i)
261262
case class THROW(i: WasmTagName) extends WasmTagInstr("throw", 0x08, i) with StackPolymorphicInstr
262263
case object THROW_REF extends WasmSimpleInstr("throw_ref", 0x0A) with StackPolymorphicInstr
263264
case class TRY_TABLE(i: BlockType, cs: List[CatchClause], label: Option[WasmLabelName] = None)

wasm/src/main/scala/wasm4s/Names.scala

+3
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ object Names {
118118
}
119119
}
120120

121+
def forTableEntry(clazz: IRNames.ClassName, method: IRNames.MethodName): WasmFunctionName =
122+
new WasmFunctionName("t#" + clazz.nameString, method.nameString)
123+
121124
def forExport(exportedName: String): WasmFunctionName =
122125
new WasmFunctionName("export", exportedName)
123126

wasm/src/main/scala/wasm4s/Types.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ object Types {
2323
case object WasmInt8 extends WasmPackedType("i8", 0x78)
2424
case object WasmInt16 extends WasmPackedType("i16", 0x77)
2525

26-
final case class WasmRefType(nullable: Boolean, heapType: WasmHeapType) extends WasmType
26+
final case class WasmRefType(nullable: Boolean, heapType: WasmHeapType) extends WasmType {
27+
def toNullable: WasmRefType = WasmRefType(true, heapType)
28+
def toNonNullable: WasmRefType = WasmRefType(false, heapType)
29+
}
2730

2831
object WasmRefType {
2932

wasm/src/main/scala/wasm4s/WasmContext.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
578578
instrs += WasmInstr.I32_CONST(idx)
579579

580580
for (method <- iface.tableEntries)
581-
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).wasmName)
581+
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
582582
instrs += WasmInstr.STRUCT_NEW(WasmTypeName.WasmStructTypeName.forITable(iface.name))
583583
instrs += WasmInstr.ARRAY_SET(WasmTypeName.WasmArrayTypeName.itables)
584584
}
@@ -598,7 +598,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
598598
instrs += I32_CONST(getItableIdx(interfaceName))
599599

600600
for (method <- interfaceInfo.tableEntries)
601-
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).wasmName)
601+
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
602602
instrs += STRUCT_NEW(WasmStructTypeName.forITable(interfaceName))
603603
instrs += ARRAY_SET(WasmArrayTypeName.itables)
604604
}
@@ -879,7 +879,7 @@ object WasmContext {
879879
val ownerClass: IRNames.ClassName,
880880
val methodName: IRNames.MethodName
881881
) {
882-
val wasmName = WasmFunctionName(IRTrees.MemberNamespace.Public, ownerClass, methodName)
882+
val tableEntryName = WasmFunctionName.forTableEntry(ownerClass, methodName)
883883

884884
private var effectivelyFinal: Boolean = true
885885

0 commit comments

Comments
 (0)