44
55package scala .async .internal
66
7+ import scala .collection .mutable .ListBuffer
78import scala .reflect .macros .Context
89import scala .collection .mutable
910
@@ -53,14 +54,13 @@ trait AsyncAnalysis {
5354 }
5455
5556 override def traverse (tree : Tree ) {
56- def containsAwait = tree exists isAwait
5757 tree match {
58- case Try (_, _, _) if containsAwait =>
58+ case Try (_, _, _) if containsAwait(tree) =>
5959 reportUnsupportedAwait(tree, " try/catch" )
6060 super .traverse(tree)
6161 case Return (_) =>
6262 c.abort(tree.pos, " return is illegal within a async block" )
63- case DefDef (mods, _, _, _, _, _) if mods.hasFlag(Flag .LAZY ) && containsAwait =>
63+ case DefDef (mods, _, _, _, _, _) if mods.hasFlag(Flag .LAZY ) && containsAwait(tree) =>
6464 reportUnsupportedAwait(tree, " lazy val initializer" )
6565 case CaseDef (_, guard, _) if guard exists isAwait =>
6666 // TODO lift this restriction
@@ -74,9 +74,19 @@ trait AsyncAnalysis {
7474 * @return true, if the tree contained an unsupported await.
7575 */
7676 private def reportUnsupportedAwait (tree : Tree , whyUnsupported : String ): Boolean = {
77- val badAwaits : List [RefTree ] = tree collect {
78- case rt : RefTree if isAwait(rt) => rt
77+ val badAwaits = ListBuffer [Tree ]()
78+ object traverser extends Traverser {
79+ override def traverse (tree : Tree ): Unit = {
80+ if (! isAsync(tree))
81+ super .traverse(tree)
82+ tree match {
83+ case rt : RefTree if isAwait(rt) =>
84+ badAwaits += rt
85+ case _ =>
86+ }
87+ }
7988 }
89+ traverser(tree)
8090 badAwaits foreach {
8191 tree =>
8292 reportError(tree.pos, s " await must not be used under a $whyUnsupported. " )
0 commit comments