Skip to content

Commit f691b84

Browse files
committed
Add support for field accesses and constructors
1 parent ae30c48 commit f691b84

File tree

10 files changed

+202
-43
lines changed

10 files changed

+202
-43
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
472472
def New(tpt: Tree, argss: List[List[Tree]])(using Context): Tree =
473473
ensureApplied(argss.foldLeft(makeNew(tpt))(Apply(_, _)))
474474

475-
/** A new expression with constrictor and possibly type arguments. See
475+
/** A new expression with constructor and possibly type arguments. See
476476
* `New(tpt, argss)` for details.
477477
*/
478478
def makeNew(tpt: Tree)(using Context): Tree = {

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ class Definitions {
662662
@tu lazy val StringClass: ClassSymbol = requiredClass("java.lang.String")
663663
def StringType: Type = StringClass.typeRef
664664
@tu lazy val StringModule: Symbol = StringClass.linkedClass
665+
@tu lazy val String_== : TermSymbol = enterMethod(StringClass, nme.EQ, methOfAnyRef(BooleanType), Final)
665666
@tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final)
666667
@tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match {
667668
case List(pt) => pt.isAny || pt.stripNull().isAnyRef

compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dotty.tools.dotc.qualified_types
22

33
import scala.collection.mutable
44
import scala.collection.mutable.ArrayBuffer
5+
import scala.collection.mutable.ListBuffer
56

67
import dotty.tools.dotc.ast.tpd.{
78
closureDef,
@@ -25,13 +26,16 @@ import dotty.tools.dotc.core.Constants.Constant
2526
import dotty.tools.dotc.core.Contexts.Context
2627
import dotty.tools.dotc.core.Contexts.ctx
2728
import dotty.tools.dotc.core.Decorators.i
29+
import dotty.tools.dotc.core.Flags
2830
import dotty.tools.dotc.core.Hashable.Binders
2931
import dotty.tools.dotc.core.Names.Designator
3032
import dotty.tools.dotc.core.StdNames.nme
3133
import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol}
3234
import dotty.tools.dotc.core.Types.{
35+
AppliedType,
3336
CachedProxyType,
3437
ConstantType,
38+
LambdaType,
3539
MethodType,
3640
NamedType,
3741
NoPrefix,
@@ -40,14 +44,14 @@ import dotty.tools.dotc.core.Types.{
4044
TermParamRef,
4145
TermRef,
4246
Type,
47+
TypeRef,
4348
TypeVar,
4449
ValueType
4550
}
4651
import dotty.tools.dotc.qualified_types.ENode.Op
4752
import dotty.tools.dotc.reporting.trace
4853
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp
4954
import dotty.tools.dotc.util.Spans.Span
50-
import scala.collection.mutable.ListBuffer
5155

5256
final class EGraph(rootCtx: Context):
5357

@@ -92,7 +96,7 @@ final class EGraph(rootCtx: Context):
9296
private val builtinOps = Map(
9397
d.Int_== -> Op.Equal,
9498
d.Boolean_== -> Op.Equal,
95-
d.Any_== -> Op.Equal,
99+
d.String_== -> Op.Equal,
96100
d.Boolean_&& -> Op.And,
97101
d.Boolean_|| -> Op.Or,
98102
d.Boolean_! -> Op.Not,
@@ -108,9 +112,8 @@ final class EGraph(rootCtx: Context):
108112

109113
def equiv(node1: ENode, node2: ENode)(using Context): Boolean =
110114
trace(i"EGraph.equiv", Printers.qualifiedTypes):
111-
val margin = ctx.base.indentTab * (ctx.base.indent)
115+
// val margin = ctx.base.indentTab * (ctx.base.indent)
112116
// println(s"$margin node1: $node1\n$margin node2: $node2")
113-
// Check if the representents of both nodes are the same
114117
val repr1 = representent(node1)
115118
val repr2 = representent(node2)
116119
repr1 eq repr2
@@ -121,8 +124,8 @@ final class EGraph(rootCtx: Context):
121124
node match
122125
case ENode.Atom(tp) =>
123126
()
124-
case ENode.New(clazz) =>
125-
addUse(clazz, node)
127+
case ENode.Constructor(sym) =>
128+
()
126129
case ENode.Select(qual, member) =>
127130
addUse(qual, node)
128131
case ENode.Apply(fn, args) =>
@@ -138,6 +141,7 @@ final class EGraph(rootCtx: Context):
138141
}
139142
).asInstanceOf[node.type]
140143

144+
// TODO(mbovel): Memoize this
141145
def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using
142146
Context
143147
): Option[ENode] =
@@ -165,16 +169,18 @@ final class EGraph(rootCtx: Context):
165169
tree match
166170
case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] =>
167171
Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType]))
168-
case New(clazz) =>
169-
for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode.New(clazzNode)
172+
case Select(New(_), nme.CONSTRUCTOR) =>
173+
constructorNode(tree.symbol)
174+
case tree: Select if isCaseClassApply(tree.symbol) =>
175+
constructorNode(tree.symbol.owner.linkedClass.primaryConstructor)
170176
case Select(qual, name) =>
171-
for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode.Select(qualNode, tree.symbol)
177+
for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect(qualNode, tree.symbol)
172178
case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) =>
173179
for
174180
lhsNode <- toNode(lhs, paramSyms, paramTps)
175181
rhsNode <- toNode(rhs, paramSyms, paramTps)
176182
yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode))
177-
case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] =>
183+
case BinaryOp(lhs, d.Int_-, rhs) =>
178184
for
179185
lhsNode <- toNode(lhs, paramSyms, paramTps)
180186
rhsNode <- toNode(rhs, paramSyms, paramTps)
@@ -192,7 +198,7 @@ final class EGraph(rootCtx: Context):
192198
case mt: MethodType =>
193199
assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?")
194200
val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol)
195-
val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty
201+
val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty
196202
val paramTpsSize = paramTps.size
197203
for myParamSym <- myParamSyms do
198204
val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
@@ -204,15 +210,38 @@ final class EGraph(rootCtx: Context):
204210
case _ =>
205211
None
206212

