diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index 81b2e55b..e8f79bdf 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -9,6 +9,7 @@ object TestSuites { TestSuite("testsuite.core.VirtualDispatch"), TestSuite("testsuite.core.InterfaceCall"), TestSuite("testsuite.core.AsInstanceOfTest"), + TestSuite("testsuite.core.IsInstanceOfTest"), TestSuite("testsuite.core.ClassOfTest"), TestSuite("testsuite.core.ClosureTest"), TestSuite("testsuite.core.FieldsTest"), diff --git a/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala new file mode 100644 index 00000000..0c0f86af --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala @@ -0,0 +1,59 @@ +package testsuite.core + +import testsuite.Assert.ok + +object IsInstanceOfTest { + def main(): Unit = { + ok(testInheritance()) + ok(testMixinAll(new Child())) + ok(testMixinAll(new Parent())) + ok(testMixinAll(new Base {})) + ok(testMixin()) + ok(testPrimitiveIsInstanceOfBase(5)) + ok(testPrimitiveIsInstanceOfBase("foo")) + + ok(!testInt("foo")) + ok(!testInt(2147483648L)) + ok(testInt(3)) + ok(!testInt(new Child())) + ok(testString("foo")) + ok(!testString(new Child())) + } + + private def testInheritance(): Boolean = { + val child = new Child() + val parent = new Parent() + child.isInstanceOf[Parent] && child.isInstanceOf[Child] && + parent.isInstanceOf[Parent] && !parent.isInstanceOf[Child] + } + + private def testMixinAll(o: Base): Boolean = { + o.isInstanceOf[Base] && o.isInstanceOf[Base1] && o.isInstanceOf[Base2] + } + + private def testMixin(): Boolean = { + val base1 = new Base1 {} + val base2 = new Base2 {} + base1.isInstanceOf[Base1] && + !base1.isInstanceOf[Base2] && + !base1.isInstanceOf[Base] && + !base2.isInstanceOf[Base1] && + base2.isInstanceOf[Base2] && + !base2.isInstanceOf[Base] + } + + private def testPrimitiveIsInstanceOfBase(p: Any): Boolean = + !p.isInstanceOf[Base] + + private def testInt(e: Any): Boolean = e.isInstanceOf[Int] + private def testString(e: Any): Boolean = e.isInstanceOf[String] +} + +class Parent extends Base { + def foo(): Int = 5 +} +class Child extends Parent + +trait Base1 +trait Base2 +trait Base extends Base1 with Base2 diff --git a/wasm/src/main/scala/Compiler.scala b/wasm/src/main/scala/Compiler.scala index 8d54656b..eeec7f48 100644 --- a/wasm/src/main/scala/Compiler.scala +++ b/wasm/src/main/scala/Compiler.scala @@ -72,10 +72,9 @@ object Compiler { else a.className.compareTo(b.className) < 0 } - sortedClasses.foreach(showLinkedClass(_)) + // sortedClasses.foreach(showLinkedClass(_)) Preprocessor.preprocess(sortedClasses)(context) - println("preprocessed") HelperFunctions.genGlobalHelpers() builder.genPrimitiveTypeDataGlobals() sortedClasses.foreach { clazz => diff --git a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala index 9202db4d..cd064bab 100644 --- a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala +++ b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala @@ -3,6 +3,8 @@ package wasm.ir2wasm import org.scalajs.ir.{Trees => IRTrees} import org.scalajs.ir.{Types => IRTypes} import org.scalajs.ir.{Names => IRNames} +import org.scalajs.linker.standard.LinkedClass +import org.scalajs.ir.ClassKind import wasm.wasm4s._ import wasm.wasm4s.WasmContext._ @@ -685,4 +687,101 @@ object HelperFunctions { fctx.buildAndAddToContext() } + /** Generate type inclusion test for interfaces `isInstanceOf[]` will be compiled to + * the CALL of the generated this function. + * + * TODO: Efficient type inclusion test Current implementation walk through the itables of the + * Object which takes O(N) (where N = number of interfaces the expr implements) See: + * https://github.com/tanishiking/scala-wasm/issues/27#issuecomment-2008252049 + */ + def genInstanceTest(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { + import WasmImmediate._ + assert(clazz.kind == ClassKind.Interface) + + val fctx = WasmFunctionContext( + Names.WasmFunctionName.instanceTest(clazz.name.name), + List("expr" -> WasmAnyRef), + List(WasmInt32) + ) + val List(exprParam) = fctx.paramIndices + + import fctx.instrs + + val found = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val cnt = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val len = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val itables = fctx.addLocal( + fctx.genSyntheticLocalName(), + Types.WasmRefNullType(Types.WasmHeapType.Type(WasmArrayType.itables.name)) + ) + + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(found) + + fctx.block(WasmRefNullType(WasmHeapType.Simple.Any)) { testFail => + // if expr is not an instance of Object, return false + instrs += LOCAL_GET(exprParam) + instrs += BR_ON_CAST_FAIL( + CastFlags(true, false), + testFail, + HeapType(Types.WasmHeapType.Simple.Any), + HeapType(Types.WasmHeapType.ObjectType) + ) + + // if the itables is null (no interfaces are implemented) + instrs += LOCAL_GET(exprParam) + instrs += REF_CAST(HeapType(Types.WasmHeapType.ObjectType)) + instrs += STRUCT_GET(TypeIdx(Types.WasmHeapType.ObjectType.typ), StructFieldIdx(1)) + instrs += LOCAL_TEE(itables) + instrs += REF_IS_NULL + instrs += BR_IF(testFail) + + // found := 0 + // len := length(itables) + // loop $loopLabel { + // if (cnt < len) { + // if (itables(cnt) is instance of testClassName's itable) { + // found := 1 + // } else { + // cnt := cnt + 1 + // br $loopLabel + // } + // } + // return found + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(cnt) + // len := length(itables) + instrs += LOCAL_GET(itables) + instrs += ARRAY_LEN + instrs += LOCAL_SET(len) + + fctx.loop() { loopLabel => + instrs += LOCAL_GET(cnt) + instrs += LOCAL_GET(len) + instrs += I32_LT_U + fctx.ifThen() { + instrs += LOCAL_GET(itables) + instrs += LOCAL_GET(cnt) + instrs += ARRAY_GET(TypeIdx(WasmArrayType.itables.name)) + instrs += REF_TEST( + HeapType(Types.WasmHeapType.Type(WasmTypeName.WasmITableTypeName(clazz.name.name))) + ) + fctx.ifThenElse() { + instrs += I32_CONST(I32(1)) + instrs += LOCAL_SET(found) + } { + instrs += LOCAL_GET(cnt) + instrs += I32_CONST(I32(1)) + instrs += I32_ADD + instrs += LOCAL_SET(cnt) + instrs += BR(loopLabel) + } + } + } + } + instrs += DROP + instrs += LOCAL_GET(found) + fctx.buildAndAddToContext() + } + } diff --git a/wasm/src/main/scala/ir2wasm/Preprocessor.scala b/wasm/src/main/scala/ir2wasm/Preprocessor.scala index 4b8b6751..f1ef8a3e 100644 --- a/wasm/src/main/scala/ir2wasm/Preprocessor.scala +++ b/wasm/src/main/scala/ir2wasm/Preprocessor.scala @@ -17,8 +17,11 @@ object Preprocessor { for (clazz <- classes) preprocess(clazz) - for (clazz <- classes) + for (clazz <- classes) { collectAbstractMethodCalls(clazz) + if (!clazz.hasDirectInstances && clazz.hasInstanceTests) + HelperFunctions.genInstanceTest(clazz) + } } private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 84f55ad8..42c06ec5 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -1087,15 +1087,13 @@ private class WasmExpressionBuilder private ( genIsPrimType(primType) case None => val info = ctx.getClassInfo(testClassName) - if (info.isInterface) { - // TODO: run-time type test for interface - println(tree) - ??? - } else { + + if (info.isInterface) + instrs += CALL(FuncIdx(WasmFunctionName.instanceTest(testClassName))) + else instrs += REF_TEST( HeapType(Types.WasmHeapType.Type(WasmStructTypeName(testClassName))) ) - } } case IRTypes.ArrayType(_) => diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala index 99236756..0b834bd7 100644 --- a/wasm/src/main/scala/wasm4s/Names.scala +++ b/wasm/src/main/scala/wasm4s/Names.scala @@ -131,13 +131,12 @@ object Names { def forExport(exportedName: String): WasmFunctionName = new WasmFunctionName("export", exportedName) - // Adding prefix __ to avoid name clashes with user code. - // It should be safe not to add prefix to the method name - // since loadModule is a static method and it's not registered in the vtable. def loadModule(clazz: IRNames.ClassName): WasmFunctionName = new WasmFunctionName("loadModule", clazz.nameString) def newDefault(clazz: IRNames.ClassName): WasmFunctionName = new WasmFunctionName("new", clazz.nameString) + def instanceTest(clazz: IRNames.ClassName): WasmFunctionName = + new WasmFunctionName("instanceTest", clazz.nameString) val start = new WasmFunctionName("start", "start")