Skip to content

Add MonadCatch instance for STM #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions io-classes/src/Control/Monad/Class/MonadThrow.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module Control.Monad.Class.MonadThrow
, ExitCase (..)
, Handler (..)
, catches
, generalBracketSTM
-- * Deprecated interfaces
, throwM
) where
Expand Down Expand Up @@ -241,17 +242,21 @@ instance MonadEvaluate IO where
instance MonadThrow STM where
throwIO = STM.throwSTM

instance MonadCatch STM where
catch = STM.catchSTM

generalBracket acquire release use = do
generalBracketSTM :: MonadCatch m => m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracketSTM acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance MonadCatch STM where
catch = STM.catchSTM

generalBracket = generalBracketSTM


--
-- Instances for ReaderT
Expand Down
53 changes: 48 additions & 5 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
-- incomplete uni patterns in 'schedule' (when interpreting 'StmTxCommitted')
-- and 'reschedule'.
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}

module Control.Monad.IOSim.Internal
( IOSim (..)
Expand Down Expand Up @@ -70,7 +72,7 @@ import qualified Deque.Strict as Deque

import GHC.Exts (fromList)

import Control.Exception (NonTermination (..), assert, throw)
import Control.Exception (NonTermination (..), assert, throw, SomeException (SomeException))
import Control.Monad (join)

import Control.Monad (when)
Expand Down Expand Up @@ -828,6 +830,7 @@ runSimTraceST mainAction = schedule mainThread initialState
}



--
-- Executing STM Transactions
--
Expand All @@ -843,6 +846,22 @@ execAtomically :: forall s a c.
execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
go AtomicallyFrame Map.empty Map.empty [] [] nextVid0 action0
where
catchStackFrameOrAbort :: forall b.
StmStack s b a
-> Map TVarId (SomeTVar s)
-> SomeException
-> TVarId
-> ST s (SimTrace c)
-> ST s (SimTrace c)
catchStackFrameOrAbort ctl read exc nextVid abort =
case ctl of
AtomicallyFrame -> abort
OrElseLeftFrame _ _ _ _ _ ctl' -> catchStackFrameOrAbort ctl' read exc nextVid abort
OrElseRightFrame _ _ _ _ ctl' -> catchStackFrameOrAbort ctl' read exc nextVid abort
CatchStmFrame handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (handler exc)


