@@ -2,6 +2,7 @@ package dotty.tools.dotc.qualified_types
2
2
3
3
import scala .collection .mutable
4
4
import scala .collection .mutable .ArrayBuffer
5
+ import scala .collection .mutable .ListBuffer
5
6
6
7
import dotty .tools .dotc .ast .tpd .{
7
8
closureDef ,
@@ -25,13 +26,16 @@ import dotty.tools.dotc.core.Constants.Constant
25
26
import dotty .tools .dotc .core .Contexts .Context
26
27
import dotty .tools .dotc .core .Contexts .ctx
27
28
import dotty .tools .dotc .core .Decorators .i
29
+ import dotty .tools .dotc .core .Flags
28
30
import dotty .tools .dotc .core .Hashable .Binders
29
31
import dotty .tools .dotc .core .Names .Designator
30
32
import dotty .tools .dotc .core .StdNames .nme
31
33
import dotty .tools .dotc .core .Symbols .{defn , NoSymbol , Symbol }
32
34
import dotty .tools .dotc .core .Types .{
35
+ AppliedType ,
33
36
CachedProxyType ,
34
37
ConstantType ,
38
+ LambdaType ,
35
39
MethodType ,
36
40
NamedType ,
37
41
NoPrefix ,
@@ -40,14 +44,14 @@ import dotty.tools.dotc.core.Types.{
40
44
TermParamRef ,
41
45
TermRef ,
42
46
Type ,
47
+ TypeRef ,
43
48
TypeVar ,
44
49
ValueType
45
50
}
46
51
import dotty .tools .dotc .qualified_types .ENode .Op
47
52
import dotty .tools .dotc .reporting .trace
48
53
import dotty .tools .dotc .transform .TreeExtractors .BinaryOp
49
54
import dotty .tools .dotc .util .Spans .Span
50
- import scala .collection .mutable .ListBuffer
51
55
52
56
final class EGraph (rootCtx : Context ):
53
57
@@ -92,7 +96,7 @@ final class EGraph(rootCtx: Context):
92
96
private val builtinOps = Map (
93
97
d.Int_== -> Op .Equal ,
94
98
d.Boolean_== -> Op .Equal ,
95
- d.Any_ == -> Op .Equal ,
99
+ d.String_ == -> Op .Equal ,
96
100
d.Boolean_&& -> Op .And ,
97
101
d.Boolean_|| -> Op .Or ,
98
102
d.Boolean_! -> Op .Not ,
@@ -108,9 +112,8 @@ final class EGraph(rootCtx: Context):
108
112
109
113
def equiv (node1 : ENode , node2 : ENode )(using Context ): Boolean =
110
114
trace(i " EGraph.equiv " , Printers .qualifiedTypes):
111
- val margin = ctx.base.indentTab * (ctx.base.indent)
115
+ // val margin = ctx.base.indentTab * (ctx.base.indent)
112
116
// println(s"$margin node1: $node1\n$margin node2: $node2")
113
- // Check if the representents of both nodes are the same
114
117
val repr1 = representent(node1)
115
118
val repr2 = representent(node2)
116
119
repr1 eq repr2
@@ -121,8 +124,8 @@ final class EGraph(rootCtx: Context):
121
124
node match
122
125
case ENode .Atom (tp) =>
123
126
()
124
- case ENode .New (clazz ) =>
125
- addUse(clazz, node )
127
+ case ENode .Constructor (sym ) =>
128
+ ( )
126
129
case ENode .Select (qual, member) =>
127
130
addUse(qual, node)
128
131
case ENode .Apply (fn, args) =>
@@ -138,6 +141,7 @@ final class EGraph(rootCtx: Context):
138
141
}
139
142
).asInstanceOf [node.type ]
140
143
144
+ // TODO(mbovel): Memoize this
141
145
def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramTps : List [ENode .ArgRefType ] = Nil )(using
142
146
Context
143
147
): Option [ENode ] =
@@ -165,16 +169,18 @@ final class EGraph(rootCtx: Context):
165
169
tree match
166
170
case Literal (_) | Ident (_) | This (_) if tree.tpe.isInstanceOf [SingletonType ] =>
167
171
Some (ENode .Atom (mapType(tree.tpe).asInstanceOf [SingletonType ]))
168
- case New (clazz) =>
169
- for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode .New (clazzNode)
172
+ case Select (New (_), nme.CONSTRUCTOR ) =>
173
+ constructorNode(tree.symbol)
174
+ case tree : Select if isCaseClassApply(tree.symbol) =>
175
+ constructorNode(tree.symbol.owner.linkedClass.primaryConstructor)
170
176
case Select (qual, name) =>
171
- for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode . Select (qualNode, tree.symbol)
177
+ for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect (qualNode, tree.symbol)
172
178
case BinaryOp (lhs, op, rhs) if builtinOps.contains(op) =>
173
179
for
174
180
lhsNode <- toNode(lhs, paramSyms, paramTps)
175
181
rhsNode <- toNode(rhs, paramSyms, paramTps)
176
182
yield normalizeOp(builtinOps(op), List (lhsNode, rhsNode))
177
- case BinaryOp (lhs, d.Int_- , rhs) if lhs.tpe. isInstanceOf [ ValueType ] && rhs.tpe. isInstanceOf [ ValueType ] =>
183
+ case BinaryOp (lhs, d.Int_- , rhs) =>
178
184
for
179
185
lhsNode <- toNode(lhs, paramSyms, paramTps)
180
186
rhsNode <- toNode(rhs, paramSyms, paramTps)
@@ -192,7 +198,7 @@ final class EGraph(rootCtx: Context):
192
198
case mt : MethodType =>
193
199
assert(defDef.termParamss.size == 1 , " closures have a single parameter list, right?" )
194
200
val myParamSyms : List [Symbol ] = defDef.termParamss.head.map(_.symbol)
195
- val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
201
+ val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
196
202
val paramTpsSize = paramTps.size
197
203
for myParamSym <- myParamSyms do
198
204
val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
@@ -204,15 +210,38 @@ final class EGraph(rootCtx: Context):
204
210
case _ =>
205
211
None
206
212
213
+ // TODO(mbovel): Memoize this
214
+ private def constructorNode (constr : Symbol )(using Context ): Option [ENode .Constructor ] =
215
+ val clazz = constr.owner
216
+ if hasStructuralEquality(clazz) then
217
+ val isPrimaryConstructor = constr.denot.isPrimaryConstructor
218
+ val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember)
219
+ val constrParams = constr.paramSymss.flatten.filter(_.isTerm)
220
+ val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol ))
221
+ Some (ENode .Constructor (constr)(fields))
222
+ else
223
+ None
224
+
225
+ private def hasStructuralEquality (clazz : Symbol )(using Context ): Boolean =
226
+ val equalsMethod = clazz.info.decls.lookup(nme.equals_)
227
+ val equalsNotOverriden = ! equalsMethod.exists || equalsMethod.is(Flags .Synthetic )
228
+ clazz.isClass && clazz.is(Flags .Case ) && equalsNotOverriden
229
+
230
+ private def isCaseClassApply (meth : Symbol )(using Context ): Boolean =
231
+ meth.name == nme.apply
232
+ && meth.flags.is(Flags .Synthetic )
233
+ && meth.owner.linkedClass.is(Flags .Case )
234
+
207
235
private def canonicalize (node : ENode ): ENode =
236
+ // println(s"canonicalize $node")
208
237
representent(unique(
209
238
node match
210
239
case ENode .Atom (tp) =>
211
240
node
212
- case ENode .New (clazz ) =>
213
- ENode . New (representent(clazz))
241
+ case ENode .Constructor (sym ) =>
242
+ node
214
243
case ENode .Select (qual, member) =>
215
- ENode . Select (representent(qual), member)
244
+ normalizeSelect (representent(qual), member)
216
245
case ENode .Apply (fn, args) =>
217
246
ENode .Apply (representent(fn), args.map(representent))
218
247
case ENode .OpApply (op, args) =>
@@ -223,6 +252,33 @@ final class EGraph(rootCtx: Context):
223
252
ENode .Lambda (paramTps, retTp, representent(body))
224
253
))
225
254
255
+ private def normalizeSelect (qual : ENode , member : Symbol ): ENode =
256
+ getAppliedConstructor(qual) match
257
+ case Some (constr) =>
258
+ val memberIndex = constr.fields.indexOf(member)
259
+
260
+ if memberIndex >= 0 then
261
+ val args = getTermArguments(qual)
262
+ assert(args.size == constr.fields.size)
263
+ args(memberIndex)
264
+ else
265
+ ENode .Select (qual, member)
266
+ case None =>
267
+ ENode .Select (qual, member)
268
+
269
+ private def getAppliedConstructor (node : ENode ): Option [ENode .Constructor ] =
270
+ node match
271
+ case ENode .Apply (fn, args) => getAppliedConstructor(fn)
272
+ case ENode .TypeApply (fn, args) => getAppliedConstructor(fn)
273
+ case node : ENode .Constructor => Some (node)
274
+ case _ => None
275
+
276
+ private def getTermArguments (node : ENode ): List [ENode ] =
277
+ node match
278
+ case ENode .Apply (fn, args) => getTermArguments(fn) ::: args
279
+ case ENode .TypeApply (fn, args) => getTermArguments(fn)
280
+ case _ => Nil
281
+
226
282
private def normalizeOp (op : ENode .Op , args : List [ENode ]): ENode =
227
283
op match
228
284
case Op .Equal =>
@@ -316,12 +372,10 @@ final class EGraph(rootCtx: Context):
316
372
(a, b) match
317
373
case (ENode .Atom (_ : ConstantType ), _) => (a, b)
318
374
case (_, ENode .Atom (_ : ConstantType )) => (b, a)
319
- case (ENode . Atom ( _ : SkolemType ) , _) => (a, b)
320
- case (_, ENode . Atom ( _ : SkolemType )) => (b, a)
375
+ case (_ : ENode . Constructor , _) => (a, b)
376
+ case (_, _ : ENode . Constructor ) => (b, a)
321
377
case (_ : ENode .Atom , _) => (a, b)
322
378
case (_, _ : ENode .Atom ) => (b, a)
323
- case (_ : ENode .New , _) => (a, b)
324
- case (_, _ : ENode .New ) => (b, a)
325
379
case (_ : ENode .Select , _) => (a, b)
326
380
case (_, _ : ENode .Select ) => (b, a)
327
381
case (_ : ENode .Apply , _) => (a, b)
@@ -336,8 +390,6 @@ final class EGraph(rootCtx: Context):
336
390
if aRepr eq bRepr then return
337
391
assert(aRepr != bRepr, s " $aRepr and $bRepr are `equals` but not `eq` " )
338
392
339
- // TODO(mbovel): if both nodes are objects, recursively merge their arguments
340
-
341
393
// / Update represententOf and usedBy maps
342
394
val (newRepr, oldRepr) = order(aRepr, bRepr)
343
395
represententOf(oldRepr) = newRepr
@@ -371,8 +423,9 @@ final class EGraph(rootCtx: Context):
371
423
node match
372
424
case ENode .Atom (tp) =>
373
425
singleton(tp)
374
- case ENode .New (clazz) =>
375
- New (toTree(clazz, paramRefs))
426
+ case ENode .Constructor (sym) =>
427
+ val tycon = sym.owner.info.typeConstructor
428
+ New (tycon).select(TermRef (tycon, sym))
376
429
case ENode .Select (qual, member) =>
377
430
toTree(qual, paramRefs).select(member)
378
431
case ENode .Apply (fn, args) =>
0 commit comments