Skip to content

Commit a8b1052

Browse files
committed
Handle IfOpsApp at the pattern position
1 parent 3fe4565 commit a8b1052

File tree

5 files changed

+174
-120
lines changed

5 files changed

+174
-120
lines changed

shared/src/main/scala/mlscript/Typer.scala

+20-8
Original file line numberDiff line numberDiff line change
@@ -1170,14 +1170,26 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool)
11701170
desugarMatchBranch(scrutinee, L(consequent), partialPattern, acc)
11711171
}
11721172
case L(IfOpsApp(patLhs, opsRhss)) =>
1173-
patLhs match {
1174-
case App(
1175-
App(Var("and"),
1176-
Tup((_ -> Fld(_, _, realPat)) :: Nil)),
1177-
Tup((_ -> Fld(_, _, restCond)) :: Nil)
1178-
) =>
1179-
???
1180-
case pattern => ???
1173+
separatePattern(patLhs) match {
1174+
case (patternPart, N) =>
1175+
val partialPattern = addTerm(pat, patternPart)
1176+
opsRhss.foreach { case op -> consequent =>
1177+
desugarMatchBranch(scrutinee, L(consequent), addOp(partialPattern, op), acc)
1178+
}
1179+
case (patternPart, S(extraTests)) =>
1180+
addTerm(pat, patternPart) match {
1181+
case N => ??? // Error: cannot be empty
1182+
case S(R(_)) => ??? // Error: cannot be incomplete
1183+
case S(L(pattern)) =>
1184+
val patternConditions = destructPattern(scrutinee, pattern, ctx)
1185+
val testTerms = splitAnd(extraTests)
1186+
val middleConditions = desugarConditions(testTerms.init)
1187+
val accumulatedConditions = acc ::: patternConditions ::: middleConditions
1188+
opsRhss.foreach { case op -> consequent =>
1189+
// TODO: Use lastOption
1190+
desugarIfBody(consequent)(S(L(testTerms.last)), accumulatedConditions)
1191+
}
1192+
}
11811193
}
11821194
case L(IfElse(consequent)) =>
11831195
// Because this pattern matching is incomplete, it's not included in

shared/src/main/scala/mlscript/helpers.scala

+16-16
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,10 @@ object IfBodyHelpers {
825825
case S(extraTest) =>
826826
S(L(mkBinOp(assertion, Var("and"), extraTest)))
827827
}
828-
case S(R(lhs -> op)) => S(L(mkBinOp(lhs, op, rhs)))
828+
case S(R(lhs -> op)) =>
829+
val (realRhs, extraExprOpt) = separatePattern(rhs)
830+
val leftmost = mkBinOp(lhs, op, realRhs)
831+
S(L(extraExprOpt.fold(leftmost){ mkBinOp(leftmost, Var("and"), _) }))
829832
}
830833

831834
// Add an operator to a partial term.
@@ -905,7 +908,9 @@ abstract class MutCaseOf {
905908
def merge
906909
(branch: Ls[IfBodyHelpers.Condition] -> Term)
907910
(implicit raise: Diagnostic => Unit): Unit
908-
def mergeDefault(default: Term): Unit
911+
def mergeDefault
912+
(default: Term)
913+
(implicit raise: Diagnostic => Unit): Unit
909914
def toTerm: Term
910915
}
911916

