diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index 780d4559a..0d43335b7 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -1,5 +1,10 @@ # Changelog for persistent-postgresql +## 2.13.6.0 + +* [#1482](https://github.com/yesodweb/persistent/pull/1482) + * Add `isSerializationFailure` and `isDeadlockDetected` exception predicates + ## 2.13.5.2 * [#1471](https://github.com/yesodweb/persistent/pull/1471) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 4dd2dcad5..cf1d2748d 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -38,6 +38,10 @@ module Database.Persist.Postgresql , createPostgresqlPoolModified , createPostgresqlPoolModifiedWithVersion , createPostgresqlPoolWithConf + + , isSerializationFailure + , isDeadlockDetected + , module Database.Persist.Sql , ConnectionString , HandleUpdateCollision @@ -77,13 +81,14 @@ import qualified Database.PostgreSQL.Simple.Transaction as PG import qualified Database.PostgreSQL.Simple.Types as PG import Control.Arrow -import Control.Exception (Exception, throw, throwIO) +import Control.Exception + (Exception(fromException), SomeException, throw, throwIO) import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift (MonadIO(..), MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) -import Control.Monad.Trans.Reader (ReaderT(..), asks, runReaderT) import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Reader (ReaderT(..), asks, runReaderT) #if !MIN_VERSION_base(4,12,0) import Control.Monad.Trans.Reader (withReaderT) #endif @@ -102,8 +107,8 @@ import qualified Data.Conduit.List as CL import Data.Data (Data) import Data.Either (partitionEithers) import Data.Function (on) -import Data.IORef import Data.Int (Int64) +import Data.IORef import Data.List (find, foldl', groupBy, sort) import qualified Data.List as List import Data.List.NonEmpty (NonEmpty) @@ -122,12 +127,13 @@ import System.Environment (getEnvironment) #if MIN_VERSION_base(4,12,0) import Database.Persist.Compatible #endif +import qualified Data.Vault.Strict as Vault import Database.Persist.Postgresql.Internal import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util import Database.Persist.SqlBackend -import Database.Persist.SqlBackend.StatementCache (StatementCache, mkSimpleStatementCache, mkStatementCache) -import qualified Data.Vault.Strict as Vault +import Database.Persist.SqlBackend.StatementCache + (StatementCache, mkSimpleStatementCache, mkStatementCache) import System.IO.Unsafe (unsafePerformIO) -- | A @libpq@ connection string. A simple example of connection @@ -1953,6 +1959,31 @@ createRawPostgresqlPoolWithConf conf hooks = do modConn = pgConfHooksAfterCreate hooks createSqlPoolWithConfig (open' modConn getVer withRawConnection (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf) +-- | An exception predicate checking for a PostgreSQL serialization error, i.e. +-- a @SQLSTATE@ error code of @"40001"@ (@serialization_failure@). +-- +-- This error can occur when concurrent transactions modify the same row(s) at +-- serializable isolation level. +-- +-- This predicate is intended for use with 'runSqlPoolWithExtensibleHooksRetry'. +-- +-- @since 2.13.6.0 +isSerializationFailure :: SomeException -> Bool +isSerializationFailure ex + | Just sqlError <- fromException ex = PG.isSerializationError sqlError + | otherwise = False + +-- | An exception predicate checking for a PostgreSQL deadlock detected error, +-- i.e. a @SQLSTATE@ error code of @"40P01"@ (@deadlock_detected@). +-- +-- This predicate is intended for use with 'runSqlPoolWithExtensibleHooksRetry'. +-- +-- @since 2.13.6.0 +isDeadlockDetected :: SomeException -> Bool +isDeadlockDetected ex + | Just sqlError <- fromException ex = PG.sqlState sqlError == "40P01" + | otherwise = False + #if MIN_VERSION_base(4,12,0) instance (PersistCore b) => PersistCore (RawPostgresql b) where newtype BackendKey (RawPostgresql b) = RawPostgresqlKey { unRawPostgresqlKey :: BackendKey (Compatible b (RawPostgresql b)) } diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index 05d4dbb4c..1b1a28446 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -1,5 +1,5 @@ name: persistent-postgresql -version: 2.13.5.2 +version: 2.13.6.0 license: MIT license-file: LICENSE author: Felipe Lessa, Michael Snoyman @@ -58,6 +58,8 @@ test-suite test UpsertWhere ImplicitUuidSpec MigrationReferenceSpec + AsyncExceptionsTest + RetryableTransactionsTest ghc-options: -Wall build-depends: base >= 4.9 && < 5 @@ -86,6 +88,7 @@ test-suite test , unliftio , unordered-containers , vector + , postgresql-simple default-language: Haskell2010 executable conn-kill diff --git a/persistent-postgresql/test/AsyncExceptionsTest.hs b/persistent-postgresql/test/AsyncExceptionsTest.hs new file mode 100644 index 000000000..51d9276cf --- /dev/null +++ b/persistent-postgresql/test/AsyncExceptionsTest.hs @@ -0,0 +1,196 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module AsyncExceptionsTest + ( specs + ) where + +import Control.Concurrent + ( ThreadId + , forkIO + , killThread + , myThreadId + , newEmptyMVar + , putMVar + , takeMVar + ) +import Control.Exception (MaskingState(MaskedUninterruptible), getMaskingState) +import Data.Function ((&)) +import Database.Persist.SqlBackend.SqlPoolHooks (modifyRunOnException) +import GHC.Stack (SrcLoc, callStack, getCallStack) +import HookCounts + ( HookCountRefs(..) + , HookCounts(..) + , hookCountsShouldBe + , newHookCountRefs + , trackHookCounts + ) +import Init (aroundAll_) +import PgInit + ( Filter + , HasCallStack + , MonadIO(..) + , PersistQueryWrite(deleteWhere) + , PersistStoreWrite(insert_) + , ReaderT + , RunConnArgs(sqlPoolHooks) + , Spec + , SqlBackend + , Text + , defaultRunConnArgs + , describe + , it + , mkMigrate + , mkPersist + , persistLowerCase + , runConnUsing + , runConn_ + , runMigrationSilent + , share + , sqlSettings + , void + ) +import Test.HUnit.Lang (FailureReason(Reason), HUnitFailure(HUnitFailure)) +import UnliftIO.Exception (bracket_, throwTo) + +share + [mkPersist sqlSettings, mkMigrate "asyncExceptionsTestMigrate"] + [persistLowerCase| + AsyncExceptionTestData + stuff Text + Primary stuff + deriving Eq Show + |] + +setup :: IO () +setup = runConn_ $ void $ runMigrationSilent asyncExceptionsTestMigrate + +teardown :: IO () +teardown = runConn_ cleanDB + +cleanDB :: forall m. (MonadIO m) => ReaderT SqlBackend m () +cleanDB = deleteWhere ([] :: [Filter AsyncExceptionTestData]) + +specs :: Spec +specs = aroundAll_ (bracket_ setup teardown) $ do + describe "Testing async exceptions" $ do + it "runOnException hook is executed" $ do + insertDoneRef <- newEmptyMVar + shouldProceedRef <- newEmptyMVar + + hookCountRefs <- newHookCountRefs + runConnArgs <- mkRunConnArgs hookCountRefs + + threadId <- forkIO $ do + runConnUsing runConnArgs $ do + insert_ $ AsyncExceptionTestData "bloorp" + liftIO $ do + -- "Child" thread signals to the main thread that the insert was + -- executed. + putMVar insertDoneRef () + -- "Child" thread waits around indefinitely on this @MVar@. + -- @shouldProceedRef@ is intentionally never written to in this test + -- so that the "child" thread is blocked here until the main thread + -- kills it via async exception. See the remaining comments in this + -- test for more detail. + takeMVar shouldProceedRef + + -- Main thread waits here for the signal from the "child" thread telling + -- us the DB insert has been performed. More specifically, we know the + -- following events have occurred in the "child" thread after this + -- @takeMVar@ call succeeds: + -- + -- 1) The @alterBackend@ hook was executed + -- 2) The @runBefore@ hook was executed + -- 3) The insert of our test data was executed + -- 4) Execution is blocked right after the insert, so either of the + -- @runOnException@ or @runAfter@ hooks have not yet been executed. + takeMVar insertDoneRef + + -- Verify that the actual hook execution in the "child" thread is as + -- described previously. + hookCountRefs `hookCountsShouldBe` + HookCounts + { alterBackendCount = 1 + , runBeforeCount = 1 + , runOnExceptionCount = 0 + , runAfterCount = 0 + } + + -- Main thread kills the "child" thread via async exception while the + -- "child" thread is still in its user-specified DB action, which should + -- cause the @runOnException@ hook to fire, rolling back the transaction. + -- + -- Note that the @runOnException@ hook produced by @mkRunConnArgs@ also + -- ensures the handler's masking state is uninterruptible. See + -- @mkRunConnArgs@ for that check's implementation. + killThread threadId + + -- Verify that the @runOnException@ hook was indeed executed. + hookCountRefs `hookCountsShouldBe` + HookCounts + { alterBackendCount = 1 + , runBeforeCount = 1 + , runOnExceptionCount = 1 + , runAfterCount = 0 + } + +-- | Build a 'RunConnArgs' value for use in this module's specs. +-- +-- This function should only be called from the main thread. +mkRunConnArgs + :: forall m + . (MonadIO m) + => HookCountRefs + -> m (RunConnArgs m) +mkRunConnArgs hookCountRefs = do + threadId <- liftIO myThreadId + pure $ (defaultRunConnArgs @m) + { sqlPoolHooks = + trackHookCounts hookCountRefs (sqlPoolHooks defaultRunConnArgs) + & flip modifyRunOnException (\origRunOnException conn level ex -> do + -- It's sneaky to make this masking state assertion here rather + -- than explicitly in a spec. At this time, it feels a bit cleaner + -- to keep this assertion tucked away in here. The downside is + -- that this function does not run in the main thread, so we must + -- throw an expectation failure into the main thread on assertion + -- failure to have it reported by Hspec. + liftIO $ + getMaskingState >>= \case + MaskedUninterruptible -> pure () + _ -> + throwExpectationFailureTo + threadId + "Expected runOnException masking to be uninterruptible" + + origRunOnException conn level ex + ) + } + +throwExpectationFailureTo + :: HasCallStack + => ThreadId + -> String + -> IO () +throwExpectationFailureTo threadId msg = + throwTo threadId $ HUnitFailure location $ Reason msg + +location :: HasCallStack => Maybe SrcLoc +location = case reverse $ getCallStack callStack of + (_, loc) : _ -> Just loc + [] -> Nothing diff --git a/persistent-postgresql/test/PgInit.hs b/persistent-postgresql/test/PgInit.hs index cdad410fb..31a50232e 100644 --- a/persistent-postgresql/test/PgInit.hs +++ b/persistent-postgresql/test/PgInit.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -10,6 +11,10 @@ module PgInit , runConnAssert , runConnAssertUseConf + , runConnUsing + , defaultRunConnArgs + , RunConnArgs(..) + , MonadIO , persistSettings , MkPersistSettings (..) @@ -119,12 +124,15 @@ import Data.Maybe (fromMaybe) import Data.Monoid ((<>)) import Data.Text (Text) import Data.Vector (Vector) +import Database.PostgreSQL.Simple (SqlError(SqlError)) import System.Environment (getEnvironment) import System.Log.FastLogger (fromLogStr) import Database.Persist import Database.Persist.Postgresql import Database.Persist.Sql +import Database.Persist.SqlBackend.SqlPoolHooks + (SqlPoolHooks, defaultSqlPoolHooks) import Database.Persist.TH () _debugOn :: Bool @@ -144,7 +152,7 @@ runConn :: MonadUnliftIO m => SqlPersistT (LoggingT m) t -> m () runConn f = runConn_ f >>= const (return ()) runConn_ :: MonadUnliftIO m => SqlPersistT (LoggingT m) t -> m t -runConn_ f = runConnInternal RunConnBasic f +runConn_ f = runConnUsing defaultRunConnArgs f -- | Data type to switch between pool creation functions, to ease testing both. data RunConnType = @@ -152,8 +160,28 @@ data RunConnType = | RunConnConf -- ^ Use 'withPostgresqlPoolWithConf' deriving (Show, Eq) -runConnInternal :: MonadUnliftIO m => RunConnType -> SqlPersistT (LoggingT m) t -> m t -runConnInternal connType f = do +data RunConnArgs m = RunConnArgs + { connType :: RunConnType + , sqlPoolHooks :: SqlPoolHooks (LoggingT m) SqlBackend + , level :: Maybe IsolationLevel + , shouldRetry :: SomeException -> Bool + } + +defaultRunConnArgs :: forall m . (MonadIO m) => RunConnArgs m +defaultRunConnArgs = + RunConnArgs + { connType = RunConnBasic + , sqlPoolHooks = defaultSqlPoolHooks + , level = Nothing + , shouldRetry = const False + } + +runConnUsing + :: MonadUnliftIO m + => RunConnArgs m + -> SqlPersistT (LoggingT m) t + -> m t +runConnUsing RunConnArgs { connType, sqlPoolHooks, level, shouldRetry } action = do travis <- liftIO isTravis let debugPrint = not travis && _debugOn printDebug = if debugPrint then print . fromLogStr else void . return @@ -170,7 +198,13 @@ runConnInternal connType f = do let go = case connType of RunConnBasic -> - withPostgresqlPool connString poolSize $ runSqlPool f + withPostgresqlPool connString poolSize $ \pool -> do + runSqlPoolWithExtensibleHooksRetry + shouldRetry + action + pool + level + sqlPoolHooks RunConnConf -> do let conf = PostgresConf { pgConnStr = connString @@ -178,22 +212,30 @@ runConnInternal connType f = do , pgPoolIdleTimeout = 60 , pgPoolSize = poolSize } - hooks = defaultPostgresConfHooks - withPostgresqlPoolWithConf conf hooks (runSqlPool f) + pgConfHooks = defaultPostgresConfHooks + withPostgresqlPoolWithConf conf pgConfHooks $ \pool -> do + runSqlPoolWithExtensibleHooksRetry + shouldRetry + action + pool + level + sqlPoolHooks -- horrifying hack :( postgresql is having weird connection failures in -- CI, for no reason that i can determine. see this PR for notes: -- https://github.com/yesodweb/persistent/pull/1197 eres <- try go case eres of - Left (err :: SomeException) -> do - eres' <- try go - case eres' of - Left (err' :: SomeException) -> - if show err == show err' - then throwIO err - else throwIO err' - Right a -> - pure a + Left (err :: SomeException) + | isSqlError err -> throwIO err -- throw, rather than trying the action again + | otherwise -> do + eres' <- try go + case eres' of + Left (err' :: SomeException) -> + if show err == show err' + then throwIO err + else throwIO err' + Right a -> + pure a Right a -> pure a @@ -204,7 +246,14 @@ runConnAssert actions = do -- | Like runConnAssert, but uses the "conf" flavor of functions to test that code path. runConnAssertUseConf :: SqlPersistT (LoggingT (ResourceT IO)) () -> Assertion runConnAssertUseConf actions = do - runResourceT $ runConnInternal RunConnConf (actions >> transactionUndo) + runResourceT + $ runConnUsing defaultRunConnArgs { connType = RunConnConf } + $ actions >> transactionUndo + +isSqlError :: SomeException -> Bool +isSqlError ex + | Just SqlError {} <- fromException ex = True + | otherwise = False newtype AValue = AValue { getValue :: Value } diff --git a/persistent-postgresql/test/RetryableTransactionsTest.hs b/persistent-postgresql/test/RetryableTransactionsTest.hs new file mode 100644 index 000000000..7d5b99ef9 --- /dev/null +++ b/persistent-postgresql/test/RetryableTransactionsTest.hs @@ -0,0 +1,249 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module RetryableTransactionsTest + ( specs + ) where + +import Control.Concurrent (threadDelay) +import Data.Foldable (find) +import qualified Data.Text as Text +import Database.Persist.Postgresql (isSerializationFailure) +import HookCounts + ( HookCounts(HookCounts, alterBackendCount, runAfterCount, runBeforeCount, runOnExceptionCount) + , hookCountsShouldBe + , newHookCountRefs + , trackHookCounts + ) +import Init (IsolationLevel(Serializable), aroundAll_, guard) +import PgInit + ( Filter + , MonadIO(..) + , PersistQueryWrite(deleteWhere) + , ReaderT + , RunConnArgs(level, shouldRetry, sqlPoolHooks) + , Single(unSingle) + , Spec + , SqlBackend + , Text + , defaultRunConnArgs + , describe + , expectationFailure + , get + , insert + , it + , mkMigrate + , mkPersist + , persistLowerCase + , rawSql + , runConnUsing + , runConn_ + , runMigrationSilent + , share + , shouldReturn + , sqlSettings + , update + , void + , (+=.) + , (-=.) + ) +import UnliftIO.Async (Concurrently(Concurrently, runConcurrently)) +import UnliftIO.Exception (bracket_) +import UnliftIO.STM (atomically, newTVarIO, readTVar, writeTVar) +import UnliftIO.Timeout (timeout) + +share + [mkPersist sqlSettings, mkMigrate "retryableTransactionsTestMigrate"] + [persistLowerCase| + RetryableTransactionsTestData + stuff Text + things Int + Primary stuff + deriving Eq Show + |] + +setup :: IO () +setup = runConn_ $ void $ runMigrationSilent retryableTransactionsTestMigrate + +teardown :: IO () +teardown = runConn_ cleanDB + +cleanDB :: forall m. (MonadIO m) => ReaderT SqlBackend m () +cleanDB = deleteWhere ([] :: [Filter RetryableTransactionsTestData]) + +specs :: Spec +specs = aroundAll_ (bracket_ setup teardown) $ do + describe "Testing retryable transactions" $ do + it "serializable isolation" $ do + hookCountRefs <- newHookCountRefs + let runConnArgs = + (defaultRunConnArgs @IO) + { level = Just Serializable + , shouldRetry = isSerializationFailure + , sqlPoolHooks = + trackHookCounts hookCountRefs + $ sqlPoolHooks defaultRunConnArgs + } + + child1WithinTxRef <- newTVarIO False + child1ShouldUpdateRef <- newTVarIO False + child1UpdateDoneRef <- newTVarIO False + child1ShouldCommitRef <- newTVarIO False + + child2WithinTxRef <- newTVarIO False + child2ShouldUpdateRef <- newTVarIO False + + -- From the main thread, insert a row for subsequent use from the spawned + -- threads. + key <- runConnUsing runConnArgs $ do + insert $ RetryableTransactionsTestData "bloorp" 42 + + -- This test launches and waits for three threads. The first two threads + -- perform an update on the same row. The third thread coordinates when + -- the first two threads can proceed with their steps. This test will + -- reproduce a database-level serialization error from thread 2, and so + -- thread 2's transaction should be retried. While it isn't much code, the + -- sequence is nuanced. The exact sequence is as follows: + -- + -- 1) Threads 1 and 2 each start up a serializable transaction and + -- indicate to thread 3 that they have done so. + -- 2) Threads 1 and 2 await a go-ahead from thread 3 for them to proceed. + -- 3) When thread 3 receives the signals indicating threads 1 and 2 are + -- currently within transactions, it signals to thread 1 that it may + -- proceed with its update. + -- 4) Thread 1 performs its update and indicates to thread 3 that it has + -- done so. + -- 5) Thread 1 awaits a go-ahead from thread 3 for it to commit its + -- transaction. + -- 6) When thread 3 receives the signal indicating thread 1 has performed + -- its update, it signals to thread 2 that it may proceed with its update. + -- 7) Thread 2 attempts to perform its update. This update is blocking due + -- to serializable isolation, i.e. thread 1 has already performed an + -- update on the row and now thread 2 is attempting to update the same + -- row. + -- 8) Thread 3 polls the database for a signal indicating that thread 2's + -- update statement is blocked. + -- 9) When thread 3 receives the signal indicating thread 2's update + -- statement is indeed blocked, it signals to thread 1 that it may commit + -- its transaction. + -- 10) Thread 1 commmits its transaction. The database then reports a + -- serialization error in thread 2's open transaction. With support in + -- persistent for retryable transactions, this error can be detected and + -- thread 2's transaction can be retried immediately. The subsequent retry + -- of the transaction will complete successfully, as there will be no + -- other concurrent transactions trying to update the same row this time + -- around. + mTimeoutRes <- timeout 10000000 $ runConcurrently $ + (\() () () -> ()) + <$> Concurrently + ( runConnUsing runConnArgs $ do + liftIO $ atomically $ writeTVar child1WithinTxRef True + liftIO $ atomically $ guard =<< readTVar child1ShouldUpdateRef + update key [RetryableTransactionsTestDataThings -=. 1] + liftIO $ atomically $ writeTVar child1UpdateDoneRef True + liftIO $ atomically $ guard =<< readTVar child1ShouldCommitRef + ) + <*> Concurrently + ( runConnUsing runConnArgs $ do + liftIO $ atomically $ writeTVar child2WithinTxRef True + liftIO $ atomically $ guard =<< readTVar child2ShouldUpdateRef + update key [RetryableTransactionsTestDataThings +=. 1] + ) + <*> Concurrently + ( do + atomically $ do + child1WithinTx <- readTVar child1WithinTxRef + child2WithinTx <- readTVar child2WithinTxRef + guard $ child1WithinTx && child2WithinTx + writeTVar child1ShouldUpdateRef True + + -- Check hook execution counts, verifying the following: + -- 1) The main thread completed a full transaction when it + -- inserted the test data. This contributes 1 towards + -- @alterBackendCount@, 1 towards @runBeforeCount@, and 1 + -- towards @runAfterCount@). + -- 2) Child threads 1 and 2 both have started a transaction. + -- These child threads each contribute 1 towards + -- @alterBackendCount@ and 1 towards @runBeforeCount@. + hookCountRefs `hookCountsShouldBe` + HookCounts + { alterBackendCount = 3 + , runBeforeCount = 3 + , runOnExceptionCount = 0 + , runAfterCount = 1 + } + + atomically $ do + guard =<< readTVar child1UpdateDoneRef + writeTVar child2ShouldUpdateRef True + + pollForBlockedQuery runConnArgs $ Text.unwords + [ "UPDATE \"retryable_transactions_test_data\"" + , "SET \"things\"=\"things\"+1" + , "WHERE \"stuff\"='bloorp' " + ] + + atomically $ do + writeTVar child1ShouldCommitRef True + + -- Check hook execution counts, verifying the following: + -- 1) The counts checked previously were preserved. + -- 2) Child thread 1 completed its transaction, contributing + -- 1 towards @runAfterCount@. + -- 3) Child thread 2 retried its transaction on encountering + -- a serialization failure, so it contributes 1 to + -- @runOnExceptionCount@, then 1 each for @runBeforeCount@ + -- and @runAfterCount@ for the new transaction. + hookCountRefs `hookCountsShouldBe` + HookCounts + { alterBackendCount = 4 + , runBeforeCount = 4 + , runOnExceptionCount = 1 + , runAfterCount = 3 + } + ) + + case mTimeoutRes of + Nothing -> + expectationFailure "Serializable isolation test threads took too long" + Just () -> pure () + + runConnUsing runConnArgs (get key) + `shouldReturn` Just (RetryableTransactionsTestData "bloorp" 42) + +pollForBlockedQuery :: RunConnArgs IO -> Text -> IO () +pollForBlockedQuery runConnArgs targetBlockedQuery = do + timeout 10000000 go >>= \case + Nothing -> expectationFailure "pollForBlockedQuery: took too long" + Just () -> pure () + where + go = do + blockedQueries :: [Text] <- + fmap (fmap unSingle) $ runConnUsing runConnArgs $ do + rawSql query [] + case find (== targetBlockedQuery) blockedQueries of + Nothing -> + threadDelay 200000 *> pollForBlockedQuery runConnArgs targetBlockedQuery + Just {} -> pure () + + query = + Text.unwords + [ "select query" + , "from pg_stat_activity" + , "where cardinality(pg_blocking_pids(pid)) > 0" + ] diff --git a/persistent-postgresql/test/main.hs b/persistent-postgresql/test/main.hs index c00650ac0..6e7db85b6 100644 --- a/persistent-postgresql/test/main.hs +++ b/persistent-postgresql/test/main.hs @@ -23,6 +23,7 @@ import Data.Time import Test.QuickCheck import qualified ArrayAggTest +import qualified AsyncExceptionsTest import qualified CompositeTest import qualified CustomConstraintTest import qualified CustomPersistFieldTest @@ -47,14 +48,15 @@ import qualified MigrationReferenceSpec import qualified MigrationTest import qualified MpsCustomPrefixTest import qualified MpsNoPrefixTest -import qualified PersistUniqueTest import qualified PersistentTest +import qualified PersistUniqueTest import qualified PgIntervalTest import qualified PrimaryTest import qualified RawSqlTest import qualified ReadWriteTest import qualified Recursive import qualified RenameTest +import qualified RetryableTransactionsTest import qualified SumTypeTest import qualified TransactionLevelTest import qualified TreeTest @@ -214,3 +216,5 @@ main = do PgIntervalTest.specs ArrayAggTest.specs GeneratedColumnTestSQL.specsWith runConnAssert + AsyncExceptionsTest.specs + RetryableTransactionsTest.specs diff --git a/persistent-test/persistent-test.cabal b/persistent-test/persistent-test.cabal index 72e900b80..9221d6c56 100644 --- a/persistent-test/persistent-test.cabal +++ b/persistent-test/persistent-test.cabal @@ -57,6 +57,7 @@ library UniqueTest UpsertTest LongIdentifierTest + HookCounts hs-source-dirs: src diff --git a/persistent-test/src/HookCounts.hs b/persistent-test/src/HookCounts.hs new file mode 100644 index 000000000..9280ed225 --- /dev/null +++ b/persistent-test/src/HookCounts.hs @@ -0,0 +1,112 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +module HookCounts + ( hookCountsShouldBe + + , HookCountRefs(..) + , newHookCountRefs + + , HookCounts(..) + + , trackHookCounts + ) where + +import Control.Monad.IO.Unlift (MonadIO(liftIO)) +import Data.Function ((&)) +import Database.Persist.SqlBackend.SqlPoolHooks + ( SqlPoolHooks + , modifyAlterBackend + , modifyRunAfter + , modifyRunBefore + , modifyRunOnException + ) +import Init (Expectation, HasCallStack, expectationFailure, guard) +import UnliftIO.STM (STM, TVar, atomically, modifyTVar', newTVarIO, readTVar) +import UnliftIO.Timeout (timeout) + +hookCountsShouldBe :: HasCallStack => HookCountRefs -> HookCounts -> Expectation +hookCountsShouldBe hookCountRefs hookCounts = + checkHookCounts hookCountRefs (== hookCounts) + +checkHookCounts + :: HasCallStack + => HookCountRefs + -> (HookCounts -> Bool) + -> Expectation +checkHookCounts hookCountRefs p = do + -- The input predicate can cause the STM transaction to retry, so the STM + -- computation is wrapped in a timeout of 10 seconds in case the STM + -- transaction never completes. + mResult <- timeout 10000000 $ atomically $ do + hookCounts <- hookCountsSTM hookCountRefs + guard $ p hookCounts + case mResult of + Nothing -> expectationFailure "checkHookCounts: took too long" + Just () -> pure () + +data HookCountRefs = HookCountRefs + { alterBackendCountRef :: TVar Int + , runBeforeCountRef :: TVar Int + , runOnExceptionCountRef :: TVar Int + , runAfterCountRef :: TVar Int + } + +newHookCountRefs :: IO HookCountRefs +newHookCountRefs = + HookCountRefs + <$> newTVarIO 0 + <*> newTVarIO 0 + <*> newTVarIO 0 + <*> newTVarIO 0 + +hookCountsSTM :: HookCountRefs -> STM HookCounts +hookCountsSTM hookCountRefs = + HookCounts + <$> readTVar (alterBackendCountRef hookCountRefs) + <*> readTVar (runBeforeCountRef hookCountRefs) + <*> readTVar (runOnExceptionCountRef hookCountRefs) + <*> readTVar (runAfterCountRef hookCountRefs) + +data HookCounts = HookCounts + { alterBackendCount :: Int + , runBeforeCount :: Int + , runOnExceptionCount :: Int + , runAfterCount :: Int + } deriving stock (Eq, Show) + +trackHookCounts + :: forall m backend + . (MonadIO m) + => HookCountRefs + -> SqlPoolHooks m backend + -> SqlPoolHooks m backend +trackHookCounts hookCountRefs sqlPoolHooks = + sqlPoolHooks + & flip modifyAlterBackend (\origAlterBackend conn -> do + bumpCount alterBackendCountRef + origAlterBackend conn + ) + & flip modifyRunBefore (\origRunBefore conn level -> do + bumpCount runBeforeCountRef + origRunBefore conn level + ) + & flip modifyRunOnException (\origRunOnException conn level ex -> do + bumpCount runOnExceptionCountRef + origRunOnException conn level ex + ) + & flip modifyRunAfter (\origRunAfter conn level -> do + bumpCount runAfterCountRef + origRunAfter conn level + ) + where + bumpCount :: TVar Int -> m () + bumpCount countRef = do + liftIO $ atomically $ modifyTVar' countRef (+ 1) + + HookCountRefs + { alterBackendCountRef + , runBeforeCountRef + , runOnExceptionCountRef + , runAfterCountRef + } = hookCountRefs diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index 2de90fae7..ec59058fc 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -1,5 +1,13 @@ # Changelog for persistent +## 2.14.6.0 + +* [#1482](https://github.com/yesodweb/persistent/pull/1482) + * Ensure `runOnException` hook is run when user-specified database action is + interrupted via async exception + * Add `runSqlPoolWithExtensibleHooksRetry` to support automatic transaction + retrying on user-specified synchronous exceptions + ## 2.14.5.1 * [#1496](https://github.com/yesodweb/persistent/pull/1496) diff --git a/persistent/Database/Persist/Sql/Run.hs b/persistent/Database/Persist/Sql/Run.hs index 46ea85df0..dc6e3acd9 100644 --- a/persistent/Database/Persist/Sql/Run.hs +++ b/persistent/Database/Persist/Sql/Run.hs @@ -1,11 +1,12 @@ -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} module Database.Persist.Sql.Run where +import Control.Monad (void) import Control.Monad.IO.Unlift import Control.Monad.Logger.CallStack -import Control.Monad (void) import Control.Monad.Reader (MonadReader) import qualified Control.Monad.Reader as MonadReader import Control.Monad.Trans.Reader hiding (local) @@ -19,8 +20,9 @@ import Database.Persist.Class.PersistStore import Database.Persist.Sql.Raw import Database.Persist.Sql.Types import Database.Persist.Sql.Types.Internal -import Database.Persist.SqlBackend.Internal.StatementCache import Database.Persist.SqlBackend.Internal.SqlPoolHooks +import Database.Persist.SqlBackend.Internal.StatementCache +import Database.Persist.SqlBackend.SqlPoolHooks (mapSqlPoolHooks) -- | Get a connection from the pool, run the given action, and then return the -- connection to the pool. @@ -105,9 +107,7 @@ runSqlPoolWithHooks r pconn i before after onException = -- | This function is how 'runSqlPoolWithHooks' is defined. -- --- It's currently the most general function for using a SQL pool. --- --- @since 2.13.0.0 +-- @since 2.13.3.0 runSqlPoolWithExtensibleHooks :: forall backend m a. (MonadUnliftIO m, BackendCompatible SqlBackend backend) => ReaderT backend m a @@ -115,18 +115,55 @@ runSqlPoolWithExtensibleHooks -> Maybe IsolationLevel -> SqlPoolHooks m backend -> m a -runSqlPoolWithExtensibleHooks r pconn i SqlPoolHooks{..} = - withRunInIO $ \runInIO -> - withResource pconn $ \conn -> - UE.mask $ \restore -> do - conn' <- restore $ runInIO $ alterBackend conn - _ <- restore $ runInIO $ runBefore conn' i - a <- restore (runInIO (runReaderT r conn')) - `UE.catchAny` \e -> do - _ <- restore $ runInIO $ runOnException conn' i e - UE.throwIO e - _ <- restore $ runInIO $ runAfter conn' i - pure a +runSqlPoolWithExtensibleHooks = runSqlPoolWithExtensibleHooksRetry $ const False + +-- | This function is equivalent to 'runSqlPoolWithExtensibleHooks' but +-- additionally allows specifying an exception predicate. On encountering an +-- exception during a transaction, this predicate decides whether or not the +-- transaction should be retried. This can be used to build various retrying +-- schemes, such as retrying on serialization/deadlock errors when running +-- transactions at serializable isolation level. +-- +-- Note that even though the predicate operates on 'UE.SomeException', it is +-- only applied to synchronous exceptions. Asynchronous exceptions are always +-- rethrown and will never trigger a retry of the transaction. +-- +-- Considering @persistent@ abstracts over specific SQL backends, you will +-- likely need to reach for a backend-specific exception type when defining an +-- exception predicate. +-- +-- @since 2.14.6.0 +runSqlPoolWithExtensibleHooksRetry + :: forall backend m a. (MonadUnliftIO m, BackendCompatible SqlBackend backend) + => (UE.SomeException -> Bool) + -> ReaderT backend m a + -> Pool backend + -> Maybe IsolationLevel + -> SqlPoolHooks m backend + -> m a +runSqlPoolWithExtensibleHooksRetry shouldRetry r pconn i hooks = + withRunInIO $ \runInIO -> do + let hooksIO = mapSqlPoolHooks runInIO hooks + withResource pconn $ \conn -> do + UE.mask $ \restore -> do + conn' <- restore $ alterBackend hooksIO conn + loop (runBefore hooksIO conn' i) $ UE.try $ do + a <- restore (runInIO $ runReaderT r conn') + `UE.withException` \e -> do + runOnException hooksIO conn' i e + runAfter hooksIO conn' i + pure a + where + loop :: IO () -> IO (Either UE.SomeException a) -> IO a + loop begin action = go + where + go = begin >> action >>= \case + Left ex -> + if shouldRetry ex + then go + else UE.throwIO ex + Right x -> + pure x rawAcquireSqlConn :: forall backend m diff --git a/persistent/Database/Persist/SqlBackend/Internal/SqlPoolHooks.hs b/persistent/Database/Persist/SqlBackend/Internal/SqlPoolHooks.hs index 556bd736e..52726a243 100644 --- a/persistent/Database/Persist/SqlBackend/Internal/SqlPoolHooks.hs +++ b/persistent/Database/Persist/SqlBackend/Internal/SqlPoolHooks.hs @@ -18,4 +18,8 @@ data SqlPoolHooks m backend = SqlPoolHooks -- ^ This action is performed when an exception is received. The -- exception is provided as a convenience - it is rethrown once this -- cleanup function is complete. + -- + -- Note that this action is run in an @uninterruptibleMask@. If you are + -- overriding this hook, be sure your action can complete in a timely + -- manner. } diff --git a/persistent/Database/Persist/SqlBackend/SqlPoolHooks.hs b/persistent/Database/Persist/SqlBackend/SqlPoolHooks.hs index c180a1d1a..805409ef8 100644 --- a/persistent/Database/Persist/SqlBackend/SqlPoolHooks.hs +++ b/persistent/Database/Persist/SqlBackend/SqlPoolHooks.hs @@ -1,6 +1,8 @@ +{-# LANGUAGE RankNTypes #-} module Database.Persist.SqlBackend.SqlPoolHooks ( SqlPoolHooks , defaultSqlPoolHooks + , mapSqlPoolHooks , getAlterBackend , modifyAlterBackend , setAlterBackend @@ -11,16 +13,18 @@ module Database.Persist.SqlBackend.SqlPoolHooks , modifyRunAfter , setRunAfter , getRunOnException + , modifyRunOnException + , setRunOnException ) where import Control.Exception import Control.Monad.IO.Class +import Database.Persist.Class.PersistStore import Database.Persist.Sql.Raw import Database.Persist.SqlBackend.Internal -import Database.Persist.SqlBackend.Internal.SqlPoolHooks import Database.Persist.SqlBackend.Internal.IsolationLevel -import Database.Persist.Class.PersistStore +import Database.Persist.SqlBackend.Internal.SqlPoolHooks -- | Lifecycle hooks that may be altered to extend SQL pool behavior -- in a backwards compatible fashion. @@ -50,6 +54,18 @@ defaultSqlPoolHooks = SqlPoolHooks liftIO $ connRollback sqlBackend getter } +mapSqlPoolHooks + :: (forall x. m x -> n x) + -> SqlPoolHooks m backend + -> SqlPoolHooks n backend +mapSqlPoolHooks f hooks = SqlPoolHooks + { alterBackend = f . alterBackend hooks + , runBefore = \conn mLevel -> f $ runBefore hooks conn mLevel + , runAfter = \conn mLevel -> f $ runAfter hooks conn mLevel + , runOnException = \conn mLevel ex -> + f $ runOnException hooks conn mLevel ex + } + getAlterBackend :: SqlPoolHooks m backend -> (backend -> m backend) getAlterBackend = alterBackend diff --git a/persistent/persistent.cabal b/persistent/persistent.cabal index 857cd8b47..845eb530c 100644 --- a/persistent/persistent.cabal +++ b/persistent/persistent.cabal @@ -1,5 +1,5 @@ name: persistent -version: 2.14.5.1 +version: 2.14.6.0 license: MIT license-file: LICENSE author: Michael Snoyman