From 0bc9f43019d0fc3ceb1a3b1c06a142dece1de856 Mon Sep 17 00:00:00 2001 From: odersky Date: Sun, 31 Aug 2025 14:23:50 +0200 Subject: [PATCH 01/11] Allow multiple spreads in function arguments --- .../src/dotty/tools/dotc/config/Feature.scala | 1 + .../dotty/tools/dotc/core/Definitions.scala | 9 +- .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotty/tools/dotc/parsing/Parsers.scala | 15 +- .../tools/dotc/transform/PostTyper.scala | 72 ++++++ .../dotty/tools/dotc/typer/Applications.scala | 43 +++- library/src/scala/compiletime/Spread.scala | 0 library/src/scala/language.scala | 6 +- .../src/scala/runtime/ArraySeqBuilder.scala | 223 ++++++++++++++++++ .../runtime/stdLibPatches/language.scala | 5 + project/Build.scala | 2 + tests/neg/i11419.scala | 2 +- tests/run/i11419.check | 1 + tests/run/i11419.scala | 9 + tests/run/spreads.scala | 21 ++ 15 files changed, 392 insertions(+), 18 deletions(-) create mode 100644 library/src/scala/compiletime/Spread.scala create mode 100644 library/src/scala/runtime/ArraySeqBuilder.scala create mode 100644 tests/run/i11419.check create mode 100644 tests/run/i11419.scala create mode 100644 tests/run/spreads.scala diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 70a77c9560b2..02bdb16ae217 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -37,6 +37,7 @@ object Feature: val modularity = experimental("modularity") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") val packageObjectValues = experimental("packageObjectValues") + val multiSpreads = experimental("multiSpreads") val subCases = experimental("subCases") def experimentalAutoEnableFeatures(using Context): List[TermName] = diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 365969d2f74e..1c9bb9edf593 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -468,6 +468,11 @@ class Definitions { @tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw, MethodType(List(ThrowableType), NothingType)) + @tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread, + PolyType(TypeBounds.empty :: Nil)( + tl => MethodType(AnyType :: Nil, tl.paramRefs(0)) + )) + @tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol( ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType)) def NothingType: TypeRef = NothingClass.typeRef @@ -519,6 +524,8 @@ class Definitions { @tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray") @tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray") + @tu lazy val ArraySeqBuilderModule: Symbol = requiredModule("scala.runtime.ArraySeqBuilder") + def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule // The set of all wrap{X, Ref}Array methods, where X is a value type @@ -2234,7 +2241,7 @@ class Definitions { /** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */ @tu lazy val syntheticCoreMethods: List[TermSymbol] = - AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod) + AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod) @tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 18873dfa83af..23c60d19fe0e 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -619,6 +619,7 @@ object StdNames { val setSymbol: N = "setSymbol" val setType: N = "setType" val setTypeSignature: N = "setTypeSignature" + val spread: N = "spread" val standardInterpolator: N = "standardInterpolator" val staticClass : N = "staticClass" val staticModule : N = "staticModule" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index a31152ddcc6f..e5e1248eb396 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1056,17 +1056,22 @@ object Parsers { } /** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not - syntactically valid, but we need to include them here for error recovery. */ + syntactically valid, but we need to include them here for error recovery. + Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally. + */ def followingIsVararg(): Boolean = in.isIdent(nme.raw.STAR) && { val lookahead = in.LookaheadScanner() lookahead.nextToken() lookahead.token == RPAREN || lookahead.token == COMMA - && { - lookahead.nextToken() - lookahead.token == RPAREN || lookahead.token == EOF - } + && ( + in.featureEnabled(Feature.multiSpreads) + || { + lookahead.nextToken() + lookahead.token == RPAREN || lookahead.token == EOF + } + ) } /** When encountering a `:`, is that in the binding of a lambda? diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 9f79c063dc03..377088c54314 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -19,6 +19,9 @@ import config.Feature import util.{SrcPos, Stats} import reporting.* import NameKinds.WildcardParamName +import typer.Applications.{spread, HasSpreads} +import typer.Implicits.SearchFailureType +import Constants.Constant import cc.* import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation import dotty.tools.dotc.core.NameKinds.DefaultGetterName @@ -376,6 +379,73 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case _ => tpt + /** Translate sequence literal containing spread operators. Example: + * + * val xs, ys: List[Int] + * [1, xs*, 2, ys*] + * + * Here the sequence literal is translated at typer to + * + * [1, spread(xs), 2, spread(ys)] + * + * This then translates to + * + * scala.runtime.ArraySeqBuilcder.ofInt(2 + xs.length + ys.length) + * .add(1) + * .addSeq(xs) + * .add(2) + * .addSeq(ys) + * + * The reason for doing a two-step typer/postTyper translation is that + * at typer, we don't have all type variables instantiated yet. + */ + private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree = + val SeqLiteral(elems, elemtpt) = tree + val elemType = elemtpt.tpe + val elemCls = elemType.classSymbol + + val lengthCalls = elems.collect: + case spread(elem) => elem.select(nme.length) + val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length)) + val totalLength = + lengthCalls.foldLeft(singleElemCount): (acc, len) => + acc.select(defn.Int_+).appliedTo(len) + + def makeBuilder(name: String) = + ref(defn.ArraySeqBuilderModule).select(name.toTermName) + def genericBuilder = makeBuilder("generic") + .appliedToType(elemType) + .appliedTo(totalLength) + + val builder = + if defn.ScalaValueClasses().contains(elemCls) then + makeBuilder(s"of${elemCls.name}").appliedTo(totalLength) + else if elemCls.derivesFrom(defn.ObjectClass) then + val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType) + val classTag = atPhase(Phases.typerPhase): + ctx.typer.inferImplicitArg(classTagType, tree.span.startPos) + classTag.tpe match + case _: SearchFailureType => + genericBuilder + case _ => + makeBuilder("ofRef") + .appliedToType(elemType) + .appliedTo(totalLength) + .appliedTo(classTag) + else + genericBuilder + + elems.foldLeft(builder): (bldr, elem) => + elem match + case spread(arg) => + val selector = + if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq" + else "addArray" + bldr.select(selector.toTermName).appliedTo(arg) + case _ => bldr.select("add".toTermName).appliedTo(elem) + .select("result".toTermName) + end flattenSpreads + override def transform(tree: Tree)(using Context): Tree = try tree match { // TODO move CaseDef case lower: keep most probable trees first for performance @@ -592,6 +662,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case tree: RefinedTypeTree => Checking.checkPolyFunctionType(tree) super.transform(tree) + case tree: SeqLiteral if tree.hasAttachment(HasSpreads) => + flattenSpreads(tree) case _: Quote | _: QuotePattern => ctx.compilationUnit.needsStaging = true super.transform(tree) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 2929749e3f70..8e13c91a8d16 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -24,6 +24,7 @@ import Inferencing.* import reporting.* import Nullables.*, NullOpsDecorator.* import config.{Feature, MigrationVersion, SourceVersion} +import util.Property import collection.mutable import config.Printers.{overload, typr, unapp} @@ -42,6 +43,17 @@ import dotty.tools.dotc.inlines.Inlines object Applications { import tpd.* + /** Attachment key for SeqLiterals containing spreads. Eliminated at PostTyper */ + val HasSpreads = new Property.StickyKey[Unit] + + /** An extractor for spreads in sequence literals */ + object spread: + def apply(arg: Tree, elemtpt: Tree)(using Context) = + ref(defn.spreadMethod).appliedToTypeTree(elemtpt).appliedTo(arg) + def unapply(arg: Apply)(using Context): Option[Tree] = arg match + case Apply(fn, x :: Nil) if fn.symbol == defn.spreadMethod => Some(x) + case _ => None + def extractorMember(tp: Type, name: Name)(using Context): SingleDenotation = tp.member(name).suchThat(sym => sym.info.isParameterless && sym.info.widenExpr.isValueType) @@ -797,14 +809,19 @@ trait Applications extends Compatibility { addTyped(arg) case _ => val elemFormal = formal.widenExpr.argTypesLo.head - val typedArgs = - harmonic(harmonizeArgs, elemFormal) { - args.map { arg => + if Feature.enabled(Feature.multiSpreads) + && !ctx.isAfterTyper && args.exists(isVarArg) + then + args.foreach: arg => + if isVarArg(arg) + then addArg(typedArg(arg, formal), formal) + else addArg(typedArg(arg, elemFormal), elemFormal) + else + val typedArgs = harmonic(harmonizeArgs, elemFormal): + args.map: arg => checkNoVarArg(arg) typedArg(arg, elemFormal) - } - } - typedArgs.foreach(addArg(_, elemFormal)) + typedArgs.foreach(addArg(_, elemFormal)) makeVarArg(args.length, elemFormal) } else args match { @@ -944,12 +961,18 @@ trait Applications extends Compatibility { typedArgBuf += typedArg ok = ok & !typedArg.tpe.isError - def makeVarArg(n: Int, elemFormal: Type): Unit = { + def makeVarArg(n: Int, elemFormal: Type): Unit = val args = typedArgBuf.takeRight(n).toList typedArgBuf.dropRightInPlace(n) - val elemtpt = TypeTree(elemFormal.normalizedTupleType, inferred = true) - typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt)) - } + val elemTpe = elemFormal.normalizedTupleType + val elemtpt = TypeTree(elemTpe, inferred = true) + def wrapSpread(arg: Tree): Tree = arg match + case Typed(argExpr, tpt) if tpt.tpe.isRepeatedParam => spread(argExpr, elemtpt) + case _ => arg + val args1 = args.mapConserve(wrapSpread) + val seqLit = SeqLiteral(args1, elemtpt) + if args1 ne args then seqLit.putAttachment(HasSpreads, ()) + typedArgBuf += seqToRepeated(seqLit) def harmonizeArgs(args: List[TypedArg]): List[Tree] = // harmonize args only if resType depends on parameter types diff --git a/library/src/scala/compiletime/Spread.scala b/library/src/scala/compiletime/Spread.scala new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/library/src/scala/language.scala b/library/src/scala/language.scala index bacbb09ad615..63941a86bd67 100644 --- a/library/src/scala/language.scala +++ b/library/src/scala/language.scala @@ -350,11 +350,15 @@ object language { @compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements") object packageObjectValues + /** Experimental support for multiple spread arguments. + */ + @compileTimeOnly("`multiSpreads` can only be used at compile time in import statements") + object multiSpreads + /** Experimental support for match expressions with sub cases. */ @compileTimeOnly("`subCases` can only be used at compile time in import statements") object subCases - } /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/library/src/scala/runtime/ArraySeqBuilder.scala b/library/src/scala/runtime/ArraySeqBuilder.scala new file mode 100644 index 000000000000..399912c07269 --- /dev/null +++ b/library/src/scala/runtime/ArraySeqBuilder.scala @@ -0,0 +1,223 @@ +package scala.runtime + +import scala.collection.immutable.ArraySeq +import scala.reflect.ClassTag + +sealed abstract class ArraySeqBuilder[T]: + def add(elem: T): this.type + def addSeq(elems: Seq[T]): this.type + def addArray(elems: Array[T]): this.type + def result: ArraySeq[T] + +object ArraySeqBuilder: + + def generic[T](n: Int) = new ArraySeqBuilder[T]: + private val xs = new Array[AnyRef](n) + def result = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] + private var i = 0 + def add(elem: T): this.type = + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + def addSeq(elems: Seq[T]): this.type = + for elem <- elems do + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + def addArray(elems: Array[T]): this.type = + for elem <- elems do + xs(i) = elem.asInstanceOf[AnyRef] + i += 1 + this + + def ofRef[T <: AnyRef](n: Int)(using ClassTag[T]) = new ArraySeqBuilder[T]: + private val xs = new Array[T](n) + def result = ArraySeq.ofRef(xs) + private var i = 0 + def add(elem: T): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[T]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[T]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofByte(n: Int) = new ArraySeqBuilder[Byte]: + private val xs = new Array[Byte](n) + def result = ArraySeq.ofByte(xs) + private var i = 0 + def add(elem: Byte): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Byte]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Byte]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofShort(n: Int) = new ArraySeqBuilder[Short]: + private val xs = new Array[Short](n) + def result = ArraySeq.ofShort(xs) + private var i = 0 + def add(elem: Short): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Short]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Short]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofChar(n: Int) = new ArraySeqBuilder[Char]: + private val xs = new Array[Char](n) + def result = ArraySeq.ofChar(xs) + private var i = 0 + def add(elem: Char): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Char]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Char]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofInt(n: Int) = new ArraySeqBuilder[Int]: + private val xs = new Array[Int](n) + def result = ArraySeq.ofInt(xs) + private var i = 0 + def add(elem: Int): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Int]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Int]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofLong(n: Int) = new ArraySeqBuilder[Long]: + private val xs = new Array[Long](n) + def result = ArraySeq.ofLong(xs) + private var i = 0 + def add(elem: Long): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Long]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Long]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofFloat(n: Int) = new ArraySeqBuilder[Float]: + private val xs = new Array[Float](n) + def result = ArraySeq.ofFloat(xs) + private var i = 0 + def add(elem: Float): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Float]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Float]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofDouble(n: Int) = new ArraySeqBuilder[Double]: + private val xs = new Array[Double](n) + def result = ArraySeq.ofDouble(xs) + private var i = 0 + def add(elem: Double): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Double]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Double]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofBoolean(n: Int) = new ArraySeqBuilder[Boolean]: + private val xs = new Array[Boolean](n) + def result = ArraySeq.ofBoolean(xs) + private var i = 0 + def add(elem: Boolean): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Boolean]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Boolean]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + + def ofUnit(n: Int) = new ArraySeqBuilder[Unit]: + private val xs = new Array[Unit](n) + def result = ArraySeq.ofUnit(xs) + private var i = 0 + def add(elem: Unit): this.type = + xs(i) = elem + i += 1 + this + def addSeq(elems: Seq[Unit]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + def addArray(elems: Array[Unit]): this.type = + for elem <- elems do + xs(i) = elem + i += 1 + this + +end ArraySeqBuilder \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 9d38ea4371ff..c4da436a78d8 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -157,6 +157,11 @@ object language: @compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements") object packageObjectValues + /** Experimental support for multiple spread arguments. + */ + @compileTimeOnly("`multiSpreads` can only be used at compile time in import statements") + object multiSpreads + /** Experimental support for match expressions with sub cases. */ @compileTimeOnly("`subCases` can only be used at compile time in import statements") diff --git a/project/Build.scala b/project/Build.scala index 4edd5a2ed45a..8045f58d468b 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1171,6 +1171,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/ArraySeqBuilder.scala"), ) ) lazy val `scala3-library-bootstrapped`: Project = project.in(file("library")).asDottyLibrary(Bootstrapped) @@ -1309,6 +1310,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/ArraySeqBuilder.scala"), ) ) diff --git a/tests/neg/i11419.scala b/tests/neg/i11419.scala index 6ccc65855755..d4bc40ac4cf4 100644 --- a/tests/neg/i11419.scala +++ b/tests/neg/i11419.scala @@ -1,7 +1,7 @@ object Test { def main(args: Array[String]): Unit = { val arr: Array[String] = Array("foo") - val lst = List("x", arr: _*) // error + val lst = List("x", arr*) // error println(lst) } } diff --git a/tests/run/i11419.check b/tests/run/i11419.check new file mode 100644 index 000000000000..a95d8f696118 --- /dev/null +++ b/tests/run/i11419.check @@ -0,0 +1 @@ +List(x, foo) diff --git a/tests/run/i11419.scala b/tests/run/i11419.scala new file mode 100644 index 000000000000..1f0a0591dbaf --- /dev/null +++ b/tests/run/i11419.scala @@ -0,0 +1,9 @@ +import language.experimental.multiSpreads + +object Test { + def main(args: Array[String]): Unit = { + val arr: Array[String] = Array("foo") + val lst = List("x", arr*) // error + println(lst) + } +} diff --git a/tests/run/spreads.scala b/tests/run/spreads.scala new file mode 100644 index 000000000000..c879da923f04 --- /dev/null +++ b/tests/run/spreads.scala @@ -0,0 +1,21 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3) + use(arr*) + + val iarr: IArray[Int] = IArray(1, 2, 3) + use(iarr*) + + val xs = List(1, 2, 3) + val ys = List("A") + + val x: Unit = use[Int](1, 2, xs*) + val y = use(1, 2, xs*) + use(1, xs*, 2) + use(1, xs*, 2, xs*, 3) + use(1, xs*, true, ys*, false) From 1f333bc4716b817d2bbedc7e0cf56785e89ef3c3 Mon Sep 17 00:00:00 2001 From: odersky Date: Tue, 2 Sep 2025 14:14:41 +0200 Subject: [PATCH 02/11] Make sure spreads are evaluated only once We need to access them twice because we first need to take their length, then append them to the buffer. If a spread might have side effects, lift all side-effecting arguments out in the order of occurrence. --- .../tools/dotc/transform/PostTyper.scala | 105 +++++++++++------- tests/run/spreads.check | 12 ++ tests/run/spreads.scala | 11 ++ 3 files changed, 86 insertions(+), 42 deletions(-) create mode 100644 tests/run/spreads.check diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 377088c54314..17172e9d58b1 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -18,13 +18,14 @@ import config.Printers.typr import config.Feature import util.{SrcPos, Stats} import reporting.* -import NameKinds.WildcardParamName +import NameKinds.{WildcardParamName, TempResultName} import typer.Applications.{spread, HasSpreads} import typer.Implicits.SearchFailureType import Constants.Constant import cc.* import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation import dotty.tools.dotc.core.NameKinds.DefaultGetterName +import ast.TreeInfo object PostTyper { val name: String = "posttyper" @@ -379,6 +380,25 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case _ => tpt + private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree = + if trees.exists: + case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent) + case _ => false + then + val lifted = new mutable.ListBuffer[ValDef] + def liftIfImpure(tree: Tree): Tree = tree match + case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod => + cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure)) + case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent => + tree + case _ => + val vdef = SyntheticValDef(TempResultName.fresh(), tree) + lifted += vdef + Ident(vdef.namedType) + val pureTrees = trees.mapConserve(liftIfImpure) + Block(lifted.toList, within(pureTrees)) + else within(trees) + /** Translate sequence literal containing spread operators. Example: * * val xs, ys: List[Int] @@ -400,50 +420,51 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => * at typer, we don't have all type variables instantiated yet. */ private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree = - val SeqLiteral(elems, elemtpt) = tree + val SeqLiteral(rawElems, elemtpt) = tree val elemType = elemtpt.tpe val elemCls = elemType.classSymbol - val lengthCalls = elems.collect: - case spread(elem) => elem.select(nme.length) - val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length)) - val totalLength = - lengthCalls.foldLeft(singleElemCount): (acc, len) => - acc.select(defn.Int_+).appliedTo(len) - - def makeBuilder(name: String) = - ref(defn.ArraySeqBuilderModule).select(name.toTermName) - def genericBuilder = makeBuilder("generic") - .appliedToType(elemType) - .appliedTo(totalLength) - - val builder = - if defn.ScalaValueClasses().contains(elemCls) then - makeBuilder(s"of${elemCls.name}").appliedTo(totalLength) - else if elemCls.derivesFrom(defn.ObjectClass) then - val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType) - val classTag = atPhase(Phases.typerPhase): - ctx.typer.inferImplicitArg(classTagType, tree.span.startPos) - classTag.tpe match - case _: SearchFailureType => - genericBuilder - case _ => - makeBuilder("ofRef") - .appliedToType(elemType) - .appliedTo(totalLength) - .appliedTo(classTag) - else - genericBuilder - - elems.foldLeft(builder): (bldr, elem) => - elem match - case spread(arg) => - val selector = - if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq" - else "addArray" - bldr.select(selector.toTermName).appliedTo(arg) - case _ => bldr.select("add".toTermName).appliedTo(elem) - .select("result".toTermName) + evalSpreadsOnce(rawElems): elems => + val lengthCalls = elems.collect: + case spread(elem) => elem.select(nme.length) + val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length)) + val totalLength = + lengthCalls.foldLeft(singleElemCount): (acc, len) => + acc.select(defn.Int_+).appliedTo(len) + + def makeBuilder(name: String) = + ref(defn.ArraySeqBuilderModule).select(name.toTermName) + def genericBuilder = makeBuilder("generic") + .appliedToType(elemType) + .appliedTo(totalLength) + + val builder = + if defn.ScalaValueClasses().contains(elemCls) then + makeBuilder(s"of${elemCls.name}").appliedTo(totalLength) + else if elemCls.derivesFrom(defn.ObjectClass) then + val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType) + val classTag = atPhase(Phases.typerPhase): + ctx.typer.inferImplicitArg(classTagType, tree.span.startPos) + classTag.tpe match + case _: SearchFailureType => + genericBuilder + case _ => + makeBuilder("ofRef") + .appliedToType(elemType) + .appliedTo(totalLength) + .appliedTo(classTag) + else + genericBuilder + + elems.foldLeft(builder): (bldr, elem) => + elem match + case spread(arg) => + val selector = + if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq" + else "addArray" + bldr.select(selector.toTermName).appliedTo(arg) + case _ => bldr.select("add".toTermName).appliedTo(elem) + .select("result".toTermName) end flattenSpreads override def transform(tree: Tree)(using Context): Tree = diff --git a/tests/run/spreads.check b/tests/run/spreads.check new file mode 100644 index 000000000000..48bf0d8e1f32 --- /dev/null +++ b/tests/run/spreads.check @@ -0,0 +1,12 @@ +ArraySeq(1, 2, 3) +ArraySeq(1, 2, 3) +ArraySeq(1, 2, 1, 2, 3) +ArraySeq(1, 2, 1, 2, 3) +ArraySeq(1, 1, 2, 3, 2) +ArraySeq(1, 1, 2, 3, 2, 1, 2, 3, 3) +ArraySeq(1, 1, 2, 3, true, A, false) +ArraySeq(1, 1, 2, 3, 2) +one +one-two-three +two +ArraySeq(1, 1, 2, 3, 2) diff --git a/tests/run/spreads.scala b/tests/run/spreads.scala index c879da923f04..116f08d76f9f 100644 --- a/tests/run/spreads.scala +++ b/tests/run/spreads.scala @@ -13,9 +13,20 @@ def useInt(xs: Int*) = ??? val xs = List(1, 2, 3) val ys = List("A") + val ao = Option(1.0).toList val x: Unit = use[Int](1, 2, xs*) val y = use(1, 2, xs*) use(1, xs*, 2) use(1, xs*, 2, xs*, 3) use(1, xs*, true, ys*, false) + use(1, identity(xs)*, 2) + + def one = { println("one"); 1 } + def two = { println("two"); 2 } + def oneTwoThree = { println("one-two-three"); xs } + use(one, oneTwoThree*, two) + //use(1.0, ao*, 2.0) + + + From 98cf146f0d2f0cd03a13ae671ae5d0bf0e1549f7 Mon Sep 17 00:00:00 2001 From: odersky Date: Tue, 2 Sep 2025 16:42:25 +0200 Subject: [PATCH 03/11] Fix test-pickling problem --- compiler/src/dotty/tools/dotc/transform/PostTyper.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 17172e9d58b1..06bd1b34b5ff 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -392,9 +392,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent => tree case _ => - val vdef = SyntheticValDef(TempResultName.fresh(), tree) + val vdef = SyntheticValDef(TempResultName.fresh(), tree).withSpan(tree.span) lifted += vdef - Ident(vdef.namedType) + Ident(vdef.namedType).withSpan(tree.span) val pureTrees = trees.mapConserve(liftIfImpure) Block(lifted.toList, within(pureTrees)) else within(trees) From 20ae6c144a3f88015b7074c3c57406cbf59dccf4 Mon Sep 17 00:00:00 2001 From: odersky Date: Wed, 3 Sep 2025 09:28:41 +0200 Subject: [PATCH 04/11] Add comment --- compiler/src/dotty/tools/dotc/transform/PostTyper.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 06bd1b34b5ff..c8e2de1d321c 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -380,6 +380,10 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case _ => tpt + /** If one of `trees` is a spread of an expression that is not idempotent, lift out all + * non-idempotent expressions (not just the spreads) and apply `within` to the resulting + * pure references. Otherwise apply `within` to the original trees. + */ private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree = if trees.exists: case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent) From ba00c03a2a576b42be39fab03cee5b099e42f163 Mon Sep 17 00:00:00 2001 From: odersky Date: Thu, 18 Sep 2025 23:06:44 +0200 Subject: [PATCH 05/11] Implement spreads in the middle of pattern sequences --- .../dotty/tools/dotc/core/Definitions.scala | 3 +- .../src/dotty/tools/dotc/core/StdNames.scala | 2 + .../dotty/tools/dotc/parsing/Parsers.scala | 4 +- .../tools/dotc/transform/PatternMatcher.scala | 73 +++++++++++++------ .../dotty/tools/dotc/typer/Applications.scala | 5 ++ docs/_docs/internals/syntax.md | 2 +- tests/neg/spread-patterns.scala | 15 ++++ tests/run/spread-patterns.check | 4 + tests/run/spread-patterns.scala | 38 ++++++++++ 9 files changed, 120 insertions(+), 26 deletions(-) create mode 100644 tests/neg/spread-patterns.scala create mode 100644 tests/run/spread-patterns.check create mode 100644 tests/run/spread-patterns.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1c9bb9edf593..4b8e42a33bc0 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -570,11 +570,12 @@ class Definitions { @tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply) @tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head) @tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop) + @tu lazy val Seq_dropRight : Symbol = SeqClass.requiredMethod(nme.dropRight) + @tu lazy val Seq_takeRight : Symbol = SeqClass.requiredMethod(nme.takeRight) @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) - @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") @tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 23c60d19fe0e..323c59a5711d 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -470,6 +470,7 @@ object StdNames { val doubleHash: N = "doubleHash" val dotty: N = "dotty" val drop: N = "drop" + val dropRight: N = "dropRight" val dynamics: N = "dynamics" val elem: N = "elem" val elems: N = "elems" @@ -802,6 +803,7 @@ object StdNames { val takeModulo: N = "takeModulo" val takeNot: N = "takeNot" val takeOr: N = "takeOr" + val takeRight: N = "takeRight" val takeXor: N = "takeXor" val testEqual: N = "testEqual" val testGreaterOrEqualThan: N = "testGreaterOrEqualThan" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index e5e1248eb396..3a68703e4734 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -3352,7 +3352,9 @@ object Parsers { if (in.token == RPAREN) Nil else patterns(location) /** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’ - * | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’ + * | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns] ‘)’ + * + * -- It is checked in Typer that there are no repeated PatVar arguments. */ def argumentPatterns(): List[Tree] = inParensWithCommas(patternsOpt(Location.InPatternArgs)) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 558cbd72dd43..af0dcea31edc 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -299,30 +299,57 @@ object PatternMatcher { } /** Plan for matching the sequence in `getResult` against sequence elements - * and a possible last varargs argument `args`. + * `args`. Sequence elements may contain a varargs argument. + * Example: + * + * lst match case Seq(1, xs*, 2, 3) => ... + * + * generates code which is equivalent to: + * + * if lst != null then + * if lst.lengthCompare >= 1 then + * if lst(0) == 1 then + * val x1 = lst.drop(1) + * val xs = x1.dropRight(2) + * val x2 = lst.takeRight(2) + * if x2.lengthCompare >= 2 then + * if x2(0) == 2 then + * if x2(1) == 3 then + * return[matchResult] ... */ - def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match { - case Some(VarArgPattern(arg)) => - val matchRemaining = - if (args.length == 1) { - val toSeq = ref(getResult) - .select(defn.Seq_toSeq.matchingMember(getResult.info)) - letAbstract(toSeq) { toSeqResult => - patternPlan(toSeqResult, arg, onSuccess) - } - } - else { - val dropped = ref(getResult) - .select(defn.Seq_drop.matchingMember(getResult.info)) - .appliedTo(Literal(Constant(args.length - 1))) - letAbstract(dropped) { droppedResult => - patternPlan(droppedResult, arg, onSuccess) - } - } - matchElemsPlan(getResult, args.init, exact = false, matchRemaining) - case _ => - matchElemsPlan(getResult, args, exact = true, onSuccess) - } + def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = + val (leading, varargAndRest) = args.span: + case VarArgPattern(_) => false + case _ => true + varargAndRest match + case VarArgPattern(arg) :: trailing => + val remaining = + if leading.isEmpty then + ref(getResult) + .select(defn.Seq_toSeq.matchingMember(getResult.info)) + else + ref(getResult) + .select(defn.Seq_drop.matchingMember(getResult.info)) + .appliedTo(Literal(Constant(leading.length))) + val matchRemaining = + letAbstract(remaining): remainingResult => + if trailing.isEmpty then + patternPlan(remainingResult, arg, onSuccess) + else + val seq = ref(remainingResult) + .select(defn.Seq_dropRight.matchingMember(remainingResult.info)) + .appliedTo(Literal(Constant(trailing.length))) + letAbstract(seq): seqResult => + val rest = ref(remainingResult) + .select(defn.Seq_takeRight.matchingMember(remainingResult.info)) + .appliedTo(Literal(Constant(trailing.length))) + val matchTrailing = + letAbstract(rest): trailingResult => + matchElemsPlan(trailingResult, trailing, exact = true, onSuccess) + patternPlan(seqResult, arg, matchTrailing) + matchElemsPlan(getResult, leading, exact = false, matchRemaining) + case _ => + matchElemsPlan(getResult, args, exact = true, onSuccess) /** Plan for matching the sequence in `getResult` * diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 8e13c91a8d16..592376413c73 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -303,6 +303,11 @@ object Applications { report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), pos) argTypes.take(args.length) ++ List.fill(argTypes.length - args.length)(WildcardType) + + val varArgs = alignedArgs.filter(untpd.isWildcardStarArg) + if varArgs.length >= 2 then + report.error(em"Ony one spread pattern allowed in sequence", varArgs(1).srcPos) + alignedArgs.lazyZip(alignedArgTypes).map(typer.typed(_, _)) .showing(i"unapply patterns = $result", unapp) diff --git a/docs/_docs/internals/syntax.md b/docs/_docs/internals/syntax.md index 7fd7ec1be2e1..686d0551e0f6 100644 --- a/docs/_docs/internals/syntax.md +++ b/docs/_docs/internals/syntax.md @@ -365,7 +365,7 @@ Patterns ::= Pattern {‘,’ Pattern} NamedPattern ::= id '=' Pattern ArgumentPatterns ::= ‘(’ [Patterns] ‘)’ Apply(fn, pats) - | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’ + | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns]‘)’ ``` ### Type and Value Parameters diff --git a/tests/neg/spread-patterns.scala b/tests/neg/spread-patterns.scala new file mode 100644 index 000000000000..eecb3ecfeb13 --- /dev/null +++ b/tests/neg/spread-patterns.scala @@ -0,0 +1,15 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6) + val xs = List(1, 2, 3, 4, 5, 6) + + xs match + case List(1, 2, xs*, ys*, 6) => println(xs) // error + + + diff --git a/tests/run/spread-patterns.check b/tests/run/spread-patterns.check new file mode 100644 index 000000000000..a69afe787df5 --- /dev/null +++ b/tests/run/spread-patterns.check @@ -0,0 +1,4 @@ +List(3, 4, 5) +ArraySeq(4, 5, 6) +ArraySeq(1, 2, 3) +ArraySeq(3, 4, 5, 6) diff --git a/tests/run/spread-patterns.scala b/tests/run/spread-patterns.scala new file mode 100644 index 000000000000..9001ec0b61ab --- /dev/null +++ b/tests/run/spread-patterns.scala @@ -0,0 +1,38 @@ +import language.experimental.multiSpreads + +def use[T](xs: T*) = println(xs) + +def useInt(xs: Int*) = ??? + +@main def Test() = + val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6) + val lst = List(1, 2, 3, 4, 5, 6) + + lst match + case List(1, xs*, 2, 3, 4, 5, 6) => + assert(xs.isEmpty) + + lst match + case List(1, 2, xs*, 6) => println(xs) + + arr match + case Array(1, 2, xs*, 7) => assert(false) + case Array(1, 2, 3, xs*) => println(xs) + + arr match + case Array(xs*, 1, 2) => assert(false) + case Array(xs*, 4, 5, 6) => println(xs) + + arr match + case Array(1, 2, xs*) => println(xs) + + lst match + case List(1, 2, 3, 4, 5, 6, xs*) => assert(xs.isEmpty) + + lst match + case Seq(xs*, 1, 2, 3, 4, 5, 6) => assert(xs.isEmpty) + + + + + From 54cea7bae5c3b718af772bb1c21df9fbf373bff8 Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 19 Sep 2025 09:51:17 +0200 Subject: [PATCH 06/11] Optimize length testing for sequence matches --- .../tools/dotc/transform/PatternMatcher.scala | 34 +++++++++++-------- tests/run/spreads.scala | 7 ++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index af0dcea31edc..273879f3c3cb 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -198,6 +198,8 @@ object PatternMatcher { case object NonNullTest extends Test // scrutinee ne null case object GuardTest extends Test // scrutinee + val noLengthTest = LengthTest(0, exact = false) + // ------- Generating plans from trees ------------------------ /** A set of variabes that are known to be not null */ @@ -291,12 +293,14 @@ object PatternMatcher { /** Plan for matching the sequence in `seqSym` against sequence elements `args`. * If `exact` is true, the sequence is not permitted to have any elements following `args`. */ - def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = { - val selectors = args.indices.toList.map(idx => - ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))) - TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span, - matchArgsPlan(selectors, args, onSuccess)) - } + def matchElemsPlan(seqSym: Symbol, args: List[Tree], lengthTest: LengthTest, onSuccess: Plan) = + val selectors = args.indices.toList.map: idx => + ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))) + if lengthTest.len == 0 && lengthTest.exact == false then // redundant test + matchArgsPlan(selectors, args, onSuccess) + else + TestPlan(lengthTest, seqSym, seqSym.span, + matchArgsPlan(selectors, args, onSuccess)) /** Plan for matching the sequence in `getResult` against sequence elements * `args`. Sequence elements may contain a varargs argument. @@ -307,15 +311,13 @@ object PatternMatcher { * generates code which is equivalent to: * * if lst != null then - * if lst.lengthCompare >= 1 then + * if lst.lengthCompare >= 3 then * if lst(0) == 1 then * val x1 = lst.drop(1) * val xs = x1.dropRight(2) * val x2 = lst.takeRight(2) - * if x2.lengthCompare >= 2 then - * if x2(0) == 2 then - * if x2(1) == 3 then - * return[matchResult] ... + * if x2(0) == 2 && x2(1) == 3 then + * return[matchResult] ... */ def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = val (leading, varargAndRest) = args.span: @@ -345,11 +347,13 @@ object PatternMatcher { .appliedTo(Literal(Constant(trailing.length))) val matchTrailing = letAbstract(rest): trailingResult => - matchElemsPlan(trailingResult, trailing, exact = true, onSuccess) + matchElemsPlan(trailingResult, trailing, noLengthTest, onSuccess) patternPlan(seqResult, arg, matchTrailing) - matchElemsPlan(getResult, leading, exact = false, matchRemaining) + matchElemsPlan(getResult, leading, + LengthTest(leading.length + trailing.length, exact = false), + matchRemaining) case _ => - matchElemsPlan(getResult, args, exact = true, onSuccess) + matchElemsPlan(getResult, args, LengthTest(args.length, exact = true), onSuccess) /** Plan for matching the sequence in `getResult` * @@ -518,7 +522,7 @@ object PatternMatcher { case WildcardPattern() | This(_) => onSuccess case SeqLiteral(pats, _) => - matchElemsPlan(scrutinee, pats, exact = true, onSuccess) + matchElemsPlan(scrutinee, pats, LengthTest(pats.length, exact = true), onSuccess) case _ => TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess) } diff --git a/tests/run/spreads.scala b/tests/run/spreads.scala index 116f08d76f9f..8d33dc054102 100644 --- a/tests/run/spreads.scala +++ b/tests/run/spreads.scala @@ -28,5 +28,12 @@ def useInt(xs: Int*) = ??? use(one, oneTwoThree*, two) //use(1.0, ao*, 2.0) + val numbers1 = Array(1, 2, 3) + val numbers2 = List(4, 5, 6) + + def sum(xs: Int*) = xs.sum + + assert(sum(0, numbers1*, numbers2*, 4) == 25) + From c9601795717caa17778b51be1d61396481fd0abb Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 19 Sep 2025 16:27:49 +0200 Subject: [PATCH 07/11] Perform harmonization also in the presence of spread arguments --- .../dotty/tools/dotc/typer/Applications.scala | 26 +++++++++---------- tests/run/spreads.check | 1 + tests/run/spreads.scala | 10 ++++++- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 592376413c73..c70915b0f7db 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -25,6 +25,7 @@ import reporting.* import Nullables.*, NullOpsDecorator.* import config.{Feature, MigrationVersion, SourceVersion} import util.Property +import util.chaining.tap import collection.mutable import config.Printers.{overload, typr, unapp} @@ -814,19 +815,17 @@ trait Applications extends Compatibility { addTyped(arg) case _ => val elemFormal = formal.widenExpr.argTypesLo.head - if Feature.enabled(Feature.multiSpreads) - && !ctx.isAfterTyper && args.exists(isVarArg) - then - args.foreach: arg => - if isVarArg(arg) - then addArg(typedArg(arg, formal), formal) - else addArg(typedArg(arg, elemFormal), elemFormal) - else - val typedArgs = harmonic(harmonizeArgs, elemFormal): - args.map: arg => - checkNoVarArg(arg) + val typedVarArgs = util.HashSet[TypedArg]() + val typedArgs = harmonic(harmonizeArgs, elemFormal): + args.map: arg => + if isVarArg(arg) then + if !Feature.enabled(Feature.multiSpreads) || ctx.isAfterTyper then + checkNoVarArg(arg) + typedArg(arg, formal).tap(typedVarArgs += _) + else typedArg(arg, elemFormal) - typedArgs.foreach(addArg(_, elemFormal)) + typedArgs.foreach: targ => + addArg(targ, if typedVarArgs.contains(targ) then formal else elemFormal) makeVarArg(args.length, elemFormal) } else args match { @@ -2704,7 +2703,8 @@ trait Applications extends Compatibility { case ConstantType(c: Constant) if c.tag == IntTag => targetClass(ts1, cls, true) case t => - val sym = t.classSymbol + val sym = + if t.isRepeatedParam then t.argTypesLo.head.classSymbol else t.classSymbol if (!sym.isNumericValueClass || cls.exists && cls != sym) NoSymbol else targetClass(ts1, sym, intLitSeen) } diff --git a/tests/run/spreads.check b/tests/run/spreads.check index 48bf0d8e1f32..bee363252638 100644 --- a/tests/run/spreads.check +++ b/tests/run/spreads.check @@ -10,3 +10,4 @@ one one-two-three two ArraySeq(1, 1, 2, 3, 2) +13.0 diff --git a/tests/run/spreads.scala b/tests/run/spreads.scala index 8d33dc054102..654f180ae1f7 100644 --- a/tests/run/spreads.scala +++ b/tests/run/spreads.scala @@ -35,5 +35,13 @@ def useInt(xs: Int*) = ??? assert(sum(0, numbers1*, numbers2*, 4) == 25) - + // Tests for harmonization with varargs + + val darr: Array[Double] = Array(1.0, 2) + val zs1 = Array(1, darr*, 2, darr*, 3) + val _: Array[Double] = zs1 + val d: Double = 4 + val zs2 = Array(1, darr*, 2, d, 3) + val _: Array[Double] = zs2 + println(zs2.sum) From c8d50cf865d2ec05326561e152fad958f36c1ade Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 19 Sep 2025 16:55:03 +0200 Subject: [PATCH 08/11] Rename ArraySeqBuilder to VarArgsBuilder and apply review suggestions --- .../dotty/tools/dotc/core/Definitions.scala | 2 +- .../tools/dotc/transform/PostTyper.scala | 24 +++------ ...ySeqBuilder.scala => VarArgsBuilder.scala} | 54 +++++++++---------- project/Build.scala | 4 +- 4 files changed, 36 insertions(+), 48 deletions(-) rename library/src/scala/runtime/{ArraySeqBuilder.scala => VarArgsBuilder.scala} (75%) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 4b8e42a33bc0..dbe1602e2d82 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -524,7 +524,7 @@ class Definitions { @tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray") @tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray") - @tu lazy val ArraySeqBuilderModule: Symbol = requiredModule("scala.runtime.ArraySeqBuilder") + @tu lazy val VarArgsBuilderModule: Symbol = requiredModule("scala.runtime.VarArgsBuilder") def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index c8e2de1d321c..a347a9dd8afa 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -437,30 +437,17 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => acc.select(defn.Int_+).appliedTo(len) def makeBuilder(name: String) = - ref(defn.ArraySeqBuilderModule).select(name.toTermName) - def genericBuilder = makeBuilder("generic") - .appliedToType(elemType) - .appliedTo(totalLength) + ref(defn.VarArgsBuilderModule).select(name.toTermName) val builder = if defn.ScalaValueClasses().contains(elemCls) then - makeBuilder(s"of${elemCls.name}").appliedTo(totalLength) + makeBuilder(s"of${elemCls.name}") else if elemCls.derivesFrom(defn.ObjectClass) then - val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType) - val classTag = atPhase(Phases.typerPhase): - ctx.typer.inferImplicitArg(classTagType, tree.span.startPos) - classTag.tpe match - case _: SearchFailureType => - genericBuilder - case _ => - makeBuilder("ofRef") - .appliedToType(elemType) - .appliedTo(totalLength) - .appliedTo(classTag) + makeBuilder("ofRef").appliedToType(elemType) else - genericBuilder + makeBuilder("generic").appliedToType(elemType) - elems.foldLeft(builder): (bldr, elem) => + elems.foldLeft(builder.appliedTo(totalLength)): (bldr, elem) => elem match case spread(arg) => val selector = @@ -469,6 +456,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => bldr.select(selector.toTermName).appliedTo(arg) case _ => bldr.select("add".toTermName).appliedTo(elem) .select("result".toTermName) + .appliedToNone end flattenSpreads override def transform(tree: Tree)(using Context): Tree = diff --git a/library/src/scala/runtime/ArraySeqBuilder.scala b/library/src/scala/runtime/VarArgsBuilder.scala similarity index 75% rename from library/src/scala/runtime/ArraySeqBuilder.scala rename to library/src/scala/runtime/VarArgsBuilder.scala index 399912c07269..c9aa2b3be556 100644 --- a/library/src/scala/runtime/ArraySeqBuilder.scala +++ b/library/src/scala/runtime/VarArgsBuilder.scala @@ -3,17 +3,17 @@ package scala.runtime import scala.collection.immutable.ArraySeq import scala.reflect.ClassTag -sealed abstract class ArraySeqBuilder[T]: +sealed abstract class VarArgsBuilder[T]: def add(elem: T): this.type def addSeq(elems: Seq[T]): this.type def addArray(elems: Array[T]): this.type - def result: ArraySeq[T] + def result(): Seq[T] -object ArraySeqBuilder: +object VarArgsBuilder: - def generic[T](n: Int) = new ArraySeqBuilder[T]: + def generic[T](n: Int): VarArgsBuilder[T] = new VarArgsBuilder[T]: private val xs = new Array[AnyRef](n) - def result = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] + def result() = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] private var i = 0 def add(elem: T): this.type = xs(i) = elem.asInstanceOf[AnyRef] @@ -30,9 +30,9 @@ object ArraySeqBuilder: i += 1 this - def ofRef[T <: AnyRef](n: Int)(using ClassTag[T]) = new ArraySeqBuilder[T]: - private val xs = new Array[T](n) - def result = ArraySeq.ofRef(xs) + def ofRef[T <: AnyRef](n: Int): VarArgsBuilder[T] = new VarArgsBuilder[T]: + private val xs = new Array[AnyRef](n) + def result() = ArraySeq.ofRef(xs).asInstanceOf[ArraySeq[T]] private var i = 0 def add(elem: T): this.type = xs(i) = elem @@ -49,9 +49,9 @@ object ArraySeqBuilder: i += 1 this - def ofByte(n: Int) = new ArraySeqBuilder[Byte]: + def ofByte(n: Int): VarArgsBuilder[Byte] = new VarArgsBuilder[Byte]: private val xs = new Array[Byte](n) - def result = ArraySeq.ofByte(xs) + def result() = ArraySeq.ofByte(xs) private var i = 0 def add(elem: Byte): this.type = xs(i) = elem @@ -68,9 +68,9 @@ object ArraySeqBuilder: i += 1 this - def ofShort(n: Int) = new ArraySeqBuilder[Short]: + def ofShort(n: Int): VarArgsBuilder[Short] = new VarArgsBuilder[Short]: private val xs = new Array[Short](n) - def result = ArraySeq.ofShort(xs) + def result() = ArraySeq.ofShort(xs) private var i = 0 def add(elem: Short): this.type = xs(i) = elem @@ -87,9 +87,9 @@ object ArraySeqBuilder: i += 1 this - def ofChar(n: Int) = new ArraySeqBuilder[Char]: + def ofChar(n: Int): VarArgsBuilder[Char] = new VarArgsBuilder[Char]: private val xs = new Array[Char](n) - def result = ArraySeq.ofChar(xs) + def result() = ArraySeq.ofChar(xs) private var i = 0 def add(elem: Char): this.type = xs(i) = elem @@ -106,9 +106,9 @@ object ArraySeqBuilder: i += 1 this - def ofInt(n: Int) = new ArraySeqBuilder[Int]: + def ofInt(n: Int): VarArgsBuilder[Int] = new VarArgsBuilder[Int]: private val xs = new Array[Int](n) - def result = ArraySeq.ofInt(xs) + def result() = ArraySeq.ofInt(xs) private var i = 0 def add(elem: Int): this.type = xs(i) = elem @@ -125,9 +125,9 @@ object ArraySeqBuilder: i += 1 this - def ofLong(n: Int) = new ArraySeqBuilder[Long]: + def ofLong(n: Int): VarArgsBuilder[Long] = new VarArgsBuilder[Long]: private val xs = new Array[Long](n) - def result = ArraySeq.ofLong(xs) + def result() = ArraySeq.ofLong(xs) private var i = 0 def add(elem: Long): this.type = xs(i) = elem @@ -144,9 +144,9 @@ object ArraySeqBuilder: i += 1 this - def ofFloat(n: Int) = new ArraySeqBuilder[Float]: + def ofFloat(n: Int): VarArgsBuilder[Float] = new VarArgsBuilder[Float]: private val xs = new Array[Float](n) - def result = ArraySeq.ofFloat(xs) + def result() = ArraySeq.ofFloat(xs) private var i = 0 def add(elem: Float): this.type = xs(i) = elem @@ -163,9 +163,9 @@ object ArraySeqBuilder: i += 1 this - def ofDouble(n: Int) = new ArraySeqBuilder[Double]: + def ofDouble(n: Int): VarArgsBuilder[Double] = new VarArgsBuilder[Double]: private val xs = new Array[Double](n) - def result = ArraySeq.ofDouble(xs) + def result() = ArraySeq.ofDouble(xs) private var i = 0 def add(elem: Double): this.type = xs(i) = elem @@ -182,9 +182,9 @@ object ArraySeqBuilder: i += 1 this - def ofBoolean(n: Int) = new ArraySeqBuilder[Boolean]: + def ofBoolean(n: Int): VarArgsBuilder[Boolean] = new VarArgsBuilder[Boolean]: private val xs = new Array[Boolean](n) - def result = ArraySeq.ofBoolean(xs) + def result() = ArraySeq.ofBoolean(xs) private var i = 0 def add(elem: Boolean): this.type = xs(i) = elem @@ -201,9 +201,9 @@ object ArraySeqBuilder: i += 1 this - def ofUnit(n: Int) = new ArraySeqBuilder[Unit]: + def ofUnit(n: Int): VarArgsBuilder[Unit] = new VarArgsBuilder[Unit]: private val xs = new Array[Unit](n) - def result = ArraySeq.ofUnit(xs) + def result() = ArraySeq.ofUnit(xs) private var i = 0 def add(elem: Unit): this.type = xs(i) = elem @@ -220,4 +220,4 @@ object ArraySeqBuilder: i += 1 this -end ArraySeqBuilder \ No newline at end of file +end VarArgsBuilder \ No newline at end of file diff --git a/project/Build.scala b/project/Build.scala index 8045f58d468b..071f887dfeec 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1171,7 +1171,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), - file(s"${baseDirectory.value}/src/scala/runtime/ArraySeqBuilder.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"), ) ) lazy val `scala3-library-bootstrapped`: Project = project.in(file("library")).asDottyLibrary(Bootstrapped) @@ -1310,7 +1310,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/quoted/runtime/StopMacroExpansion.scala"), file(s"${baseDirectory.value}/src/scala/compiletime/Erased.scala"), file(s"${baseDirectory.value}/src/scala/annotation/internal/onlyCapability.scala"), - file(s"${baseDirectory.value}/src/scala/runtime/ArraySeqBuilder.scala"), + file(s"${baseDirectory.value}/src/scala/runtime/VarArgsBuilder.scala"), ) ) From 5907afb151bf146926890ec0e1320fd61f34f835 Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 19 Sep 2025 18:38:25 +0200 Subject: [PATCH 09/11] Survive -Ycheck if widening of array element types is needed --- .../src/dotty/tools/dotc/transform/PostTyper.scala | 11 ++++++----- tests/run/spreads-subtyping.check | 1 + tests/run/spreads-subtyping.scala | 7 +++++++ 3 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 tests/run/spreads-subtyping.check create mode 100644 tests/run/spreads-subtyping.scala diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index a347a9dd8afa..128655debc2e 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -414,7 +414,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => * * This then translates to * - * scala.runtime.ArraySeqBuilcder.ofInt(2 + xs.length + ys.length) + * scala.runtime.VarArgsBuilder.ofInt(2 + xs.length + ys.length) * .add(1) * .addSeq(xs) * .add(2) @@ -450,10 +450,11 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => elems.foldLeft(builder.appliedTo(totalLength)): (bldr, elem) => elem match case spread(arg) => - val selector = - if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq" - else "addArray" - bldr.select(selector.toTermName).appliedTo(arg) + if arg.tpe.derivesFrom(defn.SeqClass) then + bldr.select("addSeq".toTermName).appliedTo(arg) + else + bldr.select("addArray".toTermName).appliedTo( + arg.ensureConforms(defn.ArrayOf(elemType))) case _ => bldr.select("add".toTermName).appliedTo(elem) .select("result".toTermName) .appliedToNone diff --git a/tests/run/spreads-subtyping.check b/tests/run/spreads-subtyping.check new file mode 100644 index 000000000000..6e4e9402b9c0 --- /dev/null +++ b/tests/run/spreads-subtyping.check @@ -0,0 +1 @@ +ooffoobar diff --git a/tests/run/spreads-subtyping.scala b/tests/run/spreads-subtyping.scala new file mode 100644 index 000000000000..3680b2c19562 --- /dev/null +++ b/tests/run/spreads-subtyping.scala @@ -0,0 +1,7 @@ +import language.experimental.multiSpreads + +def foo(x: CharSequence*): String = x.mkString +val strings: Array[String] = Array("foo", "bar") + +@main def Test = + println(foo(("oof": CharSequence), strings*)) \ No newline at end of file From 3af6d6f1a8b58c710f8afcfba7a92bd1f0803f49 Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 26 Sep 2025 15:48:12 +0200 Subject: [PATCH 10/11] Drop empty file which was added by accident --- library/src/scala/compiletime/Spread.scala | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 library/src/scala/compiletime/Spread.scala diff --git a/library/src/scala/compiletime/Spread.scala b/library/src/scala/compiletime/Spread.scala deleted file mode 100644 index e69de29bb2d1..000000000000 From 61d0a781ce6db85d801928d5c0f43d3ee3fb933e Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 26 Sep 2025 18:28:24 +0200 Subject: [PATCH 11/11] Drop stray `// error` --- tests/run/i11419.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/run/i11419.scala b/tests/run/i11419.scala index 1f0a0591dbaf..be22b943861f 100644 --- a/tests/run/i11419.scala +++ b/tests/run/i11419.scala @@ -3,7 +3,7 @@ import language.experimental.multiSpreads object Test { def main(args: Array[String]): Unit = { val arr: Array[String] = Array("foo") - val lst = List("x", arr*) // error + val lst = List("x", arr*) println(lst) } }