diff --git a/io-classes/src/Control/Monad/Class/MonadThrow.hs b/io-classes/src/Control/Monad/Class/MonadThrow.hs index 541f6b04..ad655fed 100644 --- a/io-classes/src/Control/Monad/Class/MonadThrow.hs +++ b/io-classes/src/Control/Monad/Class/MonadThrow.hs @@ -241,6 +241,7 @@ instance MonadEvaluate IO where instance MonadThrow STM where throwIO = STM.throwSTM + instance MonadCatch STM where catch = STM.catchSTM diff --git a/io-sim/src/Control/Monad/IOSim/Internal.hs b/io-sim/src/Control/Monad/IOSim/Internal.hs index a3b889e1..17e425f5 100644 --- a/io-sim/src/Control/Monad/IOSim/Internal.hs +++ b/io-sim/src/Control/Monad/IOSim/Internal.hs @@ -19,6 +19,8 @@ -- incomplete uni patterns in 'schedule' (when interpreting 'StmTxCommitted') -- and 'reschedule'. {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} module Control.Monad.IOSim.Internal ( IOSim (..) @@ -71,9 +73,7 @@ import qualified Deque.Strict as Deque import GHC.Exts (fromList) import Control.Exception (NonTermination (..), assert, throw) -import Control.Monad (join) - -import Control.Monad (when) +import Control.Monad (join, when) import Control.Monad.ST.Lazy import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST) import Data.STRef.Lazy @@ -828,6 +828,35 @@ runSimTraceST mainAction = schedule mainThread initialState } +data StmControl s a where + StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a + + +-- Unwind the STM control stack till the matching exception is found +unwindControlStmStack :: forall s a. + SomeException + -> StmControl s a + -> Either Bool + ( StmControl s a + , [Map TVarId (SomeTVar s)] + ) +unwindControlStmStack e (StmControl _ frame) = unwindFrame [] frame + + where + unwindFrame :: forall s' b. [Map TVarId (SomeTVar s')] -> StmStack s' b a -> Either Bool (StmControl s' a, [Map TVarId (SomeTVar s')]) + unwindFrame _ AtomicallyFrame = Left True + unwindFrame ws (OrElseLeftFrame _ _ w _ _ ctl) = unwindFrame (w:ws) ctl + unwindFrame ws (OrElseRightFrame _ w _ _ ctl) = unwindFrame (w:ws) ctl + unwindFrame ws (CatchHandlerStmFrame _ _w _ _ ctl) = unwindFrame ws ctl -- Should not happen + unwindFrame ws (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) = + case fromException e of + -- Continue to unwind till we find a handler which can handle this exception. + Nothing -> unwindFrame (writtenOuter:ws) ctl + Just e' -> + let action' = handler e' + ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl + in Right $ (StmControl action' ctl', reverse ws) + -- -- Executing STM Transactions -- @@ -910,11 +939,47 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 = -- Continue with the k continuation go ctl' read written' writtenSeq' createdSeq' nextVid (k x) - ThrowStm e -> - {-# SCC "execAtomically.go.ThrowStm" #-} do + CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do + -- Successful main catch action + -- Merge allocations with outer sequence + !_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar) + (Map.intersection written writtenOuter) + -- Merge the written set of the inner with the outer + let written' = Map.union written writtenOuter + writtenSeq' = filter (\(SomeTVar tvar) -> + tvarId tvar `Map.notMember` writtenOuter) + writtenSeq + ++ writtenOuterSeq + -- Skip the orElse right hand and continue with the k continuation + go ctl' read written' writtenSeq' createdOuterSeq nextVid (k x) + + CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do + -- Undo all written tvars + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (k x) + + ThrowStm e -> {-# SCC "execAtomically.go.ThrowStm" #-} do + revertThem written + case unwindControlStmStack e (StmControl action ctl) of + + -- Unwind to the nearest matching exception + Right (StmControl action' ctl', ws) -> do + mapM_ revertThem ws + go ctl' read written writtenSeq createdSeq nextVid action' + + -- Abort if no matching exception is found + Left{} -> + k0 $ StmTxAborted [] (toException e) + -- Revert all the TVar writes - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted [] (toException e) + where + revertThem x = + traverse_ (\(SomeTVar tvar) -> revertTVar tvar) x + + CatchStm act handler k -> + {-# SCC "execAtomically.go.CatchStm" #-} do + let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid act Retry -> {-# SCC "execAtomically.go.Retry" #-} @@ -941,6 +1006,19 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 = -- using the written set for the outer frame go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.catchStmFrame" #-} do + -- This is XSTM3 test case from the STM paper. + -- Revert all the TVar writes within this catch action branch + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + + CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.catchHandlerStmFrame" #-} do + -- Undo all written tvars + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + OrElse a b k -> {-# SCC "execAtomically.go.OrElse" #-} do -- Execute the left side in a new frame with an empty written set diff --git a/io-sim/src/Control/Monad/IOSim/Types.hs b/io-sim/src/Control/Monad/IOSim/Types.hs index 2dc46fc7..91f2d935 100644 --- a/io-sim/src/Control/Monad/IOSim/Types.hs +++ b/io-sim/src/Control/Monad/IOSim/Types.hs @@ -175,6 +175,8 @@ runSTM (STM k) = k ReturnStm data StmA s a where ReturnStm :: a -> StmA s a ThrowStm :: SomeException -> StmA s a + -- Catch with continuation + CatchStm :: Exception e => StmA s a -> (e -> StmA s a) -> (a -> StmA s b) -> StmA s b NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b @@ -228,6 +230,8 @@ instance Monad (IOSim s) where fail = Fail.fail #endif + + instance Semigroup a => Semigroup (IOSim s a) where (<>) = liftA2 (<>) @@ -238,6 +242,8 @@ instance Monoid a => Monoid (IOSim s a) where mappend = liftA2 mappend #endif + + instance Fail.MonadFail (IOSim s) where fail msg = IOSim $ oneShot $ \_ -> Throw (toException (IO.Error.userError msg)) @@ -273,6 +279,8 @@ instance Monad (STM s) where fail = Fail.fail #endif + + instance Fail.MonadFail (STM s) where fail msg = STM $ oneShot $ \_ -> ThrowStm (toException (ErrorCall msg)) @@ -313,6 +321,23 @@ instance MonadThrow (STM s) where instance Exceptions.MonadThrow (STM s) where throwM = MonadThrow.throwIO +instance MonadCatch (STM s) where + + catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . handler) k + + -- Default implmentation uses mask. For STM, mask is not necessary. + generalBracket acquire release use = do + resource <- acquire + b <- use resource `catch` \e -> do + _ <- release resource (ExitCaseException e) + throwIO e + c <- release resource (ExitCaseSuccess b) + return (b, c) + +instance Exceptions.MonadCatch (STM s) where + + catch = MonadThrow.catch + instance MonadCatch (IOSim s) where catch action handler = IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k @@ -853,6 +878,23 @@ data StmStack s b a where -> StmStack s b c -> StmStack s a c + -- | Executing in the context of the /action/ part of the 'catch' + CatchStmFrame :: Exception e + => (e -> StmA s a) -- exception handler + -> (a -> StmA s b) -- subsequent continuation + -> Map TVarId (SomeTVar s) -- saved written vars set + -> [SomeTVar s] -- saved written vars list + -> [SomeTVar s] -- created vars list (allocations) + -> StmStack s b c + -> StmStack s a c + + -- | A continuation frame + CatchHandlerStmFrame :: (b -> StmA s c) -- subsequent continuation + -> Map TVarId (SomeTVar s) -- saved written vars set + -> [SomeTVar s] -- saved written vars list + -> [SomeTVar s] -- created vars list (allocations) + -> !(StmStack s c a) + -> StmStack s b a --- --- Schedules --- diff --git a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs index e0deae98..22d4efa3 100644 --- a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs +++ b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs @@ -1039,6 +1039,32 @@ controlSimTraceST limit control mainAction = } +data StmControl s a where + StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a + + +-- Unwind the STM control stack till the matching exception is found +unwindControlStmStack :: forall s a. + SomeException + -> StmControl s a + -> Either Bool (StmControl s a) +unwindControlStmStack e (StmControl _ frame) = unwindFrame frame + + where + unwindFrame :: forall s' b. StmStack s' b a -> Either Bool (StmControl s' a) + unwindFrame AtomicallyFrame = Left True + unwindFrame (OrElseLeftFrame _ _ _ _ _ ctl) = unwindFrame ctl + unwindFrame (OrElseRightFrame _ _ _ _ ctl) = unwindFrame ctl + unwindFrame (CatchHandlerStmFrame _ _ _ _ ctl) = unwindFrame ctl + unwindFrame (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) = + case fromException e of + -- Continue to unwind till we find a handler which can handle this exception. + Nothing -> unwindFrame ctl + Just e' -> + let action' = handler e' + ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl + in Right $ StmControl action' ctl' + -- -- Executing STM Transactions -- @@ -1121,11 +1147,44 @@ execAtomically time tid tlbl nextVid0 action0 k0 = -- Continue with the k continuation go ctl' read written' writtenSeq' createdSeq' nextVid (k x) + CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do + let written' = Map.union written writtenOuter + writtenSeq' = filter (\(SomeTVar tvar) -> + tvarId tvar `Map.notMember` writtenOuter) + writtenSeq + ++ writtenOuterSeq + createdSeq' = createdSeq ++ createdOuterSeq + go ctl' read written' writtenSeq' createdSeq' nextVid (k x) + + CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do + !_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar) + (Map.intersection written writtenOuter) + let written' = Map.union written writtenOuter + writtenSeq' = filter (\(SomeTVar tvar) -> + tvarId tvar `Map.notMember` writtenOuter) + writtenSeq + ++ writtenOuterSeq + createdSeq' = createdSeq ++ createdOuterSeq + go ctl' read written' writtenSeq' createdSeq' nextVid (k x) + + + ThrowStm e -> - {-# SCC "execAtomically.go.ThrowStm" #-} do - -- Revert all the TVar writes - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted (Map.elems read) (toException e) + {-# SCC "execAtomically.go.ThrowStm" #-} + + let abort = do + -- Revert all the TVar writes + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + k0 $ StmTxAborted (Map.elems read) (toException e) + + in case unwindControlStmStack e (StmControl action ctl) of + Left _ -> abort + Right (StmControl action' ctl') -> go ctl' read written writtenSeq createdSeq nextVid action' + + CatchStm act handler k -> + {-# SCC "execAtomically.go.ThrowStm" #-} do + let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid act Retry -> {-# SCC "execAtomically.go.Retry" #-} @@ -1152,6 +1211,19 @@ execAtomically time tid tlbl nextVid0 action0 k0 = -- using the written set for the outer frame go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.catchStmFrame" #-} do + -- Revert all the TVar writes within this catch action branch + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + + CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.catchHandlerStmFrame" #-} do + -- Revert all the TVar writes within this catch action branch + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + + OrElse a b k -> {-# SCC "execAtomically.go.OrElse" #-} do -- Execute the left side in a new frame with an empty written set diff --git a/io-sim/test/Main.hs b/io-sim/test/Main.hs index 87db2244..eb170f08 100644 --- a/io-sim/test/Main.hs +++ b/io-sim/test/Main.hs @@ -3,6 +3,7 @@ module Main (main) where import Test.Tasty import qualified Test.IOSim (tests) +import qualified Test.STM (tests) main :: IO () main = defaultMain tests @@ -10,5 +11,7 @@ main = defaultMain tests tests :: TestTree tests = testGroup "IO Sim" - [ Test.IOSim.tests + [ + Test.IOSim.tests + , Test.STM.tests ] diff --git a/io-sim/test/Test/IOSim.hs b/io-sim/test/Test/IOSim.hs index 53b817ac..8adce19d 100644 --- a/io-sim/test/Test/IOSim.hs +++ b/io-sim/test/Test/IOSim.hs @@ -33,7 +33,7 @@ import Control.Monad.Class.MonadTime import Control.Monad.Class.MonadTimer import Control.Monad.IOSim -import Test.STM +import Test.STM hiding (tests) import Test.QuickCheck import Test.Tasty @@ -134,8 +134,8 @@ tests = , testProperty "16" unit_async_16 ] , testGroup "STM reference semantics" - [ testProperty "Reference vs IO" prop_stm_referenceIO - , testProperty "Reference vs Sim" prop_stm_referenceSim + [ testProperty "Reference vs IO" (withMaxSuccess 10000 prop_stm_referenceIO) + , testProperty "Reference vs Sim" (withMaxSuccess 10000 prop_stm_referenceSim) ] , testGroup "MonadFix instance" [ testProperty "purity" prop_mfix_purity @@ -1049,7 +1049,7 @@ prop_stm_referenceSim t = -- | Compare the behaviour of the STM reference operational semantics with -- the behaviour of any 'MonadSTM' STM implementation. -- -prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m) +prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m, LazySTM.MonadSTM m, MonadCatch (LazySTM.STM m)) => SomeTerm -> m Property prop_stm_referenceM (SomeTerm _tyrep t) = do let (r1, _heap) = evalAtomically t diff --git a/io-sim/test/Test/STM.hs b/io-sim/test/Test/STM.hs index 27b5d5a5..6765a3b5 100644 --- a/io-sim/test/Test/STM.hs +++ b/io-sim/test/Test/STM.hs @@ -34,6 +34,8 @@ import Control.Monad.Class.MonadSTM as STM import Control.Monad.Class.MonadThrow import Test.QuickCheck +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.QuickCheck (testProperty) -- | The type level structure of types in our STM 'Term's. This is kept simple, @@ -67,6 +69,7 @@ data Term (t :: Type) where Return :: Expr t -> Term t Throw :: Expr a -> Term t + Catch :: Term t -> Term t -> Term t Retry :: Term t ReadTVar :: Name (TyVar t) -> Term t @@ -309,6 +312,30 @@ evalTerm !env !heap !allocs term = case term of where e' = evalExpr env e + -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of + -- + Catch t1 t2 -> + let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of + + -- Rule XSTM1 + -- M; heap, {} => return P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[return P]; heap', allocs' + NfReturn v -> (NfReturn v, heap', allocs `mappend` allocs') + + -- Rule XSTM2 + -- M; heap, {} => throw P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs' + NfThrow _ -> evalTerm env (heap `mappend` allocs') (allocs `mappend` allocs') t2 + + -- Rule XSTM3 + -- M; heap, {} => retry; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[retry]; heap, allocs + NfRetry -> (NfRetry, heap, allocs) + + Retry -> (NfRetry, heap, allocs) -- Rule READ @@ -437,7 +464,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) = -- | Execute an STM 'Term' in the 'STM' monad. -- -execTerm :: (MonadSTM m, MonadThrow (STM m)) +execTerm :: (MonadSTM m, MonadThrow (STM m), MonadCatch (STM m)) => ExecEnv m -> Term t -> STM m (ExecValue m t) @@ -451,6 +478,8 @@ execTerm env t = let e' = execExpr env e throwSTM =<< snapshotExecValue e' + Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2 + Retry -> retry ReadTVar n -> do @@ -491,7 +520,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x) snapshotExecValue (ExecValVar v _) = fmap ImmValVar (snapshotExecValue =<< readTVar v) -execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m) +execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m, MonadCatch (STM m)) => Term t -> m TxResult execAtomically t = toTxResult <$> try (atomically action') @@ -657,7 +686,7 @@ genTerm env tyrep = Nothing) ] - binTerm = frequency [ (2, bindTerm), (1, orElseTerm)] + binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)] bindTerm = sized $ \sz -> do @@ -675,6 +704,11 @@ genTerm env tyrep = OrElse <$> genTerm env tyrep <*> genTerm env tyrep + catchTerm = + sized $ \sz -> resize (sz `div` 2) $ + Catch <$> genTerm env tyrep + <*> genTerm env tyrep + genSomeExpr :: GenEnv -> Gen SomeExpr genSomeExpr env = oneof' @@ -713,6 +747,9 @@ shrinkTerm t = case t of Return e -> [Return e' | e' <- shrinkExpr e] Throw e -> [Throw e' | e' <- shrinkExpr e] + Catch t1 t2 -> [t1, t2] + ++ [Catch t1' t2 | t1' <- shrinkTerm t1 ] + ++ [Catch t1 t2' | t2' <- shrinkTerm t2 ] Retry -> [] ReadTVar _ -> [] @@ -738,6 +775,7 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = [] freeNamesTerm :: Term t -> Set NameId freeNamesTerm (Return e) = freeNamesExpr e freeNamesTerm (Throw e) = freeNamesExpr e +freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2 freeNamesTerm Retry = Set.empty freeNamesTerm (ReadTVar n) = Set.singleton (nameId n) freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e @@ -768,6 +806,7 @@ prop_genSomeTerm (SomeTerm tyrep term) = termSize :: Term a -> Int termSize Return{} = 1 termSize Throw{} = 1 +termSize (Catch a b) = 1 + termSize a + termSize b termSize Retry{} = 1 termSize ReadTVar{} = 1 termSize WriteTVar{} = 1 @@ -778,6 +817,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b termDepth :: Term a -> Int termDepth Return{} = 1 termDepth Throw{} = 1 +termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b) termDepth Retry{} = 1 termDepth ReadTVar{} = 1 termDepth WriteTVar{} = 1 @@ -790,6 +830,9 @@ showTerm p (Return e) = showParen (p > 10) $ showString "return " . showExpr 11 e showTerm p (Throw e) = showParen (p > 10) $ showString "throwSTM " . showExpr 11 e +showTerm p (Catch t1 t2) = showParen (p > 9) $ + showTerm 10 t1 . showString " `catch` " + . showTerm 10 t2 showTerm _ Retry = showString "retry" showTerm p (ReadTVar n) = showParen (p > 10) $ showString "readTVar " . showName n @@ -824,3 +867,10 @@ showTyRep _ TyRepUnit = showString "()" showTyRep _ TyRepInt = showString "Int" showTyRep p (TyRepVar t) = showParen (p > 10) $ showString "TVar " . showTyRep 11 t + +tests :: TestTree +tests = + testGroup "Test STM" + [ + testProperty "Term generation" prop_genSomeTerm + ]