refactor: streamline account multiproof generation

- Removed unnecessary intermediate variables and streamlined the logic for calculating storage root targets.
- Introduced a new helper function `build_account_multiproof_with_storage_roots` to encapsulate the account multiproof construction process.
- Enhanced error handling for account multiproof task queuing and result retrieval.
- Improved documentation for clarity on the multiproof generation workflow and related methods.
This commit is contained in:
Yong Kang
2025-10-08 11:26:10 +00:00
parent c48085dfbc
commit 66bf6945fd
2 changed files with 79 additions and 54 deletions

View File

@@ -191,16 +191,14 @@ where
self,
targets: MultiProofTargets,
) -> Result<DecodedMultiProof, ParallelStateRootError> {
let mut tracker = ParallelTrieTracker::default();
// Extend prefix sets with targets
let prefix_sets = Self::extend_prefix_sets_with_targets(&self.prefix_sets, &targets);
let storage_root_targets = StorageRootTargets::new(
let storage_root_targets_len = StorageRootTargets::new(
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
);
let storage_root_targets_len = storage_root_targets.len();
)
.len();
trace!(
target: "trie::parallel_proof",
@@ -208,42 +206,80 @@ where
"Starting parallel proof generation"
);
// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
// Queue account multiproof request to account worker pool
// stores the receiver for the storage proof outcome for the hashed addresses
// this way we can lazily await the outcome when we iterate over the map
let mut storage_proof_receivers =
B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
let input = AccountMultiproofInput {
targets,
prefix_sets,
collect_branch_node_masks: self.collect_branch_node_masks,
multi_added_removed_keys: self.multi_added_removed_keys.clone(),
missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
};
for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
let (sender, receiver) = channel();
self.storage_proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof: account worker pool unavailable"
.to_string(),
)
})?;
// store the receiver for that result with the hashed address so we can await this in
// place when we iterate over the trie
storage_proof_receivers.insert(hashed_address, receiver);
}
// Wait for account multiproof result from worker
let (multiproof, stats) = receiver.recv().map_err(|_| {
ParallelStateRootError::Other(
"Account multiproof channel dropped: worker died or pool shutdown".to_string(),
)
})??;
let provider_ro = self.view.provider_ro()?;
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&self.nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&self.state_sorted,
#[cfg(feature = "metrics")]
self.metrics.record(stats);
trace!(
target: "trie::parallel_proof",
total_targets = storage_root_targets_len,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
precomputed_storage_roots = stats.precomputed_storage_roots(),
"Calculated decoded proof"
);
Ok(multiproof)
}
}
/// Builds an account multiproof given pre-collected storage proofs.
///
/// This is a helper function used by both `decoded_multiproof` and account workers to build
/// the account subtree proof after storage proofs have been collected.
///
/// Returns a `DecodedMultiProof` containing the account subtree and storage proofs.
#[allow(clippy::too_many_arguments)]
pub(crate) fn build_account_multiproof_with_storage_roots<C, H>(
trie_cursor_factory: C,
hashed_cursor_factory: H,
targets: &MultiProofTargets,
prefix_set: PrefixSet,
collect_branch_node_masks: bool,
multi_added_removed_keys: Option<&Arc<MultiAddedRemovedKeys>>,
mut storage_proofs: B256Map<DecodedStorageMultiProof>,
missed_leaves_storage_roots: &DashMap<B256, B256>,
tracker: &mut ParallelTrieTracker,
) -> Result<DecodedMultiProof, ParallelStateRootError>
where
C: TrieCursorFactory + Clone,
H: HashedCursorFactory + Clone,
{
let accounts_added_removed_keys =
self.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
// Create the walker.
let walker = TrieWalker::<_>::state_trie(
trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
prefix_sets.account_prefix_set,
prefix_set,
)
.with_added_removed_keys(accounts_added_removed_keys)
.with_deletions_retained(true);
@@ -256,7 +292,7 @@ where
.with_added_removed_keys(accounts_added_removed_keys);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_masks);
.with_updates(collect_branch_node_masks);
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
@@ -267,9 +303,8 @@ where
walker,
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
);
while let Some(account_node) =
account_node_iter.try_next().map_err(ProviderError::Database)?
{
while let Some(account_node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
match account_node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
@@ -339,37 +374,22 @@ where
let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let (branch_node_hash_masks, branch_node_tree_masks) = if collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
updated_branch_nodes.into_iter().map(|(path, node)| (path, node.tree_mask)).collect(),
)
} else {
(HashMap::default(), HashMap::default())
(Default::default(), Default::default())
};
trace!(
target: "trie::parallel_proof",
total_targets = storage_root_targets_len,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
precomputed_storage_roots = stats.precomputed_storage_roots(),
"Calculated decoded proof"
);
Ok(DecodedMultiProof {
account_subtree: decoded_account_subtree,
branch_node_hash_masks,
branch_node_tree_masks,
storages: collected_decoded_storages,
})
}
}
#[cfg(test)]

View File

@@ -8,7 +8,11 @@
//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
use crate::root::ParallelStateRootError;
use crate::{
root::ParallelStateRootError,
stats::{ParallelTrieStats, ParallelTrieTracker},
StorageRootTargets,
};
use alloy_primitives::{
map::{B256Map, B256Set},
B256,
@@ -52,7 +56,8 @@ use crate::proof_task_metrics::ProofTaskMetrics;
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
type AccountMultiproofResult = Result<DecodedMultiProof, ParallelStateRootError>;
type AccountMultiproofResult =
Result<(DecodedMultiProof, ParallelTrieStats), ParallelStateRootError>;
/// Worker type identifier
#[derive(Debug)]