From 6b2bc770805a2f23c673d56841944d6725f72cfe Mon Sep 17 00:00:00 2001 From: Rikito Taniguchi Date: Thu, 14 Mar 2024 14:25:12 +0900 Subject: [PATCH] WIP: Array --- cli/src/main/scala/TestSuites.scala | 2 +- sample/src/main/scala/Sample.scala | 91 +--------- .../main/scala/testsuite/core/ArrayTest.scala | 27 +++ .../main/scala/ir2wasm/TypeTransformer.scala | 56 +++--- .../scala/ir2wasm/WasmExpressionBuilder.scala | 169 +++++++++++++----- wasm/src/main/scala/wasm4s/Names.scala | 3 +- wasm/src/main/scala/wasm4s/Wasm.scala | 6 +- wasm/src/main/scala/wasm4s/WasmContext.scala | 24 ++- 8 files changed, 208 insertions(+), 170 deletions(-) create mode 100644 test-suite/src/main/scala/testsuite/core/ArrayTest.scala diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index b773ce8e..827421de 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -5,7 +5,7 @@ object TestSuites { val suites = List( TestSuite("testsuite.core.simple.Simple", "simple"), TestSuite("testsuite.core.add.Add", "add"), - TestSuite("testsuite.core.add.Add", "add"), + TestSuite("testsuite.core.array.ArrayTest", "array"), TestSuite("testsuite.core.virtualdispatch.VirtualDispatch", "virtualDispatch"), TestSuite("testsuite.core.interfacecall.InterfaceCall", "interfaceCall"), TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"), diff --git a/sample/src/main/scala/Sample.scala b/sample/src/main/scala/Sample.scala index 1e687e4e..b23ba8bb 100644 --- a/sample/src/main/scala/Sample.scala +++ b/sample/src/main/scala/Sample.scala @@ -4,95 +4,14 @@ import scala.annotation.tailrec import scala.scalajs.js.annotation._ -// -// class Base { -// def sqrt(x: Int) = x * x -// } -// object Main { @JSExportTopLevel("test") def test() = { - val i = 4 - val loopFib = fib(new LoopFib {}, i) - val recFib = fib(new RecFib {}, i) - val tailrecFib = fib(new TailRecFib {}, i) - loopFib == recFib && loopFib == tailrecFib - } - def fib(fib: Fib, n: Int): Int = fib.fib(n) -} + val a = Array(Array(1), Array(2), Array(3)) + a(0) = Array(100) // Assign(ArraySelect(...), ...) + a(0)(0) == 100 // ArraySelect(...) - -trait LoopFib extends Fib { - def fib(n: Int): Int = { - var a = 0 - var b = 1 - var i = 0 - while (i < n) { - val temp = b - b = a + b - a = temp - i += 1 - } - a + val d3 = Array.ofDim(3, 3, 3) + d3(0)(0)(0) == 0 } - } - -trait RecFib extends Fib { - def fib(n: Int): Int = - if (n <= 1) { - n - } else { - fib(n - 1) + fib(n - 2) - } -} - -trait TailRecFib extends Fib { - def fib(n: Int): Int = fibLoop(n, 0, 1) - - @tailrec - final def fibLoop(n: Int, a: Int, b: Int): Int = - if (n == 0) a - else fibLoop(n - 1, b, a + b) -} - -trait Fib { - def fib(n: Int): Int - // = { - // if (n <= 1) { - // n - // } else { - // fib(n - 1) + fib(n - 2) - // } - // } - -} - -// -// -// object Bar { -// def bar(b: Base) = b.base -// } - -// class Base extends Incr { -// override def incr(x: Int) = foo(x) + 1 -// } -// -// trait Incr extends BaseTrait { -// // val one = 1 -// def incr(x: Int): Int -// } -// -// trait BaseTrait { -// def foo(x: Int) = x -// } - -// object Foo { -// def foo = -// Main.ident(1) -// } -// -// class Derived(override val i: Int) extends Base(i) { -// def derived(x: Int) = x * i -// override def base(x: Int): Int = x * i -// } diff --git a/test-suite/src/main/scala/testsuite/core/ArrayTest.scala b/test-suite/src/main/scala/testsuite/core/ArrayTest.scala new file mode 100644 index 00000000..86c812ae --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/ArrayTest.scala @@ -0,0 +1,27 @@ +package testsuite.core.array + +import scala.scalajs.js.annotation._ + +object ArrayTest { + def main(): Unit = { val _ = test() } + @JSExportTopLevel("array") + def test(): Boolean = { + testSimple() && testNested() && testSelect() + } + + def testSimple(): Boolean = { + val a = Array(1, 2, 3) + a.length == 3 + } + + def testNested(): Boolean = { + val a = Array(Array(1, 2), Array(2), Array(3)) + a.length == 3 + } + + def testSelect(): Boolean = { + val a = Array(Array(1), Array(2), Array(3)) + a(0) = Array(100) // Assign(ArraySelect(...), ...) + a(0)(0) == 100 // ArraySelect(...) + } +} \ No newline at end of file diff --git a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala index ff5e19c9..c7c862d2 100644 --- a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala +++ b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala @@ -15,7 +15,7 @@ object TypeTransformer { def transformFunctionType( // clazz: WasmContext.WasmClassInfo, method: WasmContext.WasmFunctionInfo - )(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionType = { + )(implicit ctx: TypeDefinableWasmContext): WasmFunctionType = { // val className = clazz.name val name = method.name val receiverType = makeReceiverType @@ -43,41 +43,35 @@ object TypeTransformer { t match { case IRTypes.AnyType => Types.WasmAnyRef - case tpe @ IRTypes.ArrayType(IRTypes.ArrayTypeRef(elemType, size)) => - // TODO - // val wasmElemTy = - // elemType match { - // case IRTypes.ClassRef(className) => - // // val gcTypeSym = context.gcTypes.reference(Ident(className.nameString)) - // Types.WasmRefType(Types.WasmHeapType.Type(Names.WasmGCTypeName.fromIR(className))) - // case IRTypes.PrimRef(tpe) => - // transform(tpe) - // } - // val field = WasmStructField("TODO", wasmElemTy, isMutable = false) - // val arrayTySym = - // context.gcTypes.define(WasmArrayType(Names.WasmGCTypeName.fromIR(tpe), field)) - // Types.WasmRefType(Types.WasmHeapType.Type(arrayTySym)) - ??? - case clazz @ IRTypes.ClassType(className) => - className match { - case _ => - val info = ctx.getClassInfo(clazz.className) - if (info.isAncestorOfHijackedClass) - Types.WasmAnyRef - else if (info.isInterface) - Types.WasmRefNullType(Types.WasmHeapType.ObjectType) - else - Types.WasmRefNullType( - Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName(className)) - ) - } - case IRTypes.RecordType(fields) => ??? + case tpe: IRTypes.ArrayType => + Types.WasmRefNullType( + Types.WasmHeapType.Type(Names.WasmTypeName.WasmArrayTypeName(tpe)) + ) + case IRTypes.ClassType(className) => transformClassByName(className) + case IRTypes.RecordType(fields) => ??? case IRTypes.StringType | IRTypes.UndefType => Types.WasmRefType.any case p: IRTypes.PrimTypeWithRef => transformPrimType(p) } - def transformPrimType( + private def transformClassByName( + className: IRNames.ClassName + )(implicit ctx: ReadOnlyWasmContext): Types.WasmType = { + className match { + case _ => + val info = ctx.getClassInfo(className) + if (info.isAncestorOfHijackedClass) + Types.WasmAnyRef + else if (info.isInterface) + Types.WasmRefNullType(Types.WasmHeapType.ObjectType) + else + Types.WasmRefNullType( + Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName(className)) + ) + } + } + + private def transformPrimType( t: IRTypes.PrimTypeWithRef ): Types.WasmType = t match { diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 691f7e12..50ac07f8 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -20,7 +20,7 @@ import org.scalajs.ir.Position object WasmExpressionBuilder { def transformBody(tree: IRTrees.Tree, resultType: IRTypes.Type)(implicit - ctx: FunctionTypeWriterWasmContext, + ctx: TypeDefinableWasmContext, fctx: WasmFunctionContext ): WasmExpr = { val builder = new WasmExpressionBuilder(ctx, fctx) @@ -40,8 +40,8 @@ object WasmExpressionBuilder { private object PrimTypeWithBoxUnbox { def unapply(primType: IRTypes.PrimTypeWithRef): Option[IRTypes.PrimTypeWithRef] = { primType match { - case IRTypes.BooleanType | IRTypes.ByteType | IRTypes.ShortType | - IRTypes.IntType | IRTypes.FloatType | IRTypes.DoubleType => + case IRTypes.BooleanType | IRTypes.ByteType | IRTypes.ShortType | IRTypes.IntType | + IRTypes.FloatType | IRTypes.DoubleType => Some(primType) case _ => None @@ -51,7 +51,7 @@ object WasmExpressionBuilder { } private class WasmExpressionBuilder private ( - ctx: FunctionTypeWriterWasmContext, + ctx: TypeDefinableWasmContext, fctx: WasmFunctionContext ) { import WasmExpressionBuilder._ @@ -123,6 +123,11 @@ private class WasmExpressionBuilder private ( case t: IRTrees.While => genWhile(t) case t: IRTrees.Skip => IRTypes.NoType case t: IRTrees.IdentityHashCode => genIdentityHashCode(t) + // array + case t: IRTrees.ArrayLength => genArrayLength(t) + // case t: IRTrees.NewArray => genNewArray(t) + case t: IRTrees.ArraySelect => genArraySelect(t) + case t: IRTrees.ArrayValue => genArrayValue(t) case _ => println(tree) ??? @@ -141,11 +146,9 @@ private class WasmExpressionBuilder private ( // case IRTrees.JSLinkingInfo(pos) => // case IRTrees.Select(tpe) => // case IRTrees.Return(pos) => - // case IRTrees.ArrayLength(pos) => // case IRTrees.While(pos) => // case IRTrees.LoadJSConstructor(pos) => // case IRTrees.JSSuperMethodCall(pos) => - // case IRTrees.NewArray(pos) => // case IRTrees.Match(tpe) => // case IRTrees.Throw(pos) => // case IRTrees.JSNew(pos) => @@ -162,7 +165,6 @@ private class WasmExpressionBuilder private ( // case IRTrees.GetClass(pos) => // case IRTrees.JSImportMeta(pos) => // case IRTrees.JSSuperSelect(pos) => - // case IRTrees.ArraySelect(tpe) => // case IRTrees.JSSelect(pos) => // case IRTrees.LoadJSModule(pos) => // case IRTrees.JSFunctionApply(pos) => @@ -171,7 +173,6 @@ private class WasmExpressionBuilder private ( // case IRTrees.Clone(pos) => // case IRTrees.CreateJSClass(pos) => // case IRTrees.Transient(pos) => - // case IRTrees.ArrayValue(pos) => // case IRTrees.JSDelete(pos) => // case IRTrees.ForIn(pos) => // case IRTrees.JSArrayConstr(pos) => @@ -245,7 +246,18 @@ private class WasmExpressionBuilder private ( genTree(t.rhs, t.lhs.tpe) instrs += STRUCT_SET(TypeIdx(className), idx) - case assign: IRTrees.ArraySelect => ??? // array.set + case sel: IRTrees.ArraySelect => + val typeName = sel.array.tpe match { + case arrTy: IRTypes.ArrayType => WasmArrayTypeName(arrTy) + case _ => + throw new IllegalArgumentException( + s"ArraySelect.array must be an array type, but has type ${sel.array.tpe}" + ) + } + genTreeAuto(sel.array) + genTree(sel.index, IRTypes.IntType) + genTree(t.rhs, t.lhs.tpe) + instrs += ARRAY_SET(TypeIdx(typeName)) case assign: IRTrees.RecordSelect => ??? // struct.set case assign: IRTrees.JSPrivateSelect => ??? case assign: IRTrees.JSSelect => ??? @@ -276,7 +288,9 @@ private class WasmExpressionBuilder private ( // statically resolved call with non-null argument val receiverClassName = IRTypes.PrimTypeToBoxedClass(prim) genApplyStatically( - IRTrees.ApplyStatically(t.flags, t.receiver, receiverClassName, t.method, t.args)(t.tpe)(t.pos) + IRTrees.ApplyStatically(t.flags, t.receiver, receiverClassName, t.method, t.args)(t.tpe)( + t.pos + ) ) case IRTypes.ClassType(className) if IRNames.HijackedClasses.contains(className) => @@ -294,9 +308,9 @@ private class WasmExpressionBuilder private ( implicit val pos: Position = t.pos val receiverClassName = t.receiver.tpe match { - case ClassType(className) => className - case IRTypes.AnyType => IRNames.ObjectClass - case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") + case ClassType(className) => className + case IRTypes.AnyType => IRNames.ObjectClass + case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") } val receiverClassInfo = ctx.getClassInfo(receiverClassName) @@ -388,14 +402,17 @@ private class WasmExpressionBuilder private ( fctx.locals.define(WasmLocal(receiverLocal, Types.WasmRefType.any, isParameter = false)) instrs += LOCAL_SET(LocalIdx(receiverLocal)) - val argsLocals: List[WasmLocalName] = for ((arg, typeRef) <- t.args.zip(t.method.name.paramTypeRefs)) yield { - val typ = ctx.inferTypeFromTypeRef(typeRef) - genTree(arg, typ) - val localName = fctx.genSyntheticLocalName() - fctx.locals.define(WasmLocal(localName, TypeTransformer.transformType(typ)(ctx), isParameter = false)) - instrs += LOCAL_SET(LocalIdx(localName)) - localName - } + val argsLocals: List[WasmLocalName] = + for ((arg, typeRef) <- t.args.zip(t.method.name.paramTypeRefs)) yield { + val typ = ctx.inferTypeFromTypeRef(typeRef) + genTree(arg, typ) + val localName = fctx.genSyntheticLocalName() + fctx.locals.define( + WasmLocal(localName, TypeTransformer.transformType(typ)(ctx), isParameter = false) + ) + instrs += LOCAL_SET(LocalIdx(localName)) + localName + } instrs += LOCAL_GET(LocalIdx(receiverLocal)) argsLocals } @@ -458,21 +475,19 @@ private class WasmExpressionBuilder private ( } /** Generates a vtable- or itable-based dispatch. - * - * Before this code gen, the stack must contain the receiver and the args of - * the target method. In addition, the receiver must be available in the - * local `receiverLocalForDispatch`. The two occurrences of the receiver - * must have the type for dispatch. - * - * After this code gen, the stack contains the result. If the result type is - * `NothingType`, `genTableDispatch` leaves the stack in an arbitrary state. - * It is up to the caller to insert an `unreachable` instruction when - * appropriate. - */ + * + * Before this code gen, the stack must contain the receiver and the args of the target method. + * In addition, the receiver must be available in the local `receiverLocalForDispatch`. The two + * occurrences of the receiver must have the type for dispatch. + * + * After this code gen, the stack contains the result. If the result type is `NothingType`, + * `genTableDispatch` leaves the stack in an arbitrary state. It is up to the caller to insert an + * `unreachable` instruction when appropriate. + */ def genTableDispatch( - receiverClassInfo: WasmContext.WasmClassInfo, - methodName: IRNames.MethodName, - receiverLocalForDispatch: WasmLocalName + receiverClassInfo: WasmContext.WasmClassInfo, + methodName: IRNames.MethodName, + receiverLocalForDispatch: WasmLocalName ): Unit = { // Generates an itable-based dispatch. def genITableDispatch(): Unit = { @@ -960,7 +975,9 @@ private class WasmExpressionBuilder private ( case IRTypes.NothingType => () // unreachable case IRTypes.NoType => - throw new AssertionError(s"Found expression of type void in String_+ at ${tree.pos}: $tree") + throw new AssertionError( + s"Found expression of type void in String_+ at ${tree.pos}: $tree" + ) } case IRTypes.ClassType(IRNames.BoxedStringClass) => @@ -1095,7 +1112,9 @@ private class WasmExpressionBuilder private ( case IRTypes.UndefType | IRTypes.StringType => () case PrimTypeWithBoxUnbox(primType) => - instrs += CALL(WasmImmediate.FuncIdx(WasmFunctionName.unboxOrNull(primType.primRef))) + instrs += CALL( + WasmImmediate.FuncIdx(WasmFunctionName.unboxOrNull(primType.primRef)) + ) case IRTypes.CharType => val structTypeName = WasmStructTypeName(SpecialNames.CharBoxClass) instrs += REF_CAST_NULL(HeapType(Types.WasmHeapType.Type(structTypeName))) @@ -1121,11 +1140,11 @@ private class WasmExpressionBuilder private ( } /** Unbox the `anyref` on the stack to the target `PrimType`. - * - * `targetTpe` must not be `NothingType`, `NullType` nor `NoType`. - * - * The type left on the stack is non-nullable. - */ + * + * `targetTpe` must not be `NothingType`, `NullType` nor `NoType`. + * + * The type left on the stack is non-nullable. + */ private def genUnbox(targetTpe: IRTypes.PrimType)(implicit pos: Position): Unit = { targetTpe match { case IRTypes.UndefType => @@ -1296,7 +1315,10 @@ private class WasmExpressionBuilder private ( } /** Codegen to box a primitive `char`/`long` into a `CharacterBox`/`LongBox`. */ - private def genBox(primType: IRTypes.PrimTypeWithRef, boxClassName: IRNames.ClassName): IRTypes.Type = { + private def genBox( + primType: IRTypes.PrimTypeWithRef, + boxClassName: IRNames.ClassName + ): IRTypes.Type = { // `primTyp` is `i32` for `char` (containing a `u16` value) or `i64` for `long`. val primTyp = TypeTransformer.transformType(primType)(ctx) val primLocal = WasmLocal(fctx.genSyntheticLocalName(), primTyp, isParameter = false) @@ -1345,4 +1367,65 @@ private class WasmExpressionBuilder private ( IRTypes.IntType } + + // =============================================================================== + // array + // =============================================================================== + private def genArrayLength(t: IRTrees.ArrayLength): IRTypes.Type = { + genTreeAuto(t.array) + instrs += ARRAY_LEN + IRTypes.IntType + } + + private def genNewArray(t: IRTrees.NewArray): IRTypes.Type = { + ??? + } + + /** For getting element from an array, array.set should be generated by transformation of + * `Assign(ArraySelect(...), ...)` + */ + private def genArraySelect(t: IRTrees.ArraySelect): IRTypes.Type = { + val arrayType = t.array.tpe match { + case t: IRTypes.ArrayType => t + case _ => + throw new IllegalArgumentException( + s"ArraySelect.array must be an array type, but has type ${t.array.tpe}" + ) + } + genTreeAuto(t.array) + genTree(t.index, IRTypes.IntType) + instrs += ARRAY_GET(TypeIdx(WasmTypeName.WasmArrayTypeName(arrayType))) + + val typeRef = arrayType.arrayTypeRef + if (typeRef.dimensions > 1) IRTypes.ArrayType(typeRef.copy(dimensions = typeRef.dimensions - 1)) + else + typeRef.base match { + case IRTypes.ClassRef(className) => ClassType(className) + case IRTypes.PrimRef(tpe) => tpe + } + } + + private def genArrayValue(t: IRTrees.ArrayValue): IRTypes.Type = { + val irElemTy = extractArrayElemType(t.typeRef) + val wasmElemTy = TypeTransformer.transformType(irElemTy)(ctx) + val arrTyName = Names.WasmTypeName.WasmArrayTypeName(t.tpe) + ctx.addArrayType( + WasmArrayType( + arrTyName, + WasmStructField(Names.WasmFieldName.arrayField, wasmElemTy, isMutable = true) + ) + ) + t.elems.foreach(genTreeAuto) + instrs += ARRAY_NEW_FIXED(TypeIdx(arrTyName), I32(t.elems.size)) + t.tpe + } + + private def extractArrayElemType(typeRef: IRTypes.ArrayTypeRef): IRTypes.Type = { + if (typeRef.dimensions > 1) IRTypes.ArrayType(typeRef.copy(dimensions = typeRef.dimensions - 1)) + else + typeRef.base match { + case IRTypes.ClassRef(className) => ClassType(className) + case IRTypes.PrimRef(tpe) => tpe + } + } } diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala index da8feb91..0278c6fa 100644 --- a/wasm/src/main/scala/wasm4s/Names.scala +++ b/wasm/src/main/scala/wasm4s/Names.scala @@ -140,6 +140,7 @@ object Names { val vtable = new WasmFieldName("vtable") val itable = new WasmFieldName("itable") val itables = new WasmFieldName("itables") + val arrayField = new WasmFieldName("field") } // GC types ==== @@ -161,7 +162,7 @@ object Names { def apply(ty: IRTypes.ArrayType) = { val ref = ty.arrayTypeRef // TODO: better naming? - new WasmArrayTypeName(s"${ref.base.displayName}_${ref.dimensions}") + new WasmArrayTypeName(s"array_${ref.base.displayName}_${ref.dimensions}") } val itables = new WasmArrayTypeName("itable") } diff --git a/wasm/src/main/scala/wasm4s/Wasm.scala b/wasm/src/main/scala/wasm4s/Wasm.scala index 98bad118..994403dc 100644 --- a/wasm/src/main/scala/wasm4s/Wasm.scala +++ b/wasm/src/main/scala/wasm4s/Wasm.scala @@ -81,7 +81,7 @@ case class WasmStructType( ) extends WasmGCTypeDefinition case class WasmArrayType( - name: WasmTypeName, + name: WasmTypeName.WasmArrayTypeName, field: WasmStructField ) extends WasmGCTypeDefinition object WasmArrayType { @@ -111,6 +111,7 @@ object WasmStructField { */ class WasmModule( private val _functionTypes: mutable.ListBuffer[WasmFunctionType] = new mutable.ListBuffer(), + private val _arrayTypes: mutable.Set[WasmArrayType] = new mutable.HashSet(), private val _recGroupTypes: mutable.ListBuffer[WasmStructType] = new mutable.ListBuffer(), // val importsInOrder: List[WasmNamedModuleField] = Nil, private val _imports: mutable.ListBuffer[WasmImport] = new mutable.ListBuffer(), @@ -131,6 +132,7 @@ class WasmModule( ) { def addImport(imprt: WasmImport): Unit = _imports.addOne(imprt) def addFunction(function: WasmFunction): Unit = _definedFunctions.addOne(function) + def addArrayType(typ: WasmArrayType): Unit = _arrayTypes.addOne(typ) def addFunctionType(typ: WasmFunctionType): Unit = _functionTypes.addOne(typ) def addRecGroupType(typ: WasmStructType): Unit = _recGroupTypes.addOne(typ) def addGlobal(typ: WasmGlobal): Unit = _globals.addOne(typ) @@ -138,7 +140,7 @@ class WasmModule( def functionTypes = _functionTypes.toList def recGroupTypes = WasmModule.tsort(_recGroupTypes.toList) - def arrayTypes = List(WasmArrayType.itables) + def arrayTypes = List(WasmArrayType.itables) ++ _arrayTypes.toList def imports = _imports.toList def definedFunctions = _definedFunctions.toList def globals = _globals.toList diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala index b7ad7271..36bacb45 100644 --- a/wasm/src/main/scala/wasm4s/WasmContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmContext.scala @@ -105,7 +105,7 @@ trait ReadOnlyWasmContext { } } -trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmContext => +trait TypeDefinableWasmContext extends ReadOnlyWasmContext { this: WasmContext => protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int] def addFunctionType(sig: WasmFunctionSignature): WasmFunctionTypeName = { @@ -120,9 +120,12 @@ trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmCont case Some(value) => WasmFunctionTypeName(value) } } + def addArrayType(ty: WasmArrayType): Unit = { + module.addArrayType(ty) + } } -class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext { +class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { import WasmContext._ def addExport(exprt: WasmExport[_]): Unit = module.addExport(exprt) def addFunction(fun: WasmFunction): Unit = { @@ -141,7 +144,11 @@ class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit = classInfo.put(name, info) - private def addHelperImport(name: WasmFunctionName, params: List[WasmType], results: List[WasmType]): Unit = { + private def addHelperImport( + name: WasmFunctionName, + params: List[WasmType], + results: List[WasmType] + ): Unit = { val sig = WasmFunctionSignature(params, results) val typ = WasmFunctionType(addFunctionType(sig), sig) module.addImport(WasmImport(name.className, name.methodName, WasmImportDesc.Func(name, typ))) @@ -176,7 +183,11 @@ class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext addHelperImport(WasmFunctionName.intToString, List(WasmInt32), List(WasmRefType.any)) addHelperImport(WasmFunctionName.longToString, List(WasmInt64), List(WasmRefType.any)) addHelperImport(WasmFunctionName.doubleToString, List(WasmFloat64), List(WasmRefType.any)) - addHelperImport(WasmFunctionName.stringConcat, List(WasmRefType.any, WasmRefType.any), List(WasmRefType.any)) + addHelperImport( + WasmFunctionName.stringConcat, + List(WasmRefType.any, WasmRefType.any), + List(WasmRefType.any) + ) addHelperImport(WasmFunctionName.isString, List(WasmAnyRef), List(WasmInt32)) addHelperImport(WasmFunctionName.jsValueHashCode, List(WasmRefType.any), List(WasmInt32)) @@ -245,7 +256,7 @@ object WasmContext { // flags: IRTrees.MemberFlags, isAbstract: Boolean ) { - def toWasmFunctionType()(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionType = + def toWasmFunctionType()(implicit ctx: TypeDefinableWasmContext): WasmFunctionType = TypeTransformer.transformFunctionType(this) } @@ -290,7 +301,8 @@ object WasmContext { .getOrElse(throw new Error(s"Function not found: $name")) def resolveWithIdx(name: WasmFunctionName): (Int, WasmFunctionInfo) = { val idx = functions.indexWhere(_.name.methodName == name.methodName) - if (idx < 0) throw new Error(s"Function not found: $name among ${functions.map(_.name.methodName)}") + if (idx < 0) + throw new Error(s"Function not found: $name among ${functions.map(_.name.methodName)}") else (idx, functions(idx)) } }