213+
// TODO(mbovel): Memoize this
214+
private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] =
215+
val clazz = constr.owner
216+
if hasStructuralEquality(clazz) then
217+
val isPrimaryConstructor = constr.denot.isPrimaryConstructor
218+
val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember)
219+
val constrParams = constr.paramSymss.flatten.filter(_.isTerm)
220+
val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol))
221+
Some(ENode.Constructor(constr)(fields))
222+
else
223+
None
224+
225+
private def hasStructuralEquality(clazz: Symbol)(using Context): Boolean =
226+
val equalsMethod = clazz.info.decls.lookup(nme.equals_)
227+
val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic)
228+
clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden
229+
230+
private def isCaseClassApply(meth: Symbol)(using Context): Boolean =
231+
meth.name == nme.apply
232+
&& meth.flags.is(Flags.Synthetic)
233+
&& meth.owner.linkedClass.is(Flags.Case)
234+
207235
private def canonicalize(node: ENode): ENode =
236+
// println(s"canonicalize $node")
208237
representent(unique(
209238
node match
210239
case ENode.Atom(tp) =>
211240
node
212-
case ENode.New(clazz) =>
213-
ENode.New(representent(clazz))
241+
case ENode.Constructor(sym) =>
242+
node
214243
case ENode.Select(qual, member) =>
215-
ENode.Select(representent(qual), member)
244+
normalizeSelect(representent(qual), member)
216245
case ENode.Apply(fn, args) =>
217246
ENode.Apply(representent(fn), args.map(representent))
218247
case ENode.OpApply(op, args) =>
@@ -223,6 +252,33 @@ final class EGraph(rootCtx: Context):
223252
ENode.Lambda(paramTps, retTp, representent(body))
224253
))
225254

