perf(tree): worker pooling for account proofs (#18901)

Co-authored-by: Brian Picciano <me@mediocregopher.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com>
This commit is contained in:
YK
2025-10-15 08:26:02 +08:00
committed by GitHub
parent 169a1fb97b
commit e0b7a86313
8 changed files with 904 additions and 574 deletions

View File

@@ -21,6 +21,14 @@ fn default_storage_worker_count() -> usize {
}
}
/// Returns the default number of account worker threads.
///
/// Account workers coordinate storage proof collection and account trie traversal.
/// They are set to the same count as storage workers for simplicity.
fn default_account_worker_count() -> usize {
default_storage_worker_count()
}
/// The size of proof targets chunk to spawn in one multiproof calculation.
pub const DEFAULT_MULTIPROOF_TASK_CHUNK_SIZE: usize = 10;
@@ -123,6 +131,8 @@ pub struct TreeConfig {
allow_unwind_canonical_header: bool,
/// Number of storage proof worker threads.
storage_worker_count: usize,
/// Number of account proof worker threads.
account_worker_count: usize,
}
impl Default for TreeConfig {
@@ -150,6 +160,7 @@ impl Default for TreeConfig {
prewarm_max_concurrency: DEFAULT_PREWARM_MAX_CONCURRENCY,
allow_unwind_canonical_header: false,
storage_worker_count: default_storage_worker_count(),
account_worker_count: default_account_worker_count(),
}
}
}
@@ -180,6 +191,7 @@ impl TreeConfig {
prewarm_max_concurrency: usize,
allow_unwind_canonical_header: bool,
storage_worker_count: usize,
account_worker_count: usize,
) -> Self {
assert!(max_proof_task_concurrency > 0, "max_proof_task_concurrency must be at least 1");
Self {
@@ -205,6 +217,7 @@ impl TreeConfig {
prewarm_max_concurrency,
allow_unwind_canonical_header,
storage_worker_count,
account_worker_count,
}
}
@@ -482,4 +495,15 @@ impl TreeConfig {
self.storage_worker_count = storage_worker_count;
self
}
/// Return the number of account proof worker threads.
pub const fn account_worker_count(&self) -> usize {
self.account_worker_count
}
/// Setter for the number of account proof worker threads.
pub const fn with_account_worker_count(mut self, account_worker_count: usize) -> Self {
self.account_worker_count = account_worker_count;
self
}
}

View File

