Skip to content

Commit 3fe4565

Browse files
committedOct 26, 2022
Suppor tuple pattern syntax and improve defualt case propagation
1 parent 5e3fce0 commit 3fe4565

File tree

6 files changed

+102
-111
lines changed

6 files changed

+102
-111
lines changed
 

‎shared/src/main/scala/mlscript/Typer.scala

+7-13
Original file line numberDiff line numberDiff line change
@@ -1063,22 +1063,16 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool)
10631063
}
10641064
// This case handles tuple destructions.
10651065
// x is (a, b, c)
1066-
case Tup(elems) =>
1067-
// temporary binding name -> pattern
1068-
val subPatterns = Buffer.empty[(Var, Term)]
1069-
val bindings = elems.iterator.zipWithIndex.flatMap {
1070-
// x is (_, _, _) : ignore this binding
1071-
case (_ -> Fld(_, _, Var("_")), _) => N
1072-
case (_ -> Fld(_, _, name: Var), index) => S(index -> name)
1073-
case (_ -> Fld(_, _, pattern: Term), index) =>
1074-
val alias = Var(freshName)
1075-
subPatterns += ((alias, pattern))
1076-
S(index -> alias)
1077-
}.toList
1066+
case Bra(_, Tup(elems)) =>
1067+
val (subPatterns, bindings) = desugarPositionals(
1068+
scrutinee,
1069+
elems.iterator.map(_._2.value),
1070+
1.to(elems.length).map("_" + _).toList
1071+
)
10781072
Condition.MatchTuple(scrutinee, elems.length, bindings) ::
10791073
destructSubPatterns(subPatterns, ctx)
10801074
// What else?
1081-
case _ => throw new Exception(s"illegal pattern: $pattern")
1075+
case _ => throw new Exception(s"illegal pattern: ${mlscript.codegen.Helpers.inspect(pattern)}")
10821076
}
10831077

10841078

‎shared/src/main/scala/mlscript/helpers.scala

