Skip to content

Commit 6a5b1f2

Browse files
Merge pull request #223 from kakkun61/hooks
Introducing hooks
2 parents 0cd6de7 + 601703a commit 6a5b1f2

10 files changed

+204
-36
lines changed

hedis.cabal

+20-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ library
108108
Database.Redis.Commands,
109109
Database.Redis.ManualCommands,
110110
Database.Redis.URL,
111-
Database.Redis.ConnectionContext
111+
Database.Redis.ConnectionContext,
112+
Database.Redis.Hooks
112113
other-extensions: StrictData
113114

114115
benchmark hedis-benchmark
@@ -203,3 +204,21 @@ test-suite hedis-test-cluster
203204
ghc-options: -Werror
204205
if flag(dev)
205206
ghc-prof-options: -auto-all
207+
208+
test-suite hedis-test-hooks
209+
default-language: Haskell2010
210+
type: exitcode-stdio-1.0
211+
hs-source-dirs: test
212+
main-is: MainHooks.hs
213+
build-depends:
214+
base == 4.*,
215+
hedis,
216+
HUnit,
217+
test-framework,
218+
test-framework-hunit
219+
-- We use -O0 here, since GHC takes *very* long to compile so many constants
220+
ghc-options: -O0 -Wall -rtsopts -fno-warn-unused-do-bind -Wunused-packages
221+
if flag(dev)
222+
ghc-options: -Werror
223+
if flag(dev)
224+
ghc-prof-options: -auto-all

src/Database/Redis.hs

+3
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ module Database.Redis (
176176
-- * Pub\/Sub
177177
module Database.Redis.PubSub,
178178

179+
-- * Hooks
180+
Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook, defaultHooks,
181+
179182
-- * Low-Level Command API
180183
sendRequest,
181184
Reply(..), Status(..), RedisArg(..), RedisResult(..), ConnectionLostException(..),

src/Database/Redis/Cluster.hs

+17-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module Database.Redis.Cluster
1515
, disconnect
1616
, requestPipelined
1717
, nodes
18+
, hooks
1819
) where
1920

2021
import qualified Data.ByteString as B
@@ -36,6 +37,7 @@ import System.IO.Unsafe(unsafeInterleaveIO)
3637

3738
import Database.Redis.Protocol(Reply(Error), renderRequest, reply)
3839
import qualified Database.Redis.Cluster.Command as CMD
40+
import Database.Redis.Hooks (Hooks)
3941

4042
-- This module implements a clustered connection whilst maintaining
4143
-- compatibility with the original Hedis codebase. In particular it still
@@ -48,7 +50,7 @@ import qualified Database.Redis.Cluster.Command as CMD
4850

4951
-- | A connection to a redis cluster, it is compoesed of a map from Node IDs to
5052
-- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap'
51-
data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap
53+
data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap Hooks
5254

5355
-- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection'
5456
data NodeConnection = NodeConnection CC.ConnectionContext (IOR.IORef (Maybe B.ByteString)) NodeID
@@ -100,13 +102,13 @@ instance Exception UnsupportedClusterCommandException
100102
newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable)
101103
instance Exception CrossSlotException
102104

103-
connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection
104-
connect commandInfos shardMapVar timeoutOpt = do
105+
connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Hooks -> IO Connection
106+
connect commandInfos shardMapVar timeoutOpt hooks' = do
105107
shardMap <- readMVar shardMapVar
106108
stateVar <- newMVar $ Pending []
107109
pipelineVar <- newMVar $ Pipeline stateVar
108110
nodeConns <- nodeConnections shardMap
109-
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) where
111+
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) hooks' where
110112
nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection)
111113
nodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap)
112114
connectNode :: Node -> IO (NodeID, NodeConnection)
@@ -116,14 +118,14 @@ connect commandInfos shardMapVar timeoutOpt = do
116118
return (n, NodeConnection ctx ref n)
117119

