Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit a2bc825

Browse files
authored
Merge pull request #111 from sjrd/push-expected-types-more
Propagate expected types into more codegen functions.
2 parents 17365de + 43c740b commit a2bc825

File tree

1 file changed

+83
-65
lines changed

1 file changed

+83
-65
lines changed

wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+83-65
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ private class WasmExpressionBuilder private (
114114

115115
def genTree(tree: IRTrees.Tree, expectedType: IRTypes.Type): Unit = {
116116
val generatedType: IRTypes.Type = tree match {
117-
case t: IRTrees.Literal => genLiteral(t)
117+
case t: IRTrees.Literal => genLiteral(t, expectedType)
118118
case t: IRTrees.UnaryOp => genUnaryOp(t)
119119
case t: IRTrees.BinaryOp => genBinaryOp(t)
120120
case t: IRTrees.VarRef => genVarRef(t)
@@ -139,10 +139,10 @@ private class WasmExpressionBuilder private (
139139
case t: IRTrees.If => genIf(t, expectedType)
140140
case t: IRTrees.While => genWhile(t)
141141
case t: IRTrees.ForIn => genForIn(t)
142-
case t: IRTrees.TryCatch => genTryCatch(t)
143-
case t: IRTrees.TryFinally => unwinding.genTryFinally(t)
142+
case t: IRTrees.TryCatch => genTryCatch(t, expectedType)
143+
case t: IRTrees.TryFinally => unwinding.genTryFinally(t, expectedType)
144144
case t: IRTrees.Throw => genThrow(t)
145-
case t: IRTrees.Match => genMatch(t)
145+
case t: IRTrees.Match => genMatch(t, expectedType)
146146
case t: IRTrees.Debugger => IRTypes.NoType // ignore
147147
case t: IRTrees.Skip => IRTypes.NoType
148148
case t: IRTrees.Clone => genClone(t)
@@ -299,7 +299,7 @@ private class WasmExpressionBuilder private (
299299
instrs += CALL(WasmFunctionName.jsSuperSet)
300300

301301
case assign: IRTrees.JSGlobalRef =>
302-
genLiteral(IRTrees.StringLiteral(assign.name)(assign.pos))
302+
instrs ++= ctx.getConstantStringInstr(assign.name)
303303
genTree(t.rhs, IRTypes.AnyType)
304304
instrs += CALL(WasmFunctionName.jsGlobalRefSet)
305305

@@ -788,41 +788,50 @@ private class WasmExpressionBuilder private (
788788
}
789789
}
790790

791-
private def genLiteral(l: IRTrees.Literal): IRTypes.Type = {
792-
l match {
793-
case IRTrees.BooleanLiteral(v) => instrs += WasmInstr.I32_CONST(if (v) 1 else 0)
794-
case IRTrees.ByteLiteral(v) => instrs += WasmInstr.I32_CONST(v)
795-
case IRTrees.ShortLiteral(v) => instrs += WasmInstr.I32_CONST(v)
796-
case IRTrees.IntLiteral(v) => instrs += WasmInstr.I32_CONST(v)
797-
case IRTrees.CharLiteral(v) => instrs += WasmInstr.I32_CONST(v)
798-
case IRTrees.LongLiteral(v) => instrs += WasmInstr.I64_CONST(v)
799-
case IRTrees.FloatLiteral(v) => instrs += WasmInstr.F32_CONST(v)
800-
case IRTrees.DoubleLiteral(v) => instrs += WasmInstr.F64_CONST(v)
801-
802-
case v: IRTrees.Undefined =>
803-
instrs += CALL(WasmFunctionName.undef)
804-
case v: IRTrees.Null =>
805-
instrs += WasmInstr.REF_NULL(Types.WasmHeapType.None)
806-
807-
case v: IRTrees.StringLiteral =>
808-
instrs ++= ctx.getConstantStringInstr(v.value)
809-
810-
case v: IRTrees.ClassOf =>
811-
v.typeRef match {
812-
case typeRef: IRTypes.NonArrayTypeRef =>
813-
genClassOfFromTypeData(getNonArrayTypeDataInstr(typeRef))
814-
815-
case typeRef: IRTypes.ArrayTypeRef =>
816-
val typeDataType = Types.WasmRefType(WasmStructTypeName.typeData)
817-
val typeDataLocal = fctx.addSyntheticLocal(typeDataType)
791+
private def genLiteral(l: IRTrees.Literal, expectedType: IRTypes.Type): IRTypes.Type = {
792+
if (expectedType == IRTypes.NoType) {
793+
/* Since all primitives are pure, we can always get rid of them.
794+
* This is mostly useful for the argument of `Return` nodes that target a
795+
* `Labeled` in statement position, since they must have a non-`void`
796+
* type in the IR but they get a `void` expected type.
797+
*/
798+
expectedType
799+
} else {
800+
l match {
801+
case IRTrees.BooleanLiteral(v) => instrs += WasmInstr.I32_CONST(if (v) 1 else 0)
802+
case IRTrees.ByteLiteral(v) => instrs += WasmInstr.I32_CONST(v)
803+
case IRTrees.ShortLiteral(v) => instrs += WasmInstr.I32_CONST(v)
804+
case IRTrees.IntLiteral(v) => instrs += WasmInstr.I32_CONST(v)
805+
case IRTrees.CharLiteral(v) => instrs += WasmInstr.I32_CONST(v)
806+
case IRTrees.LongLiteral(v) => instrs += WasmInstr.I64_CONST(v)
807+
case IRTrees.FloatLiteral(v) => instrs += WasmInstr.F32_CONST(v)
808+
case IRTrees.DoubleLiteral(v) => instrs += WasmInstr.F64_CONST(v)
809+
810+
case v: IRTrees.Undefined =>
811+
instrs += CALL(WasmFunctionName.undef)
812+
case v: IRTrees.Null =>
813+
instrs += WasmInstr.REF_NULL(Types.WasmHeapType.None)
814+
815+
case v: IRTrees.StringLiteral =>
816+
instrs ++= ctx.getConstantStringInstr(v.value)
817+
818+
case v: IRTrees.ClassOf =>
819+
v.typeRef match {
820+
case typeRef: IRTypes.NonArrayTypeRef =>
821+
genClassOfFromTypeData(getNonArrayTypeDataInstr(typeRef))
822+
823+
case typeRef: IRTypes.ArrayTypeRef =>
824+
val typeDataType = Types.WasmRefType(WasmStructTypeName.typeData)
825+
val typeDataLocal = fctx.addSyntheticLocal(typeDataType)
826+
827+
genLoadArrayTypeData(typeRef)
828+
instrs += LOCAL_SET(typeDataLocal)
829+
genClassOfFromTypeData(LOCAL_GET(typeDataLocal))
830+
}
831+
}
818832

819-
genLoadArrayTypeData(typeRef)
820-
instrs += LOCAL_SET(typeDataLocal)
821-
genClassOfFromTypeData(LOCAL_GET(typeDataLocal))
822-
}
833+
l.tpe
823834
}
824-
825-
l.tpe
826835
}
827836

828837
private def getNonArrayTypeDataInstr(typeRef: IRTypes.NonArrayTypeRef): WasmInstr =
@@ -1194,7 +1203,7 @@ private class WasmExpressionBuilder private (
11941203
instrs += BR_ON_NON_NULL(labelDone)
11951204
}
11961205

1197-
genLiteral(IRTrees.StringLiteral("null")(tree.pos))
1206+
instrs ++= ctx.getConstantStringInstr("null")
11981207
}
11991208
} else {
12001209
/* Dispatch where the receiver can be a JS value.
@@ -1498,7 +1507,7 @@ private class WasmExpressionBuilder private (
14981507
targetTpe match {
14991508
case IRTypes.UndefType =>
15001509
instrs += DROP
1501-
genLiteral(IRTrees.Undefined())
1510+
instrs += CALL(WasmFunctionName.undef)
15021511
case IRTypes.StringType =>
15031512
instrs += REF_AS_NOT_NULL
15041513

@@ -1628,10 +1637,19 @@ private class WasmExpressionBuilder private (
16281637
private def genIf(t: IRTrees.If, expectedType: IRTypes.Type): IRTypes.Type = {
16291638
val ty = TypeTransformer.transformResultType(expectedType)(ctx)
16301639
genTree(t.cond, IRTypes.BooleanType)
1631-
fctx.ifThenElse(ty) {
1632-
genTree(t.thenp, expectedType)
1633-
} {
1634-
genTree(t.elsep, expectedType)
1640+
1641+
t.elsep match {
1642+
case IRTrees.Skip() =>
1643+
assert(expectedType == IRTypes.NoType)
1644+
fctx.ifThen() {
1645+
genTree(t.thenp, expectedType)
1646+
}
1647+
case _ =>
1648+
fctx.ifThenElse(ty) {
1649+
genTree(t.thenp, expectedType)
1650+
} {
1651+
genTree(t.elsep, expectedType)
1652+
}
16351653
}
16361654

16371655
if (expectedType == IRTypes.NothingType)
@@ -1706,17 +1724,17 @@ private class WasmExpressionBuilder private (
17061724
IRTypes.NoType
17071725
}
17081726

1709-
private def genTryCatch(t: IRTrees.TryCatch): IRTypes.Type = {
1710-
val resultType = TypeTransformer.transformResultType(t.tpe)(ctx)
1727+
private def genTryCatch(t: IRTrees.TryCatch, expectedType: IRTypes.Type): IRTypes.Type = {
1728+
val resultType = TypeTransformer.transformResultType(expectedType)(ctx)
17111729

17121730
if (UseLegacyExceptionsForTryCatch) {
17131731
instrs += TRY(fctx.sigToBlockType(WasmFunctionSignature(Nil, resultType)))
1714-
genTree(t.block, t.tpe)
1732+
genTree(t.block, expectedType)
17151733
instrs += CATCH(ctx.exceptionTagName)
17161734
fctx.withNewLocal(t.errVar.name, Types.WasmRefType.anyref) { exceptionLocal =>
17171735
instrs += ANY_CONVERT_EXTERN
17181736
instrs += LOCAL_SET(exceptionLocal)
1719-
genTree(t.handler, t.tpe)
1737+
genTree(t.handler, expectedType)
17201738
}
17211739
instrs += END
17221740
} else {
@@ -1731,22 +1749,22 @@ private class WasmExpressionBuilder private (
17311749
fctx.tryTable(Types.WasmRefType.externref)(
17321750
List(CatchClause.Catch(ctx.exceptionTagName, catchLabel))
17331751
) {
1734-
genTree(t.block, t.tpe)
1752+
genTree(t.block, expectedType)
17351753
instrs += BR(doneLabel)
17361754
}
17371755
} // end block $catch
17381756
fctx.withNewLocal(t.errVar.name, Types.WasmRefType.anyref) { exceptionLocal =>
17391757
instrs += ANY_CONVERT_EXTERN
17401758
instrs += LOCAL_SET(exceptionLocal)
1741-
genTree(t.handler, t.tpe)
1759+
genTree(t.handler, expectedType)
17421760
}
17431761
} // end block $done
17441762
}
17451763

1746-
if (t.tpe == IRTypes.NothingType)
1764+
if (expectedType == IRTypes.NothingType)
17471765
instrs += UNREACHABLE
17481766

1749-
t.tpe
1767+
expectedType
17501768
}
17511769

17521770
private def genThrow(tree: IRTrees.Throw): IRTypes.Type = {
@@ -2071,13 +2089,13 @@ private class WasmExpressionBuilder private (
20712089
}
20722090

20732091
private def genJSGlobalRef(tree: IRTrees.JSGlobalRef): IRTypes.Type = {
2074-
genLiteral(IRTrees.StringLiteral(tree.name)(tree.pos))
2092+
instrs ++= ctx.getConstantStringInstr(tree.name)
20752093
instrs += CALL(WasmFunctionName.jsGlobalRefGet)
20762094
IRTypes.AnyType
20772095
}
20782096

20792097
private def genJSTypeOfGlobalRef(tree: IRTrees.JSTypeOfGlobalRef): IRTypes.Type = {
2080-
genLiteral(IRTrees.StringLiteral(tree.globalRef.name)(tree.pos))
2098+
instrs ++= ctx.getConstantStringInstr(tree.globalRef.name)
20812099
instrs += CALL(WasmFunctionName.jsGlobalRefTypeof)
20822100
IRTypes.AnyType
20832101
}
@@ -2354,13 +2372,13 @@ private class WasmExpressionBuilder private (
23542372
t.tpe
23552373
}
23562374

2357-
private def genMatch(tree: IRTrees.Match): IRTypes.Type = {
2375+
private def genMatch(tree: IRTrees.Match, expectedType: IRTypes.Type): IRTypes.Type = {
23582376
val IRTrees.Match(selector, cases, defaultBody) = tree
23592377
val selectorLocal = fctx.addSyntheticLocal(TypeTransformer.transformType(selector.tpe)(ctx))
23602378
genTreeAuto(selector)
23612379
instrs += LOCAL_SET(selectorLocal)
23622380

2363-
fctx.block(TypeTransformer.transformResultType(tree.tpe)(ctx)) { doneLabel =>
2381+
fctx.block(TypeTransformer.transformResultType(expectedType)(ctx)) { doneLabel =>
23642382
fctx.block() { defaultLabel =>
23652383
val caseLabels = cases.map(c => c._1 -> fctx.genLabel())
23662384
for (caseLabel <- caseLabels)
@@ -2390,17 +2408,17 @@ private class WasmExpressionBuilder private (
23902408

23912409
for ((caseLabel, caze) <- caseLabels.zip(cases).reverse) {
23922410
instrs += END
2393-
genTree(caze._2, tree.tpe)
2411+
genTree(caze._2, expectedType)
23942412
instrs += BR(doneLabel)
23952413
}
23962414
}
2397-
genTree(defaultBody, tree.tpe)
2415+
genTree(defaultBody, expectedType)
23982416
}
23992417

2400-
if (tree.tpe == IRTypes.NothingType)
2418+
if (expectedType == IRTypes.NothingType)
24012419
instrs += UNREACHABLE
24022420

2403-
tree.tpe
2421+
expectedType
24042422
}
24052423

24062424
private def genCreateJSClass(tree: IRTrees.CreateJSClass): IRTypes.Type = {
@@ -2809,10 +2827,10 @@ private class WasmExpressionBuilder private (
28092827
expectedType
28102828
}
28112829

2812-
def genTryFinally(t: IRTrees.TryFinally): IRTypes.Type = {
2830+
def genTryFinally(t: IRTrees.TryFinally, expectedType: IRTypes.Type): IRTypes.Type = {
28132831
val entry = new TryFinallyEntry(currentUnwindingStackDepth)
28142832

2815-
val resultType = TypeTransformer.transformResultType(t.tpe)(ctx)
2833+
val resultType = TypeTransformer.transformResultType(expectedType)(ctx)
28162834
val resultLocals = resultType.map(fctx.addSyntheticLocal(_))
28172835

28182836
fctx.block() { doneLabel =>
@@ -2825,7 +2843,7 @@ private class WasmExpressionBuilder private (
28252843
fctx.tryTable()(List(CatchClause.CatchAllRef(catchLabel))) {
28262844
// try block
28272845
enterTryFinally(entry) {
2828-
genTree(t.block, t.tpe)
2846+
genTree(t.block, expectedType)
28292847
}
28302848

28312849
// store the result in locals during the finally block
@@ -2922,10 +2940,10 @@ private class WasmExpressionBuilder private (
29222940
for (resultLocal <- resultLocals)
29232941
instrs += LOCAL_GET(resultLocal)
29242942

2925-
if (t.tpe == IRTypes.NothingType)
2943+
if (expectedType == IRTypes.NothingType)
29262944
instrs += UNREACHABLE
29272945

2928-
t.tpe
2946+
expectedType
29292947
}
29302948

29312949
private def emitBRTable(

0 commit comments

Comments
 (0)