Skip to content

Commit 5c2a8b2

Browse files
committed
Fix some test cases
TBH, my tree construction logic is quite weird.
1 parent cbc8688 commit 5c2a8b2

File tree

4 files changed

+66
-121
lines changed

4 files changed

+66
-121
lines changed

shared/src/main/scala/mlscript/Typer.scala

+5-3
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,12 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool)
722722
con(s_ty, req, cs_ty)
723723
case iff @ If(body, fallback) =>
724724
try {
725-
val cnf = desugarIf(body)(ctx)
725+
val cnf = desugarIf(body, fallback)(ctx)
726726
IfBodyHelpers.showConjunctions(println, cnf)
727727
val caseTree = MutCaseOf.build(cnf)
728728
println("The mutable CaseOf tree")
729729
MutCaseOf.show(caseTree).foreach(println(_))
730-
val trm = caseTree.toTerm(fallback)
730+
val trm = caseTree.toTerm
731731
println(s"Desugared term: ${trm.print(false)}")
732732
iff.desugaredIf = S(trm)
733733
typeTerm(trm)
@@ -1037,7 +1037,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool)
10371037
case _ => throw new Exception(s"illegal pattern: $pattern")
10381038
}
10391039

1040-
def desugarIf(body: IfBody)(implicit ctx: Ctx): Ls[Ls[Condition] -> Term] = {
1040+
def desugarIf(body: IfBody, fallback: Opt[Term])(implicit ctx: Ctx): Ls[Ls[Condition] -> Term] = {
10411041
// We allocate temporary variable names for nested patterns.
10421042
// This prevents aliasing problems.
10431043
implicit val scrutineeFieldAliases: MutMap[Term, MutMap[Str, Var]] = MutMap.empty
@@ -1168,6 +1168,8 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool)
11681168
}
11691169
}
11701170
desugarIfBody(body)(N, Nil)
1171+
// Add the fallback case to conjunctions if there is any.
1172+
fallback.foreach { branches += Nil -> _ }
11711173
branches.toList
11721174
}
11731175
}

shared/src/main/scala/mlscript/helpers.scala

+48-36
Original file line numberDiff line numberDiff line change
@@ -866,14 +866,14 @@ object IfBodyHelpers {
866866
s"$scrutinee is $className"
867867
case IfBodyHelpers.Condition.MatchTuple(scrutinee, arity, fields) =>
868868
s"$scrutinee is Tuple#$arity"
869-
}.mkString("", " and ", s" => $term"))
869+
}.mkString("«", "» and «", s"» => $term"))
870870
}
871871
}
872872
}
873873