118120
disconnect :: Connection -> IO ()
119-
disconnect (Connection nodeConnMap _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where
121+
disconnect (Connection nodeConnMap _ _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where
120122
disconnectNode (NodeConnection nodeCtx _ _) = CC.disconnect nodeCtx
121123

122124
-- Add a request to the current pipeline for this connection. The pipeline will
123125
-- be executed implicitly as soon as any result returned from this function is
124126
-- evaluated.
125127
requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply
126-
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
128+
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
127129
(newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case
128130
Pending requests | isMulti nextRequest -> do
129131
replies <- evaluatePipeline shardMapVar refreshAction conn requests
@@ -228,7 +230,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
228230
-- there is one.
229231
case last replies of
230232
(Error errString) | B.isPrefixOf "MOVED" errString -> do
231-
let (Connection _ _ _ infoMap) = conn
233+
let (Connection _ _ _ infoMap _) = conn
232234
keys <- mconcat <$> mapM (requestKeys infoMap) requests
233235
hashSlot <- hashSlotForKeys (CrossSlotException requests) keys
234236
nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot
@@ -250,7 +252,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
250252
evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply]
251253
evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do
252254
let requests = reverse requests'
253-
let (Connection _ _ _ infoMap) = conn
255+
let (Connection _ _ _ infoMap _) = conn
254256
keys <- mconcat <$> mapM (requestKeys infoMap) requests
255257
-- In cluster mode Redis expects commands in transactions to all work on the
256258
-- same hashslot. We find that hashslot here.
@@ -296,7 +298,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d
296298

297299
nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection
298300
nodeConnForHashSlot shardMapVar conn exception hashSlot = do
299-
let (Connection nodeConns _ _ _) = conn
301+
let (Connection nodeConns _ _ _ _) = conn
300302
(ShardMap shardMap) <- hasLocked $ readMVar shardMapVar
301303
node <-
302304
case IntMap.lookup (fromEnum hashSlot) shardMap of
@@ -339,12 +341,12 @@ moved _ = False
339341

340342

341343
nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection
342-
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _) host port = do
344+
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do
343345
node <- nodeWithHostAndPort shardMap host port
344346
HM.lookup (nodeId node) nodeConns
345347

346348
nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection]
347-
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shardMap) request =
349+
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request =
348350
case request of
349351
("FLUSHALL" : _) -> allNodes
350352
("FLUSHDB" : _) -> allNodes
@@ -364,7 +366,7 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shard
364366
Just allNodes' -> return allNodes'
365367

366368
allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection]
367-
allMasterNodes (Connection nodeConns _ _ _) (ShardMap shardMap) =
369+
allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) =
368370
mapM (flip HM.lookup nodeConns . nodeId) masterNodes
369371
where
370372
masterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap)
@@ -410,3 +412,6 @@ hasLocked action =
410412
action `catches`
411413
[ Handler $ \exc@BlockedIndefinitelyOnMVar -> throwIO exc
412414
]
415+
416+
hooks :: Connection -> Hooks
417+
hooks (Connection _ _ _ _ h) = h

src/Database/Redis/Connection.hs

+8-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import qualified Network.Socket as NS
3131
import qualified Data.HashMap.Strict as HM
3232

3333
import qualified Database.Redis.ProtocolPipelining as PP
34-
import Database.Redis.Core(Redis, runRedisInternal, runRedisClusteredInternal)
34+
import Database.Redis.Core(Redis, Hooks, runRedisInternal, runRedisClusteredInternal, defaultHooks)
3535
import Database.Redis.Protocol(Reply(..))
3636
import Database.Redis.Cluster(ShardMap(..), Node, Shard(..))
3737
import qualified Database.Redis.Cluster as Cluster
@@ -97,6 +97,7 @@ data ConnectInfo = ConnInfo
9797
-- get connected in this interval of time.
9898
, connectTLSParams :: Maybe ClientParams
9999
-- ^ Optional TLS parameters. TLS will be enabled if this is provided.
100+
, connectHooks :: Hooks
100101
} deriving Show
101102

