refactor(trie): enhance proof task management with transaction pooling

- Simplified the `ProofTaskManager` by implementing a transaction pool using `crossbeam` channels for better concurrency.
- Replaced the previous transaction management logic with a pre-initialized pool of transactions, improving efficiency in handling proof tasks.
- Updated task dispatching to utilize the new transaction pool, ensuring smoother execution and resource management.
- Enhanced error handling and logging for transaction operations to improve debugging capabilities.
This commit is contained in:
Yong Kang
2025-10-07 02:57:10 +00:00
parent f2dab42b7b
commit 883b38bb43

View File

@@ -1,20 +1,15 @@
//! A Task that manages sending proof requests to a number of tasks that have longer-running
//! database transactions.
//! Proof task management using a pool of pre-warmed database transactions.
//!
//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks,
//! and is responsible for managing the fixed number of database transactions created at the start
//! of the task.
//!
//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
//! This module provides proof computation using Tokio's blocking threadpool with
//! transaction reuse via a crossbeam channel pool.
use crate::root::ParallelStateRootError;
use alloy_primitives::{map::B256Set, B256};
use crossbeam_channel::{bounded, unbounded, Receiver, Sender, TryRecvError, TrySendError};
use reth_db_api::transaction::DbTx;
use reth_execution_errors::SparseTrieError;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
ProviderResult,
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderResult,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
@@ -31,16 +26,14 @@ use reth_trie_common::{
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
use std::{
collections::VecDeque,
sync::{
atomic::{AtomicUsize, Ordering},
mpsc::{channel, Receiver, SendError, Sender},
Arc,
},
time::Instant,
};
use tokio::runtime::Handle;
use tracing::{debug, trace};
use tokio::{runtime::Handle, sync::Notify};
use tracing::{error, trace};
#[cfg(feature = "metrics")]
use crate::proof_task_metrics::ProofTaskMetrics;
@@ -48,170 +41,41 @@ use crate::proof_task_metrics::ProofTaskMetrics;
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
/// A task that manages sending multiproof requests to a number of tasks that have longer-running
/// database transactions
#[derive(Debug)]
pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
/// Max number of database transactions to create
max_concurrency: usize,
/// Number of database transactions created
total_transactions: usize,
/// Consistent view provider used for creating transactions on-demand
view: ConsistentDbView<Factory>,
/// Proof task context shared across all proof tasks
task_ctx: ProofTaskCtx,
/// Proof tasks pending execution
pending_tasks: VecDeque<ProofTaskKind>,
/// The underlying handle from which to spawn proof tasks
/// Creates a new proof task handle with a pre-initialized transaction pool.
///
/// This function creates a pool of database transactions that will be reused across
/// multiple proof tasks. Tasks are queued asynchronously and coordinated via notification,
/// with actual computation dispatched to Tokio's blocking threadpool using the pooled transactions.
pub fn new_proof_task_handle<Factory>(
executor: Handle,
/// The proof task transactions, containing owned cursor factories that are reused for proof
/// calculation.
proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
/// A receiver for new proof tasks.
proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
/// A sender for sending back transactions.
tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
/// The number of active handles.
///
/// Incremented in [`ProofTaskManagerHandle::new`] and decremented in
/// [`ProofTaskManagerHandle::drop`].
active_handles: Arc<AtomicUsize>,
/// Metrics tracking blinded node fetches.
#[cfg(feature = "metrics")]
metrics: ProofTaskMetrics,
}
impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
/// Creates a new [`ProofTaskManager`] with the given max concurrency, creating that number of
/// cursor factories.
///
/// Returns an error if the consistent view provider fails to create a read-only transaction.
pub fn new(
executor: Handle,
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
max_concurrency: usize,
) -> Self {
let (tx_sender, proof_task_rx) = channel();
Self {
max_concurrency,
total_transactions: 0,
view,
task_ctx,
pending_tasks: VecDeque::new(),
executor,
proof_task_txs: Vec::new(),
proof_task_rx,
tx_sender,
active_handles: Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "metrics")]
metrics: ProofTaskMetrics::default(),
}
}
/// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
}
}
impl<Factory> ProofTaskManager<Factory>
view: ConsistentDbView<Factory>,
task_ctx: ProofTaskCtx,
max_concurrency: usize,
) -> ProviderResult<ProofTaskManagerHandle<<Factory::Provider as DBProvider>::Tx>>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + 'static,
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + Send + Sync + 'static,
{
/// Inserts the task into the pending tasks queue.
pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
self.pending_tasks.push_back(task);
let max_concurrency = max_concurrency.max(1);
let (tx_pool_sender, tx_pool_receiver) = bounded(max_concurrency);
let pool_notify = Arc::new(Notify::new());
// Pre-create all transactions upfront
for worker_id in 0..max_concurrency {
let provider_ro = view.provider_ro()?;
let tx = provider_ro.into_tx();
let proof_task_tx = Arc::new(ProofTaskTx::new(tx, task_ctx.clone(), worker_id));
tx_pool_sender.send(proof_task_tx).expect("pool channel should have capacity");
}
/// Gets either the next available transaction, or creates a new one if all are in use and the
/// total number of transactions created is less than the max concurrency.
pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
if let Some(proof_task_tx) = self.proof_task_txs.pop() {
return Ok(Some(proof_task_tx));
}
// if we can create a new tx within our concurrency limits, create one on-demand
if self.total_transactions < self.max_concurrency {
let provider_ro = self.view.provider_ro()?;
let tx = provider_ro.into_tx();
self.total_transactions += 1;
return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions)));
}
Ok(None)
}
/// Spawns the next queued proof task on the executor with the given input, if there are any
/// transactions available.
///
/// This will return an error if a transaction must be created on-demand and the consistent view
/// provider fails.
pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
let Some(proof_task_tx) = self.get_or_create_tx()? else {
// if there are no txs available, requeue the proof task
self.pending_tasks.push_front(task);
return Ok(())
};
let tx_sender = self.tx_sender.clone();
self.executor.spawn_blocking(move || match task {
ProofTaskKind::StorageProof(input, sender) => {
proof_task_tx.storage_proof(input, sender, tx_sender);
}
ProofTaskKind::BlindedAccountNode(path, sender) => {
proof_task_tx.blinded_account_node(path, sender, tx_sender);
}
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
}
});
Ok(())
}
/// Loops, managing the proof tasks, and sending new tasks to the executor.
pub fn run(mut self) -> ProviderResult<()> {
loop {
match self.proof_task_rx.recv() {
Ok(message) => match message {
ProofTaskMessage::QueueTask(task) => {
// Track metrics for blinded node requests
#[cfg(feature = "metrics")]
match &task {
ProofTaskKind::BlindedAccountNode(_, _) => {
self.metrics.account_nodes += 1;
}
ProofTaskKind::BlindedStorageNode(_, _, _) => {
self.metrics.storage_nodes += 1;
}
_ => {}
}
// queue the task
self.queue_proof_task(task)
}
ProofTaskMessage::Transaction(tx) => {
// return the transaction to the pool
self.proof_task_txs.push(tx);
}
ProofTaskMessage::Terminate => {
// Record metrics before terminating
#[cfg(feature = "metrics")]
self.metrics.record();
return Ok(())
}
},
// All senders are disconnected, so we can terminate
// However this should never happen, as this struct stores a sender
Err(_) => return Ok(()),
};
// try spawning the next task
self.try_spawn_next()?;
}
}
Ok(ProofTaskManagerHandle::new(
tx_pool_sender,
tx_pool_receiver,
pool_notify,
executor,
Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "metrics")]
Arc::new(ProofTaskMetrics::default()),
))
}
/// Type alias for the factory tuple returned by `create_factories`
@@ -257,92 +121,83 @@ where
}
/// Calculates a storage proof for the given hashed address, and desired prefix set.
fn storage_proof(
self,
input: StorageProofInput,
result_sender: Sender<StorageProofResult>,
tx_sender: Sender<ProofTaskMessage<Tx>>,
) {
fn storage_proof(&self, input: StorageProofInput, result_sender: &Sender<StorageProofResult>) {
let StorageProofInput {
hashed_address,
prefix_set,
target_slots,
with_branch_node_masks,
multi_added_removed_keys,
} = input;
trace!(
target: "trie::proof_task",
hashed_address=?input.hashed_address,
worker_id = self.id,
hashed_address = ?hashed_address,
"Starting storage proof task calculation"
);
let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
let multi_added_removed_keys = input
.multi_added_removed_keys
.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&input.hashed_address);
let multi_added_removed_keys =
multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
let span = tracing::trace_span!(
target: "trie::proof_task",
"Storage proof calculation",
hashed_address=?input.hashed_address,
hashed_address = ?hashed_address,
// Add a unique id because we often have parallel storage proof calculations for the
// same hashed address, and we want to differentiate them during trace analysis.
span_id=self.id,
span_id = self.id,
);
let span_guard = span.enter();
let _span_guard: tracing::span::Entered<'_> = span.enter();
let target_slots_len = input.target_slots.len();
let target_slots_len = target_slots.len();
let proof_start = Instant::now();
let raw_proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
input.hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().copied()))
.with_branch_node_masks(input.with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(input.target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
drop(span_guard);
let raw_proof_result =
StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
.with_branch_node_masks(with_branch_node_masks)
.with_added_removed_keys(added_removed_keys)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
input.hashed_address, e
hashed_address, e
))
})
});
trace!(
target: "trie::proof_task",
hashed_address=?input.hashed_address,
prefix_set = ?input.prefix_set.len(),
target_slots = ?target_slots_len,
worker_id = self.id,
hashed_address = ?hashed_address,
prefix_set_len = prefix_set.len(),
target_slots = target_slots_len,
proof_time = ?proof_start.elapsed(),
"Completed storage proof task calculation"
);
// send the result back
if let Err(error) = result_sender.send(decoded_result) {
debug!(
// Send the result back (log error if receiver dropped)
if let Err(e) = result_sender.send(decoded_result) {
error!(
target: "trie::proof_task",
hashed_address = ?input.hashed_address,
?error,
task_time = ?proof_start.elapsed(),
"Storage proof receiver is dropped, discarding the result"
worker_id = self.id,
"Failed to send storage proof result: {:?}",
e
);
}
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
}
/// Retrieves blinded account node by path.
fn blinded_account_node(
self,
path: Nibbles,
result_sender: Sender<TrieNodeProviderResult>,
tx_sender: Sender<ProofTaskMessage<Tx>>,
) {
fn blinded_account_node(&self, path: &Nibbles, result_sender: &Sender<TrieNodeProviderResult>) {
trace!(
target: "trie::proof_task",
worker_id = self.id,
?path,
"Starting blinded account node retrieval"
);
@@ -356,37 +211,35 @@ where
);
let start = Instant::now();
let result = blinded_provider_factory.account_node_provider().trie_node(&path);
let result = blinded_provider_factory.account_node_provider().trie_node(path);
trace!(
target: "trie::proof_task",
worker_id = self.id,
?path,
elapsed = ?start.elapsed(),
"Completed blinded account node retrieval"
);
if let Err(error) = result_sender.send(result) {
tracing::error!(
if let Err(e) = result_sender.send(result) {
error!(
target: "trie::proof_task",
?path,
?error,
"Failed to send blinded account node result"
worker_id = self.id,
"Failed to send account node result: {:?}",
e
);
}
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
}
/// Retrieves blinded storage node of the given account by path.
fn blinded_storage_node(
self,
account: B256,
path: Nibbles,
result_sender: Sender<TrieNodeProviderResult>,
tx_sender: Sender<ProofTaskMessage<Tx>>,
&self,
account: &B256,
path: &Nibbles,
result_sender: &Sender<TrieNodeProviderResult>,
) {
trace!(
target: "trie::proof_task",
worker_id = self.id,
?account,
?path,
"Starting blinded storage node retrieval"
@@ -401,9 +254,10 @@ where
);
let start = Instant::now();
let result = blinded_provider_factory.storage_node_provider(account).trie_node(&path);
let result = blinded_provider_factory.storage_node_provider(*account).trie_node(path);
trace!(
target: "trie::proof_task",
worker_id = self.id,
?account,
?path,
elapsed = ?start.elapsed(),
@@ -411,17 +265,19 @@ where
);
if let Err(error) = result_sender.send(result) {
tracing::error!(
error!(
target: "trie::proof_task",
?account,
?path,
worker_id = self.id,
?error,
"Failed to send blinded storage node result"
"Failed to send storage node result"
);
}
// send the tx back
let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
// Note: Transaction return to pool is handled by dispatch_task() after spawn_blocking
// completes. The Arc<ProofTaskTx> is moved into the closure and returned as the result,
// then sent back to the pool automatically. No explicit return needed here.
}
}
@@ -429,15 +285,15 @@ where
#[derive(Debug)]
pub struct StorageProofInput {
/// The hashed address for which the proof is calculated.
hashed_address: B256,
pub hashed_address: B256,
/// The prefix set for the proof calculation.
prefix_set: PrefixSet,
pub prefix_set: PrefixSet,
/// The target slots for the proof calculation.
target_slots: B256Set,
pub target_slots: B256Set,
/// Whether or not to collect branch node masks
with_branch_node_masks: bool,
pub with_branch_node_masks: bool,
/// Provided by the user to give the necessary context to retain extra proofs.
multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
pub multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
}
impl StorageProofInput {
@@ -483,21 +339,7 @@ impl ProofTaskCtx {
}
}
/// Message used to communicate with [`ProofTaskManager`].
#[derive(Debug)]
pub enum ProofTaskMessage<Tx> {
/// A request to queue a proof task.
QueueTask(ProofTaskKind),
/// A returned database transaction.
Transaction(ProofTaskTx<Tx>),
/// A request to terminate the proof task manager.
Terminate,
}
/// Proof task kind.
///
/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum
/// specifies the type of proof task to be executed.
/// Proof task kind dispatched via [`ProofTaskManagerHandle::queue_task`].
#[derive(Debug)]
pub enum ProofTaskKind {
/// A storage proof request.
@@ -512,89 +354,237 @@ pub enum ProofTaskKind {
///
/// Tasks are dispatched directly without an intermediate manager loop.
pub struct ProofTaskManagerHandle<Tx> {
/// The sender for the proof task manager.
sender: Sender<ProofTaskMessage<Tx>>,
/// The number of active handles.
/// Transaction pool sender (for returning transactions)
tx_pool_sender: Sender<Arc<ProofTaskTx<Tx>>>,
/// Transaction pool receiver (for checking out transactions)
tx_pool_receiver: Receiver<Arc<ProofTaskTx<Tx>>>,
/// Notifies waiters when a transaction is returned to the pool
pool_notify: Arc<Notify>,
/// Tokio executor for spawning blocking tasks
executor: Handle,
/// The number of active handles (for metrics).
active_handles: Arc<AtomicUsize>,
/// Metrics tracking blinded node fetches.
#[cfg(feature = "metrics")]
metrics: Arc<ProofTaskMetrics>,
}
impl<Tx> ProofTaskManagerHandle<Tx> {
/// Creates a new [`ProofTaskManagerHandle`] with the given sender.
pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
// Manual Debug impl since Tx may not be Debug
impl<Tx> std::fmt::Debug for ProofTaskManagerHandle<Tx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProofTaskManagerHandle")
.field("executor", &self.executor)
.field("active_handles", &self.active_handles)
.finish()
}
}
impl<Tx> ProofTaskManagerHandle<Tx>
where
Tx: DbTx + Send + 'static,
{
/// Creates a new [`ProofTaskManagerHandle`].
pub fn new(
tx_pool_sender: Sender<Arc<ProofTaskTx<Tx>>>,
tx_pool_receiver: Receiver<Arc<ProofTaskTx<Tx>>>,
pool_notify: Arc<Notify>,
executor: Handle,
active_handles: Arc<AtomicUsize>,
#[cfg(feature = "metrics")] metrics: Arc<ProofTaskMetrics>,
) -> Self {
active_handles.fetch_add(1, Ordering::SeqCst);
Self { sender, active_handles }
Self {
tx_pool_sender,
tx_pool_receiver,
pool_notify,
executor,
active_handles,
#[cfg(feature = "metrics")]
metrics,
}
}
/// Queues a task to the proof task manager.
pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
self.sender.send(ProofTaskMessage::QueueTask(task))
}
/// Queues a task by checking out a transaction from the pool and spawning it
/// directly in Tokio's blocking threadpool.
pub fn queue_task(&self, task: ProofTaskKind) {
let tx_pool_receiver = self.tx_pool_receiver.clone();
let tx_pool_sender = self.tx_pool_sender.clone();
let executor = self.executor.clone();
let pool_notify = Arc::clone(&self.pool_notify);
/// Terminates the proof task manager.
pub fn terminate(&self) {
let _ = self.sender.send(ProofTaskMessage::Terminate);
#[cfg(feature = "metrics")]
let metrics = Arc::clone(&self.metrics);
// Track metrics for blinded node requests
#[cfg(feature = "metrics")]
match &task {
ProofTaskKind::BlindedAccountNode(_, _) => {
metrics.account_nodes.fetch_add(1, Ordering::Relaxed);
}
ProofTaskKind::BlindedStorageNode(_, _, _) => {
metrics.storage_nodes.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
self.executor.spawn(async move {
// Wait asynchronously until a transaction becomes available.
let proof_tx = loop {
match tx_pool_receiver.try_recv() {
Ok(tx) => break tx,
Err(TryRecvError::Empty) => {
pool_notify.notified().await;
}
Err(TryRecvError::Disconnected) => {
error!(target: "trie::proof_task", "Transaction pool closed");
return;
}
}
};
// Execute task in blocking threadpool
let result = executor
.spawn_blocking(move || {
match task {
ProofTaskKind::StorageProof(input, sender) => {
proof_tx.storage_proof(input, &sender);
}
ProofTaskKind::BlindedAccountNode(path, sender) => {
proof_tx.blinded_account_node(&path, &sender);
}
ProofTaskKind::BlindedStorageNode(account, path, sender) => {
proof_tx.blinded_storage_node(&account, &path, &sender);
}
}
proof_tx
})
.await;
// Return transaction to pool
match result {
Ok(proof_tx) => {
match tx_pool_sender.try_send(proof_tx) {
Ok(()) => {
pool_notify.notify_one();
}
Err(TrySendError::Full(tx)) => {
// Should never happen - we're returning what we took
error!(target: "trie::proof_task",
"Pool full on return. This should not happen.");
// Fallback: Use spawn_blocking to retry the send operation
// This prevents losing the transaction from the pool
// The send() call blocks a blocking-pool thread, NOT the async worker
let tx_pool_sender = tx_pool_sender.clone();
let pool_notify = Arc::clone(&pool_notify);
executor.spawn_blocking(move || {
// Retry the send in a blocking context
if tx_pool_sender.send(tx).is_ok() {
pool_notify.notify_one();
} else {
error!(target: "trie::proof_task",
"Failed to return transaction to pool even after blocking retry");
}
});
}
Err(TrySendError::Disconnected(_)) => {
// Pool closed, ignore
}
}
}
Err(e) => {
error!(target: "trie::proof_task", ?e, "Proof task panicked, transaction lost from pool");
}
}
});
}
}
impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
impl<Tx> Clone for ProofTaskManagerHandle<Tx>
where
Tx: DbTx + Send + 'static,
{
fn clone(&self) -> Self {
Self::new(self.sender.clone(), self.active_handles.clone())
Self::new(
self.tx_pool_sender.clone(),
self.tx_pool_receiver.clone(),
Arc::clone(&self.pool_notify),
self.executor.clone(),
Arc::clone(&self.active_handles),
#[cfg(feature = "metrics")]
Arc::clone(&self.metrics),
)
}
}
impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
fn drop(&mut self) {
// Decrement the number of active handles and terminate the manager if it was the last
// handle.
// Wake any tasks waiting on a transaction so they can observe shutdown.
self.pool_notify.notify_waiters();
// Record metrics if this is the last handle
if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
self.terminate();
#[cfg(feature = "metrics")]
self.metrics.record();
}
}
}
impl<Tx: DbTx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx> {
impl<Tx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx>
where
Tx: DbTx + Send + 'static,
{
type AccountNodeProvider = ProofTaskTrieNodeProvider<Tx>;
type StorageNodeProvider = ProofTaskTrieNodeProvider<Tx>;
fn account_node_provider(&self) -> Self::AccountNodeProvider {
ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() }
}
fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() }
}
}
/// Trie node provider for retrieving trie nodes by path.
#[derive(Debug)]
pub enum ProofTaskTrieNodeProvider<Tx> {
/// Blinded account trie node provider.
AccountNode {
/// Sender to the proof task.
sender: Sender<ProofTaskMessage<Tx>>,
/// Handle to the transaction pool
handle: ProofTaskManagerHandle<Tx>,
},
/// Blinded storage trie node provider.
StorageNode {
/// Target account.
account: B256,
/// Sender to the proof task.
sender: Sender<ProofTaskMessage<Tx>>,
/// Handle to the transaction pool
handle: ProofTaskManagerHandle<Tx>,
},
}
impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
let (tx, rx) = channel();
impl<Tx> std::fmt::Debug for ProofTaskTrieNodeProvider<Tx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AccountNode { sender } => {
let _ = sender.send(ProofTaskMessage::QueueTask(
ProofTaskKind::BlindedAccountNode(*path, tx),
));
Self::AccountNode { .. } => f.debug_struct("AccountNode").finish(),
Self::StorageNode { account, .. } => {
f.debug_struct("StorageNode").field("account", account).finish()
}
Self::StorageNode { sender, account } => {
let _ = sender.send(ProofTaskMessage::QueueTask(
ProofTaskKind::BlindedStorageNode(*account, *path, tx),
));
}
}
}
impl<Tx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx>
where
Tx: DbTx + Send + 'static,
{
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
let (tx, rx) = unbounded();
match self {
Self::AccountNode { handle } => {
handle.queue_task(ProofTaskKind::BlindedAccountNode(*path, tx));
}
Self::StorageNode { handle, account } => {
handle.queue_task(ProofTaskKind::BlindedStorageNode(*account, *path, tx));
}
}