diff --git a/Cargo.lock b/Cargo.lock index 8350347b6b4..fde6f2dc3aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10739,6 +10739,7 @@ dependencies = [ "alloy-primitives", "alloy-rlp", "codspeed-criterion-compat", + "crossbeam-channel", "dashmap 6.1.0", "derive_more", "itertools 0.14.0", diff --git a/crates/engine/tree/benches/state_root_task.rs b/crates/engine/tree/benches/state_root_task.rs index 9f61e62d2f9..02fbd693034 100644 --- a/crates/engine/tree/benches/state_root_task.rs +++ b/crates/engine/tree/benches/state_root_task.rs @@ -228,16 +228,21 @@ fn bench_state_root(c: &mut Criterion) { }, |(genesis_hash, mut payload_processor, provider, state_updates)| { black_box({ - let mut handle = payload_processor.spawn( - Default::default(), - core::iter::empty::< - Result, core::convert::Infallible>, - >(), - StateProviderBuilder::new(provider.clone(), genesis_hash, None), - ConsistentDbView::new_with_latest_tip(provider).unwrap(), - TrieInput::default(), - &TreeConfig::default(), - ); + let mut handle = payload_processor + .spawn( + Default::default(), + core::iter::empty::< + Result< + Recovered, + core::convert::Infallible, + >, + >(), + StateProviderBuilder::new(provider.clone(), genesis_hash, None), + ConsistentDbView::new_with_latest_tip(provider).unwrap(), + TrieInput::default(), + &TreeConfig::default(), + ) + .expect("failed to spawn payload processor task"); let mut state_hook = handle.state_hook(); diff --git a/crates/engine/tree/src/tree/payload_processor/mod.rs b/crates/engine/tree/src/tree/payload_processor/mod.rs index 8d9bd1ba2e0..511b860d98d 100644 --- a/crates/engine/tree/src/tree/payload_processor/mod.rs +++ b/crates/engine/tree/src/tree/payload_processor/mod.rs @@ -26,13 +26,13 @@ use reth_evm::{ }; use reth_primitives_traits::NodePrimitives; use reth_provider::{ - providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateProviderFactory, - StateReader, + providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, ProviderResult, + StateProviderFactory, StateReader, }; use reth_revm::{db::BundleState, state::EvmState}; use reth_trie::TrieInput; use reth_trie_parallel::{ - proof_task::{ProofTaskCtx, ProofTaskManager}, + proof_task::{new_proof_task_handle, ProofTaskCtx}, root::ParallelStateRootError, }; use reth_trie_sparse::{ @@ -58,7 +58,7 @@ use configured_sparse_trie::ConfiguredSparseTrie; /// Default parallelism thresholds to use with the [`ParallelSparseTrie`]. /// /// These values were determined by performing benchmarks using gradually increasing values to judge -/// the affects. Below 100 throughput would generally be equal or slightly less, while above 150 it +/// the effects. Below 100 throughput would generally be equal or slightly less, while above 150 it /// would deteriorate to the point where PST might as well not be used. pub const PARALLEL_SPARSE_TRIE_PARALLELISM_THRESHOLDS: ParallelismThresholds = ParallelismThresholds { min_revealed_nodes: 100, min_updated_nodes: 100 }; @@ -69,7 +69,7 @@ pub struct PayloadProcessor where Evm: ConfigureEvm, { - /// The executor used by to spawn tasks. + /// The executor used to spawn tasks. executor: WorkloadExecutor, /// The most recent cache used for execution. execution_cache: ExecutionCache, @@ -163,9 +163,9 @@ where /// /// This task runs until there are no further updates to process. /// - /// /// This returns a handle to await the final state root and to interact with the tasks (e.g. /// canceling) + #[allow(clippy::type_complexity)] pub fn spawn>( &mut self, env: ExecutionEnv, @@ -174,7 +174,7 @@ where consistent_view: ConsistentDbView

, trie_input: TrieInput, config: &TreeConfig, - ) -> PayloadHandle, I::Tx>, I::Error> + ) -> ProviderResult, I::Tx>, I::Error>> where P: DatabaseProviderFactory + BlockReader @@ -196,20 +196,20 @@ where state_root_config.prefix_sets.clone(), ); let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize; - let proof_task = ProofTaskManager::new( + let proof_task_handle = new_proof_task_handle( self.executor.handle().clone(), state_root_config.consistent_view.clone(), task_ctx, max_proof_task_concurrency, - ); + )?; // We set it to half of the proof task concurrency, because often for each multiproof we // spawn one Tokio task for the account proof, and one Tokio task for the storage proof. - let max_multi_proof_task_concurrency = max_proof_task_concurrency / 2; + let max_multi_proof_task_concurrency = (max_proof_task_concurrency / 2).max(1); let multi_proof_task = MultiProofTask::new( state_root_config, self.executor.clone(), - proof_task.handle(), + proof_task_handle.clone(), to_sparse_trie, max_multi_proof_task_concurrency, config.multiproof_chunking_enabled().then_some(config.multiproof_chunk_size()), @@ -238,26 +238,14 @@ where let (state_root_tx, state_root_rx) = channel(); // Spawn the sparse trie task using any stored trie and parallel trie configuration. - self.spawn_sparse_trie_task(sparse_trie_rx, proof_task.handle(), state_root_tx); + self.spawn_sparse_trie_task(sparse_trie_rx, proof_task_handle, state_root_tx); - // spawn the proof task - self.executor.spawn_blocking(move || { - if let Err(err) = proof_task.run() { - // At least log if there is an error at any point - tracing::error!( - target: "engine::root", - ?err, - "Storage proof task returned an error" - ); - } - }); - - PayloadHandle { + Ok(PayloadHandle { to_multi_proof, prewarm_handle, state_root: Some(state_root_rx), transactions: execution_rx, - } + }) } /// Spawns a task that exclusively handles cache prewarming for transaction execution. @@ -857,14 +845,19 @@ mod tests { PrecompileCacheMap::default(), ); let provider = BlockchainProvider::new(factory).unwrap(); - let mut handle = payload_processor.spawn( - Default::default(), - core::iter::empty::, core::convert::Infallible>>(), - StateProviderBuilder::new(provider.clone(), genesis_hash, None), - ConsistentDbView::new_with_latest_tip(provider).unwrap(), - TrieInput::from_state(hashed_state), - &TreeConfig::default(), - ); + let mut handle = + payload_processor + .spawn( + Default::default(), + core::iter::empty::< + Result, core::convert::Infallible>, + >(), + StateProviderBuilder::new(provider.clone(), genesis_hash, None), + ConsistentDbView::new_with_latest_tip(provider).unwrap(), + TrieInput::from_state(hashed_state), + &TreeConfig::default(), + ) + .expect("failed to spawn payload processor task"); let mut state_hook = handle.state_hook(); diff --git a/crates/engine/tree/src/tree/payload_processor/multiproof.rs b/crates/engine/tree/src/tree/payload_processor/multiproof.rs index 6c7f5de40a3..2feba6ba992 100644 --- a/crates/engine/tree/src/tree/payload_processor/multiproof.rs +++ b/crates/engine/tree/src/tree/payload_processor/multiproof.rs @@ -1204,7 +1204,7 @@ mod tests { use alloy_primitives::map::B256Set; use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory}; use reth_trie::{MultiProof, TrieInput}; - use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager}; + use reth_trie_parallel::proof_task::{new_proof_task_handle, ProofTaskCtx}; use revm_primitives::{B256, U256}; use std::sync::Arc; @@ -1231,15 +1231,16 @@ mod tests { config.state_sorted.clone(), config.prefix_sets.clone(), ); - let proof_task = ProofTaskManager::new( + let proof_task_handle = new_proof_task_handle( executor.handle().clone(), config.consistent_view.clone(), task_ctx, - 1, - ); + 1, // max_concurrency for test + ) + .expect("Failed to create proof task handle for multiproof test"); let channel = channel(); - MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None) + MultiProofTask::new(config, executor, proof_task_handle, channel.0, 1, None) } #[test] diff --git a/crates/engine/tree/src/tree/payload_validator.rs b/crates/engine/tree/src/tree/payload_validator.rs index cd2c37d1e91..87cb12268e2 100644 --- a/crates/engine/tree/src/tree/payload_validator.rs +++ b/crates/engine/tree/src/tree/payload_validator.rs @@ -885,7 +885,7 @@ where consistent_view, trie_input, &self.config, - ), + )?, StateRootStrategy::StateRootTask, ) // if prefix sets are not empty, we spawn a task that exclusively handles cache diff --git a/crates/trie/parallel/Cargo.toml b/crates/trie/parallel/Cargo.toml index c9f625a1500..b4463d9ede3 100644 --- a/crates/trie/parallel/Cargo.toml +++ b/crates/trie/parallel/Cargo.toml @@ -36,6 +36,7 @@ derive_more.workspace = true rayon.workspace = true itertools.workspace = true tokio = { workspace = true, features = ["rt-multi-thread"] } +crossbeam-channel.workspace = true # `metrics` feature reth-metrics = { workspace = true, optional = true } diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index d6e1b57ed9b..587edd11f5c 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -14,8 +14,7 @@ use dashmap::DashMap; use itertools::Itertools; use reth_execution_errors::StorageRootError; use reth_provider::{ - providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx, - ProviderError, + providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError, }; use reth_storage_errors::db::DatabaseError; use reth_trie::{ @@ -34,7 +33,7 @@ use reth_trie_common::{ proof::{DecodedProofNodes, ProofRetainer}, }; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use std::sync::{mpsc::Receiver, Arc}; +use std::sync::Arc; use tracing::trace; /// Parallel proof calculator. @@ -59,7 +58,7 @@ pub struct ParallelProof { /// Provided by the user to give the necessary context to retain extra proofs. multi_added_removed_keys: Option>, /// Handle to the storage proof task. - storage_proof_task_handle: ProofTaskManagerHandle>, + storage_proof_task_handle: ProofTaskManagerHandle<::Tx>, /// Cached storage proof roots for missed leaves; this maps /// hashed (missed) addresses to their storage proof roots. missed_leaves_storage_roots: Arc>, @@ -75,7 +74,7 @@ impl ParallelProof { state_sorted: Arc, prefix_sets: Arc, missed_leaves_storage_roots: Arc>, - storage_proof_task_handle: ProofTaskManagerHandle>, + storage_proof_task_handle: ProofTaskManagerHandle<::Tx>, ) -> Self { Self { view, @@ -118,7 +117,7 @@ where hashed_address: B256, prefix_set: PrefixSet, target_slots: B256Set, - ) -> Receiver> { + ) -> crossbeam_channel::Receiver> { let input = StorageProofInput::new( hashed_address, prefix_set, @@ -127,9 +126,8 @@ where self.multi_added_removed_keys.clone(), ); - let (sender, receiver) = std::sync::mpsc::channel(); - let _ = - self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender)); + let (sender, receiver) = crossbeam_channel::unbounded(); + self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender)); receiver } @@ -323,6 +321,17 @@ where } } } + + // Drain receivers for accounts the walker never touched (e.g. destroyed targets) so workers + // can deliver their results without hitting a closed channel. + for (hashed_address, rx) in storage_proof_receivers { + let decoded_storage_multiproof = rx.recv().map_err(|e| { + ParallelStateRootError::StorageRoot(StorageRootError::Database( + DatabaseError::Other(format!("channel closed for {hashed_address}: {e}")), + )) + })??; + collected_decoded_storages.insert(hashed_address, decoded_storage_multiproof); + } let _ = hash_builder.root(); let stats = tracker.finish(); @@ -368,7 +377,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::proof_task::{ProofTaskCtx, ProofTaskManager}; + use crate::proof_task::{new_proof_task_handle, ProofTaskCtx}; use alloy_primitives::{ keccak256, map::{B256Set, DefaultHashBuilder}, @@ -447,13 +456,13 @@ mod tests { let task_ctx = ProofTaskCtx::new(Default::default(), Default::default(), Default::default()); - let proof_task = - ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1); - let proof_task_handle = proof_task.handle(); - - // keep the join handle around to make sure it does not return any errors - // after we compute the state root - let join_handle = rt.spawn_blocking(move || proof_task.run()); + let proof_task_handle = new_proof_task_handle( + rt.handle().clone(), + consistent_view.clone(), + task_ctx, + 1, // max_concurrency for test + ) + .expect("Failed to create proof task"); let parallel_result = ParallelProof::new( consistent_view, @@ -489,9 +498,7 @@ mod tests { // then compare the entire thing for any mask differences assert_eq!(parallel_result, sequential_result_decoded); - // drop the handle to terminate the task and then block on the proof task handle to make - // sure it does not return any errors + // Drop the handle to release transaction pool resources drop(proof_task_handle); - rt.block_on(join_handle).unwrap().expect("The proof task should not return an error"); } } diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index 9bb96d4b19e..386a65104df 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -1,20 +1,15 @@ -//! A Task that manages sending proof requests to a number of tasks that have longer-running -//! database transactions. +//! Proof task management using a pool of pre-warmed database transactions. //! -//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks, -//! and is responsible for managing the fixed number of database transactions created at the start -//! of the task. -//! -//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and -//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction. +//! This module provides proof computation using Tokio's blocking threadpool with +//! transaction reuse via a crossbeam channel pool. use crate::root::ParallelStateRootError; use alloy_primitives::{map::B256Set, B256}; +use crossbeam_channel::{bounded, unbounded, Receiver, SendError, Sender, TrySendError}; use reth_db_api::transaction::DbTx; use reth_execution_errors::SparseTrieError; use reth_provider::{ - providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx, - ProviderResult, + providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderResult, }; use reth_trie::{ hashed_cursor::HashedPostStateCursorFactory, @@ -31,16 +26,16 @@ use reth_trie_common::{ use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory}; use std::{ - collections::VecDeque, + fmt, + marker::PhantomData, sync::{ atomic::{AtomicUsize, Ordering}, - mpsc::{channel, Receiver, SendError, Sender}, Arc, }, time::Instant, }; -use tokio::runtime::Handle; -use tracing::{debug, trace}; +use tokio::{runtime::Handle, task}; +use tracing::{error, trace}; #[cfg(feature = "metrics")] use crate::proof_task_metrics::ProofTaskMetrics; @@ -48,195 +43,85 @@ use crate::proof_task_metrics::ProofTaskMetrics; type StorageProofResult = Result; type TrieNodeProviderResult = Result, SparseTrieError>; -/// A task that manages sending multiproof requests to a number of tasks that have longer-running -/// database transactions -#[derive(Debug)] -pub struct ProofTaskManager { - /// Max number of database transactions to create - max_concurrency: usize, - /// Number of database transactions created - total_transactions: usize, - /// Consistent view provider used for creating transactions on-demand +/// Type alias for the factory tuple returned by `create_factories` +type ProofFactories<'a, Tx> = ( + InMemoryTrieCursorFactory, &'a TrieUpdatesSorted>, + HashedPostStateCursorFactory, &'a HashedPostStateSorted>, +); + +/// Builds the worker-pool handle that queues proof tasks and reuses a fixed set of database +/// transactions. +pub fn new_proof_task_handle( + executor: Handle, view: ConsistentDbView, - /// Proof task context shared across all proof tasks task_ctx: ProofTaskCtx, - /// Proof tasks pending execution - pending_tasks: VecDeque, - /// The underlying handle from which to spawn proof tasks - executor: Handle, - /// The proof task transactions, containing owned cursor factories that are reused for proof - /// calculation. - proof_task_txs: Vec>>, - /// A receiver for new proof tasks. - proof_task_rx: Receiver>>, - /// A sender for sending back transactions. - tx_sender: Sender>>, - /// The number of active handles. - /// - /// Incremented in [`ProofTaskManagerHandle::new`] and decremented in - /// [`ProofTaskManagerHandle::drop`]. - active_handles: Arc, - /// Metrics tracking blinded node fetches. - #[cfg(feature = "metrics")] - metrics: ProofTaskMetrics, -} + max_concurrency: usize, +) -> ProviderResult::Tx>> +where + Factory: DatabaseProviderFactory + Clone + Send + Sync + 'static, +{ + let queue_capacity = max_concurrency.max(1); + let worker_count = queue_capacity; // TODO: Update this when we have account proof workers. + + let (task_sender, task_receiver) = bounded(queue_capacity); + + // Spawn blocking workers upfront; each owns a reusable transaction from the consistent view. + for worker_id in 0..worker_count { + let provider_ro = view.provider_ro()?; + let tx = provider_ro.into_tx(); + let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id); + let receiver = task_receiver.clone(); + executor.spawn_blocking(move || worker_loop(proof_task_tx, receiver)); + } -impl ProofTaskManager { - /// Creates a new [`ProofTaskManager`] with the given max concurrency, creating that number of - /// cursor factories. - /// - /// Returns an error if the consistent view provider fails to create a read-only transaction. - pub fn new( - executor: Handle, - view: ConsistentDbView, - task_ctx: ProofTaskCtx, - max_concurrency: usize, - ) -> Self { - let (tx_sender, proof_task_rx) = channel(); - Self { - max_concurrency, - total_transactions: 0, - view, - task_ctx, - pending_tasks: VecDeque::new(), + let handle: ProofTaskManagerHandle<::Tx> = + ProofTaskManagerHandle::new( executor, - proof_task_txs: Vec::new(), - proof_task_rx, - tx_sender, - active_handles: Arc::new(AtomicUsize::new(0)), + task_sender, + Arc::new(AtomicUsize::new(0)), #[cfg(feature = "metrics")] - metrics: ProofTaskMetrics::default(), - } - } + Arc::new(ProofTaskMetrics::default()), + ); - /// Returns a handle for sending new proof tasks to the [`ProofTaskManager`]. - pub fn handle(&self) -> ProofTaskManagerHandle> { - ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone()) - } + Ok(handle) } -impl ProofTaskManager +fn worker_loop(proof_tx: ProofTaskTx, receiver: Receiver) where - Factory: DatabaseProviderFactory + 'static, + Tx: DbTx, { - /// Inserts the task into the pending tasks queue. - pub fn queue_proof_task(&mut self, task: ProofTaskKind) { - self.pending_tasks.push_back(task); - } - - /// Gets either the next available transaction, or creates a new one if all are in use and the - /// total number of transactions created is less than the max concurrency. - pub fn get_or_create_tx(&mut self) -> ProviderResult>>> { - if let Some(proof_task_tx) = self.proof_task_txs.pop() { - return Ok(Some(proof_task_tx)); - } - - // if we can create a new tx within our concurrency limits, create one on-demand - if self.total_transactions < self.max_concurrency { - let provider_ro = self.view.provider_ro()?; - let tx = provider_ro.into_tx(); - self.total_transactions += 1; - return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions))); - } - - Ok(None) - } - - /// Spawns the next queued proof task on the executor with the given input, if there are any - /// transactions available. - /// - /// This will return an error if a transaction must be created on-demand and the consistent view - /// provider fails. - pub fn try_spawn_next(&mut self) -> ProviderResult<()> { - let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) }; - - let Some(proof_task_tx) = self.get_or_create_tx()? else { - // if there are no txs available, requeue the proof task - self.pending_tasks.push_front(task); - return Ok(()) - }; - - let tx_sender = self.tx_sender.clone(); - self.executor.spawn_blocking(move || match task { + while let Ok(task) = receiver.recv() { + match task { ProofTaskKind::StorageProof(input, sender) => { - proof_task_tx.storage_proof(input, sender, tx_sender); + proof_tx.storage_proof(input, &sender); } ProofTaskKind::BlindedAccountNode(path, sender) => { - proof_task_tx.blinded_account_node(path, sender, tx_sender); + proof_tx.blinded_account_node(&path, &sender); } ProofTaskKind::BlindedStorageNode(account, path, sender) => { - proof_task_tx.blinded_storage_node(account, path, sender, tx_sender); + proof_tx.blinded_storage_node(&account, &path, &sender); + } + #[cfg(test)] + ProofTaskKind::Test(task) => { + (task)(); } - }); - - Ok(()) - } - - /// Loops, managing the proof tasks, and sending new tasks to the executor. - pub fn run(mut self) -> ProviderResult<()> { - loop { - match self.proof_task_rx.recv() { - Ok(message) => match message { - ProofTaskMessage::QueueTask(task) => { - // Track metrics for blinded node requests - #[cfg(feature = "metrics")] - match &task { - ProofTaskKind::BlindedAccountNode(_, _) => { - self.metrics.account_nodes += 1; - } - ProofTaskKind::BlindedStorageNode(_, _, _) => { - self.metrics.storage_nodes += 1; - } - _ => {} - } - // queue the task - self.queue_proof_task(task) - } - ProofTaskMessage::Transaction(tx) => { - // return the transaction to the pool - self.proof_task_txs.push(tx); - } - ProofTaskMessage::Terminate => { - // Record metrics before terminating - #[cfg(feature = "metrics")] - self.metrics.record(); - return Ok(()) - } - }, - // All senders are disconnected, so we can terminate - // However this should never happen, as this struct stores a sender - Err(_) => return Ok(()), - }; - - // try spawning the next task - self.try_spawn_next()?; } } } -/// Type alias for the factory tuple returned by `create_factories` -type ProofFactories<'a, Tx> = ( - InMemoryTrieCursorFactory, &'a TrieUpdatesSorted>, - HashedPostStateCursorFactory, &'a HashedPostStateSorted>, -); - /// This contains all information shared between all storage proof instances. #[derive(Debug)] pub struct ProofTaskTx { /// The tx that is reused for proof calculations. tx: Tx, - /// Trie updates, prefix sets, and state updates task_ctx: ProofTaskCtx, - - /// Identifier for the tx within the context of a single [`ProofTaskManager`], used only for - /// tracing. + /// Identifier for the tx within the pool, used only for tracing. id: usize, } impl ProofTaskTx { - /// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. The id is - /// used only for tracing. + /// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self { Self { tx, task_ctx, id } } @@ -261,92 +146,85 @@ where } /// Calculates a storage proof for the given hashed address, and desired prefix set. - fn storage_proof( - self, - input: StorageProofInput, - result_sender: Sender, - tx_sender: Sender>, - ) { + fn storage_proof(&self, input: StorageProofInput, result_sender: &Sender) { + let StorageProofInput { + hashed_address, + prefix_set, + target_slots, + with_branch_node_masks, + multi_added_removed_keys, + } = input; + trace!( target: "trie::proof_task", - hashed_address=?input.hashed_address, + worker_id = self.id, + hashed_address = ?hashed_address, "Starting storage proof task calculation" ); let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories(); - let multi_added_removed_keys = input - .multi_added_removed_keys - .unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new())); - let added_removed_keys = multi_added_removed_keys.get_storage(&input.hashed_address); + let multi_added_removed_keys = + multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new())); + let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address); let span = tracing::trace_span!( target: "trie::proof_task", "Storage proof calculation", - hashed_address=?input.hashed_address, + hashed_address = ?hashed_address, // Add a unique id because we often have parallel storage proof calculations for the // same hashed address, and we want to differentiate them during trace analysis. - span_id=self.id, + span_id = self.id, ); let span_guard = span.enter(); - let target_slots_len = input.target_slots.len(); + let target_slots_len = target_slots.len(); let proof_start = Instant::now(); - let raw_proof_result = StorageProof::new_hashed( - trie_cursor_factory, - hashed_cursor_factory, - input.hashed_address, - ) - .with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().copied())) - .with_branch_node_masks(input.with_branch_node_masks) - .with_added_removed_keys(added_removed_keys) - .storage_multiproof(input.target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string())); - - drop(span_guard); + let raw_proof_result = + StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address) + .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied())) + .with_branch_node_masks(with_branch_node_masks) + .with_added_removed_keys(added_removed_keys) + .storage_multiproof(target_slots) + .map_err(|e| ParallelStateRootError::Other(e.to_string())); let decoded_result = raw_proof_result.and_then(|raw_proof| { raw_proof.try_into().map_err(|e: alloy_rlp::Error| { ParallelStateRootError::Other(format!( "Failed to decode storage proof for {}: {}", - input.hashed_address, e + hashed_address, e )) }) }); trace!( target: "trie::proof_task", - hashed_address=?input.hashed_address, - prefix_set = ?input.prefix_set.len(), - target_slots = ?target_slots_len, + worker_id = self.id, + hashed_address = ?hashed_address, + prefix_set_len = prefix_set.len(), + target_slots = target_slots_len, proof_time = ?proof_start.elapsed(), "Completed storage proof task calculation" ); - // send the result back - if let Err(error) = result_sender.send(decoded_result) { - debug!( + drop(span_guard); + + // Send the result back (log error if receiver dropped) + if let Err(e) = result_sender.send(decoded_result) { + error!( target: "trie::proof_task", - hashed_address = ?input.hashed_address, - ?error, - task_time = ?proof_start.elapsed(), - "Storage proof receiver is dropped, discarding the result" + worker_id = self.id, + "Failed to send storage proof result: {:?}", + e ); } - - // send the tx back - let _ = tx_sender.send(ProofTaskMessage::Transaction(self)); } /// Retrieves blinded account node by path. - fn blinded_account_node( - self, - path: Nibbles, - result_sender: Sender, - tx_sender: Sender>, - ) { + fn blinded_account_node(&self, path: &Nibbles, result_sender: &Sender) { trace!( target: "trie::proof_task", + worker_id = self.id, ?path, "Starting blinded account node retrieval" ); @@ -360,37 +238,35 @@ where ); let start = Instant::now(); - let result = blinded_provider_factory.account_node_provider().trie_node(&path); + let result = blinded_provider_factory.account_node_provider().trie_node(path); trace!( target: "trie::proof_task", + worker_id = self.id, ?path, elapsed = ?start.elapsed(), "Completed blinded account node retrieval" ); - if let Err(error) = result_sender.send(result) { - tracing::error!( + if let Err(e) = result_sender.send(result) { + error!( target: "trie::proof_task", - ?path, - ?error, - "Failed to send blinded account node result" + worker_id = self.id, + "Failed to send account node result: {:?}", + e ); } - - // send the tx back - let _ = tx_sender.send(ProofTaskMessage::Transaction(self)); } /// Retrieves blinded storage node of the given account by path. fn blinded_storage_node( - self, - account: B256, - path: Nibbles, - result_sender: Sender, - tx_sender: Sender>, + &self, + account: &B256, + path: &Nibbles, + result_sender: &Sender, ) { trace!( target: "trie::proof_task", + worker_id = self.id, ?account, ?path, "Starting blinded storage node retrieval" @@ -405,9 +281,10 @@ where ); let start = Instant::now(); - let result = blinded_provider_factory.storage_node_provider(account).trie_node(&path); + let result = blinded_provider_factory.storage_node_provider(*account).trie_node(path); trace!( target: "trie::proof_task", + worker_id = self.id, ?account, ?path, elapsed = ?start.elapsed(), @@ -415,17 +292,16 @@ where ); if let Err(error) = result_sender.send(result) { - tracing::error!( + error!( target: "trie::proof_task", ?account, ?path, + worker_id = self.id, ?error, - "Failed to send blinded storage node result" + "Failed to send storage node result" ); } - // send the tx back - let _ = tx_sender.send(ProofTaskMessage::Transaction(self)); } } @@ -433,20 +309,19 @@ where #[derive(Debug)] pub struct StorageProofInput { /// The hashed address for which the proof is calculated. - hashed_address: B256, + pub hashed_address: B256, /// The prefix set for the proof calculation. - prefix_set: PrefixSet, + pub prefix_set: PrefixSet, /// The target slots for the proof calculation. - target_slots: B256Set, + pub target_slots: B256Set, /// Whether or not to collect branch node masks - with_branch_node_masks: bool, + pub with_branch_node_masks: bool, /// Provided by the user to give the necessary context to retain extra proofs. - multi_added_removed_keys: Option>, + pub multi_added_removed_keys: Option>, } impl StorageProofInput { - /// Creates a new [`StorageProofInput`] with the given hashed address, prefix set, and target - /// slots. + /// Creates a new [`StorageProofInput`] with the given parameters. pub const fn new( hashed_address: B256, prefix_set: PrefixSet, @@ -489,22 +364,7 @@ impl ProofTaskCtx { } } -/// Message used to communicate with [`ProofTaskManager`]. -#[derive(Debug)] -pub enum ProofTaskMessage { - /// A request to queue a proof task. - QueueTask(ProofTaskKind), - /// A returned database transaction. - Transaction(ProofTaskTx), - /// A request to terminate the proof task manager. - Terminate, -} - -/// Proof task kind. -/// -/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum -/// specifies the type of proof task to be executed. -#[derive(Debug)] +/// Proof task kind dispatched via [`ProofTaskManagerHandle::queue_task`]. pub enum ProofTaskKind { /// A storage proof request. StorageProof(StorageProofInput, Sender), @@ -512,98 +372,301 @@ pub enum ProofTaskKind { BlindedAccountNode(Nibbles, Sender), /// A blinded storage node request. BlindedStorageNode(B256, Nibbles, Sender), + /// Test-only hook for exercising the worker pool. + #[cfg(test)] + Test(Box), } -/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the -/// number of active handles went to zero. -#[derive(Debug)] +impl fmt::Debug for ProofTaskKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::StorageProof(_, _) => f.write_str("StorageProof"), + Self::BlindedAccountNode(_, _) => f.write_str("BlindedAccountNode"), + Self::BlindedStorageNode(_, _, _) => f.write_str("BlindedStorageNode"), + #[cfg(test)] + Self::Test(_) => f.write_str("Test"), + } + } +} + +/// A handle for dispatching proof tasks using a transaction pool and Tokio's blocking threadpool. +/// +/// Tasks are dispatched directly without an intermediate manager loop. pub struct ProofTaskManagerHandle { - /// The sender for the proof task manager. - sender: Sender>, - /// The number of active handles. + /// Tokio executor for spawning helper tasks. + executor: Handle, + /// Sender used to dispatch tasks to the persistent worker pool. + task_sender: Sender, + /// The number of active handles (for metrics). active_handles: Arc, + /// Metrics tracking blinded node fetches. + #[cfg(feature = "metrics")] + metrics: Arc, + /// Marker to retain the database transaction type parameter. + _marker: PhantomData, } -impl ProofTaskManagerHandle { - /// Creates a new [`ProofTaskManagerHandle`] with the given sender. - pub fn new(sender: Sender>, active_handles: Arc) -> Self { - active_handles.fetch_add(1, Ordering::SeqCst); - Self { sender, active_handles } +// Manual Debug impl since Tx may not be Debug +impl std::fmt::Debug for ProofTaskManagerHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProofTaskManagerHandle") + .field("executor", &self.executor) + .field("active_handles", &self.active_handles) + .finish() } +} - /// Queues a task to the proof task manager. - pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError>> { - self.sender.send(ProofTaskMessage::QueueTask(task)) +impl ProofTaskManagerHandle +where + Tx: DbTx + Send + 'static, +{ + /// Creates a new [`ProofTaskManagerHandle`]. + pub fn new( + executor: Handle, + task_sender: Sender, + active_handles: Arc, + #[cfg(feature = "metrics")] metrics: Arc, + ) -> Self { + active_handles.fetch_add(1, Ordering::SeqCst); + Self { + executor, + task_sender, + active_handles, + #[cfg(feature = "metrics")] + metrics, + _marker: PhantomData, + } } - /// Terminates the proof task manager. - pub fn terminate(&self) { - let _ = self.sender.send(ProofTaskMessage::Terminate); + /// Queues a proof task by enqueuing it onto the worker channel. + pub fn queue_task(&self, task: ProofTaskKind) { + #[cfg(feature = "metrics")] + { + match &task { + ProofTaskKind::BlindedAccountNode(_, _) => { + self.metrics.account_nodes.fetch_add(1, Ordering::Relaxed); + } + ProofTaskKind::BlindedStorageNode(_, _, _) => { + self.metrics.storage_nodes.fetch_add(1, Ordering::Relaxed); + } + _ => {} + } + } + + match self.task_sender.try_send(task) { + Ok(()) => {} + Err(TrySendError::Full(task)) => { + let sender = self.task_sender.clone(); + let executor = self.executor.clone(); + executor.spawn(async move { + let send_result = task::spawn_blocking(move || sender.send(task)).await; + match send_result { + Ok(Ok(())) => {} + Ok(Err(SendError(_))) => { + error!( + target: "trie::proof_task", + "Worker channel disconnected while enqueueing proof task" + ); + } + Err(join_error) => { + error!( + target: "trie::proof_task", + ?join_error, + "Failed to enqueue proof task: blocking send panicked" + ); + } + } + }); + } + Err(TrySendError::Disconnected(_)) => { + error!( + target: "trie::proof_task", + "Worker channel disconnected, dropping proof task" + ); + } + } } } -impl Clone for ProofTaskManagerHandle { +impl Clone for ProofTaskManagerHandle +where + Tx: DbTx + Send + 'static, +{ fn clone(&self) -> Self { - Self::new(self.sender.clone(), self.active_handles.clone()) + Self::new( + self.executor.clone(), + self.task_sender.clone(), + Arc::clone(&self.active_handles), + #[cfg(feature = "metrics")] + Arc::clone(&self.metrics), + ) } } impl Drop for ProofTaskManagerHandle { fn drop(&mut self) { - // Decrement the number of active handles and terminate the manager if it was the last - // handle. + // Record metrics if this is the last handle if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 { - self.terminate(); + #[cfg(feature = "metrics")] + self.metrics.record(); } } } -impl TrieNodeProviderFactory for ProofTaskManagerHandle { +impl TrieNodeProviderFactory for ProofTaskManagerHandle +where + Tx: DbTx + Send + 'static, +{ type AccountNodeProvider = ProofTaskTrieNodeProvider; type StorageNodeProvider = ProofTaskTrieNodeProvider; fn account_node_provider(&self) -> Self::AccountNodeProvider { - ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() } + ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() } } fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider { - ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() } + ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() } } } /// Trie node provider for retrieving trie nodes by path. -#[derive(Debug)] pub enum ProofTaskTrieNodeProvider { /// Blinded account trie node provider. AccountNode { - /// Sender to the proof task. - sender: Sender>, + /// Handle to the transaction pool + handle: ProofTaskManagerHandle, }, /// Blinded storage trie node provider. StorageNode { /// Target account. account: B256, - /// Sender to the proof task. - sender: Sender>, + /// Handle to the transaction pool + handle: ProofTaskManagerHandle, }, } -impl TrieNodeProvider for ProofTaskTrieNodeProvider { +impl std::fmt::Debug for ProofTaskTrieNodeProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::AccountNode { .. } => f.debug_struct("AccountNode").finish(), + Self::StorageNode { account, .. } => { + f.debug_struct("StorageNode").field("account", account).finish() + } + } + } +} + +impl TrieNodeProvider for ProofTaskTrieNodeProvider +where + Tx: DbTx + Send + 'static, +{ fn trie_node(&self, path: &Nibbles) -> Result, SparseTrieError> { - let (tx, rx) = channel(); + let (tx, rx) = unbounded(); match self { - Self::AccountNode { sender } => { - let _ = sender.send(ProofTaskMessage::QueueTask( - ProofTaskKind::BlindedAccountNode(*path, tx), - )); + Self::AccountNode { handle } => { + handle.queue_task(ProofTaskKind::BlindedAccountNode(*path, tx)); } - Self::StorageNode { sender, account } => { - let _ = sender.send(ProofTaskMessage::QueueTask( - ProofTaskKind::BlindedStorageNode(*account, *path, tx), - )); + Self::StorageNode { handle, account } => { + handle.queue_task(ProofTaskKind::BlindedStorageNode(*account, *path, tx)); } } rx.recv().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use alloy_primitives::map::{B256Map, B256Set}; + use crossbeam_channel::bounded; + use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory}; + use reth_trie_common::{ + updates::TrieUpdatesSorted, HashedAccountsSorted, HashedPostStateSorted, + }; + use std::{sync::Arc, time::Duration}; + use tokio::runtime::Runtime; + + fn empty_task_ctx() -> ProofTaskCtx { + ProofTaskCtx::new( + Arc::new(TrieUpdatesSorted { + account_nodes: Vec::new(), + storage_tries: B256Map::default(), + }), + Arc::new(HashedPostStateSorted::new( + HashedAccountsSorted::default(), + B256Map::default(), + )), + Arc::new(TriePrefixSetsMut { + account_prefix_set: PrefixSetMut::default(), + storage_prefix_sets: B256Map::default(), + destroyed_accounts: B256Set::default(), + }), + ) + } + + #[test] + fn worker_pool_respects_storage_worker_limit() { + let factory = create_test_provider_factory(); + let consistent_view = ConsistentDbView::new(factory, None); + let runtime = Runtime::new().expect("failed to construct runtime"); + + let task_ctx = empty_task_ctx(); + let handle = new_proof_task_handle( + runtime.handle().clone(), + consistent_view, + task_ctx, + 2, // max_concurrency (results in worker_count = 2) + ) + .expect("failed to create proof task handle"); + + let (entered_tx, entered_rx) = bounded::(10); + let (release_tx, release_rx) = bounded::<()>(10); + let release_rx = Arc::new(release_rx); + + for id in 0..2 { + let entered_tx = entered_tx.clone(); + let release_rx = Arc::clone(&release_rx); + handle.queue_task(ProofTaskKind::Test(Box::new(move || { + entered_tx.send(id).unwrap(); + release_rx.recv().unwrap(); + }))); + } + + { + let entered_tx = entered_tx.clone(); + let release_rx = Arc::clone(&release_rx); + handle.queue_task(ProofTaskKind::Test(Box::new(move || { + entered_tx.send(2).unwrap(); + release_rx.recv().unwrap(); + }))); + } + + drop(entered_tx); + + let first = + entered_rx.recv_timeout(Duration::from_secs(1)).expect("first task not started"); + let second = + entered_rx.recv_timeout(Duration::from_secs(1)).expect("second task not started"); + assert_ne!(first, second, "tasks should be executed by distinct workers"); + + assert!( + entered_rx.recv_timeout(Duration::from_millis(200)).is_err(), + "third task started before workers were released" + ); + + release_tx.send(()).unwrap(); + + let third = + entered_rx.recv_timeout(Duration::from_secs(1)).expect("third task never started"); + assert_eq!(third, 2); + + release_tx.send(()).unwrap(); + release_tx.send(()).unwrap(); + + drop(handle); + drop(release_tx); + drop(release_rx); + drop(runtime); + } +} diff --git a/crates/trie/parallel/src/proof_task_metrics.rs b/crates/trie/parallel/src/proof_task_metrics.rs index cdb59d078d8..196fa5753ed 100644 --- a/crates/trie/parallel/src/proof_task_metrics.rs +++ b/crates/trie/parallel/src/proof_task_metrics.rs @@ -1,21 +1,35 @@ use reth_metrics::{metrics::Histogram, Metrics}; +use std::sync::atomic::{AtomicUsize, Ordering}; /// Metrics for blinded node fetching for the duration of the proof task manager. -#[derive(Clone, Debug, Default)] +/// `AtomicUsize` because we want to be able to add to these counts from concurrent workers. +#[derive(Debug, Default)] pub struct ProofTaskMetrics { /// The actual metrics for blinded nodes. pub task_metrics: ProofTaskTrieMetrics, /// Count of blinded account node requests. - pub account_nodes: usize, + pub account_nodes: AtomicUsize, /// Count of blinded storage node requests. - pub storage_nodes: usize, + pub storage_nodes: AtomicUsize, +} + +/// Implements `Clone` for `ProofTaskMetrics`. +/// Uses `Ordering::Relaxed` for atomics since no synchronization is required. +impl Clone for ProofTaskMetrics { + fn clone(&self) -> Self { + Self { + task_metrics: self.task_metrics.clone(), + account_nodes: AtomicUsize::new(self.account_nodes.load(Ordering::Relaxed)), + storage_nodes: AtomicUsize::new(self.storage_nodes.load(Ordering::Relaxed)), + } + } } impl ProofTaskMetrics { /// Record the blinded node counts into the histograms. pub fn record(&self) { - self.task_metrics.record_account_nodes(self.account_nodes); - self.task_metrics.record_storage_nodes(self.storage_nodes); + self.task_metrics.record_account_nodes(self.account_nodes.load(Ordering::Relaxed)); + self.task_metrics.record_storage_nodes(self.storage_nodes.load(Ordering::Relaxed)); } }