Skip to content

Commit c216dd2

Browse files
committed
Print infix operations in infix form
1 parent 59b076b commit c216dd2

10 files changed

+84
-22
lines changed

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import dotty.tools.dotc.util.SourcePosition
3232
import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef}
3333
import cc.*
3434
import dotty.tools.dotc.parsing.JavaParsers
35+
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp
3536

3637
class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
3738

@@ -387,6 +388,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
387388

388389
def optDotPrefix(tree: This) = optText(tree.qual)(_ ~ ".") provided !isLocalThis(tree)
389390

391+
/** Should a binary operation with this operator be printed infix? */
392+
def isInfix(op: Symbol) =
393+
op.exists && (op.isDeclaredInfix || op.name.isOperatorName)
394+
390395
def caseBlockText(tree: Tree): Text = tree match {
391396
case Block(stats, expr) => toText(stats :+ expr, "\n")
392397
case expr => toText(expr)
@@ -478,6 +483,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
478483
optDotPrefix(tree) ~ keywordStr("this") ~ idText(tree)
479484
case Super(qual: This, mix) =>
480485
optDotPrefix(qual) ~ keywordStr("super") ~ optText(mix)("[" ~ _ ~ "]")
486+
case BinaryOp(l, op, r) if isInfix(op) =>
487+
val isRightAssoc = op.name.endsWith(":")
488+
val opPrec = parsing.precedence(op.name)
489+
val leftPrec = if isRightAssoc then opPrec + 1 else opPrec
490+
val rightPrec = if !isRightAssoc then opPrec + 1 else opPrec
491+
changePrec(opPrec):
492+
atPrec(leftPrec)(toText(l)) ~ " " ~ toText(op.name) ~ " " ~ atPrec(rightPrec)(toText(r))
481493
case app @ Apply(fun, args) =>
482494
if (fun.hasType && fun.symbol == defn.throwMethod)
483495
changePrec (GlobalPrec) {

tests/printing/export-param-flags.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
package <empty> {
33
final lazy module val A: A = new A()
44
final module class A() extends Object() { this: A.type =>
5-
inline def inlinedParam(inline x: Int): Int = x.+(x):Int
5+
inline def inlinedParam(inline x: Int): Int = x + x:Int
66
}
77
final lazy module val Exported: Exported = new Exported()
88
final module class Exported() extends Object() { this: Exported.type =>

tests/printing/infix-operations.check

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[[syntax trees at end of typer]] // tests/printing/infix-operations.scala
2+
package <empty> {
3+
class C(a: Int) extends Object() {
4+
private[this] val a: Int
5+
def foo(b: Int): C = this
6+
def +(b: Int): C = this
7+
}
8+
final lazy module val infix-operations$package: infix-operations$package =
9+
new infix-operations$package()
10+
final module class infix-operations$package() extends Object() {
11+
this: infix-operations$package.type =>
12+
@main def main: Unit =
13+
{
14+
val v1: Int = 1 + 2 + 3
15+
val v2: Int = 1 + (2 + 3)
16+
val v3: Int = 1 + 2 * 3
17+
val v4: Int = (1 + 2) * 3
18+
val v5: Int = 1 + 2:Int
19+
val v6: Int = 1 + 2:Int
20+
val c: C = new C(2)
21+
val v7: C = c.foo(3)
22+
val v8: C = c + 3
23+
()
24+
}
25+
}
26+
final class main() extends Object() {
27+
<static> def main(args: Array[String]): Unit =
28+
try main catch
29+
{
30+
case error @ _:scala.util.CommandLineParser.ParseError =>
31+
scala.util.CommandLineParser.showError(error)
32+
}
33+
}
34+
}
35+

tests/printing/infix-operations.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
class C(a: Int):
2+
def foo(b: Int): C = this
3+
def +(b: Int): C = this
4+
5+
@main def main =
6+
val v1 = 1 + 2 + 3
7+
val v2 = 1 + (2 + 3)
8+
val v3 = 1 + 2 * 3
9+
val v4 = (1 + 2) * 3
10+
val v5 = (1 + 2):Int
11+
val v6 = 1 + 2:Int // same as above
12+
13+
val c = new C(2) // must not be printed in infix form
14+
val v7 = c.foo(3) // must not be printed in infix form
15+
val v8 = c + 3

tests/printing/lambdas.check

+12-12
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,26 @@
22
package <empty> {
33
final lazy module val Main: Main = new Main()
44
final module class Main() extends Object() { this: Main.type =>
5-
val f1: Int => Int = (x: Int) => x.+(1)
6-
val f2: (Int, Int) => Int = (x: Int, y: Int) => x.+(y)
7-
val f3: Int => Int => Int = (x: Int) => (y: Int) => x.+(y)
8-
val f4: [T] => (x: Int) => Int = [T >: Nothing <: Any] => (x: Int) => x.+(1)
5+
val f1: Int => Int = (x: Int) => x + 1
6+
val f2: (Int, Int) => Int = (x: Int, y: Int) => x + y
7+
val f3: Int => Int => Int = (x: Int) => (y: Int) => x + y
8+
val f4: [T] => (x: Int) => Int = [T >: Nothing <: Any] => (x: Int) => x + 1
99
val f5: [T] => (x: Int) => Int => Int = [T >: Nothing <: Any] => (x: Int)
10-
=> (y: Int) => x.+(y)
10+
=> (y: Int) => x + y
1111
val f6: Int => Int = (x: Int) =>
1212
{
13-
val x2: Int = x.+(1)
14-
x2.+(1)
13+
val x2: Int = x + 1
14+
x2 + 1
1515
}
16-
def f7(x: Int): Int = x.+(1)
16+
def f7(x: Int): Int = x + 1
1717
val f8: Int => Int = (x: Int) => Main.f7(x)
1818
val l: List[Int] = List.apply[Int]([1,2,3 : Int]*)
19-
Main.l.map[Int]((_$1: Int) => _$1.+(1))
20-
Main.l.map[Int]((x: Int) => x.+(1))
19+
Main.l.map[Int]((_$1: Int) => _$1 + 1)
20+
Main.l.map[Int]((x: Int) => x + 1)
2121
Main.l.map[Int]((x: Int) =>
2222
{
23-
val x2: Int = x.+(1)
24-
x2.+(1)
23+
val x2: Int = x + 1
24+
x2 + 1
2525
}
2626
)
2727
Main.l.map[Int]((x: Int) => Main.f7(x))

tests/printing/posttyper/i22533.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ package <empty> {
1010
override def equals(x$0: Any): Boolean =
1111
x$0 match
1212
{
13-
case x$0 @ _:Foo @unchecked => this.u.==(x$0.u)
13+
case x$0 @ _:Foo @unchecked => this.u == x$0.u
1414
case _ => false
1515
}
1616
private[this] val u: Int

tests/printing/transformed/lazy-vals-legacy.check

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ package <empty> {
2626
{
2727
val flag: Long = scala.runtime.LazyVals.get(this, A.OFFSET$_m_0)
2828
val state: Long = scala.runtime.LazyVals.STATE(flag, 0)
29-
if state.==(3) then return A.x$lzy1 else
30-
if state.==(0) then
29+
if state == 3 then return A.x$lzy1 else
30+
if state == 0 then
3131
if scala.runtime.LazyVals.CAS(this, A.OFFSET$_m_0, flag, 1, 0)
3232
then
3333
try

tests/printing/transformed/lazy-vals-new.check

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ package <empty> {
2424
{
2525
val result: Object = A.x$lzy1
2626
if result.isInstanceOf[Int] then Int.unbox(result) else
27-
if result.eq(scala.runtime.LazyVals.NullValue) then Int.unbox(null)
27+
if result eq scala.runtime.LazyVals.NullValue then Int.unbox(null)
2828
else Int.unbox(A.x$lzyINIT1())
2929
}
3030
private def x$lzyINIT1(): Object =
3131
while <empty> do
3232
{
3333
val current: Object = A.x$lzy1
34-
if current.eq(null) then
34+
if current eq null then
3535
if
3636
scala.runtime.LazyVals.objCAS(this, A.OFFSET$_m_0, null,
3737
scala.runtime.LazyVals.Evaluating)
@@ -42,7 +42,7 @@ package <empty> {
4242
try
4343
{
4444
resultNullable = Int.box(2)
45-
if resultNullable.eq(null) then
45+
if resultNullable eq null then
4646
result = scala.runtime.LazyVals.NullValue else
4747
result = resultNullable
4848
()
@@ -69,7 +69,7 @@ package <empty> {
6969
current.isInstanceOf[
7070
scala.runtime.LazyVals.LazyVals$LazyValControlState]
7171
then
72-
if current.eq(scala.runtime.LazyVals.Evaluating) then
72+
if current eq scala.runtime.LazyVals.Evaluating then
7373
{
7474
scala.runtime.LazyVals.objCAS(this, A.OFFSET$_m_0, current,
7575
new scala.runtime.LazyVals.LazyVals$Waiting())

tests/run/i13181.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
@main def Test = assert(scala.compiletime.codeOf(1+2) == "1.+(2)")
1+
@main def Test = assert(scala.compiletime.codeOf(1+2) == "1 + 2")

tests/run/typeCheckErrors.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
List(Error(value check is not a member of Unit,compileError("1" * 2).check(""),22,Typer), Error(argument to compileError must be a statically known String but was: augmentString("1").*(2),compileError("1" * 2).check(""),13,Typer))
1+
List(Error(value check is not a member of Unit,compileError("1" * 2).check(""),22,Typer), Error(argument to compileError must be a statically known String but was: augmentString("1") * 2,compileError("1" * 2).check(""),13,Typer))

0 commit comments

Comments
 (0)