Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #116 from sjrd/static-dispatch-when-possible
Browse files Browse the repository at this point in the history
Statically resolve Apply calls when possible.
  • Loading branch information
sjrd authored Apr 23, 2024
2 parents 1be5a2b + 33f6b34 commit 7fd86c1
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 303 deletions.
2 changes: 1 addition & 1 deletion wasm/src/main/scala/WebAssemblyLinkerBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ final class WebAssemblyLinkerBackend(

// sortedClasses.foreach(cls => println(utils.LinkedClassPrinters.showLinkedClass(cls)))

Preprocessor.preprocess(sortedClasses)(context)
Preprocessor.preprocess(sortedClasses, onlyModule.topLevelExports)(context)
HelperFunctions.genGlobalHelpers()
builder.genPrimitiveTypeDataGlobals()
sortedClasses.foreach { clazz =>
Expand Down
7 changes: 4 additions & 3 deletions wasm/src/main/scala/ir2wasm/HelperFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,10 @@ object HelperFunctions {
// reflectiveProxies
instrs += ARRAY_NEW_FIXED(WasmArrayTypeName.reflectiveProxies, 0) // TODO

instrs ++= ctx
.calculateGlobalVTable(IRNames.ObjectClass)
.map(method => WasmInstr.REF_FUNC(method.name))
val objectClassInfo = ctx.getClassInfo(IRNames.ObjectClass)
instrs ++= objectClassInfo.tableEntries.map { methodName =>
ctx.refFuncWithDeclaration(objectClassInfo.resolvedMethodInfos(methodName).wasmName)
}
instrs += STRUCT_NEW(WasmTypeName.WasmStructTypeName.ObjectVTable)
instrs += LOCAL_TEE(arrayTypeDataLocal)

Expand Down
118 changes: 60 additions & 58 deletions wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,27 @@ import org.scalajs.ir.{Names => IRNames}
import org.scalajs.ir.ClassKind
import org.scalajs.ir.Traversers

import org.scalajs.linker.standard.LinkedClass
import org.scalajs.linker.standard.{LinkedClass, LinkedTopLevelExport}

import EmbeddedConstants._
import WasmContext._

object Preprocessor {
def preprocess(classes: List[LinkedClass])(implicit ctx: WasmContext): Unit = {
def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport])(implicit
ctx: WasmContext
): Unit = {
for (clazz <- classes)
preprocess(clazz)

val collector = new AbstractMethodCallCollector(ctx)
for (clazz <- classes)
collector.collectAbstractMethodCalls(clazz)
for (tle <- tles)
collector.collectAbstractMethodCalls(tle)

for (clazz <- classes) {
collectAbstractMethodCalls(clazz)
ctx.getClassInfo(clazz.className).buildMethodTable()

if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)
HelperFunctions.genInstanceTest(clazz)
HelperFunctions.genCloneFunction(clazz)
Expand All @@ -46,20 +55,14 @@ object Preprocessor {
Nil
}

val classMethodInfos = {
val classConcretePublicMethodNames = {
if (kind.isClass || kind == ClassKind.HijackedClass) {
clazz.methods
.filter(_.flags.namespace == IRTrees.MemberNamespace.Public)
.map(method => makeWasmFunctionInfo(clazz, method))
} else {
Nil
}
}
val reflectiveProxies = {
if (kind.isClass || kind == ClassKind.HijackedClass) {
clazz.methods
.filter(_.name.name.isReflectiveProxy)
.map(method => makeWasmFunctionInfo(clazz, method))
for {
m <- clazz.methods
if m.body.isDefined && m.flags.namespace == IRTrees.MemberNamespace.Public
} yield {
m.methodName
}
} else {
Nil
}
Expand All @@ -85,11 +88,11 @@ object Preprocessor {
ctx.putClassInfo(
clazz.name.name,
new WasmClassInfo(
ctx,
clazz.name.name,
kind,
clazz.jsClassCaptures,
classMethodInfos,
reflectiveProxies,
classConcretePublicMethodNames,
allFieldDefs,
clazz.superClass.map(_.name),
clazz.interfaces.map(_.name),
Expand Down Expand Up @@ -129,19 +132,6 @@ object Preprocessor {
clazz.ancestors.foreach(ancestor => ctx.getClassInfo(ancestor).setHasInstances())
}

private def makeWasmFunctionInfo(
clazz: LinkedClass,
method: IRTrees.MethodDef
): WasmFunctionInfo = {
WasmFunctionInfo(
Names.WasmFunctionName(method.flags.namespace, clazz.name.name, method.name.name),
method.args.map(_.ptpe),
method.resultType,
isAbstract = method.body.isEmpty,
isReflectiveProxy = method.name.name.isReflectiveProxy
)
}

/** Collect WasmFunctionInfo based on the abstract method call
*
* ```
Expand All @@ -162,35 +152,47 @@ object Preprocessor {
* It keeps B.c because it's concrete and used. But because `C.c` isn't there at all anymore, if
* we have val `x: C` and we call `x.c`, we don't find the method at all.
*/
private def collectAbstractMethodCalls(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
object traverser extends Traversers.Traverser {
import IRTrees._

override def traverse(tree: Tree): Unit = {
super.traverse(tree)

tree match {
case Apply(flags, receiver, methodName, _) =>
receiver.tpe match {
case IRTypes.ClassType(className) =>
val classInfo = ctx.getClassInfo(className)
if (classInfo.hasInstances)
classInfo.maybeAddAbstractMethod(methodName.name, ctx)
case _ =>
()
}

case _ =>
()
}
private class AbstractMethodCallCollector(ctx: WasmContext) extends Traversers.Traverser {
import IRTrees._

def collectAbstractMethodCalls(clazz: LinkedClass): Unit = {
for (method <- clazz.methods)
traverseMethodDef(method)
for (jsConstructor <- clazz.jsConstructorDef)
traverseJSConstructorDef(jsConstructor)
for (export <- clazz.exportedMembers)
traverseJSMethodPropDef(export)
}

def collectAbstractMethodCalls(tle: LinkedTopLevelExport): Unit = {
tle.tree match {
case IRTrees.TopLevelMethodExportDef(_, jsMethodDef) =>
traverseJSMethodPropDef(jsMethodDef)
case _ =>
()
}
}

for (method <- clazz.methods)
traverser.traverseMethodDef(method)
for (jsConstructor <- clazz.jsConstructorDef)
traverser.traverseJSConstructorDef(jsConstructor)
for (export <- clazz.exportedMembers)
traverser.traverseJSMethodPropDef(export)
override def traverse(tree: Tree): Unit = {
super.traverse(tree)

tree match {
case Apply(flags, receiver, methodName, _) if !methodName.name.isReflectiveProxy =>
receiver.tpe match {
case IRTypes.ClassType(className) =>
val classInfo = ctx.getClassInfo(className)
if (classInfo.hasInstances)
classInfo.registerDynamicCall(methodName.name)
case IRTypes.AnyType =>
ctx.getClassInfo(IRNames.ObjectClass).registerDynamicCall(methodName.name)
case _ =>
// For all other cases, including arrays, we will always perform a static dispatch
()
}

case _ =>
()
}
}
}
}
18 changes: 0 additions & 18 deletions wasm/src/main/scala/ir2wasm/TypeTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,6 @@ import wasm.wasm4s.Names._

object TypeTransformer {

val makeReceiverType: Types.WasmType =
Types.WasmRefType.any

def transformFunctionType(
// clazz: WasmContext.WasmClassInfo,
method: WasmContext.WasmFunctionInfo
)(implicit ctx: TypeDefinableWasmContext): WasmTypeName = {
// val className = clazz.name
val name = method.name
val receiverType = makeReceiverType
// if (clazz.kind.isClass) List(makeReceiverType) else Nil
val sig = WasmFunctionSignature(
receiverType +: method.argTypes.map(transformType),
transformResultType(method.resultType)
)
ctx.addFunctionTypeInMainRecType(sig)
}

/** This transformation should be used only for the result types of functions or blocks.
*
* `nothing` translates to an empty result type list, because Wasm does not have a bottom type
Expand Down
65 changes: 32 additions & 33 deletions wasm/src/main/scala/ir2wasm/WasmBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,13 @@ class WasmBuilder(coreSpec: CoreSpec) {
}
}

private def genTypeDataFieldValues(clazz: LinkedClass, vtableElems: List[WasmFunctionInfo])(
implicit ctx: WasmContext
private def genTypeDataFieldValues(
clazz: LinkedClass,
reflectiveProxies: List[ConcreteMethodInfo]
)(implicit
ctx: WasmContext
): List[WasmInstr] = {
import WasmFieldName.typeData._
import WasmFieldName.typeData.{reflectiveProxies => _, _}

val className = clazz.className
val classInfo = ctx.getClassInfo(className)
Expand Down Expand Up @@ -201,7 +204,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
classInfo.specialInstanceTypes,
IRTypes.ClassRef(clazz.className),
isJSClassInstanceFuncOpt,
vtableElems
reflectiveProxies
)
}

Expand Down Expand Up @@ -267,7 +270,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
specialInstanceTypes: Int,
typeRef: IRTypes.NonArrayTypeRef,
isJSClassInstanceFuncOpt: Option[WasmFunctionName],
vtableElems: List[WasmFunctionInfo]
reflectiveProxies: List[ConcreteMethodInfo]
)(implicit
ctx: WasmContext
): List[WasmInstr] = {
Expand Down Expand Up @@ -329,16 +332,15 @@ class WasmBuilder(coreSpec: CoreSpec) {
case Some(funcName) => REF_FUNC(funcName)
}

val reflectiveProxies: List[WasmInstr] = {
val proxies = vtableElems.filter(_.isReflectiveProxy)
proxies.flatMap { method =>
val proxyId = ctx.getReflectiveProxyId(method.name.simpleName)
val reflectiveProxiesInstrs: List[WasmInstr] = {
reflectiveProxies.flatMap { proxyInfo =>
val proxyId = ctx.getReflectiveProxyId(proxyInfo.methodName)
List(
I32_CONST(proxyId),
REF_FUNC(method.name),
STRUCT_NEW(Names.WasmTypeName.WasmStructTypeName.reflectiveProxy)
REF_FUNC(proxyInfo.wasmName),
STRUCT_NEW(WasmStructTypeName.reflectiveProxy)
)
} :+ ARRAY_NEW_FIXED(Names.WasmTypeName.WasmArrayTypeName.reflectiveProxies, proxies.size)
} :+ ARRAY_NEW_FIXED(WasmArrayTypeName.reflectiveProxies, reflectiveProxies.size)
}

nameDataValue :::
Expand Down Expand Up @@ -368,7 +370,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
// reflective proxies - used to reflective call on the class at runtime.
// Generated instructions create an array of reflective proxy structs, where each struct
// contains the ID of the reflective proxy and a reference to the actual method implementation.
reflectiveProxies
reflectiveProxiesInstrs
}

private def genTypeDataGlobal(
Expand Down Expand Up @@ -399,8 +401,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
val classInfo = ctx.getClassInfo(className)

// generate vtable type, this should be done for both abstract and concrete classes
val vtable = ctx.calculateVtableType(className)
val vtableTypeName = genVTableType(clazz, vtable.functions)
val vtableTypeName = genVTableType(classInfo)

val isAbstractClass = !clazz.hasDirectInstances

Expand All @@ -412,9 +413,12 @@ class WasmBuilder(coreSpec: CoreSpec) {

if (!isAbstractClass) {
// Generate an actual vtable
val functions = ctx.calculateGlobalVTable(className)
val typeDataFieldValues = genTypeDataFieldValues(clazz, functions)
val vtableElems = functions.map(method => WasmInstr.REF_FUNC(method.name))
val reflectiveProxies =
classInfo.resolvedMethodInfos.valuesIterator.filter(_.methodName.isReflectiveProxy).toList
val typeDataFieldValues = genTypeDataFieldValues(clazz, reflectiveProxies)
val vtableElems = classInfo.tableEntries.map { methodName =>
REF_FUNC(classInfo.resolvedMethodInfos(methodName).wasmName)
}
val globalVTable =
genTypeDataGlobal(typeRef, vtableTypeName, typeDataFieldValues, vtableElems)
ctx.addGlobal(globalVTable)
Expand Down Expand Up @@ -447,21 +451,19 @@ class WasmBuilder(coreSpec: CoreSpec) {
structType
}

private def genVTableType(clazz: LinkedClass, functions: List[WasmFunctionInfo])(implicit
ctx: WasmContext
): WasmTypeName = {
val typeName = Names.WasmTypeName.WasmStructTypeName.forVTable(clazz.name.name)
private def genVTableType(classInfo: WasmClassInfo)(implicit ctx: WasmContext): WasmTypeName = {
val typeName = Names.WasmTypeName.WasmStructTypeName.forVTable(classInfo.name)
val vtableFields =
functions.map { method =>
classInfo.tableEntries.map { methodName =>
WasmStructField(
Names.WasmFieldName.forMethodTableEntry(method.name),
WasmRefType.nullable(method.toWasmFunctionType()),
Names.WasmFieldName.forMethodTableEntry(methodName),
WasmRefType(ctx.tableFunctionType(methodName)),
isMutable = false
)
}
val superType = clazz.superClass match {
val superType = classInfo.superClass match {
case None => WasmTypeName.WasmStructTypeName.typeData
case Some(s) => WasmTypeName.WasmStructTypeName.forVTable(s.name)
case Some(s) => WasmTypeName.WasmStructTypeName.forVTable(s)
}
val structType = WasmStructType(
WasmStructType.typeData.fields ::: vtableFields
Expand Down Expand Up @@ -561,18 +563,15 @@ class WasmBuilder(coreSpec: CoreSpec) {
val classInfo = ctx.getClassInfo(clazz.className)
val itableTypeName = Names.WasmTypeName.WasmStructTypeName.forITable(className)
val itableType = WasmStructType(
classInfo.methods.map { m =>
classInfo.tableEntries.map { methodName =>
WasmStructField(
Names.WasmFieldName(m.name.simpleName),
WasmRefType.nullable(m.toWasmFunctionType()),
Names.WasmFieldName.forMethodTableEntry(methodName),
WasmRefType(ctx.tableFunctionType(methodName)),
isMutable = false
)
}
)
ctx.mainRecType.addSubType(itableTypeName, itableType)
// typeName
// genITable
// generateVTable()
}

private def transformModuleClass(clazz: LinkedClass)(implicit ctx: WasmContext) = {
Expand Down
Loading

0 comments on commit 7fd86c1

Please sign in to comment.