Skip to content

Commit 9fe5c5e

Browse files
committed
Only shuffle 20% of the time
1 parent 6a781b9 commit 9fe5c5e

File tree

1 file changed

+42
-23
lines changed

1 file changed

+42
-23
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ module Control.Monad.IOSim.Internal
4848

4949
import Prelude hiding (read)
5050

51+
import Data.Deque.Strict (Deque)
52+
import qualified Data.Deque.Strict as Deque
5153
import Data.Dynamic
5254
import Data.Foldable (foldlM, toList, traverse_)
5355
import qualified Data.List as List
@@ -60,8 +62,6 @@ import qualified Data.OrdPSQ as PSQ
6062
import Data.Set (Set)
6163
import qualified Data.Set as Set
6264
import Data.Time (UTCTime (..), fromGregorian)
63-
import Data.Deque.Strict (Deque)
64-
import qualified Data.Deque.Strict as Deque
6565

6666
import Control.Exception (NonTermination (..), assert, throw)
6767
import Control.Monad (join, when)
@@ -76,13 +76,16 @@ import Control.Monad.Class.MonadSTM hiding (STM)
7676
import Control.Monad.Class.MonadSTM.Internal (TMVarDefault (TMVar))
7777
import Control.Monad.Class.MonadThrow hiding (getMaskingState)
7878
import Control.Monad.Class.MonadTime
79-
import Control.Monad.Class.MonadTimer.SI (TimeoutState (..), DiffTime, diffTimeToMicrosecondsAsInt, microsecondsAsIntToDiffTime)
79+
import Control.Monad.Class.MonadTimer.SI (DiffTime, TimeoutState (..),
80+
diffTimeToMicrosecondsAsInt, microsecondsAsIntToDiffTime)
8081

8182
import Control.Monad.IOSim.InternalTypes
8283
import Control.Monad.IOSim.Types hiding (SimEvent (SimPOREvent),
8384
Trace (SimPORTrace))
8485
import Control.Monad.IOSim.Types (SimEvent)
85-
import System.Random (StdGen, randomR, split)
86+
import Data.Bifunctor (first)
87+
import Data.Ord (comparing)
88+
import System.Random (StdGen, randomR, split)
8689

8790
--
8891
-- Simulation interpreter
@@ -849,31 +852,47 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
849852
timeoutSTMAction TimerTimeout{} = return ()
850853

851854
unblockThreads :: Bool -> [IOSimThreadId] -> SimState s a -> ([IOSimThreadId], SimState s a)
852-
unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
855+
unblockThreads !onlySTM !wakeup simstate@SimState {runqueue, threads, stdGen} =
853856
-- To preserve our invariants (that threadBlocked is correct)
854857
-- we update the runqueue and threads together here
855858
(unblocked, simstate {
856-
runqueue = Deque.fromList (shuffledRunqueue ++ rest),
859+
runqueue = runqueue <> Deque.fromList unblocked,
857860
threads = threads',
858861
stdGen = stdGen''
859862
})
860863
where
861-
!(shuffledRunqueue, stdGen'') = fisherYatesShuffle stdGen' toShuffle
862-
!((toShuffle, rest), stdGen') =
863-
let runqueueList = Deque.toList $ runqueue <> Deque.fromList unblocked
864-
runqueueListLength = max 1 (length runqueueList)
865-
(ix, newGen) = randomR (0, runqueueListLength `div` 2) stdGen
866-
in (splitAt ix runqueueList, newGen)
867864
-- can only unblock if the thread exists and is blocked (not running)
868-
!unblocked = [ tid
869-
| tid <- wakeup
870-
, case Map.lookup tid threads of
871-
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
872-
-> True
873-
Just Thread { threadStatus = ThreadBlocked _ }
874-
-> not onlySTM
875-
_ -> False
876-
]
865+
!blockedOnOther = [ (tid, ix)
866+
| (tid, ix) <- zip wakeup [0 :: Int ..]
867+
, case Map.lookup tid threads of
868+
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
869+
-> False
870+
Just Thread { threadStatus = ThreadBlocked _ }
871+
-> not onlySTM
872+
_ -> False
873+
]
874+
875+
!blockedOnSTM = [ (tid, ix)
876+
| (tid, ix) <- zip wakeup [0 :: Int ..]
877+
, case Map.lookup tid threads of
878+
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
879+
-> True
880+
_ -> False
881+
]
882+
883+
mergeByIndex :: Ord a => [(b, a)] -> [(b, a)] -> [b]
884+
mergeByIndex a b = map fst $ List.sortBy (comparing snd) (a ++ b)
885+
886+
-- Shuffle only 1/5th of the time
887+
(shouldShuffle, !stdGen') =
888+
first (== 0) $ randomR (0 :: Int, 5) stdGen
889+
890+
(!shuffledBlockedOnSTM, !stdGen'')
891+
| shouldShuffle = fisherYatesShuffle stdGen' blockedOnSTM
892+
| otherwise = (blockedOnSTM, stdGen')
893+
894+
!unblocked = mergeByIndex blockedOnOther shuffledBlockedOnSTM
895+
877896
-- and in which case we mark them as now running
878897
!threads' = List.foldl'
879898
(flip (Map.adjust (\t -> t { threadStatus = ThreadRunning })))
@@ -889,8 +908,8 @@ unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
889908
where
890909
go 0 lst g = (lst, g)
891910
go n lst g = let (k, newGen) = randomR (0, n) g
892-
(x:xs) = drop k lst
893-
swapped = take k lst ++ [lst !! n] ++ drop (k + 1) lst
911+
(x:xs) = drop k lst
912+
swapped = take k lst ++ [lst !! n] ++ drop (k + 1) lst
894913
in go (n - 1) (take n swapped ++ [x] ++ drop n xs) newGen
895914

896915
-- | This function receives a list of TimerTimeout values that represent threads

0 commit comments

Comments
 (0)