Skip to content

Commit f6021f4

Browse files
committedMay 5, 2023
Merge remote-tracking branch 'lptk-fork/new-definition-typing' into ucs-paper
2 parents d168eb5 + 99d688b commit f6021f4

File tree

100 files changed

+4238
-1097
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+4238
-1097
lines changed
 

‎compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala

+19-15
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
8989
private def getFields(etts: List[Statement]): Set[Var] = {
9090
etts.flatMap{
9191
case NuFunDef(_, nm, _, _) => Some(nm)
92-
case NuTypeDef(_, TypeName(nm), _, _, _, _) => Some(Var(nm))
92+
case nuty: NuTypeDef => Some(Var(nuty.name))
9393
case Let(_, name, _, _) => Some(name)
9494
case _ => None
9595
}.toSet
@@ -180,13 +180,13 @@ class ClassLifter(logDebugMsg: Boolean = false) {
180180
}.fold(emptyCtx)(_ ++ _)
181181
case TyApp(trm, tpLst) =>
182182
getFreeVars(trm).addT(tpLst.flatMap(_.collectTypeNames.map(TypeName(_))))
183-
case NuTypeDef(_, nm, tps, param, pars, body) =>
183+
case NuTypeDef(_, nm, tps, param, _, pars, _, _, body) =>
184184
val prmVs = getFreeVars(param)(using emptyCtx, Map(), None)
185185
val newVs = prmVs.vSet ++ getFields(body.entities) + Var(nm.name)
186-
val nCtx = ctx.addV(newVs).addT(nm).addT(tps)
186+
val nCtx = ctx.addV(newVs).addT(nm).addT(tps.map(_._2))
187187
val parVs = pars.map(getFreeVars(_)(using nCtx)).fold(emptyCtx)(_ ++ _)
188188
val bodyVs = body.entities.map(getFreeVars(_)(using nCtx)).fold(emptyCtx)(_ ++ _)
189-
(bodyVs ++ parVs -+ prmVs).extT(tps)
189+
(bodyVs ++ parVs -+ prmVs).extT(tps.map(_._2))
190190
case Blk(stmts) =>
191191
val newVs = getFields(stmts)
192192
stmts.map(getFreeVars(_)(using ctx.addV(newVs))).fold(emptyCtx)(_ ++ _)
@@ -197,7 +197,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
197197
}
198198

199199
private def collectClassInfo(cls: NuTypeDef, preClss: Set[TypeName])(using ctx: LocalContext, cache: ClassCache, outer: Option[ClassInfoCache]): ClassInfoCache = {
200-
val NuTypeDef(_, nm, tps, param, pars, body) = cls
200+
val NuTypeDef(_, nm, tps, param, _, pars, _, _, body) = cls
201201
log(s"grep context of ${cls.nme.name} under {\n$ctx\n$cache\n$outer\n}\n")
202202
val (clses, funcs, trms) = splitEntities(cls.body.entities)
203203
val (supNms, rcdFlds) = pars.map(getSupClsInfoByTerm).unzip
@@ -209,7 +209,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
209209
}.unzip
210210
log(s"par record: ${flds._2.flatten}")
211211
val fields = (param.fields.flatMap(tupleEntityToVar) ++ funcs.map(_.nme) ++ clses.map(x => Var(x.nme.name)) ++ trms.flatMap(grepFieldsInTrm) ++ flds._1).toSet
212-
val nCtx = ctx.addV(fields).addV(flds._1).extT(tps)
212+
val nCtx = ctx.addV(fields).addV(flds._1).extT(tps.map(_._2))
213213
val tmpCtx = ((body.entities.map(getFreeVars(_)(using nCtx)) ++ pars.map(getFreeVars(_)(using nCtx))).fold(emptyCtx)(_ ++ _).moveT2V(preClss)
214214
).addT(flds._2.flatten.toSet).extV(supNms.flatten.map(x => Var(x.name)))
215215