874874
abstract class MutCaseOf {
875-
def append(branch: Ls[IfBodyHelpers.Condition] -> Term): Unit
876-
def toTerm(implicit fallback: Opt[Term]): Term
875+
def merge(branch: Ls[IfBodyHelpers.Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit
876+
def toTerm: Term
877877
}
878878

879879
object MutCaseOf {
@@ -933,38 +933,39 @@ object MutCaseOf {
933933

934934
// A short-hand for pattern matchings with only true and false branches.
935935
final case class IfThenElse(condition: Term, var whenTrue: MutCaseOf, var whenFalse: MutCaseOf) extends MutCaseOf {
936-
override def append(branch: Ls[Condition] -> Term): Unit = branch match {
937-
case Nil -> term =>
938-
whenFalse match {
939-
case Consequent(_) => ??? // duplicated branch
940-
case MissingCase => whenFalse = Consequent(term)
941-
case _ => whenFalse.append(branch)
942-
}
943-
case (Condition.BooleanTest(test) :: tail) -> term =>
944-
if (test === condition) {
945-
whenTrue.append(tail -> term)
946-
} else {
936+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
937+
branch match {
938+
case Nil -> term =>
947939
whenFalse match {
948940
case Consequent(_) => ??? // duplicated branch
949941
case MissingCase => whenFalse = Consequent(term)
950-
case _ => whenFalse.append(branch)
942+
case _ => whenFalse.merge(branch)
951943
}
952-
}
953-
case _ =>
954-
whenFalse match {
955-
case Consequent(_) => ??? // duplicated branch
956-
case MissingCase => whenFalse = buildBranch(branch._1, branch._2)
957-
case _ => whenFalse.append(branch)
958-
}
959-
}
960-
override def toTerm(implicit fallback: Opt[Term]): Term = {
944+
case (Condition.BooleanTest(test) :: tail) -> term =>
945+
if (test === condition) {
946+
whenTrue.merge(tail -> term)
947+
} else {
948+
whenFalse match {
949+
case Consequent(_) => ??? // duplicated branch
950+
case MissingCase => whenFalse = Consequent(term)
951+
case _ => whenFalse.merge(branch)
952+
}
953+
}
954+
case _ =>
955+
whenFalse match {
956+
case Consequent(_) => ??? // duplicated branch
957+
case MissingCase => whenFalse = buildBranch(branch._1, branch._2)
958+
case _ => whenFalse.merge(branch)
959+
}
960+
}
961+
override def toTerm: Term = {
961962
val falseBranch = Wildcard(whenFalse.toTerm)
962963
val trueBranch = Case(Var("true"), whenTrue.toTerm, falseBranch)
963964
CaseOf(condition, trueBranch)
964965
}
965966
}
966967
final case class Match(scrutinee: Term, val branches: Buffer[Branch]) extends MutCaseOf {
967-
override def append(branch: Ls[Condition] -> Term): Unit = {
968+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit = {
968969
branch match {
969970
case (Condition.MatchClass(scrutinee2, className, fields) :: tail) -> term if scrutinee2 === scrutinee =>
970971
branches.find(_.matches(className)) match {
@@ -973,7 +974,7 @@ object MutCaseOf {
973974
branches += Branch(S(className -> Buffer.from(fields)), buildBranch(tail, term))
974975
case S(branch) =>
975976
branch.addFields(fields)
976-
branch.consequent.append(tail -> term)
977+
branch.consequent.merge(tail -> term)
977978
}
978979
case (Condition.MatchTuple(scrutinee2, arity, fields) :: tail) -> term if scrutinee2 === scrutinee =>
979980
val tupleClassName = Var(s"Tuple#$arity")
@@ -983,19 +984,29 @@ object MutCaseOf {
983984
branches += Branch(S(tupleClassName -> Buffer.from(indexFields)), buildBranch(tail, term))
984985
case S(branch) =>
985986
branch.addFields(indexFields)
986-
branch.consequent.append(tail -> term)
987+
branch.consequent.merge(tail -> term)
988+
}
989+
// A wild card case. We should propagate wildcard to every default positions.
990+
case Nil -> term =>
991+
var hasWildcard = false
992+
branches.foreach {
993+
case Branch(_, _: Consequent | MissingCase) => ()
994+
case Branch(patternFields, consequent) =>
995+
consequent.merge(branch)
996+
hasWildcard &&= patternFields.isEmpty
987997
}
998+
if (!hasWildcard) branches += Branch(N, Consequent(term))
988999
case conditions -> term =>
9891000
// Locate the wildcard case.
9901001
branches.find(_.isWildcard) match {
9911002
// No wildcard. We will create a new one.
9921003
case N => branches += Branch(N, buildBranch(conditions, term))
993-
case S(Branch(N, consequent)) => consequent.append(conditions -> term)
1004+
case S(Branch(N, consequent)) => consequent.merge(conditions -> term)
9941005
case S(_) => ??? // Impossible case. What we find should be N.
9951006
}
9961007
}
9971008
}
998-
override def toTerm(implicit fallback: Opt[Term]): Term = {
1009+
override def toTerm: Term = {
9991010
def rec(xs: Ls[Branch]): CaseBranches =
10001011
xs match {
10011012
case Nil => NoCases
@@ -1010,13 +1021,14 @@ object MutCaseOf {
10101021
}
10111022
}
10121023
final case class Consequent(term: Term) extends MutCaseOf {
1013-
override def append(branch: Ls[Condition] -> Term): Unit = ???
1014-
override def toTerm(implicit fallback: Opt[Term]): Term = term
1024+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
1025+
raise(WarningReport(Message.fromStr("duplicated branch") -> N :: Nil))
1026+
override def toTerm: Term = term
10151027
}
10161028
final case object MissingCase extends MutCaseOf {
1017-
override def append(branch: Ls[Condition] -> Term): Unit = ???
1018-
override def toTerm(implicit fallback: Opt[Term]): Term =
1019-
fallback.getOrElse(throw new IfDesugaringException("missing a default branch"))
1029+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit = ???
1030+
override def toTerm: Term =
1031+
throw new IfDesugaringException("missing a default branch")
10201032
}
10211033

10221034
private def buildBranch(conditions: Ls[Condition], term: Term): MutCaseOf = {
@@ -1040,12 +1052,12 @@ object MutCaseOf {
10401052
rec(conditions)
10411053
}
10421054

1043-
def build(cnf: Ls[Ls[Condition] -> Term]): MutCaseOf = {
1055+
def build(cnf: Ls[Ls[Condition] -> Term])(implicit raise: Diagnostic => Unit): MutCaseOf = {
10441056
cnf match {
10451057
case Nil => ???
10461058
case (conditions -> term) :: next =>
10471059
val root = MutCaseOf.buildBranch(conditions, term)
1048-
next.foreach(root.append(_))
1060+
next.foreach(root.merge(_))
10491061
root
10501062
}
10511063
}

shared/src/test/diff/nu/Humiliation.mls

+13-20
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,13 @@ fun test(x) = if x is 1 then 0 else 1
1919
//│ test: number -> (0 | 1)
2020
//│ = [Function: test]
2121

22-
if f is
22+
fun testF(x) = if x is
2323
Foo(a) then a
2424
Foo(a) then a
25-
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
26-
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
27-
//│ at: mlscript.MutCaseOf$Consequent.append(helpers.scala:1013)
28-
//│ at: mlscript.MutCaseOf$Match.append(helpers.scala:976)
29-
//│ at: mlscript.MutCaseOf$.$anonfun$build$1(helpers.scala:1048)
30-
//│ at: mlscript.MutCaseOf$.$anonfun$build$1$adapted(helpers.scala:1048)
31-
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
32-
//│ at: mlscript.MutCaseOf$.build(helpers.scala:1048)
33-
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:727)
34-
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
35-
//│ at: mlscript.Typer.typeTerm(Typer.scala:744)
25+
//│ ╔══[WARNING] duplicated branch
26+
//│ ╙──
27+
//│ testF: (Foo with {x: 'x}) -> 'x
28+
//│ = [Function: testF]
3629

3730

3831
class Bar(y, z)
@@ -58,12 +51,12 @@ if x is
5851
Pair(y, 1) then x
5952
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
6053
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
61-
//│ at: mlscript.MutCaseOf$IfThenElse.append(helpers.scala:948)
62-
//│ at: mlscript.MutCaseOf$Match.append(helpers.scala:976)
63-
//│ at: mlscript.MutCaseOf$.$anonfun$build$1(helpers.scala:1048)
64-
//│ at: mlscript.MutCaseOf$.$anonfun$build$1$adapted(helpers.scala:1048)
54+
//│ at: mlscript.MutCaseOf$IfThenElse.merge(helpers.scala:949)
55+
//│ at: mlscript.MutCaseOf$Match.merge(helpers.scala:977)
56+
//│ at: mlscript.MutCaseOf$.$anonfun$build$1(helpers.scala:1060)
57+
//│ at: mlscript.MutCaseOf$.$anonfun$build$1$adapted(helpers.scala:1060)
6558
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
66-
//│ at: mlscript.MutCaseOf$.build(helpers.scala:1048)
59+
//│ at: mlscript.MutCaseOf$.build(helpers.scala:1060)
6760
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:727)
6861
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
6962
//│ at: mlscript.Typer.typeTerm(Typer.scala:744)
@@ -141,7 +134,7 @@ fun foo(x) = if x is
141134
//│ foo: (Pair & {snd: nothing} & 'a) -> ("ones" | "zeros" | 'a)
142135
//│ = [Function: foo5]
143136

144-
137+
// Note: the parser produces a wrong syntax tree.
145138
fun foo(x, y) = if x is Z() and y is O() then 0 else 1
146139
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
147140
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
@@ -158,7 +151,7 @@ fun foo(x, y) = if x is Z() and y is O() then 0 else 1
158151
fun foo(x, y) = if x is
159152
Z() and y is O() then 0 else 1
160153
//│ ╔══[PARSE ERROR] Unexpected 'else' keyword here
161-
//│ ║ l.159: Z() and y is O() then 0 else 1
154+
//│ ║ l.152: Z() and y is O() then 0 else 1
162155
//│ ╙── ^^^^
163156
//│ foo: (Z, O,) -> 0
164157
//│ = [Function: foo6]
@@ -167,6 +160,6 @@ fun foo(x, y) =
167160
if x is
168161
Z() and y is O() then 0
169162
else 1
170-
//│ foo: (anything, O,) -> (0 | 1)
163+
//│ foo: (anything, anything,) -> (0 | 1)
171164
//│ = [Function: foo7]
172165

shared/src/test/diff/nu/TrivialIf.mls

-62
Original file line numberDiff line numberDiff line change
@@ -14,72 +14,10 @@ class None: Option
1414
//│ Some: 'value -> (Some with {value: 'value})
1515
//│ None: () -> None
1616

17-
:d
1817
fun getOrElse(opt, default) =
1918
if opt is
2019
Some(value) then value
2120
None then default
22-
//│ 1. Typing term opt, default, => {if opt is ‹(Some (value,)) then value; (None) then default›}
23-
//│ | 1. Typing pattern opt, default,
24-
//│ | | 1. Typing pattern opt
25-
//│ | | 1. : α48'
26-
//│ | | 1. Typing pattern default
27-
//│ | | 1. : α49'
28-
//│ | 1. : (α48', α49',)
29-
//│ | 1. Typing term {if opt is ‹(Some (value,)) then value; (None) then default›}
30-
//│ | | 1. Typing term if opt is ‹(Some (value,)) then value; (None) then default›
31-
//│ | | | Flattened conjunctions
32-
//│ | | | + opt is Some => value
33-
//│ | | | + opt is None => default
34-
//│ | | | The mutable CaseOf tree
35-
//│ | | | «opt» match
36-
//│ | | | case Some =>
37-
//│ | | | let value = .value
38-
//│ | | | «value»
39-
//│ | | | case None =>
40-
//│ | | | «default»
41-
//│ | | | Desugared term: case opt of { Some => let value = (opt).value in value; None => default }
42-
//│ | | | 1. Typing term case opt of { Some => let value = (opt).value in value; None => default }
43-
//│ | | | | 1. Typing term opt
44-
//│ | | | | 1. : α48'
45-
//│ | | | | 1. Typing term let value = (opt).value in value
46-
//│ | | | | | 2. Typing term (opt).value
47-
//│ | | | | | | 2. Typing term opt
48-
//│ | | | | | | 2. : α50'
49-
//│ | | | | | | CONSTRAIN α50' <! {value: value51''}
50-
//│ | | | | | | where
51-
//│ | | | | | | C α50' <! {value: value51''} (0)
52-
//│ | | | | | | | EXTR RHS {value: value51''} ~> {value: value52'} to 1
53-
//│ | | | | | | | where
54-
//│ | | | | | | | and
55-
//│ value51'' :> value52'
56-
//│ | | | | | | | C α50' <! {value: value52'} (1)
57-
//│ | | | | | 2. : value51''
58-
//│ | | | | | 1. Typing term value
59-
//│ | | | | | 1. : value53'
60-
//│ | | | | 1. : value53'
61-
//│ | | | | 1. Typing term default
62-
//│ | | | | 1. : α49'
63-
//│ | | | | CONSTRAIN α48' <! ((some<option> & α50') | ((none<option> & α54') & ~(some<option>)))
64-
//│ | | | | where
65-
//│ α50' <: {value: value52'}
66-
//│ | | | | C α48' <! ((some<option> & α50') | ((none<option> & α54') & ~(some<option>))) (0)
67-
//│ | | | 1. : (value53' | α49')
68-
//│ | | 1. : (value53' | α49')
69-
//│ | 1. : (value53' | α49')
70-
//│ 1. : ((α48', α49',) -> (value53' | α49'))
71-
//│ CONSTRAIN ((α48', α49',) -> (value53' | α49')) <! getOrElse47'
72-
//│ where
73-
//│ α48' <: ((some<option> & α50') | ((none<option> & α54') & ~(some<option>)))
74-
//│ α50' <: {value: value52'}
75-
//│ value53' :> value52'
76-
//│ C ((α48', α49',) -> (value53' | α49')) <! getOrElse47' (0)
77-
//│ ⬤ Typed as: getOrElse47'
78-
//│ where:
79-
//│ getOrElse47' :> ((α48', α49',) -> (value53' | α49'))
80-
//│ α48' <: ((some<option> & α50') | ((none<option> & α54') & ~(some<option>)))
81-
//│ α50' <: {value: value52'}
82-
//│ value53' :> value52'
8321
//│ getOrElse: (None | (Some with {value: 'value}), 'value,) -> 'value
8422

8523
getOrElse(None(), 0)

0 commit comments

Comments
 (0)