Skip to content

Commit 7479fd6

Browse files
test(mempool): add function to assert equality and implement Eq,PartialEq for related structs
1 parent 724ac68 commit 7479fd6

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

crates/mempool/src/mempool_test.rs

+22-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::cmp::Reverse;
12
use std::collections::HashMap;
23

34
use assert_matches::assert_matches;
@@ -33,6 +34,11 @@ impl MempoolState {
3334
let tx_queue: TransactionQueue = queue_txs.into_iter().collect();
3435
MempoolState { tx_pool, tx_queue }
3536
}
37+
38+
fn assert_eq_mempool_state(&self, mempool: &Mempool) {
39+
assert_eq!(self.tx_pool, mempool.tx_pool);
40+
assert_eq!(self.tx_queue, mempool.tx_queue);
41+
}
3642
}
3743

3844
impl From<MempoolState> for Mempool {
@@ -139,35 +145,31 @@ fn assert_eq_mempool_queue(mempool: &Mempool, expected_queue: &[ThinTransaction]
139145
#[case::test_get_more_than_all_eligible_txs(5)]
140146
#[case::test_get_less_than_all_eligible_txs(2)]
141147
fn test_get_txs(#[case] requested_txs: usize) {
142-
// TODO(Ayelet): Avoid cloning the transactions in the test.
143-
let add_tx_inputs = [
144-
add_tx_input!(tip: 50, tx_hash: 1),
145-
add_tx_input!(tip: 100, tx_hash: 2, sender_address: "0x1"),
146-
add_tx_input!(tip: 10, tx_hash: 3, sender_address: "0x2"),
147-
];
148-
let tx_references_iterator =
149-
add_tx_inputs.iter().map(|input| TransactionReference::new(&input.tx));
150-
let txs_iterator = add_tx_inputs.iter().map(|input| input.tx.clone());
148+
let tx1 = add_tx_input!(tip: 50, tx_hash: 1).tx;
149+
let tx2 = add_tx_input!(tip: 100, tx_hash: 2, sender_address: "0x1").tx;
150+
let tx3 = add_tx_input!(tip: 10, tx_hash: 3, sender_address: "0x2").tx;
151+
152+
let mut tx_inputs = vec![tx1, tx2, tx3];
153+
let tx_references_iterator = tx_inputs.iter().map(TransactionReference::new);
154+
let txs_iterator = tx_inputs.iter().cloned();
151155

152156
let mut mempool: Mempool = MempoolState::new(txs_iterator, tx_references_iterator).into();
153157

154158
let txs = mempool.get_txs(requested_txs).unwrap();
155159

156-
let sorted_txs = [
157-
add_tx_inputs[1].tx.clone(), // tip 100
158-
add_tx_inputs[0].tx.clone(), // tip 50
159-
add_tx_inputs[2].tx.clone(), // tip 10
160-
];
160+
tx_inputs.sort_by_key(|tx| Reverse(tx.tip));
161161

162-
// This ensures we do not exceed the number of transactions available in the mempool.
163-
let max_requested_txs = requested_txs.min(add_tx_inputs.len());
162+
// Ensure we do not exceed the number of transactions available in the mempool.
163+
let max_requested_txs = requested_txs.min(tx_inputs.len());
164164

165-
// checks that the returned transactions are the ones with the highest priority.
166-
let (expected_queue, remaining_txs) = sorted_txs.split_at(max_requested_txs);
165+
// Check that the returned transactions are the ones with the highest priority.
166+
let (expected_queue, remaining_txs) = tx_inputs.split_at(max_requested_txs);
167167
assert_eq!(txs, expected_queue);
168168

169-
// checks that the transactions that were not returned are still in the mempool.
170-
assert_eq_mempool_queue(&mempool, remaining_txs);
169+
// Check that the transactions that were not returned are still in the mempool.
170+
let remaining_tx_references = remaining_txs.iter().map(TransactionReference::new);
171+
let mempool_state = MempoolState::new(remaining_txs.to_vec(), remaining_tx_references);
172+
mempool_state.assert_eq_mempool_state(&mempool);
171173
}
172174

173175
#[rstest]

crates/mempool/src/transaction_pool.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type HashToTransaction = HashMap<TransactionHash, ThinTransaction>;
1313
/// Invariant: both data structures are consistent regarding the existence of transactions:
1414
/// A transaction appears in one if and only if it appears in the other.
1515
/// No duplicate transactions appear in the pool.
16-
#[derive(Debug, Default)]
16+
#[derive(Debug, Default, Eq, PartialEq)]
1717
pub struct TransactionPool {
1818
// Holds the complete transaction objects; it should be the sole entity that does so.
1919
tx_pool: HashToTransaction,
@@ -79,7 +79,7 @@ impl TransactionPool {
7979
}
8080
}
8181

82-
#[derive(Debug, Default)]
82+
#[derive(Debug, Default, Eq, PartialEq)]
8383
struct AccountTransactionIndex(HashMap<ContractAddress, BTreeMap<Nonce, TransactionReference>>);
8484

8585
impl AccountTransactionIndex {

crates/mempool/src/transaction_queue.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ use starknet_api::core::{ContractAddress, Nonce};
55
use starknet_api::transaction::TransactionHash;
66

77
use crate::mempool::TransactionReference;
8-
// Assumption: for the MVP only one transaction from the same contract class can be in the mempool
9-
// at a time. When this changes, saving the transactions themselves on the queu might no longer be
10-
// appropriate, because we'll also need to stores transactions without indexing them. For example,
11-
// transactions with future nonces will need to be stored, and potentially indexed on block commits.
12-
#[derive(Debug, Default)]
8+
9+
// Note: the derived comparison functionality considers the order guaranteed by the data structures
10+
// used.
11+
#[derive(Debug, Default, Eq, PartialEq)]
1312
pub struct TransactionQueue {
1413
// Priority queue of transactions with associated priority.
1514
queue: BTreeSet<QueuedTransaction>,

0 commit comments

Comments
 (0)