perf(tree): worker pooling for storage in multiproof generation (#18887)

Co-authored-by: Brian Picciano <me@mediocregopher.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com>
This commit is contained in:
YK
2025-10-10 15:58:15 +08:00
committed by GitHub
parent d2070f4de3
commit 397a30defb
11 changed files with 626 additions and 197 deletions

1
Cargo.lock generated
View File

@@ -10739,6 +10739,7 @@ dependencies = [
"alloy-primitives",
"alloy-rlp",
"codspeed-criterion-compat",
"crossbeam-channel",
"dashmap 6.1.0",
"derive_more",
"itertools 0.14.0",

View File

@@ -6,9 +6,21 @@ pub const DEFAULT_PERSISTENCE_THRESHOLD: u64 = 2;
/// How close to the canonical head we persist blocks.
pub const DEFAULT_MEMORY_BLOCK_BUFFER_TARGET: u64 = 0;
/// Default maximum concurrency for proof tasks
/// Default maximum concurrency for on-demand proof tasks (blinded nodes)
pub const DEFAULT_MAX_PROOF_TASK_CONCURRENCY: u64 = 256;
/// Returns the default number of storage worker threads based on available parallelism.
fn default_storage_worker_count() -> usize {
#[cfg(feature = "std")]
{
std::thread::available_parallelism().map(|n| (n.get() * 2).clamp(2, 64)).unwrap_or(8)
}
#[cfg(not(feature = "std"))]
{
8
}
}
/// The size of proof targets chunk to spawn in one multiproof calculation.
pub const DEFAULT_MULTIPROOF_TASK_CHUNK_SIZE: usize = 10;
@@ -109,6 +121,8 @@ pub struct TreeConfig {
prewarm_max_concurrency: usize,
/// Whether to unwind canonical header to ancestor during forkchoice updates.
allow_unwind_canonical_header: bool,
/// Number of storage proof worker threads.
storage_worker_count: usize,
}
impl Default for TreeConfig {
@@ -135,6 +149,7 @@ impl Default for TreeConfig {
always_process_payload_attributes_on_canonical_head: false,
prewarm_max_concurrency: DEFAULT_PREWARM_MAX_CONCURRENCY,
allow_unwind_canonical_header: false,
storage_worker_count: default_storage_worker_count(),
}
}
}
@@ -164,7 +179,9 @@ impl TreeConfig {
always_process_payload_attributes_on_canonical_head: bool,
prewarm_max_concurrency: usize,
allow_unwind_canonical_header: bool,
storage_worker_count: usize,
) -> Self {
assert!(max_proof_task_concurrency > 0, "max_proof_task_concurrency must be at least 1");
Self {
persistence_threshold,
memory_block_buffer_target,
@@ -187,6 +204,7 @@ impl TreeConfig {
always_process_payload_attributes_on_canonical_head,
prewarm_max_concurrency,
allow_unwind_canonical_header,
storage_worker_count,
}
}
@@ -394,6 +412,7 @@ impl TreeConfig {
mut self,
max_proof_task_concurrency: u64,
) -> Self {
assert!(max_proof_task_concurrency > 0, "max_proof_task_concurrency must be at least 1");
self.max_proof_task_concurrency = max_proof_task_concurrency;
self
}
@@ -452,4 +471,15 @@ impl TreeConfig {
pub const fn prewarm_max_concurrency(&self) -> usize {
self.prewarm_max_concurrency
}
/// Return the number of storage proof worker threads.
pub const fn storage_worker_count(&self) -> usize {
self.storage_worker_count
}
/// Setter for the number of storage proof worker threads.
pub const fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self {
self.storage_worker_count = storage_worker_count;
self
}
}

View File

@@ -228,16 +228,22 @@ fn bench_state_root(c: &mut Criterion) {
},
|(genesis_hash, mut payload_processor, provider, state_updates)| {
black_box({
let mut handle = payload_processor.spawn(
Default::default(),
core::iter::empty::<
Result<Recovered<TransactionSigned>, core::convert::Infallible>,
>(),
StateProviderBuilder::new(provider.clone(), genesis_hash, None),
ConsistentDbView::new_with_latest_tip(provider).unwrap(),
TrieInput::default(),
&TreeConfig::default(),
);
let mut handle = payload_processor
.spawn(
Default::default(),
core::iter::empty::<
Result<
Recovered<TransactionSigned>,
core::convert::Infallible,
>,
>(),
StateProviderBuilder::new(provider.clone(), genesis_hash, None),
ConsistentDbView::new_with_latest_tip(provider).unwrap(),
TrieInput::default(),
&TreeConfig::default(),
)
.map_err(|(err, ..)| err)
.expect("failed to spawn payload processor");
let mut state_hook = handle.state_hook();

View File

@@ -45,7 +45,7 @@ use std::sync::{
mpsc::{self, channel, Sender},
Arc,
};
use tracing::{debug, instrument};
use tracing::{debug, instrument, warn};
mod configured_sparse_trie;
pub mod executor;
@@ -166,6 +166,10 @@ where
///
/// This returns a handle to await the final state root and to interact with the tasks (e.g.
/// canceling)
///
/// Returns an error with the original transactions iterator if the proof task manager fails to
/// initialize.
#[allow(clippy::type_complexity)]
pub fn spawn<P, I: ExecutableTxIterator<Evm>>(
&mut self,
env: ExecutionEnv<Evm>,
@@ -174,7 +178,10 @@ where
consistent_view: ConsistentDbView<P>,
trie_input: TrieInput,
config: &TreeConfig,
) -> PayloadHandle<WithTxEnv<TxEnvFor<Evm>, I::Tx>, I::Error>
) -> Result<
PayloadHandle<WithTxEnv<TxEnvFor<Evm>, I::Tx>, I::Error>,
(reth_provider::ProviderError, I, ExecutionEnv<Evm>, StateProviderBuilder<N, P>),
>
where
P: DatabaseProviderFactory<Provider: BlockReader>
+ BlockReader
@@ -196,12 +203,19 @@ where
state_root_config.prefix_sets.clone(),
);
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
let proof_task = ProofTaskManager::new(
let storage_worker_count = config.storage_worker_count();
let proof_task = match ProofTaskManager::new(
self.executor.handle().clone(),
state_root_config.consistent_view.clone(),
task_ctx,
max_proof_task_concurrency,
);
storage_worker_count,
) {
Ok(task) => task,
Err(error) => {
return Err((error, transactions, env, provider_builder));
}
};
// We set it to half of the proof task concurrency, because often for each multiproof we
// spawn one Tokio task for the account proof, and one Tokio task for the storage proof.
@@ -252,12 +266,12 @@ where
}
});
PayloadHandle {
Ok(PayloadHandle {
to_multi_proof,
prewarm_handle,
state_root: Some(state_root_rx),
transactions: execution_rx,
}
})
}
/// Spawns a task that exclusively handles cache prewarming for transaction execution.
@@ -857,14 +871,20 @@ mod tests {
PrecompileCacheMap::default(),
);
let provider = BlockchainProvider::new(factory).unwrap();
let mut handle = payload_processor.spawn(
Default::default(),
core::iter::empty::<Result<Recovered<TransactionSigned>, core::convert::Infallible>>(),
StateProviderBuilder::new(provider.clone(), genesis_hash, None),
ConsistentDbView::new_with_latest_tip(provider).unwrap(),
TrieInput::from_state(hashed_state),
&TreeConfig::default(),
);
let mut handle =
payload_processor
.spawn(
Default::default(),
core::iter::empty::<
Result<Recovered<TransactionSigned>, core::convert::Infallible>,
>(),
StateProviderBuilder::new(provider.clone(), genesis_hash, None),
ConsistentDbView::new_with_latest_tip(provider).unwrap(),
TrieInput::from_state(hashed_state),
&TreeConfig::default(),
)
.map_err(|(err, ..)| err)
.expect("failed to spawn payload processor");
let mut state_hook = handle.state_hook();