+39-24
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ object IfBodyHelpers {
799799
final case class MatchTuple(
800800
scrutinee: Term,
801801
arity: Int,
802-
fields: List[Int -> Var]
802+
fields: List[Str -> Var]
803803
) extends Condition
804804
final case class BooleanTest(test: Term) extends Condition
805805
}
@@ -904,7 +904,8 @@ object IfBodyHelpers {
904904
abstract class MutCaseOf {
905905
def merge
906906
(branch: Ls[IfBodyHelpers.Condition] -> Term)
907-
(implicit raise: Diagnostic => Unit, allowDuplicate: Boolean = false): Unit
907+
(implicit raise: Diagnostic => Unit): Unit
908+
def mergeDefault(default: Term): Unit
908909
def toTerm: Term
909910
}
910911

@@ -1008,11 +1009,10 @@ object MutCaseOf {
10081009

10091010
// A short-hand for pattern matchings with only true and false branches.
10101011
final case class IfThenElse(condition: Term, var whenTrue: MutCaseOf, var whenFalse: MutCaseOf) extends MutCaseOf {
1011-
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit, allowDuplicate: Bool): Unit =
1012+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
10121013
branch match {
10131014
case Nil -> term =>
10141015
whenFalse match {
1015-
case Consequent(_) if allowDuplicate => ()
10161016
case Consequent(_) =>
10171017
raise(WarningReport(Message.fromStr("duplicated else branch") -> N :: Nil))
10181018
case MissingCase => whenFalse = Consequent(term)
@@ -1035,14 +1035,23 @@ object MutCaseOf {
10351035
case _ => whenFalse.merge(branch)
10361036
}
10371037
}
1038+
1039+
override def mergeDefault(default: Term): Unit = {
1040+
whenTrue.mergeDefault(default)
1041+
whenFalse match {
1042+
case MissingCase => whenFalse = Consequent(default)
1043+
case _ => whenFalse.mergeDefault(default)
1044+
}
1045+
}
1046+
10381047
override def toTerm: Term = {
10391048
val falseBranch = Wildcard(whenFalse.toTerm)
10401049
val trueBranch = Case(Var("true"), whenTrue.toTerm, falseBranch)
10411050
CaseOf(condition, trueBranch)
10421051
}
10431052
}
10441053
final case class Match(scrutinee: Term, val branches: Buffer[Branch]) extends MutCaseOf {
1045-
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit, allowDuplicate: Bool): Unit = {
1054+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit = {
10461055
branch match {
10471056
case (Condition.MatchClass(scrutinee2, className, fields) :: tail) -> term if scrutinee2 === scrutinee =>
10481057
branches.find(_.matches(className)) match {
@@ -1054,25 +1063,16 @@ object MutCaseOf {
10541063
branch.consequent.merge(tail -> term)
10551064
}
10561065
case (Condition.MatchTuple(scrutinee2, arity, fields) :: tail) -> term if scrutinee2 === scrutinee =>
1057-
val tupleClassName = Var(s"Tuple#$arity")
1058-
val indexFields = fields.map { case (index, alias) => index.toString -> alias }
1066+
val tupleClassName = Var(s"Tuple#$arity") // TODO: Find a name known by Typer.
10591067
branches.find(_.matches(tupleClassName)) match {
10601068
case N =>
1061-
branches += Branch(S(tupleClassName -> Buffer.from(indexFields)), buildBranch(tail, term))
1069+
branches += Branch(S(tupleClassName -> Buffer.from(fields)), buildBranch(tail, term))
10621070
case S(branch) =>
1063-
branch.addFields(indexFields)
1071+
branch.addFields(fields)
10641072
branch.consequent.merge(tail -> term)
10651073
}
10661074
// A wild card case. We should propagate wildcard to every default positions.
1067-
case Nil -> term =>
1068-
var hasWildcard = false
1069-
branches.foreach {
1070-
case Branch(_, _: Consequent | MissingCase) => ()
1071-
case Branch(patternFields, consequent) =>
1072-
consequent.merge(branch)(implicitly, true)
1073-
hasWildcard &&= patternFields.isEmpty
1074-
}
1075-
if (!hasWildcard) branches += Branch(N, Consequent(term))
1075+
case Nil -> term => mergeDefault(term)
10761076
case conditions -> term =>
10771077
// Locate the wildcard case.
10781078
branches.find(_.isWildcard) match {
@@ -1083,6 +1083,19 @@ object MutCaseOf {
10831083
}
10841084
}
10851085
}
1086+
1087+
override def mergeDefault(default: Term): Unit = {
1088+
var hasWildcard = false
1089+
branches.foreach {
1090+
case Branch(_, _: Consequent | MissingCase) => ()
1091+
case Branch(patternFields, consequent) =>
1092+
consequent.mergeDefault(default)
1093+
hasWildcard &&= patternFields.isEmpty
1094+
}
1095+
// If this match doesn't have a default case, we make one.
1096+
if (!hasWildcard) branches += Branch(N, Consequent(default))
1097+
}
1098+
10861099
override def toTerm: Term = {
10871100
def rec(xs: Ls[Branch]): CaseBranches =
10881101
xs match {
@@ -1098,16 +1111,18 @@ object MutCaseOf {
10981111
}
10991112
}
11001113
final case class Consequent(term: Term) extends MutCaseOf {
1101-
override def merge
1102-
(branch: Ls[Condition] -> Term)
1103-
(implicit raise: Diagnostic => Unit, allowDuplicate: Bool): Unit =
1114+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
11041115
raise(WarningReport(Message.fromStr("duplicated branch") -> N :: Nil))
1116+
1117+
override def mergeDefault(default: Term): Unit = ()
1118+
11051119
override def toTerm: Term = term
11061120
}
11071121
final case object MissingCase extends MutCaseOf {
1108-
override def merge
1109-
(branch: Ls[Condition] -> Term)
1110-
(implicit raise: Diagnostic => Unit, allowDuplicate: Bool): Unit = ???
1122+
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit = ???
1123+
1124+
override def mergeDefault(default: Term): Unit = ()
1125+
11111126
override def toTerm: Term =
11121127
throw new IfDesugaringException("missing a default branch")
11131128
}

‎shared/src/test/diff/nu/Humiliation.mls

+12-20
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,12 @@ if x is
5454
Pair(1, 1) then "ones"
5555
Pair(y, 1) then x
5656
_ then "nah"
57-
//│ ╔══[ERROR] missing a default branch
57+
//│ ╔══[ERROR] identifier not found: x
5858
//│ ║ l.52: if x is
59-
//│ ║ ^^^^
60-
//│ ║ l.53: Pair(0, 0) then "zeros"
61-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^
62-
//│ ║ l.54: Pair(1, 1) then "ones"
63-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^
64-
//│ ║ l.55: Pair(y, 1) then x
65-
//│ ║ ^^^^^^^^^^^^^^^^^^^
66-
//│ ║ l.56: _ then "nah"
67-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^
68-
//│ res: error
59+
//│ ╙── ^
60+
//│ res: "nah" | "ones" | "zeros"
6961
//│ Code generation encountered an error:
70-
//│ if expression has not been not desugared
62+
//│ unresolved symbol x
7163

7264
class Z()
7365
class O()
@@ -83,11 +75,11 @@ fun foo(x) = if x is
8375
Pair(Z(), Z()) then "zeros"
8476
Pair(O(), O()) then "ones"
8577
//│ ╔══[ERROR] not exhaustive
86-
//│ ║ l.82: fun foo(x) = if x is
78+
//│ ║ l.74: fun foo(x) = if x is
8779
//│ ║ ^^^^
88-
//│ ║ l.83: Pair(Z(), Z()) then "zeros"
80+
//│ ║ l.75: Pair(Z(), Z()) then "zeros"
8981
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
90-
//│ ║ l.84: Pair(O(), O()) then "ones"
82+
//│ ║ l.76: Pair(O(), O()) then "ones"
9183
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9284
//│ foo: anything -> error
9385
//│ Code generation encountered an error:
@@ -150,13 +142,13 @@ fun foo(x) = if x is
150142
Pair(O(), O()) then "ones"
151143
Pair(y, O()) then x
152144
//│ ╔══[ERROR] not exhaustive
153-
//│ ║ l.148: fun foo(x) = if x is
145+
//│ ║ l.140: fun foo(x) = if x is
154146
//│ ║ ^^^^
155-
//│ ║ l.149: Pair(Z(), Z()) then "zeros"
147+
//│ ║ l.141: Pair(Z(), Z()) then "zeros"
156148
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
157-
//│ ║ l.150: Pair(O(), O()) then "ones"
149+
//│ ║ l.142: Pair(O(), O()) then "ones"
158150
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
159-
//│ ║ l.151: Pair(y, O()) then x
151+
//│ ║ l.143: Pair(y, O()) then x
160152
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^
161153
//│ foo: anything -> error
162154
//│ Code generation encountered an error:
@@ -170,7 +162,7 @@ fun foo(x, y) = if x is Z() and y is O() then 0 else 1
170162
fun foo(x, y) = if x is
171163
Z() and y is O() then 0 else 1
172164
//│ ╔══[PARSE ERROR] Unexpected 'else' keyword here
173-
//│ ║ l.171: Z() and y is O() then 0 else 1
165+
//│ ║ l.163: Z() and y is O() then 0 else 1
174166
//│ ╙── ^^^^
175167
//│ foo: (Z, O,) -> 0
176168
//│ = [Function: foo7]

‎shared/src/test/diff/nu/InterleavedLet.mls

+20-20
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ if
5454
B then "B"
5555
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
5656
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
57-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1250)
58-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1248)
57+
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
58+
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
5959
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
60-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1248)
61-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
60+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
61+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
6262
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
6363
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
6464
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -102,11 +102,11 @@ if x ==
102102
y * y then 0
103103
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
104104
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
105-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1250)
106-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1248)
105+
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
106+
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
107107
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
108-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1248)
109-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
108+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
109+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
110110
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
111111
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
112112
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -117,11 +117,11 @@ if
117117
let y = 0
118118
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
119119
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
120-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1250)
121-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1248)
120+
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
121+
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
122122
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
123-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1248)
124-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
123+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
124+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
125125
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
126126
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
127127
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -135,11 +135,11 @@ if
135135
z == 2 then 2
136136
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
137137
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
138-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1250)
139-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1248)
138+
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
139+
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
140140
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
141-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1248)
142-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
141+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
142+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
143143
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
144144
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
145145
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -167,11 +167,11 @@ if q(y) and
167167
y == z * z then "bruh"
168168
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
169169
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
170-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1250)
171-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1248)
170+
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
171+
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
172172
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
173-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1248)
174-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
173+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
174+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
175175
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
176176
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
177177
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)

