diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 70a77c9560b2..02bdb16ae217 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -37,6 +37,7 @@ object Feature: val modularity = experimental("modularity") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") val packageObjectValues = experimental("packageObjectValues") + val multiSpreads = experimental("multiSpreads") val subCases = experimental("subCases") def experimentalAutoEnableFeatures(using Context): List[TermName] = diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 365969d2f74e..dbe1602e2d82 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -468,6 +468,11 @@ class Definitions { @tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw, MethodType(List(ThrowableType), NothingType)) + @tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread, + PolyType(TypeBounds.empty :: Nil)( + tl => MethodType(AnyType :: Nil, tl.paramRefs(0)) + )) + @tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol( ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType)) def NothingType: TypeRef = NothingClass.typeRef @@ -519,6 +524,8 @@ class Definitions { @tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray") @tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray") + @tu lazy val VarArgsBuilderModule: Symbol = requiredModule("scala.runtime.VarArgsBuilder") + def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule // The set of all wrap{X, Ref}Array methods, where X is a value type @@ -563,11 +570,12 @@ class Definitions { @tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply) @tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head) @tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop) + @tu lazy val Seq_dropRight : Symbol = SeqClass.requiredMethod(nme.dropRight) + @tu lazy val Seq_takeRight : Symbol = SeqClass.requiredMethod(nme.takeRight) @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") @tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format) @@ -2234,7 +2242,7 @@ class Definitions { /** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */ @tu lazy val syntheticCoreMethods: List[TermSymbol] = - AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod) + AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod) @tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 18873dfa83af..323c59a5711d 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -470,6 +470,7 @@ object StdNames { val doubleHash: N = "doubleHash" val dotty: N = "dotty" val drop: N = "drop" + val dropRight: N = "dropRight" val dynamics: N = "dynamics" val elem: N = "elem" val elems: N = "elems" @@ -619,6 +620,7 @@ object StdNames { val setSymbol: N = "setSymbol" val setType: N = "setType" val setTypeSignature: N = "setTypeSignature" + val spread: N = "spread" val standardInterpolator: N = "standardInterpolator" val staticClass : N = "staticClass" val staticModule : N = "staticModule" @@ -801,6 +803,7 @@ object StdNames { val takeModulo: N = "takeModulo" val takeNot: N = "takeNot" val takeOr: N = "takeOr" + val takeRight: N = "takeRight" val takeXor: N = "takeXor" val testEqual: N = "testEqual" val testGreaterOrEqualThan: N = "testGreaterOrEqualThan" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index a31152ddcc6f..3a68703e4734 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1056,17 +1056,22 @@ object Parsers { } /** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not - syntactically valid, but we need to include them here for error recovery. */ + syntactically valid, but we need to include them here for error recovery. + Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally. + */ def followingIsVararg(): Boolean = in.isIdent(nme.raw.STAR) && { val lookahead = in.LookaheadScanner() lookahead.nextToken() lookahead.token == RPAREN || lookahead.token == COMMA - && { - lookahead.nextToken() - lookahead.token == RPAREN || lookahead.token == EOF - } + && ( + in.featureEnabled(Feature.multiSpreads) + || { + lookahead.nextToken() + lookahead.token == RPAREN || lookahead.token == EOF + } + ) } /** When encountering a `:`, is that in the binding of a lambda? @@ -3347,7 +3352,9 @@ object Parsers { if (in.token == RPAREN) Nil else patterns(location) /** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’ - * | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’ + * | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns] ‘)’ + * + * -- It is checked in Typer that there are no repeated PatVar arguments. */ def argumentPatterns(): List[Tree] = inParensWithCommas(patternsOpt(Location.InPatternArgs)) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 558cbd72dd43..273879f3c3cb 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -198,6 +198,8 @@ object PatternMatcher { case object NonNullTest extends Test // scrutinee ne null case object GuardTest extends Test // scrutinee + val noLengthTest = LengthTest(0, exact = false) + // ------- Generating plans from trees ------------------------ /** A set of variabes that are known to be not null */ @@ -291,38 +293,67 @@ object PatternMatcher { /** Plan for matching the sequence in `seqSym` against sequence elements `args`. * If `exact` is true, the sequence is not permitted to have any elements following `args`. */ - def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = { - val selectors = args.indices.toList.map(idx => - ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))) - TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span, - matchArgsPlan(selectors, args, onSuccess)) - } + def matchElemsPlan(seqSym: Symbol, args: List[Tree], lengthTest: LengthTest, onSuccess: Plan) = + val selectors = args.indices.toList.map: idx => + ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))) + if lengthTest.len == 0 && lengthTest.exact == false then // redundant test + matchArgsPlan(selectors, args, onSuccess) + else + TestPlan(lengthTest, seqSym, seqSym.span, + matchArgsPlan(selectors, args, onSuccess)) /** Plan for matching the sequence in `getResult` against sequence elements - * and a possible last varargs argument `args`. + * `args`. Sequence elements may contain a varargs argument. + * Example: + * + * lst match case Seq(1, xs*, 2, 3) => ... + * + * generates code which is equivalent to: + * + * if lst != null then + * if lst.lengthCompare >= 3 then + * if lst(0) == 1 then + * val x1 = lst.drop(1) + * val xs = x1.dropRight(2) + * val x2 = lst.takeRight(2) + * if x2(0) == 2 && x2(1) == 3 then + * return[matchResult] ... */ - def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match { - case Some(VarArgPattern(arg)) => - val matchRemaining = - if (args.length == 1) { - val toSeq = ref(getResult) - .select(defn.Seq_toSeq.matchingMember(getResult.info)) - letAbstract(toSeq) { toSeqResult => - patternPlan(toSeqResult, arg, onSuccess) - } - } - else { - val dropped = ref(getResult) - .select(defn.Seq_drop.matchingMember(getResult.info)) - .appliedTo(Literal(Constant(args.length - 1))) - letAbstract(dropped) { droppedResult => - patternPlan(droppedResult, arg, onSuccess) - } - } - matchElemsPlan(getResult, args.init, exact = false, matchRemaining) - case _ => - matchElemsPlan(getResult, args, exact = true, onSuccess) - } + def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = + val (leading, varargAndRest) = args.span: + case VarArgPattern(_) => false + case _ => true + varargAndRest match + case VarArgPattern(arg) :: trailing => + val remaining = + if leading.isEmpty then + ref(getResult) + .select(defn.Seq_toSeq.matchingMember(getResult.info)) + else + ref(getResult) + .select(defn.Seq_drop.matchingMember(getResult.info)) + .appliedTo(Literal(Constant(leading.length))) + val matchRemaining = + letAbstract(remaining): remainingResult => + if trailing.isEmpty then + patternPlan(remainingResult, arg, onSuccess) + else + val seq = ref(remainingResult) + .select(defn.Seq_dropRight.matchingMember(remainingResult.info)) + .appliedTo(Literal(Constant(trailing.length))) + letAbstract(seq): seqResult => + val rest = ref(remainingResult) + .select(defn.Seq_takeRight.matchingMember(remainingResult.info)) + .appliedTo(Literal(Constant(trailing.length))) + val matchTrailing = + letAbstract(rest): trailingResult => + matchElemsPlan(trailingResult, trailing, noLengthTest, onSuccess) + patternPlan(seqResult, arg, matchTrailing) + matchElemsPlan(getResult, leading, + LengthTest(leading.length + trailing.length, exact = false), + matchRemaining) + case _ => + matchElemsPlan(getResult, args, LengthTest(args.length, exact = true), onSuccess) /** Plan for matching the sequence in `getResult` * @@ -491,7 +522,7 @@ object PatternMatcher { case WildcardPattern() | This(_) => onSuccess case SeqLiteral(pats, _) => - matchElemsPlan(scrutinee, pats, exact = true, onSuccess) + matchElemsPlan(scrutinee, pats, LengthTest(pats.length, exact = true), onSuccess) case _ => TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess) } diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 9f79c063dc03..128655debc2e 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -18,10 +18,14 @@ import config.Printers.typr import config.Feature import util.{SrcPos, Stats} import reporting.* -import NameKinds.WildcardParamName +import NameKinds.{WildcardParamName, TempResultName} +import typer.Applications.{spread, HasSpreads} +import typer.Implicits.SearchFailureType +import Constants.Constant import cc.* import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation import dotty.tools.dotc.core.NameKinds.DefaultGetterName +import ast.TreeInfo object PostTyper { val name: String = "posttyper" @@ -376,6 +380,86 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case _ => tpt + /** If one of `trees` is a spread of an expression that is not idempotent, lift out all + * non-idempotent expressions (not just the spreads) and apply `within` to the resulting + * pure references. Otherwise apply `within` to the original trees. + */ + private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree = + if trees.exists: + case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent) + case _ => false + then + val lifted = new mutable.ListBuffer[ValDef] + def liftIfImpure(tree: Tree): Tree = tree match + case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod => + cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure)) + case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent => + tree + case _ => + val vdef = SyntheticValDef(TempResultName.fresh(), tree).withSpan(tree.span) + lifted += vdef + Ident(vdef.namedType).withSpan(tree.span) + val pureTrees = trees.mapConserve(liftIfImpure) + Block(lifted.toList, within(pureTrees)) + else within(trees) + + /** Translate sequence literal containing spread operators. Example: + * + * val xs, ys: List[Int] + * [1, xs*, 2, ys*] + * + * Here the sequence literal is translated at typer to + * + * [1, spread(xs), 2, spread(ys)] + * + * This then translates to + * + * scala.runtime.VarArgsBuilder.ofInt(2 + xs.length + ys.length) + * .add(1) + * .addSeq(xs) + * .add(2) + * .addSeq(ys) + * + * The reason for doing a two-step typer/postTyper translation is that + * at typer, we don't have all type variables instantiated yet. + */ + private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree = + val SeqLiteral(rawElems, elemtpt) = tree + val elemType = elemtpt.tpe + val elemCls = elemType.classSymbol + + evalSpreadsOnce(rawElems): elems => + val lengthCalls = elems.collect: + case spread(elem) => elem.select(nme.length) + val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length)) + val totalLength = + lengthCalls.foldLeft(singleElemCount): (acc, len) => + acc.select(defn.Int_+).appliedTo(len) + + def makeBuilder(name: String) = + ref(defn.VarArgsBuilderModule).select(name.toTermName) + + val builder = + if defn.ScalaValueClasses().contains(elemCls) then + makeBuilder(s"of${elemCls.name}") + else if elemCls.derivesFrom(defn.ObjectClass) then + makeBuilder("ofRef").appliedToType(elemType) + else + makeBuilder("generic").appliedToType(elemType) + + elems.foldLeft(builder.appliedTo(totalLength)): (bldr, elem) => + elem match + case spread(arg) => + if arg.tpe.derivesFrom(defn.SeqClass) then + bldr.select("addSeq".toTermName).appliedTo(arg) + else + bldr.select("addArray".toTermName).appliedTo( + arg.ensureConforms(defn.ArrayOf(elemType))) + case _ => bldr.select("add".toTermName).appliedTo(elem) + .select("result".toTermName) + .appliedToNone + end flattenSpreads + override def transform(tree: Tree)(using Context): Tree = try tree match { // TODO move CaseDef case lower: keep most probable trees first for performance @@ -592,6 +676,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case tree: RefinedTypeTree => Checking.checkPolyFunctionType(tree) super.transform(tree) + case tree: SeqLiteral if tree.hasAttachment(HasSpreads) => + flattenSpreads(tree) case _: Quote | _: QuotePattern => ctx.compilationUnit.needsStaging = true super.transform(tree) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 2929749e3f70..c70915b0f7db 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -24,6 +24,8 @@ import Inferencing.* import reporting.* import Nullables.*, NullOpsDecorator.* import config.{Feature, MigrationVersion, SourceVersion} +import util.Property +import util.chaining.tap import collection.mutable import config.Printers.{overload, typr, unapp} @@ -42,6 +44,17 @@ import dotty.tools.dotc.inlines.Inlines object Applications { import tpd.* + /** Attachment key for SeqLiterals containing spreads. Eliminated at PostTyper */ + val HasSpreads = new Property.StickyKey[Unit] + + /** An extractor for spreads in sequence literals */ + object spread: + def apply(arg: Tree, elemtpt: Tree)(using Context) = + ref(defn.spreadMethod).appliedToTypeTree(elemtpt).appliedTo(arg) + def unapply(arg: Apply)(using Context): Option[Tree] = arg match + case Apply(fn, x :: Nil) if fn.symbol == defn.spreadMethod => Some(x) + case _ => None + def extractorMember(tp: Type, name: Name)(using Context): SingleDenotation = tp.member(name).suchThat(sym => sym.info.isParameterless && sym.info.widenExpr.isValueType) @@ -291,6 +304,11 @@ object Applications { report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), pos) argTypes.take(args.length) ++ List.fill(argTypes.length - args.length)(WildcardType) + + val varArgs = alignedArgs.filter(untpd.isWildcardStarArg) + if varArgs.length >= 2 then + report.error(em"Ony one spread pattern allowed in sequence", varArgs(1).srcPos) + alignedArgs.lazyZip(alignedArgTypes).map(typer.typed(_, _)) .showing(i"unapply patterns = $result", unapp) @@ -797,14 +815,17 @@ trait Applications extends Compatibility { addTyped(arg) case _ => val elemFormal = formal.widenExpr.argTypesLo.head - val typedArgs = - harmonic(harmonizeArgs, elemFormal) { - args.map { arg => - checkNoVarArg(arg) + val typedVarArgs = util.HashSet[TypedArg]() + val typedArgs = harmonic(harmonizeArgs, elemFormal): + args.map: arg => + if isVarArg(arg) then + if !Feature.enabled(Feature.multiSpreads) || ctx.isAfterTyper then + checkNoVarArg(arg) + typedArg(arg, formal).tap(typedVarArgs += _) + else typedArg(arg, elemFormal) - } - } - typedArgs.foreach(addArg(_, elemFormal)) + typedArgs.foreach: targ => + addArg(targ, if typedVarArgs.contains(targ) then formal else elemFormal) makeVarArg(args.length, elemFormal) } else args match { @@ -944,12 +965,18 @@ trait Applications extends Compatibility { typedArgBuf += typedArg ok = ok & !typedArg.tpe.isError - def makeVarArg(n: Int, elemFormal: Type): Unit = { + def makeVarArg(n: Int, elemFormal: Type): Unit = val args = typedArgBuf.takeRight(n).toList typedArgBuf.dropRightInPlace(n) - val elemtpt = TypeTree(elemFormal.normalizedTupleType, inferred = true) - typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt)) - } + val elemTpe = elemFormal.normalizedTupleType + val elemtpt = TypeTree(elemTpe, inferred = true) + def wrapSpread(arg: Tree): Tree = arg match + case Typed(argExpr, tpt) if tpt.tpe.isRepeatedParam => spread(argExpr, elemtpt) + case _ => arg + val args1 = args.mapConserve(wrapSpread) + val seqLit = SeqLiteral(args1, elemtpt) + if args1 ne args then seqLit.putAttachment(HasSpreads, ()) + typedArgBuf += seqToRepeated(seqLit) def harmonizeArgs(args: List[TypedArg]): List[Tree] = // harmonize args only if resType depends on parameter types @@ -2676,7 +2703,8 @@ trait Applications extends Compatibility { case ConstantType(c: Constant) if c.tag == IntTag => targetClass(ts1, cls, true) case t => - val sym = t.classSymbol + val sym = + if t.isRepeatedParam then t.argTypesLo.head.classSymbol else t.classSymbol if (!sym.isNumericValueClass || cls.exists && cls != sym) NoSymbol else targetClass(ts1, sym, intLitSeen) } diff --git a/docs/_docs/internals/syntax.md b/docs/_docs/internals/syntax.md index 7fd7ec1be2e1..686d0551e0f6 100644 --- a/docs/_docs/internals/syntax.md +++ b/docs/_docs/internals/syntax.md @@ -365,7 +365,7 @@ Patterns ::= Pattern {‘,’ Pattern} NamedPattern ::= id '=' Pattern ArgumentPatterns ::= ‘(’ [Patterns] ‘)’ Apply(fn, pats) - | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’ + | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns]‘)’ ``` ### Type and Value Parameters diff --git a/library/src/scala/language.scala b/library/src/scala/language.scala index bacbb09ad615..63941a86bd67 100644 --- a/library/src/scala/language.scala +++ b/library/src/scala/language.scala @@ -350,11 +350,15 @@ object language { @compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements") object packageObjectValues + /** Experimental support for multiple spread arguments. + */ + @compileTimeOnly("`multiSpreads` can only be used at compile time in import statements") + object multiSpreads + /** Experimental support for match expressions with sub cases. */ @compileTimeOnly("`subCases` can only be used at compile time in import statements") object subCases - } /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/library/src/scala/runtime/VarArgsBuilder.scala b/library/src/scala/runtime/VarArgsBuilder.scala new file mode 100644 index 000000000000..c9aa2b3be556 --- /dev/null +++ b/library/src/scala/runtime/VarArgsBuilder.scala @@ -0,0 +1,223 @@ +package scala.runtime + +import scala.collection.immutable.ArraySeq +import scala.reflect.ClassTag + +sealed abstract class VarArgsBuilder[T]: + def add(elem: T): this.type + def addSeq(elems: Seq[T]): this.type + def addArray(elems: Array[T]): this.type + def result(): Seq[T] + +object VarArgsBuilder: + + def generic[T](n: Int): VarArgsBuilder[T] = new VarArgsBuilder[T]: + private val xs = new Array[AnyRef](n) + def result() = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] + private var i = 0 + def add(elem: T): this.type = + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + def addSeq(elems: Seq[T]): this.type = + for elem <- elems do + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + def addArray(elems: Array[T]): this.type = + for elem <- elems do + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + + def ofRef[T <: AnyRef](n: Int): VarArgsBuilder[T] = new VarArgsBuilder[T]: + private val xs = new Array[AnyRef](n) + def result() = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] + private var i = 0 + def add(elem: T): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[T]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[T]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofByte(n: Int): VarArgsBuilder[Byte] = new VarArgsBuilder[Byte]: + private val xs = new Array[Byte](n) + def result() = ArraySeq.ofByte(xs) + private var i = 0 + def add(elem: Byte): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Byte]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Byte]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofShort(n: Int): VarArgsBuilder[Short] = new VarArgsBuilder[Short]: + private val xs = new Array[Short](n) + def result() = ArraySeq.ofShort(xs) + private var i = 0 + def add(elem: Short): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Short]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Short]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofChar(n: Int): VarArgsBuilder[Char] = new VarArgsBuilder[Char]: + private val xs = new Array[Char](n) + def result() = ArraySeq.ofChar(xs) + private var i = 0 + def add(elem: Char): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Char]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Char]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofInt(n: Int): VarArgsBuilder[Int] = new VarArgsBuilder[Int]: + private val xs = new Array[Int](n) + def result() = ArraySeq.ofInt(xs) + private var i = 0 + def add(elem: Int): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Int]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Int]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofLong(n: Int): VarArgsBuilder[Long] = new VarArgsBuilder[Long]: + private val xs = new Array[Long](n) + def result() = ArraySeq.ofLong(xs) + private var i = 0 + def add(elem: Long): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Long]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Long]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofFloat(n: Int): VarArgsBuilder[Float] = new VarArgsBuilder[Float]: + private val xs = new Array[Float](n) + def result() = ArraySeq.ofFloat(xs) + private var i = 0 + def add(elem: Float): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Float]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Float]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofDouble(n: Int): VarArgsBuilder[Double] = new VarArgsBuilder[Double]: + private val xs = new Array[Double](n) + def result() = ArraySeq.ofDouble(xs) + private var i = 0 + def add(elem: Double): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Double]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Double]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofBoolean(n: Int): VarArgsBuilder[Boolean] = new VarArgsBuilder[Boolean]: + private val xs = new Array[Boolean](n) + def result() = ArraySeq.ofBoolean(xs) + private var i = 0 + def add(elem: Boolean): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Boolean]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Boolean]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofUnit(n: Int): VarArgsBuilder[Unit] = new VarArgsBuilder[Unit]: + private val xs = new Array[Unit](n) + def result() = ArraySeq.ofUnit(xs) + private var i = 0 + def add(elem: Unit): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Unit]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Unit]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + +end VarArgsBuilder \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 9d38ea4371ff..c4da436a78d8 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -157,6 +157,11 @@ object language: @compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements") object packageObjectValues + /** Experimental support for multiple spread arguments. + */ + @compileTimeOnly("`multiSpreads` can only be used at compile time in import statements") + object multiSpreads + /** Experimental support for match expressions with sub cases. */ @compileTimeOnly("`subCases` can only be used at compile time in import statements") diff --git a/project/Build.scala b/project/Build.scala index 4edd5a2ed45a..071f887dfeec 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1171,6 +1171,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"), ) ) lazy val `scala3-library-bootstrapped`: Project = project.in(file("library")).asDottyLibrary(Bootstrapped) @@ -1309,6 +1310,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"), ) ) diff --git a/tests/neg/i11419.scala b/tests/neg/i11419.scala index 6ccc65855755..d4bc40ac4cf4 100644 --- a/tests/neg/i11419.scala +++ b/tests/neg/i11419.scala @@ -1,7 +1,7 @@ object Test { def main(args: Array[String]): Unit = { val arr: Array[String] = Array("foo") - val lst = List("x", arr: _*) // error + val lst = List("x", arr*) // error println(lst) } } diff --git a/tests/neg/spread-patterns.scala b/tests/neg/spread-patterns.scala new file mode 100644 index 000000000000..eecb3ecfeb13 --- /dev/null +++ b/tests/neg/spread-patterns.scala @@ -0,0 +1,15 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6) + val xs = List(1, 2, 3, 4, 5, 6) + + xs match + case List(1, 2, xs*, ys*, 6) => println(xs) // error + + + diff --git a/tests/run/i11419.check b/tests/run/i11419.check new file mode 100644 index 000000000000..a95d8f696118 --- /dev/null +++ b/tests/run/i11419.check @@ -0,0 +1 @@ +List(x, foo) diff --git a/tests/run/i11419.scala b/tests/run/i11419.scala new file mode 100644 index 000000000000..be22b943861f --- /dev/null +++ b/tests/run/i11419.scala @@ -0,0 +1,9 @@ +import language.experimental.multiSpreads + +object Test { + def main(args: Array[String]): Unit = { + val arr: Array[String] = Array("foo") + val lst = List("x", arr*) + println(lst) + } +} diff --git a/tests/run/spread-patterns.check b/tests/run/spread-patterns.check new file mode 100644 index 000000000000..a69afe787df5 --- /dev/null +++ b/tests/run/spread-patterns.check @@ -0,0 +1,4 @@ +List(3, 4, 5) +ArraySeq(4, 5, 6) +ArraySeq(1, 2, 3) +ArraySeq(3, 4, 5, 6) diff --git a/tests/run/spread-patterns.scala b/tests/run/spread-patterns.scala new file mode 100644 index 000000000000..9001ec0b61ab --- /dev/null +++ b/tests/run/spread-patterns.scala @@ -0,0 +1,38 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6) + val lst = List(1, 2, 3, 4, 5, 6) + + lst match + case List(1, xs*, 2, 3, 4, 5, 6) => + assert(xs.isEmpty) + + lst match + case List(1, 2, xs*, 6) => println(xs) + + arr match + case Array(1, 2, xs*, 7) => assert(false) + case Array(1, 2, 3, xs*) => println(xs) + + arr match + case Array(xs*, 1, 2) => assert(false) + case Array(xs*, 4, 5, 6) => println(xs) + + arr match + case Array(1, 2, xs*) => println(xs) + + lst match + case List(1, 2, 3, 4, 5, 6, xs*) => assert(xs.isEmpty) + + lst match + case Seq(xs*, 1, 2, 3, 4, 5, 6) => assert(xs.isEmpty) + + + + + diff --git a/tests/run/spreads-subtyping.check b/tests/run/spreads-subtyping.check new file mode 100644 index 000000000000..6e4e9402b9c0 --- /dev/null +++ b/tests/run/spreads-subtyping.check @@ -0,0 +1 @@ +ooffoobar diff --git a/tests/run/spreads-subtyping.scala b/tests/run/spreads-subtyping.scala new file mode 100644 index 000000000000..3680b2c19562 --- /dev/null +++ b/tests/run/spreads-subtyping.scala @@ -0,0 +1,7 @@ +import language.experimental.multiSpreads + +def foo(x: CharSequence*): String = x.mkString +val strings: Array[String] = Array("foo", "bar") + +@main def Test = + println(foo(("oof": CharSequence), strings*)) \ No newline at end of file diff --git a/tests/run/spreads.check b/tests/run/spreads.check new file mode 100644 index 000000000000..bee363252638 --- /dev/null +++ b/tests/run/spreads.check @@ -0,0 +1,13 @@ +ArraySeq(1, 2, 3) +ArraySeq(1, 2, 3) +ArraySeq(1, 2, 1, 2, 3) +ArraySeq(1, 2, 1, 2, 3) +ArraySeq(1, 1, 2, 3, 2) +ArraySeq(1, 1, 2, 3, 2, 1, 2, 3, 3) +ArraySeq(1, 1, 2, 3, true, A, false) +ArraySeq(1, 1, 2, 3, 2) +one +one-two-three +two +ArraySeq(1, 1, 2, 3, 2) +13.0 diff --git a/tests/run/spreads.scala b/tests/run/spreads.scala new file mode 100644 index 000000000000..654f180ae1f7 --- /dev/null +++ b/tests/run/spreads.scala @@ -0,0 +1,47 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3) + use(arr*) + + val iarr: IArray[Int] = IArray(1, 2, 3) + use(iarr*) + + val xs = List(1, 2, 3) + val ys = List("A") + val ao = Option(1.0).toList + + val x: Unit = use[Int](1, 2, xs*) + val y = use(1, 2, xs*) + use(1, xs*, 2) + use(1, xs*, 2, xs*, 3) + use(1, xs*, true, ys*, false) + use(1, identity(xs)*, 2) + + def one = { println("one"); 1 } + def two = { println("two"); 2 } + def oneTwoThree = { println("one-two-three"); xs } + use(one, oneTwoThree*, two) + //use(1.0, ao*, 2.0) + + val numbers1 = Array(1, 2, 3) + val numbers2 = List(4, 5, 6) + + def sum(xs: Int*) = xs.sum + + assert(sum(0, numbers1*, numbers2*, 4) == 25) + + // Tests for harmonization with varargs + + val darr: Array[Double] = Array(1.0, 2) + val zs1 = Array(1, darr*, 2, darr*, 3) + val _: Array[Double] = zs1 + val d: Double = 4 + val zs2 = Array(1, darr*, 2, d, 3) + val _: Array[Double] = zs2 + println(zs2.sum) +