Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit b3701b9

Browse files
authored
Merge pull request #10 from sjrd/enable-hijacked-classes
Support transformation of hijacked classes.
2 parents 4c6609b + 45a43a9 commit b3701b9

File tree

9 files changed

+155
-48
lines changed

9 files changed

+155
-48
lines changed

cli/src/main/scala/TestSuites.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ object TestSuites {
44
case class TestSuite(className: String, methodName: String)
55
val suites = List(
66
TestSuite("testsuite.core.simple.Simple", "simple"),
7-
TestSuite("testsuite.core.add.Add", "add")
7+
TestSuite("testsuite.core.add.Add", "add"),
8+
TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"),
9+
TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono")
810
)
911
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package testsuite.core.asinstanceof
2+
3+
import scala.scalajs.js.annotation._
4+
5+
object AsInstanceOfTest {
6+
def main(): Unit = { val _ = test() }
7+
8+
@JSExportTopLevel("asInstanceOf")
9+
def test(): Boolean = {
10+
testInt(5) &&
11+
testClasses(new Child()) &&
12+
testString("foo", true)
13+
}
14+
15+
def testClasses(c: Child): Boolean = {
16+
val c1 = c.asInstanceOf[Child]
17+
val c2 = c.asInstanceOf[Parent]
18+
c1.foo() == 5 && c2.foo() == 5
19+
}
20+
21+
def testInt(x: Int): Boolean = {
22+
val x1 = x.asInstanceOf[Int]
23+
x1 == 5
24+
}
25+
26+
def testString(s: String, b: Boolean): Boolean = {
27+
val s1 = s.asInstanceOf[String]
28+
val s2 = ("" + b).asInstanceOf[String]
29+
s1.length() == 3 && s2.length() == 4
30+
}
31+
32+
class Parent {
33+
def foo(): Int = 5
34+
}
35+
class Child extends Parent
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package testsuite.core.hijackedclassesmono
2+
3+
import scala.scalajs.js.annotation._
4+
5+
object HijackedClassesMonoTest {
6+
def main(): Unit = { val _ = test() }
7+
8+
@JSExportTopLevel("hijackedClassesMono")
9+
def test(): Boolean = {
10+
testInteger(5) &&
11+
testString("foo")
12+
}
13+
14+
def testInteger(x: Int): Boolean = {
15+
x.hashCode() == 5
16+
}
17+
18+
def testString(foo: String): Boolean = {
19+
foo.length() == 3 &&
20+
foo.hashCode() == 101574
21+
}
22+
}

wasm/src/main/scala/Compiler.scala

+5-14
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ object Compiler {
4949
} yield {
5050
val onlyModule = moduleSet.modules.head
5151

52-
val filteredClasses = onlyModule.classDefs.filter { c =>
53-
!ExcludedClasses.contains(c.className)
54-
}
52+
// Sort for stability
53+
val sortedClasses = onlyModule.classDefs.sortBy(_.className)
5554

56-
filteredClasses.sortBy(_.className).foreach(showLinkedClass(_))
55+
sortedClasses.foreach(showLinkedClass(_))
5756

58-
Preprocessor.preprocess(filteredClasses)(context)
57+
Preprocessor.preprocess(sortedClasses)(context)
5958
println("preprocessed")
60-
filteredClasses.foreach { clazz =>
59+
sortedClasses.foreach { clazz =>
6160
builder.transformClassDef(clazz)
6261
}
6362
onlyModule.topLevelExports.foreach { tle =>
@@ -71,14 +70,6 @@ object Compiler {
7170
}
7271
}
7372

74-
private val ExcludedClasses: Set[ir.Names.ClassName] = {
75-
import ir.Names._
76-
HijackedClasses ++ // hijacked classes
77-
Set(
78-
ClassClass // java.lang.Class
79-
)
80-
}
81-
8273
private def showLinkedClass(clazz: LinkedClass): Unit = {
8374
val writer = new java.io.PrintWriter(System.out)
8475
val printer = new LinkedClassPrinter(writer)

wasm/src/main/scala/ir2wasm/LibraryPatches.scala

+9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ object LibraryPatches {
6565

6666
private val MethodPatches: Map[ClassName, List[MethodDef]] = {
6767
Map(
68+
ObjectClass -> List(
69+
// TODO Remove this patch when we support getClass() and full string concatenation
70+
MethodDef(
71+
EMF, m("toString", Nil, T), NON,
72+
Nil, ClassType(BoxedStringClass),
73+
Some(StringLiteral("[object]"))
74+
)(EOH, NOV)
75+
),
76+
6877
BoxedCharacterClass.withSuffix("$") -> List(
6978
MethodDef(
7079
EMF, m("toString", List(C), T), NON,

wasm/src/main/scala/ir2wasm/Preprocessor.scala

+9-13
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,24 @@ object Preprocessor {
1717
for (clazz <- classes)
1818
preprocess(clazz)
1919

20-
for (clazz <- classes) {
21-
if (clazz.className != IRNames.ObjectClass)
22-
collectAbstractMethodCalls(clazz)
23-
}
20+
for (clazz <- classes)
21+
collectAbstractMethodCalls(clazz)
2422
}
2523

2624
private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
2725
clazz.kind match {
28-
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface =>
26+
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass =>
2927
collectMethods(clazz)
3028
case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass |
31-
ClassKind.AbstractJSType | ClassKind.NativeJSClass | ClassKind.HijackedClass =>
29+
ClassKind.AbstractJSType | ClassKind.NativeJSClass =>
3230
???
3331
}
3432
}
3533

3634
private def collectMethods(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
37-
val infos =
38-
if (clazz.name.name == IRNames.ObjectClass) Nil
39-
else
40-
clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
41-
makeWasmFunctionInfo(clazz, method)
42-
}
35+
val infos = clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
36+
makeWasmFunctionInfo(clazz, method)
37+
}
4338
ctx.putClassInfo(
4439
clazz.name.name,
4540
new WasmClassInfo(
@@ -48,7 +43,8 @@ object Preprocessor {
4843
infos,
4944
clazz.fields.collect { case f: IRTrees.FieldDef => Names.WasmFieldName(f.name.name) },
5045
clazz.superClass.map(_.name),
51-
clazz.interfaces.map(_.name)
46+
clazz.interfaces.map(_.name),
47+
clazz.ancestors
5248
)
5349
)
5450
}

wasm/src/main/scala/ir2wasm/WasmBuilder.scala

+18-14
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ class WasmBuilder {
2424

2525
def transformClassDef(clazz: LinkedClass)(implicit ctx: WasmContext) = {
2626
clazz.kind match {
27-
case ClassKind.ModuleClass => transformModuleClass(clazz)
28-
case ClassKind.Class => transformClass(clazz)
29-
case ClassKind.Interface => transformInterface(clazz)
30-
case _ =>
27+
case ClassKind.ModuleClass => transformModuleClass(clazz)
28+
case ClassKind.Class => transformClass(clazz)
29+
case ClassKind.HijackedClass => transformHijackedClass(clazz)
30+
case ClassKind.Interface => transformInterface(clazz)
31+
case _ => ???
3132
}
3233
}
3334

@@ -68,15 +69,10 @@ class WasmBuilder {
6869
)
6970
ctx.addGCType(structType)
7071

71-
// Do not generate methods in Object for now
72-
if (clazz.name.name == IRNames.ObjectClass)
73-
clazz.methods.filter(_.name.name == IRNames.NoArgConstructorName).foreach { method =>
74-
genFunction(clazz, method)
75-
}
76-
else
77-
clazz.methods.foreach { method =>
78-
genFunction(clazz, method)
79-
}
72+
// implementation of methods
73+
clazz.methods.foreach { method =>
74+
genFunction(clazz, method)
75+
}
8076

8177
structType
8278
}
@@ -242,6 +238,12 @@ class WasmBuilder {
242238
transformClassCommon(clazz)
243239
}
244240

241+
private def transformHijackedClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
242+
clazz.methods.foreach { method =>
243+
genFunction(clazz, method)
244+
}
245+
}
246+
245247
private def transformInterface(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
246248
assert(clazz.kind == ClassKind.Interface)
247249
// gen itable type
@@ -389,7 +391,9 @@ class WasmBuilder {
389391
// Receiver type for non-constructor methods needs to be Object type because params are invariant
390392
// Otherwise, vtable can't be a subtype of the supertype's subtype
391393
// Constructor can use the exact type because it won't be registered to vtables.
392-
if (method.flags.namespace.isConstructor)
394+
if (clazz.kind == ClassKind.HijackedClass)
395+
transformType(IRTypes.BoxedClassToPrimType(clazz.name.name))
396+
else if (method.flags.namespace.isConstructor)
393397
WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name)))
394398
else
395399
WasmRefNullType(WasmHeapType.ObjectType),

wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+45-5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
6363
transformApplyStatically(t)
6464
case t: IRTrees.Apply => transformApply(t)
6565
case t: IRTrees.ApplyDynamicImport => ???
66+
case t: IRTrees.AsInstanceOf => transformAsInstanceOf(t)
6667
case t: IRTrees.Block => transformBlock(t)
6768
case t: IRTrees.Labeled => transformLabeled(t)
6869
case t: IRTrees.Return => transformReturn(t)
@@ -180,8 +181,9 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
180181
val wasmArgs = t.args.flatMap(transformTree)
181182

182183
val receiverClassName = t.receiver.tpe match {
183-
case ClassType(className) => className
184-
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
184+
case ClassType(className) => className
185+
case prim: IRTypes.PrimType => IRTypes.PrimTypeToBoxedClass(prim)
186+
case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}")
185187
}
186188
val receiverClassInfo = ctx.getClassInfo(receiverClassName)
187189

@@ -228,6 +230,18 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
228230
TypeIdx(method.toWasmFunctionType()(ctx).name)
229231
)
230232
)
233+
} else if (receiverClassInfo.kind == ClassKind.HijackedClass) {
234+
// statically resolved call
235+
val info = receiverClassInfo.getMethodInfo(t.method.name)
236+
val castIfNeeded =
237+
if (receiverClassName == IRNames.BoxedStringClass && t.receiver.tpe == ClassType(IRNames.BoxedStringClass))
238+
List(REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string))))
239+
else
240+
Nil
241+
pushReceiver ++ castIfNeeded ++ wasmArgs ++
242+
List(
243+
CALL(FuncIdx(info.name))
244+
)
231245
} else { // virtual dispatch
232246
val (methodIdx, info) = ctx
233247
.calculateVtable(receiverClassName)
@@ -401,6 +415,17 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
401415
case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U)
402416
case BinaryOp.Long_>> => longShiftOp(I64_SHR_S)
403417

