feat(trie): Decode storage proofs in parallel tasks (#16400)

Signed-off-by: 7suyash7 <suyashnyn1@gmail.com>
This commit is contained in:
Suyash Nayan
2025-06-18 01:58:07 +05:30
committed by GitHub
parent 243a523149
commit 1d01f2a46d
2 changed files with 60 additions and 53 deletions

View File

@@ -25,10 +25,10 @@ use reth_trie::{
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
updates::TrieUpdatesSorted,
walker::TrieWalker,
DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted, MultiProof,
MultiProofTargets, Nibbles, StorageMultiProof, TRIE_ACCOUNT_RLP_MAX_SIZE,
DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted,
MultiProofTargets, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_common::proof::{DecodedProofNodes, ProofRetainer};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::{mpsc::Receiver, Arc};
use tracing::debug;
@@ -97,7 +97,7 @@ where
hashed_address: B256,
prefix_set: PrefixSet,
target_slots: B256Set,
) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
let input = StorageProofInput::new(
hashed_address,
prefix_set,
@@ -116,7 +116,7 @@ where
self,
hashed_address: B256,
target_slots: B256Set,
) -> Result<StorageMultiProof, ParallelStateRootError> {
) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
let total_targets = target_slots.len();
let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
let prefix_set = prefix_set.freeze();
@@ -152,19 +152,14 @@ where
hashed_address: B256,
target_slots: B256Set,
) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
let proof = self.storage_proof(hashed_address, target_slots)?;
// Now decode the nodes of the proof
let proof = proof.try_into()?;
Ok(proof)
self.storage_proof(hashed_address, target_slots)
}
/// Generate a state multiproof according to specified targets.
pub fn multiproof(
pub fn decoded_multiproof(
self,
targets: MultiProofTargets,
) -> Result<MultiProof, ParallelStateRootError> {
) -> Result<DecodedMultiProof, ParallelStateRootError> {
let mut tracker = ParallelTrieTracker::default();
// Extend prefix sets with targets
@@ -199,7 +194,7 @@ where
// stores the receiver for the storage proof outcome for the hashed addresses
// this way we can lazily await the outcome when we iterate over the map
let mut storage_proofs =
let mut storage_proof_receivers =
B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
for (hashed_address, prefix_set) in
@@ -210,7 +205,7 @@ where
// store the receiver for that result with the hashed address so we can await this in
// place when we iterate over the trie
storage_proofs.insert(hashed_address, receiver);
storage_proof_receivers.insert(hashed_address, receiver);
}
let provider_ro = self.view.provider_ro()?;
@@ -238,8 +233,8 @@ where
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
let mut storages: B256Map<_> =
targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
targets.keys().map(|key| (*key, DecodedStorageMultiProof::empty())).collect();
let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
let mut account_node_iter = TrieNodeIter::state_trie(
walker,
@@ -253,11 +248,13 @@ where
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
let storage_multiproof = match storage_proofs.remove(&hashed_address) {
Some(rx) => rx.recv().map_err(|_| {
let decoded_storage_multiproof = match storage_proof_receivers
.remove(&hashed_address)
{
Some(rx) => rx.recv().map_err(|e| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
DatabaseError::Other(format!(
"channel closed for {hashed_address}"
"channel closed for {hashed_address}: {e}"
)),
))
})??,
@@ -265,7 +262,8 @@ where
// be a possibility of re-adding a non-modified leaf to the hash builder.
None => {
tracker.inc_missed_leaves();
StorageProof::new_hashed(
let raw_fallback_proof = StorageProof::new_hashed(
trie_cursor_factory.clone(),
hashed_cursor_factory.clone(),
hashed_address,
@@ -278,20 +276,23 @@ where
ParallelStateRootError::StorageRoot(StorageRootError::Database(
DatabaseError::Other(e.to_string()),
))
})?
})?;
raw_fallback_proof.try_into()?
}
};
// Encode account
account_rlp.clear();
let account = account.into_trie_account(storage_multiproof.root);
let account = account.into_trie_account(decoded_storage_multiproof.root);
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
// We might be adding leaves that are not necessarily our proof targets.
if targets.contains_key(&hashed_address) {
storages.insert(hashed_address, storage_multiproof);
collected_decoded_storages
.insert(hashed_address, decoded_storage_multiproof);
}
}
}
@@ -302,7 +303,9 @@ where
#[cfg(feature = "metrics")]
self.metrics.record(stats);
let account_subtree = hash_builder.take_proof_nodes();
let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
@@ -327,25 +330,15 @@ where
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
precomputed_storage_roots = stats.precomputed_storage_roots(),
"Calculated proof"
"Calculated decoded proof"
);
Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
}
/// Returns a [`DecodedMultiProof`] for the given proof.
///
/// Uses `multiproof` first to get the proof, and then decodes the nodes of the multiproof.
pub fn decoded_multiproof(
self,
targets: MultiProofTargets,
) -> Result<DecodedMultiProof, ParallelStateRootError> {
let multiproof = self.multiproof(targets)?;
// Now decode the nodes of the multiproof
let multiproof = multiproof.try_into()?;
Ok(multiproof)
Ok(DecodedMultiProof {
account_subtree: decoded_account_subtree,
branch_node_hash_masks,
branch_node_tree_masks,
storages: collected_decoded_storages,
})
}
}
@@ -446,26 +439,31 @@ mod tests {
Default::default(),
proof_task_handle.clone(),
)
.multiproof(targets.clone())
.decoded_multiproof(targets.clone())
.unwrap();
let sequential_result =
Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
let sequential_result_raw = Proof::new(trie_cursor_factory, hashed_cursor_factory)
.multiproof(targets.clone())
.unwrap(); // targets might be consumed by parallel_result
let sequential_result_decoded: DecodedMultiProof = sequential_result_raw
.try_into()
.expect("Failed to decode sequential_result for test comparison");
// to help narrow down what is wrong - first compare account subtries
assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
assert_eq!(parallel_result.account_subtree, sequential_result_decoded.account_subtree);
// then compare length of all storage subtries
assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
assert_eq!(parallel_result.storages.len(), sequential_result_decoded.storages.len());
// then compare each storage subtrie
for (hashed_address, storage_proof) in &parallel_result.storages {
let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
let sequential_storage_proof =
sequential_result_decoded.storages.get(hashed_address).unwrap();
assert_eq!(storage_proof, sequential_storage_proof);
}
// then compare the entire thing for any mask differences
assert_eq!(parallel_result, sequential_result);
assert_eq!(parallel_result, sequential_result_decoded);
// drop the handle to terminate the task and then block on the proof task handle to make
// sure it does not return any errors

