diff --git a/KLR/NKI/Annotations.lean b/KLR/NKI/Annotations.lean index dcccd604..e025c913 100644 --- a/KLR/NKI/Annotations.lean +++ b/KLR/NKI/Annotations.lean @@ -131,7 +131,7 @@ private def stmts (l : List Stmt) : Ann (List Stmt) := do private def stmt' (s : Stmt') : Ann Stmt' := do match s with | .expr e => return .expr (<- expr e) - | .assert e => return .assert (<- expr e) + | .assert e msg => return .assert (<- expr e) (<- msg.mapM expr) | .ret e => return .ret (<- expr e) | .declare n e => return .declare n (<- expr e) | .letM p ty e => return .letM p (<- optExpr ty) (<- expr e) diff --git a/KLR/NKI/Basic.lean b/KLR/NKI/Basic.lean index ba9a8919..81dbfc71 100644 --- a/KLR/NKI/Basic.lean +++ b/KLR/NKI/Basic.lean @@ -138,7 +138,7 @@ structure Stmt where @[serde tag = 11] inductive Stmt' where | expr (e : Expr) - | assert (e : Expr) + | assert (e : Expr) (msg : Option Expr) | ret (e : Expr) | declare (x : Name) (ty : Expr) | letM (p : Pattern) (ty : Option Expr) (e : Expr) diff --git a/KLR/NKI/Pretty.lean b/KLR/NKI/Pretty.lean index 3946d0f7..66e21c12 100644 --- a/KLR/NKI/Pretty.lean +++ b/KLR/NKI/Pretty.lean @@ -131,7 +131,7 @@ private def stmts (l : List Stmt) : Format := private def stmt' (s : Stmt') : Format := match s with | .expr e => expr e - | .assert e => "assert " ++ expr e + | .assert e msg => "assert " ++ expr e ++ (msg.map (fun m => ", " ++ expr m)).getD "" | .ret e => "ret " ++ expr e | .declare x ty => x.toString ++ " : " ++ expr ty | .letM p none e => format p ++ " = " ++ expr e diff --git a/KLR/NKI/Simplify.lean b/KLR/NKI/Simplify.lean index 18008a17..e6f5760d 100644 --- a/KLR/NKI/Simplify.lean +++ b/KLR/NKI/Simplify.lean @@ -339,7 +339,8 @@ private def stmt' (s : Python.Stmt') : Simplify (List Stmt') := do match s with | .pass => return [] | .expr e => return [.expr (<- expr e)] - | .assert e => return [.assert (<- expr e)] + | .assert e msg => + return [.assert (<- expr e) (<- msg.mapM expr)] | .ret e => return [.ret (<- expr e)] | .assign xs e => do assign (<- exprs xs) (<- expr e) none | .augAssign x op e => do diff --git a/KLR/NKI/SimplifyOperators.lean b/KLR/NKI/SimplifyOperators.lean index d304d586..14dc2173 100644 --- a/KLR/NKI/SimplifyOperators.lean +++ b/KLR/NKI/SimplifyOperators.lean @@ -110,7 +110,7 @@ private def stmt' (s : Stmt') : SimplifyOp Stmt' := do | .whileLoop test body => return .whileLoop test (<- stmts body) | .dynWhile t body => return .dynWhile t (<- stmts body) -- statments that only contain expressions don't need to be considered and can be simply passed back - | .expr _ | .assert _ | .ret _ | .declare _ _ | .breakLoop | .continueLoop => return s + | .expr _ | .assert _ _ | .ret _ | .declare _ _ | .breakLoop | .continueLoop => return s termination_by sizeOf s end diff --git a/KLR/Python.lean b/KLR/Python.lean index a778aa69..1f42cb20 100644 --- a/KLR/Python.lean +++ b/KLR/Python.lean @@ -137,7 +137,7 @@ structure Stmt where inductive Stmt' where | pass | expr (e : Expr) - | assert (e : Expr) + | assert (e : Expr) (msg : Option Expr) | ret (e : Expr) | assign (xs : List Expr) (e: Expr) | augAssign (x : Expr) (op : BinOp) (e : Expr) diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index d93cee18..b939427e 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -401,9 +401,13 @@ partial def dynamic (l : List Stmt) : Trace Unit := do partial def stmt' (s' : Stmt') : Trace Result := do match s' with | .expr e => let _ <- expr e; return .next - | .assert e => + | .assert e msg => if <- (<- expr e).isFalse then - throw "assertion failed" + let msg <- msg.mapM expr + match msg with + | some $ .string m => + throw s!"assertion failed, {m}" + | _ => throw "assertion failed" return .next | .ret e => return .ret (<- expr e) | .declare .. => return .next diff --git a/interop/klr/gather.c b/interop/klr/gather.c index 37d063a6..9db9bb45 100644 --- a/interop/klr/gather.c +++ b/interop/klr/gather.c @@ -1181,8 +1181,8 @@ static lean_object* stmt(struct state *st, struct _stmt *python) { break; } case Assert_kind: { - // TODO capture message - s = Python_Stmt_assert(expr(st, python->v.Assert.test)); + lean_object *msg = python->v.Assert.msg ? mkSome(expr(st, python->v.Assert.msg)) : mkNone(); + s = Python_Stmt_assert(expr(st, python->v.Assert.test), msg); break; } case Return_kind: { @@ -1699,8 +1699,11 @@ PyObject* specialize(struct kernel *k, PyObject *args, PyObject *kws, PyObject * // save the constructed kernel if (k->lean_kernel) { - if (k->lean_kernel->kernel) - lean_dec(k->lean_kernel->kernel); + if (k->lean_kernel->kernel) { + if (k->lean_kernel->kernel->m_rc != 0) { + lean_dec(k->lean_kernel->kernel); + } + } } else { k->lean_kernel = calloc(1, sizeof(struct lean_kernel)); } diff --git a/interop/klr/lean_ast.h b/interop/klr/lean_ast.h index d90ad4d6..702acee3 100644 --- a/interop/klr/lean_ast.h +++ b/interop/klr/lean_ast.h @@ -15,7 +15,7 @@ lean_object* Python_Expr_mk(lean_object*,lean_object*); lean_object* Python_Kernel_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); lean_object* Python_Class_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); lean_object* Python_Stmt_expr(lean_object*); -lean_object* Python_Stmt_assert(lean_object*); +lean_object* Python_Stmt_assert(lean_object*,lean_object*); lean_object* Python_Stmt_ret(lean_object*); lean_object* Python_Stmt_assign(lean_object*,lean_object*); lean_object* Python_Stmt_augAssign(lean_object*,uint8_t,lean_object*);