refactor: update proof task management to use spawn_proof_workers

- Replaced the ProofTaskManager with a new spawn_proof_workers function for better clarity and maintainability.
- Updated related code to utilize the new function, simplifying the worker spawning process.
- Enhanced metrics tracking for storage and account proof requests, ensuring thread-safe operations.
- Improved error handling and code structure across proof task implementations.
This commit is contained in:
Yong Kang
2025-10-10 09:49:05 +00:00
parent c610bc9ea5
commit 1ca1637f2a
5 changed files with 420 additions and 239 deletions

View File

@@ -32,7 +32,7 @@ use reth_provider::{
use reth_revm::{db::BundleState, state::EvmState};
use reth_trie::TrieInput;
use reth_trie_parallel::{
proof_task::{ProofTaskCtx, ProofTaskManager},
proof_task::{spawn_proof_workers, ProofTaskCtx},
root::ParallelStateRootError,
};
use reth_trie_sparse::{
@@ -204,14 +204,14 @@ where
let storage_worker_count = config.storage_worker_count();
let account_worker_count = config.account_worker_count();
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
let proof_task = match ProofTaskManager::new(
let proof_handle = match spawn_proof_workers(
self.executor.handle().clone(),
consistent_view,
task_ctx,
storage_worker_count,
account_worker_count,
) {
Ok(task) => task,
Ok(handle) => handle,
Err(error) => {
return Err((error, transactions, env, provider_builder));
}
@@ -223,7 +223,7 @@ where
let multi_proof_task = MultiProofTask::new(
state_root_config,
self.executor.clone(),
proof_task.handle(),
proof_handle.clone(),
to_sparse_trie,
max_multi_proof_task_concurrency,
config.multiproof_chunking_enabled().then_some(config.multiproof_chunk_size()),
@@ -252,19 +252,7 @@ where
let (state_root_tx, state_root_rx) = channel();
// Spawn the sparse trie task using any stored trie and parallel trie configuration.
self.spawn_sparse_trie_task(sparse_trie_rx, proof_task.handle(), state_root_tx);
// spawn the proof task
self.executor.spawn_blocking(move || {
if let Err(err) = proof_task.run() {
// At least log if there is an error at any point
tracing::error!(
target: "engine::root",
?err,
"Storage proof task returned an error"
);
}
});
self.spawn_sparse_trie_task(sparse_trie_rx, proof_handle, state_root_tx);
Ok(PayloadHandle {
to_multi_proof,

View File

@@ -20,7 +20,7 @@ use reth_trie::{
};
use reth_trie_parallel::{
proof::ParallelProof,
proof_task::{AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle},
proof_task::{AccountMultiproofInput, ProofTaskManagerHandle},
root::ParallelStateRootError,
};
use std::{
@@ -556,15 +556,10 @@ impl MultiproofManager {
missed_leaves_storage_roots,
};
let (sender, receiver) = channel();
let proof_result: Result<DecodedMultiProof, ParallelStateRootError> = (|| {
account_proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof to worker pool".into(),
)
})?;
let receiver = account_proof_task_handle
.queue_account_multiproof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
receiver
.recv()
@@ -1223,7 +1218,7 @@ mod tests {
DatabaseProviderFactory,
};
use reth_trie::{MultiProof, TrieInput};
use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager};
use reth_trie_parallel::proof_task::{spawn_proof_workers, ProofTaskCtx};
use revm_primitives::{B256, U256};
fn create_test_state_root_task<F>(factory: F) -> MultiProofTask
@@ -1238,12 +1233,12 @@ mod tests {
config.prefix_sets.clone(),
);
let consistent_view = ConsistentDbView::new(factory, None);
let proof_task =
ProofTaskManager::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
.expect("Failed to create ProofTaskManager");
let proof_handle =
spawn_proof_workers(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
.expect("Failed to spawn proof workers");
let channel = channel();
MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None)
MultiProofTask::new(config, executor, proof_handle, channel.0, 1, None)
}
#[test]

View File

@@ -1,8 +1,6 @@
use crate::{
metrics::ParallelTrieMetrics,
proof_task::{
AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle, StorageProofInput,
},
proof_task::{AccountMultiproofInput, ProofTaskManagerHandle, StorageProofInput},
root::ParallelStateRootError,
StorageRootTargets,
};
@@ -16,10 +14,7 @@ use reth_trie::{
DecodedMultiProof, DecodedStorageMultiProof, HashedPostStateSorted, MultiProofTargets, Nibbles,
};
use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
use std::sync::{
mpsc::{channel, Receiver},
Arc,
};
use std::sync::{mpsc::Receiver, Arc};
use tracing::trace;
/// Parallel proof calculator.
@@ -93,7 +88,10 @@ impl ParallelProof {
hashed_address: B256,
prefix_set: PrefixSet,
target_slots: B256Set,
) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
) -> Result<
Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>>,
ParallelStateRootError,
> {
let input = StorageProofInput::new(
hashed_address,
prefix_set,
@@ -102,9 +100,9 @@ impl ParallelProof {
self.multi_added_removed_keys.clone(),
);
let (sender, receiver) = std::sync::mpsc::channel();
let _ = self.proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
receiver
self.proof_task_handle
.queue_storage_proof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
}
/// Generate a storage multiproof according to the specified targets and hashed address.
@@ -124,7 +122,7 @@ impl ParallelProof {
"Starting storage proof generation"
);
let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots)?;
let proof_result = receiver.recv().map_err(|_| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
format!("channel closed for {hashed_address}"),
@@ -193,15 +191,10 @@ impl ParallelProof {
missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
};
let (sender, receiver) = channel();
self.proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof: account worker pool unavailable"
.to_string(),
)
})?;
let receiver = self
.proof_task_handle
.queue_account_multiproof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;
// Wait for account multiproof result from worker
let (multiproof, stats) = receiver.recv().map_err(|_| {
@@ -231,7 +224,7 @@ impl ParallelProof {
#[cfg(test)]
mod tests {
use super::*;
use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
use crate::proof_task::{spawn_proof_workers, ProofTaskCtx};
use alloy_primitives::{
keccak256,
map::{B256Set, DefaultHashBuilder, HashMap},
@@ -313,13 +306,8 @@ mod tests {
let task_ctx =
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
let proof_task =
ProofTaskManager::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
let proof_task_handle = proof_task.handle();
// keep the join handle around to make sure it does not return any errors
// after we compute the state root
let join_handle = rt.spawn_blocking(move || proof_task.run());
let proof_task_handle =
spawn_proof_workers(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
let parallel_result = ParallelProof::new(
Default::default(),
@@ -354,9 +342,7 @@ mod tests {
// then compare the entire thing for any mask differences
assert_eq!(parallel_result, sequential_result_decoded);
// drop the handle to terminate the task and then block on the proof task handle to make
// sure it does not return any errors
// Workers shut down automatically when handle is dropped
drop(proof_task_handle);
rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
}
}

View File

@@ -21,7 +21,7 @@ use alloy_rlp::{BufMut, Encodable};
use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
use dashmap::DashMap;
use reth_db_api::transaction::DbTx;
use reth_execution_errors::SparseTrieError;
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
ProviderResult,
@@ -88,24 +88,148 @@ enum StorageWorkerJob {
},
}
impl StorageWorkerJob {
/// Sends an error back to the caller when worker pool is unavailable.
///
/// Returns `Ok(())` if the error was sent successfully, or `Err(())` if the receiver was
/// dropped.
fn send_worker_unavailable_error(&self) -> Result<(), ()> {
let error =
ParallelStateRootError::Other("Storage proof worker pool unavailable".to_string());
/// Spawns storage and account worker pools with dedicated database transactions.
///
/// Returns a handle for submitting proof tasks to the worker pools.
/// Workers run until the last handle is dropped.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `storage_worker_count`: Number of storage workers to spawn
/// - `account_worker_count`: Number of account workers to spawn
pub fn spawn_proof_workers<Factory>(
executor: Handle,
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
storage_worker_count: usize,
account_worker_count: usize,
) -> ProviderResult<ProofTaskManagerHandle>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
let (account_work_tx, account_work_rx) = unbounded::<AccountWorkerJob>();
match self {
Self::StorageProof { result_sender, .. } => {
result_sender.send(Err(error)).map_err(|_| ())
}
Self::BlindedStorageNode { result_sender, .. } => result_sender
.send(Err(SparseTrieError::from(SparseTrieErrorKind::Other(Box::new(error)))))
.map_err(|_| ()),
}
tracing::info!(
target: "trie::proof_task",
storage_worker_count,
account_worker_count,
"Spawning proof worker pools"
);
// Spawn storage workers (reuse existing spawn_storage_workers logic)
spawn_storage_workers_internal(
&executor,
&view,
&task_ctx,
storage_worker_count,
storage_work_rx,
)?;
// Spawn account workers (reuse existing spawn_account_workers logic)
spawn_account_workers_internal(
&executor,
&view,
&task_ctx,
account_worker_count,
account_work_rx,
storage_work_tx.clone(),
)?;
Ok(ProofTaskManagerHandle::new(
storage_work_tx,
account_work_tx,
Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "metrics")]
Arc::new(ProofTaskMetrics::default()),
))
}
/// Spawns a pool of storage workers with dedicated database transactions.
///
/// Each worker receives `StorageWorkerJob` from the channel and processes storage proofs
/// and blinded storage node requests using a dedicated long-lived transaction.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `worker_count`: Number of storage workers to spawn
/// - `work_rx`: Receiver for storage worker jobs
fn spawn_storage_workers_internal<Factory>(
executor: &Handle,
view: &ConsistentDbView<Factory>,
task_ctx: &ProofTaskCtx,
worker_count: usize,
work_rx: CrossbeamReceiver<StorageWorkerJob>,
) -> ProviderResult<()>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
for worker_id in 0..worker_count {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
let work_rx_clone = work_rx.clone();
executor
.spawn_blocking(move || storage_worker_loop(proof_task_tx, work_rx_clone, worker_id));
tracing::debug!(
target: "trie::proof_task",
worker_id,
"Storage worker spawned successfully"
);
}
Ok(())
}
/// Spawns a pool of account workers with dedicated database transactions.
///
/// Each worker receives `AccountWorkerJob` from the channel and processes account multiproofs
/// and blinded account node requests using a dedicated long-lived transaction. Account workers
/// can delegate storage proof computation to the storage worker pool.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `worker_count`: Number of account workers to spawn
/// - `work_rx`: Receiver for account worker jobs
/// - `storage_work_tx`: Sender to delegate storage proofs to storage worker pool
fn spawn_account_workers_internal<Factory>(
executor: &Handle,
view: &ConsistentDbView<Factory>,
task_ctx: &ProofTaskCtx,
worker_count: usize,
work_rx: CrossbeamReceiver<AccountWorkerJob>,
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
) -> ProviderResult<()>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
for worker_id in 0..worker_count {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
let work_rx_clone = work_rx.clone();
let storage_work_tx_clone = storage_work_tx.clone();
executor.spawn_blocking(move || {
account_worker_loop(proof_task_tx, work_rx_clone, storage_work_tx_clone, worker_id)
});
tracing::debug!(
target: "trie::proof_task",
worker_id,
"Account worker spawned successfully"
);
}
Ok(())
}
/// Manager for coordinating proof request execution across different task types.
@@ -129,6 +253,13 @@ impl StorageWorkerJob {
/// - Submit tasks via `queue_task(ProofTaskKind)`
/// - Use standard `std::mpsc` message passing
/// - Receive consistent return types and error handling
///
/// # Deprecation Notice
///
/// This struct is deprecated. Use `spawn_proof_workers()` instead, which returns
/// a `ProofTaskManagerHandle` directly without requiring a separate manager instance.
#[deprecated(note = "Use spawn_proof_workers() instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub struct ProofTaskManager {
/// Sender for storage worker jobs to worker pool.
@@ -149,6 +280,7 @@ pub struct ProofTaskManager {
proof_task_rx: CrossbeamReceiver<ProofTaskMessage>,
/// Sender for creating handles that can queue tasks.
#[allow(dead_code)]
proof_task_tx: CrossbeamSender<ProofTaskMessage>,
/// The number of active handles.
@@ -677,6 +809,7 @@ fn queue_storage_proofs(
Ok(storage_proof_receivers)
}
#[allow(deprecated)]
impl ProofTaskManager {
/// Creates a new [`ProofTaskManager`] with pre-spawned storage and account proof workers.
///
@@ -716,7 +849,7 @@ impl ProofTaskManager {
);
// Spawn storage workers
let spawned_storage_workers = Self::spawn_storage_workers(
spawn_storage_workers_internal(
&executor,
&view,
&task_ctx,
@@ -725,7 +858,7 @@ impl ProofTaskManager {
)?;
// Spawn account workers with direct access to the storage worker queue
let spawned_account_workers = Self::spawn_account_workers(
spawn_account_workers_internal(
&executor,
&view,
&task_ctx,
@@ -736,9 +869,9 @@ impl ProofTaskManager {
Ok(Self {
storage_work_tx,
storage_worker_count: spawned_storage_workers,
storage_worker_count,
account_work_tx,
account_worker_count: spawned_account_workers,
account_worker_count,
proof_task_rx,
proof_task_tx,
active_handles: Arc::new(AtomicUsize::new(0)),
@@ -749,110 +882,18 @@ impl ProofTaskManager {
}
/// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
///
/// DEPRECATED: This method returns a handle that uses the deprecated routing mechanism.
/// Use `spawn_proof_workers()` instead for direct worker pool access.
#[deprecated(note = "Use spawn_proof_workers() instead")]
pub fn handle(&self) -> ProofTaskManagerHandle {
ProofTaskManagerHandle::new(self.proof_task_tx.clone(), self.active_handles.clone())
}
/// Spawns a pool of storage workers with dedicated database transactions.
///
/// Each worker receives `StorageWorkerJob` from the channel and processes storage proofs
/// and blinded storage node requests using a dedicated long-lived transaction.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `worker_count`: Number of storage workers to spawn
/// - `work_rx`: Receiver for storage worker jobs
///
/// # Returns
/// The number of storage workers successfully spawned
fn spawn_storage_workers<Factory>(
executor: &Handle,
view: &ConsistentDbView<Factory>,
task_ctx: &ProofTaskCtx,
worker_count: usize,
work_rx: CrossbeamReceiver<StorageWorkerJob>,
) -> ProviderResult<usize>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
let mut spawned_workers = 0;
for worker_id in 0..worker_count {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
let work_rx_clone = work_rx.clone();
executor.spawn_blocking(move || {
storage_worker_loop(proof_task_tx, work_rx_clone, worker_id)
});
spawned_workers += 1;
tracing::debug!(
target: "trie::proof_task",
worker_id,
spawned_workers,
"Storage worker spawned successfully"
);
}
Ok(spawned_workers)
}
/// Spawns a pool of account workers with dedicated database transactions.
///
/// Each worker receives `AccountWorkerJob` from the channel and processes account multiproofs
/// and blinded account node requests using a dedicated long-lived transaction. Account workers
/// can delegate storage proof computation to the storage worker pool.
///
/// # Parameters
/// - `executor`: Tokio runtime handle for spawning blocking tasks
/// - `view`: Consistent database view for creating transactions
/// - `task_ctx`: Shared context with trie updates and prefix sets
/// - `worker_count`: Number of account workers to spawn
/// - `work_rx`: Receiver for account worker jobs
/// - `storage_work_tx`: Sender to delegate storage proofs to storage worker pool
///
/// # Returns
/// The number of account workers successfully spawned
fn spawn_account_workers<Factory>(
executor: &Handle,
view: &ConsistentDbView<Factory>,
task_ctx: &ProofTaskCtx,
worker_count: usize,
work_rx: CrossbeamReceiver<AccountWorkerJob>,
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
) -> ProviderResult<usize>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
let mut spawned_workers = 0;
for worker_id in 0..worker_count {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
let work_rx_clone = work_rx.clone();
let storage_work_tx_clone = storage_work_tx.clone();
executor.spawn_blocking(move || {
account_worker_loop(proof_task_tx, work_rx_clone, storage_work_tx_clone, worker_id)
});
spawned_workers += 1;
tracing::debug!(
target: "trie::proof_task",
worker_id,
spawned_workers,
"Account worker spawned successfully"
);
}
Ok(spawned_workers)
ProofTaskManagerHandle::new(
self.storage_work_tx.clone(),
self.account_work_tx.clone(),
self.active_handles.clone(),
#[cfg(feature = "metrics")]
Arc::new(ProofTaskMetrics::default()),
)
}
/// Loops, managing the proof tasks, routing them to the appropriate worker pools.
@@ -870,7 +911,14 @@ impl ProofTaskManager {
///
/// On termination, `storage_work_tx` and `account_work_tx` are dropped, closing the channels
/// and signaling all workers to shut down gracefully.
pub fn run(mut self) -> ProviderResult<()> {
///
/// # Deprecation Notice
///
/// This method is deprecated. With `spawn_proof_workers()`, workers are spawned directly
/// and no routing thread is needed. Workers shut down automatically when all handles are
/// dropped.
#[deprecated(note = "Use spawn_proof_workers() instead - no routing thread needed")]
pub fn run(self) -> ProviderResult<()> {
loop {
match self.proof_task_rx.recv() {
Ok(message) => {
@@ -893,7 +941,7 @@ impl ProofTaskManager {
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
#[cfg(feature = "metrics")]
{
self.metrics.storage_nodes += 1;
self.metrics.storage_nodes.fetch_add(1, Ordering::Relaxed);
}
self.storage_work_tx
@@ -915,7 +963,7 @@ impl ProofTaskManager {
ProofTaskKind::BlindedAccountNode(path, sender) => {
#[cfg(feature = "metrics")]
{
self.metrics.account_nodes += 1;
self.metrics.account_nodes.fetch_add(1, Ordering::Relaxed);
}
self.account_work_tx
@@ -1202,6 +1250,11 @@ impl ProofTaskCtx {
}
/// Message used to communicate with [`ProofTaskManager`].
///
/// DEPRECATED: No longer needed with `spawn_proof_workers()` which provides direct
/// worker pool access. Use explicit queue methods on `ProofTaskManagerHandle` instead.
#[deprecated(note = "Use explicit queue methods on ProofTaskManagerHandle instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub enum ProofTaskMessage {
/// A request to queue a proof task.
@@ -1214,6 +1267,12 @@ pub enum ProofTaskMessage {
///
/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum
/// specifies the type of proof task to be executed.
///
/// DEPRECATED: Use explicit queue methods on `ProofTaskManagerHandle` instead:
/// - `queue_storage_proof()` for storage proofs
/// - `queue_account_multiproof()` for account multiproofs
#[deprecated(note = "Use explicit queue methods on ProofTaskManagerHandle instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub enum ProofTaskKind {
/// A storage proof request.
@@ -1226,53 +1285,183 @@ pub enum ProofTaskKind {
AccountMultiproof(AccountMultiproofInput, Sender<AccountMultiproofResult>),
}
/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the
/// number of active handles went to zero.
/// A handle that provides type-safe access to proof worker pools.
///
/// The handle stores direct senders to both storage and account worker pools,
/// eliminating the need for a routing thread. All handles share reference-counted
/// channels, and workers shut down gracefully when all handles are dropped.
#[derive(Debug)]
pub struct ProofTaskManagerHandle {
/// The sender for the proof task manager.
sender: CrossbeamSender<ProofTaskMessage>,
/// The number of active handles.
/// Direct sender to storage worker pool
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
/// Direct sender to account worker pool
account_work_tx: CrossbeamSender<AccountWorkerJob>,
/// Active handle reference count for auto-termination
active_handles: Arc<AtomicUsize>,
/// Metrics tracking (lock-free)
#[cfg(feature = "metrics")]
metrics: Arc<ProofTaskMetrics>,
}
impl ProofTaskManagerHandle {
/// Creates a new [`ProofTaskManagerHandle`] with the given sender.
/// Creates a new [`ProofTaskManagerHandle`] with direct access to worker pools.
#[allow(private_interfaces)]
pub fn new(
sender: CrossbeamSender<ProofTaskMessage>,
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
account_work_tx: CrossbeamSender<AccountWorkerJob>,
active_handles: Arc<AtomicUsize>,
#[cfg(feature = "metrics")] metrics: Arc<ProofTaskMetrics>,
) -> Self {
active_handles.fetch_add(1, Ordering::SeqCst);
Self { sender, active_handles }
Self {
storage_work_tx,
account_work_tx,
active_handles,
#[cfg(feature = "metrics")]
metrics,
}
}
/// Queue a storage proof computation
pub fn queue_storage_proof(
&self,
input: StorageProofInput,
) -> Result<Receiver<StorageProofResult>, ProviderError> {
let (tx, rx) = channel();
self.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("storage workers unavailable"))
})?;
#[cfg(feature = "metrics")]
self.metrics.storage_proofs.fetch_add(1, Ordering::Relaxed);
Ok(rx)
}
/// Queue an account multiproof computation
pub fn queue_account_multiproof(
&self,
input: AccountMultiproofInput,
) -> Result<Receiver<AccountMultiproofResult>, ProviderError> {
let (tx, rx) = channel();
self.account_work_tx
.send(AccountWorkerJob::AccountMultiproof { input, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("account workers unavailable"))
})?;
#[cfg(feature = "metrics")]
self.metrics.account_proofs.fetch_add(1, Ordering::Relaxed);
Ok(rx)
}
/// Internal: Queue blinded storage node request
fn queue_blinded_storage_node(
&self,
account: B256,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.storage_work_tx
.send(StorageWorkerJob::BlindedStorageNode { account, path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("storage workers unavailable"))
})?;
#[cfg(feature = "metrics")]
self.metrics.storage_nodes.fetch_add(1, Ordering::Relaxed);
Ok(rx)
}
/// Internal: Queue blinded account node request
fn queue_blinded_account_node(
&self,
path: Nibbles,
) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
let (tx, rx) = channel();
self.account_work_tx
.send(AccountWorkerJob::BlindedAccountNode { path, result_sender: tx })
.map_err(|_| {
ProviderError::other(std::io::Error::other("account workers unavailable"))
})?;
#[cfg(feature = "metrics")]
self.metrics.account_nodes.fetch_add(1, Ordering::Relaxed);
Ok(rx)
}
/// Queues a task to the proof task manager.
///
/// DEPRECATED: Use explicit methods like `queue_storage_proof` or `queue_account_multiproof`
/// instead. This method is kept temporarily for backwards compatibility during the
/// migration.
#[allow(deprecated)]
#[deprecated(note = "Use explicit queue methods instead")]
pub fn queue_task(
&self,
task: ProofTaskKind,
) -> Result<(), crossbeam_channel::SendError<ProofTaskMessage>> {
self.sender.send(ProofTaskMessage::QueueTask(task))
match task {
ProofTaskKind::StorageProof(input, sender) => {
self.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, result_sender: sender })
.expect("storage workers should be available");
}
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
self.storage_work_tx
.send(StorageWorkerJob::BlindedStorageNode {
account,
path,
result_sender: sender,
})
.expect("storage workers should be available");
}
ProofTaskKind::BlindedAccountNode(path, sender) => {
self.account_work_tx
.send(AccountWorkerJob::BlindedAccountNode { path, result_sender: sender })
.expect("account workers should be available");
}
ProofTaskKind::AccountMultiproof(input, sender) => {
self.account_work_tx
.send(AccountWorkerJob::AccountMultiproof { input, result_sender: sender })
.expect("account workers should be available");
}
}
Ok(())
}
/// Terminates the proof task manager.
pub fn terminate(&self) {
let _ = self.sender.send(ProofTaskMessage::Terminate);
///
/// DEPRECATED: Workers now shut down automatically when all handles are dropped.
/// This method is kept for backwards compatibility but does nothing.
#[deprecated(note = "Workers shut down automatically when all handles are dropped")]
pub const fn terminate(&self) {
// No-op: workers shut down when all handles are dropped
}
}
impl Clone for ProofTaskManagerHandle {
fn clone(&self) -> Self {
Self::new(self.sender.clone(), self.active_handles.clone())
Self::new(
self.storage_work_tx.clone(),
self.account_work_tx.clone(),
self.active_handles.clone(),
#[cfg(feature = "metrics")]
self.metrics.clone(),
)
}
}
impl Drop for ProofTaskManagerHandle {
fn drop(&mut self) {
// Decrement the number of active handles and terminate the manager if it was the last
// handle.
if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
self.terminate();
}
// Decrement the number of active handles.
// When the last handle is dropped, the channels are dropped and workers shut down.
self.active_handles.fetch_sub(1, Ordering::SeqCst);
}
}
@@ -1281,11 +1470,11 @@ impl TrieNodeProviderFactory for ProofTaskManagerHandle {
type StorageNodeProvider = ProofTaskTrieNodeProvider;
fn account_node_provider(&self) -> Self::AccountNodeProvider {
ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() }
}
fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() }
}
}
@@ -1294,35 +1483,38 @@ impl TrieNodeProviderFactory for ProofTaskManagerHandle {
pub enum ProofTaskTrieNodeProvider {
/// Blinded account trie node provider.
AccountNode {
/// Sender to the proof task.
sender: CrossbeamSender<ProofTaskMessage>,
/// Handle to the proof task manager.
handle: ProofTaskManagerHandle,
},
/// Blinded storage trie node provider.
StorageNode {
/// Target account.
account: B256,
/// Sender to the proof task.
sender: CrossbeamSender<ProofTaskMessage>,
/// Handle to the proof task manager.
handle: ProofTaskManagerHandle,
},
}
impl TrieNodeProvider for ProofTaskTrieNodeProvider {
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
let (tx, rx) = channel();
match self {
Self::AccountNode { sender } => {
let _ = sender.send(ProofTaskMessage::QueueTask(
ProofTaskKind::BlindedAccountNode(*path, tx),
));
Self::AccountNode { handle } => {
let rx = handle.queue_blinded_account_node(*path).map_err(|e| {
SparseTrieErrorKind::Other(Box::new(std::io::Error::other(e.to_string())))
})?;
rx.recv().map_err(|_| {
SparseTrieErrorKind::Other(Box::new(std::io::Error::other("channel closed")))
})?
}
Self::StorageNode { sender, account } => {
let _ = sender.send(ProofTaskMessage::QueueTask(
ProofTaskKind::BlindedStorageNode(*account, *path, tx),
));
Self::StorageNode { handle, account } => {
let rx = handle.queue_blinded_storage_node(*account, *path).map_err(|e| {
SparseTrieErrorKind::Other(Box::new(std::io::Error::other(e.to_string())))
})?;
rx.recv().map_err(|_| {
SparseTrieErrorKind::Other(Box::new(std::io::Error::other("channel closed")))
})?
}
}
rx.recv().unwrap()
}
}
@@ -1349,9 +1541,9 @@ mod tests {
)
}
/// Ensures `max_concurrency` is independent of storage and account workers.
/// Ensures `spawn_proof_workers` spawns workers correctly.
#[test]
fn proof_task_manager_independent_pools() {
fn spawn_proof_workers_creates_handle() {
let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
runtime.block_on(async {
let handle = tokio::runtime::Handle::current();
@@ -1359,13 +1551,13 @@ mod tests {
let view = ConsistentDbView::new(factory, None);
let ctx = test_ctx();
let manager = ProofTaskManager::new(handle.clone(), view, ctx, 5, 3).unwrap();
// With storage_worker_count=5, we get exactly 5 storage workers
assert_eq!(manager.storage_worker_count, 5);
// With account_worker_count=3, we get exactly 3 account workers
assert_eq!(manager.account_worker_count, 3);
let proof_handle = spawn_proof_workers(handle.clone(), view, ctx, 5, 3).unwrap();
drop(manager);
// Verify handle can be cloned
let _cloned_handle = proof_handle.clone();
// Workers shut down automatically when handle is dropped
drop(proof_handle);
task::yield_now().await;
});
}

View File

@@ -1,21 +1,41 @@
use reth_metrics::{metrics::Histogram, Metrics};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
/// Metrics for blinded node fetching for the duration of the proof task manager.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug)]
pub struct ProofTaskMetrics {
/// The actual metrics for blinded nodes.
pub task_metrics: ProofTaskTrieMetrics,
/// Count of blinded account node requests.
pub account_nodes: usize,
/// Count of blinded storage node requests.
pub storage_nodes: usize,
/// Count of storage proof requests (lock-free).
pub storage_proofs: Arc<AtomicU64>,
/// Count of account proof requests (lock-free).
pub account_proofs: Arc<AtomicU64>,
/// Count of blinded account node requests (lock-free).
pub account_nodes: Arc<AtomicU64>,
/// Count of blinded storage node requests (lock-free).
pub storage_nodes: Arc<AtomicU64>,
}
impl Default for ProofTaskMetrics {
fn default() -> Self {
Self {
task_metrics: ProofTaskTrieMetrics::default(),
storage_proofs: Arc::new(AtomicU64::new(0)),
account_proofs: Arc::new(AtomicU64::new(0)),
account_nodes: Arc::new(AtomicU64::new(0)),
storage_nodes: Arc::new(AtomicU64::new(0)),
}
}
}
impl ProofTaskMetrics {
/// Record the blinded node counts into the histograms.
pub fn record(&self) {
self.task_metrics.record_account_nodes(self.account_nodes);
self.task_metrics.record_storage_nodes(self.storage_nodes);
self.task_metrics.record_account_nodes(self.account_nodes.load(Ordering::Relaxed) as usize);
self.task_metrics.record_storage_nodes(self.storage_nodes.load(Ordering::Relaxed) as usize);
}
}