@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
1616 import c .internal ._
1717 import decorators ._
1818
19- def anfTransform (tree : Tree ): Block = {
19+ def anfTransform (tree : Tree , owner : Symbol ): Block = {
2020 // Must prepend the () for issue #31.
21- val block = c.typecheck(atPos(tree.pos)(Block (List (Literal (Constant (()))), tree))).setType(tree.tpe)
21+ val block = c.typecheck(atPos(tree.pos)(newBlock (List (Literal (Constant (()))), tree))).setType(tree.tpe)
2222
2323 sealed abstract class AnfMode
2424 case object Anf extends AnfMode
2525 case object Linearizing extends AnfMode
2626
27+ val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
28+
2729 var mode : AnfMode = Anf
28- typingTransform(block )((tree, api) => {
30+ typingTransform(tree1, owner )((tree, api) => {
2931 def blockToList (tree : Tree ): List [Tree ] = tree match {
3032 case Block (stats, expr) => stats :+ expr
3133 case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
3436 def listToBlock (trees : List [Tree ]): Block = trees match {
3537 case trees @ (init :+ last) =>
3638 val pos = trees.map(_.pos).reduceLeft(_ union _)
37- Block (init, last).setType(last.tpe).setPos(pos)
39+ newBlock (init, last).setType(last.tpe).setPos(pos)
3840 }
3941
4042 object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
6668 stats :+ valDef :+ atPos(tree.pos)(ref1)
6769
6870 case If (cond, thenp, elsep) =>
71+ // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
72+ // as though it was typed with `Unit`.
73+ def isPatMatGeneratedJump (t : Tree ): Boolean = t match {
74+ case Block (_, expr) => isPatMatGeneratedJump(expr)
75+ case If (_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
76+ case _ : Apply if isLabel(t.symbol) => true
77+ case _ => false
78+ }
79+ if (isPatMatGeneratedJump(expr)) {
80+ internal.setType(expr, definitions.UnitTpe )
81+ }
6982 // if type of if-else is Unit don't introduce assignment,
7083 // but add Unit value to bring it into form expected by async transform
7184 if (expr.tpe =:= definitions.UnitTpe ) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
7790 def branchWithAssign (orig : Tree ) = api.typecheck(atPos(orig.pos) {
7891 def cast (t : Tree ) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
7992 orig match {
80- case Block (thenStats, thenExpr) => Block (thenStats, Assign (Ident (varDef.symbol), cast(thenExpr)))
93+ case Block (thenStats, thenExpr) => newBlock (thenStats, Assign (Ident (varDef.symbol), cast(thenExpr)))
8194 case _ => Assign (Ident (varDef.symbol), cast(orig))
8295 }
8396 })
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
115128 }
116129 }
117130
118- private def defineVar (prefix : String , tp : Type , pos : Position ): ValDef = {
131+ def defineVar (prefix : String , tp : Type , pos : Position ): ValDef = {
119132 val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC ).setInfo(uncheckedBounds(tp))
120133 valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType ).setPos(pos)
121134 }
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
152165 }
153166
154167 def _transformToList (tree : Tree ): List [Tree ] = trace(tree) {
155- val containsAwait = tree exists isAwait
156- if (! containsAwait) {
168+ if (! containsAwait(tree)) {
157169 tree match {
158170 case Block (stats, expr) =>
159171 // avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
207219 funStats ++ argStatss.flatten.flatten :+ typedNewApply
208220
209221 case Block (stats, expr) =>
210- (stats :+ expr).flatMap(linearize.transformToList)
222+ val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
223+ eliminateMatchEndLabelParameter(trees)
211224
212225 case ValDef (mods, name, tpt, rhs) =>
213- if (rhs exists isAwait ) {
226+ if (containsAwait( rhs) ) {
214227 val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
215228 stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
216229 stats :+ treeCopy.ValDef (tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
247260 scrutStats :+ treeCopy.Match (tree, scrutExpr, caseDefs)
248261
249262 case LabelDef (name, params, rhs) =>
250- List (LabelDef (name, params, Block (linearize.transformToList(rhs), Literal (Constant (())))).setSymbol(tree.symbol))
263+ List (LabelDef (name, params, newBlock (linearize.transformToList(rhs), Literal (Constant (())))).setSymbol(tree.symbol))
251264
252265 case TypeApply (fun, targs) =>
253266 val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,52 @@ private[async] trait AnfTransform {
259272 }
260273 }
261274
275+ // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
276+ //
277+ // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
278+ // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
279+ //
280+ // For our purposes, it is easier to:
281+ // - extract a `matchRes` variable
282+ // - rewrite the terminal label def to take no parameters, and instead read this temp variable
283+ // - change jumps to the terminal label to an assignment and a no-arg label application
284+ def eliminateMatchEndLabelParameter (statsExpr : List [Tree ]): List [Tree ] = {
285+ import internal .{methodType , setInfo }
286+ val caseDefToMatchResult = collection.mutable.Map [Symbol , Symbol ]()
287+
288+ val matchResults = collection.mutable.Buffer [Tree ]()
289+ val statsExpr0 = statsExpr.reverseMap {
290+ case ld @ LabelDef (_, param :: Nil , body) =>
291+ val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
292+ matchResults += matchResult
293+ caseDefToMatchResult(ld.symbol) = matchResult.symbol
294+ val ld2 = treeCopy.LabelDef (ld, ld.name, Nil , body.substituteSymbols(param.symbol :: Nil , matchResult.symbol :: Nil ))
295+ setInfo(ld.symbol, methodType(Nil , ld.symbol.info.resultType))
296+ ld2
297+ case t =>
298+ if (caseDefToMatchResult.isEmpty) t
299+ else typingTransform(t)((tree, api) =>
300+ tree match {
301+ case Apply (fun, arg :: Nil ) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
302+ api.typecheck(atPos(tree.pos)(newBlock(Assign (Ident (caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil , treeCopy.Apply (tree, fun, Nil ))))
303+ case Block (stats, expr) =>
304+ api.default(tree) match {
305+ case Block (stats, Block (stats1, expr)) =>
306+ treeCopy.Block (tree, stats ::: stats1, expr)
307+ case t => t
308+ }
309+ case _ =>
310+ api.default(tree)
311+ }
312+ )
313+ }
314+ matchResults.toList match {
315+ case Nil => statsExpr
316+ case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
317+ case _ => c.error(macroPos, " Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
318+ }
319+ }
320+
262321 def anfLinearize (tree : Tree ): Block = {
263322 val trees : List [Tree ] = mode match {
264323 case Anf => anf._transformToList(tree)
0 commit comments