refactor(trie): reorder proof worker handle and task management

- Introduced `ProofWorkerHandle` for type-safe access to storage and account worker pools.
- Replaced `ProofResultContext` with direct senders for improved performance.
- Updated worker loops to handle proof tasks more efficiently, including better tracing and error handling.
- Added methods for dispatching storage and account multiproof computations, improving interleaved parallelism.
- Enhanced metrics tracking for storage and account processing.
This commit is contained in:
Yong Kang
2025-10-27 04:42:38 +00:00
parent e5db20471f
commit b879126a77

View File

@@ -77,15 +77,492 @@ use std::{
time::{Duration, Instant},
};
use tokio::runtime::Handle;
use tracing::{debug, debug_span, error, trace};
use tracing::{debug_span, error, trace};
#[cfg(feature = "metrics")]
use crate::proof_task_metrics::ProofTaskTrieMetrics;
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
type AccountMultiproofResult =
Result<(DecodedMultiProof, ParallelTrieStats), ParallelStateRootError>;
/// A handle that provides type-safe access to proof worker pools.
///
/// The handle stores direct senders to both storage and account worker pools,
/// eliminating the need for a routing thread. All handles share reference-counted
/// channels, and workers shut down gracefully when all handles are dropped.
#[derive(Debug, Clone)]
pub struct ProofWorkerHandle {
/// Direct sender to storage worker pool
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
/// Direct sender to account worker pool
account_work_tx: CrossbeamSender<AccountWorkerJob>,
/// Counter tracking available storage workers. Workers decrement when starting work,
/// increment when finishing. Used to determine whether to chunk multiproofs.
storage_available_workers: Arc<AtomicUsize>,
/// Counter tracking available account workers. Workers decrement when starting work,
/// increment when finishing. Used to determine whether to chunk multiproofs.
account_available_workers: Arc<AtomicUsize>,
}
impl ProofWorkerHandle {
/// Spawns storage and account worker pools with dedicated database transactions.
///
/// Returns a handle for submitting proof tasks to the worker pools.
/// Workers run until the last handle is dropped.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `storage_worker_count`: Number of storage workers to spawn
/// - `account_worker_count`: Number of account workers to spawn
pub fn new<Factory>(
executor: Handle,
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
storage_worker_count: usize,
account_worker_count: usize,
) -> Self
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
let (account_work_tx, account_work_rx) = unbounded::<AccountWorkerJob>();
// Initialize availability counters at zero. Each worker will increment when it
// successfully initializes, ensuring only healthy workers are counted.
let storage_available_workers = Arc::new(AtomicUsize::new(0));
let account_available_workers = Arc::new(AtomicUsize::new(0));
tracing::debug!(
target: "trie::proof_task",
storage_worker_count,
account_worker_count,
"Spawning proof worker pools"
);
let storage_worker_parent =
debug_span!(target: "trie::proof_task", "Storage worker tasks", ?storage_worker_count);
let _guard = storage_worker_parent.enter();
// Spawn storage workers
for worker_id in 0..storage_worker_count {
let parent_span = debug_span!(target: "trie::proof_task", "Storage worker", ?worker_id);
let view_clone = view.clone();
let task_ctx_clone = task_ctx.clone();
let work_rx_clone = storage_work_rx.clone();
let storage_available_workers_clone = storage_available_workers.clone();
executor.spawn_blocking(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
let _guard = parent_span.enter();
storage_worker_loop(
view_clone,
task_ctx_clone,
work_rx_clone,
worker_id,
storage_available_workers_clone,
#[cfg(feature = "metrics")]
metrics,
)
});
tracing::debug!(
target: "trie::proof_task",
worker_id,
"Storage worker spawned successfully"
);
}
drop(_guard);
let account_worker_parent =
debug_span!(target: "trie::proof_task", "Account worker tasks", ?account_worker_count);
let _guard = account_worker_parent.enter();
// Spawn account workers
for worker_id in 0..account_worker_count {
let parent_span = debug_span!(target: "trie::proof_task", "Account worker", ?worker_id);
let view_clone = view.clone();
let task_ctx_clone = task_ctx.clone();
let work_rx_clone = account_work_rx.clone();
let storage_work_tx_clone = storage_work_tx.clone();
let account_available_workers_clone = account_available_workers.clone();
executor.spawn_blocking(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
let _guard = parent_span.enter();
account_worker_loop(
view_clone,
task_ctx_clone,
work_rx_clone,
storage_work_tx_clone,
worker_id,
account_available_workers_clone,
#[cfg(feature = "metrics")]
metrics,
)
});
tracing::debug!(
target: "trie::proof_task",
worker_id,
"Account worker spawned successfully"
);
}
drop(_guard);
Self::new_handle(
storage_work_tx,
account_work_tx,
storage_available_workers,
account_available_workers,
)
}
/// Creates a new [`ProofWorkerHandle`] with direct access to worker pools.
///
/// This is an internal constructor used for creating handles.
const fn new_handle(
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
account_work_tx: CrossbeamSender<AccountWorkerJob>,
storage_available_workers: Arc<AtomicUsize>,
account_available_workers: Arc<AtomicUsize>,
) -> Self {
Self {
storage_work_tx,
account_work_tx,
storage_available_workers,
account_available_workers,
}
}
/// Returns true if there are available storage workers to process tasks.
pub fn has_available_storage_workers(&self) -> bool {
self.storage_available_workers.load(Ordering::Relaxed) > 0
}
/// Returns true if there are available account workers to process tasks.
pub fn has_available_account_workers(&self) -> bool {
self.account_available_workers.load(Ordering::Relaxed) > 0
}
/// Returns the number of pending storage tasks in the queue.
pub fn pending_storage_tasks(&self) -> usize {
self.storage_work_tx.len()
}
/// Returns the number of pending account tasks in the queue.
pub fn pending_account_tasks(&self) -> usize {
self.account_work_tx.len()
}
/// Dispatch a storage proof computation to storage worker pool
///
/// The result will be sent via the `proof_result_sender` channel.
pub fn dispatch_storage_proof(
&self,
input: StorageProofInput,
proof_result_sender: ProofResultContext,
) -> Result<(), ProviderError> {
self.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, proof_result_sender })
.map_err(|err| {
let error =
ProviderError::other(std::io::Error::other("storage workers unavailable"));
if let StorageWorkerJob::StorageProof { proof_result_sender, .. } = err.0 {
let ProofResultContext {
sender: result_tx,
sequence_number: seq,
state,
start_time: start,
} = proof_result_sender;
let _ = result_tx.send(ProofResultMessage {
sequence_number: seq,
result: Err(ParallelStateRootError::Provider(error.clone())),
elapsed: start.elapsed(),
state,
});
}
error
})
}
/// Dispatch an account multiproof computation
///
/// The result will be sent via the `result_sender` channel included in the input.
pub fn dispatch_account_multiproof(
&self,
input: AccountMultiproofInput,
) -> Result<(), ProviderError> {
self.account_work_tx
.send(AccountWorkerJob::AccountMultiproof { input: Box::new(input) })
.map_err(|err| {
let error =
ProviderError::other(std::io::Error::other("account workers unavailable"));
if let AccountWorkerJob::AccountMultiproof { input } = err.0 {
let AccountMultiproofInput {
proof_result_sender:
ProofResultContext {
sender: result_tx,
sequence_number: seq,
state,
start_time: start,
},
..
} = *input;
let _ = result_tx.send(ProofResultMessage {
sequence_number: seq,
result: Err(ParallelStateRootError::Provider(error.clone())),
elapsed: start.elapsed(),
state,
});
}
error
})
}
/// Dispatch blinded storage node request to storage worker pool
pub(crate) fn dispatch_blinded_storage_node(
&self,
account: B256,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.storage_work_tx
.send(StorageWorkerJob::BlindedStorageNode { account, path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("storage workers unavailable"))
})?;
Ok(rx)
}
/// Dispatch blinded account node request to account worker pool
pub(crate) fn dispatch_blinded_account_node(
&self,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.account_work_tx
.send(AccountWorkerJob::BlindedAccountNode { path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("account workers unavailable"))
})?;
Ok(rx)
}
}
/// Data used for initializing cursor factories that is shared across all storage proof instances.
#[derive(Debug, Clone)]
pub struct ProofTaskCtx {
/// The sorted collection of cached in-memory intermediate trie nodes that can be reused for
/// computation.
nodes_sorted: Arc<TrieUpdatesSorted>,
/// The sorted in-memory overlay hashed state.
state_sorted: Arc<HashedPostStateSorted>,
/// The collection of prefix sets for the computation. Since the prefix sets _always_
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
prefix_sets: Arc<TriePrefixSetsMut>,
}
impl ProofTaskCtx {
/// Creates a new [`ProofTaskCtx`] with the given sorted nodes and state.
pub const fn new(
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
) -> Self {
Self { nodes_sorted, state_sorted, prefix_sets }
}
}
/// Type alias for the factory tuple returned by [`ProofTaskTx::create_factories`].
type ProofFactories<'a, Tx> = (
InMemoryTrieCursorFactory<DatabaseTrieCursorFactory<&'a Tx>, &'a TrieUpdatesSorted>,
HashedPostStateCursorFactory<DatabaseHashedCursorFactory<&'a Tx>, &'a HashedPostStateSorted>,
);
/// This contains all information shared between all storage proof instances.
#[derive(Debug)]
pub struct ProofTaskTx<Tx> {
/// The tx that is reused for proof calculations.
tx: Tx,
/// Trie updates, prefix sets, and state updates
task_ctx: ProofTaskCtx,
/// Identifier for the worker within the worker pool, used only for tracing.
id: usize,
}
impl<Tx> ProofTaskTx<Tx> {
/// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. The id is
/// used only for tracing.
const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
Self { tx, task_ctx, id }
}
}
impl<Tx> ProofTaskTx<Tx>
where
Tx: DbTx,
{
#[inline]
fn create_factories(&self) -> ProofFactories<'_, Tx> {
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(&self.tx),
self.task_ctx.nodes_sorted.as_ref(),
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(&self.tx),
self.task_ctx.state_sorted.as_ref(),
);
(trie_cursor_factory, hashed_cursor_factory)
}
/// Compute storage proof with pre-created factories.
///
/// Accepts cursor factories as parameters to allow reuse across multiple proofs.
/// Used by storage workers in the worker pool to avoid factory recreation
/// overhead on each proof computation.
#[inline]
fn compute_storage_proof(
&self,
input: StorageProofInput,
trie_cursor_factory: impl TrieCursorFactory,
hashed_cursor_factory: impl HashedCursorFactory,
) -> StorageProofResult {
// Consume the input so we can move large collections (e.g. target slots) without cloning.
let StorageProofInput {
hashed_address,
prefix_set,
target_slots,
with_branch_node_masks,
multi_added_removed_keys,
} = input;
// Get or create added/removed keys context
let multi_added_removed_keys =
multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
let span = tracing::debug_span!(
target: "trie::proof_task",
"Storage proof calculation",
hashed_address = ?hashed_address,
worker_id = self.id,
);
let _span_guard = span.enter();
let proof_start = Instant::now();
// Compute raw storage multiproof
let raw_proof_result =
StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
.with_branch_node_masks(with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
// Decode proof into DecodedStorageMultiProof
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
hashed_address, e
))
})
});
trace!(
target: "trie::proof_task",
hashed_address = ?hashed_address,
proof_time_us = proof_start.elapsed().as_micros(),
worker_id = self.id,
"Completed storage proof calculation"
);
decoded_result
}
}
impl TrieNodeProviderFactory for ProofWorkerHandle {
type AccountNodeProvider = ProofTaskTrieNodeProvider;
type StorageNodeProvider = ProofTaskTrieNodeProvider;
fn account_node_provider(&self) -> Self::AccountNodeProvider {
ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() }
}
fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() }
}
}
/// Trie node provider for retrieving trie nodes by path.
#[derive(Debug)]
pub enum ProofTaskTrieNodeProvider {
/// Blinded account trie node provider.
AccountNode {
/// Handle to the proof worker pools.
handle: ProofWorkerHandle,
},
/// Blinded storage trie node provider.
StorageNode {
/// Target account.
account: B256,
/// Handle to the proof worker pools.
handle: ProofWorkerHandle,
},
}
impl TrieNodeProvider for ProofTaskTrieNodeProvider {
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
match self {
Self::AccountNode { handle } => {
let rx = handle
.dispatch_blinded_account_node(*path)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
}
Self::StorageNode { handle, account } => {
let rx = handle
.dispatch_blinded_storage_node(*account, *path)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
}
}
}
}
/// Result of a proof calculation, which can be either an account multiproof or a storage proof.
#[derive(Debug)]
pub enum ProofResult {
/// Account multiproof with statistics
AccountMultiproof(DecodedMultiProof, ParallelTrieStats),
/// Storage proof for a specific account
StorageProof {
/// The hashed address this storage proof belongs to
hashed_address: B256,
/// The storage multiproof
proof: DecodedStorageMultiProof,
},
}
/// Channel used by worker threads to deliver `ProofResultMessage` items back to
/// `MultiProofTask`.
@@ -101,8 +578,8 @@ pub type ProofResultSender = CrossbeamSender<ProofResultMessage>;
pub struct ProofResultMessage {
/// Sequence number for ordering proofs
pub sequence_number: u64,
/// The proof calculation result
pub result: AccountMultiproofResult,
/// The proof calculation result (either account multiproof or storage proof)
pub result: Result<ProofResult, ParallelStateRootError>,
/// Time taken for the entire proof calculation (from dispatch to completion)
pub elapsed: Duration,
/// Original state update that triggered this proof
@@ -248,18 +725,10 @@ fn storage_worker_loop<Factory>(
let proof_elapsed = proof_start.elapsed();
storage_proofs_processed += 1;
// Convert storage proof to account multiproof format
let result_msg = match result {
Ok(storage_proof) => {
let multiproof = reth_trie::DecodedMultiProof::from_storage_proof(
hashed_address,
storage_proof,
);
let stats = crate::stats::ParallelTrieTracker::default().finish();
Ok((multiproof, stats))
}
Err(e) => Err(e),
};
let result_msg = result.map(|storage_proof| ProofResult::StorageProof {
hashed_address,
proof: storage_proof,
});
if sender
.send(ProofResultMessage {
@@ -496,7 +965,7 @@ fn account_worker_loop<Factory>(
let proof_elapsed = proof_start.elapsed();
let total_elapsed = start.elapsed();
let stats = tracker.finish();
let result = result.map(|proof| (proof, stats));
let result = result.map(|proof| ProofResult::AccountMultiproof(proof, stats));
account_proofs_processed += 1;
// Send result to MultiProofTask
@@ -716,9 +1185,9 @@ where
// Consume remaining storage proof receivers for accounts not encountered during trie walk.
for (hashed_address, receiver) in storage_proof_receivers {
if let Ok(proof_msg) = receiver.recv() {
// Extract storage proof from the multiproof wrapper
if let Ok((mut multiproof, _stats)) = proof_msg.result &&
let Some(proof) = multiproof.storages.remove(&hashed_address)
// Extract storage proof from the result
if let Ok(ProofResult::StorageProof { hashed_address: addr, proof }) = proof_msg.result &&
addr == hashed_address
{
collected_decoded_storages.insert(hashed_address, proof);
}
@@ -807,119 +1276,6 @@ fn dispatch_storage_proofs(
Ok(storage_proof_receivers)
}
/// Type alias for the factory tuple returned by `create_factories`
type ProofFactories<'a, Tx> = (
InMemoryTrieCursorFactory<DatabaseTrieCursorFactory<&'a Tx>, &'a TrieUpdatesSorted>,
HashedPostStateCursorFactory<DatabaseHashedCursorFactory<&'a Tx>, &'a HashedPostStateSorted>,
);
/// This contains all information shared between all storage proof instances.
#[derive(Debug)]
pub struct ProofTaskTx<Tx> {
/// The tx that is reused for proof calculations.
tx: Tx,
/// Trie updates, prefix sets, and state updates
task_ctx: ProofTaskCtx,
/// Identifier for the worker within the worker pool, used only for tracing.
id: usize,
}
impl<Tx> ProofTaskTx<Tx> {
/// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. The id is
/// used only for tracing.
const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
Self { tx, task_ctx, id }
}
}
impl<Tx> ProofTaskTx<Tx>
where
Tx: DbTx,
{
#[inline]
fn create_factories(&self) -> ProofFactories<'_, Tx> {
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(&self.tx),
self.task_ctx.nodes_sorted.as_ref(),
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(&self.tx),
self.task_ctx.state_sorted.as_ref(),
);
(trie_cursor_factory, hashed_cursor_factory)
}
/// Compute storage proof with pre-created factories.
///
/// Accepts cursor factories as parameters to allow reuse across multiple proofs.
/// Used by storage workers in the worker pool to avoid factory recreation
/// overhead on each proof computation.
#[inline]
fn compute_storage_proof(
&self,
input: StorageProofInput,
trie_cursor_factory: impl TrieCursorFactory,
hashed_cursor_factory: impl HashedCursorFactory,
) -> StorageProofResult {
// Consume the input so we can move large collections (e.g. target slots) without cloning.
let StorageProofInput {
hashed_address,
prefix_set,
target_slots,
with_branch_node_masks,
multi_added_removed_keys,
} = input;
// Get or create added/removed keys context
let multi_added_removed_keys =
multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
let span = debug_span!(
target: "trie::proof_task",
"Storage proof calculation",
hashed_address = ?hashed_address,
worker_id = self.id,
);
let _span_guard = span.enter();
let proof_start = Instant::now();
// Compute raw storage multiproof
let raw_proof_result =
StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
.with_branch_node_masks(with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
// Decode proof into DecodedStorageMultiProof
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
hashed_address, e
))
})
});
trace!(
target: "trie::proof_task",
hashed_address = ?hashed_address,
proof_time_us = proof_start.elapsed().as_micros(),
worker_id = self.id,
"Completed storage proof calculation"
);
decoded_result
}
}
/// Input parameters for storage proof computation.
#[derive(Debug)]
pub struct StorageProofInput {
@@ -1005,327 +1361,6 @@ enum AccountWorkerJob {
},
}
/// Data used for initializing cursor factories that is shared across all storage proof instances.
#[derive(Debug, Clone)]
pub struct ProofTaskCtx {
/// The sorted collection of cached in-memory intermediate trie nodes that can be reused for
/// computation.
nodes_sorted: Arc<TrieUpdatesSorted>,
/// The sorted in-memory overlay hashed state.
state_sorted: Arc<HashedPostStateSorted>,
/// The collection of prefix sets for the computation. Since the prefix sets _always_
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
prefix_sets: Arc<TriePrefixSetsMut>,
}
impl ProofTaskCtx {
/// Creates a new [`ProofTaskCtx`] with the given sorted nodes and state.
pub const fn new(
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
) -> Self {
Self { nodes_sorted, state_sorted, prefix_sets }
}
}
/// A handle that provides type-safe access to proof worker pools.
///
/// The handle stores direct senders to both storage and account worker pools,
/// eliminating the need for a routing thread. All handles share reference-counted
/// channels, and workers shut down gracefully when all handles are dropped.
#[derive(Debug, Clone)]
pub struct ProofWorkerHandle {
/// Direct sender to storage worker pool
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
/// Direct sender to account worker pool
account_work_tx: CrossbeamSender<AccountWorkerJob>,
/// Counter tracking available storage workers. Workers decrement when starting work,
/// increment when finishing. Used to determine whether to chunk multiproofs.
storage_available_workers: Arc<AtomicUsize>,
/// Counter tracking available account workers. Workers decrement when starting work,
/// increment when finishing. Used to determine whether to chunk multiproofs.
account_available_workers: Arc<AtomicUsize>,
}
impl ProofWorkerHandle {
/// Spawns storage and account worker pools with dedicated database transactions.
///
/// Returns a handle for submitting proof tasks to the worker pools.
/// Workers run until the last handle is dropped.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `storage_worker_count`: Number of storage workers to spawn
/// - `account_worker_count`: Number of account workers to spawn
pub fn new<Factory>(
executor: Handle,
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
storage_worker_count: usize,
account_worker_count: usize,
) -> Self
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
let (account_work_tx, account_work_rx) = unbounded::<AccountWorkerJob>();
// Initialize availability counters at zero. Each worker will increment when it
// successfully initializes, ensuring only healthy workers are counted.
let storage_available_workers = Arc::new(AtomicUsize::new(0));
let account_available_workers = Arc::new(AtomicUsize::new(0));
debug!(
target: "trie::proof_task",
storage_worker_count,
account_worker_count,
"Spawning proof worker pools"
);
let parent_span =
debug_span!(target: "trie::proof_task", "storage proof workers", ?storage_worker_count)
.entered();
// Spawn storage workers
for worker_id in 0..storage_worker_count {
let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
let view_clone = view.clone();
let task_ctx_clone = task_ctx.clone();
let work_rx_clone = storage_work_rx.clone();
let storage_available_workers_clone = storage_available_workers.clone();
executor.spawn_blocking(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
let _guard = span.enter();
storage_worker_loop(
view_clone,
task_ctx_clone,
work_rx_clone,
worker_id,
storage_available_workers_clone,
#[cfg(feature = "metrics")]
metrics,
)
});
}
drop(parent_span);
let parent_span =
debug_span!(target: "trie::proof_task", "account proof workers", ?storage_worker_count)
.entered();
// Spawn account workers
for worker_id in 0..account_worker_count {
let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
let view_clone = view.clone();
let task_ctx_clone = task_ctx.clone();
let work_rx_clone = account_work_rx.clone();
let storage_work_tx_clone = storage_work_tx.clone();
let account_available_workers_clone = account_available_workers.clone();
executor.spawn_blocking(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
let _guard = span.enter();
account_worker_loop(
view_clone,
task_ctx_clone,
work_rx_clone,
storage_work_tx_clone,
worker_id,
account_available_workers_clone,
#[cfg(feature = "metrics")]
metrics,
)
});
}
drop(parent_span);
Self {
storage_work_tx,
account_work_tx,
storage_available_workers,
account_available_workers,
}
}
/// Returns true if there are available storage workers to process tasks.
pub fn has_available_storage_workers(&self) -> bool {
self.storage_available_workers.load(Ordering::Relaxed) > 0
}
/// Returns true if there are available account workers to process tasks.
pub fn has_available_account_workers(&self) -> bool {
self.account_available_workers.load(Ordering::Relaxed) > 0
}
/// Returns the number of pending storage tasks in the queue.
pub fn pending_storage_tasks(&self) -> usize {
self.storage_work_tx.len()
}
/// Returns the number of pending account tasks in the queue.
pub fn pending_account_tasks(&self) -> usize {
self.account_work_tx.len()
}
/// Dispatch a storage proof computation to storage worker pool
///
/// The result will be sent via the `proof_result_sender` channel.
pub fn dispatch_storage_proof(
&self,
input: StorageProofInput,
proof_result_sender: ProofResultContext,
) -> Result<(), ProviderError> {
self.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, proof_result_sender })
.map_err(|err| {
let error =
ProviderError::other(std::io::Error::other("storage workers unavailable"));
if let StorageWorkerJob::StorageProof { proof_result_sender, .. } = err.0 {
let ProofResultContext {
sender: result_tx,
sequence_number: seq,
state,
start_time: start,
} = proof_result_sender;
let _ = result_tx.send(ProofResultMessage {
sequence_number: seq,
result: Err(ParallelStateRootError::Provider(error.clone())),
elapsed: start.elapsed(),
state,
});
}
error
})
}
/// Dispatch an account multiproof computation
///
/// The result will be sent via the `result_sender` channel included in the input.
pub fn dispatch_account_multiproof(
&self,
input: AccountMultiproofInput,
) -> Result<(), ProviderError> {
self.account_work_tx
.send(AccountWorkerJob::AccountMultiproof { input: Box::new(input) })
.map_err(|err| {
let error =
ProviderError::other(std::io::Error::other("account workers unavailable"));
if let AccountWorkerJob::AccountMultiproof { input } = err.0 {
let AccountMultiproofInput {
proof_result_sender:
ProofResultContext {
sender: result_tx,
sequence_number: seq,
state,
start_time: start,
},
..
} = *input;
let _ = result_tx.send(ProofResultMessage {
sequence_number: seq,
result: Err(ParallelStateRootError::Provider(error.clone())),
elapsed: start.elapsed(),
state,
});
}
error
})
}
/// Dispatch blinded storage node request to storage worker pool
pub(crate) fn dispatch_blinded_storage_node(
&self,
account: B256,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.storage_work_tx
.send(StorageWorkerJob::BlindedStorageNode { account, path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("storage workers unavailable"))
})?;
Ok(rx)
}
/// Dispatch blinded account node request to account worker pool
pub(crate) fn dispatch_blinded_account_node(
&self,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.account_work_tx
.send(AccountWorkerJob::BlindedAccountNode { path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("account workers unavailable"))
})?;
Ok(rx)
}
}
impl TrieNodeProviderFactory for ProofWorkerHandle {
type AccountNodeProvider = ProofTaskTrieNodeProvider;
type StorageNodeProvider = ProofTaskTrieNodeProvider;
fn account_node_provider(&self) -> Self::AccountNodeProvider {
ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() }
}
fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() }
}
}
/// Trie node provider for retrieving trie nodes by path.
#[derive(Debug)]
pub enum ProofTaskTrieNodeProvider {
/// Blinded account trie node provider.
AccountNode {
/// Handle to the proof worker pools.
handle: ProofWorkerHandle,
},
/// Blinded storage trie node provider.
StorageNode {
/// Target account.
account: B256,
/// Handle to the proof worker pools.
handle: ProofWorkerHandle,
},
}
impl TrieNodeProvider for ProofTaskTrieNodeProvider {
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
match self {
Self::AccountNode { handle } => {
let rx = handle
.dispatch_blinded_account_node(*path)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
}
Self::StorageNode { handle, account } => {
let rx = handle
.dispatch_blinded_storage_node(*account, *path)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;