@@ -48,6 +48,8 @@ module Control.Monad.IOSim.Internal
48
48
49
49
import Prelude hiding (read )
50
50
51
+ import Data.Deque.Strict (Deque )
52
+ import qualified Data.Deque.Strict as Deque
51
53
import Data.Dynamic
52
54
import Data.Foldable (foldlM , toList , traverse_ )
53
55
import qualified Data.List as List
@@ -60,8 +62,6 @@ import qualified Data.OrdPSQ as PSQ
60
62
import Data.Set (Set )
61
63
import qualified Data.Set as Set
62
64
import Data.Time (UTCTime (.. ), fromGregorian )
63
- import Data.Deque.Strict (Deque )
64
- import qualified Data.Deque.Strict as Deque
65
65
66
66
import Control.Exception (NonTermination (.. ), assert , throw )
67
67
import Control.Monad (join , when )
@@ -76,13 +76,16 @@ import Control.Monad.Class.MonadSTM hiding (STM)
76
76
import Control.Monad.Class.MonadSTM.Internal (TMVarDefault (TMVar ))
77
77
import Control.Monad.Class.MonadThrow hiding (getMaskingState )
78
78
import Control.Monad.Class.MonadTime
79
- import Control.Monad.Class.MonadTimer.SI (TimeoutState (.. ), DiffTime , diffTimeToMicrosecondsAsInt , microsecondsAsIntToDiffTime )
79
+ import Control.Monad.Class.MonadTimer.SI (DiffTime , TimeoutState (.. ),
80
+ diffTimeToMicrosecondsAsInt , microsecondsAsIntToDiffTime )
80
81
81
82
import Control.Monad.IOSim.InternalTypes
82
83
import Control.Monad.IOSim.Types hiding (SimEvent (SimPOREvent ),
83
84
Trace (SimPORTrace ))
84
85
import Control.Monad.IOSim.Types (SimEvent )
85
- import System.Random (StdGen , randomR , split )
86
+ import Data.Bifunctor (first )
87
+ import Data.Ord (comparing )
88
+ import System.Random (StdGen , randomR , split )
86
89
87
90
--
88
91
-- Simulation interpreter
@@ -849,31 +852,47 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
849
852
timeoutSTMAction TimerTimeout {} = return ()
850
853
851
854
unblockThreads :: Bool -> [IOSimThreadId ] -> SimState s a -> ([IOSimThreadId ], SimState s a )
852
- unblockThreads ! onlySTM ! wakeup ! simstate@ SimState {runqueue, threads, stdGen} =
855
+ unblockThreads ! onlySTM ! wakeup simstate@ SimState {runqueue, threads, stdGen} =
853
856
-- To preserve our invariants (that threadBlocked is correct)
854
857
-- we update the runqueue and threads together here
855
858
(unblocked, simstate {
856
- runqueue = Deque. fromList (shuffledRunqueue ++ rest) ,
859
+ runqueue = runqueue <> Deque. fromList unblocked ,
857
860
threads = threads',
858
861
stdGen = stdGen''
859
862
})
860
863
where
861
- ! (shuffledRunqueue, stdGen'') = fisherYatesShuffle stdGen' toShuffle
862
- ! ((toShuffle, rest), stdGen') =
863
- let runqueueList = Deque. toList $ runqueue <> Deque. fromList unblocked
864
- runqueueListLength = max 1 (length runqueueList)
865
- (ix, newGen) = randomR (0 , runqueueListLength `div` 2 ) stdGen
866
- in (splitAt ix runqueueList, newGen)
867
864
-- can only unblock if the thread exists and is blocked (not running)
868
- ! unblocked = [ tid
869
- | tid <- wakeup
870
- , case Map. lookup tid threads of
871
- Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
872
- -> True
873
- Just Thread { threadStatus = ThreadBlocked _ }
874
- -> not onlySTM
875
- _ -> False
876
- ]
865
+ ! blockedOnOther = [ (tid, ix)
866
+ | (tid, ix) <- zip wakeup [0 :: Int .. ]
867
+ , case Map. lookup tid threads of
868
+ Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
869
+ -> False
870
+ Just Thread { threadStatus = ThreadBlocked _ }
871
+ -> not onlySTM
872
+ _ -> False
873
+ ]
874
+
875
+ ! blockedOnSTM = [ (tid, ix)
876
+ | (tid, ix) <- zip wakeup [0 :: Int .. ]
877
+ , case Map. lookup tid threads of
878
+ Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
879
+ -> True
880
+ _ -> False
881
+ ]
882
+
883
+ mergeByIndex :: Ord a => [(b , a )] -> [(b , a )] -> [b ]
884
+ mergeByIndex a b = map fst $ List. sortBy (comparing snd ) (a ++ b)
885
+
886
+ -- Shuffle only 1/5th of the time
887
+ (shouldShuffle, ! stdGen') =
888
+ first (== 0 ) $ randomR (0 :: Int , 5 ) stdGen
889
+
890
+ (! shuffledBlockedOnSTM, ! stdGen'')
891
+ | shouldShuffle = fisherYatesShuffle stdGen' blockedOnSTM
892
+ | otherwise = (blockedOnSTM, stdGen')
893
+
894
+ ! unblocked = mergeByIndex blockedOnOther shuffledBlockedOnSTM
895
+
877
896
-- and in which case we mark them as now running
878
897
! threads' = List. foldl'
879
898
(flip (Map. adjust (\ t -> t { threadStatus = ThreadRunning })))
@@ -889,8 +908,8 @@ unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
889
908
where
890
909
go 0 lst g = (lst, g)
891
910
go n lst g = let (k, newGen) = randomR (0 , n) g
892
- (x: xs) = drop k lst
893
- swapped = take k lst ++ [lst !! n] ++ drop (k + 1 ) lst
911
+ (x: xs) = drop k lst
912
+ swapped = take k lst ++ [lst !! n] ++ drop (k + 1 ) lst
894
913
in go (n - 1 ) (take n swapped ++ [x] ++ drop n xs) newGen
895
914
896
915
-- | This function receives a list of TimerTimeout values that represent threads
0 commit comments