1
- {-# LANGUAGE NamedFieldPuns #-}
2
- {-# LANGUAGE ScopedTypeVariables #-}
1
+ {-# LANGUAGE DisambiguateRecordFields #-}
2
+ {-# LANGUAGE NamedFieldPuns #-}
3
+ {-# LANGUAGE ScopedTypeVariables #-}
4
+ {-# LANGUAGE TupleSections #-}
3
5
4
6
-- | The module should be imported qualified.
5
7
--
@@ -16,19 +18,19 @@ module Ouroboros.Network.TxSubmission.Mempool.Simple
16
18
import Prelude hiding (read , seq )
17
19
18
20
import Control.Concurrent.Class.MonadSTM.Strict
19
-
21
+ import Data.Bitraversable
22
+ import Data.Either
20
23
import Data.Foldable (toList )
21
24
import Data.Foldable qualified as Foldable
22
- import Data.Function (on )
23
- import Data.List (find , nubBy )
25
+ import Data.List (find )
24
26
import Data.Maybe (isJust )
25
27
import Data.Sequence (Seq )
26
28
import Data.Sequence qualified as Seq
27
29
import Data.Set (Set )
28
30
import Data.Set qualified as Set
29
31
32
+ import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (.. ))
30
33
import Ouroboros.Network.SizeInBytes
31
- import Ouroboros.Network.TxSubmission.Inbound.V2.Types
32
34
import Ouroboros.Network.TxSubmission.Mempool.Reader
33
35
34
36
@@ -98,31 +100,65 @@ getReader getTxId getTxSize (Mempool mempool) =
98
100
f idx tx = (getTxId tx, idx, getTxSize tx)
99
101
100
102
101
- -- | A simple mempool writer.
103
+ -- | A mempool writer which generalizes the tx submission mempool writer
104
+ -- TODO: We could replace TxSubmissionMempoolWriter with this at some point
105
+ --
106
+ data MempoolWriter txid tx failure idx m =
107
+ MempoolWriter {
108
+
109
+ -- | Compute the transaction id from a transaction.
110
+ --
111
+ -- This is used in the protocol handler to verify a full transaction
112
+ -- matches a previously given transaction id.
113
+ --
114
+ txId :: tx -> txid ,
115
+
116
+ -- | Supply a batch of transactions to the mempool. They are either
117
+ -- accepted or rejected individually, but in the order supplied.
118
+ --
119
+ -- The 'txid's of all transactions that were added successfully are
120
+ -- returned.
121
+ mempoolAddTxs :: [tx ] -> m [SubmitResult failure ]
122
+ }
123
+
124
+
125
+ -- | A mempool writer with validation harness
126
+ -- PRECONDITION: no duplicates given to mempoolAddTxs
102
127
--
103
- getWriter :: forall tx txid m .
128
+ getWriter :: forall tx txid tx' failure m .
104
129
( MonadSTM m
105
130
, Ord txid
106
131
)
107
132
=> (tx -> txid )
108
- -> (tx -> Bool )
109
- -- ^ validate a tx
110
- -> Mempool m tx
111
- -> TxSubmissionMempoolWriter txid tx Int m
112
- getWriter getTxId validateTx (Mempool mempool) =
113
- TxSubmissionMempoolWriter {
114
- txId = getTxId,
115
-
116
- mempoolAddTxs = \ txs -> do
117
- atomically $ do
118
- mempoolTxs <- readTVar mempool
119
- let currentIds = Set. fromList (map getTxId (toList mempoolTxs))
120
- validTxs = nubBy (on (==) getTxId)
121
- $ filter
122
- (\ tx -> validateTx tx
123
- && getTxId tx `Set.notMember` currentIds)
124
- txs
125
- mempoolTxs' = Foldable. foldl' (Seq. |>) mempoolTxs validTxs
126
- writeTVar mempool mempoolTxs'
127
- return (map getTxId validTxs)
128
- }
133
+ -- ^ get txid of a tx
134
+ -> ([tx ] -> m [tx' ])
135
+ -- ^ monadic validation context, acquired once prior to all work
136
+ -> (tx' -> Bool -> Either failure () )
137
+ -- ^ validate a tx in an atomic block, any failing `tx` throws an exception.
138
+ -> (failure -> STM m failure )
139
+ -- ^ return `True` when a failure should throw an exception
140
+ -> Mempool m txid tx
141
+ -> MempoolWriter txid tx failure Int m
142
+ getWriter getTxId withContext validateTx failureFilterFn (Mempool mempool) =
143
+ MempoolWriter {
144
+ txId = getTxId,
145
+
146
+ mempoolAddTxs = \ txs -> do
147
+ txs' <- withContext txs
148
+ atomically $ do
149
+ MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
150
+ result <- sequence
151
+ [bimapM (fmap SubmitFail . failureFilterFn) (pure . const (txid, tx)) validated
152
+ | (tx, tx') <- zip txs txs'
153
+ , let txid = getTxId tx
154
+ validated =
155
+ validateTx tx' (txid `Set.member` mempoolSet)
156
+ ]
157
+ let (validIds, validTxs) = unzip . rights $ result
158
+ mempoolTxs' = MempoolSeq {
159
+ mempoolSet = Set. union mempoolSet (Set. fromList validIds),
160
+ mempoolSeq = Foldable. foldl' (Seq. |>) mempoolSeq validTxs
161
+ }
162
+ writeTVar mempool mempoolTxs'
163
+ return $ fromLeft SubmitSuccess <$> result
164
+ }
0 commit comments