Skip to content

Commit 15c6ea1

Browse files
committed
Add interesting "array programming" test + fixes to make it work
1 parent 4f99999 commit 15c6ea1

File tree

7 files changed

+379
-6
lines changed

7 files changed

+379
-6
lines changed

shared/src/main/scala/mlscript/NewParser.scala

+11
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,10 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], raiseFun: D
573573
case (KEYWORD("super"), l0) :: _ =>
574574
consume
575575
exprCont(Super().withLoc(S(l0)), prec, allowNewlines = false)
576+
case (IDENT("~", _), l0) :: _ =>
577+
consume
578+
val rest = expr(prec, allowSpace = true)
579+
exprCont(App(Var("~").withLoc(S(l0)), rest).withLoc(S(l0 ++ rest.toLoc)), prec, allowNewlines = false)
576580
case (br @ BRACKETS(bk @ (Round | Square | Curly), toks), loc) :: _ =>
577581
consume
578582
val res = rec(toks, S(br.innerLoc), br.describe).concludeWith(_.argsMaybeIndented()) // TODO
@@ -723,6 +727,13 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], raiseFun: D
723727

724728
final def exprCont(acc: Term, prec: Int, allowNewlines: Bool)(implicit et: ExpectThen, fe: FoundErr, l: Line): IfBody \/ Term = wrap(prec, s"`$acc`", allowNewlines) { l =>
725729
cur match {
730+
case (IDENT(".", _), l0) :: (br @ BRACKETS(Square, toks), l1) :: _ =>
731+
consume
732+
consume
733+
val idx = rec(toks, S(br.innerLoc), br.describe)
734+
.concludeWith(_.expr(0, allowSpace = true))
735+
val newAcc = Subs(acc, idx).withLoc(S(l0 ++ l1 ++ idx.toLoc))
736+
exprCont(newAcc, prec, allowNewlines)
726737
case (IDENT(opStr, true), l0) :: _ if /* isInfix(opStr) && */ opPrec(opStr)._1 > prec =>
727738
consume
728739
val v = Var(opStr).withLoc(S(l0))

shared/src/main/scala/mlscript/NuTypeDefs.scala

+26-5
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
483483
case s => R(s)
484484
}
485485
val funSigs = MutMap.empty[Str, NuFunDef]
486-
val implems = if (topLevel) decls else decls.filter {
486+
val implems = decls.filter {
487487
case fd @ NuFunDef(N, nme, tparams, R(rhs)) =>
488488
funSigs.updateWith(nme.name) {
489489
case S(s) =>
@@ -494,6 +494,11 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
494494
false // There will already be typed in DelayedTypeInfo
495495
case _ => true
496496
}
497+
498+
val sigInfos = if (topLevel) funSigs.map { case (nme, decl) =>
499+
val lti = new DelayedTypeInfo(decl, implicitly)
500+
decl.name -> lti
501+
} else Nil
497502
val infos = implems.map {
498503
case _decl: NuDecl =>
499504
val decl = _decl match {
@@ -530,13 +535,29 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
530535
decl.name -> lti
531536
}
532537
ctx ++= infos
538+
ctx ++= sigInfos
539+
540+
val tpdFunSigs = sigInfos.mapValues(_.complete() match {
541+
case res: TypedNuFun if res.fd.isDecl =>
542+
TopType
543+
case res: TypedNuFun =>
544+
res.typeSignature
545+
case _ => die
546+
}).toMap
533547

534548
// * Complete typing of block definitions and add results to context
535-
val completedInfos = infos.mapValues(_.complete() match {
549+
val completedInfos = (infos ++ sigInfos).mapValues(_.complete() match {
536550
case res: TypedNuFun =>
537-
// * Generalize functions as they are typed.
538-
// * Note: eventually we'll want to first reorder their typing topologically so as to maximize polymorphism.
539-
ctx += res.name -> VarSymbol(res.typeSignature, res.fd.nme)
551+
tpdFunSigs.get(res.name) match {
552+
case S(expected) =>
553+
implicit val prov: TP =
554+
TypeProvenance(res.fd.toLoc, res.fd.describe)
555+
subsume(res.typeSignature, expected)
556+
case _ =>
557+
// * Generalize functions as they are typed.
558+
// * Note: eventually we'll want to first reorder their typing topologically so as to maximize polymorphism.
559+
ctx += res.name -> VarSymbol(res.typeSignature, res.fd.nme)
560+
}
540561
CompletedTypeInfo(res)
541562
case res => CompletedTypeInfo(res)
542563
})

shared/src/main/scala/mlscript/Typer.scala

+2
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, var ne
233233
NuTypeDef(Mod, TN("true"), Nil, N, N, N, Var("Bool") :: Nil, N, N, TypingUnit(Nil))(N, N),
234234
NuTypeDef(Mod, TN("false"), Nil, N, N, N, Var("Bool") :: Nil, N, N, TypingUnit(Nil))(N, N),
235235
NuTypeDef(Cls, TN("Str"), Nil, N, N, N, Nil, N, N, TypingUnit(Nil))(N, S(preludeLoc)),
236+
NuTypeDef(Als, TN("undefined"), Nil, N, N, S(Literal(UnitLit(true))), Nil, N, N, TypingUnit(Nil))(N, S(preludeLoc)),
237+
NuTypeDef(Als, TN("null"), Nil, N, N, S(Literal(UnitLit(false))), Nil, N, N, TypingUnit(Nil))(N, S(preludeLoc)),
236238
)
237239
val builtinTypes: Ls[TypeDef] =
238240
TypeDef(Cls, TN("int"), Nil, TopType, Nil, Nil, sing(TN("number")), N, Nil) ::

shared/src/main/scala/mlscript/helpers.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ trait TermImpl extends StatementImpl { self: Term =>
573573
case App(App(Var("->"), lhs), tup: Tup) => Function(lhs.toType_!, tup.toType_!)
574574
case App(App(Var("|"), lhs), rhs) => Union(lhs.toType_!, rhs.toType_!)
575575
case App(App(Var("&"), lhs), rhs) => Inter(lhs.toType_!, rhs.toType_!)
576+
case App(Var("~"), rhs) => Neg(rhs.toType_!)
576577
case Lam(lhs, rhs) => Function(lhs.toType_!, rhs.toType_!)
577578
case App(lhs, rhs) => lhs.toType_! match {
578579
case AppliedType(base, targs) => AppliedType(base, targs :+ rhs.toType_!)
@@ -635,7 +636,7 @@ trait LitImpl { self: Lit =>
635636
case _: IntLit => Set.single(TypeName("Int")) + TypeName("Num") + TypeName("Object")
636637
case _: StrLit => Set.single(TypeName("Str")) + TypeName("Object")
637638
case _: DecLit => Set.single(TypeName("Num")) + TypeName("Object")
638-
case _: UnitLit => Set.empty
639+
case _: UnitLit => Set.single(TypeName("Object"))
639640
}
640641
}
641642

shared/src/test/diff/nu/ArrayProg.mls

+254
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
:NewDefs
2+
3+
4+
5+
fun cast(x) = x
6+
declare fun cast: anything -> nothing
7+
//│ fun cast: forall 'a. 'a -> 'a
8+
//│ fun cast: anything -> nothing
9+
10+
11+
fun mapi: (Array['a], ('a, Int) -> 'b) -> Array['b]
12+
fun mapi(xs, f) = cast(xs).map(f)
13+
//│ fun mapi: (anything, anything,) -> nothing
14+
//│ fun mapi: forall 'a 'b. (Array['a], ('a, Int,) -> 'b,) -> Array['b]
15+
16+
mapi of ["a", "", "bb"], (x, i) => [i, length of x]
17+
//│ Array[(Int, Int,)]
18+
//│ res
19+
//│ = [ [ 0, 1 ], [ 1, 0 ], [ 2, 2 ] ]
20+
21+
22+
fun map(xs, f) = mapi(xs, (x, i) => f(x))
23+
//│ fun map: forall 'a 'b. (Array['a], 'a -> 'b,) -> Array['b]
24+
25+
map of ["a", "", "bb"], x => length of x
26+
//│ Array[Int]
27+
//│ res
28+
//│ = [ 1, 0, 2 ]
29+
30+
31+
fun zip: (Array['a], Array['b & ~undefined & Object]) -> Array[['a, 'b]]
32+
fun zip(xs, ys) = mapi of xs, (x, i) =>
33+
if ys.[i] is
34+
undefined then error
35+
y then [x, y]
36+
//│ fun zip: forall 'c 'd. (Array['c], Array[Object & 'd & ~undefined],) -> Array[('c, 'd,)]
37+
//│ fun zip: forall 'a 'b. (Array['a], Array[Object & 'b & ~undefined],) -> Array[('a, 'b,)]
38+
39+
40+
zip
41+
//│ forall 'a 'b. (Array['a], Array[Object & 'b & ~undefined],) -> Array[('a, 'b,)]
42+
//│ res
43+
//│ = [Function: zip1]
44+
45+
46+
47+
class Numbr(n: Int)
48+
class Vectr(xs: Array[Numbr | Vectr])
49+
//│ class Numbr(n: Int)
50+
//│ class Vectr(xs: Array[Numbr | Vectr])
51+
52+
class Pair[A, B](a: A, b: B)
53+
//│ class Pair[A, B](a: A, b: B)
54+
55+
56+
fun unbox(x) = if x is
57+
Numbr(n) then n
58+
Vectr(xs) then map of xs, unbox
59+
//│ fun unbox: forall 'a. (Numbr | Vectr) -> 'a
60+
//│ where
61+
//│ 'a :> Int | Array['a]
62+
63+
fun add(e) =
64+
if e is
65+
Pair(Numbr(n), Numbr(m)) then Numbr(n + m)
66+
Pair(Vectr(xs), Vectr(ys)) then
67+
Vectr of map of zip(xs, ys), ([x, y]) => add of Pair of x, y
68+
Pair(Vectr(xs), Numbr(n)) then
69+
Vectr of map of xs, x => add of Pair of x, Numbr(n)
70+
Pair(Numbr(n), Vectr(xs)) then
71+
Vectr of map of xs, x => add of Pair of Numbr(n), x
72+
//│ fun add: Pair[Numbr | Vectr, Numbr | Vectr] -> (Numbr | Vectr)
73+
74+
75+
add(Pair(Numbr(0), Numbr(1)))
76+
//│ Numbr | Vectr
77+
//│ res
78+
//│ = Numbr {}
79+
80+
add(Pair(Vectr([]), Vectr([])))
81+
//│ Numbr | Vectr
82+
//│ res
83+
//│ = Vectr {}
84+
85+
let v = Vectr of [Numbr(10), Numbr(20), Numbr(30)]
86+
//│ let v: Vectr
87+
//│ v
88+
//│ = Vectr {}
89+
90+
unbox(v)
91+
//│ forall 'a. 'a
92+
//│ where
93+
//│ 'a :> Int | Array['a]
94+
//│ res
95+
//│ = [ 10, 20, 30 ]
96+
97+
98+
let res = add of Pair of (Vectr of [Numbr(1), Numbr(2)]), (Vectr of [Numbr(3), v])
99+
//│ let res: Numbr | Vectr
100+
//│ res
101+
//│ = Vectr {}
102+
103+
unbox(res)
104+
//│ forall 'a. 'a
105+
//│ where
106+
//│ 'a :> Int | Array['a]
107+
//│ res
108+
//│ = [ 4, [ 12, 22, 32 ] ]
109+
110+
111+
fun add2(e) =
112+
if e is
113+
Pair(Numbr(n), Numbr(m)) then Numbr(m + m)
114+
Pair(Numbr(n), Vectr(n)) then n
115+
//│ fun add2: Pair[Numbr, Numbr | Vectr] -> (Numbr | Array[Numbr | Vectr])
116+
117+
add2(Pair(Numbr(0), Numbr(1)))
118+
//│ Numbr | Array[Numbr | Vectr]
119+
//│ res
120+
//│ = Numbr {}
121+
122+
123+
124+
// * Playing with approximated unions/intersections
125+
126+
127+
fun t: ([Numbr,Numbr]|[Vectr,Vectr]) -> Int
128+
//│ fun t: (Numbr | Vectr, Numbr | Vectr,) -> Int
129+
130+
131+
fun s: (([Numbr,Numbr] -> Int) & ([Vectr,Vectr] -> Int),)
132+
//│ fun s: (Numbr | Vectr, Numbr | Vectr,) -> Int
133+
134+
// FIXME why does the above parse the same as:
135+
136+
fun s: ([Numbr,Numbr] -> Int) & ([Vectr,Vectr] -> Int)
137+
//│ fun s: (Numbr | Vectr, Numbr | Vectr,) -> Int
138+
139+
140+
s(Vectr([]),Vectr([]))
141+
//│ Int
142+
//│ res
143+
//│ = <no result>
144+
//│ s is not implemented
145+
146+
147+
module A {
148+
fun g: (Int -> Int) & (Str -> Str)
149+
fun g(x) = x
150+
}
151+
// g: (Int | Str) -> (Int & Str) -- under-approx
152+
// g: (Int & Str) -> (Int | Str) -- over-approx
153+
//│ module A {
154+
//│ fun g: Int -> Int & Str -> Str
155+
//│ }
156+
157+
158+
159+
160+
// === === === ERROR CASES === === === //
161+
162+
163+
:ShowRelativeLineNums
164+
:AllowTypeErrors
165+
166+
167+
:e
168+
s([Numbr(0),Numbr(0)])
169+
//│ ╔══[ERROR] Type mismatch in application:
170+
//│ ║ l.+1: s([Numbr(0),Numbr(0)])
171+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^
172+
//│ ╟── argument of type `((?a, ?b,),)` does not match type `(Numbr | Vectr, Numbr | Vectr,)`
173+
//│ ║ l.+1: s([Numbr(0),Numbr(0)])
174+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^
175+
//│ ╟── Note: constraint arises from tuple type:
176+
//│ ║ l.136: fun s: ([Numbr,Numbr] -> Int) & ([Vectr,Vectr] -> Int)
177+
//│ ╙── ^^^^^^^^^^^
178+
//│ Int | error
179+
180+
:e
181+
A.g(0)
182+
// g <: 0 -> 'a
183+
//│ ╔══[ERROR] Type mismatch in application:
184+
//│ ║ l.+1: A.g(0)
185+
//│ ║ ^^^^^^
186+
//│ ╟── integer literal of type `0` is not an instance of type `Str`
187+
//│ ║ l.+1: A.g(0)
188+
//│ ╙── ^
189+
//│ Int | Str | error
190+
191+
:e
192+
fun add(e) =
193+
if e is
194+
Pair(Numbr(n), Numbr(m)) then 0
195+
Pair(Vectr(xs), Vectr(ys)) then 1
196+
Pair(Vectr(xs), Numbr(n)) then 2
197+
//│ ╔══[ERROR] The match is not exhaustive.
198+
//│ ║ l.+2: if e is
199+
//│ ║ ^^^^
200+
//│ ╟── The scrutinee at this position misses 1 case.
201+
//│ ║ l.+3: Pair(Numbr(n), Numbr(m)) then 0
202+
//│ ║ ^^^^^^^^
203+
//│ ╟── [Missing Case 1/1] `Vectr`
204+
//│ ╟── It first appears here.
205+
//│ ║ l.+4: Pair(Vectr(xs), Vectr(ys)) then 1
206+
//│ ╙── ^^^^^^^^^
207+
//│ fun add: anything -> error
208+
209+
:e
210+
fun add(e) =
211+
if e is
212+
Pair(Numbr(n), Numbr(m)) then 0
213+
Pair(Vectr(xs), Vectr(ys)) then 1
214+
//│ ╔══[ERROR] The match is not exhaustive.
215+
//│ ║ l.+2: if e is
216+
//│ ║ ^^^^
217+
//│ ╟── The scrutinee at this position misses 1 case.
218+
//│ ║ l.+3: Pair(Numbr(n), Numbr(m)) then 0
219+
//│ ║ ^^^^^^^^
220+
//│ ╟── [Missing Case 1/1] `Vectr`
221+
//│ ╟── It first appears here.
222+
//│ ║ l.+4: Pair(Vectr(xs), Vectr(ys)) then 1
223+
//│ ╙── ^^^^^^^^^
224+
//│ fun add: anything -> error
225+
226+
:e
227+
add2(Pair(Vectr(0), Numbr(1)))
228+
//│ ╔══[ERROR] Type mismatch in application:
229+
//│ ║ l.+1: add2(Pair(Vectr(0), Numbr(1)))
230+
//│ ║ ^^^^^^^^
231+
//│ ╟── integer literal of type `0` does not match type `Array[Numbr | Vectr]`
232+
//│ ║ l.+1: add2(Pair(Vectr(0), Numbr(1)))
233+
//│ ║ ^
234+
//│ ╟── Note: constraint arises from applied type reference:
235+
//│ ║ l.48: class Vectr(xs: Array[Numbr | Vectr])
236+
//│ ╙── ^^^^^^^^^^^^^^^^^^^^
237+
//│ ╔══[ERROR] Type mismatch in application:
238+
//│ ║ l.+1: add2(Pair(Vectr(0), Numbr(1)))
239+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
240+
//│ ╟── application of type `Vectr` is not an instance of type `Numbr`
241+
//│ ║ l.+1: add2(Pair(Vectr(0), Numbr(1)))
242+
//│ ║ ^^^^^^^^
243+
//│ ╟── Note: constraint arises from class pattern:
244+
//│ ║ l.113: Pair(Numbr(n), Numbr(m)) then Numbr(m + m)
245+
//│ ║ ^^^^^
246+
//│ ╟── from reference:
247+
//│ ║ l.112: if e is
248+
//│ ║ ^
249+
//│ ╟── Note: type parameter A is defined at:
250+
//│ ║ l.52: class Pair[A, B](a: A, b: B)
251+
//│ ╙── ^
252+
//│ Numbr | error | Array[Numbr | Vectr]
253+
254+

0 commit comments

Comments
 (0)