‎shared/src/test/diff/nu/NestedPattern.mls

+19-29
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,23 @@ fun crazy(v) =
3535
_ then "lol"
3636
//│ crazy: anything -> ("bruh!" | "lol")
3737

38-
if x is
39-
(0, 0) then "zeros"
40-
(1, 1) then "ones"
41-
//│ /!!!\ Uncaught error: java.lang.Exception: illegal pattern: '(' 0, 0, ')'
42-
//│ at: mlscript.Typer.destructPattern(Typer.scala:1081)
43-
//│ at: mlscript.Typer.desugarMatchBranch$1(Typer.scala:1137)
44-
//│ at: mlscript.Typer.$anonfun$desugarIf$9(Typer.scala:1229)
45-
//│ at: mlscript.Typer.$anonfun$desugarIf$9$adapted(Typer.scala:1229)
46-
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
47-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1229)
48-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
49-
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
50-
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
51-
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
38+
:e
39+
fun f(x) =
40+
if x is
41+
(0, 0) then "zeros"
42+
(1, 1) then "ones"
43+
_ then "bruh"
44+
//│ ╔══[ERROR] type identifier not found: Tuple#2
45+
//│ ╙──
46+
//│ f: error -> error
5247

53-
if x is
54-
(0, 0) then "zeros"
55-
(1, 1) then "ones"
56-
(y, 1) then x
57-
//│ /!!!\ Uncaught error: java.lang.Exception: illegal pattern: '(' 0, 0, ')'
58-
//│ at: mlscript.Typer.destructPattern(Typer.scala:1081)
59-
//│ at: mlscript.Typer.desugarMatchBranch$1(Typer.scala:1137)
60-
//│ at: mlscript.Typer.$anonfun$desugarIf$9(Typer.scala:1229)
61-
//│ at: mlscript.Typer.$anonfun$desugarIf$9$adapted(Typer.scala:1229)
62-
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
63-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1229)
64-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
65-
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
66-
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
67-
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
48+
:e
49+
fun f(x) =
50+
if x is
51+
(0, 0) then "zeros"
52+
(1, 1) then "ones"
53+
(y, 1) then x
54+
_ then "que?"
55+
//│ ╔══[ERROR] type identifier not found: Tuple#2
56+
//│ ╙──
57+
//│ f: error -> error

‎shared/src/test/diff/nu/SplitAroundOp.mls

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ if x is
7676
< 0 then 2
7777
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
7878
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
79-
//│ at: mlscript.Typer.desugarMatchBranch$1(Typer.scala:1186)
80-
//│ at: mlscript.Typer.$anonfun$desugarIf$9(Typer.scala:1229)
81-
//│ at: mlscript.Typer.$anonfun$desugarIf$9$adapted(Typer.scala:1229)
79+
//│ at: mlscript.Typer.desugarMatchBranch$1(Typer.scala:1180)
80+
//│ at: mlscript.Typer.$anonfun$desugarIf$9(Typer.scala:1223)
81+
//│ at: mlscript.Typer.$anonfun$desugarIf$9$adapted(Typer.scala:1223)
8282
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
83-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1229)
84-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1256)
83+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1223)
84+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
8585
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
8686
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
8787
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)

0 commit comments

Comments
 (0)
Please sign in to comment.