@@ -68,6 +68,7 @@ data Term (t :: Type) where
68
68
69
69
Return :: Expr t -> Term t
70
70
Throw :: Expr a -> Term t
71
+ Catch :: Term t -> Term t -> Term t
71
72
Retry :: Term t
72
73
73
74
ReadTVar :: Name (TyVar t ) -> Term t
@@ -297,7 +298,7 @@ deriving instance Show (NfTerm t)
297
298
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
298
299
--
299
300
-- Compare the implementation of this against the operational semantics in
300
- -- Figure 4 in the paper. Note that @catch@ is not included .
301
+ -- Figure 4 in the paper including the `Catch` semantics from the Appendix A .
301
302
--
302
303
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t , Heap , Allocs )
303
304
evalTerm ! env ! heap ! allocs term = case term of
@@ -310,6 +311,30 @@ evalTerm !env !heap !allocs term = case term of
310
311
where
311
312
e' = evalExpr env e
312
313
314
+ -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
315
+ -- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
316
+ Catch t1 t2 ->
317
+ let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
318
+
319
+ -- Rule XSTM1
320
+ -- M; heap, {} => return P; heap', allocs'
321
+ -- --------------------------------------------------------
322
+ -- S[catch M N]; heap, allocs => S[return P]; heap', allocs'
323
+ NfReturn v -> (NfReturn v, heap', allocs <> allocs')
324
+
325
+ -- Rule XSTM2
326
+ -- M; heap, {} => throw P; heap', allocs'
327
+ -- --------------------------------------------------------
328
+ -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
329
+ NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
330
+
331
+ -- Rule XSTM3
332
+ -- M; heap, {} => retry; heap', allocs'
333
+ -- --------------------------------------------------------
334
+ -- S[catch M N]; heap, allocs => S[retry]; heap, allocs
335
+ NfRetry -> (NfRetry , heap, allocs)
336
+
337
+
313
338
Retry -> (NfRetry , heap, allocs)
314
339
315
340
-- Rule READ
@@ -438,7 +463,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
438
463
439
464
-- | Execute an STM 'Term' in the 'STM' monad.
440
465
--
441
- execTerm :: (MonadSTM m , MonadThrow (STM m ))
466
+ execTerm :: (MonadSTM m , MonadCatch (STM m ))
442
467
=> ExecEnv m
443
468
-> Term t
444
469
-> STM m (ExecValue m t )
@@ -452,6 +477,8 @@ execTerm env t =
452
477
let e' = execExpr env e
453
478
throwSTM =<< snapshotExecValue e'
454
479
480
+ Catch t1 t2 -> execTerm env t1 `catch` \ (_ :: ImmValue ) -> execTerm env t2
481
+
455
482
Retry -> retry
456
483
457
484
ReadTVar n -> do
@@ -492,7 +519,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
492
519
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
493
520
(snapshotExecValue =<< readTVar v)
494
521
495
- execAtomically :: forall m t . (MonadSTM m , MonadThrow (STM m ), MonadCatch m )
522
+ execAtomically :: forall m t . (MonadSTM m , MonadCatch (STM m ), MonadCatch m )
496
523
=> Term t -> m TxResult
497
524
execAtomically t =
498
525
toTxResult <$> try (atomically action')
@@ -658,7 +685,7 @@ genTerm env tyrep =
658
685
Nothing )
659
686
]
660
687
661
- binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm)]
688
+ binTerm = frequency [ (2 , bindTerm), (1 , orElseTerm), ( 1 , catchTerm) ]
662
689
663
690
bindTerm =
664
691
sized $ \ sz -> do
@@ -672,10 +699,15 @@ genTerm env tyrep =
672
699
return (Bind t1 name t2)
673
700
674
701
orElseTerm =
675
- sized $ \ sz -> resize (sz `div` 2 ) $
702
+ scale ( `div` 2 ) $
676
703
OrElse <$> genTerm env tyrep
677
704
<*> genTerm env tyrep
678
705
706
+ catchTerm =
707
+ scale (`div` 2 ) $
708
+ Catch <$> genTerm env tyrep
709
+ <*> genTerm env tyrep
710
+
679
711
genSomeExpr :: GenEnv -> Gen SomeExpr
680
712
genSomeExpr env =
681
713
oneof'
@@ -714,6 +746,8 @@ shrinkTerm t =
714
746
case t of
715
747
Return e -> [Return e' | e' <- shrinkExpr e]
716
748
Throw e -> [Throw e' | e' <- shrinkExpr e]
749
+ Catch t1 t2 -> [t1, t2]
750
+ ++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
717
751
Retry -> []
718
752
ReadTVar _ -> []
719
753
@@ -722,12 +756,10 @@ shrinkTerm t =
722
756
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
723
757
724
758
Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
725
- ++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
726
- ++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
759
+ ++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
727
760
728
761
OrElse t1 t2 -> [t1, t2]
729
- ++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
730
- ++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
762
+ ++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
731
763
732
764
shrinkExpr :: Expr t -> [Expr t ]
733
765
shrinkExpr ExprUnit = []
@@ -739,6 +771,12 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
739
771
freeNamesTerm :: Term t -> Set NameId
740
772
freeNamesTerm (Return e) = freeNamesExpr e
741
773
freeNamesTerm (Throw e) = freeNamesExpr e
774
+ -- A catch handler should actually have an argument, and then the implementation
775
+ -- should handle it. But since current implementation of catch never binds the
776
+ -- variable, the following implementation is correct as of now. It needs to be
777
+ -- tackled once nested exceptions are handled.
778
+ -- TODO: Correctly handle free names when the handler also binds a variable.
779
+ freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
742
780
freeNamesTerm Retry = Set. empty
743
781
freeNamesTerm (ReadTVar n) = Set. singleton (nameId n)
744
782
freeNamesTerm (WriteTVar n e) = Set. singleton (nameId n) <> freeNamesExpr e
@@ -769,6 +807,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
769
807
termSize :: Term a -> Int
770
808
termSize Return {} = 1
771
809
termSize Throw {} = 1
810
+ termSize (Catch a b) = 1 + termSize a + termSize b
772
811
termSize Retry {} = 1
773
812
termSize ReadTVar {} = 1
774
813
termSize WriteTVar {} = 1
@@ -779,6 +818,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
779
818
termDepth :: Term a -> Int
780
819
termDepth Return {} = 1
781
820
termDepth Throw {} = 1
821
+ termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
782
822
termDepth Retry {} = 1
783
823
termDepth ReadTVar {} = 1
784
824
termDepth WriteTVar {} = 1
@@ -791,6 +831,9 @@ showTerm p (Return e) = showParen (p > 10) $
791
831
showString " return " . showExpr 11 e
792
832
showTerm p (Throw e) = showParen (p > 10 ) $
793
833
showString " throwSTM " . showExpr 11 e
834
+ showTerm p (Catch t1 t2) = showParen (p > 9 ) $
835
+ showTerm 10 t1 . showString " `catch` "
836
+ . showTerm 10 t2
794
837
showTerm _ Retry = showString " retry"
795
838
showTerm p (ReadTVar n) = showParen (p > 10 ) $
796
839
showString " readTVar " . showName n
0 commit comments