@@ -379,13 +379,14 @@ class ClassLifter(logDebugMsg: Boolean = false) {
379379
val nTpNm = TypeName(genAnoName(t.name))
380380
val cls = cache.get(t).get
381381
val supArgs = Tup(cls.body.params.fields.flatMap(tupleEntityToVar).map(toFldsEle))
382-
val anoCls = NuTypeDef(Cls, nTpNm, Nil, cls.body.params, List(App(Var(t.name), supArgs)), tu)
382+
val anoCls = NuTypeDef(Cls, nTpNm, Nil, cls.body.params, None,
383+
List(App(Var(t.name), supArgs)), None, None, tu)(None)
383384
val nSta = New(Some((nTpNm, prm)), TypingUnit(Nil))
384385
val ret = liftEntities(List(anoCls, nSta))
385386
(Blk(ret._1), ret._2)
386387
case New(None, tu) =>
387388
val nTpNm = TypeName(genAnoName())
388-
val anoCls = NuTypeDef(Cls, nTpNm, Nil, Tup(Nil), Nil, tu)
389+
val anoCls = NuTypeDef(Cls, nTpNm, Nil, Tup(Nil), None, Nil, None, None, tu)(None)
389390
val nSta = New(Some((nTpNm, Tup(Nil))), TypingUnit(Nil))
390391
val ret = liftEntities(List(anoCls, nSta))
391392
(Blk(ret._1), ret._2)
@@ -464,7 +465,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
464465
val nlhs = liftType(lb)
465466
val nrhs = liftType(ub)
466467
Bounds(nlhs._1, nrhs._1) -> (nlhs._2 ++ nrhs._2)
467-
case Constrained(base, bounds, where) =>
468+
case Constrained(base: Type, bounds, where) =>
468469
val (nTargs, nCtx) = bounds.map { case (tv, Bounds(lb, ub)) =>
469470
val nlhs = liftType(lb)
470471
val nrhs = liftType(ub)
@@ -478,6 +479,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
478479
val (nBase, bCtx) = liftType(base)
479480
Constrained(nBase, nTargs, bounds2) ->
480481
((nCtx ++ nCtx2).fold(emptyCtx)(_ ++ _) ++ bCtx)
482+
case Constrained(_, _, _) => die
481483
case Function(lhs, rhs) =>
482484
val nlhs = liftType(lhs)
483485
val nrhs = liftType(rhs)
@@ -531,7 +533,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
531533
case PolyType(targs, body) =>
532534
val (body2, ctx) = liftType(body)
533535
PolyType(targs, body2) -> ctx
534-
case Top | Bot | _: Literal | _: TypeTag | _: TypeVar => target -> emptyCtx
536+
case Top | Bot | _: Literal | _: TypeTag | _: TypeVar => target.asInstanceOf[Type] -> emptyCtx
535537
}
536538

537539

@@ -541,14 +543,14 @@ class ClassLifter(logDebugMsg: Boolean = false) {
541543
body match{
542544
case Left(value) =>
543545
val ret = liftTerm(value)(using ctx.addV(nm).addT(tpVs))
544-
(func.copy(rhs = Left(ret._1)), ret._2)
546+
(func.copy(rhs = Left(ret._1))(None), ret._2)
545547
case Right(PolyType(targs, body)) =>
546548
val nBody = liftType(body)(using ctx.addT(tpVs))
547549
val nTargs = targs.map {
548550
case L(tp) => liftTypeName(tp)(using ctx.addT(tpVs)).mapFirst(Left.apply)
549551
case R(tv) => R(tv) -> emptyCtx
550552
}.unzip
551-
(func.copy(rhs = Right(PolyType(nTargs._1, nBody._1))), nTargs._2.fold(nBody._2)(_ ++ _))
553+
(func.copy(rhs = Right(PolyType(nTargs._1, nBody._1)))(None), nTargs._2.fold(nBody._2)(_ ++ _))
552554
}
553555
}
554556

@@ -624,14 +626,14 @@ class ClassLifter(logDebugMsg: Boolean = false) {
624626
).flatten.toMap
625627
}
626628
log("lift type " + target.toString() + " with cache " + cache.toString())
627-
val NuTypeDef(kind, nme, tps, params, pars, body) = target
629+
val NuTypeDef(kind, nme, tps, params, sig, pars, supAnn, thisAnn, body) = target
628630
val nOuter = cache.get(nme)
629631
val ClassInfoCache(_, nName, freeVs, flds, inners, sups, _, _, _) = nOuter.get
630632
val (clsList, funcList, termList) = splitEntities(body.entities)
631633
val innerNmsSet = clsList.map(_.nme).toSet
632634

633635
val nCache = cache ++ inners ++ getAllInners(sups)
634-
val nTps = (tps ++ (freeVs.tSet -- nCache.keySet).toList).distinct
636+
val nTps = (tps.map(_._2) ++ (freeVs.tSet -- nCache.keySet).toList).distinct
635637
val nCtx = freeVs.addT(nTps)
636638
val nParams =
637639
outer.map(x => List(toFldsEle(Var(genParName(x.liftedNm.name))))).getOrElse(Nil)
@@ -641,7 +643,9 @@ class ClassLifter(logDebugMsg: Boolean = false) {
641643
val nFuncs = funcList.map(liftFunc(_)(using emptyCtx, nCache, nOuter)).unzip
642644
val nTerms = termList.map(liftTerm(_)(using emptyCtx, nCache, nOuter)).unzip
643645
clsList.foreach(x => liftTypeDefNew(x)(using nCache, nOuter))
644-
retSeq = retSeq.appended(NuTypeDef(kind, nName, nTps, Tup(nParams), nPars._1, TypingUnit(nFuncs._1 ++ nTerms._1)))
646+
retSeq = retSeq.appended(NuTypeDef(
647+
kind, nName, nTps.map((None, _)), Tup(nParams), None, nPars._1,
648+
None, None, TypingUnit(nFuncs._1 ++ nTerms._1))(None))
645649
}
646650

647651
def liftTypingUnit(rawUnit: TypingUnit): TypingUnit = {

‎compiler/shared/main/scala/mlscript/compiler/PrettyPrinter.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ object PrettyPrinter:
3636
case Some(true) => "let'"
3737
}
3838
s"$st ${funDef.nme.name}"
39-
+ (if funDef.targs.isEmpty
39+
+ (if funDef.tparams.isEmpty
4040
then ""
41-
else funDef.targs.map(_.name).mkString("[", ", ", "]"))
41+
else funDef.tparams.map(_.name).mkString("[", ", ", "]"))
4242
+ " = "
4343
+ funDef.rhs.fold(_.toString, _.show)
4444

4545
def showTypeDef(tyDef: NuTypeDef, indent: Int = 0): String =
4646
s"${tyDef.kind.str} ${tyDef.nme.name}"
4747
+ (if tyDef.tparams.isEmpty
4848
then ""
49-
else tyDef.tparams.map(_.name).mkString("[", ",", "]"))
49+
else tyDef.tparams.map(_._2.name).mkString("[", ",", "]"))
5050
+ "(" + tyDef.params + ")"
5151
+ (if tyDef.parents.isEmpty
5252
then ""

0 commit comments

Comments
 (0)
Please sign in to comment.