102103
data ConnectError = ConnectAuthError Reply
@@ -117,6 +118,7 @@ instance Exception ConnectError
117118
-- connectMaxIdleTime = 30 -- Keep open for 30 seconds
118119
-- connectTimeout = Nothing -- Don't add timeout logic
119120
-- connectTLSParams = Nothing -- Do not use TLS
121+
-- connectHooks = defaultHooks -- Do nothing
120122
-- @
121123
--
122124
defaultConnectInfo :: ConnectInfo
@@ -130,13 +132,14 @@ defaultConnectInfo = ConnInfo
130132
, connectMaxIdleTime = 30
131133
, connectTimeout = Nothing
132134
, connectTLSParams = Nothing
135+
, connectHooks = defaultHooks
133136
}
134137

135138
createConnection :: ConnectInfo -> IO PP.Connection
136139
createConnection ConnInfo{..} = do
137140
let timeoutOptUs =
138141
round . (1000000 *) <$> connectTimeout
139-
conn <- PP.connect connectHost connectPort timeoutOptUs
142+
conn <- PP.connectWithHooks connectHost connectPort timeoutOptUs connectHooks
140143
conn' <- case connectTLSParams of
141144
Nothing -> return conn
142145
Just tlsParams -> PP.enableTLS tlsParams conn
@@ -231,9 +234,9 @@ connectCluster bootstrapConnInfo = do
231234
Left e -> throwIO $ ClusterConnectError e
232235
Right infos -> do
233236
#if MIN_VERSION_resource_pool(0,3,0)
234-
pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
237+
pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
235238
#else
236-
pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
239+
pool <- createPool (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
237240
#endif
238241
return $ ClusteredConnection shardMapVar pool
239242

@@ -255,7 +258,7 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m
255258
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort)
256259

257260
refreshShardMap :: Cluster.Connection -> IO ShardMap
258-
refreshShardMap (Cluster.Connection nodeConns _ _ _) = do
261+
refreshShardMap (Cluster.Connection nodeConns _ _ _ _) = do
259262
let (Cluster.NodeConnection ctx _ _) = head $ HM.elems nodeConns
260263
pipelineConn <- PP.fromCtx ctx
261264
_ <- PP.beginReceiving pipelineConn

src/Database/Redis/Core.hs

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
module Database.Redis.Core (
66
Redis(), unRedis, reRedis,
77
RedisCtx(..), MonadRedis(..),
8+
Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook,
89
send, recv, sendRequest,
910
runRedisInternal,
1011
runRedisClusteredInternal,
12+
defaultHooks,
1113
RedisEnv(..),
1214
) where
1315

@@ -24,6 +26,7 @@ import qualified Database.Redis.ProtocolPipelining as PP
2426
import Database.Redis.Types
2527
import Database.Redis.Cluster(ShardMap)
2628
import qualified Database.Redis.Cluster as Cluster
29+
import Database.Redis.Hooks
2730

2831
--------------------------------------------------------------------------------
2932
-- The Redis Monad
@@ -118,8 +121,8 @@ sendRequest req = do
118121
env <- ask
119122
case env of
120123
NonClusteredEnv{..} -> do
121-
r <- liftIO $ PP.request envConn (renderRequest req)
124+
r <- liftIO $ sendRequestHook (PP.hooks envConn) (PP.request envConn . renderRequest) req
122125
setLastReply r
123126
return r
124-
ClusteredEnv{..} -> liftIO $ Cluster.requestPipelined refreshAction connection req
127+
ClusteredEnv{..} -> liftIO $ sendRequestHook (Cluster.hooks connection) (Cluster.requestPipelined refreshAction connection) req
125128
returnDecode r'