418+
// New in 1.11
419+
case BinaryOp.String_charAt =>
420+
transformTree(binary.lhs) ++ // push the string
421+
List(
422+
STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array
423+
) ++
424+
transformTree(binary.rhs) ++ // push the index
425+
List(
426+
ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array
427+
)
428+
404429
case _ => transformElementaryBinaryOp(binary)
405430
}
406431
}
@@ -479,9 +504,6 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
479504
case BinaryOp.Double_<= => F64_LE
480505
case BinaryOp.Double_> => F64_GT
481506
case BinaryOp.Double_>= => F64_GE
482-
483-
// // New in 1.11
484-
case BinaryOp.String_charAt => ??? // TODO
485507
}
486508
lhsInstrs ++ rhsInstrs :+ operation
487509
}
@@ -539,6 +561,24 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti
539561
}
540562
}
541563

564+
private def transformAsInstanceOf(tree: IRTrees.AsInstanceOf): List[WasmInstr] = {
565+
val exprInstrs = transformTree(tree.expr)
566+
567+
val sourceTpe = tree.expr.tpe
568+
val targetTpe = tree.tpe
569+
570+
if (IRTypes.isSubtype(sourceTpe, targetTpe)(isSubclass(_, _))) {
571+
// Common case where no cast is necessary
572+
exprInstrs
573+
} else {
574+
println(tree)
575+
???
576+
}
577+
}
578+
579+
private def isSubclass(subClass: IRNames.ClassName, superClass: IRNames.ClassName): Boolean =
580+
ctx.getClassInfo(subClass).ancestors.contains(superClass)
581+
542582
private def transformVarRef(r: IRTrees.VarRef): LOCAL_GET = {
543583
val name = WasmLocalName.fromIR(r.ident.name)
544584
LOCAL_GET(LocalIdx(name))

wasm/src/main/scala/wasm4s/WasmContext.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ object WasmContext {
129129
private var _methods: List[WasmFunctionInfo],
130130
private val fields: List[WasmFieldName],
131131
val superClass: Option[IRNames.ClassName],
132-
val interfaces: List[IRNames.ClassName]
132+
val interfaces: List[IRNames.ClassName],
133+
val ancestors: List[IRNames.ClassName]
133134
) {
134135

135136
def isInterface = kind == ClassKind.Interface
@@ -145,6 +146,12 @@ object WasmContext {
145146
}
146147
}
147148

149+
def getMethodInfo(methodName: IRNames.MethodName): WasmFunctionInfo = {
150+
methods.find(_.name.methodName == methodName.nameString).getOrElse {
151+
throw new IllegalArgumentException(s"Cannot find method ${methodName.nameString} in class ${name.nameString}")
152+
}
153+
}
154+
148155
def getFieldIdx(name: WasmFieldName): WasmImmediate.StructFieldIdx =
149156
fields.indexWhere(_ == name) match {
150157
case i if i < 0 => throw new Error(s"Field not found: $name")

0 commit comments

Comments
 (0)