View File

@@ -22,7 +22,7 @@ use reth_trie::{
proof::{ProofBlindedProviderFactory, StorageProof},
trie_cursor::InMemoryTrieCursorFactory,
updates::TrieUpdatesSorted,
HashedPostStateSorted, Nibbles, StorageMultiProof,
DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
};
use reth_trie_common::prefix_set::{PrefixSet, PrefixSetMut};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
@@ -39,7 +39,7 @@ use std::{
use tokio::runtime::Handle;
use tracing::debug;
type StorageProofResult = Result<StorageMultiProof, ParallelStateRootError>;
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
type BlindedNodeResult = Result<Option<RevealedNode>, SparseTrieError>;
/// A task that manages sending multiproof requests to a number of tasks that have longer-running
@@ -244,7 +244,7 @@ where
let target_slots_len = input.target_slots.len();
let proof_start = Instant::now();
let result = StorageProof::new_hashed(
let raw_proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
input.hashed_address,
@@ -254,6 +254,15 @@ where
.storage_multiproof(input.target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
input.hashed_address, e
))
})
});
debug!(
target: "trie::proof_task",
hashed_address=?input.hashed_address,
@@ -264,7 +273,7 @@ where
);
// send the result back
if let Err(error) = result_sender.send(result) {
if let Err(error) = result_sender.send(decoded_result) {
debug!(
target: "trie::proof_task",
hashed_address = ?input.hashed_address,