go :: forall b.
StmStack s b a
-> Map TVarId (SomeTVar s) -- set of vars read
Expand Down Expand Up @@ -910,11 +929,28 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- Continue with the k continuation
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
-- Merge allocations with outer sequence
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
{-# SCC "execAtomically.go.ThrowStm" #-}
let abort = do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
in catchStackFrameOrAbort ctl read (toException e) nextVid abort

CatchStm act handler k ->
{-# SCC "execAtomically.go.CatchStm" #-} do
let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid act

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
Expand All @@ -941,6 +977,13 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- This is XSTM3 test case from the STM paper.
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
-- Execute the left side in a new frame with an empty written set
Expand Down
46 changes: 37 additions & 9 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-partial-fields #-}
{-# LANGUAGE MultiWayIf #-}

module Control.Monad.IOSim.Types
( IOSim (..)
Expand Down Expand Up @@ -175,6 +176,8 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
-- Catch with continuation
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -224,19 +227,19 @@ instance Monad (IOSim s) where
{-# INLINE (>>) #-}
(>>) = (*>)

#if !(MIN_VERSION_base(4,13,0))
fail = Fail.fail
#endif




instance Semigroup a => Semigroup (IOSim s a) where
(<>) = liftA2 (<>)

instance Monoid a => Monoid (IOSim s a) where
mempty = pure mempty

#if !(MIN_VERSION_base(4,11,0))
mappend = liftA2 mappend
#endif




instance Fail.MonadFail (IOSim s) where
fail msg = IOSim $ oneShot $ \_ -> Throw (toException (IO.Error.userError msg))
Expand Down Expand Up @@ -269,9 +272,9 @@ instance Monad (STM s) where
{-# INLINE (>>) #-}
(>>) = (*>)

#if !(MIN_VERSION_base(4,13,0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reinstantiate these changes.

fail = Fail.fail
#endif




instance Fail.MonadFail (STM s) where
fail msg = STM $ oneShot $ \_ -> ThrowStm (toException (ErrorCall msg))
Expand Down Expand Up @@ -313,6 +316,23 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO

instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . handler') k
where
handler' :: SomeException -> STM s a
handler' exc =
if
| Just exc' <- fromException exc -> handler exc'
| otherwise -> throwIO exc

-- Default implmentation uses mask. For STM, mask is not necessary.
generalBracket = generalBracketSTM

instance Exceptions.MonadCatch (STM s) where

catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -853,6 +873,14 @@ data StmStack s b a where
-> StmStack s b c
-> StmStack s a c

-- | Executing in the context of the /action/ part of the 'catch'
CatchStmFrame :: (SomeException -> StmA s a) -- exception handler
-> (a -> StmA s b) -- subsequent continuation
-> Map TVarId (SomeTVar s) -- saved written vars set
-> [SomeTVar s] -- saved written vars list
-> [SomeTVar s] -- created vars list (allocations)
-> StmStack s b c
-> StmStack s a c
---
--- Schedules
---
Expand Down
5 changes: 4 additions & 1 deletion io-sim/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ module Main (main) where
import Test.Tasty

import qualified Test.IOSim (tests)
import qualified Test.STM (tests)

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests =
testGroup "IO Sim"
[ Test.IOSim.tests
[
Test.IOSim.tests
, Test.STM.tests
]
4 changes: 2 additions & 2 deletions io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.IOSim

import Test.STM
import Test.STM hiding (tests)

import Test.QuickCheck
import Test.Tasty
Expand Down Expand Up @@ -1049,7 +1049,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m, LazySTM.MonadSTM m, MonadCatch (LazySTM.STM m))
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
49 changes: 47 additions & 2 deletions io-sim/test/Test/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import Control.Monad.Class.MonadSTM as STM
import Control.Monad.Class.MonadThrow

import Test.QuickCheck
import Test.Tasty (testGroup, TestTree)
import Test.Tasty.QuickCheck (testProperty)


-- | The type level structure of types in our STM 'Term's. This is kept simple,
Expand Down Expand Up @@ -67,6 +69,7 @@ data Term (t :: Type) where

Return :: Expr t -> Term t
Throw :: Expr a -> Term t
Catch :: Term t -> Term t -> Term t
Retry :: Term t

ReadTVar :: Name (TyVar t) -> Term t
Expand Down Expand Up @@ -309,6 +312,30 @@ evalTerm !env !heap !allocs term = case term of
where
e' = evalExpr env e

-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
Catch t1 t2 ->
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of

-- Rule XSTM1
-- M; heap, {} => return P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[return P]; heap, allocs
NfReturn v -> (NfReturn v, heap', allocs `mappend` allocs')

-- Rule XSTM2
-- M; heap, {} => throw P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[N P]; heap, allocs
NfThrow _ -> evalTerm env (heap `mappend` allocs') (allocs `mappend` allocs') t2

-- Rule XSTM3
-- M; heap, {} => retry; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
NfRetry -> (NfRetry, heap, allocs)


Retry -> (NfRetry, heap, allocs)

-- Rule READ
Expand Down Expand Up @@ -437,7 +464,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =

-- | Execute an STM 'Term' in the 'STM' monad.
--
execTerm :: (MonadSTM m, MonadThrow (STM m))
execTerm :: (MonadSTM m, MonadThrow (STM m), MonadCatch (STM m))
=> ExecEnv m
-> Term t
-> STM m (ExecValue m t)
Expand All @@ -451,6 +478,8 @@ execTerm env t =
let e' = execExpr env e
throwSTM =<< snapshotExecValue e'

Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2

Retry -> retry

ReadTVar n -> do
Expand Down Expand Up @@ -491,7 +520,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
(snapshotExecValue =<< readTVar v)

execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m, MonadCatch (STM m))
=> Term t -> m TxResult
execAtomically t =
toTxResult <$> try (atomically action')
Expand Down Expand Up @@ -713,6 +742,9 @@ shrinkTerm t =
case t of
Return e -> [Return e' | e' <- shrinkExpr e]
Throw e -> [Throw e' | e' <- shrinkExpr e]
Catch t1 t2 -> [t1, t2]
++ [Catch t1' t2 | t1' <- shrinkTerm t1 ]
++ [Catch t1 t2' | t2' <- shrinkTerm t2 ]
Retry -> []
ReadTVar _ -> []

Expand All @@ -738,6 +770,7 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
freeNamesTerm :: Term t -> Set NameId
freeNamesTerm (Return e) = freeNamesExpr e
freeNamesTerm (Throw e) = freeNamesExpr e
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
freeNamesTerm Retry = Set.empty
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
Expand Down Expand Up @@ -768,6 +801,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
termSize :: Term a -> Int
termSize Return{} = 1
termSize Throw{} = 1
termSize (Catch a b) = 1 + termSize a + termSize b
termSize Retry{} = 1
termSize ReadTVar{} = 1
termSize WriteTVar{} = 1
Expand All @@ -778,6 +812,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
termDepth :: Term a -> Int
termDepth Return{} = 1
termDepth Throw{} = 1
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
termDepth Retry{} = 1
termDepth ReadTVar{} = 1
termDepth WriteTVar{} = 1
Expand All @@ -790,6 +825,9 @@ showTerm p (Return e) = showParen (p > 10) $
showString "return " . showExpr 11 e
showTerm p (Throw e) = showParen (p > 10) $
showString "throwSTM " . showExpr 11 e
showTerm p (Catch t1 t2) = showParen (p > 9) $
showTerm 10 t1 . showString " `catch` "
. showTerm 10 t2
showTerm _ Retry = showString "retry"
showTerm p (ReadTVar n) = showParen (p > 10) $
showString "readTVar " . showName n
Expand Down Expand Up @@ -824,3 +862,10 @@ showTyRep _ TyRepUnit = showString "()"
showTyRep _ TyRepInt = showString "Int"
showTyRep p (TyRepVar t) = showParen (p > 10) $
showString "TVar " . showTyRep 11 t

tests :: TestTree
tests =
testGroup "Test STM"
[
testProperty "Term generation" prop_genSomeTerm
]