Skip to content
Merged
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object Feature:
val modularity = experimental("modularity")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
val packageObjectValues = experimental("packageObjectValues")
val multiSpreads = experimental("multiSpreads")
val subCases = experimental("subCases")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
Expand Down
12 changes: 10 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ class Definitions {
@tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw,
MethodType(List(ThrowableType), NothingType))

@tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread,
PolyType(TypeBounds.empty :: Nil)(
tl => MethodType(AnyType :: Nil, tl.paramRefs(0))
))

@tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol(
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
def NothingType: TypeRef = NothingClass.typeRef
Expand Down Expand Up @@ -519,6 +524,8 @@ class Definitions {
@tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray")
@tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray")

@tu lazy val VarArgsBuilderModule: Symbol = requiredModule("scala.runtime.VarArgsBuilder")

def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule

// The set of all wrap{X, Ref}Array methods, where X is a value type
Expand Down Expand Up @@ -563,11 +570,12 @@ class Definitions {
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
@tu lazy val Seq_dropRight : Symbol = SeqClass.requiredMethod(nme.dropRight)
@tu lazy val Seq_takeRight : Symbol = SeqClass.requiredMethod(nme.takeRight)
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)

Expand Down Expand Up @@ -2234,7 +2242,7 @@ class Definitions {

/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
@tu lazy val syntheticCoreMethods: List[TermSymbol] =
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod)

@tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ object StdNames {
val doubleHash: N = "doubleHash"
val dotty: N = "dotty"
val drop: N = "drop"
val dropRight: N = "dropRight"
val dynamics: N = "dynamics"
val elem: N = "elem"
val elems: N = "elems"
Expand Down Expand Up @@ -619,6 +620,7 @@ object StdNames {
val setSymbol: N = "setSymbol"
val setType: N = "setType"
val setTypeSignature: N = "setTypeSignature"
val spread: N = "spread"
val standardInterpolator: N = "standardInterpolator"
val staticClass : N = "staticClass"
val staticModule : N = "staticModule"
Expand Down Expand Up @@ -801,6 +803,7 @@ object StdNames {
val takeModulo: N = "takeModulo"
val takeNot: N = "takeNot"
val takeOr: N = "takeOr"
val takeRight: N = "takeRight"
val takeXor: N = "takeXor"
val testEqual: N = "testEqual"
val testGreaterOrEqualThan: N = "testGreaterOrEqualThan"
Expand Down
19 changes: 13 additions & 6 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1056,17 +1056,22 @@ object Parsers {
}

/** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not
syntactically valid, but we need to include them here for error recovery. */
syntactically valid, but we need to include them here for error recovery.
Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally.
*/
def followingIsVararg(): Boolean =
in.isIdent(nme.raw.STAR) && {
val lookahead = in.LookaheadScanner()
lookahead.nextToken()
lookahead.token == RPAREN
|| lookahead.token == COMMA
&& {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
&& (
in.featureEnabled(Feature.multiSpreads)
|| {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
)
}

/** When encountering a `:`, is that in the binding of a lambda?
Expand Down Expand Up @@ -3347,7 +3352,9 @@ object Parsers {
if (in.token == RPAREN) Nil else patterns(location)

/** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns] ‘)’
*
* -- It is checked in Typer that there are no repeated PatVar arguments.
*/
def argumentPatterns(): List[Tree] =
inParensWithCommas(patternsOpt(Location.InPatternArgs))
Expand Down
91 changes: 61 additions & 30 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ object PatternMatcher {
case object NonNullTest extends Test // scrutinee ne null
case object GuardTest extends Test // scrutinee

val noLengthTest = LengthTest(0, exact = false)

// ------- Generating plans from trees ------------------------

/** A set of variabes that are known to be not null */
Expand Down Expand Up @@ -291,38 +293,67 @@ object PatternMatcher {
/** Plan for matching the sequence in `seqSym` against sequence elements `args`.
* If `exact` is true, the sequence is not permitted to have any elements following `args`.
*/
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
val selectors = args.indices.toList.map(idx =>
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span,
matchArgsPlan(selectors, args, onSuccess))
}
def matchElemsPlan(seqSym: Symbol, args: List[Tree], lengthTest: LengthTest, onSuccess: Plan) =
val selectors = args.indices.toList.map: idx =>
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))
if lengthTest.len == 0 && lengthTest.exact == false then // redundant test
matchArgsPlan(selectors, args, onSuccess)
else
TestPlan(lengthTest, seqSym, seqSym.span,
matchArgsPlan(selectors, args, onSuccess))

/** Plan for matching the sequence in `getResult` against sequence elements
* and a possible last varargs argument `args`.
* `args`. Sequence elements may contain a varargs argument.
* Example:
*
* lst match case Seq(1, xs*, 2, 3) => ...
*
* generates code which is equivalent to:
*
* if lst != null then
* if lst.lengthCompare >= 3 then
* if lst(0) == 1 then
* val x1 = lst.drop(1)
* val xs = x1.dropRight(2)
* val x2 = lst.takeRight(2)
* if x2(0) == 2 && x2(1) == 3 then
* return[matchResult] ...
*/
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
case Some(VarArgPattern(arg)) =>
val matchRemaining =
if (args.length == 1) {
val toSeq = ref(getResult)
.select(defn.Seq_toSeq.matchingMember(getResult.info))
letAbstract(toSeq) { toSeqResult =>
patternPlan(toSeqResult, arg, onSuccess)
}
}
else {
val dropped = ref(getResult)
.select(defn.Seq_drop.matchingMember(getResult.info))
.appliedTo(Literal(Constant(args.length - 1)))
letAbstract(dropped) { droppedResult =>
patternPlan(droppedResult, arg, onSuccess)
}
}
matchElemsPlan(getResult, args.init, exact = false, matchRemaining)
case _ =>
matchElemsPlan(getResult, args, exact = true, onSuccess)
}
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan =
val (leading, varargAndRest) = args.span:
case VarArgPattern(_) => false
case _ => true
varargAndRest match
case VarArgPattern(arg) :: trailing =>
val remaining =
if leading.isEmpty then
ref(getResult)
.select(defn.Seq_toSeq.matchingMember(getResult.info))
else
ref(getResult)
.select(defn.Seq_drop.matchingMember(getResult.info))
.appliedTo(Literal(Constant(leading.length)))
val matchRemaining =
letAbstract(remaining): remainingResult =>
if trailing.isEmpty then
patternPlan(remainingResult, arg, onSuccess)
else
val seq = ref(remainingResult)
.select(defn.Seq_dropRight.matchingMember(remainingResult.info))
.appliedTo(Literal(Constant(trailing.length)))
letAbstract(seq): seqResult =>
val rest = ref(remainingResult)
.select(defn.Seq_takeRight.matchingMember(remainingResult.info))
.appliedTo(Literal(Constant(trailing.length)))
val matchTrailing =
letAbstract(rest): trailingResult =>
matchElemsPlan(trailingResult, trailing, noLengthTest, onSuccess)
patternPlan(seqResult, arg, matchTrailing)
matchElemsPlan(getResult, leading,
LengthTest(leading.length + trailing.length, exact = false),
matchRemaining)
case _ =>
matchElemsPlan(getResult, args, LengthTest(args.length, exact = true), onSuccess)

/** Plan for matching the sequence in `getResult`
*
Expand Down Expand Up @@ -491,7 +522,7 @@ object PatternMatcher {
case WildcardPattern() | This(_) =>
onSuccess
case SeqLiteral(pats, _) =>
matchElemsPlan(scrutinee, pats, exact = true, onSuccess)
matchElemsPlan(scrutinee, pats, LengthTest(pats.length, exact = true), onSuccess)
case _ =>
TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess)
}
Expand Down
88 changes: 87 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ import config.Printers.typr
import config.Feature
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import NameKinds.{WildcardParamName, TempResultName}
import typer.Applications.{spread, HasSpreads}
import typer.Implicits.SearchFailureType
import Constants.Constant
import cc.*
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
import ast.TreeInfo

object PostTyper {
val name: String = "posttyper"
Expand Down Expand Up @@ -376,6 +380,86 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
tpt

/** If one of `trees` is a spread of an expression that is not idempotent, lift out all
* non-idempotent expressions (not just the spreads) and apply `within` to the resulting
* pure references. Otherwise apply `within` to the original trees.
*/
private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree =
if trees.exists:
case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent)
case _ => false
then
val lifted = new mutable.ListBuffer[ValDef]
def liftIfImpure(tree: Tree): Tree = tree match
case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod =>
cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure))
case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent =>
tree
case _ =>
val vdef = SyntheticValDef(TempResultName.fresh(), tree).withSpan(tree.span)
lifted += vdef
Ident(vdef.namedType).withSpan(tree.span)
val pureTrees = trees.mapConserve(liftIfImpure)
Block(lifted.toList, within(pureTrees))
else within(trees)

/** Translate sequence literal containing spread operators. Example:
*
* val xs, ys: List[Int]
* [1, xs*, 2, ys*]
*
* Here the sequence literal is translated at typer to
*
* [1, spread(xs), 2, spread(ys)]
*
* This then translates to
*
* scala.runtime.VarArgsBuilder.ofInt(2 + xs.length + ys.length)
* .add(1)
* .addSeq(xs)
* .add(2)
* .addSeq(ys)
*
* The reason for doing a two-step typer/postTyper translation is that
* at typer, we don't have all type variables instantiated yet.
*/
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
val SeqLiteral(rawElems, elemtpt) = tree
val elemType = elemtpt.tpe
val elemCls = elemType.classSymbol

evalSpreadsOnce(rawElems): elems =>
val lengthCalls = elems.collect:
case spread(elem) => elem.select(nme.length)
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
val totalLength =
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
acc.select(defn.Int_+).appliedTo(len)

def makeBuilder(name: String) =
ref(defn.VarArgsBuilderModule).select(name.toTermName)

val builder =
if defn.ScalaValueClasses().contains(elemCls) then
makeBuilder(s"of${elemCls.name}")
else if elemCls.derivesFrom(defn.ObjectClass) then
makeBuilder("ofRef").appliedToType(elemType)
else
makeBuilder("generic").appliedToType(elemType)

elems.foldLeft(builder.appliedTo(totalLength)): (bldr, elem) =>
elem match
case spread(arg) =>
if arg.tpe.derivesFrom(defn.SeqClass) then
bldr.select("addSeq".toTermName).appliedTo(arg)
else
bldr.select("addArray".toTermName).appliedTo(
arg.ensureConforms(defn.ArrayOf(elemType)))
case _ => bldr.select("add".toTermName).appliedTo(elem)
.select("result".toTermName)
.appliedToNone
end flattenSpreads

override def transform(tree: Tree)(using Context): Tree =
try tree match {
// TODO move CaseDef case lower: keep most probable trees first for performance
Expand Down Expand Up @@ -592,6 +676,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case tree: SeqLiteral if tree.hasAttachment(HasSpreads) =>
flattenSpreads(tree)
case _: Quote | _: QuotePattern =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
Loading
Loading