Skip to content

Commit d073a5b

Browse files
yogeshsajanikarcoot
andcommitted
Add support for MonadCatch
- Add support for Catch in IOSim and IOSimPOR - Add support for Catch in Test/STM.hs Co-authored-by: Marcin Szamotulski <[email protected]>
1 parent 6b81d7c commit d073a5b

File tree

5 files changed

+155
-42
lines changed

5 files changed

+155
-42
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -926,19 +926,42 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
926926

927927
ThrowStm e ->
928928
{-# SCC "execAtomically.go.ThrowStm" #-} do
929-
-- Revert all the TVar writes
929+
-- Rollback `TVar`s written since catch handler was installed
930930
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
931-
k0 $ StmTxAborted [] (toException e)
931+
case ctl of
932+
AtomicallyFrame -> do
933+
k0 $ StmTxAborted (Map.elems read) (toException e)
934+
935+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
936+
{-# SCC "execAtomically.go.BranchFrame" #-} do
937+
-- Execute the left side in a new frame with an empty written set.
938+
-- Rollback `TVar`s written since catch handler was installed,
939+
-- but preserve ones that were set prior to it, as specified in the
940+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
941+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
942+
go ctl'' read Map.empty [] [] nextVid (h e)
943+
--
944+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
945+
{-# SCC "execAtomically.go.BranchFrame" #-} do
946+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
947+
948+
CatchStm a h k ->
949+
{-# SCC "execAtomically.go.ThrowStm" #-} do
950+
-- Execute the catch handler with an empty written set.
951+
-- but preserve ones that were set prior to it, as specified in the
952+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
953+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
954+
go ctl' read Map.empty [] [] nextVid a
955+
932956

933957
Retry ->
934-
{-# SCC "execAtomically.go.Retry" #-}
935-
do
958+
{-# SCC "execAtomically.go.Retry" #-} do
936959
-- Always revert all the TVar writes for the retry
937960
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
938961
case ctl of
939962
AtomicallyFrame -> do
940963
-- Return vars read, so the thread can block on them
941-
k0 $! StmTxBlocked $! (Map.elems read)
964+
k0 $! StmTxBlocked $! Map.elems read
942965

943966
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
944967
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ runSTM (STM k) = k ReturnStm
184184
data StmA s a where
185185
ReturnStm :: a -> StmA s a
186186
ThrowStm :: SomeException -> StmA s a
187+
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b
187188

188189
NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
189190
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
@@ -322,6 +323,32 @@ instance MonadThrow (STM s) where
322323
instance Exceptions.MonadThrow (STM s) where
323324
throwM = MonadThrow.throwIO
324325

326+
327+
instance MonadCatch (STM s) where
328+
329+
catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
330+
where
331+
-- Get a total handler from the given handler
332+
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
333+
fromHandler h e = case fromException e of
334+
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
335+
Just e' -> h e'
336+
337+
-- STM actions are always run inside `execAtomically` and behave as if masked
338+
-- Another point to note that the default implementation of `generalBracket` needs
339+
-- mask, and is part of `MonadThrow`. For STM, we don't need masking because
340+
-- async exceptions are handled outside of `execAtomically`.
341+
generalBracket acquire release use = do
342+
resource <- acquire
343+
b <- use resource `catch` \e -> do
344+
_ <- release resource (ExitCaseException e)
345+
throwIO e
346+
c <- release resource (ExitCaseSuccess b)
347+
return (b, c)
348+
349+
instance Exceptions.MonadCatch (STM s) where
350+
catch = MonadThrow.catch
351+
325352
instance MonadCatch (IOSim s) where
326353
catch action handler =
327354
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
@@ -857,9 +884,10 @@ data StmTxResult s a =
857884
| StmTxAborted [SomeTVar s] SomeException
858885

859886

860-
-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
861-
-- can be a NoOp
862-
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
887+
-- | A branch is an alternative of a `OrElse` or a `CatchStm` statement
888+
data BranchStmA s a = OrElseStmA (StmA s a)
889+
| CatchStmA (SomeException -> StmA s a)
890+
| NoOpStmA
863891

864892
data StmStack s b a where
865893
-- | Executing in the context of a top level 'atomically'.

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,32 +1174,51 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
11741174
{-# SCC "execAtomically.go.ThrowStm" #-} do
11751175
-- Revert all the TVar writes
11761176
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1177-
k0 $ StmTxAborted (Map.elems read) (toException e)
1177+
case ctl of
1178+
AtomicallyFrame -> do
1179+
k0 $ StmTxAborted (Map.elems read) (toException e)
1180+
1181+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1182+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1183+
-- Execute the left side in a new frame with an empty written set.
1184+
-- Rollback `TVar`s written since catch handler was installed,
1185+
-- but preserve ones that were set prior to it, as specified in the
1186+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
1187+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1188+
go ctl'' read Map.empty [] [] nextVid (h e)
1189+
1190+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1191+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1192+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1193+
1194+
CatchStm a h k ->
1195+
{-# SCC "execAtomically.go.ThrowStm" #-} do
1196+
-- Execute the left side in a new frame with an empty written set
1197+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
1198+
go ctl' read Map.empty [] [] nextVid a
11781199

11791200
Retry ->
1180-
{-# SCC "execAtomically.go.Retry" #-}
1181-
do
1182-
-- Always revert all the TVar writes for the retry
1183-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1184-
case ctl of
1185-
AtomicallyFrame -> do
1186-
-- Return vars read, so the thread can block on them
1187-
k0 $! StmTxBlocked $! Map.elems read
1188-
1189-
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1190-
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1191-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1192-
-- Execute the orElse right hand with an empty written set
1193-
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1194-
go ctl'' read Map.empty [] [] nextVid b
1195-
1196-
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1197-
{-# SCC "execAtomically.go.BranchFrame" #-} do
1198-
-- Retry makes sense only within a OrElse context. If it is a branch other than
1199-
-- OrElse left side, then bubble up the `retry` to the frame above.
1200-
-- Skip the continuation and propagate the retry into the outer frame
1201-
-- using the written set for the outer frame
1202-
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
1201+
{-# SCC "execAtomically.go.Retry" #-} do
1202+
-- Always revert all the TVar writes for the retry
1203+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1204+
case ctl of
1205+
AtomicallyFrame -> do
1206+
-- Return vars read, so the thread can block on them
1207+
k0 $! StmTxBlocked $! Map.elems read
1208+
1209+
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1210+
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1211+
-- Execute the orElse right hand with an empty written set
1212+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1213+
go ctl'' read Map.empty [] [] nextVid b
1214+
1215+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1216+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1217+
-- Retry makes sense only within a OrElse context. If it is a branch other than
1218+
-- OrElse left side, then bubble up the `retry` to the frame above.
1219+
-- Skip the continuation and propagate the retry into the outer frame
1220+
-- using the written set for the outer frame
1221+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
12031222

12041223
OrElse a b k ->
12051224
{-# SCC "execAtomically.go.OrElse" #-} do

io-sim/test/Test/IOSim.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ prop_stm_referenceSim t =
12211221
-- | Compare the behaviour of the STM reference operational semantics with
12221222
-- the behaviour of any 'MonadSTM' STM implementation.
12231223
--
1224-
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
1224+
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
12251225
=> SomeTerm -> m Property
12261226
prop_stm_referenceM (SomeTerm _tyrep t) = do
12271227
let (r1, _heap) = evalAtomically t

io-sim/test/Test/STM.hs

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ data Term (t :: Type) where
6868

6969
Return :: Expr t -> Term t
7070
Throw :: Expr a -> Term t
71+
Catch :: Term t -> Term t -> Term t
7172
Retry :: Term t
7273

7374
ReadTVar :: Name (TyVar t) -> Term t
@@ -297,7 +298,7 @@ deriving instance Show (NfTerm t)
297298
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
298299
--
299300
-- Compare the implementation of this against the operational semantics in
300-
-- Figure 4 in the paper. Note that @catch@ is not included.
301+
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
301302
--
302303
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
303304
evalTerm !env !heap !allocs term = case term of
@@ -310,6 +311,30 @@ evalTerm !env !heap !allocs term = case term of
310311
where
311312
e' = evalExpr env e
312313

314+
-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
315+
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
316+
Catch t1 t2 ->
317+
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
318+
319+
-- Rule XSTM1
320+
-- M; heap, {} => return P; heap', allocs'
321+
-- --------------------------------------------------------
322+
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs'
323+
NfReturn v -> (NfReturn v, heap', allocs <> allocs')
324+
325+
-- Rule XSTM2
326+
-- M; heap, {} => throw P; heap', allocs'
327+
-- --------------------------------------------------------
328+
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
329+
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
330+
331+
-- Rule XSTM3
332+
-- M; heap, {} => retry; heap', allocs'
333+
-- --------------------------------------------------------
334+
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
335+
NfRetry -> (NfRetry, heap, allocs)
336+
337+
313338
Retry -> (NfRetry, heap, allocs)
314339

315340
-- Rule READ
@@ -438,7 +463,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
438463

439464
-- | Execute an STM 'Term' in the 'STM' monad.
440465
--
441-
execTerm :: (MonadSTM m, MonadThrow (STM m))
466+
execTerm :: (MonadSTM m, MonadCatch (STM m))
442467
=> ExecEnv m
443468
-> Term t
444469
-> STM m (ExecValue m t)
@@ -452,6 +477,8 @@ execTerm env t =
452477
let e' = execExpr env e
453478
throwSTM =<< snapshotExecValue e'
454479

480+
Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2
481+
455482
Retry -> retry
456483

457484
ReadTVar n -> do
@@ -492,7 +519,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
492519
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
493520
(snapshotExecValue =<< readTVar v)
494521

495-
execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
522+
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
496523
=> Term t -> m TxResult
497524
execAtomically t =
498525
toTxResult <$> try (atomically action')
@@ -658,7 +685,7 @@ genTerm env tyrep =
658685
Nothing)
659686
]
660687

661-
binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
688+
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]
662689

663690
bindTerm =
664691
sized $ \sz -> do
@@ -672,10 +699,15 @@ genTerm env tyrep =
672699
return (Bind t1 name t2)
673700

674701
orElseTerm =
675-
sized $ \sz -> resize (sz `div` 2) $
702+
scale (`div` 2) $
676703
OrElse <$> genTerm env tyrep
677704
<*> genTerm env tyrep
678705

706+
catchTerm =
707+
scale (`div` 2) $
708+
Catch <$> genTerm env tyrep
709+
<*> genTerm env tyrep
710+
679711
genSomeExpr :: GenEnv -> Gen SomeExpr
680712
genSomeExpr env =
681713
oneof'
@@ -714,6 +746,8 @@ shrinkTerm t =
714746
case t of
715747
Return e -> [Return e' | e' <- shrinkExpr e]
716748
Throw e -> [Throw e' | e' <- shrinkExpr e]
749+
Catch t1 t2 -> [t1, t2]
750+
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
717751
Retry -> []
718752
ReadTVar _ -> []
719753

@@ -722,12 +756,10 @@ shrinkTerm t =
722756
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
723757

724758
Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
725-
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
726-
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
759+
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
727760

728761
OrElse t1 t2 -> [t1, t2]
729-
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
730-
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
762+
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
731763

732764
shrinkExpr :: Expr t -> [Expr t]
733765
shrinkExpr ExprUnit = []
@@ -739,6 +771,12 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
739771
freeNamesTerm :: Term t -> Set NameId
740772
freeNamesTerm (Return e) = freeNamesExpr e
741773
freeNamesTerm (Throw e) = freeNamesExpr e
774+
-- A catch handler should actually have an argument, and then the implementation
775+
-- should handle it. But since current implementation of catch never binds the
776+
-- variable, the following implementation is correct as of now. It needs to be
777+
-- tackled once nested exceptions are handled.
778+
-- TODO: Correctly handle free names when the handler also binds a variable.
779+
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
742780
freeNamesTerm Retry = Set.empty
743781
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
744782
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
@@ -769,6 +807,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
769807
termSize :: Term a -> Int
770808
termSize Return{} = 1
771809
termSize Throw{} = 1
810+
termSize (Catch a b) = 1 + termSize a + termSize b
772811
termSize Retry{} = 1
773812
termSize ReadTVar{} = 1
774813
termSize WriteTVar{} = 1
@@ -779,6 +818,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
779818
termDepth :: Term a -> Int
780819
termDepth Return{} = 1
781820
termDepth Throw{} = 1
821+
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
782822
termDepth Retry{} = 1
783823
termDepth ReadTVar{} = 1
784824
termDepth WriteTVar{} = 1
@@ -791,6 +831,9 @@ showTerm p (Return e) = showParen (p > 10) $
791831
showString "return " . showExpr 11 e
792832
showTerm p (Throw e) = showParen (p > 10) $
793833
showString "throwSTM " . showExpr 11 e
834+
showTerm p (Catch t1 t2) = showParen (p > 9) $
835+
showTerm 10 t1 . showString " `catch` "
836+
. showTerm 10 t2
794837
showTerm _ Retry = showString "retry"
795838
showTerm p (ReadTVar n) = showParen (p > 10) $
796839
showString "readTVar " . showName n

0 commit comments

Comments
 (0)