@@ -1011,13 +1016,7 @@ object MutCaseOf {
10111016
final case class IfThenElse(condition: Term, var whenTrue: MutCaseOf, var whenFalse: MutCaseOf) extends MutCaseOf {
10121017
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
10131018
branch match {
1014-
case Nil -> term =>
1015-
whenFalse match {
1016-
case Consequent(_) =>
1017-
raise(WarningReport(Message.fromStr("duplicated else branch") -> N :: Nil))
1018-
case MissingCase => whenFalse = Consequent(term)
1019-
case _ => whenFalse.merge(branch)
1020-
}
1019+
case Nil -> term => this.mergeDefault(term)
10211020
case (Condition.BooleanTest(test) :: tail) -> term =>
10221021
if (test === condition) {
10231022
whenTrue.merge(tail -> term)
@@ -1036,11 +1035,13 @@ object MutCaseOf {
10361035
}
10371036
}
10381037

1039-
override def mergeDefault(default: Term): Unit = {
1038+
override def mergeDefault(default: Term)(implicit raise: Diagnostic => Unit): Unit = {
10401039
whenTrue.mergeDefault(default)
10411040
whenFalse match {
1041+
case Consequent(_) =>
1042+
raise(WarningReport(Message.fromStr("duplicated else branch") -> N :: Nil))
10421043
case MissingCase => whenFalse = Consequent(default)
1043-
case _ => whenFalse.mergeDefault(default)
1044+
case _: IfThenElse | _: Match => whenFalse.mergeDefault(default)
10441045
}
10451046
}
10461047

@@ -1084,7 +1085,7 @@ object MutCaseOf {
10841085
}
10851086
}
10861087

1087-
override def mergeDefault(default: Term): Unit = {
1088+
override def mergeDefault(default: Term)(implicit raise: Diagnostic => Unit): Unit = {
10881089
var hasWildcard = false
10891090
branches.foreach {
10901091
case Branch(_, _: Consequent | MissingCase) => ()
@@ -1114,14 +1115,14 @@ object MutCaseOf {
11141115
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit =
11151116
raise(WarningReport(Message.fromStr("duplicated branch") -> N :: Nil))
11161117

1117-
override def mergeDefault(default: Term): Unit = ()
1118+
override def mergeDefault(default: Term)(implicit raise: Diagnostic => Unit): Unit = ()
11181119

11191120
override def toTerm: Term = term
11201121
}
11211122
final case object MissingCase extends MutCaseOf {
11221123
override def merge(branch: Ls[Condition] -> Term)(implicit raise: Diagnostic => Unit): Unit = ???
11231124

1124-
override def mergeDefault(default: Term): Unit = ()
1125+
override def mergeDefault(default: Term)(implicit raise: Diagnostic => Unit): Unit = ()
11251126

11261127
override def toTerm: Term =
11271128
throw new IfDesugaringException("missing a default branch")
@@ -1139,8 +1140,7 @@ object MutCaseOf {
11391140
case Condition.MatchTuple(scrutinee, arity, fields) =>
11401141
val branches = Buffer.empty[Branch]
11411142
val tupleClassName = Var(s"Tuple#$arity")
1142-
val indexFields = fields.map { case (index, alias) => index.toString -> alias }
1143-
branches += Branch(S(tupleClassName -> Buffer.from(indexFields)), rec(next))
1143+
branches += Branch(S(tupleClassName -> Buffer.from(fields)), rec(next))
11441144
Match(scrutinee, branches)
11451145
}
11461146
case Nil => Consequent(term)

shared/src/test/diff/nu/DirectLines.mls

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ fun f(x, y, z) =
4444
_ then "bruh"
4545
3 then "y = 3"
4646
_ then "bruh"
47+
//│ ╔══[WARNING] duplicated else branch
48+
//│ ╙──
4749
//│ f: (number, number, number,) -> ("bruh" | "x" | "y = 1" | "y = 3" | "z = 0" | "z = 9")
4850
//│ = [Function: f2]
4951

shared/src/test/diff/nu/InterleavedLet.mls

+20-20
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ if
5454
B then "B"
5555
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
5656
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
57-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
58-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
57+
//│ at: mlscript.Typer.$anonfun$desugarIf$12(Typer.scala:1256)
58+
//│ at: mlscript.Typer.$anonfun$desugarIf$12$adapted(Typer.scala:1254)
5959
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
60-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
61-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
60+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1254)
61+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1262)
6262
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
6363
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
6464
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -102,11 +102,11 @@ if x ==
102102
y * y then 0
103103
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
104104
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
105-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
106-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
105+
//│ at: mlscript.Typer.$anonfun$desugarIf$12(Typer.scala:1256)
106+
//│ at: mlscript.Typer.$anonfun$desugarIf$12$adapted(Typer.scala:1254)
107107
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
108-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
109-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
108+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1254)
109+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1262)
110110
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
111111
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
112112
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -117,11 +117,11 @@ if
117117
let y = 0
118118
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
119119
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
120-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
121-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
120+
//│ at: mlscript.Typer.$anonfun$desugarIf$12(Typer.scala:1256)
121+
//│ at: mlscript.Typer.$anonfun$desugarIf$12$adapted(Typer.scala:1254)
122122
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
123-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
124-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
123+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1254)
124+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1262)
125125
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
126126
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
127127
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -135,11 +135,11 @@ if
135135
z == 2 then 2
136136
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
137137
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
138-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
139-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
138+
//│ at: mlscript.Typer.$anonfun$desugarIf$12(Typer.scala:1256)
139+
//│ at: mlscript.Typer.$anonfun$desugarIf$12$adapted(Typer.scala:1254)
140140
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
141-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
142-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
141+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1254)
142+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1262)
143143
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
144144
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
145145
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)
@@ -167,11 +167,11 @@ if q(y) and
167167
y == z * z then "bruh"
168168
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
169169
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
170-
//│ at: mlscript.Typer.$anonfun$desugarIf$10(Typer.scala:1244)
171-
//│ at: mlscript.Typer.$anonfun$desugarIf$10$adapted(Typer.scala:1242)
170+
//│ at: mlscript.Typer.$anonfun$desugarIf$12(Typer.scala:1256)
171+
//│ at: mlscript.Typer.$anonfun$desugarIf$12$adapted(Typer.scala:1254)
172172
//│ at: scala.collection.immutable.List.foreach(List.scala:333)
173-
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1242)
174-
//│ at: mlscript.Typer.desugarIf(Typer.scala:1250)
173+
//│ at: mlscript.Typer.desugarIfBody$1(Typer.scala:1254)
174+
//│ at: mlscript.Typer.desugarIf(Typer.scala:1262)
175175
//│ at: mlscript.Typer.$anonfun$typeTerm$2(Typer.scala:725)
176176
//│ at: mlscript.TyperHelpers.trace(TyperHelpers.scala:30)
177177
//│ at: mlscript.Typer.typeTerm(Typer.scala:750)

0 commit comments

Comments
 (0)