Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #14 from tanishiking/fix-virtual-dispatch
Browse files Browse the repository at this point in the history
fix: Add tests for virtual dispatch + fix bug for abstract class
  • Loading branch information
tanishiking authored Mar 11, 2024
2 parents e198109 + af2d6ad commit e956e49
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 130 deletions.
3 changes: 3 additions & 0 deletions cli/src/main/scala/TestSuites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ object TestSuites {
val suites = List(
TestSuite("testsuite.core.simple.Simple", "simple"),
TestSuite("testsuite.core.add.Add", "add"),
TestSuite("testsuite.core.add.Add", "add"),
TestSuite("testsuite.core.virtualdispatch.VirtualDispatch", "virtualDispatch"),
TestSuite("testsuite.core.interfacecall.InterfaceCall", "interfaceCall"),
TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"),
TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono")
)
Expand Down
31 changes: 31 additions & 0 deletions test-suite/src/main/scala/testsuite/core/InterfaceCall.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package testsuite.core.interfacecall

import scala.scalajs.js.annotation._

object InterfaceCall {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("interfaceCall")
def test(): Boolean = {
val c = new Concrete()
c.plus(c.zero, 1) == 1 && c.minus(1, c.zero) == 1
}

class Concrete extends AddSub with Zero {
override def zero: Int = 0
}

trait Adder {
def plus(a: Int, b: Int) = a + b
}

trait Sub {
def minus(a: Int, b: Int): Int = a - b
}

trait AddSub extends Adder with Sub

trait Zero {
def zero: Int
}
}
54 changes: 54 additions & 0 deletions test-suite/src/main/scala/testsuite/core/VirtualDispatch.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package testsuite.core.virtualdispatch

import scala.scalajs.js.annotation._

object VirtualDispatch {
def main(): Unit = { val _ = test() }

@JSExportTopLevel("virtualDispatch")
def test(): Boolean = {
val a = new A
val b = new B

testA(a) &&
testB(a, isInstanceOfA = true) &&
testB(b, isInstanceOfA = false) &&
testC(a, isInstanceOfA = true) &&
testC(b, isInstanceOfA = false)
}

def testA(a: A): Boolean = {
a.a == 2 && a.impl == 2 && a.b == 1 && a.c == 1
}

def testB(b: B, isInstanceOfA: Boolean): Boolean = {
if (isInstanceOfA) {
b.b == 1 && b.c == 1 && b.impl == 2
} else {
b.b == 1 && b.c == 1 && b.impl == 0
}
}

def testC(c: C, isInstanceOfA: Boolean): Boolean = {
if (isInstanceOfA) {
c.c == 1 && c.impl == 2
} else {
c.c == 1 && c.impl == 0
}
}

class A extends B {
def a: Int = 2
override def impl = 2
}

class B extends C {
def b: Int = 1
override def c: Int = 1
}

abstract class C {
def c: Int
def impl: Int = 0
}
}
23 changes: 22 additions & 1 deletion wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ object Preprocessor {

private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
clazz.kind match {
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass =>
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface |
ClassKind.HijackedClass =>
collectMethods(clazz)
case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass |
ClassKind.AbstractJSType | ClassKind.NativeJSClass =>
Expand Down Expand Up @@ -61,6 +62,26 @@ object Preprocessor {
)
}

/** Collect WasmFunctionInfo based on the abstract method call
*
* ```
* class A extends B:
* def a = 1
*
* class B extends C:
* def b: Int = 1
* override def c: Int = 1
*
* abstract class C:
* def c: Int
* ```
*
* why we need this? - The problem is that the frontend linker gets rid of abstract method
* entirely.
*
* It keeps B.c because it's concrete and used. But because `C.c` isn't there at all anymore, if
* we have val `x: C` and we call `x.c`, we don't find the method at all.
*/
private def collectAbstractMethodCalls(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
object traverser extends Traversers.Traverser {
import IRTrees._
Expand Down
155 changes: 84 additions & 71 deletions wasm/src/main/scala/ir2wasm/WasmBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,40 +42,97 @@ class WasmBuilder {
}
}

/** @return
* Optionally returns the generated struct type for this class. If the given LinkedClass is an
* abstract class, returns None
*/
private def transformClassCommon(
clazz: LinkedClass
)(implicit ctx: WasmContext): WasmStructType = {
val (vtableType, vtableName) = genVTable(clazz)
// gen functions
clazz.methods.foreach { method =>
genFunction(clazz, method)
}
val className = clazz.name.name

// generate vtable type, this should be done for both abstract and concrete classes
val vtable = ctx.calculateVtableType(className)
val vtableType = genVTableType(clazz, vtable.functions)
ctx.addGCType(vtableType)

val isAbstractClass = {
// If number of declared functions doesn't match number of defined functions, it must be a AbstractClass
// TODO: better way to check if it's abstract class
val definedFunctions = ctx.calculateGlobalVTable(className)
val declaredFunctions = vtable.functions
declaredFunctions.length != definedFunctions.length
}

// we should't generate global vtable for abstract class because
// - Can't generate Global vtable because we can't fill the slot for abstract methods
// - We won't access vtable for abstract classes since we can't instantiate abstract classes, there's no point generating
//
// However, I couldn't find a way to test if the LinkedClass is abstract
// "clazz.methods.exists(m => m.body.isEmpty)" doesn't work because abstract methods are removed at linker optimization
// the WasmFunctionInfo of the abstract methods will be added specially in Preprocessor
val (gVtable, gItable) = if (!isAbstractClass) {
// Generate global vtable
val functions = ctx.calculateGlobalVTable(className)
val vtableInit = functions.map { method =>
WasmInstr.REF_FUNC(method.name)
} :+ WasmInstr.STRUCT_NEW(vtableType.name)
val vtableName = Names.WasmGlobalName.WasmGlobalVTableName(clazz.name.name)
val globalVTable =
WasmGlobal(
vtableName,
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
WasmExpr(vtableInit),
isMutable = false
)
ctx.addGlobal(globalVTable)

// Generate class itable
val globalClassITable = calculateClassITable(clazz)
globalClassITable.foreach(ctx.addGlobal)

(Some(globalVTable), globalClassITable)
} else (None, None)

// Declare the strcut type for the class
genStructNewDefault(clazz, gVtable, gItable)
val vtableField = WasmStructField(
Names.WasmFieldName.vtable,
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
isMutable = false
)
calculateClassITable(clazz) match {
case None =>
genStructNewDefault(clazz, vtableName, None)
case Some(globalITable) =>
ctx.addGlobal(globalITable)
genStructNewDefault(clazz, vtableName, Some(globalITable))
}

// type definition
val fields = clazz.fields.map(transformField)
val structType = WasmStructType(
Names.WasmTypeName.WasmStructTypeName(clazz.name.name),
vtableField +: WasmStructField.itables +: fields,
clazz.superClass.map(s => Names.WasmTypeName.WasmStructTypeName(s.name))
)
ctx.addGCType(structType)

// implementation of methods
clazz.methods.foreach { method =>
genFunction(clazz, method)
}

structType
}

private def genVTableType(clazz: LinkedClass, functions: List[WasmFunctionInfo])(implicit
ctx: WasmContext
): WasmStructType = {
val vtableFields =
functions.map { method =>
WasmStructField(
Names.WasmFieldName(method.name),
WasmRefNullType(WasmHeapType.Func(method.toWasmFunctionType().name)),
isMutable = false
)
}
WasmStructType(
Names.WasmTypeName.WasmVTableTypeName(clazz.name.name),
vtableFields,
clazz.superClass.map(s => Names.WasmTypeName.WasmVTableTypeName(s.name))
)
}

private def genLoadModuleFunc(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
import WasmImmediate._
assert(clazz.kind == ClassKind.ModuleClass)
Expand Down Expand Up @@ -118,10 +175,18 @@ class WasmBuilder {

private def genStructNewDefault(
clazz: LinkedClass,
vtable: WasmGlobalName.WasmGlobalVTableName,
vtable: Option[WasmGlobal],
itable: Option[WasmGlobal]
)(implicit ctx: WasmContext): Unit = {
val getVTable = GLOBAL_GET(WasmImmediate.GlobalIdx(vtable))
val getVTable = vtable match {
case None =>
REF_NULL(
WasmImmediate.HeapType(
WasmHeapType.Type(WasmTypeName.WasmVTableTypeName(clazz.name.name))
)
)
case Some(v) => GLOBAL_GET(WasmImmediate.GlobalIdx(v.name))
}
val getITable = itable match {
case None => REF_NULL(WasmImmediate.HeapType(WasmHeapType.Type(WasmArrayType.itables.name)))
case Some(i) => GLOBAL_GET(WasmImmediate.GlobalIdx(i.name))
Expand Down Expand Up @@ -156,21 +221,7 @@ class WasmBuilder {
)(implicit ctx: ReadOnlyWasmContext): Option[WasmGlobal] = {
val classItables = ctx.calculateClassItables(clazz.name.name)
if (!classItables.isEmpty) {
// val classITableTypeName = WasmTypeName.WasmITableTypeName(clazz.name.name)
// val classITableType = WasmStructType(
// classITableTypeName,
// interfaceInfos.map { info =>
// val itableTypeName = WasmTypeName.WasmITableTypeName(info.name)
// WasmStructField(
// Names.WasmFieldName(itableTypeName),
// WasmRefType(WasmHeapType.Type(itableTypeName)),
// isMutable = false
// )
// },
// None
// )

val vtable = ctx.calculateVtable(clazz.name.name)
val vtable = ctx.calculateVtableType(clazz.name.name)

val itablesInit: List[WasmInstr] = classItables.itables.flatMap { iface =>
iface.methods.map { method =>
Expand All @@ -194,44 +245,6 @@ class WasmBuilder {
} else None
}

private def genVTable(
clazz: LinkedClass
)(implicit ctx: WasmContext): (WasmStructType, WasmGlobalName.WasmGlobalVTableName) = {
val className = clazz.name.name
def genVTableType(vtable: WasmVTable): WasmStructType = {
val vtableFields =
vtable.functions.map { method =>
WasmStructField(
Names.WasmFieldName(method.name),
WasmRefNullType(WasmHeapType.Func(method.toWasmFunctionType().name)),
isMutable = false
)
}
WasmStructType(
Names.WasmTypeName.WasmVTableTypeName.fromIR(clazz.name.name),
vtableFields,
clazz.superClass.map(s => Names.WasmTypeName.WasmVTableTypeName.fromIR(s.name))
)
}

val vtableName = Names.WasmGlobalName.WasmGlobalVTableName(clazz.name.name)

val vtable = ctx.calculateVtable(className)
val vtableType = genVTableType(vtable)
ctx.addGCType(vtableType)

val globalVTable =
WasmGlobal(
vtableName,
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
WasmExpr(vtable.toVTableEntries(vtableType.name)),
isMutable = false
)
ctx.addGlobal(globalVTable)

(vtableType, vtableName)
}

private def transformClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
assert(clazz.kind == ClassKind.Class)
transformClassCommon(clazz)
Expand Down
Loading

0 comments on commit e956e49

Please sign in to comment.