Compare commits

..

1 Commits

Author SHA1 Message Date
Georgios Konstantopoulos
7840143e6d feat(trie): add batch node fetch API to reduce O(n²) proof walks
Amp-Thread-ID: https://ampcode.com/threads/T-019bfe25-43f3-75ac-98f7-32bf937b69e1
Co-authored-by: Amp <amp@ampcode.com>
2026-01-27 07:40:11 +00:00
3 changed files with 94 additions and 8 deletions

View File

@@ -36,15 +36,13 @@ pub trait TrieNodeProvider {
/// Retrieve trie node by path.
fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError>;
/// Batch retrieve trie nodes by paths.
///
/// Returns a map from path to revealed node. Missing nodes are omitted.
/// Default implementation falls back to sequential fetching.
/// Batch retrieve trie nodes by paths. Default: sequential fallback.
fn trie_nodes_batch(
&self,
paths: &[Nibbles],
) -> Result<HashMap<Nibbles, RevealedNode>, SparseTrieError> {
let mut result = HashMap::with_capacity_and_hasher(paths.len(), Default::default());
let mut result = HashMap::default();
result.reserve(paths.len());
for path in paths {
if let Some(node) = self.trie_node(path)? {
result.insert(path.clone(), node);

View File

@@ -1417,7 +1417,7 @@ impl SerialSparseTrie {
while let Some((mut path, level)) = paths.pop() {
match self.nodes.get(&path).unwrap() {
SparseNode::Empty | SparseNode::Hash(_) => {}
SparseNode::Leaf { key: _, hash, .. } => {
SparseNode::Leaf { key: _, hash } => {
if hash.is_some() && !prefix_set.contains(&path) {
continue
}
@@ -1438,7 +1438,7 @@ impl SerialSparseTrie {
paths.push((path, level + 1));
}
}
SparseNode::Branch { state_mask, hash, .. } => {
SparseNode::Branch { state_mask, hash, store_in_db_trie: _ } => {
if hash.is_some() && !prefix_set.contains(&path) {
continue
}

View File

@@ -1,6 +1,6 @@
use super::{Proof, StorageProof};
use crate::{hashed_cursor::HashedCursorFactory, trie_cursor::TrieCursorFactory};
use alloy_primitives::{map::HashSet, B256};
use alloy_primitives::{map::{B256Set, HashMap, HashSet}, B256};
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
use reth_trie_common::{MultiProofTargets, Nibbles};
use reth_trie_sparse::provider::{
@@ -94,6 +94,48 @@ where
);
Ok(node.map(|node| RevealedNode { node, tree_mask, hash_mask }))
}
fn trie_nodes_batch(
&self,
paths: &[Nibbles],
) -> Result<HashMap<Nibbles, RevealedNode>, SparseTrieError> {
if paths.is_empty() {
return Ok(HashMap::default());
}
let start = enabled!(target: "trie::proof::blinded", Level::TRACE).then(Instant::now);
let targets: MultiProofTargets =
paths.iter().map(|p| (pad_path_to_key(p), HashSet::default())).collect();
let mut proof = Proof::new(&self.trie_cursor_factory, &self.hashed_cursor_factory)
.with_branch_node_masks(true)
.multiproof(targets)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
let mut account_subtree = proof.account_subtree.into_inner();
let mut result = HashMap::default();
result.reserve(paths.len());
for path in paths {
if let Some(node) = account_subtree.remove(path) {
let masks = proof.branch_node_masks.remove(path);
let hash_mask = masks.map(|m| m.hash_mask);
let tree_mask = masks.map(|m| m.tree_mask);
result.insert(path.clone(), RevealedNode { node, tree_mask, hash_mask });
}
}
trace!(
target: "trie::proof::blinded",
elapsed = ?start.unwrap().elapsed(),
paths_count = paths.len(),
results_count = result.len(),
"Batch blinded nodes for account trie"
);
Ok(result)
}
}
/// Blinded provider for retrieving storage trie nodes by path.
@@ -148,4 +190,50 @@ where
);
Ok(node.map(|node| RevealedNode { node, tree_mask, hash_mask }))
}
fn trie_nodes_batch(
&self,
paths: &[Nibbles],
) -> Result<HashMap<Nibbles, RevealedNode>, SparseTrieError> {
if paths.is_empty() {
return Ok(HashMap::default());
}
let start = enabled!(target: "trie::proof::blinded", Level::TRACE).then(Instant::now);
let targets: B256Set = paths.iter().map(pad_path_to_key).collect();
let mut proof = StorageProof::new_hashed(
&self.trie_cursor_factory,
&self.hashed_cursor_factory,
self.account,
)
.with_branch_node_masks(true)
.storage_multiproof(targets)
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
let mut subtree = proof.subtree.into_inner();
let mut result = HashMap::default();
result.reserve(paths.len());
for path in paths {
if let Some(node) = subtree.remove(path) {
let masks = proof.branch_node_masks.remove(path);
let hash_mask = masks.map(|m| m.hash_mask);
let tree_mask = masks.map(|m| m.tree_mask);
result.insert(path.clone(), RevealedNode { node, tree_mask, hash_mask });
}
}
trace!(
target: "trie::proof::blinded",
account = ?self.account,
elapsed = ?start.unwrap().elapsed(),
paths_count = paths.len(),
results_count = result.len(),
"Batch blinded nodes for storage trie"
);
Ok(result)
}
}