diff --git a/.ghci b/.ghci index f92ef3b77..ac8589060 100644 --- a/.ghci +++ b/.ghci @@ -13,3 +13,6 @@ import qualified System.Environment as E :def amc-trace \xs -> pure $ "E.setEnv \"AMC_TRACE\" \"" ++ xs ++ "\"" :set -fobject-code +:def l const (pure ":list") +:def c const (pure ":continue") +:def s const (pure ":step") diff --git a/src/Core/Lint.hs b/src/Core/Lint.hs index 814870849..316c8b3cd 100644 --- a/src/Core/Lint.hs +++ b/src/Core/Lint.hs @@ -39,6 +39,7 @@ data CoreError | InfoMismatch CoVar VarInfo VarInfo | InfoIllegal CoVar VarInfo VarInfo | NoSuchVar CoVar + | Duplicate CoVar | IllegalUnbox | InvalidCoercion Coercion | PatternMismatch [(CoVar, Type)] [(CoVar, Type)] @@ -73,6 +74,7 @@ instance Pretty CoreError where text " got var info" <+> string (show r) text "for" <+> pretty v pretty (NoSuchVar a) = text "No such variable" <+> pretty a + pretty (Duplicate a) = text "Duplicate declaration of" <+> pretty a pretty IllegalUnbox = text "Illegal unboxed type" pretty (InvalidCoercion a) = text "Illegal coercion" <+> pretty a pretty (PatternMismatch l r) = text "Expected vars" <+> pVs l @@ -115,8 +117,8 @@ checkStmt s (Foreign v ty b:xs) = do es <- gatherError' . liftError $ -- Ensure we're declaring a value unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v))) - -- And the type is well formed *> checkType s ty + *> checkNodup v (vars s) ((Foreign v ty b, es):) <$> checkStmt (s { vars = insertVar v ty (vars s) }) xs @@ -130,6 +132,7 @@ checkStmt s (StmtLet (One (v, ty, e)):xs) = do _ -> pure ()) *> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v))) *> checkType s ty + *> checkNodup v (vars s) ((StmtLet (One (v, ty, e')), es):) <$> checkStmt (s { vars = insertVar v ty (vars s) }) xs checkStmt s (StmtLet (Many vs):xs) = do @@ -144,6 +147,7 @@ checkStmt s (StmtLet (Many vs):xs) = do _ -> pure ()) *> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v))) *> checkType s ty + *> checkNodup v (vars s) pure ((v, ty, e'), es) ((StmtLet (Many vs'), es):) <$> checkStmt s' xs @@ -197,6 +201,7 @@ checkTerm s (Lam arg@(TermArgument a ty) bod) = do -- Ensure type is valid and we're declaring a value unless (varInfo a == ValueVar) (pushError (InfoIllegal (toVar a) ValueVar (varInfo a))) *> checkType s ty + *> checkNodup a (vars s) (bty, bod') <- checkTerm (s { vars = insertVar a ty (vars s) }) bod pure ( ForallTy Irrelevant ty <$> bty @@ -206,6 +211,7 @@ checkTerm s (Lam arg@(TypeArgument a ty) bod) = do -- Ensure type is valid and we're declaring a tyvar unless (varInfo a == TypeVar) (pushError (InfoIllegal (toVar a) TypeVar (varInfo a))) *> checkType (s { tyVars = VarSet.insert (toVar a) (tyVars s) }) ty + *> checkNodup' a (tyVars s) (bty, bod') <- checkTerm (s { tyVars = VarSet.insert (toVar a) (tyVars s) }) bod pure ( ForallTy (Relevant (toVar a)) ty <$> bty @@ -222,6 +228,7 @@ checkTerm s (Let (One (v, ty, e)) r) = do _ -> pure ()) *> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v))) *> checkType s ty + *> checkNodup v (vars s) pure ( tyr, AnnLet es (One (v, ty, e')) r') @@ -236,6 +243,7 @@ checkTerm s (Let (Many vs) r) = do Just ty' | ty `apart` ty' -> pushError (TypeMismatch ty ty') _ -> pure ()) *> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v))) + *> checkNodup v (vars s) *> checkType s ty pure ((v, ty, e'), es) @@ -257,6 +265,7 @@ checkTerm s (Match e bs) = do _ -> pure ()) *> when (vs /= patternVars p) (pushError (PatternMismatch (first toVar <$> patternVars p) (first toVar <$> vs))) *> checkPattern s ty p + *> traverse_ (\(x, _) -> checkNodup x (vars s)) pVars pure ((tyr, Arm p ty r' vs tvs), es) -- Verify the types are consistent @@ -495,6 +504,13 @@ liftError m = case runErrors m of Left e -> throwError e Right x -> pure x +checkNodup :: IsVar a => a -> VarMap.Map b -> Errors CoreErrors () +checkNodup v m = when (toVar v `VarMap.member` m) (pushError (Duplicate (toVar v))) + +checkNodup' :: IsVar a => a -> VarSet.Set -> Errors CoreErrors () +checkNodup' v m = when (toVar v `VarSet.member` m) (pushError (Duplicate (toVar v))) + + gatherError :: MonadWriter LintResult m => ExceptT CoreErrors m b -> m (Maybe b, CoreErrors) gatherError m = do res <- runExceptT m diff --git a/src/Core/Optimise/Reduce.hs b/src/Core/Optimise/Reduce.hs index fd5138ad7..579994086 100644 --- a/src/Core/Optimise/Reduce.hs +++ b/src/Core/Optimise/Reduce.hs @@ -18,10 +18,10 @@ import Control.Arrow hiding ((<+>)) import qualified Data.Map.Strict as Map import qualified Data.VarMap as VarMap import qualified Data.VarSet as VarSet +import Data.Foldable import Data.Triple import Data.Graph import Data.Maybe -import Data.List import Core.Optimise.Reduce.Pattern import Core.Optimise.Reduce.Inline @@ -471,15 +471,17 @@ reduceTermK _ (AnnMatch _ test arms) cont = do . (armBody .~ substituteInTys tySubst body') reduceBody :: (Term a -> m (Term a)) -> Subst a -> AnnTerm VarSet.Set (OccursVar a) -> m (Term a) - reduceBody cont subst body = - let (sub, binds) = foldr - (\(var, a) (sub, binds) -> - (VarMap.insert (toVar var) (basicDef var (Atom a)) sub, - if isTrivialAtom a - then binds - else Let (One (var, approximateAtomType a, Atom a)) . binds)) - (mempty, id) subst - in binds <$> local (varScope %~ VarMap.union sub) (reduceTermK UsedOther body cont) + reduceBody cont subst body = do + (sub, binds) <- foldrM (\(var, a) (sub, binds) -> + if isTrivialAtom a + then pure ( VarMap.insert (toVar var) (basicDef var (Atom a)) sub, binds ) + else do + let ty = approximateAtomType a + v <- freshFrom' var + pure ( VarMap.insert (toVar var) (basicDef var (Atom (Ref (toVar v) ty))) sub + , Let (One (v, ty, Atom a)) . binds )) + (mempty, id) subst + binds <$> local (varScope %~ VarMap.union sub) (reduceTermK UsedOther body cont) foldVar :: [(a, Type)] -> (a, Atom) -> Maybe (VarMap.Map Type) -> Maybe (VarMap.Map Type) foldVar _ _ Nothing = Nothing diff --git a/src/Core/Optimise/SAT.hs b/src/Core/Optimise/SAT.hs index 2d5bb8dda..fdb04671e 100644 --- a/src/Core/Optimise/SAT.hs +++ b/src/Core/Optimise/SAT.hs @@ -60,7 +60,6 @@ import Data.Semigroup import Data.Maybe import Data.List - -- | Do the static argument transformation on a whole program. staticArgsPass :: (MonadNamey m, IsVar a) => [Stmt a] -> m [Stmt a] staticArgsPass = traverse staticArg_stmt @@ -174,13 +173,13 @@ doStaticArgs the_func the_type the_body = mkShadow worker = let go_dynamic args = do inside <- mkApps (Ref worker worker_ty) worker_ty args - pure $ foldr Lam inside args + refresh $ foldr Lam inside args go (Static (TypeArgument _ k):xs) = do - x <- fromVar . mkTyvar <$> genName + x <- fresh' TypeVar Lam (TypeArgument x k) <$> go xs go (Static (TermArgument _ k):xs) = do - x <- fromVar . mkVal <$> genName + x <- fresh' ValueVar Lam (TermArgument x k) <$> go xs go [] = go_dynamic non_static_bndrs go _ = error "NonStatic binder in static_binders" @@ -229,11 +228,11 @@ isStatic _ _ = NonStatic mkApps :: forall a m. (IsVar a, MonadNamey m) => Atom -> Type -> [Argument a] -> m (Term a) mkApps at _ [] = pure $ Atom at mkApps at (ForallTy Irrelevant _ t) (TermArgument x tau:xs) = do - this_app <- fromVar . mkVal <$> genName + this_app <- fresh' ValueVar Let (One (this_app, t, App at (Ref (toVar x) tau))) <$> mkApps (Ref (toVar this_app) t) t xs mkApps at (ForallTy r _ t) (TypeArgument v _:xs) = do - this_app <- fromVar . mkVal <$> genName + this_app <- fresh' ValueVar let t' = case r of Relevant binder -> substituteInType (VarMap.singleton binder (VarTy (toVar v))) t diff --git a/tests/lua/opt_sat_inline.lua b/tests/lua/opt_sat_inline.lua new file mode 100644 index 000000000..b31f9d5d1 --- /dev/null +++ b/tests/lua/opt_sat_inline.lua @@ -0,0 +1,10 @@ +do + local E = { __tag = "E" } + local function foldr_sat(zero, x) + if x.__tag ~= "T" then return zero end + local tmp = x[1] + local tmp0 = tmp._2._2 + return foldr_sat(tmp._1 + foldr_sat(zero, tmp0._2), tmp0._1) + end + foldr_sat(0, E) +end diff --git a/tests/lua/opt_sat_inline.ml b/tests/lua/opt_sat_inline.ml new file mode 100644 index 000000000..12d203846 --- /dev/null +++ b/tests/lua/opt_sat_inline.ml @@ -0,0 +1,11 @@ +external val (+) : int -> int -> int = "function(x, y) return x + y end" + +type sz_tree 'a = + | E + | T of 'a * int * sz_tree 'a * sz_tree 'a + +let rec foldr f zero = function + | E -> zero + | T (x, _, l, r) -> foldr f (f x (foldr f zero r)) l + +let _ = foldr (+) 0 E