View File

@@ -1236,7 +1236,9 @@ mod tests {
config.consistent_view.clone(),
task_ctx,
1,
);
1,
)
.expect("Failed to create ProofTaskManager");
let channel = channel();
MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None)

View File

@@ -877,17 +877,37 @@ where
// too expensive because it requires walking all paths in every proof.
let spawn_start = Instant::now();
let (handle, strategy) = if trie_input.prefix_sets.is_empty() {
(
self.payload_processor.spawn(
env,
txs,
provider_builder,
consistent_view,
trie_input,
&self.config,
),
StateRootStrategy::StateRootTask,
)
match self.payload_processor.spawn(
env,
txs,
provider_builder,
consistent_view,
trie_input,
&self.config,
) {
Ok(handle) => {
// Successfully spawned with state root task support
(handle, StateRootStrategy::StateRootTask)
}
Err((error, txs, env, provider_builder)) => {
// Failed to initialize proof task manager, fallback to parallel state
// root
error!(
target: "engine::tree",
block=?block_num_hash,
?error,
"Failed to initialize proof task manager, falling back to parallel state root"
);
(
self.payload_processor.spawn_cache_exclusive(
env,
txs,
provider_builder,
),
StateRootStrategy::Parallel,
)
}
}
// if prefix sets are not empty, we spawn a task that exclusively handles cache
// prewarming for transaction execution
} else {

View File

@@ -108,6 +108,11 @@ pub struct EngineArgs {
/// See `TreeConfig::unwind_canonical_header` for more details.
#[arg(long = "engine.allow-unwind-canonical-header", default_value = "false")]
pub allow_unwind_canonical_header: bool,
/// Configure the number of storage proof workers in the Tokio blocking pool.
/// If not specified, defaults to 2x available parallelism, clamped between 2 and 64.
#[arg(long = "engine.storage-worker-count")]
pub storage_worker_count: Option<usize>,
}
#[allow(deprecated)]
@@ -134,6 +139,7 @@ impl Default for EngineArgs {
state_root_fallback: false,
always_process_payload_attributes_on_canonical_head: false,
allow_unwind_canonical_header: false,
storage_worker_count: None,
}
}
}
@@ -141,7 +147,7 @@ impl Default for EngineArgs {
impl EngineArgs {
/// Creates a [`TreeConfig`] from the engine arguments.
pub fn tree_config(&self) -> TreeConfig {
TreeConfig::default()
let mut config = TreeConfig::default()
.with_persistence_threshold(self.persistence_threshold)
.with_memory_block_buffer_target(self.memory_block_buffer_target)
.with_legacy_state_root(self.legacy_state_root_task_enabled)
@@ -159,7 +165,13 @@ impl EngineArgs {
.with_always_process_payload_attributes_on_canonical_head(
self.always_process_payload_attributes_on_canonical_head,
)
.with_unwind_canonical_header(self.allow_unwind_canonical_header)
.with_unwind_canonical_header(self.allow_unwind_canonical_header);
if let Some(count) = self.storage_worker_count {
config = config.with_storage_worker_count(count);
}
config
}
}

View File

@@ -36,6 +36,7 @@ derive_more.workspace = true
rayon.workspace = true
itertools.workspace = true
tokio = { workspace = true, features = ["rt-multi-thread"] }
crossbeam-channel.workspace = true
# `metrics` feature
reth-metrics = { workspace = true, optional = true }

View File

@@ -448,7 +448,8 @@ mod tests {
let task_ctx =
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
let proof_task =
ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1, 1)
.unwrap();
let proof_task_handle = proof_task.handle();
// keep the join handle around to make sure it does not return any errors

View File

@@ -10,17 +10,18 @@
use crate::root::ParallelStateRootError;
use alloy_primitives::{map::B256Set, B256};
use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
use reth_db_api::transaction::DbTx;
use reth_execution_errors::SparseTrieError;
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
ProviderResult,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
prefix_set::TriePrefixSetsMut,
proof::{ProofTrieNodeProviderFactory, StorageProof},
trie_cursor::InMemoryTrieCursorFactory,
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
updates::TrieUpdatesSorted,
DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
};
@@ -40,7 +41,7 @@ use std::{
time::Instant,
};
use tokio::runtime::Handle;
use tracing::{debug, trace};
use tracing::trace;
#[cfg(feature = "metrics")]
use crate::proof_task_metrics::ProofTaskMetrics;
@@ -48,65 +49,333 @@ use crate::proof_task_metrics::ProofTaskMetrics;
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
/// A task that manages sending multiproof requests to a number of tasks that have longer-running
/// database transactions
/// Internal message for storage workers.
///
/// This is NOT exposed publicly. External callers use `ProofTaskKind::StorageProof` or
/// `ProofTaskKind::BlindedStorageNode` which are routed through the manager's `std::mpsc` channel.
#[derive(Debug)]
enum StorageWorkerJob {
/// Storage proof computation request
StorageProof {
/// Storage proof input parameters
input: StorageProofInput,
/// Channel to send result back to original caller
result_sender: Sender<StorageProofResult>,
},
/// Blinded storage node retrieval request
BlindedStorageNode {
/// Target account
account: B256,
/// Path to the storage node
path: Nibbles,
/// Channel to send result back to original caller
result_sender: Sender<TrieNodeProviderResult>,
},
}
impl StorageWorkerJob {
/// Sends an error back to the caller when worker pool is unavailable.
///
/// Returns `Ok(())` if the error was sent successfully, or `Err(())` if the receiver was
/// dropped.
fn send_worker_unavailable_error(&self) -> Result<(), ()> {
let error =
ParallelStateRootError::Other("Storage proof worker pool unavailable".to_string());
match self {
Self::StorageProof { result_sender, .. } => {
result_sender.send(Err(error)).map_err(|_| ())
}
Self::BlindedStorageNode { result_sender, .. } => result_sender
.send(Err(SparseTrieError::from(SparseTrieErrorKind::Other(Box::new(error)))))
.map_err(|_| ()),
}
}
}
/// Manager for coordinating proof request execution across different task types.
///
/// # Architecture
///
/// This manager handles two distinct execution paths:
///
/// 1. **Storage Worker Pool** (for storage trie operations):
/// - Pre-spawned workers with dedicated long-lived transactions
/// - Handles `StorageProof` and `BlindedStorageNode` requests
/// - Tasks queued via crossbeam unbounded channel
/// - Workers continuously process without transaction overhead
/// - Unbounded queue ensures all storage proofs benefit from transaction reuse
///
/// 2. **On-Demand Execution** (for account trie operations):
/// - Lazy transaction creation for `BlindedAccountNode` requests
/// - Transactions returned to pool after use for reuse
///
/// # Public Interface
///
/// The public interface through `ProofTaskManagerHandle` allows external callers to:
/// - Submit tasks via `queue_task(ProofTaskKind)`
/// - Use standard `std::mpsc` message passing
/// - Receive consistent return types and error handling
#[derive(Debug)]
pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
/// Max number of database transactions to create
/// Sender for storage worker jobs to worker pool.
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
/// Number of storage workers successfully spawned.
///
/// May be less than requested if concurrency limits reduce the worker budget.
storage_worker_count: usize,
/// Max number of database transactions to create for on-demand account trie operations.
max_concurrency: usize,
/// Number of database transactions created
/// Number of database transactions created for on-demand operations.
total_transactions: usize,
/// Consistent view provider used for creating transactions on-demand
view: ConsistentDbView<Factory>,
/// Proof task context shared across all proof tasks
task_ctx: ProofTaskCtx,
/// Proof tasks pending execution
/// Proof tasks pending execution (account trie operations only).
pending_tasks: VecDeque<ProofTaskKind>,
/// The underlying handle from which to spawn proof tasks
executor: Handle,
/// The proof task transactions, containing owned cursor factories that are reused for proof
/// calculation.
/// calculation (account trie operations only).
proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
/// A receiver for new proof tasks.
/// Consistent view provider used for creating transactions on-demand.
view: ConsistentDbView<Factory>,
/// Proof task context shared across all proof tasks.
task_ctx: ProofTaskCtx,
/// The underlying handle from which to spawn proof tasks.
executor: Handle,
/// Receives proof task requests from [`ProofTaskManagerHandle`].
proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
/// A sender for sending back transactions.
/// Internal channel for on-demand tasks to return transactions after use.
tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
/// The number of active handles.
///
/// Incremented in [`ProofTaskManagerHandle::new`] and decremented in
/// [`ProofTaskManagerHandle::drop`].
active_handles: Arc<AtomicUsize>,
/// Metrics tracking blinded node fetches.
/// Metrics tracking proof task operations.
#[cfg(feature = "metrics")]
metrics: ProofTaskMetrics,
}
impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
/// Creates a new [`ProofTaskManager`] with the given max concurrency, creating that number of
/// cursor factories.
/// Worker loop for storage trie operations.
///
/// # Lifecycle
///
/// Each worker:
/// 1. Receives `StorageWorkerJob` from crossbeam unbounded channel
/// 2. Computes result using its dedicated long-lived transaction
/// 3. Sends result directly to original caller via `std::mpsc`
/// 4. Repeats until channel closes (graceful shutdown)
///
/// # Transaction Reuse
///
/// Reuses the same transaction and cursor factories across multiple operations
/// to avoid transaction creation and cursor factory setup overhead.
///
/// # Panic Safety
///
/// If this function panics, the worker thread terminates but other workers
/// continue operating and the system degrades gracefully.
///
/// # Shutdown
///
/// Worker shuts down when the crossbeam channel closes (all senders dropped).
fn storage_worker_loop<Tx>(
proof_tx: ProofTaskTx<Tx>,
work_rx: CrossbeamReceiver<StorageWorkerJob>,
worker_id: usize,
) where
Tx: DbTx,
{
tracing::debug!(
target: "trie::proof_task",
worker_id,
"Storage worker started"
);
// Create factories once at worker startup to avoid recreation overhead.
let (trie_cursor_factory, hashed_cursor_factory) = proof_tx.create_factories();
// Create blinded provider factory once for all blinded node requests
let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
trie_cursor_factory.clone(),
hashed_cursor_factory.clone(),
proof_tx.task_ctx.prefix_sets.clone(),
);
let mut storage_proofs_processed = 0u64;
let mut storage_nodes_processed = 0u64;
while let Ok(job) = work_rx.recv() {
match job {
StorageWorkerJob::StorageProof { input, result_sender } => {
let hashed_address = input.hashed_address;
trace!(
target: "trie::proof_task",
worker_id,
hashed_address = ?hashed_address,
prefix_set_len = input.prefix_set.len(),
target_slots = input.target_slots.len(),
"Processing storage proof"
);
let proof_start = Instant::now();
let result = proof_tx.compute_storage_proof(
input,
trie_cursor_factory.clone(),
hashed_cursor_factory.clone(),
);
let proof_elapsed = proof_start.elapsed();
storage_proofs_processed += 1;
if result_sender.send(result).is_err() {
tracing::debug!(
target: "trie::proof_task",
worker_id,
hashed_address = ?hashed_address,
storage_proofs_processed,
"Storage proof receiver dropped, discarding result"
);
}
trace!(
target: "trie::proof_task",
worker_id,
hashed_address = ?hashed_address,
proof_time_us = proof_elapsed.as_micros(),
total_processed = storage_proofs_processed,
"Storage proof completed"
);
}
StorageWorkerJob::BlindedStorageNode { account, path, result_sender } => {
trace!(
target: "trie::proof_task",
worker_id,
?account,
?path,
"Processing blinded storage node"
);
let start = Instant::now();
let result =
blinded_provider_factory.storage_node_provider(account).trie_node(&path);
let elapsed = start.elapsed();
storage_nodes_processed += 1;
if result_sender.send(result).is_err() {
tracing::debug!(
target: "trie::proof_task",
worker_id,
?account,
?path,
storage_nodes_processed,
"Blinded storage node receiver dropped, discarding result"
);
}
trace!(
target: "trie::proof_task",
worker_id,
?account,
?path,
elapsed_us = elapsed.as_micros(),
total_processed = storage_nodes_processed,
"Blinded storage node completed"
);
}
}
}
tracing::debug!(
target: "trie::proof_task",
worker_id,
storage_proofs_processed,
storage_nodes_processed,
"Storage worker shutting down"
);
}
impl<Factory> ProofTaskManager<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
/// Creates a new [`ProofTaskManager`] with pre-spawned storage proof workers.
///
/// Returns an error if the consistent view provider fails to create a read-only transaction.
/// The `storage_worker_count` determines how many storage workers to spawn, and
/// `max_concurrency` determines the limit for on-demand operations (blinded account nodes).
/// These are now independent - storage workers are spawned as requested, and on-demand
/// operations use a separate concurrency pool for blinded account nodes.
/// Returns an error if the underlying provider fails to create the transactions required for
/// spawning workers.
pub fn new(
executor: Handle,
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
max_concurrency: usize,
) -> Self {
storage_worker_count: usize,
) -> ProviderResult<Self> {
let (tx_sender, proof_task_rx) = channel();
Self {
// Use unbounded channel to ensure all storage operations are queued to workers.
// This maintains transaction reuse benefits and avoids fallback to on-demand execution.
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
tracing::info!(
target: "trie::proof_task",
storage_worker_count,
max_concurrency,
"Initializing storage worker pool with unbounded queue"
);
let mut spawned_workers = 0;
for worker_id in 0..storage_worker_count {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
let work_rx = storage_work_rx.clone();
executor.spawn_blocking(move || storage_worker_loop(proof_task_tx, work_rx, worker_id));
spawned_workers += 1;
tracing::debug!(
target: "trie::proof_task",
worker_id,
spawned_workers,
"Storage worker spawned successfully"
);
}
Ok(Self {
storage_work_tx,
storage_worker_count: spawned_workers,
max_concurrency,
total_transactions: 0,
pending_tasks: VecDeque::new(),
proof_task_txs: Vec::with_capacity(max_concurrency),
view,
task_ctx,
pending_tasks: VecDeque::new(),
executor,
proof_task_txs: Vec::new(),
proof_task_rx,
tx_sender,
active_handles: Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "metrics")]
metrics: ProofTaskMetrics::default(),
}
})
}
/// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
@@ -158,14 +427,12 @@ where
let tx_sender = self.tx_sender.clone();
self.executor.spawn_blocking(move || match task {
ProofTaskKind::StorageProof(input, sender) => {
proof_task_tx.storage_proof(input, sender, tx_sender);
}
ProofTaskKind::BlindedAccountNode(path, sender) => {
proof_task_tx.blinded_account_node(path, sender, tx_sender);
}
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
// Storage trie operations should never reach here as they're routed to worker pool
ProofTaskKind::BlindedStorageNode(_, _, _) | ProofTaskKind::StorageProof(_, _) => {
unreachable!("Storage trie operations should be routed to worker pool")
}
});
@@ -173,42 +440,121 @@ where
}
/// Loops, managing the proof tasks, and sending new tasks to the executor.
///
/// # Task Routing
///
/// - **Storage Trie Operations** (`StorageProof` and `BlindedStorageNode`): Routed to
/// pre-spawned worker pool via unbounded channel.
/// - **Account Trie Operations** (`BlindedAccountNode`): Queued for on-demand execution via
/// `pending_tasks`.
///
/// # Shutdown
///
/// On termination, `storage_work_tx` is dropped, closing the channel and
/// signaling all workers to shut down gracefully.
pub fn run(mut self) -> ProviderResult<()> {
loop {
match self.proof_task_rx.recv() {
Ok(message) => match message {
ProofTaskMessage::QueueTask(task) => {
// Track metrics for blinded node requests
#[cfg(feature = "metrics")]
match &task {
Ok(message) => {
match message {
ProofTaskMessage::QueueTask(task) => match task {
ProofTaskKind::StorageProof(input, sender) => {
match self.storage_work_tx.send(StorageWorkerJob::StorageProof {
input,
result_sender: sender,
}) {
Ok(_) => {
tracing::trace!(
target: "trie::proof_task",
"Storage proof dispatched to worker pool"
);
}
Err(crossbeam_channel::SendError(job)) => {
tracing::error!(
target: "trie::proof_task",
storage_worker_count = self.storage_worker_count,
"Worker pool disconnected, cannot process storage proof"
);
// Send error back to caller
let _ = job.send_worker_unavailable_error();
}
}
}
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
#[cfg(feature = "metrics")]
{
self.metrics.storage_nodes += 1;
}
match self.storage_work_tx.send(
StorageWorkerJob::BlindedStorageNode {
account,
path,
result_sender: sender,
},
) {
Ok(_) => {
tracing::trace!(
target: "trie::proof_task",
?account,
?path,
"Blinded storage node dispatched to worker pool"
);
}
Err(crossbeam_channel::SendError(job)) => {
tracing::warn!(
target: "trie::proof_task",
storage_worker_count = self.storage_worker_count,
?account,
?path,
"Worker pool disconnected, cannot process blinded storage node"
);
// Send error back to caller
let _ = job.send_worker_unavailable_error();
}
}
}
ProofTaskKind::BlindedAccountNode(_, _) => {
self.metrics.account_nodes += 1;
// Route account trie operations to pending_tasks
#[cfg(feature = "metrics")]
{
self.metrics.account_nodes += 1;
}
self.queue_proof_task(task);
}
ProofTaskKind::BlindedStorageNode(_, _, _) => {
self.metrics.storage_nodes += 1;
}
_ => {}
},
ProofTaskMessage::Transaction(tx) => {
// Return transaction to pending_tasks pool
self.proof_task_txs.push(tx);
}
ProofTaskMessage::Terminate => {
// Drop storage_work_tx to signal workers to shut down
drop(self.storage_work_tx);
tracing::debug!(
target: "trie::proof_task",
storage_worker_count = self.storage_worker_count,
"Shutting down proof task manager, signaling workers to terminate"
);
// Record metrics before terminating
#[cfg(feature = "metrics")]
self.metrics.record();
return Ok(())
}
// queue the task
self.queue_proof_task(task)
}
ProofTaskMessage::Transaction(tx) => {
// return the transaction to the pool
self.proof_task_txs.push(tx);
}
ProofTaskMessage::Terminate => {
// Record metrics before terminating
#[cfg(feature = "metrics")]
self.metrics.record();
return Ok(())
}
},
}
// All senders are disconnected, so we can terminate
// However this should never happen, as this struct stores a sender
Err(_) => return Ok(()),
};
// try spawning the next task
// Try spawning pending account trie tasks
self.try_spawn_next()?;
}
}
@@ -246,6 +592,7 @@ impl<Tx> ProofTaskTx<Tx>
where
Tx: DbTx,
{
#[inline]
fn create_factories(&self) -> ProofFactories<'_, Tx> {
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(&self.tx),
@@ -260,82 +607,70 @@ where
(trie_cursor_factory, hashed_cursor_factory)
}
/// Calculates a storage proof for the given hashed address, and desired prefix set.
fn storage_proof(
self,
/// Compute storage proof with pre-created factories.
///
/// Accepts cursor factories as parameters to allow reuse across multiple proofs.
/// Used by storage workers in the worker pool to avoid factory recreation
/// overhead on each proof computation.
#[inline]
fn compute_storage_proof(
&self,
input: StorageProofInput,
result_sender: Sender<StorageProofResult>,
tx_sender: Sender<ProofTaskMessage<Tx>>,
) {
trace!(
target: "trie::proof_task",
hashed_address=?input.hashed_address,
"Starting storage proof task calculation"
);
trie_cursor_factory: impl TrieCursorFactory,
hashed_cursor_factory: impl HashedCursorFactory,
) -> StorageProofResult {
// Consume the input so we can move large collections (e.g. target slots) without cloning.
let StorageProofInput {
hashed_address,
prefix_set,
target_slots,
with_branch_node_masks,
multi_added_removed_keys,
} = input;
let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
let multi_added_removed_keys = input
.multi_added_removed_keys
.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&input.hashed_address);
// Get or create added/removed keys context
let multi_added_removed_keys =
multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
let span = tracing::trace_span!(
target: "trie::proof_task",
"Storage proof calculation",
hashed_address=?input.hashed_address,
// Add a unique id because we often have parallel storage proof calculations for the
// same hashed address, and we want to differentiate them during trace analysis.
span_id=self.id,
hashed_address = ?hashed_address,
worker_id = self.id,
);
let span_guard = span.enter();
let _span_guard = span.enter();
let target_slots_len = input.target_slots.len();
let proof_start = Instant::now();
let raw_proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
input.hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().copied()))
.with_branch_node_masks(input.with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(input.target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
drop(span_guard);
// Compute raw storage multiproof
let raw_proof_result =
StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
.with_branch_node_masks(with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
// Decode proof into DecodedStorageMultiProof
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
input.hashed_address, e
hashed_address, e
))
})
});
trace!(
target: "trie::proof_task",
hashed_address=?input.hashed_address,
prefix_set = ?input.prefix_set.len(),
target_slots = ?target_slots_len,
proof_time = ?proof_start.elapsed(),
"Completed storage proof task calculation"
hashed_address = ?hashed_address,
proof_time_us = proof_start.elapsed().as_micros(),
worker_id = self.id,
"Completed storage proof calculation"
);
// send the result back
if let Err(error) = result_sender.send(decoded_result) {
debug!(
target: "trie::proof_task",
hashed_address = ?input.hashed_address,
?error,
task_time = ?proof_start.elapsed(),
"Storage proof receiver is dropped, discarding the result"
);
}
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
decoded_result
}
/// Retrieves blinded account node by path.
@@ -380,53 +715,6 @@ where
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
}
/// Retrieves blinded storage node of the given account by path.
fn blinded_storage_node(
self,
account: B256,
path: Nibbles,
result_sender: Sender<TrieNodeProviderResult>,
tx_sender: Sender<ProofTaskMessage<Tx>>,
) {
trace!(
target: "trie::proof_task",
?account,
?path,
"Starting blinded storage node retrieval"
);
let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
trie_cursor_factory,
hashed_cursor_factory,
self.task_ctx.prefix_sets.clone(),
);
let start = Instant::now();
let result = blinded_provider_factory.storage_node_provider(account).trie_node(&path);
trace!(
target: "trie::proof_task",
?account,
?path,
elapsed = ?start.elapsed(),
"Completed blinded storage node retrieval"
);
if let Err(error) = result_sender.send(result) {
tracing::error!(
target: "trie::proof_task",
?account,
?path,
?error,
"Failed to send blinded storage node result"
);
}
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
}
}
/// This represents an input for a storage proof.
@@ -607,3 +895,48 @@ impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
rx.recv().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::map::B256Map;
use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory};
use reth_trie_common::{
prefix_set::TriePrefixSetsMut, updates::TrieUpdatesSorted, HashedAccountsSorted,
HashedPostStateSorted,
};
use std::sync::Arc;
use tokio::{runtime::Builder, task};
fn test_ctx() -> ProofTaskCtx {
ProofTaskCtx::new(
Arc::new(TrieUpdatesSorted::default()),
Arc::new(HashedPostStateSorted::new(
HashedAccountsSorted::default(),
B256Map::default(),
)),
Arc::new(TriePrefixSetsMut::default()),
)
}
/// Ensures `max_concurrency` is independent of storage workers.
#[test]
fn proof_task_manager_independent_pools() {
let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
runtime.block_on(async {
let handle = tokio::runtime::Handle::current();
let factory = create_test_provider_factory();
let view = ConsistentDbView::new(factory, None);
let ctx = test_ctx();
let manager = ProofTaskManager::new(handle.clone(), view, ctx, 1, 5).unwrap();
// With storage_worker_count=5, we get exactly 5 workers
assert_eq!(manager.storage_worker_count, 5);
// max_concurrency=1 is for on-demand operations only
assert_eq!(manager.max_concurrency, 1);
drop(manager);
task::yield_now().await;
});
}
}

View File

@@ -864,6 +864,9 @@ Engine:
--engine.allow-unwind-canonical-header
Allow unwinding canonical header to ancestor during forkchoice updates. See `TreeConfig::unwind_canonical_header` for more details
--engine.storage-worker-count <STORAGE_WORKER_COUNT>
Configure the number of storage proof workers in the Tokio blocking pool. If not specified, defaults to 2x available parallelism, clamped between 2 and 64
ERA:
--era.enable
Enable import from ERA1 files