src/Database/Redis/Hooks.hs

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module Database.Redis.Hooks where
2+
3+
import Data.ByteString (ByteString)
4+
import Database.Redis.Protocol (Reply)
5+
import {-# SOURCE #-} Database.Redis.PubSub (Message, PubSub)
6+
7+
data Hooks =
8+
Hooks
9+
{ sendRequestHook :: SendRequestHook
10+
, sendPubSubHook :: SendPubSubHook
11+
, callbackHook :: CallbackHook
12+
, sendHook :: SendHook
13+
, receiveHook :: ReceiveHook
14+
}
15+
16+
-- | A hook for sending commands to the server and receiving replys from the server.
17+
type SendRequestHook = ([ByteString] -> IO Reply) -> [ByteString] -> IO Reply
18+
19+
-- | A hook for sending pub/sub messages to the server.
20+
type SendPubSubHook = ([ByteString] -> IO ()) -> [ByteString] -> IO ()
21+
22+
-- | A hook for invoking callbacks with pub/sub messages.
23+
type CallbackHook = (Message -> IO PubSub) -> Message -> IO PubSub
24+
25+
-- | A hook for just sending raw data to the server.
26+
type SendHook = (ByteString -> IO ()) -> ByteString -> IO ()
27+
28+
-- | A hook for receiving raw data from the server.
29+
type ReceiveHook = IO Reply -> IO Reply
30+
31+
-- | The default hooks.
32+
-- Every hook is the identity function.
33+
defaultHooks :: Hooks
34+
defaultHooks =
35+
Hooks
36+
{ sendRequestHook = id
37+
, sendPubSubHook = id
38+
, callbackHook = id
39+
, sendHook = id
40+
, receiveHook = id
41+
}
42+
43+
instance Show Hooks where
44+
show _ = "Hooks {sendRequestHook = _, sendPubSubHook = _, callbackHook = _, sendHook = _, receiveHook = _}"

src/Database/Redis/ProtocolPipelining.hs

+14-8
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
--
1818
module Database.Redis.ProtocolPipelining (
1919
Connection,
20-
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx
20+
connect, connectWithHooks, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx, hooks
2121
) where
2222

2323
import Prelude
@@ -31,6 +31,7 @@ import System.IO.Unsafe
3131

3232
import Database.Redis.Protocol
3333
import qualified Database.Redis.ConnectionContext as CC
34+
import Database.Redis.Hooks
3435

3536
data Connection = Conn
3637
{ connCtx :: CC.ConnectionContext -- ^ Connection socket-handle.
@@ -42,14 +43,18 @@ data Connection = Conn
4243
-- ^ Number of pending replies and thus the difference length between
4344
-- 'connReplies' and 'connPending'.
4445
-- length connPending - pendingCount = length connReplies
46+
, hooks :: Hooks
4547
}
4648

4749

4850
fromCtx :: CC.ConnectionContext -> IO Connection
49-
fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0
51+
fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0 <*> pure defaultHooks
5052

5153
connect :: NS.HostName -> CC.PortID -> Maybe Int -> IO Connection
52-
connect hostName portId timeoutOpt = do
54+
connect hostName portId timeoutOpt = connectWithHooks hostName portId timeoutOpt defaultHooks
55+
56+
connectWithHooks :: NS.HostName -> CC.PortID -> Maybe Int -> Hooks -> IO Connection
57+
connectWithHooks hostName portId timeoutOpt hooks = do
5358
connCtx <- CC.connect hostName portId timeoutOpt
5459
connReplies <- newIORef []
5560
connPending <- newIORef []
@@ -74,7 +79,7 @@ disconnect Conn{..} = CC.disconnect connCtx
7479
-- The 'Handle' is 'hFlush'ed when reading replies from the 'connCtx'.
7580
send :: Connection -> S.ByteString -> IO ()
7681
send Conn{..} s = do
77-
CC.send connCtx s
82+
sendHook hooks (CC.send connCtx) s
7883

7984
-- Signal that we expect one more reply from Redis.
8085
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
@@ -88,10 +93,11 @@ send Conn{..} s = do
8893

8994
-- |Take a reply-thunk from the list of future replies.
9095
recv :: Connection -> IO Reply
91-
recv Conn{..} = do
92-
(r:rs) <- readIORef connReplies
93-
writeIORef connReplies rs
94-
return r
96+
recv Conn{..} =
97+
receiveHook hooks $ do
98+
(r:rs) <- readIORef connReplies
99+
writeIORef connReplies rs
100+
return r
95101

96102
-- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but
97103
-- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription

0 commit comments

Comments
 (0)