Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2263d9f

Browse files
committedMar 15, 2023
No longer translate literal patterns to == applications
1 parent aab8d27 commit 2263d9f

12 files changed

+245
-182
lines changed
 

‎shared/src/main/scala/mlscript/ucs/Clause.scala

+27-18
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,46 @@ abstract class Clause {
2121
* @return
2222
*/
2323
val locations: Ls[Loc]
24+
25+
protected final def bindingsToString: String =
26+
(if (bindings.isEmpty) "" else " with " + Clause.showBindings(bindings))
2427
}
2528

2629
object Clause {
30+
final case class MatchLiteral(
31+
scrutinee: Scrutinee,
32+
literal: SimpleTerm
33+
)(override val locations: Ls[Loc]) extends Clause {
34+
override def toString(): String = s"«$scrutinee is $literal" + bindingsToString
35+
}
36+
2737
final case class MatchClass(
2838
scrutinee: Scrutinee,
2939
className: Var,
3040
fields: Ls[Str -> Var]
31-
)(override val locations: Ls[Loc]) extends Clause
41+
)(override val locations: Ls[Loc]) extends Clause {
42+
override def toString(): String = s"«$scrutinee is $className»" + bindingsToString
43+
}
3244

3345
final case class MatchTuple(
3446
scrutinee: Scrutinee,
3547
arity: Int,
3648
fields: Ls[Str -> Var]
37-
)(override val locations: Ls[Loc]) extends Clause
49+
)(override val locations: Ls[Loc]) extends Clause {
50+
override def toString(): String = s"«$scrutinee is Tuple#$arity»" + bindingsToString
51+
}
3852

39-
final case class BooleanTest(test: Term)(override val locations: Ls[Loc]) extends Clause
53+
final case class BooleanTest(test: Term)(
54+
override val locations: Ls[Loc]
55+
) extends Clause {
56+
override def toString(): String = s"«$test»" + bindingsToString
57+
}
4058

41-
final case class Binding(name: Var, term: Term)(override val locations: Ls[Loc]) extends Clause
59+
final case class Binding(name: Var, term: Term)(
60+
override val locations: Ls[Loc]
61+
) extends Clause {
62+
override def toString(): String = s"«$name = $term»" + bindingsToString
63+
}
4264

4365
def showBindings(bindings: Ls[(Bool, Var, Term)]): Str =
4466
bindings match {
@@ -48,20 +70,7 @@ object Clause {
4870
}.mkString("(", ", ", ")")
4971
}
5072

51-
52-
def showClauses(clauses: Iterable[Clause]): Str = {
53-
clauses.iterator.map { clause =>
54-
(clause match {
55-
case Clause.BooleanTest(test) => s"«$test»"
56-
case Clause.MatchClass(scrutinee, Var(className), fields) =>
57-
s"«$scrutinee is $className»"
58-
case Clause.MatchTuple(scrutinee, arity, fields) =>
59-
s"«$scrutinee is Tuple#$arity»"
60-
case Clause.Binding(Var(name), term) =>
61-
s"«$name = $term»"
62-
}) + (if (clause.bindings.isEmpty) "" else " with " + showBindings(clause.bindings))
63-
}.mkString("", " and ", "")
64-
}
73+
def showClauses(clauses: Iterable[Clause]): Str = clauses.mkString("", " and ", "")
6574

