Skip to content

Add MonadCatch instance for STM #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions io-classes/src/Control/Monad/Class/MonadThrow.hs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ instance MonadEvaluate IO where
instance MonadThrow STM where
throwIO = STM.throwSTM


instance MonadCatch STM where
catch = STM.catchSTM

Expand Down
91 changes: 84 additions & 7 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
-- incomplete uni patterns in 'schedule' (when interpreting 'StmTxCommitted')
-- and 'reschedule'.
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}

module Control.Monad.IOSim.Internal
( IOSim (..)
Expand Down Expand Up @@ -71,9 +73,7 @@ import qualified Deque.Strict as Deque
import GHC.Exts (fromList)

import Control.Exception (NonTermination (..), assert, throw)
import Control.Monad (join)

import Control.Monad (when)
import Control.Monad (join, when)
import Control.Monad.ST.Lazy
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
import Data.STRef.Lazy
Expand Down Expand Up @@ -828,6 +828,32 @@ runSimTraceST mainAction = schedule mainThread initialState
}


data StmControl s a where
StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a


-- Unwind the STM control stack till the matching exception is found
unwindControlStmStack :: forall s a.
SomeException
-> StmControl s a
-> Either Bool (StmControl s a)
unwindControlStmStack e (StmControl _ frame) = unwindFrame frame

where
unwindFrame :: forall s' b. StmStack s' b a -> Either Bool (StmControl s' a)
unwindFrame AtomicallyFrame = Left True
unwindFrame (OrElseLeftFrame _ _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (OrElseRightFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchHandlerStmFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) =
case fromException e of
-- Continue to unwind till we find a handler which can handle this exception.
Nothing -> unwindFrame ctl
Just e' ->
let action' = handler e'
ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl
in Right $ StmControl action' ctl'

--
-- Executing STM Transactions
--
Expand Down Expand Up @@ -910,11 +936,47 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- Continue with the k continuation
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
-- Successful main catch action
-- Merge allocations with outer sequence
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
-- Successful completion of catch handler
!_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar)
(Map.intersection written writtenOuter)
-- Merge allocations with outer sequence
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
{-# SCC "execAtomically.go.ThrowStm" #-}
-- Abort if no matching exception is found
let abort = do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)

-- Unwind to the nearest matching exception
in case unwindControlStmStack e (StmControl action ctl) of
Left _ -> abort
Right (StmControl action' ctl') -> go ctl' read written writtenSeq createdSeq nextVid action'

CatchStm act handler k ->
{-# SCC "execAtomically.go.CatchStm" #-} do
let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid act

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
Expand All @@ -941,6 +1003,21 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- This is XSTM3 test case from the STM paper.
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- This is XSTM3 test case from the STM paper.
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry


OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
-- Execute the left side in a new frame with an empty written set
Expand Down
55 changes: 46 additions & 9 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-partial-fields #-}
{-# LANGUAGE MultiWayIf #-}

module Control.Monad.IOSim.Types
( IOSim (..)
Expand Down Expand Up @@ -175,6 +176,8 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
-- Catch with continuation
CatchStm :: Exception e => StmA s a -> (e -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -224,19 +227,19 @@ instance Monad (IOSim s) where
{-# INLINE (>>) #-}
(>>) = (*>)

#if !(MIN_VERSION_base(4,13,0))
fail = Fail.fail
#endif




instance Semigroup a => Semigroup (IOSim s a) where
(<>) = liftA2 (<>)

instance Monoid a => Monoid (IOSim s a) where
mempty = pure mempty

#if !(MIN_VERSION_base(4,11,0))
mappend = liftA2 mappend
#endif




instance Fail.MonadFail (IOSim s) where
fail msg = IOSim $ oneShot $ \_ -> Throw (toException (IO.Error.userError msg))
Expand Down Expand Up @@ -269,9 +272,9 @@ instance Monad (STM s) where
{-# INLINE (>>) #-}
(>>) = (*>)

#if !(MIN_VERSION_base(4,13,0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reinstantiate these changes.

fail = Fail.fail
#endif




instance Fail.MonadFail (STM s) where
fail msg = STM $ oneShot $ \_ -> ThrowStm (toException (ErrorCall msg))
Expand Down Expand Up @@ -313,6 +316,23 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO

instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . handler) k

-- Default implmentation uses mask. For STM, mask is not necessary.
generalBracket acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance Exceptions.MonadCatch (STM s) where

catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -853,6 +873,23 @@ data StmStack s b a where
-> StmStack s b c
-> StmStack s a c

-- | Executing in the context of the /action/ part of the 'catch'
CatchStmFrame :: Exception e
=> (e -> StmA s a) -- exception handler
-> (a -> StmA s b) -- subsequent continuation
-> Map TVarId (SomeTVar s) -- saved written vars set
-> [SomeTVar s] -- saved written vars list
-> [SomeTVar s] -- created vars list (allocations)
-> StmStack s b c
-> StmStack s a c

-- | A continuation frame
CatchHandlerStmFrame :: (b -> StmA s c) -- subsequent continuation
-> Map TVarId (SomeTVar s) -- saved written vars set
-> [SomeTVar s] -- saved written vars list
-> [SomeTVar s] -- created vars list (allocations)
-> !(StmStack s c a)
-> StmStack s b a
---
--- Schedules
---
Expand Down
80 changes: 76 additions & 4 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,32 @@ controlSimTraceST limit control mainAction =
}


data StmControl s a where
StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a


-- Unwind the STM control stack till the matching exception is found
unwindControlStmStack :: forall s a.
SomeException
-> StmControl s a
-> Either Bool (StmControl s a)
unwindControlStmStack e (StmControl _ frame) = unwindFrame frame

where
unwindFrame :: forall s' b. StmStack s' b a -> Either Bool (StmControl s' a)
unwindFrame AtomicallyFrame = Left True
unwindFrame (OrElseLeftFrame _ _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (OrElseRightFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchHandlerStmFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) =
case fromException e of
-- Continue to unwind till we find a handler which can handle this exception.
Nothing -> unwindFrame ctl
Just e' ->
let action' = handler e'
ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl
in Right $ StmControl action' ctl'

--
-- Executing STM Transactions
--
Expand Down Expand Up @@ -1121,11 +1147,44 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
-- Continue with the k continuation
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
!_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar)
(Map.intersection written writtenOuter)
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)



ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
{-# SCC "execAtomically.go.ThrowStm" #-}

let abort = do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)

in case unwindControlStmStack e (StmControl action ctl) of
Left _ -> abort
Right (StmControl action' ctl') -> go ctl' read written writtenSeq createdSeq nextVid action'

CatchStm act handler k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid act

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
Expand All @@ -1152,6 +1211,19 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchHandlerStmFrame" #-} do
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry


OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
-- Execute the left side in a new frame with an empty written set
Expand Down
5 changes: 4 additions & 1 deletion io-sim/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ module Main (main) where
import Test.Tasty

import qualified Test.IOSim (tests)
import qualified Test.STM (tests)

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests =
testGroup "IO Sim"
[ Test.IOSim.tests
[
Test.IOSim.tests
, Test.STM.tests
]
4 changes: 2 additions & 2 deletions io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.IOSim

import Test.STM
import Test.STM hiding (tests)

import Test.QuickCheck
import Test.Tasty
Expand Down Expand Up @@ -1049,7 +1049,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m, LazySTM.MonadSTM m, MonadCatch (LazySTM.STM m))
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
Loading