diff --git a/crates/engine/primitives/src/config.rs b/crates/engine/primitives/src/config.rs
index 70763b6701f..9e2c8210f08 100644
--- a/crates/engine/primitives/src/config.rs
+++ b/crates/engine/primitives/src/config.rs
@@ -9,11 +9,14 @@ pub const DEFAULT_MEMORY_BLOCK_BUFFER_TARGET: u64 = 0;
/// Default maximum concurrency for on-demand proof tasks (blinded nodes)
pub const DEFAULT_MAX_PROOF_TASK_CONCURRENCY: u64 = 256;
+/// Minimum number of workers we allow configuring explicitly.
+pub const MIN_WORKER_COUNT: usize = 32;
+
/// Returns the default number of storage worker threads based on available parallelism.
fn default_storage_worker_count() -> usize {
#[cfg(feature = "std")]
{
- std::thread::available_parallelism().map(|n| (n.get() * 2).clamp(2, 64)).unwrap_or(8)
+ std::thread::available_parallelism().map_or(8, |n| n.get() * 2).min(MIN_WORKER_COUNT)
}
#[cfg(not(feature = "std"))]
{
@@ -491,8 +494,8 @@ impl TreeConfig {
}
/// Setter for the number of storage proof worker threads.
- pub const fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self {
- self.storage_worker_count = storage_worker_count;
+ pub fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self {
+ self.storage_worker_count = storage_worker_count.max(MIN_WORKER_COUNT);
self
}
@@ -502,8 +505,8 @@ impl TreeConfig {
}
/// Setter for the number of account proof worker threads.
- pub const fn with_account_worker_count(mut self, account_worker_count: usize) -> Self {
- self.account_worker_count = account_worker_count;
+ pub fn with_account_worker_count(mut self, account_worker_count: usize) -> Self {
+ self.account_worker_count = account_worker_count.max(MIN_WORKER_COUNT);
self
}
}
diff --git a/crates/engine/tree/src/tree/payload_processor/mod.rs b/crates/engine/tree/src/tree/payload_processor/mod.rs
index c24b0d1fe16..f3ecdfa86d5 100644
--- a/crates/engine/tree/src/tree/payload_processor/mod.rs
+++ b/crates/engine/tree/src/tree/payload_processor/mod.rs
@@ -32,7 +32,7 @@ use reth_provider::{
use reth_revm::{db::BundleState, state::EvmState};
use reth_trie::TrieInput;
use reth_trie_parallel::{
- proof_task::{ProofTaskCtx, ProofTaskManager},
+ proof_task::{ProofTaskCtx, ProofWorkerHandle},
root::ParallelStateRootError,
};
use reth_trie_sparse::{
@@ -167,8 +167,7 @@ where
/// This returns a handle to await the final state root and to interact with the tasks (e.g.
/// canceling)
///
- /// Returns an error with the original transactions iterator if the proof task manager fails to
- /// initialize.
+ /// Returns an error with the original transactions iterator if proof worker spawning fails.
#[allow(clippy::type_complexity)]
pub fn spawn
>(
&mut self,
@@ -204,14 +203,14 @@ where
let storage_worker_count = config.storage_worker_count();
let account_worker_count = config.account_worker_count();
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
- let proof_task = match ProofTaskManager::new(
+ let proof_handle = match ProofWorkerHandle::new(
self.executor.handle().clone(),
consistent_view,
task_ctx,
storage_worker_count,
account_worker_count,
) {
- Ok(task) => task,
+ Ok(handle) => handle,
Err(error) => {
return Err((error, transactions, env, provider_builder));
}
@@ -223,7 +222,7 @@ where
let multi_proof_task = MultiProofTask::new(
state_root_config,
self.executor.clone(),
- proof_task.handle(),
+ proof_handle.clone(),
to_sparse_trie,
max_multi_proof_task_concurrency,
config.multiproof_chunking_enabled().then_some(config.multiproof_chunk_size()),
@@ -252,19 +251,7 @@ where
let (state_root_tx, state_root_rx) = channel();
// Spawn the sparse trie task using any stored trie and parallel trie configuration.
- self.spawn_sparse_trie_task(sparse_trie_rx, proof_task.handle(), state_root_tx);
-
- // spawn the proof task
- self.executor.spawn_blocking(move || {
- if let Err(err) = proof_task.run() {
- // At least log if there is an error at any point
- tracing::error!(
- target: "engine::root",
- ?err,
- "Storage proof task returned an error"
- );
- }
- });
+ self.spawn_sparse_trie_task(sparse_trie_rx, proof_handle, state_root_tx);
Ok(PayloadHandle {
to_multi_proof,
@@ -406,7 +393,7 @@ where
fn spawn_sparse_trie_task(
&self,
sparse_trie_rx: mpsc::Receiver,
- proof_task_handle: BPF,
+ proof_worker_handle: BPF,
state_root_tx: mpsc::Sender>,
) where
BPF: TrieNodeProviderFactory + Clone + Send + Sync + 'static,
@@ -436,7 +423,7 @@ where
let task =
SparseTrieTask::<_, ConfiguredSparseTrie, ConfiguredSparseTrie>::new_with_cleared_trie(
sparse_trie_rx,
- proof_task_handle,
+ proof_worker_handle,
self.trie_metrics.clone(),
sparse_state_trie,
);
diff --git a/crates/engine/tree/src/tree/payload_processor/multiproof.rs b/crates/engine/tree/src/tree/payload_processor/multiproof.rs
index f865312b83d..4a71bf620f7 100644
--- a/crates/engine/tree/src/tree/payload_processor/multiproof.rs
+++ b/crates/engine/tree/src/tree/payload_processor/multiproof.rs
@@ -20,7 +20,7 @@ use reth_trie::{
};
use reth_trie_parallel::{
proof::ParallelProof,
- proof_task::{AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle},
+ proof_task::{AccountMultiproofInput, ProofWorkerHandle},
root::ParallelStateRootError,
};
use std::{
@@ -346,11 +346,8 @@ pub struct MultiproofManager {
pending: VecDeque,
/// Executor for tasks
executor: WorkloadExecutor,
- /// Handle to the proof task manager used for creating `ParallelProof` instances for storage
- /// proofs.
- storage_proof_task_handle: ProofTaskManagerHandle,
- /// Handle to the proof task manager used for account multiproofs.
- account_proof_task_handle: ProofTaskManagerHandle,
+ /// Handle to the proof worker pools (storage and account).
+ proof_worker_handle: ProofWorkerHandle,
/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
///
@@ -372,8 +369,7 @@ impl MultiproofManager {
fn new(
executor: WorkloadExecutor,
metrics: MultiProofTaskMetrics,
- storage_proof_task_handle: ProofTaskManagerHandle,
- account_proof_task_handle: ProofTaskManagerHandle,
+ proof_worker_handle: ProofWorkerHandle,
max_concurrent: usize,
) -> Self {
Self {
@@ -382,8 +378,7 @@ impl MultiproofManager {
executor,
inflight: 0,
metrics,
- storage_proof_task_handle,
- account_proof_task_handle,
+ proof_worker_handle,
missed_leaves_storage_roots: Default::default(),
}
}
@@ -452,7 +447,7 @@ impl MultiproofManager {
multi_added_removed_keys,
} = storage_multiproof_input;
- let storage_proof_task_handle = self.storage_proof_task_handle.clone();
+ let storage_proof_worker_handle = self.proof_worker_handle.clone();
let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone();
self.executor.spawn_blocking(move || {
@@ -471,7 +466,7 @@ impl MultiproofManager {
config.state_sorted,
config.prefix_sets,
missed_leaves_storage_roots,
- storage_proof_task_handle,
+ storage_proof_worker_handle,
)
.with_branch_node_masks(true)
.with_multi_added_removed_keys(Some(multi_added_removed_keys))
@@ -524,7 +519,7 @@ impl MultiproofManager {
state_root_message_sender,
multi_added_removed_keys,
} = multiproof_input;
- let account_proof_task_handle = self.account_proof_task_handle.clone();
+ let account_proof_worker_handle = self.proof_worker_handle.clone();
let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone();
self.executor.spawn_blocking(move || {
@@ -556,15 +551,10 @@ impl MultiproofManager {
missed_leaves_storage_roots,
};
- let (sender, receiver) = channel();
let proof_result: Result = (|| {
- account_proof_task_handle
- .queue_task(ProofTaskKind::AccountMultiproof(input, sender))
- .map_err(|_| {
- ParallelStateRootError::Other(
- "Failed to queue account multiproof to worker pool".into(),
- )
- })?;
+ let receiver = account_proof_worker_handle
+ .queue_account_multiproof(input)
+ .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
receiver
.recv()
@@ -693,7 +683,7 @@ impl MultiProofTask {
pub(super) fn new(
config: MultiProofConfig,
executor: WorkloadExecutor,
- proof_task_handle: ProofTaskManagerHandle,
+ proof_worker_handle: ProofWorkerHandle,
to_sparse_trie: Sender,
max_concurrency: usize,
chunk_size: Option,
@@ -713,8 +703,7 @@ impl MultiProofTask {
multiproof_manager: MultiproofManager::new(
executor,
metrics.clone(),
- proof_task_handle.clone(), // handle for storage proof workers
- proof_task_handle, // handle for account proof workers
+ proof_worker_handle,
max_concurrency,
),
metrics,
@@ -1223,7 +1212,7 @@ mod tests {
DatabaseProviderFactory,
};
use reth_trie::{MultiProof, TrieInput};
- use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager};
+ use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofWorkerHandle};
use revm_primitives::{B256, U256};
fn create_test_state_root_task(factory: F) -> MultiProofTask
@@ -1238,12 +1227,12 @@ mod tests {
config.prefix_sets.clone(),
);
let consistent_view = ConsistentDbView::new(factory, None);
- let proof_task =
- ProofTaskManager::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
- .expect("Failed to create ProofTaskManager");
+ let proof_handle =
+ ProofWorkerHandle::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
+ .expect("Failed to spawn proof workers");
let channel = channel();
- MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None)
+ MultiProofTask::new(config, executor, proof_handle, channel.0, 1, None)
}
#[test]
diff --git a/crates/engine/tree/src/tree/payload_validator.rs b/crates/engine/tree/src/tree/payload_validator.rs
index 51e669b8883..17dc511a445 100644
--- a/crates/engine/tree/src/tree/payload_validator.rs
+++ b/crates/engine/tree/src/tree/payload_validator.rs
@@ -892,13 +892,12 @@ where
(handle, StateRootStrategy::StateRootTask)
}
Err((error, txs, env, provider_builder)) => {
- // Failed to initialize proof task manager, fallback to parallel state
- // root
+ // Failed to spawn proof workers, fallback to parallel state root
error!(
target: "engine::tree",
block=?block_num_hash,
?error,
- "Failed to initialize proof task manager, falling back to parallel state root"
+ "Failed to spawn proof workers, falling back to parallel state root"
);
(
self.payload_processor.spawn_cache_exclusive(
diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs
index 7fc1f022a7e..0f29502f8c7 100644
--- a/crates/trie/parallel/src/proof.rs
+++ b/crates/trie/parallel/src/proof.rs
@@ -1,8 +1,6 @@
use crate::{
metrics::ParallelTrieMetrics,
- proof_task::{
- AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle, StorageProofInput,
- },
+ proof_task::{AccountMultiproofInput, ProofWorkerHandle, StorageProofInput},
root::ParallelStateRootError,
StorageRootTargets,
};
@@ -16,10 +14,7 @@ use reth_trie::{
DecodedMultiProof, DecodedStorageMultiProof, HashedPostStateSorted, MultiProofTargets, Nibbles,
};
use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
-use std::sync::{
- mpsc::{channel, Receiver},
- Arc,
-};
+use std::sync::{mpsc::Receiver, Arc};
use tracing::trace;
/// Parallel proof calculator.
@@ -41,8 +36,8 @@ pub struct ParallelProof {
collect_branch_node_masks: bool,
/// Provided by the user to give the necessary context to retain extra proofs.
multi_added_removed_keys: Option>,
- /// Handle to the proof task manager.
- proof_task_handle: ProofTaskManagerHandle,
+ /// Handle to the proof worker pools.
+ proof_worker_handle: ProofWorkerHandle,
/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
missed_leaves_storage_roots: Arc>,
@@ -57,7 +52,7 @@ impl ParallelProof {
state_sorted: Arc,
prefix_sets: Arc,
missed_leaves_storage_roots: Arc>,
- proof_task_handle: ProofTaskManagerHandle,
+ proof_worker_handle: ProofWorkerHandle,
) -> Self {
Self {
nodes_sorted,
@@ -66,7 +61,7 @@ impl ParallelProof {
missed_leaves_storage_roots,
collect_branch_node_masks: false,
multi_added_removed_keys: None,
- proof_task_handle,
+ proof_worker_handle,
#[cfg(feature = "metrics")]
metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
}
@@ -93,7 +88,10 @@ impl ParallelProof {
hashed_address: B256,
prefix_set: PrefixSet,
target_slots: B256Set,
- ) -> Receiver> {
+ ) -> Result<
+ Receiver>,
+ ParallelStateRootError,
+ > {
let input = StorageProofInput::new(
hashed_address,
prefix_set,
@@ -102,9 +100,9 @@ impl ParallelProof {
self.multi_added_removed_keys.clone(),
);
- let (sender, receiver) = std::sync::mpsc::channel();
- let _ = self.proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
- receiver
+ self.proof_worker_handle
+ .queue_storage_proof(input)
+ .map_err(|e| ParallelStateRootError::Other(e.to_string()))
}
/// Generate a storage multiproof according to the specified targets and hashed address.
@@ -124,7 +122,7 @@ impl ParallelProof {
"Starting storage proof generation"
);
- let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
+ let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots)?;
let proof_result = receiver.recv().map_err(|_| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
format!("channel closed for {hashed_address}"),
@@ -193,15 +191,10 @@ impl ParallelProof {
missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
};
- let (sender, receiver) = channel();
- self.proof_task_handle
- .queue_task(ProofTaskKind::AccountMultiproof(input, sender))
- .map_err(|_| {
- ParallelStateRootError::Other(
- "Failed to queue account multiproof: account worker pool unavailable"
- .to_string(),
- )
- })?;
+ let receiver = self
+ .proof_worker_handle
+ .queue_account_multiproof(input)
+ .map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
// Wait for account multiproof result from worker
let (multiproof, stats) = receiver.recv().map_err(|_| {
@@ -231,7 +224,7 @@ impl ParallelProof {
#[cfg(test)]
mod tests {
use super::*;
- use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
+ use crate::proof_task::{ProofTaskCtx, ProofWorkerHandle};
use alloy_primitives::{
keccak256,
map::{B256Set, DefaultHashBuilder, HashMap},
@@ -313,20 +306,15 @@ mod tests {
let task_ctx =
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
- let proof_task =
- ProofTaskManager::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
- let proof_task_handle = proof_task.handle();
-
- // keep the join handle around to make sure it does not return any errors
- // after we compute the state root
- let join_handle = rt.spawn_blocking(move || proof_task.run());
+ let proof_worker_handle =
+ ProofWorkerHandle::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
let parallel_result = ParallelProof::new(
Default::default(),
Default::default(),
Default::default(),
Default::default(),
- proof_task_handle.clone(),
+ proof_worker_handle.clone(),
)
.decoded_multiproof(targets.clone())
.unwrap();
@@ -354,9 +342,7 @@ mod tests {
// then compare the entire thing for any mask differences
assert_eq!(parallel_result, sequential_result_decoded);
- // drop the handle to terminate the task and then block on the proof task handle to make
- // sure it does not return any errors
- drop(proof_task_handle);
- rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
+ // Workers shut down automatically when handle is dropped
+ drop(proof_worker_handle);
}
}
diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs
index 780839c238a..2d0f7e933c8 100644
--- a/crates/trie/parallel/src/proof_task.rs
+++ b/crates/trie/parallel/src/proof_task.rs
@@ -1,9 +1,14 @@
-//! A Task that manages sending proof requests to a number of tasks that have longer-running
-//! database transactions.
+//! Parallel proof computation using worker pools with dedicated database transactions.
//!
-//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks,
-//! and is responsible for managing the fixed number of database transactions created at the start
-//! of the task.
+//!
+//! # Architecture
+//!
+//! - **Worker Pools**: Pre-spawned workers with dedicated database transactions
+//! - Storage pool: Handles storage proofs and blinded storage node requests
+//! - Account pool: Handles account multiproofs and blinded account node requests
+//! - **Direct Channel Access**: [`ProofWorkerHandle`] provides type-safe queue methods with direct
+//! access to worker channels, eliminating routing overhead
+//! - **Automatic Shutdown**: Workers terminate gracefully when all handles are dropped
//!
//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
@@ -21,7 +26,7 @@ use alloy_rlp::{BufMut, Encodable};
use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
use dashmap::DashMap;
use reth_db_api::transaction::DbTx;
-use reth_execution_errors::SparseTrieError;
+use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
ProviderResult,
@@ -47,7 +52,6 @@ use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
use std::{
sync::{
- atomic::{AtomicUsize, Ordering},
mpsc::{channel, Receiver, Sender},
Arc,
},
@@ -57,7 +61,7 @@ use tokio::runtime::Handle;
use tracing::trace;
#[cfg(feature = "metrics")]
-use crate::proof_task_metrics::ProofTaskMetrics;
+use crate::proof_task_metrics::ProofTaskTrieMetrics;
type StorageProofResult = Result;
type TrieNodeProviderResult = Result