255+
private def normalizeSelect(qual: ENode, member: Symbol): ENode =
256+
getAppliedConstructor(qual) match
257+
case Some(constr) =>
258+
val memberIndex = constr.fields.indexOf(member)
259+
260+
if memberIndex >= 0 then
261+
val args = getTermArguments(qual)
262+
assert(args.size == constr.fields.size)
263+
args(memberIndex)
264+
else
265+
ENode.Select(qual, member)
266+
case None =>
267+
ENode.Select(qual, member)
268+
269+
private def getAppliedConstructor(node: ENode): Option[ENode.Constructor] =
270+
node match
271+
case ENode.Apply(fn, args) => getAppliedConstructor(fn)
272+
case ENode.TypeApply(fn, args) => getAppliedConstructor(fn)
273+
case node: ENode.Constructor => Some(node)
274+
case _ => None
275+
276+
private def getTermArguments(node: ENode): List[ENode] =
277+
node match
278+
case ENode.Apply(fn, args) => getTermArguments(fn) ::: args
279+
case ENode.TypeApply(fn, args) => getTermArguments(fn)
280+
case _ => Nil
281+
226282
private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode =
227283
op match
228284
case Op.Equal =>
@@ -316,12 +372,10 @@ final class EGraph(rootCtx: Context):
316372
(a, b) match
317373
case (ENode.Atom(_: ConstantType), _) => (a, b)
318374
case (_, ENode.Atom(_: ConstantType)) => (b, a)
319-
case (ENode.Atom(_: SkolemType), _) => (a, b)
320-
case (_, ENode.Atom(_: SkolemType)) => (b, a)
375+
case (_: ENode.Constructor, _) => (a, b)
376+
case (_, _: ENode.Constructor) => (b, a)
321377
case (_: ENode.Atom, _) => (a, b)
322378
case (_, _: ENode.Atom) => (b, a)
323-
case (_: ENode.New, _) => (a, b)
324-
case (_, _: ENode.New) => (b, a)
325379
case (_: ENode.Select, _) => (a, b)
326380
case (_, _: ENode.Select) => (b, a)
327381
case (_: ENode.Apply, _) => (a, b)
@@ -336,8 +390,6 @@ final class EGraph(rootCtx: Context):
336390
if aRepr eq bRepr then return
337391
assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`")
338392

339-
// TODO(mbovel): if both nodes are objects, recursively merge their arguments
340-
341393
/// Update represententOf and usedBy maps
342394
val (newRepr, oldRepr) = order(aRepr, bRepr)
343395
represententOf(oldRepr) = newRepr
@@ -371,8 +423,9 @@ final class EGraph(rootCtx: Context):
371423
node match
372424
case ENode.Atom(tp) =>
373425
singleton(tp)
374-
case ENode.New(clazz) =>
375-
New(toTree(clazz, paramRefs))
426+
case ENode.Constructor(sym) =>
427+
val tycon = sym.owner.info.typeConstructor
428+
New(tycon).select(TermRef(tycon, sym))
376429
case ENode.Select(qual, member) =>
377430
toTree(qual, paramRefs).select(member)
378431
case ENode.Apply(fn, args) =>

compiler/src/dotty/tools/dotc/qualified_types/ENode.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ enum ENode:
2323
import ENode.*
2424

2525
case Atom(tp: SingletonType)
26-
case New(clazz: ENode)
26+
case Constructor(constr: Symbol)(val fields: List[Symbol])
2727
case Select(qual: ENode, member: Symbol)
2828
case Apply(fn: ENode, args: List[ENode])
2929
case OpApply(fn: ENode.Op, args: List[ENode])
@@ -33,7 +33,7 @@ enum ENode:
3333
override def toString(): String =
3434
this match
3535
case Atom(tp) => typeToString(tp)
36-
case New(clazz) => s"new $clazz"
36+
case Constructor(constr) => s"new ${designatorToString(constr.lastKnownDenotation.owner)}"
3737
case Select(qual, member) => s"$qual.${designatorToString(member)}"
3838
case Apply(fn, args) => s"$fn(${args.mkString(", ")})"
3939
case OpApply(op, args) => s"(${args.mkString(op.operatorString())})"

compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ class QualifierSolver(using Context):
5353
case _ => ()
5454

5555
val egraph = EGraph(ctx)
56-
//println(s"tree implies $tree1 -> $tree2")
56+
// println(s"tree implies $tree1 -> $tree2")
5757
(egraph.toNode(QualifierEvaluator.evaluate(tree1)), egraph.toNode(QualifierEvaluator.evaluate(tree2))) match
5858
case (Some(node1), Some(node2)) =>
59-
//println(s"node implies $node1 -> $node2")
59+
// println(s"node implies $node1 -> $node2")
6060
egraph.merge(node1, egraph.trueNode)
6161
egraph.repair()
6262
egraph.equiv(node2, egraph.trueNode)

tests/neg-custom-args/qualified-types/adapt_neg.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ def test: Unit =
1212
val v3: {v: Int with v == x + 1} = x + 2 // error
1313
val v4: {v: Int with v == f(x)} = g(x) // error
1414
val v5: {v: Int with v == g(x)} = f(x) // error
15-
//val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented
16-
//val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented
15+
val v6: {v: IntBox with v == IntBox(x)} = IntBox(x + 1) // error
16+
val v7: {v: Box[Int] with v == Box(x)} = Box(x + 1) // error
1717
val v8: {v: Int with v == x + f(x)} = x + g(x) // error
1818
val v9: {v: Int with v == x + g(x)} = x + f(x) // error
1919
val v10: {v: Int with v == f(x + 1)} = f(x + 2) // error
2020
val v11: {v: Int with v == g(x + 1)} = g(x + 2) // error
21-
//val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented
22-
//val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented
21+
val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x) // error
22+
val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x) // error
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
class Box[T](val x: T)
2+
3+
class BoxMutable[T](var x: T)
4+
5+
class Foo(val id: String):
6+
def this(x: Int) = this(x.toString)
7+
8+
class Person(val name: String, val age: Int)
9+
10+
class PersonCurried(val name: String)(val age: Int)
11+
12+
class PersonMutable(val name: String, var age: Int)
13+
14+
case class PersonCaseMutable(name: String, var age: Int)
15+
16+
case class PersonCaseSecondary(name: String, age: Int):
17+
def this(name: String) = this(name, 0)
18+
19+
case class PersonCaseEqualsOverriden(name: String, age: Int):
20+
override def equals(that: Object): Boolean = this eq that
21+
22+
def test: Unit =
23+
summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error
24+
25+
summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error
26+
// TODO(mbovel): restrict selection to stable members
27+
//summon[{b: BoxMutable[Int] with b.x == 3} =:= {b: BoxMutable[Int] with b.x == 3}]
28+
29+
summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error
30+
summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error
31+
summon[{s: String with Foo("hello").id == s} =:= {s: String with s == "hello"}] // error
32+
33+
summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error
34+
summon[{s: String with Person("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
35+
summon[{n: Int with Person("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
36+
37+
summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error
38+
summon[{s: String with PersonCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] // error
39+
summon[{n: Int with PersonCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] // error
40+
41+
summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error
42+
summon[{s: String with PersonMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
43+
summon[{n: Int with PersonMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
44+
45+
summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
46+
47+
summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error
48+
summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error
49+
50+
summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error
51+
summon[{s: String with PersonCaseEqualsOverriden("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
52+
summon[{n: Int with PersonCaseEqualsOverriden("Alice", 30).age == n} =:= {n: Int with n == 30}] // error

tests/pos-custom-args/qualified-types/adapt.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ def f(x: Int): Int = ???
22
case class IntBox(x: Int)
33
case class Box[T](x: T)
44

5-
65
def f(x: Int, y: Int): {r: Int with r == x + y} = x + y
76

87
def test: Unit =
@@ -12,11 +11,11 @@ def test: Unit =
1211
val v1: {v: Int with v == x + 1} = x + 1
1312
val v2: {v: Int with v == f(x)} = f(x)
1413
val v3: {v: Int with v == g(x)} = g(x)
15-
//val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented
16-
//val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented
17-
val v4: {v: Int with v == x + f(x)} = x + f(x)
18-
val v5: {v: Int with v == x + g(x)} = x + g(x)
19-
val v6: {v: Int with v == f(x + 1)} = f(x + 1)
20-
val v7: {v: Int with v == g(x + 1)} = g(x + 1)
21-
//val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented
22-
//val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented
14+
val v4: {v: IntBox with v == IntBox(x)} = IntBox(x)
15+
val v5: {v: Box[Int] with v == Box(x)} = Box(x)
16+
val v6: {v: Int with v == x + f(x)} = x + f(x)
17+
val v7: {v: Int with v == x + g(x)} = x + g(x)
18+
val v8: {v: Int with v == f(x + 1)} = f(x + 1)
19+
val v9: {v: Int with v == g(x + 1)} = g(x + 1)
20+
val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x + 1)
21+
val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x + 1)

tests/pos-custom-args/qualified-types/sized_lists.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
2-
31
def size(v: Vec): Int = ???
42
type Vec
53

6-
74
def vec(s: Int): {v: Vec with size(v) == s} = ???
85
def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ???
96
def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ???

0 commit comments

Comments
 (0)