From bee8fde47f21f08edb7daeb044588a0d64e35b2c Mon Sep 17 00:00:00 2001 From: Rikito Taniguchi Date: Thu, 21 Mar 2024 15:53:51 +0900 Subject: [PATCH 1/3] Fix #27: Implement isInstanceOf[interface] https://github.com/tanishiking/scala-wasm/issues/27 Current implementation is temporal and we should switch to the efficient type inclusion test implementation in future. This impl test if the given expression is instance of the given interface by walking through the itables of the Object, which takes O(N) (where N = number of interfaces the expr implements) Also, this procedure can't be extracted as a helper function in Wasm because the immediate argument of `ref.test` is specific to the `testClassName`, which result in a number of instructions everywhere at `isInstanceOf[interface]`. see: https://github.com/tanishiking/scala-wasm/issues/27#issuecomment-2008252049 --- cli/src/main/scala/TestSuites.scala | 1 + .../testsuite/core/IsInstanceOfTest.scala | 59 +++++++++++++ .../scala/ir2wasm/WasmExpressionBuilder.scala | 87 ++++++++++++++++++- 3 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index 81b2e55b..e8f79bdf 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -9,6 +9,7 @@ object TestSuites { TestSuite("testsuite.core.VirtualDispatch"), TestSuite("testsuite.core.InterfaceCall"), TestSuite("testsuite.core.AsInstanceOfTest"), + TestSuite("testsuite.core.IsInstanceOfTest"), TestSuite("testsuite.core.ClassOfTest"), TestSuite("testsuite.core.ClosureTest"), TestSuite("testsuite.core.FieldsTest"), diff --git a/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala new file mode 100644 index 00000000..82fa08a1 --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala @@ -0,0 +1,59 @@ +package testsuite.core + +import testsuite.Assert.ok + +object IsInstanceOfTest { + def main(): Unit = { + ok(testInheritance()) + ok(testMixinAll(new Child())) + ok(testMixinAll(new Parent())) + ok(testMixinAll(new Base {})) + ok(testMixin()) + ok(testPrimitiveIsInstanceOfBase(5)) + ok(testPrimitiveIsInstanceOfBase("foo")) + + ok(!testInt("foo")) + ok(!testInt(2147483648L)) + ok(testInt(3)) + ok(!testInt(new Child())) + ok(testString("foo")) + ok(!testString(new Child())) + } + + private def testInheritance(): Boolean = { + val child = new Child() + val parent = new Parent() + child.isInstanceOf[Parent] && child.isInstanceOf[Child] && + parent.isInstanceOf[Parent] && !parent.isInstanceOf[Child] + } + + private def testMixinAll(o: Base): Boolean = { + o.isInstanceOf[Base] && o.isInstanceOf[Base1] & o.isInstanceOf[Base2] + } + + private def testMixin(): Boolean = { + val base1 = new Base1 {} + val base2 = new Base2 {} + base1.isInstanceOf[Base1] && + !base1.isInstanceOf[Base2] && + !base1.isInstanceOf[Base] && + !base2.isInstanceOf[Base1] && + base2.isInstanceOf[Base2] && + !base2.isInstanceOf[Base] + } + + private def testPrimitiveIsInstanceOfBase(p: Any): Boolean = + !p.isInstanceOf[Base] + + private def testInt(e: Any): Boolean = e.isInstanceOf[Int] + private def testString(e: Any): Boolean = e.isInstanceOf[String] +} + +class Parent extends Base { + def foo(): Int = 5 +} +class Child extends Parent + +trait Base1 +trait Base2 +trait Base extends Base1 with Base2 diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 84f55ad8..46c923dc 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -1087,10 +1087,91 @@ private class WasmExpressionBuilder private ( genIsPrimType(primType) case None => val info = ctx.getClassInfo(testClassName) + + // TODO: Efficient type inclusion test + // Current implementation walk through the itables of the Object which takes O(N) + // (where N = number of interfaces the expr implements) + // Also, this procedure can't be extracted as a helper function in Wasm because the + // immediate argument of `ref.test` is specific to the `testClassName`, + // which result in a number of instructions everywhere at `isInstanceOf[interface]`. + // See: https://github.com/tanishiking/scala-wasm/issues/27#issuecomment-2008252049 if (info.isInterface) { - // TODO: run-time type test for interface - println(tree) - ??? + val expr = fctx.addLocal( + fctx.genSyntheticLocalName(), + TypeTransformer.transformType(IRTypes.AnyType)(ctx) + ) + val found = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val cnt = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val len = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val itables = fctx.addLocal( + fctx.genSyntheticLocalName(), + Types.WasmRefNullType(Types.WasmHeapType.Type(WasmArrayType.itables.name)) + ) + + instrs += LOCAL_SET(expr) + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(found) + + fctx.block() { testFail => + // if expr is not an instance of Object, return false + instrs += LOCAL_GET(expr) + instrs += REF_TEST(HeapType(Types.WasmHeapType.ObjectType)) + instrs += I32_CONST(I32(1)) + instrs += I32_XOR + instrs += BR_IF(testFail) + + // if the itables is null (no interfaces are implemented) + instrs += LOCAL_GET(expr) + instrs += REF_CAST(HeapType(Types.WasmHeapType.ObjectType)) + instrs += STRUCT_GET(TypeIdx(Types.WasmHeapType.ObjectType.typ), StructFieldIdx(1)) + instrs += LOCAL_TEE(itables) + instrs += REF_IS_NULL + instrs += BR_IF(testFail) + + // found := 0 + // len := length(itables) + // loop $loopLabel { + // if (cnt < len) { + // if (itables(cnt) is instance of testClassName's itable) { + // found := 1 + // } else { + // cnt := cnt + 1 + // br $loopLabel + // } + // } + // return found + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(cnt) + // len := length(itables) + instrs += LOCAL_GET(itables) + instrs += ARRAY_LEN + instrs += LOCAL_SET(len) + + fctx.loop() { loopLabel => + instrs += LOCAL_GET(cnt) + instrs += LOCAL_GET(len) + instrs += I32_LT_U + fctx.ifThen() { + instrs += LOCAL_GET(itables) + instrs += LOCAL_GET(cnt) + instrs += ARRAY_GET(TypeIdx(WasmArrayType.itables.name)) + instrs += REF_TEST( + HeapType(Types.WasmHeapType.Type(WasmITableTypeName(testClassName))) + ) + fctx.ifThenElse() { + instrs += I32_CONST(I32(1)) + instrs += LOCAL_SET(found) + } { + instrs += LOCAL_GET(cnt) + instrs += I32_CONST(I32(1)) + instrs += I32_ADD + instrs += LOCAL_SET(cnt) + instrs += BR(loopLabel) + } + } + } + } + instrs += LOCAL_GET(found) } else { instrs += REF_TEST( HeapType(Types.WasmHeapType.Type(WasmStructTypeName(testClassName))) From 56f36852168a63b92c57dfcfae9feb05d1513dec Mon Sep 17 00:00:00 2001 From: Rikito Taniguchi Date: Thu, 21 Mar 2024 21:12:57 +0900 Subject: [PATCH 2/3] Generate instanceTest for each interfaces (that have instanceTests) --- wasm/src/main/scala/Compiler.scala | 3 +- .../main/scala/ir2wasm/HelperFunctions.scala | 96 +++++++++++++++++++ .../src/main/scala/ir2wasm/Preprocessor.scala | 5 +- .../scala/ir2wasm/WasmExpressionBuilder.scala | 89 +---------------- wasm/src/main/scala/wasm4s/Names.scala | 5 +- 5 files changed, 106 insertions(+), 92 deletions(-) diff --git a/wasm/src/main/scala/Compiler.scala b/wasm/src/main/scala/Compiler.scala index 8d54656b..eeec7f48 100644 --- a/wasm/src/main/scala/Compiler.scala +++ b/wasm/src/main/scala/Compiler.scala @@ -72,10 +72,9 @@ object Compiler { else a.className.compareTo(b.className) < 0 } - sortedClasses.foreach(showLinkedClass(_)) + // sortedClasses.foreach(showLinkedClass(_)) Preprocessor.preprocess(sortedClasses)(context) - println("preprocessed") HelperFunctions.genGlobalHelpers() builder.genPrimitiveTypeDataGlobals() sortedClasses.foreach { clazz => diff --git a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala index 9202db4d..81f21bbe 100644 --- a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala +++ b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala @@ -3,6 +3,8 @@ package wasm.ir2wasm import org.scalajs.ir.{Trees => IRTrees} import org.scalajs.ir.{Types => IRTypes} import org.scalajs.ir.{Names => IRNames} +import org.scalajs.linker.standard.LinkedClass +import org.scalajs.ir.ClassKind import wasm.wasm4s._ import wasm.wasm4s.WasmContext._ @@ -685,4 +687,98 @@ object HelperFunctions { fctx.buildAndAddToContext() } + /** Generate type inclusion test for interfaces `isInstanceOf[]` will be compiled to + * the CALL of the generated this function. + * + * TODO: Efficient type inclusion test Current implementation walk through the itables of the + * Object which takes O(N) (where N = number of interfaces the expr implements) See: + * https://github.com/tanishiking/scala-wasm/issues/27#issuecomment-2008252049 + */ + def genInstanceTest(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { + import WasmImmediate._ + assert(clazz.kind == ClassKind.Interface) + + val fctx = WasmFunctionContext( + Names.WasmFunctionName.instanceTest(clazz.name.name), + List("expr" -> WasmAnyRef), + List(WasmInt32) + ) + val List(exprParam) = fctx.paramIndices + + import fctx.instrs + + val found = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val cnt = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val len = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) + val itables = fctx.addLocal( + fctx.genSyntheticLocalName(), + Types.WasmRefNullType(Types.WasmHeapType.Type(WasmArrayType.itables.name)) + ) + + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(found) + + fctx.block() { testFail => + // if expr is not an instance of Object, return false + instrs += LOCAL_GET(exprParam) + instrs += REF_TEST(HeapType(Types.WasmHeapType.ObjectType)) + instrs += I32_CONST(I32(1)) + instrs += I32_XOR + instrs += BR_IF(testFail) + + // if the itables is null (no interfaces are implemented) + instrs += LOCAL_GET(exprParam) + instrs += REF_CAST(HeapType(Types.WasmHeapType.ObjectType)) + instrs += STRUCT_GET(TypeIdx(Types.WasmHeapType.ObjectType.typ), StructFieldIdx(1)) + instrs += LOCAL_TEE(itables) + instrs += REF_IS_NULL + instrs += BR_IF(testFail) + + // found := 0 + // len := length(itables) + // loop $loopLabel { + // if (cnt < len) { + // if (itables(cnt) is instance of testClassName's itable) { + // found := 1 + // } else { + // cnt := cnt + 1 + // br $loopLabel + // } + // } + // return found + instrs += I32_CONST(I32(0)) + instrs += LOCAL_SET(cnt) + // len := length(itables) + instrs += LOCAL_GET(itables) + instrs += ARRAY_LEN + instrs += LOCAL_SET(len) + + fctx.loop() { loopLabel => + instrs += LOCAL_GET(cnt) + instrs += LOCAL_GET(len) + instrs += I32_LT_U + fctx.ifThen() { + instrs += LOCAL_GET(itables) + instrs += LOCAL_GET(cnt) + instrs += ARRAY_GET(TypeIdx(WasmArrayType.itables.name)) + instrs += REF_TEST( + HeapType(Types.WasmHeapType.Type(WasmTypeName.WasmITableTypeName(clazz.name.name))) + ) + fctx.ifThenElse() { + instrs += I32_CONST(I32(1)) + instrs += LOCAL_SET(found) + } { + instrs += LOCAL_GET(cnt) + instrs += I32_CONST(I32(1)) + instrs += I32_ADD + instrs += LOCAL_SET(cnt) + instrs += BR(loopLabel) + } + } + } + } + instrs += LOCAL_GET(found) + fctx.buildAndAddToContext() + } + } diff --git a/wasm/src/main/scala/ir2wasm/Preprocessor.scala b/wasm/src/main/scala/ir2wasm/Preprocessor.scala index 4b8b6751..f1ef8a3e 100644 --- a/wasm/src/main/scala/ir2wasm/Preprocessor.scala +++ b/wasm/src/main/scala/ir2wasm/Preprocessor.scala @@ -17,8 +17,11 @@ object Preprocessor { for (clazz <- classes) preprocess(clazz) - for (clazz <- classes) + for (clazz <- classes) { collectAbstractMethodCalls(clazz) + if (!clazz.hasDirectInstances && clazz.hasInstanceTests) + HelperFunctions.genInstanceTest(clazz) + } } private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 46c923dc..42c06ec5 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -1088,95 +1088,12 @@ private class WasmExpressionBuilder private ( case None => val info = ctx.getClassInfo(testClassName) - // TODO: Efficient type inclusion test - // Current implementation walk through the itables of the Object which takes O(N) - // (where N = number of interfaces the expr implements) - // Also, this procedure can't be extracted as a helper function in Wasm because the - // immediate argument of `ref.test` is specific to the `testClassName`, - // which result in a number of instructions everywhere at `isInstanceOf[interface]`. - // See: https://github.com/tanishiking/scala-wasm/issues/27#issuecomment-2008252049 - if (info.isInterface) { - val expr = fctx.addLocal( - fctx.genSyntheticLocalName(), - TypeTransformer.transformType(IRTypes.AnyType)(ctx) - ) - val found = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) - val cnt = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) - val len = fctx.addLocal(fctx.genSyntheticLocalName(), Types.WasmInt32) - val itables = fctx.addLocal( - fctx.genSyntheticLocalName(), - Types.WasmRefNullType(Types.WasmHeapType.Type(WasmArrayType.itables.name)) - ) - - instrs += LOCAL_SET(expr) - instrs += I32_CONST(I32(0)) - instrs += LOCAL_SET(found) - - fctx.block() { testFail => - // if expr is not an instance of Object, return false - instrs += LOCAL_GET(expr) - instrs += REF_TEST(HeapType(Types.WasmHeapType.ObjectType)) - instrs += I32_CONST(I32(1)) - instrs += I32_XOR - instrs += BR_IF(testFail) - - // if the itables is null (no interfaces are implemented) - instrs += LOCAL_GET(expr) - instrs += REF_CAST(HeapType(Types.WasmHeapType.ObjectType)) - instrs += STRUCT_GET(TypeIdx(Types.WasmHeapType.ObjectType.typ), StructFieldIdx(1)) - instrs += LOCAL_TEE(itables) - instrs += REF_IS_NULL - instrs += BR_IF(testFail) - - // found := 0 - // len := length(itables) - // loop $loopLabel { - // if (cnt < len) { - // if (itables(cnt) is instance of testClassName's itable) { - // found := 1 - // } else { - // cnt := cnt + 1 - // br $loopLabel - // } - // } - // return found - instrs += I32_CONST(I32(0)) - instrs += LOCAL_SET(cnt) - // len := length(itables) - instrs += LOCAL_GET(itables) - instrs += ARRAY_LEN - instrs += LOCAL_SET(len) - - fctx.loop() { loopLabel => - instrs += LOCAL_GET(cnt) - instrs += LOCAL_GET(len) - instrs += I32_LT_U - fctx.ifThen() { - instrs += LOCAL_GET(itables) - instrs += LOCAL_GET(cnt) - instrs += ARRAY_GET(TypeIdx(WasmArrayType.itables.name)) - instrs += REF_TEST( - HeapType(Types.WasmHeapType.Type(WasmITableTypeName(testClassName))) - ) - fctx.ifThenElse() { - instrs += I32_CONST(I32(1)) - instrs += LOCAL_SET(found) - } { - instrs += LOCAL_GET(cnt) - instrs += I32_CONST(I32(1)) - instrs += I32_ADD - instrs += LOCAL_SET(cnt) - instrs += BR(loopLabel) - } - } - } - } - instrs += LOCAL_GET(found) - } else { + if (info.isInterface) + instrs += CALL(FuncIdx(WasmFunctionName.instanceTest(testClassName))) + else instrs += REF_TEST( HeapType(Types.WasmHeapType.Type(WasmStructTypeName(testClassName))) ) - } } case IRTypes.ArrayType(_) => diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala index 99236756..0b834bd7 100644 --- a/wasm/src/main/scala/wasm4s/Names.scala +++ b/wasm/src/main/scala/wasm4s/Names.scala @@ -131,13 +131,12 @@ object Names { def forExport(exportedName: String): WasmFunctionName = new WasmFunctionName("export", exportedName) - // Adding prefix __ to avoid name clashes with user code. - // It should be safe not to add prefix to the method name - // since loadModule is a static method and it's not registered in the vtable. def loadModule(clazz: IRNames.ClassName): WasmFunctionName = new WasmFunctionName("loadModule", clazz.nameString) def newDefault(clazz: IRNames.ClassName): WasmFunctionName = new WasmFunctionName("new", clazz.nameString) + def instanceTest(clazz: IRNames.ClassName): WasmFunctionName = + new WasmFunctionName("instanceTest", clazz.nameString) val start = new WasmFunctionName("start", "start") From ba4858b6f672acf0463c11a02f11adc81a39df28 Mon Sep 17 00:00:00 2001 From: Rikito Taniguchi Date: Thu, 21 Mar 2024 21:27:01 +0900 Subject: [PATCH 3/3] br_on_cast_fail instead of REF_TEST and BR_IF / fix test --- .../scala/testsuite/core/IsInstanceOfTest.scala | 2 +- wasm/src/main/scala/ir2wasm/HelperFunctions.scala | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala index 82fa08a1..0c0f86af 100644 --- a/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala +++ b/test-suite/src/main/scala/testsuite/core/IsInstanceOfTest.scala @@ -28,7 +28,7 @@ object IsInstanceOfTest { } private def testMixinAll(o: Base): Boolean = { - o.isInstanceOf[Base] && o.isInstanceOf[Base1] & o.isInstanceOf[Base2] + o.isInstanceOf[Base] && o.isInstanceOf[Base1] && o.isInstanceOf[Base2] } private def testMixin(): Boolean = { diff --git a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala index 81f21bbe..cd064bab 100644 --- a/wasm/src/main/scala/ir2wasm/HelperFunctions.scala +++ b/wasm/src/main/scala/ir2wasm/HelperFunctions.scala @@ -718,13 +718,15 @@ object HelperFunctions { instrs += I32_CONST(I32(0)) instrs += LOCAL_SET(found) - fctx.block() { testFail => + fctx.block(WasmRefNullType(WasmHeapType.Simple.Any)) { testFail => // if expr is not an instance of Object, return false instrs += LOCAL_GET(exprParam) - instrs += REF_TEST(HeapType(Types.WasmHeapType.ObjectType)) - instrs += I32_CONST(I32(1)) - instrs += I32_XOR - instrs += BR_IF(testFail) + instrs += BR_ON_CAST_FAIL( + CastFlags(true, false), + testFail, + HeapType(Types.WasmHeapType.Simple.Any), + HeapType(Types.WasmHeapType.ObjectType) + ) // if the itables is null (no interfaces are implemented) instrs += LOCAL_GET(exprParam) @@ -777,6 +779,7 @@ object HelperFunctions { } } } + instrs += DROP instrs += LOCAL_GET(found) fctx.buildAndAddToContext() }