@@ -192,8 +192,7 @@ where
{
let (to_sparse_trie, sparse_trie_rx) = channel();
// spawn multiproof task, save the trie input
let (trie_input, state_root_config) =
MultiProofConfig::new_from_input(consistent_view, trie_input);
let (trie_input, state_root_config) = MultiProofConfig::from_input(trie_input);
self.trie_input = Some(trie_input);
// Create and spawn the storage proof task
@@ -202,14 +201,15 @@ where
state_root_config.state_sorted.clone(),
state_root_config.prefix_sets.clone(),
);
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
let storage_worker_count = config.storage_worker_count();
let account_worker_count = config.account_worker_count();
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
let proof_task = match ProofTaskManager::new(
self.executor.handle().clone(),
state_root_config.consistent_view.clone(),
consistent_view,
task_ctx,
max_proof_task_concurrency,
storage_worker_count,
account_worker_count,
) {
Ok(task) => task,
Err(error) => {

View File

@@ -12,14 +12,17 @@ use derive_more::derive::Deref;
use metrics::Histogram;
use reth_errors::ProviderError;
use reth_metrics::Metrics;
use reth_provider::{providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, FactoryTx};
use reth_revm::state::EvmState;
use reth_trie::{
added_removed_keys::MultiAddedRemovedKeys, prefix_set::TriePrefixSetsMut,
updates::TrieUpdatesSorted, DecodedMultiProof, HashedPostState, HashedPostStateSorted,
HashedStorage, MultiProofTargets, TrieInput,
};
use reth_trie_parallel::{proof::ParallelProof, proof_task::ProofTaskManagerHandle};
use reth_trie_parallel::{
proof::ParallelProof,
proof_task::{AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle},
root::ParallelStateRootError,
};
use std::{
collections::{BTreeMap, VecDeque},
ops::DerefMut,
@@ -62,9 +65,7 @@ impl SparseTrieUpdate {
/// Common configuration for multi proof tasks
#[derive(Debug, Clone)]
pub(super) struct MultiProofConfig<Factory> {
/// View over the state in the database.
pub consistent_view: ConsistentDbView<Factory>,
pub(super) struct MultiProofConfig {
/// The sorted collection of cached in-memory intermediate trie nodes that
/// can be reused for computation.
pub nodes_sorted: Arc<TrieUpdatesSorted>,
@@ -76,17 +77,13 @@ pub(super) struct MultiProofConfig<Factory> {
pub prefix_sets: Arc<TriePrefixSetsMut>,
}
impl<Factory> MultiProofConfig<Factory> {
/// Creates a new state root config from the consistent view and the trie input.
impl MultiProofConfig {
/// Creates a new state root config from the trie input.
///
/// This returns a cleared [`TrieInput`] so that we can reuse any allocated space in the
/// [`TrieInput`].
pub(super) fn new_from_input(
consistent_view: ConsistentDbView<Factory>,
mut input: TrieInput,
) -> (TrieInput, Self) {
pub(super) fn from_input(mut input: TrieInput) -> (TrieInput, Self) {
let config = Self {
consistent_view,
nodes_sorted: Arc::new(input.nodes.drain_into_sorted()),
state_sorted: Arc::new(input.state.drain_into_sorted()),
prefix_sets: Arc::new(input.prefix_sets.clone()),
@@ -245,14 +242,14 @@ pub(crate) fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostStat
/// A pending multiproof task, either [`StorageMultiproofInput`] or [`MultiproofInput`].
#[derive(Debug)]
enum PendingMultiproofTask<Factory> {
enum PendingMultiproofTask {
/// A storage multiproof task input.
Storage(StorageMultiproofInput<Factory>),
Storage(StorageMultiproofInput),
/// A regular multiproof task input.
Regular(MultiproofInput<Factory>),
Regular(MultiproofInput),
}
impl<Factory> PendingMultiproofTask<Factory> {
impl PendingMultiproofTask {
/// Returns the proof sequence number of the task.
const fn proof_sequence_number(&self) -> u64 {
match self {
@@ -278,22 +275,22 @@ impl<Factory> PendingMultiproofTask<Factory> {
}
}
impl<Factory> From<StorageMultiproofInput<Factory>> for PendingMultiproofTask<Factory> {
fn from(input: StorageMultiproofInput<Factory>) -> Self {
impl From<StorageMultiproofInput> for PendingMultiproofTask {
fn from(input: StorageMultiproofInput) -> Self {
Self::Storage(input)
}
}
impl<Factory> From<MultiproofInput<Factory>> for PendingMultiproofTask<Factory> {
fn from(input: MultiproofInput<Factory>) -> Self {
impl From<MultiproofInput> for PendingMultiproofTask {
fn from(input: MultiproofInput) -> Self {
Self::Regular(input)
}
}
/// Input parameters for spawning a dedicated storage multiproof calculation.
#[derive(Debug)]
struct StorageMultiproofInput<Factory> {
config: MultiProofConfig<Factory>,
struct StorageMultiproofInput {
config: MultiProofConfig,
source: Option<StateChangeSource>,
hashed_state_update: HashedPostState,
hashed_address: B256,
@@ -303,7 +300,7 @@ struct StorageMultiproofInput<Factory> {
multi_added_removed_keys: Arc<MultiAddedRemovedKeys>,
}
impl<Factory> StorageMultiproofInput<Factory> {
impl StorageMultiproofInput {
/// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
fn send_empty_proof(self) {
let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
@@ -315,8 +312,8 @@ impl<Factory> StorageMultiproofInput<Factory> {
/// Input parameters for spawning a multiproof calculation.
#[derive(Debug)]
struct MultiproofInput<Factory> {
config: MultiProofConfig<Factory>,
struct MultiproofInput {
config: MultiProofConfig,
source: Option<StateChangeSource>,
hashed_state_update: HashedPostState,
proof_targets: MultiProofTargets,
@@ -325,7 +322,7 @@ struct MultiproofInput<Factory> {
multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
}
impl<Factory> MultiproofInput<Factory> {
impl MultiproofInput {
/// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
fn send_empty_proof(self) {
let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
@@ -340,17 +337,20 @@ impl<Factory> MultiproofInput<Factory> {
/// concurrency, further calculation requests are queued and spawn later, after
/// availability has been signaled.
#[derive(Debug)]
pub struct MultiproofManager<Factory: DatabaseProviderFactory> {
pub struct MultiproofManager {
/// Maximum number of concurrent calculations.
max_concurrent: usize,
/// Currently running calculations.
inflight: usize,
/// Queued calculations.
pending: VecDeque<PendingMultiproofTask<Factory>>,
pending: VecDeque<PendingMultiproofTask>,
/// Executor for tasks
executor: WorkloadExecutor,
/// Sender to the storage proof task.
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
/// Handle to the proof task manager used for creating `ParallelProof` instances for storage
/// proofs.
storage_proof_task_handle: ProofTaskManagerHandle,
/// Handle to the proof task manager used for account multiproofs.
account_proof_task_handle: ProofTaskManagerHandle,
/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
///
@@ -367,15 +367,13 @@ pub struct MultiproofManager<Factory: DatabaseProviderFactory> {
metrics: MultiProofTaskMetrics,
}
impl<Factory> MultiproofManager<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
impl MultiproofManager {
/// Creates a new [`MultiproofManager`].
fn new(
executor: WorkloadExecutor,
metrics: MultiProofTaskMetrics,
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
storage_proof_task_handle: ProofTaskManagerHandle,
account_proof_task_handle: ProofTaskManagerHandle,
max_concurrent: usize,
) -> Self {
Self {
@@ -385,6 +383,7 @@ where
inflight: 0,
metrics,
storage_proof_task_handle,
account_proof_task_handle,
missed_leaves_storage_roots: Default::default(),
}
}
@@ -395,7 +394,7 @@ where
/// Spawns a new multiproof calculation or enqueues it for later if
/// `max_concurrent` are already inflight.
fn spawn_or_queue(&mut self, input: PendingMultiproofTask<Factory>) {
fn spawn_or_queue(&mut self, input: PendingMultiproofTask) {
// If there are no proof targets, we can just send an empty multiproof back immediately
if input.proof_targets_is_empty() {
debug!(
@@ -429,7 +428,7 @@ where
/// Spawns a multiproof task, dispatching to `spawn_storage_proof` if the input is a storage
/// multiproof, and dispatching to `spawn_multiproof` otherwise.
fn spawn_multiproof_task(&mut self, input: PendingMultiproofTask<Factory>) {
fn spawn_multiproof_task(&mut self, input: PendingMultiproofTask) {
match input {
PendingMultiproofTask::Storage(storage_input) => {
self.spawn_storage_proof(storage_input);
@@ -441,7 +440,7 @@ where
}
/// Spawns a single storage proof calculation task.
fn spawn_storage_proof(&mut self, storage_multiproof_input: StorageMultiproofInput<Factory>) {
fn spawn_storage_proof(&mut self, storage_multiproof_input: StorageMultiproofInput) {
let StorageMultiproofInput {
config,
source,
@@ -468,12 +467,11 @@ where
);
let start = Instant::now();
let proof_result = ParallelProof::new(
config.consistent_view,
config.nodes_sorted,
config.state_sorted,
config.prefix_sets,
missed_leaves_storage_roots,
storage_proof_task_handle.clone(),
storage_proof_task_handle,
)
.with_branch_node_masks(true)
.with_multi_added_removed_keys(Some(multi_added_removed_keys))
@@ -516,7 +514,7 @@ where
}
/// Spawns a single multiproof calculation task.
fn spawn_multiproof(&mut self, multiproof_input: MultiproofInput<Factory>) {
fn spawn_multiproof(&mut self, multiproof_input: MultiproofInput) {
let MultiproofInput {
config,
source,
@@ -526,7 +524,7 @@ where
state_root_message_sender,
multi_added_removed_keys,
} = multiproof_input;
let storage_proof_task_handle = self.storage_proof_task_handle.clone();
let account_proof_task_handle = self.account_proof_task_handle.clone();
let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone();
self.executor.spawn_blocking(move || {
@@ -544,17 +542,37 @@ where
);
let start = Instant::now();
let proof_result = ParallelProof::new(
config.consistent_view,
config.nodes_sorted,
config.state_sorted,
config.prefix_sets,
// Extend prefix sets with targets
let frozen_prefix_sets =
ParallelProof::extend_prefix_sets_with_targets(&config.prefix_sets, &proof_targets);
// Queue account multiproof to worker pool
let input = AccountMultiproofInput {
targets: proof_targets,
prefix_sets: frozen_prefix_sets,
collect_branch_node_masks: true,
multi_added_removed_keys,
missed_leaves_storage_roots,
storage_proof_task_handle.clone(),
)
.with_branch_node_masks(true)
.with_multi_added_removed_keys(multi_added_removed_keys)
.decoded_multiproof(proof_targets);
};
let (sender, receiver) = channel();
let proof_result: Result<DecodedMultiProof, ParallelStateRootError> = (|| {
account_proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof to worker pool".into(),
)
})?;
receiver
.recv()
.map_err(|_| {
ParallelStateRootError::Other("Account multiproof channel closed".into())
})?
.map(|(proof, _stats)| proof)
})();
let elapsed = start.elapsed();
trace!(
target: "engine::root",
@@ -645,13 +663,13 @@ pub(crate) struct MultiProofTaskMetrics {
/// Then it updates relevant leaves according to the result of the transaction.
/// This feeds updates to the sparse trie task.
#[derive(Debug)]
pub(super) struct MultiProofTask<Factory: DatabaseProviderFactory> {
pub(super) struct MultiProofTask {
/// The size of proof targets chunk to spawn in one calculation.
///
/// If [`None`], then chunking is disabled.
chunk_size: Option<usize>,
/// Task configuration.
config: MultiProofConfig<Factory>,
config: MultiProofConfig,
/// Receiver for state root related messages.
rx: Receiver<MultiProofMessage>,
/// Sender for state root related messages.
@@ -665,20 +683,17 @@ pub(super) struct MultiProofTask<Factory: DatabaseProviderFactory> {
/// Proof sequencing handler.
proof_sequencer: ProofSequencer,
/// Manages calculation of multiproofs.
multiproof_manager: MultiproofManager<Factory>,
multiproof_manager: MultiproofManager,
/// multi proof task metrics
metrics: MultiProofTaskMetrics,
}
impl<Factory> MultiProofTask<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
impl MultiProofTask {
/// Creates a new multi proof task with the unified message channel
pub(super) fn new(
config: MultiProofConfig<Factory>,
config: MultiProofConfig,
executor: WorkloadExecutor,
proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
proof_task_handle: ProofTaskManagerHandle,
to_sparse_trie: Sender<SparseTrieUpdate>,
max_concurrency: usize,
chunk_size: Option<usize>,
@@ -698,7 +713,8 @@ where
multiproof_manager: MultiproofManager::new(
executor,
metrics.clone(),
proof_task_handle,
proof_task_handle.clone(), // handle for storage proof workers
proof_task_handle, // handle for account proof workers
max_concurrency,
),
metrics,
@@ -1202,43 +1218,29 @@ fn get_proof_targets(
mod tests {
use super::*;
use alloy_primitives::map::B256Set;
use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory};
use reth_provider::{
providers::ConsistentDbView, test_utils::create_test_provider_factory, BlockReader,
DatabaseProviderFactory,
};
use reth_trie::{MultiProof, TrieInput};
use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager};
use revm_primitives::{B256, U256};
use std::sync::Arc;
fn create_state_root_config<F>(factory: F, input: TrieInput) -> MultiProofConfig<F>
where
F: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
let consistent_view = ConsistentDbView::new(factory, None);
let nodes_sorted = Arc::new(input.nodes.clone().into_sorted());
let state_sorted = Arc::new(input.state.clone().into_sorted());
let prefix_sets = Arc::new(input.prefix_sets);
MultiProofConfig { consistent_view, nodes_sorted, state_sorted, prefix_sets }
}
fn create_test_state_root_task<F>(factory: F) -> MultiProofTask<F>
fn create_test_state_root_task<F>(factory: F) -> MultiProofTask
where
F: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
let executor = WorkloadExecutor::default();
let config = create_state_root_config(factory, TrieInput::default());
let (_trie_input, config) = MultiProofConfig::from_input(TrieInput::default());
let task_ctx = ProofTaskCtx::new(
config.nodes_sorted.clone(),
config.state_sorted.clone(),
config.prefix_sets.clone(),
);
let proof_task = ProofTaskManager::new(
executor.handle().clone(),
config.consistent_view.clone(),
task_ctx,
1,
1,
)
.expect("Failed to create ProofTaskManager");
let consistent_view = ConsistentDbView::new(factory, None);
let proof_task =
ProofTaskManager::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
.expect("Failed to create ProofTaskManager");
let channel = channel();
MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None)

View File

@@ -113,6 +113,11 @@ pub struct EngineArgs {
/// If not specified, defaults to 2x available parallelism, clamped between 2 and 64.
#[arg(long = "engine.storage-worker-count")]
pub storage_worker_count: Option<usize>,
/// Configure the number of account proof workers in the Tokio blocking pool.
/// If not specified, defaults to the same count as storage workers.
#[arg(long = "engine.account-worker-count")]
pub account_worker_count: Option<usize>,
}
#[allow(deprecated)]
@@ -140,6 +145,7 @@ impl Default for EngineArgs {
always_process_payload_attributes_on_canonical_head: false,
allow_unwind_canonical_header: false,
storage_worker_count: None,
account_worker_count: None,
}
}
}
@@ -171,6 +177,10 @@ impl EngineArgs {
config = config.with_storage_worker_count(count);
}
if let Some(count) = self.account_worker_count {
config = config.with_account_worker_count(count);
}
config
}
}

View File

@@ -1,40 +1,25 @@
use crate::{
metrics::ParallelTrieMetrics,
proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
proof_task::{
AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle, StorageProofInput,
},
root::ParallelStateRootError,
stats::ParallelTrieTracker,
StorageRootTargets,
};
use alloy_primitives::{
map::{B256Map, B256Set, HashMap},
B256,
};
use alloy_rlp::{BufMut, Encodable};
use alloy_primitives::{map::B256Set, B256};
use dashmap::DashMap;
use itertools::Itertools;
use reth_execution_errors::StorageRootError;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
ProviderError,
};
use reth_storage_errors::db::DatabaseError;
use reth_trie::{
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
node_iter::{TrieElement, TrieNodeIter},
prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
proof::StorageProof,
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSets, TriePrefixSetsMut},
updates::TrieUpdatesSorted,
walker::TrieWalker,
DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
DecodedMultiProof, DecodedStorageMultiProof, HashedPostStateSorted, MultiProofTargets, Nibbles,
};
use reth_trie_common::{
added_removed_keys::MultiAddedRemovedKeys,
proof::{DecodedProofNodes, ProofRetainer},
use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
use std::sync::{
mpsc::{channel, Receiver},
Arc,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::{mpsc::Receiver, Arc};
use tracing::trace;
/// Parallel proof calculator.
@@ -42,9 +27,7 @@ use tracing::trace;
/// This can collect proof for many targets in parallel, spawning a task for each hashed address
/// that has proof targets.
#[derive(Debug)]
pub struct ParallelProof<Factory: DatabaseProviderFactory> {
/// Consistent view of the database.
view: ConsistentDbView<Factory>,
pub struct ParallelProof {
/// The sorted collection of cached in-memory intermediate trie nodes that
/// can be reused for computation.
pub nodes_sorted: Arc<TrieUpdatesSorted>,
@@ -58,8 +41,8 @@ pub struct ParallelProof<Factory: DatabaseProviderFactory> {
collect_branch_node_masks: bool,
/// Provided by the user to give the necessary context to retain extra proofs.
multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
/// Handle to the storage proof task.
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
/// Handle to the proof task manager.
proof_task_handle: ProofTaskManagerHandle,
/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
@@ -67,25 +50,23 @@ pub struct ParallelProof<Factory: DatabaseProviderFactory> {
metrics: ParallelTrieMetrics,
}
impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
impl ParallelProof {
/// Create new state proof generator.
pub fn new(
view: ConsistentDbView<Factory>,
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
proof_task_handle: ProofTaskManagerHandle,
) -> Self {
Self {
view,
nodes_sorted,
state_sorted,
prefix_sets,
missed_leaves_storage_roots,
collect_branch_node_masks: false,
multi_added_removed_keys: None,
storage_proof_task_handle,
proof_task_handle,
#[cfg(feature = "metrics")]
metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
}
@@ -106,12 +87,6 @@ impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
self.multi_added_removed_keys = multi_added_removed_keys;
self
}
}
impl<Factory> ParallelProof<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
/// Queues a storage proof task and returns a receiver for the result.
fn queue_storage_proof(
&self,
@@ -128,8 +103,7 @@ where
);
let (sender, receiver) = std::sync::mpsc::channel();
let _ =
self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
let _ = self.proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
receiver
}
@@ -167,16 +141,16 @@ where
proof_result
}
/// Generate a state multiproof according to specified targets.
pub fn decoded_multiproof(
self,
targets: MultiProofTargets,
) -> Result<DecodedMultiProof, ParallelStateRootError> {
let mut tracker = ParallelTrieTracker::default();
// Extend prefix sets with targets
let mut prefix_sets = (*self.prefix_sets).clone();
prefix_sets.extend(TriePrefixSetsMut {
/// Extends prefix sets with the given multiproof targets and returns the frozen result.
///
/// This is a helper function used to prepare prefix sets before computing multiproofs.
/// Returns frozen (immutable) prefix sets ready for use in proof computation.
pub fn extend_prefix_sets_with_targets(
base_prefix_sets: &TriePrefixSetsMut,
targets: &MultiProofTargets,
) -> TriePrefixSets {
let mut extended = base_prefix_sets.clone();
extended.extend(TriePrefixSetsMut {
account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
storage_prefix_sets: targets
.iter()
@@ -187,13 +161,21 @@ where
.collect(),
destroyed_accounts: Default::default(),
});
let prefix_sets = prefix_sets.freeze();
extended.freeze()
}
let storage_root_targets = StorageRootTargets::new(
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
/// Generate a state multiproof according to specified targets.
pub fn decoded_multiproof(
self,
targets: MultiProofTargets,
) -> Result<DecodedMultiProof, ParallelStateRootError> {
// Extend prefix sets with targets
let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
let storage_root_targets_len = StorageRootTargets::count(
&prefix_sets.account_prefix_set,
&prefix_sets.storage_prefix_sets,
);
let storage_root_targets_len = storage_root_targets.len();
trace!(
target: "trie::parallel_proof",
@@ -201,150 +183,36 @@ where
"Starting parallel proof generation"
);
// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
// Queue account multiproof request to account worker pool
// stores the receiver for the storage proof outcome for the hashed addresses
// this way we can lazily await the outcome when we iterate over the map
let mut storage_proof_receivers =
B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
let input = AccountMultiproofInput {
targets,
prefix_sets,
collect_branch_node_masks: self.collect_branch_node_masks,
multi_added_removed_keys: self.multi_added_removed_keys.clone(),
missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
};
for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
let (sender, receiver) = channel();
self.proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof: account worker pool unavailable"
.to_string(),
)
})?;
// store the receiver for that result with the hashed address so we can await this in
// place when we iterate over the trie
storage_proof_receivers.insert(hashed_address, receiver);
}
// Wait for account multiproof result from worker
let (multiproof, stats) = receiver.recv().map_err(|_| {
ParallelStateRootError::Other(
"Account multiproof channel dropped: worker died or pool shutdown".to_string(),
)
})??;
let provider_ro = self.view.provider_ro()?;
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&self.nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&self.state_sorted,
);
let accounts_added_removed_keys =
self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
// Create the walker.
let walker = TrieWalker::<_>::state_trie(
trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
prefix_sets.account_prefix_set,
)
.with_added_removed_keys(accounts_added_removed_keys)
.with_deletions_retained(true);
// Create a hash builder to rebuild the root node since it is not available in the database.
let retainer = targets
.keys()
.map(Nibbles::unpack)
.collect::<ProofRetainer>()
.with_added_removed_keys(accounts_added_removed_keys);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_masks);
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
let mut account_node_iter = TrieNodeIter::state_trie(
walker,
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
);
while let Some(account_node) =
account_node_iter.try_next().map_err(ProviderError::Database)?
{
match account_node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
let root = match storage_proof_receivers.remove(&hashed_address) {
Some(rx) => {
let decoded_storage_multiproof = rx.recv().map_err(|e| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
DatabaseError::Other(format!(
"channel closed for {hashed_address}: {e}"
)),
))
})??;
let root = decoded_storage_multiproof.root;
collected_decoded_storages
.insert(hashed_address, decoded_storage_multiproof);
root
}
// Since we do not store all intermediate nodes in the database, there might
// be a possibility of re-adding a non-modified leaf to the hash builder.
None => {
tracker.inc_missed_leaves();
match self.missed_leaves_storage_roots.entry(hashed_address) {
dashmap::Entry::Occupied(occ) => *occ.get(),
dashmap::Entry::Vacant(vac) => {
let root = StorageProof::new_hashed(
trie_cursor_factory.clone(),
hashed_cursor_factory.clone(),
hashed_address,
)
.with_prefix_set_mut(Default::default())
.storage_multiproof(
targets.get(&hashed_address).cloned().unwrap_or_default(),
)
.map_err(|e| {
ParallelStateRootError::StorageRoot(
StorageRootError::Database(DatabaseError::Other(
e.to_string(),
)),
)
})?
.root;
vac.insert(root);
root
}
}
}
};
// Encode account
account_rlp.clear();
let account = account.into_trie_account(root);
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
}
}
}
let _ = hash_builder.root();
let stats = tracker.finish();
#[cfg(feature = "metrics")]
self.metrics.record(stats);
let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
)
} else {
(HashMap::default(), HashMap::default())
};
trace!(
target: "trie::parallel_proof",
total_targets = storage_root_targets_len,
@@ -356,12 +224,7 @@ where
"Calculated decoded proof"
);
Ok(DecodedMultiProof {
account_subtree: decoded_account_subtree,
branch_node_hash_masks,
branch_node_tree_masks,
storages: collected_decoded_storages,
})
Ok(multiproof)
}
}
@@ -371,13 +234,16 @@ mod tests {
use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
use alloy_primitives::{
keccak256,
map::{B256Set, DefaultHashBuilder},
map::{B256Set, DefaultHashBuilder, HashMap},
Address, U256,
};
use rand::Rng;
use reth_primitives_traits::{Account, StorageEntry};
use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
use reth_provider::{
providers::ConsistentDbView, test_utils::create_test_provider_factory, HashingWriter,
};
use reth_trie::proof::Proof;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use tokio::runtime::Runtime;
#[test]
@@ -448,8 +314,7 @@ mod tests {
let task_ctx =
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
let proof_task =
ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1, 1)
.unwrap();
ProofTaskManager::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
let proof_task_handle = proof_task.handle();
// keep the join handle around to make sure it does not return any errors
@@ -457,7 +322,6 @@ mod tests {
let join_handle = rt.spawn_blocking(move || proof_task.run());
let parallel_result = ParallelProof::new(
consistent_view,
Default::default(),
Default::default(),
Default::default(),

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,23 @@ impl StorageRootTargets {
.collect(),
)
}
/// Returns the total number of unique storage root targets without allocating new maps.
pub fn count(
account_prefix_set: &PrefixSet,
storage_prefix_sets: &B256Map<PrefixSet>,
) -> usize {
let mut count = storage_prefix_sets.len();
for nibbles in account_prefix_set {
let hashed_address = B256::from_slice(&nibbles.pack());
if !storage_prefix_sets.contains_key(&hashed_address) {
count += 1;
}
}
count
}
}
impl IntoIterator for StorageRootTargets {

View File

@@ -867,6 +867,9 @@ Engine:
--engine.storage-worker-count <STORAGE_WORKER_COUNT>
Configure the number of storage proof workers in the Tokio blocking pool. If not specified, defaults to 2x available parallelism, clamped between 2 and 64
--engine.account-worker-count <ACCOUNT_WORKER_COUNT>
Configure the number of account proof workers in the Tokio blocking pool. If not specified, defaults to the same count as storage workers
ERA:
--era.enable
Enable import from ERA1 files