Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit b256724

Browse files
committed
Make the receiver type non-nullable.
This surfaced a few places when instantiating new objects where we need to manipulate non-null values.
1 parent 9ae87d6 commit b256724

File tree

4 files changed

+44
-41
lines changed

4 files changed

+44
-41
lines changed

Diff for: wasm/src/main/scala/ir2wasm/TypeTransformer.scala

+10-15
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,23 @@ object TypeTransformer {
4343
Types.WasmRefType.nullable(
4444
Names.WasmTypeName.WasmStructTypeName.forArrayClass(tpe.arrayTypeRef)
4545
)
46-
case IRTypes.ClassType(className) => transformClassByName(className)
46+
case IRTypes.ClassType(className) => transformClassType(className)
4747
case IRTypes.RecordType(fields) => ???
4848
case IRTypes.StringType | IRTypes.UndefType =>
4949
Types.WasmRefType.any
5050
case p: IRTypes.PrimTypeWithRef => transformPrimType(p)
5151
}
5252

53-
private def transformClassByName(
53+
def transformClassType(
5454
className: IRNames.ClassName
55-
)(implicit ctx: ReadOnlyWasmContext): Types.WasmType = {
56-
className match {
57-
case _ =>
58-
val info = ctx.getClassInfo(className)
59-
if (info.isAncestorOfHijackedClass)
60-
Types.WasmRefType.anyref
61-
else if (info.isInterface)
62-
Types.WasmRefType.nullable(Types.WasmHeapType.ObjectType)
63-
else
64-
Types.WasmRefType.nullable(
65-
Names.WasmTypeName.WasmStructTypeName.forClass(className)
66-
)
67-
}
55+
)(implicit ctx: ReadOnlyWasmContext): Types.WasmRefType = {
56+
val info = ctx.getClassInfo(className)
57+
if (info.isAncestorOfHijackedClass)
58+
Types.WasmRefType.anyref
59+
else if (info.isInterface)
60+
Types.WasmRefType.nullable(Types.WasmHeapType.ObjectType)
61+
else
62+
Types.WasmRefType.nullable(WasmTypeName.WasmStructTypeName.forClass(className))
6863
}
6964

7065
private def transformPrimType(

Diff for: wasm/src/main/scala/ir2wasm/WasmBuilder.scala

+23-17
Original file line numberDiff line numberDiff line change
@@ -489,29 +489,35 @@ class WasmBuilder(coreSpec: CoreSpec) {
489489
ctor.name.name
490490
)
491491

492+
val resultTyp = WasmRefType(typeName)
493+
492494
implicit val fctx = WasmFunctionContext(
493495
WasmFunctionName.loadModule(clazz.className),
494496
Nil,
495-
List(WasmRefType.nullable(typeName))
497+
List(resultTyp)
496498
)
497499

500+
val instanceLocal = fctx.addLocal("instance", resultTyp)
501+
498502
import fctx.instrs
499503

500-
// global.get $module_name
501-
// ref.if_null
502-
// ref.null $module_type
503-
// call $module_init ;; should set to global
504-
// end
505-
// global.get $module_name
506-
instrs += GLOBAL_GET(globalInstanceName) // [rt]
507-
instrs += REF_IS_NULL // [rt] -> [i32] (bool)
508-
instrs += IF(BlockType.ValueType())
509-
instrs += CALL(WasmFunctionName.newDefault(clazz.name.name))
510-
instrs += GLOBAL_SET(globalInstanceName)
511-
instrs += GLOBAL_GET(globalInstanceName)
512-
instrs += CALL(ctorName)
513-
instrs += END
514-
instrs += GLOBAL_GET(globalInstanceName) // [rt]
504+
fctx.block(resultTyp) { nonNullLabel =>
505+
// load global, return if not null
506+
instrs += GLOBAL_GET(globalInstanceName)
507+
instrs += BR_ON_NON_NULL(nonNullLabel)
508+
509+
// create an instance and call its constructor
510+
instrs += CALL(WasmFunctionName.newDefault(clazz.name.name))
511+
instrs += LOCAL_TEE(instanceLocal)
512+
instrs += CALL(ctorName)
513+
514+
// store it in the global
515+
instrs += LOCAL_GET(instanceLocal)
516+
instrs += GLOBAL_SET(globalInstanceName)
517+
518+
// return it
519+
instrs += LOCAL_GET(instanceLocal)
520+
}
515521

516522
fctx.buildAndAddToContext()
517523
}
@@ -1093,7 +1099,7 @@ class WasmBuilder(coreSpec: CoreSpec) {
10931099
else if (isHijackedClass)
10941100
Some(transformType(IRTypes.BoxedClassToPrimType(className)))
10951101
else
1096-
Some(transformType(IRTypes.ClassType(className)))
1102+
Some(transformClassType(className).toNonNullable)
10971103

10981104
// Prepare for function context, set receiver and parameters
10991105
implicit val fctx = WasmFunctionContext(

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+7-8
Original file line numberDiff line numberDiff line change
@@ -1893,8 +1893,7 @@ private class WasmExpressionBuilder private (
18931893
* if the given class is an ancestor of hijacked classes (which in practice
18941894
* is only the case for j.l.Object).
18951895
*/
1896-
val instanceTyp =
1897-
Types.WasmRefType.nullable(WasmStructTypeName.forClass(n.className))
1896+
val instanceTyp = Types.WasmRefType(WasmStructTypeName.forClass(n.className))
18981897
val localInstance = fctx.addSyntheticLocal(instanceTyp)
18991898

19001899
fctx.markPosition(n)
@@ -1926,7 +1925,7 @@ private class WasmExpressionBuilder private (
19261925
val primLocal = fctx.addSyntheticLocal(primTyp)
19271926

19281927
val boxClassType = IRTypes.ClassType(boxClassName)
1929-
val boxTyp = TypeTransformer.transformType(boxClassType)(ctx)
1928+
val boxTyp = TypeTransformer.transformClassType(boxClassName)(ctx).toNonNullable
19301929
val instanceLocal = fctx.addSyntheticLocal(boxTyp)
19311930

19321931
/* The generated code is as follows. Before the codegen, the stack contains
@@ -1971,12 +1970,12 @@ private class WasmExpressionBuilder private (
19711970

19721971
private def genWrapAsThrowable(tree: IRTrees.WrapAsThrowable): IRTypes.Type = {
19731972
val throwableClassType = IRTypes.ClassType(IRNames.ThrowableClass)
1974-
val throwableTyp = TypeTransformer.transformType(throwableClassType)(ctx)
1973+
val nonNullThrowableTyp = Types.WasmRefType(Types.WasmHeapType.ThrowableType)
19751974

1976-
val jsExceptionClassType = IRTypes.ClassType(SpecialNames.JSExceptionClass)
1977-
val jsExceptionTyp = TypeTransformer.transformType(jsExceptionClassType)(ctx)
1975+
val jsExceptionTyp =
1976+
TypeTransformer.transformClassType(SpecialNames.JSExceptionClass)(ctx).toNonNullable
19781977

1979-
fctx.block(throwableTyp) { doneLabel =>
1978+
fctx.block(nonNullThrowableTyp) { doneLabel =>
19801979
genTree(tree.expr, IRTypes.AnyType)
19811980

19821981
fctx.markPosition(tree)
@@ -1985,7 +1984,7 @@ private class WasmExpressionBuilder private (
19851984
instrs += BR_ON_CAST(
19861985
doneLabel,
19871986
Types.WasmRefType.anyref,
1988-
Types.WasmRefType(Types.WasmHeapType.ThrowableType)
1987+
nonNullThrowableTyp
19891988
)
19901989

19911990
// otherwise, wrap in a new JavaScriptException

Diff for: wasm/src/main/scala/wasm4s/Types.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ object Types {
2323
case object WasmInt8 extends WasmPackedType("i8", 0x78)
2424
case object WasmInt16 extends WasmPackedType("i16", 0x77)
2525

26-
final case class WasmRefType(nullable: Boolean, heapType: WasmHeapType) extends WasmType
26+
final case class WasmRefType(nullable: Boolean, heapType: WasmHeapType) extends WasmType {
27+
def toNullable: WasmRefType = WasmRefType(true, heapType)
28+
def toNonNullable: WasmRefType = WasmRefType(false, heapType)
29+
}
2730

2831
object WasmRefType {
2932

0 commit comments

Comments
 (0)