6675
def print(println: (=> Any) => Unit, conjunctions: Iterable[Conjunction -> Term]): Unit = {
6776
println("Flattened conjunctions")

‎shared/src/main/scala/mlscript/ucs/Conjunction.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mlscript.ucs
33
import mlscript._, utils._, shorthands._
44
import Clause._, helpers._
55
import scala.collection.mutable.Buffer
6+
import scala.annotation.tailrec
67

78
/**
89
* A `Conjunction` represents a list of `Clause`s.
@@ -53,13 +54,20 @@ final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[(Bool, Va
5354
def +(lastBinding: (Bool, Var, Term)): Conjunction =
5455
Conjunction(clauses, trailingBindings :+ lastBinding)
5556

56-
def separate(expectedScrutinee: Scrutinee): Opt[(MatchClass, Conjunction)] = {
57-
def rec(past: Ls[Clause], upcoming: Ls[Clause]): Opt[(Ls[Clause], MatchClass, Ls[Clause])] = {
57+
def separate(expectedScrutinee: Scrutinee): Opt[(MatchClass \/ MatchLiteral, Conjunction)] = {
58+
@tailrec
59+
def rec(past: Ls[Clause], upcoming: Ls[Clause]): Opt[(Ls[Clause], MatchClass \/ MatchLiteral, Ls[Clause])] = {
5860
upcoming match {
5961
case Nil => N
62+
case (head @ MatchLiteral(scrutinee, _)) :: tail =>
63+
if (scrutinee === expectedScrutinee) {
64+
S((past, R(head), tail))
65+
} else {
66+
rec(past :+ head, tail)
67+
}
6068
case (head @ MatchClass(scrutinee, _, _)) :: tail =>
6169
if (scrutinee === expectedScrutinee) {
62-
S((past, head, tail))
70+
S((past, L(head), tail))
6371
} else {
6472
rec(past :+ head, tail)
6573
}

‎shared/src/main/scala/mlscript/ucs/Desugarer.scala

+93-52
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ class Desugarer extends TypeDefs { self: Typer =>
9999
def makeScrutinee(term: Term, matchRootLoc: Opt[Loc])(implicit ctx: Ctx): Scrutinee =
100100
traceUCS(s"Making a scrutinee for `$term`") {
101101
term match {
102-
case _: SimpleTerm => Scrutinee(N, term)(matchRootLoc)
103-
case _ => Scrutinee(S(makeLocalizedName(term)), term)(matchRootLoc)
102+
case _: Var =>
103+
printlnUCS(s"The scrutinee does not need an alias.")
104+
Scrutinee(N, term)(matchRootLoc)
105+
case _ =>
106+
val localizedName = makeLocalizedName(term)
107+
printlnUCS(s"The scrutinee needs an alias: $localizedName")
108+
Scrutinee(S(localizedName), term)(matchRootLoc)
104109
}
105110
}()
106111

@@ -160,9 +165,13 @@ class Desugarer extends TypeDefs { self: Typer =>
160165
case Var("_") => Nil
161166
// This case handles literals.
162167
// x is true | x is false | x is 0 | x is "text" | ...
163-
case literal @ (Var("true") | Var("false") | _: Lit) =>
164-
val test = mkBinOp(scrutinee.reference, Var("=="), literal)
165-
val clause = Clause.BooleanTest(test)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
168+
case literal: Var if literal.name === "true" || literal.name === "false" =>
169+
val clause = Clause.MatchLiteral(scrutinee, literal)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
170+
clause.bindings = scrutinee.asBinding.toList
171+
printlnUCS(s"Add bindings to the clause: ${scrutinee.asBinding}")
172+
clause :: Nil
173+
case literal: Lit =>
174+
val clause = Clause.MatchLiteral(scrutinee, literal)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
166175
clause.bindings = scrutinee.asBinding.toList
167176
printlnUCS(s"Add bindings to the clause: ${scrutinee.asBinding}")
168177
clause :: Nil
@@ -515,22 +524,34 @@ class Desugarer extends TypeDefs { self: Typer =>
515524
*/
516525
type ExhaustivenessMap = Map[Str \/ Int, Map[Var, MutCase]]
517526

518-
def getScurtineeKey(scrutinee: Scrutinee)(implicit ctx: Ctx, raise: Raise): Str \/ Int = {
519-
scrutinee.term match {
520-
// The original scrutinee is an reference.
521-
case v @ Var(name) =>
522-
ctx.env.get(name) match {
523-
case S(VarSymbol(_, defVar)) => defVar.uid.fold[Str \/ Int](L(v.name))(R(_))
524-
case S(_) | N => L(v.name)
525-
}
526-
// Otherwise, the scrutinee has a temporary name.
527-
case _ =>
528-
scrutinee.local match {
529-
case N => throw new Error("check your `makeScrutinee`")
530-
case S(localNameVar) => L(localNameVar.name)
531-
}
532-
}
533-
}
527+
/**
528+
* This method obtains a proper key of the given scrutinee
529+
* for memorizing patterns belongs to the scrutinee.
530+
*
531+
* @param scrutinee the scrutinee
532+
* @param ctx the context
533+
* @param raise we need this to raise errors.
534+
* @return the variable name or the variable ID
535+
*/
536+
def getScurtineeKey(scrutinee: Scrutinee)(implicit ctx: Ctx, raise: Raise): Str \/ Int =
537+
traceUCS(s"[getScrutineeKey] $scrutinee") {
538+
scrutinee.term match {
539+
// The original scrutinee is an reference.
540+
case v @ Var(name) =>
541+
printlnUCS("The original scrutinee is an reference.")
542+
ctx.env.get(name) match {
543+
case S(VarSymbol(_, defVar)) => defVar.uid.fold[Str \/ Int](L(v.name))(R(_))
544+
case S(_) | N => L(v.name)
545+
}
546+
// Otherwise, the scrutinee was localized because it might be effectful.
547+
case _ =>
548+
printlnUCS("The scrutinee was localized because it might be effectful.")
549+
scrutinee.local match {
550+
case N => throw new Error("check your `makeScrutinee`")
551+
case S(localNameVar) => L(localNameVar.name)
552+
}
553+
}
554+
}()
534555

535556
/**
536557
* Check the exhaustiveness of the given `MutCaseOf`.
@@ -542,10 +563,8 @@ class Desugarer extends TypeDefs { self: Typer =>
542563
def checkExhaustive
543564
(t: MutCaseOf, parentOpt: Opt[MutCaseOf])
544565
(implicit scrutineePatternMap: ExhaustivenessMap, ctx: Ctx, raise: Raise)
545-
: Unit = {
546-
printlnUCS(s"Check exhaustiveness of ${t.describe}")
547-
indent += 1
548-
try t match {
566+
: Unit = traceUCS(s"[checkExhaustive] ${t.describe}") {
567+
t match {
549568
case _: Consequent => ()
550569
case MissingCase =>
551570
parentOpt match {
@@ -567,18 +586,26 @@ class Desugarer extends TypeDefs { self: Typer =>
567586
case S(_) if default.isDefined =>
568587
printlnUCS("The match has a default branch. So, it is always safe.")
569588
case S(patternMap) =>
570-
printlnUCS(s"The exhaustiveness map is ${scrutineePatternMap}")
589+
printlnUCS(s"The exhaustiveness map is")
590+
scrutineePatternMap.foreach { case (key, matches) =>
591+
printlnUCS(s"- $key -> ${matches.keysIterator.mkString(", ")}")
592+
}
571593
printlnUCS(s"The scrutinee key is ${getScurtineeKey(scrutinee)}")
572594
printlnUCS("Pattern map of the scrutinee:")
573595
if (patternMap.isEmpty)
574596
printlnUCS("<Empty>")
575597
else
576598
patternMap.foreach { case (key, mutCase) => printlnUCS(s"- $key => $mutCase")}
577599
// Filter out missing cases in `branches`.
578-
val missingCases = patternMap.removedAll(branches.iterator.map {
579-
case MutCase(classNameVar -> _, _) => classNameVar
600+
val missingCases = patternMap.removedAll(branches.iterator.flatMap {
601+
case MutCase.Literal(tof @ Var(n), _) if n === "true" || n === "false" => Some(tof)
602+
case MutCase.Literal(_, _) => None
603+
case MutCase.Constructor(classNameVar -> _, _) => Some(classNameVar)
580604
})
581-
printlnUCS(s"Number of missing cases: ${missingCases.size}")
605+
printlnUCS("Missing cases")
606+
missingCases.foreach { case (key, m) =>
607+
printlnUCS(s"- $key -> ${m}")
608+
}
582609
if (!missingCases.isEmpty) {
583610
throw new DesugaringException({
584611
val numMissingCases = missingCases.size
@@ -597,53 +624,67 @@ class Desugarer extends TypeDefs { self: Typer =>
597624
}
598625
}
599626
default.foreach(checkExhaustive(_, S(t)))
600-
branches.foreach { case MutCase(_, consequent) =>
601-
checkExhaustive(consequent, S(t))
627+
branches.foreach { branch =>
628+
checkExhaustive(branch.consequent, S(t))
602629
}
603-
} finally indent -= 1
604-
}
630+
}
631+
}()
605632

606-
def summarizePatterns(t: MutCaseOf)(implicit ctx: Ctx, raise: Raise): ExhaustivenessMap = {
633+
def summarizePatterns(t: MutCaseOf)(implicit ctx: Ctx, raise: Raise): ExhaustivenessMap = traceUCS("[summarizePatterns]") {
607634
val m = MutMap.empty[Str \/ Int, MutMap[Var, MutCase]]
608-
def rec(t: MutCaseOf): Unit = {
609-
printlnUCS(s"Summarize pattern of ${t.describe}")
610-
indent += 1
611-
try t match {
635+
def rec(t: MutCaseOf): Unit = traceUCS(s"[rec] ${t.describe}") {
636+
t match {
612637
case Consequent(term) => ()
613638
case MissingCase => ()
614639
case IfThenElse(_, whenTrue, whenFalse) =>
615640
rec(whenTrue)
616641
rec(whenFalse)
617642
case Match(scrutinee, branches, default) =>
618643
val key = getScurtineeKey(scrutinee)
619-
branches.foreach { mutCase =>
620-
val patternMap = m.getOrElseUpdate( key, MutMap.empty)
621-
if (!patternMap.contains(mutCase.patternFields._1)) {
622-
patternMap += ((mutCase.patternFields._1, mutCase))
623-
}
624-
rec(mutCase.consequent)
644+
val patternMap = m.getOrElseUpdate(key, MutMap.empty)
645+
branches.foreach {
646+
case mutCase @ MutCase.Literal(literal, consequent) =>
647+
literal match {
648+
case tof @ Var(n) if n === "true" || n === "false" =>
649+
if (!patternMap.contains(tof)) {
650+
patternMap += ((tof, mutCase))
651+
}
652+
case _ => () // TODO: Summarize literals.
653+
}
654+
rec(consequent)
655+
case mutCase @ MutCase.Constructor((className, _), consequent) =>
656+
if (!patternMap.contains(className)) {
657+
patternMap += ((className, mutCase))
658+
}
659+
rec(consequent)
625660
}
626661
default.foreach(rec)
627-
} finally indent -= 1
628-
}
662+
}
663+
}()
629664
rec(t)
630-
printlnUCS("Exhaustiveness map")
631-
m.foreach { case (scrutinee, patterns) =>
632-
printlnUCS(s"- $scrutinee => " + patterns.keys.mkString(", "))
633-
}
665+
printlnUCS("Summarized patterns")
666+
if (m.isEmpty)
667+
printlnUCS("<Empty>")
668+
else
669+
m.foreach { case (scrutinee, patterns) =>
670+
printlnUCS(s"- $scrutinee => " + patterns.keysIterator.mkString(", "))
671+
}
634672
Map.from(m.iterator.map { case (key, patternMap) => key -> Map.from(patternMap) })
635-
}
673+
}()
636674

637675
protected def constructTerm(m: MutCaseOf)(implicit ctx: Ctx): Term = {
638676
def rec(m: MutCaseOf)(implicit defs: Set[Var]): Term = m match {
639677
case Consequent(term) => term
640678
case Match(scrutinee, branches, wildcard) =>
641679
def rec2(xs: Ls[MutCase]): CaseBranches =
642680
xs match {
643-
case MutCase(className -> fields, cases) :: next =>
681+
case MutCase.Constructor(className -> fields, cases) :: next =>
644682
// TODO: expand bindings here
645683
val consequent = rec(cases)(defs ++ fields.iterator.map(_._2))
646684
Case(className, mkLetFromFields(scrutinee, fields.toList, consequent), rec2(next))
685+
case MutCase.Literal(literal, cases) :: next =>
686+
val consequent = rec(cases)
687+
Case(literal, consequent, rec2(next))
647688
case Nil =>
648689
wildcard.fold[CaseBranches](NoCases) { rec(_) |> Wildcard }
649690
}

‎shared/src/main/scala/mlscript/ucs/MutCaseOf.scala

+80-30
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,16 @@ object MutCaseOf {
7979
rec(whenFalse, indent + 1, "")
8080
case Match(scrutinee, branches, default) =>
8181
lines += baseIndent + leading + bindingNames + showScrutinee(scrutinee) + " match"
82-
branches.foreach { case MutCase(Var(className) -> fields, consequent) =>
83-
lines += s"$baseIndent case $className =>"
84-
fields.foreach { case (field, Var(alias)) =>
85-
lines += s"$baseIndent let $alias = .$field"
86-
}
87-
rec(consequent, indent + 2, "")
82+
branches.foreach {
83+
case MutCase.Literal(literal, consequent) =>
84+
lines += s"$baseIndent case $literal =>"
85+
rec(consequent, indent + 1, "")
86+
case MutCase.Constructor(Var(className) -> fields, consequent) =>
87+
lines += s"$baseIndent case $className =>"
88+
fields.foreach { case (field, Var(alias)) =>
89+
lines += s"$baseIndent let $alias = .$field"
90+
}
91+
rec(consequent, indent + 2, "")
8892
}
8993
default.foreach { consequent =>
9094
lines += s"$baseIndent default"
@@ -100,20 +104,12 @@ object MutCaseOf {
100104
lines.toList
101105
}
102106

103-
/**
104-
* MutCase is a _mutable_ representation of a case in `MutCaseOf.Match`.
105-
*
106-
* @param patternFields the alias to the fields
107-
* @param consequent the consequential `MutCaseOf`
108-
*/
109-
final case class MutCase(
110-
val patternFields: Var -> Buffer[Str -> Var],
111-
var consequent: MutCaseOf,
112-
) {
113-
def matches(expected: Var): Bool = matches(expected.name)
114-
def matches(expected: Str): Bool = patternFields._1.name === expected
115-
def addFields(fields: Iterable[Str -> Var]): Unit =
116-
patternFields._2 ++= fields.iterator.filter(!patternFields._2.contains(_))
107+
sealed abstract class MutCase {
108+
var consequent: MutCaseOf
109+
110+
def matches(expected: Var): Bool
111+
def matches(expected: Str): Bool
112+
def matches(expected: Lit): Bool
117113

118114
// Note 1
119115
// ======
@@ -142,7 +138,38 @@ object MutCaseOf {
142138
}
143139
}
144140

145-
import Clause.{MatchClass, MatchTuple, BooleanTest, Binding}
141+
object MutCase {
142+
final case class Literal(
143+
val literal: SimpleTerm,
144+
var consequent: MutCaseOf,
145+
) extends MutCase {
146+
override def matches(expected: Var): Bool = literal match {
147+
case tof @ Var(n) if n === "true" || n === "false" => expected === tof
148+
case _ => false
149+
}
150+
override def matches(expected: Str): Bool = false
151+
override def matches(expected: Lit): Bool = literal === expected
152+
}
153+
154+
/**
155+
* MutCase is a _mutable_ representation of a case in `MutCaseOf.Match`.
156+
*
157+
* @param patternFields the alias to the fields
158+
* @param consequent the consequential `MutCaseOf`
159+
*/
160+
final case class Constructor(
161+
val patternFields: Var -> Buffer[Str -> Var],
162+
var consequent: MutCaseOf,
163+
) extends MutCase {
164+
override def matches(expected: Var): Bool = matches(expected.name)
165+
override def matches(expected: Str): Bool = patternFields._1.name === expected
166+
override def matches(expected: Lit): Bool = false
167+
def addFields(fields: Iterable[Str -> Var]): Unit =
168+
patternFields._2 ++= fields.iterator.filter(!patternFields._2.contains(_))
169+
}
170+
}
171+
172+
import Clause.{MatchLiteral, MatchClass, MatchTuple, BooleanTest, Binding}
146173

147174
// A short-hand for pattern matchings with only true and false branches.
148175
final case class IfThenElse(condition: Term, var whenTrue: MutCaseOf, var whenFalse: MutCaseOf) extends MutCaseOf {
@@ -223,16 +250,17 @@ object MutCaseOf {
223250
case N =>
224251
val newBranch = buildFirst(Conjunction(tail, trailingBindings), term)
225252
newBranch.addBindings(head.bindings)
226-
branches += MutCase(tupleClassName -> Buffer.from(fields), newBranch)
253+
branches += MutCase.Constructor(tupleClassName -> Buffer.from(fields), newBranch)
227254
.withLocations(head.locations)
228255
// Found existing pattern.
229-
case S(branch) =>
256+
case S(branch: MutCase.Constructor) =>
230257
branch.consequent.addBindings(head.bindings)
231258
branch.addFields(fields)
232259
branch.consequent.merge(Conjunction(tail, trailingBindings) -> term)
233260
}
234261
// A wild card case. We should propagate wildcard to every default positions.
235-
case Conjunction(Nil, trailingBindings) -> term => mergeDefault(trailingBindings, term)
262+
case Conjunction(Nil, trailingBindings) -> term =>
263+
mergeDefault(trailingBindings, term) // TODO: Handle the int result here.
236264
// The conditions to be inserted does not overlap with me.
237265
case conjunction -> term =>
238266
wildcard match {
@@ -243,27 +271,44 @@ object MutCaseOf {
243271
}
244272
}
245273
// Found a match condition against the same scrutinee
246-
case S((head @ MatchClass(_, className, fields), remainingConditions)) =>
274+
case S(L(head @ MatchClass(_, className, fields)) -> remainingConditions) =>
247275
branches.find(_.matches(className)) match {
248276
// No such pattern. We should create a new one.
249277
case N =>
250278
val newBranch = buildFirst(remainingConditions, branch._2)
251279
newBranch.addBindings(head.bindings)
252-
branches += MutCase(className -> Buffer.from(fields), newBranch)
280+
branches += MutCase.Constructor(className -> Buffer.from(fields), newBranch)
253281
.withLocations(head.locations)
254282
// Found existing pattern.
255-
case S(matchCase) =>
283+
case S(matchCase: MutCase.Constructor) =>
256284
// Merge interleaved bindings.
257285
matchCase.consequent.addBindings(head.bindings)
258286
matchCase.addFields(fields)
259287
matchCase.consequent.merge(remainingConditions -> branch._2)
260288
}
289+
case S(R(head @ MatchLiteral(_, literal)) -> remainingConditions) =>
290+
branches.find(branch => literal match {
291+
case v: Var => branch.matches(v)
292+
case l: Lit => branch.matches(l)
293+
}) match {
294+
// No such pattern. We should create a new one.
295+
case N =>
296+
val newConsequent = buildFirst(remainingConditions, branch._2)
297+
newConsequent.addBindings(head.bindings)
298+
branches += MutCase.Literal(literal, newConsequent)
299+
.withLocations(head.locations)
300+
case S(matchCase: MutCase.Literal) =>
301+
// Merge interleaved bindings.
302+
matchCase.consequent.addBindings(head.bindings)
303+
matchCase.consequent.merge(remainingConditions -> branch._2)
304+
}
261305
}
262306
}
263307

264308
def mergeDefault(bindings: Ls[(Bool, Var, Term)], default: Term)(implicit raise: Diagnostic => Unit): Int = {
265309
branches.iterator.map {
266-
case MutCase(_, consequent) => consequent.mergeDefault(bindings, default)
310+
case MutCase.Constructor(_, consequent) => consequent.mergeDefault(bindings, default)
311+
case MutCase.Literal(_, consequent) => consequent.mergeDefault(bindings, default)
267312
}.sum + {
268313
wildcard match {
269314
case N =>
@@ -296,16 +341,21 @@ object MutCaseOf {
296341
case Conjunction(head :: tail, trailingBindings) =>
297342
val realTail = Conjunction(tail, trailingBindings)
298343
(head match {
344+
case MatchLiteral(scrutinee, literal) =>
345+
val branches = Buffer(
346+
MutCase.Literal(literal, rec(realTail)).withLocation(literal.toLoc)
347+
)
348+
Match(scrutinee, branches, N)
299349
case BooleanTest(test) => IfThenElse(test, rec(realTail), MissingCase)
300350
case MatchClass(scrutinee, className, fields) =>
301351
val branches = Buffer(
302-
MutCase(className -> Buffer.from(fields), rec(realTail))
352+
MutCase.Constructor(className -> Buffer.from(fields), rec(realTail))
303353
.withLocations(head.locations)
304354
)
305355
Match(scrutinee, branches, N)
306356
case MatchTuple(scrutinee, arity, fields) =>
307357
val branches = Buffer(
308-
MutCase(Var(s"Tuple#$arity") -> Buffer.from(fields), rec(realTail))
358+
MutCase.Constructor(Var(s"Tuple#$arity") -> Buffer.from(fields), rec(realTail))
309359
.withLocations(head.locations)
310360
)
311361
Match(scrutinee, branches, N)

‎shared/src/test/diff/codegen/Mixin.mls

+7-5
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,16 @@ fun mk(n) = if n is
373373
1 then Neg(mk(n))
374374
_ then Add(mk(n), mk(n))
375375
TestLang.eval(mk(0))
376-
//│ fun mk: forall 'E. number -> 'E
376+
//│ fun mk: forall 'E. anything -> 'E
377377
//│ int
378378
//│ where
379379
//│ 'E :> Add['E] | Lit | Neg['E]
380380
//│ // Prelude
381381
//│ let typing_unit6 = { cache: {} };
382382
//│ // Query 1
383383
//│ globalThis.mk = function mk(n) {
384-
//│ return n == 0 === true ? Lit(0) : n == 1 === true ? Neg(mk(n)) : Add(mk(n), mk(n));
384+
//│ let a;
385+
//│ return a = n, a === 0 ? Lit(0) : a === 1 ? Neg(mk(n)) : Add(mk(n), mk(n));
385386
//│ };
386387
//│ // Query 2
387388
//│ res = TestLang.eval(mk(0));
@@ -396,7 +397,8 @@ TestLang.eval(mk(0))
396397
//│ │ ├── Prelude: <empty>
397398
//│ │ ├── Code:
398399
//│ │ ├── globalThis.mk = function mk(n) {
399-
//│ │ ├── return n == 0 === true ? Lit(0) : n == 1 === true ? Neg(mk(n)) : Add(mk(n), mk(n));
400+
//│ │ ├── let a;
401+
//│ │ ├── return a = n, a === 0 ? Lit(0) : a === 1 ? Neg(mk(n)) : Add(mk(n), mk(n));
400402
//│ │ ├── };
401403
//│ │ ├── Intermediate: [Function: mk]
402404
//│ │ └── Reply: [success] [Function: mk]
@@ -417,7 +419,7 @@ class Foo(x: int)
417419
:e
418420
class Bar(x: int, y: int) extends Foo(x + y)
419421
//│ ╔══[ERROR] Class inheritance is not supported yet (use mixins)
420-
//│ ║ l.418: class Bar(x: int, y: int) extends Foo(x + y)
422+
//│ ║ l.420: class Bar(x: int, y: int) extends Foo(x + y)
421423
//│ ╙── ^^^^^^^^^^
422424
//│ class Bar(x: int, y: int)
423425

@@ -530,7 +532,7 @@ mixin Base {
530532
fun x = y
531533
}
532534
//│ ╔══[ERROR] identifier not found: y
533-
//│ ║ l.530: fun x = y
535+
//│ ║ l.532: fun x = y
534536
//│ ╙── ^
535537
//│ mixin Base() {
536538
//│ fun x: error

‎shared/src/test/diff/ecoop23/ExpressionProblem.mls

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ fun mk(n) = if n is
201201
0 then Lit(0)
202202
1 then Neg(mk(n))
203203
_ then Add(mk(n), mk(n))
204-
//│ fun mk: forall 'E. number -> 'E
204+
//│ fun mk: forall 'E. anything -> 'E
205205
//│ where
206206
//│ 'E :> Add['E] | Lit | Neg['E]
207207

‎shared/src/test/diff/ecoop23/SimpleRegionDSL.mls

+5-5
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ fun go(x, offset) =
4545
else
4646
let shared = go(x - 1, round(offset / 2))
4747
Union(Translate(Vector(0 - offset, 0), shared), Translate(Vector(offset, 0), shared))
48-
//│ fun go: forall 'Region. (int, int,) -> 'Region
48+
//│ fun go: forall 'Region. (0 | int & ~0, int,) -> 'Region
4949
//│ where
5050
//│ 'Region :> Circle | Union[Translate['Region]]
5151

@@ -116,7 +116,7 @@ TestSize.size(Scale(Vector(1, 1), circles))
116116
fun pow(x, a) =
117117
if a is 0 then 1
118118
else x * pow(x, a - 1)
119-
//│ fun pow: (int, int,) -> int
119+
//│ fun pow: (int, 0 | int & ~0,) -> int
120120

121121
mixin Contains {
122122
fun contains(a, p) =
@@ -327,7 +327,7 @@ module TestElim extends Eliminate
327327
TestElim.eliminate(Outside(Outside(Univ())))
328328
//│ 'a
329329
//│ where
330-
//│ 'a :> Univ | Outside['a] | Union['a] | Intersect['a] | Translate['a] | Scale['a]
330+
//│ 'a :> Scale['a] | Univ | Outside['a] | Union['a] | Intersect['a] | Translate['a]
331331
//│ res
332332
//│ = Univ {}
333333

@@ -344,7 +344,7 @@ fun mk(n) = if n is
344344
3 then Intersect(mk(n), mk(n))
345345
4 then Translate(Vector(0, 0), mk(n))
346346
_ then Scale(Vector(0, 0), mk(n))
347-
//│ fun mk: forall 'Region. number -> 'Region
347+
//│ fun mk: forall 'Region. anything -> 'Region
348348
//│ where
349349
//│ 'Region :> Intersect['Region] | Outside['Region] | Scale['Region] | Translate['Region] | Union['Region]
350350

@@ -376,7 +376,7 @@ module Lang extends SizeBase, SizeExt, Contains, Text, IsUniv, IsEmpty, Eliminat
376376
//│ 'd <: Intersect['d] | Outside['e] | Scale['d] | Translate['d] | Union['d] | Univ | ~Intersect[anything] & ~Outside[anything] & ~Scale[anything] & ~Translate[anything] & ~Union[anything] & ~Univ
377377
//│ 'e <: Intersect['e] | Outside['d] | Scale['e] | Translate['e] | Union['e] | Univ | ~Intersect[anything] & ~Outside[anything] & ~Scale[anything] & ~Translate[anything] & ~Union[anything] & ~Univ
378378
//│ 'b <: Intersect['b] | Outside['b & (Outside['b] | ~#Outside)] | Scale['b] | Translate['b] | Union['b] | 'c & ~#Intersect & ~#Outside & ~#Scale & ~#Translate & ~#Union
379-
//│ 'c :> Translate['c] | Scale['c] | Outside['c] | Union['c] | Intersect['c]
379+
//│ 'c :> Outside['c] | Union['c] | Intersect['c] | Translate['c] | Scale['c]
380380
//│ 'a <: Circle | Intersect['a] | Outside['a] | Translate['a] | Union['a]
381381

382382
// TODO investigate

‎shared/src/test/diff/nu/EvalNegNeg.mls

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fun mk(n) = if n is
6161
1 then Neg(mk(n))
6262
_ then Add(mk(n), mk(n))
6363
TestLang.eval(mk(0))
64-
//│ fun mk: forall 'E. number -> 'E
64+
//│ fun mk: forall 'E. anything -> 'E
6565
//│ int
6666
//│ where
6767
//│ 'E :> Add['E] | Lit | Neg['E]

‎shared/src/test/diff/nu/FilterMap.mls

+2-20
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,9 @@ fun filtermap(f, xs) = if xs is
3131
//│ ╔══[ERROR] identifier not found: ys
3232
//│ ║ l.27: Cons(y, ys) and f(ys) is
3333
//│ ╙── ^^
34-
//│ ╔══[ERROR] Type mismatch in application:
35-
//│ ║ l.27: Cons(y, ys) and f(ys) is
36-
//│ ║ ^^^^^^^^
37-
//│ ║ l.28: false then filtermap(f, ys)
38-
//│ ║ ^^^^^^^^^
39-
//│ ╟── reference of type `false` is not an instance of type `number`
40-
//│ ║ l.28: false then filtermap(f, ys)
41-
//│ ╙── ^^^^^
42-
//│ ╔══[ERROR] Type mismatch in application:
43-
//│ ║ l.27: Cons(y, ys) and f(ys) is
44-
//│ ║ ^^^^^^^^
45-
//│ ║ l.28: false then filtermap(f, ys)
46-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
47-
//│ ║ l.29: true then Cons(y, filtermap(f, ys))
48-
//│ ║ ^^^^^^^^
49-
//│ ╟── reference of type `true` is not an instance of type `number`
50-
//│ ║ l.29: true then Cons(y, filtermap(f, ys))
51-
//│ ╙── ^^^^
5234
//│ ╔══[ERROR] type identifier not found: Tuple#2
5335
//│ ╙──
54-
//│ fun filtermap: ((Cons[nothing] | error | Nil) -> number & (Cons[nothing] | Nil) -> error, Cons[anything] | Nil,) -> (Cons[nothing] | Nil | error)
36+
//│ fun filtermap: ((Cons[nothing] | error | Nil) -> anything & (Cons[nothing] | Nil) -> (error | false | true), Cons[anything] | Nil,) -> (Cons[nothing] | Nil | error)
5537
//│ Code generation encountered an error:
5638
//│ unknown match case: Tuple#2
5739

@@ -70,6 +52,6 @@ fun filtermap(f, xs) = if xs is
7052
True then filtermap(f, ys)
7153
False then Cons(y, filtermap(f, ys))
7254
Pair(True, z) then Cons(z, filtermap(f, ys))
73-
//│ fun filtermap: forall 'head 'A. ('head -> (False | Pair[anything, 'A] | True), Cons['head & 'A] | Nil,) -> (Cons['A] | Nil)
55+
//│ fun filtermap: forall 'A 'head. ('head -> (False | Pair[anything, 'A] | True), Cons['head & 'A] | Nil,) -> (Cons['A] | Nil)
7456

7557

‎shared/src/test/diff/ucs/Humiliation.mls

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if 1 is 1 then 1 else 0
1111
//│ = 1
1212

1313
fun test(x) = if x is 1 then 0 else 1
14-
//│ test: number -> (0 | 1)
14+
//│ test: anything -> (0 | 1)
1515
//│ = [Function: test]
1616

1717
// It should report duplicated branches.
@@ -47,7 +47,7 @@ fun f(x) =
4747
Pair(1, 1) then "ones"
4848
Pair(y, 1) then x
4949
_ then "nah"
50-
//│ f: (Pair & {fst: number, snd: number} & 'a | ~Pair) -> ("nah" | "ones" | "zeros" | 'a)
50+
//│ f: (Pair & 'a | ~Pair) -> ("nah" | "ones" | "zeros" | 'a)
5151
//│ = [Function: f]
5252

5353
class Z()

‎shared/src/test/diff/ucs/LitUCS.mls

+7-33
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,11 @@
44
module A
55
//│ module A()
66

7-
// FIXME
87
// This one is easy to fix but what about the next one?
98
// The following example can better reveal the essence of the problem.
109
fun test(x: 0 | A) = if x is
1110
0 then 0
1211
A then A
13-
//│ ╔══[ERROR] Type mismatch in application:
14-
//│ ║ l.10: fun test(x: 0 | A) = if x is
15-
//│ ║ ^
16-
//│ ╟── type `A` is not an instance of type `number`
17-
//│ ║ l.10: fun test(x: 0 | A) = if x is
18-
//│ ║ ^
19-
//│ ╟── but it flows into reference with expected type `number`
20-
//│ ║ l.10: fun test(x: 0 | A) = if x is
21-
//│ ╙── ^
22-
//│ ╔══[ERROR] Type mismatch in `case` expression:
23-
//│ ║ l.10: fun test(x: 0 | A) = if x is
24-
//│ ║ ^^^^
25-
//│ ║ l.11: 0 then 0
26-
//│ ║ ^^^^^^^^^^
27-
//│ ║ l.12: A then A
28-
//│ ║ ^^^^^^^^^^
29-
//│ ╟── type `0` is not an instance of type `A`
30-
//│ ║ l.10: fun test(x: 0 | A) = if x is
31-
//│ ║ ^
32-
//│ ╟── but it flows into reference with expected type `A`
33-
//│ ║ l.10: fun test(x: 0 | A) = if x is
34-
//│ ║ ^
35-
//│ ╟── Note: constraint arises from class pattern:
36-
//│ ║ l.12: A then A
37-
//│ ╙── ^
3812
//│ fun test: (x: 0 | A,) -> (0 | A)
3913

4014
// FIXME
@@ -43,24 +17,24 @@ fun test(x: 0 | A) =
4317
x == 0 then 0
4418
x is A then A
4519
//│ ╔══[ERROR] Type mismatch in operator application:
46-
//│ ║ l.43: x == 0 then 0
20+
//│ ║ l.18: x == 0 then 0
4721
//│ ║ ^^^^
4822
//│ ╟── type `A` is not an instance of type `number`
49-
//│ ║ l.41: fun test(x: 0 | A) =
23+
//│ ║ l.16: fun test(x: 0 | A) =
5024
//│ ║ ^
5125
//│ ╟── but it flows into reference with expected type `number`
52-
//│ ║ l.43: x == 0 then 0
26+
//│ ║ l.18: x == 0 then 0
5327
//│ ╙── ^
5428
//│ ╔══[ERROR] Type mismatch in `case` expression:
55-
//│ ║ l.44: x is A then A
29+
//│ ║ l.19: x is A then A
5630
//│ ║ ^^^^^^^^^^^^^
5731
//│ ╟── type `0` is not an instance of type `A`
58-
//│ ║ l.41: fun test(x: 0 | A) =
32+
//│ ║ l.16: fun test(x: 0 | A) =
5933
//│ ║ ^
6034
//│ ╟── but it flows into reference with expected type `A`
61-
//│ ║ l.44: x is A then A
35+
//│ ║ l.19: x is A then A
6236
//│ ║ ^
6337
//│ ╟── Note: constraint arises from class pattern:
64-
//│ ║ l.44: x is A then A
38+
//│ ║ l.19: x is A then A
6539
//│ ╙── ^
6640
//│ fun test: (x: 0 | A,) -> (0 | A)

‎shared/src/test/diff/ucs/WeirdIf.mls

+9-12
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,17 @@ fun f(x) =
7979
//│ f: anything -> "bruh"
8080
//│ = [Function: f3]
8181

82-
:e
83-
:ge
8482
// Hmmmmmm, this one is valid but how to get it work?
8583
fun boolToStr(x) =
8684
if x is
8785
true then "yah"
8886
false then "nah"
89-
//│ ╔══[ERROR] The case when this is false is not handled: == (x,) (false,)
90-
//│ ║ l.86: if x is
91-
//│ ║ ^^^^
92-
//│ ║ l.87: true then "yah"
93-
//│ ║ ^^^^^^^^^^^^^^^^^^^
94-
//│ ║ l.88: false then "nah"
95-
//│ ╙── ^^^^^^^^^
96-
//│ boolToStr: anything -> error
97-
//│ Code generation encountered an error:
98-
//│ if expression was not desugared
87+
//│ boolToStr: bool -> ("nah" | "yah")
88+
//│ = [Function: boolToStr]
89+
90+
boolToStr of true
91+
boolToStr of false
92+
//│ res: "nah" | "yah"
93+
//│ = 'yah'
94+
//│ res: "nah" | "yah"
95+
//│ = 'nah'

0 commit comments

Comments
 (0)
Please sign in to comment.