From a9e36923e1f8efc20d986bbe54ae1dca02b448ed Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Mon, 15 Dec 2025 16:27:04 +0100 Subject: [PATCH] feat(trie): Proof Rewrite: Use cached branch nodes (#20075) Co-authored-by: YK Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com> --- Cargo.lock | 1 + crates/trie/trie/Cargo.toml | 2 + crates/trie/trie/benches/proof_v2.rs | 54 +- crates/trie/trie/src/hashed_cursor/mock.rs | 7 +- crates/trie/trie/src/proof_v2/mod.rs | 1032 ++++++++++++++++---- crates/trie/trie/src/proof_v2/node.rs | 26 +- crates/trie/trie/src/proof_v2/value.rs | 3 +- 7 files changed, 907 insertions(+), 218 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 770a7f26d1..5cfc730416 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10895,6 +10895,7 @@ dependencies = [ "pretty_assertions", "proptest", "proptest-arbitrary-interop", + "rand 0.9.2", "reth-ethereum-primitives", "reth-execution-errors", "reth-metrics", diff --git a/crates/trie/trie/Cargo.toml b/crates/trie/trie/Cargo.toml index 504f1bc6c2..d3540adda8 100644 --- a/crates/trie/trie/Cargo.toml +++ b/crates/trie/trie/Cargo.toml @@ -64,6 +64,7 @@ parking_lot.workspace = true pretty_assertions.workspace = true proptest-arbitrary-interop.workspace = true proptest.workspace = true +rand.workspace = true [features] metrics = ["reth-metrics", "dep:metrics"] @@ -84,6 +85,7 @@ serde = [ "revm-state/serde", "parking_lot/serde", "reth-ethereum-primitives/serde", + "rand/serde", ] test-utils = [ "triehash", diff --git a/crates/trie/trie/benches/proof_v2.rs b/crates/trie/trie/benches/proof_v2.rs index d592091cee..e5123ddc9a 100644 --- a/crates/trie/trie/benches/proof_v2.rs +++ b/crates/trie/trie/benches/proof_v2.rs @@ -11,20 +11,19 @@ use reth_trie::{ proof_v2::StorageProofCalculator, trie_cursor::{mock::MockTrieCursorFactory, TrieCursorFactory}, }; -use reth_trie_common::{HashedPostState, HashedStorage, Nibbles}; -use std::collections::BTreeMap; +use reth_trie_common::{HashedPostState, HashedStorage}; /// Generate test data for benchmarking. /// /// Returns a tuple of: /// - Hashed address for the storage trie /// - `HashedPostState` with random storage slots -/// - Proof targets (Nibbles) that are 80% from existing slots, 20% random +/// - Proof targets as B256 (sorted) for V2 implementation /// - Equivalent [`B256Set`] for legacy implementation fn generate_test_data( dataset_size: usize, num_targets: usize, -) -> (B256, HashedPostState, Vec, B256Set) { +) -> (B256, HashedPostState, Vec, B256Set) { let mut runner = TestRunner::deterministic(); // Use a fixed hashed address for the storage trie @@ -68,14 +67,8 @@ fn generate_test_data( let target_b256s = targets_strategy.new_tree(&mut runner).unwrap().current(); - // Convert B256 targets to sorted Nibbles for V2 - let mut targets: Vec = target_b256s - .iter() - .map(|b256| { - // SAFETY: B256 is exactly 32 bytes - unsafe { Nibbles::unpack_unchecked(b256.as_slice()) } - }) - .collect(); + // Sort B256 targets for V2 (storage_proof expects sorted targets) + let mut targets: Vec = target_b256s.clone(); targets.sort(); // Create B256Set for legacy @@ -86,19 +79,42 @@ fn generate_test_data( /// Create cursor factories from a `HashedPostState` for storage trie testing. /// -/// This mimics the test harness pattern from the `proof_v2` tests. +/// This mimics the test harness pattern from the `proof_v2` tests by using `StateRoot` +/// to generate `TrieUpdates` from the `HashedPostState`. fn create_cursor_factories( post_state: &HashedPostState, ) -> (MockTrieCursorFactory, MockHashedCursorFactory) { - // Ensure that there's a storage trie dataset for every storage trie, even if empty - let storage_trie_nodes: B256Map> = - post_state.storages.keys().copied().map(|addr| (addr, Default::default())).collect(); + use reth_trie::{updates::StorageTrieUpdates, StateRoot}; + + // Create empty trie cursor factory to serve as the initial state for StateRoot + // Ensure that there's a storage trie dataset for every storage account + let storage_tries: B256Map<_> = post_state + .storages + .keys() + .copied() + .map(|addr| (addr, StorageTrieUpdates::default())) + .collect(); + + let empty_trie_cursor_factory = + MockTrieCursorFactory::from_trie_updates(reth_trie_common::updates::TrieUpdates { + storage_tries: storage_tries.clone(), + ..Default::default() + }); // Create mock hashed cursor factory from the post state let hashed_cursor_factory = MockHashedCursorFactory::from_hashed_post_state(post_state.clone()); - // Create empty trie cursor factory (leaf-only calculator doesn't need trie nodes) - let trie_cursor_factory = MockTrieCursorFactory::new(BTreeMap::new(), storage_trie_nodes); + // Generate TrieUpdates using StateRoot + let (_root, mut trie_updates) = + StateRoot::new(empty_trie_cursor_factory, hashed_cursor_factory.clone()) + .root_with_updates() + .expect("StateRoot should succeed"); + + // Continue using empty storage tries for each account + trie_updates.storage_tries = storage_tries; + + // Initialize trie cursor factory from the generated TrieUpdates + let trie_cursor_factory = MockTrieCursorFactory::from_trie_updates(trie_updates); (trie_cursor_factory, hashed_cursor_factory) } @@ -148,7 +164,7 @@ fn bench_proof_algos(c: &mut Criterion) { || targets.clone(), |targets| { proof_calculator - .storage_proof(hashed_address, targets.into_iter()) + .storage_proof(hashed_address, targets) .expect("Proof generation failed"); }, BatchSize::SmallInput, diff --git a/crates/trie/trie/src/hashed_cursor/mock.rs b/crates/trie/trie/src/hashed_cursor/mock.rs index 63f1b138fe..15edd97ade 100644 --- a/crates/trie/trie/src/hashed_cursor/mock.rs +++ b/crates/trie/trie/src/hashed_cursor/mock.rs @@ -48,7 +48,7 @@ impl MockHashedCursorFactory { .collect(); // Extract storages from post state - let hashed_storages: B256Map> = post_state + let mut hashed_storages: B256Map> = post_state .storages .into_iter() .map(|(addr, hashed_storage)| { @@ -62,6 +62,11 @@ impl MockHashedCursorFactory { }) .collect(); + // Ensure all accounts have at least an empty storage + for account in hashed_accounts.keys() { + hashed_storages.entry(*account).or_default(); + } + Self::new(hashed_accounts, hashed_storages) } diff --git a/crates/trie/trie/src/proof_v2/mod.rs b/crates/trie/trie/src/proof_v2/mod.rs index 1d818a90e9..7606164dda 100644 --- a/crates/trie/trie/src/proof_v2/mod.rs +++ b/crates/trie/trie/src/proof_v2/mod.rs @@ -13,7 +13,7 @@ use crate::{ }; use alloy_primitives::{B256, U256}; use alloy_rlp::Encodable; -use alloy_trie::TrieMask; +use alloy_trie::{BranchNodeCompact, TrieMask}; use reth_execution_errors::trie::StateProofError; use reth_trie_common::{BranchNode, Nibbles, ProofTrieNode, RlpNode, TrieMasks, TrieNode}; use std::{cmp::Ordering, iter::Peekable}; @@ -31,6 +31,17 @@ static TRACE_TARGET: &str = "trie::proof_v2"; /// Number of bytes to pre-allocate for [`ProofCalculator`]'s `rlp_encode_buf` field. const RLP_ENCODE_BUF_SIZE: usize = 1024; +/// A [`Nibbles`] which contains 64 zero nibbles. +static PATH_ALL_ZEROS: Nibbles = { + let mut path = Nibbles::new(); + let mut i = 0; + while i < 64 { + path.push_unchecked(0); + i += 1; + } + path +}; + /// A proof calculator that generates merkle proofs using only leaf data. /// /// The calculator: @@ -66,6 +77,10 @@ pub struct ProofCalculator { /// ever be modified, and therefore all children besides the last are expected to be /// [`ProofTrieBranchChild::RlpNode`]s. child_stack: Vec>, + /// Cached branch data pulled from the `trie_cursor`. The calculator will use the cached + /// [`BranchNodeCompact::hashes`] to skip over the calculation of sub-tries in the overall + /// trie. The cached hashes cannot be used for any paths which are prefixes of a proof target. + cached_branch_stack: Vec<(Nibbles, BranchNodeCompact)>, /// The proofs which will be returned from the calculation. This gets taken at the end of every /// proof call. retained_proofs: Vec, @@ -85,9 +100,10 @@ impl ProofCalculator { Self { trie_cursor, hashed_cursor, - branch_stack: Vec::<_>::new(), + branch_stack: Vec::<_>::with_capacity(64), branch_path: Nibbles::new(), child_stack: Vec::<_>::new(), + cached_branch_stack: Vec::<_>::with_capacity(64), retained_proofs: Vec::<_>::new(), rlp_nodes_bufs: Vec::<_>::new(), rlp_encode_buf: Vec::<_>::with_capacity(RLP_ENCODE_BUF_SIZE), @@ -118,6 +134,17 @@ where .unwrap_or_else(|| Vec::with_capacity(16)) } + // Returns zero if `branch_stack` is empty, one otherwise. + // + // This is used when working with the `ext_len` field of [`ProofTrieBranch`]. The `ext_len` is + // calculated by taking the difference of the current `branch_path` and the new branch's path; + // if the new branch has a parent branch (ie `branch_stack` is not empty) then 1 is subtracted + // from the `ext_len` to account for the child's nibble on the parent. + #[inline] + const fn maybe_parent_nibble(&self) -> usize { + !self.branch_stack.is_empty() as usize + } + /// Returns true if the proof of a node at the given path should be retained. /// A node is retained if its path is a prefix of any target. /// This may move the @@ -173,17 +200,23 @@ where let &(mut lower, mut upper) = targets.peek().expect("targets is never exhausted"); - // If the path isn't in the current range then iterate forward until it is (or until there - // is no upper bound, indicating unbounded). - while upper.is_some_and(|upper| depth_first::cmp(path, &upper) != Ordering::Less) { - targets.next(); - trace!(target: TRACE_TARGET, target = ?targets.peek(), "upper target <= path, next target"); - let &(l, u) = targets.peek().expect("targets is never exhausted"); - (lower, upper) = (l, u); - } + loop { + // If the node in question is a prefix of the target then we retain + if lower.starts_with(path) { + return true + } - // If the node in question is a prefix of the target then we retain - lower.starts_with(path) + // If the path isn't in the current range then iterate forward until it is (or until + // there is no upper bound, indicating unbounded). + if upper.is_some_and(|upper| depth_first::cmp(path, &upper) != Ordering::Less) { + targets.next(); + trace!(target: TRACE_TARGET, target = ?targets.peek(), "upper target <= path, next target"); + let &(l, u) = targets.peek().expect("targets is never exhausted"); + (lower, upper) = (l, u); + } else { + return false + } + } } /// Takes a child which has been removed from the `child_stack` and converts it to an @@ -236,6 +269,27 @@ where Ok(child_rlp_node) } + /// Returns the path of the child of the currently under-construction branch at the given + /// nibble. + #[inline] + fn child_path_at(&self, nibble: u8) -> Nibbles { + let mut child_path = self.branch_path; + debug_assert!(child_path.len() < 64); + child_path.push_unchecked(nibble); + child_path + } + + /// Returns index of the highest nibble which is set in the mask. + /// + /// # Panics + /// + /// Will panic in debug mode if the mask is empty. + #[inline] + fn highest_set_nibble(mask: TrieMask) -> u8 { + debug_assert!(!mask.is_empty()); + (u16::BITS - mask.leading_zeros() - 1) as u8 + } + /// Returns the path of the child on top of the `child_stack`, or the root path if the stack is /// empty. fn last_child_path(&self) -> Nibbles { @@ -244,34 +298,23 @@ where return Nibbles::new(); }; - debug_assert_ne!(branch.state_mask.get(), 0, "branch.state_mask can never be zero"); - let last_nibble = u16::BITS - branch.state_mask.leading_zeros() - 1; - - let mut child_path = self.branch_path; - debug_assert!(child_path.len() < 64); - child_path.push_unchecked(last_nibble as u8); - child_path + self.child_path_at(Self::highest_set_nibble(branch.state_mask)) } /// Calls [`Self::commit_child`] on the last child of `child_stack`, replacing it with a /// [`ProofTrieBranchChild::RlpNode`]. /// + /// If `child_stack` is empty then this is a no-op. + /// /// NOTE that this method call relies on the `state_mask` of the top branch of the /// `branch_stack` to determine the last child's path. When committing the last child prior to /// pushing a new child, it's important to set the new child's `state_mask` bit _after_ the call /// to this method. - /// - /// # Panics - /// - /// This method panics if the `child_stack` is empty. fn commit_last_child( &mut self, targets: &mut TargetsIter>, ) -> Result<(), StateProofError> { - let child = self - .child_stack - .pop() - .expect("`commit_last_child` cannot be called with empty `child_stack`"); + let Some(child) = self.child_stack.pop() else { return Ok(()) }; // If the child is already an `RlpNode` then there is nothing to do, push it back on with no // changes. @@ -281,6 +324,9 @@ where } let child_path = self.last_child_path(); + // TODO theoretically `commit_child` only needs to convert to an `RlpNode` if it's going to + // retain the proof, otherwise we could leave the child as-is on the stack and convert it + // when popping the branch, giving more time to the DeferredEncoder to do async work. let child_rlp_node = self.commit_child(targets, child_path, child)?; // Replace the child on the stack @@ -303,12 +349,11 @@ where leaf_val: VE::DeferredEncoder, ) -> Result<(), StateProofError> { // Before pushing the new leaf onto the `child_stack` we need to commit the previous last - // child (ie the first child of this new branch), so that only `child_stack`'s final child - // is a non-RlpNode. + // child, so that only `child_stack`'s final child is a non-RlpNode. self.commit_last_child(targets)?; - // Once the first child is committed we set the new child's bit on the top branch's - // `state_mask` and push that child. + // Once the last child is committed we set the new child's bit on the top branch's + // `state_mask` and push that new child. let branch = self.branch_stack.last_mut().expect("branch_stack cannot be empty"); debug_assert!(!branch.state_mask.is_bit_set(leaf_nibble)); @@ -320,37 +365,25 @@ where Ok(()) } - /// Pushes a new branch onto the `branch_stack`, while also pushing the given leaf onto the - /// `child_stack`. + /// Pushes a new branch onto the `branch_stack` based on the path and short key of the last + /// child on the `child_stack` and the path of the next child which will be pushed on to the + /// stack after this call. /// - /// This method expects that there already exists a child on the `child_stack`, and that that - /// child has a non-zero short key. The new branch is constructed based on the top child from - /// the `child_stack` and the given leaf. - fn push_new_branch( - &mut self, - targets: &mut TargetsIter>, - leaf_key: Nibbles, - leaf_val: VE::DeferredEncoder, - ) -> Result<(), StateProofError> { - // First determine the new leaf's shortkey relative to the current branch. If there is no - // current branch then the short key is the full key. - let leaf_short_key = if self.branch_stack.is_empty() { - leaf_key + /// Returns the nibble of the branch's `state_mask` which should be set for the new child, and + /// short key that the next child should use. + fn push_new_branch(&mut self, new_child_path: Nibbles) -> (u8, Nibbles) { + // First determine the new child's shortkey relative to the current branch. If there is no + // current branch then the short key is the full path. + let new_child_short_key = if self.branch_stack.is_empty() { + new_child_path } else { // When there is a current branch then trim off its path as well as the nibble that it // has set for this leaf. - trim_nibbles_prefix(&leaf_key, self.branch_path.len() + 1) + trim_nibbles_prefix(&new_child_path, self.branch_path.len() + 1) }; - trace!( - target: TRACE_TARGET, - ?leaf_short_key, - branch_path = ?self.branch_path, - "push_new_branch: called", - ); - // Get the new branch's first child, which is the child on the top of the stack with which - // the new leaf shares the same nibble on the current branch. + // the new child shares the same nibble on the current branch. let first_child = self .child_stack .last_mut() @@ -363,8 +396,8 @@ where ); // Determine how many nibbles are shared between the new branch's first child and the new - // leaf. This common prefix will be the extension of the new branch - let common_prefix_len = first_child_short_key.common_prefix_length(&leaf_short_key); + // child. This common prefix will be the extension of the new branch + let common_prefix_len = first_child_short_key.common_prefix_length(&new_child_short_key); // Trim off the common prefix from the first child's short key, plus one nibble which will // stored by the new branch itself in its state mask. @@ -372,47 +405,42 @@ where first_child.trim_short_key_prefix(common_prefix_len + 1); // Similarly, trim off the common prefix, plus one nibble for the new branch, from the new - // leaf's short key. - let leaf_nibble = leaf_short_key.get_unchecked(common_prefix_len); - let leaf_short_key = trim_nibbles_prefix(&leaf_short_key, common_prefix_len + 1); + // child's short key. + let new_child_nibble = new_child_short_key.get_unchecked(common_prefix_len); + let new_child_short_key = trim_nibbles_prefix(&new_child_short_key, common_prefix_len + 1); - // Push the new branch onto the branch stack. We do not yet set the `state_mask` bit of the - // new leaf; `push_new_leaf` will do that. + // Update the branch path to reflect the new branch about to be pushed. Its path will be + // the path of the previous branch, plus the nibble shared by each child, plus the parent + // extension (denoted by a non-zero `ext_len`). Since the new branch's path is a prefix of + // the original new_child_path we can just slice that. + // + // If the new branch is the first branch then we do not add the extra 1, as there is no + // nibble in a parent branch to account for. + let branch_path_len = + self.branch_path.len() + common_prefix_len + self.maybe_parent_nibble(); + self.branch_path = new_child_path.slice_unchecked(0, branch_path_len); + + // Push the new branch onto the `branch_stack`. We do not yet set the `state_mask` bit of + // the new child; whatever actually pushes the child onto the `child_stack` is expected to + // do that. self.branch_stack.push(ProofTrieBranch { ext_len: common_prefix_len as u8, state_mask: TrieMask::new(1 << first_child_nibble), - tree_mask: TrieMask::default(), - hash_mask: TrieMask::default(), + masks: TrieMasks::none(), }); - // Update the branch path to reflect the new branch which was just pushed. Its path will be - // the path of the previous branch, plus the nibble shared by each child, plus the parent - // extension (denoted by a non-zero `ext_len`). Since the new branch's path is a prefix of - // the original leaf_key we can just slice that. - // - // If the branch is the first branch then we do not add the extra 1, as there is no nibble - // in a parent branch to account for. - let branch_path_len = self.branch_path.len() + - common_prefix_len + - if self.branch_stack.len() == 1 { 0 } else { 1 }; - self.branch_path = leaf_key.slice_unchecked(0, branch_path_len); - - // Push the new leaf onto the new branch. This step depends on the top branch being in the - // correct state, so must be done last. - self.push_new_leaf(targets, leaf_nibble, leaf_short_key, leaf_val)?; - trace!( target: TRACE_TARGET, - ?leaf_short_key, + ?new_child_path, ?common_prefix_len, - new_branch = ?self.branch_stack.last().expect("branch_stack was just pushed to"), - ?branch_path_len, + ?first_child_nibble, branch_path = ?self.branch_path, - "push_new_branch: returning", + "Pushed new branch", ); - Ok(()) + (new_child_nibble, new_child_short_key) } + /// Pops the top branch off of the `branch_stack`, hashes its children on the `child_stack`, and /// replaces those children on the `child_stack`. The `branch_path` field will be updated /// accordingly. @@ -449,10 +477,12 @@ where ); // Collect children into an `RlpNode` Vec by committing and pushing each of them. - for child in self.child_stack.drain(self.child_stack.len() - num_children..) { + for (idx, child) in + self.child_stack.drain(self.child_stack.len() - num_children..).enumerate() + { let ProofTrieBranchChild::RlpNode(child_rlp_node) = child else { panic!( - "all branch child must have been committed, found {}", + "all branch children must have been committed, found {} at index {idx:?}", std::any::type_name_of_val(&child) ); }; @@ -473,8 +503,10 @@ where ); // Wrap the `BranchNode` so it can be pushed onto the child stack. - let mut branch_as_child = - ProofTrieBranchChild::Branch(BranchNode::new(rlp_nodes_buf, branch.state_mask)); + let mut branch_as_child = ProofTrieBranchChild::Branch { + node: BranchNode::new(rlp_nodes_buf, branch.state_mask), + masks: branch.masks, + }; // If there is an extension then encode the branch as an `RlpNode` and use it to construct // the extension in its place @@ -487,9 +519,8 @@ where // Update the branch_path. If this branch is the only branch then only its extension needs // to be trimmed, otherwise we also need to remove its nibble from its parent. - let new_path_len = self.branch_path.len() - - branch.ext_len as usize - - if self.branch_stack.is_empty() { 0 } else { 1 }; + let new_path_len = + self.branch_path.len() - branch.ext_len as usize - self.maybe_parent_nibble(); debug_assert!(self.branch_path.len() >= new_path_len); self.branch_path = self.branch_path.slice_unchecked(0, new_path_len); @@ -499,7 +530,7 @@ where /// Adds a single leaf for a key to the stack, possibly collapsing an existing branch and/or /// creating a new one depending on the path of the key. - fn add_leaf( + fn push_leaf( &mut self, targets: &mut TargetsIter>, key: Nibbles, @@ -512,12 +543,12 @@ where branch_stack_len = ?self.branch_stack.len(), branch_path = ?self.branch_path, child_stack_len = ?self.child_stack.len(), - "add_leaf: loop", + "push_leaf: loop", ); - // Get the `state_mask` of the branch currently being built. If there are no branches on - // the stack then it means either the trie is empty or only a single leaf has been added - // previously. + // Get the `state_mask` of the branch currently being built. If there are no branches + // on the stack then it means either the trie is empty or only a single leaf has been + // added previously. let curr_branch_state_mask = match self.branch_stack.last() { Some(curr_branch) => curr_branch.state_mask, None if self.child_stack.is_empty() => { @@ -536,7 +567,8 @@ where .expect("already checked for emptiness") .short_key() .is_empty()); - self.push_new_branch(targets, key, val)?; + let (nibble, short_key) = self.push_new_branch(key); + self.push_new_leaf(targets, nibble, short_key, val)?; return Ok(()) } }; @@ -559,8 +591,11 @@ where // existing child. let nibble = key.get_unchecked(common_prefix_len); if curr_branch_state_mask.is_bit_set(nibble) { - // This method will also push the new leaf onto the `child_stack`. - self.push_new_branch(targets, key, val)?; + // Push a new branch which splits the short key of the existing child at this + // nibble. + let (nibble, short_key) = self.push_new_branch(key); + // Push the new leaf onto the new branch. + self.push_new_leaf(targets, nibble, short_key, val)?; } else { let short_key = key.slice_unchecked(common_prefix_len + 1, key.len()); self.push_new_leaf(targets, nibble, short_key, val)?; @@ -570,27 +605,485 @@ where } } + /// Given the lower and upper bounds (exclusive) of a range of keys, iterates over the + /// `hashed_cursor` and calculates all trie nodes possible based on those keys. If the upper + /// bound is None then it is considered unbounded. + /// + /// It is expected that this method is "driven" by `next_uncached_key_range`, which decides + /// which ranges of keys need to be calculated based on what cached trie data is available. + #[instrument( + target = TRACE_TARGET, + level = "trace", + skip(self, value_encoder, targets, hashed_cursor_current), + )] + fn calculate_key_range( + &mut self, + value_encoder: &VE, + targets: &mut TargetsIter>, + hashed_cursor_current: &mut Option<(Nibbles, VE::DeferredEncoder)>, + lower_bound: Nibbles, + upper_bound: Option, + ) -> Result<(), StateProofError> { + // A helper closure for mapping entries returned from the `hashed_cursor`, converting the + // key to Nibbles and immediately creating the DeferredValueEncoder so that encoding of the + // leaf value can begin ASAP. + let map_hashed_cursor_entry = |(key_b256, val): (B256, _)| { + debug_assert_eq!(key_b256.len(), 32); + // SAFETY: key is a B256 and so is exactly 32-bytes. + let key = unsafe { Nibbles::unpack_unchecked(key_b256.as_slice()) }; + let val = value_encoder.deferred_encoder(key_b256, val); + (key, val) + }; + + // If the cursor hasn't been used, or the last iterated key is prior to this range's + // key range, then seek forward to at least the first key. + if hashed_cursor_current.as_ref().is_none_or(|(key, _)| key < &lower_bound) { + let lower_key = B256::right_padding_from(&lower_bound.pack()); + *hashed_cursor_current = + self.hashed_cursor.seek(lower_key)?.map(map_hashed_cursor_entry); + } + + // Loop over all keys in the range, calling `push_leaf` on each. + while let Some((key, _)) = hashed_cursor_current.as_ref() && + upper_bound.is_none_or(|upper_bound| key < &upper_bound) + { + let (key, val) = + core::mem::take(hashed_cursor_current).expect("while-let checks for Some"); + self.push_leaf(targets, key, val)?; + *hashed_cursor_current = self.hashed_cursor.next()?.map(map_hashed_cursor_entry); + } + + Ok(()) + } + + /// Constructs and returns a new [`ProofTrieBranch`] based on an existing [`BranchNodeCompact`]. + #[inline] + const fn new_from_cached_branch( + cached_branch: &BranchNodeCompact, + ext_len: u8, + ) -> ProofTrieBranch { + ProofTrieBranch { + ext_len, + state_mask: TrieMask::new(0), + masks: TrieMasks { + tree_mask: Some(cached_branch.tree_mask), + hash_mask: Some(cached_branch.hash_mask), + }, + } + } + + /// Pushes a new branch onto the `branch_stack` which is based on a cached branch obtained via + /// the trie cursor. + /// + /// If there is already a child at the top branch of `branch_stack` occupying this new branch's + /// nibble then that child will have its short-key split with another new branch, and this + /// cached branch will be a child of that splitting branch. + fn push_cached_branch( + &mut self, + targets: &mut TargetsIter>, + cached_path: Nibbles, + cached_branch: &BranchNodeCompact, + ) -> Result<(), StateProofError> { + debug_assert!( + cached_path.starts_with(&self.branch_path), + "push_cached_branch called with path {cached_path:?} which is not a child of current branch {:?}", + self.branch_path, + ); + + let parent_branch = self.branch_stack.last(); + + // If both stacks are empty then there were no leaves before this cached branch, push it and + // be done; the extension of the branch will be its full path. + if self.child_stack.is_empty() && parent_branch.is_none() { + self.branch_path = cached_path; + self.branch_stack + .push(Self::new_from_cached_branch(cached_branch, cached_path.len() as u8)); + return Ok(()) + } + + // Get the nibble which should be set in the parent branch's `state_mask` for this new + // branch. + let cached_branch_nibble = cached_path.get_unchecked(self.branch_path.len()); + + // We calculate the `ext_len` of the new branch, and potentially update its nibble if a new + // parent branch is inserted here, based on the state of the parent branch. + let (cached_branch_nibble, ext_len) = if parent_branch + .is_none_or(|parent_branch| parent_branch.state_mask.is_bit_set(cached_branch_nibble)) + { + // If the `child_stack` is not empty but the `branch_stack` is then it implies that + // there must be a leaf or extension at the root of the trie whose short-key will get + // split by a new branch, which will become the parent of both that leaf/extension and + // this new branch. + // + // Similarly, if there is a branch on the `branch_stack` but its `state_mask` bit for + // this new branch is already set, then there must be a leaf/extension with a short-key + // to be split. + debug_assert!(!self + .child_stack + .last() + .expect("already checked for emptiness") + .short_key() + .is_empty()); + + // Split that leaf/extension's short key with a new branch. + let (nibble, short_key) = self.push_new_branch(cached_path); + (nibble, short_key.len()) + } else { + // If there is a parent branch but its `state_mask` bit for this branch is not set + // then we can simply calculate the `ext_len` based on the difference of each, minus + // 1 to account for the nibble in the `state_mask`. + (cached_branch_nibble, cached_path.len() - self.branch_path.len() - 1) + }; + + // `commit_last_child` relies on the last set bit of the parent branch's `state_mask` to + // determine the path of the last child on the `child_stack`. Since we are about to + // change that mask we need to commit that last child first. + self.commit_last_child(targets)?; + + // When pushing a new branch we need to set its child nibble in the `state_mask` of + // its parent, if there is one. + if let Some(parent_branch) = self.branch_stack.last_mut() { + parent_branch.state_mask.set_bit(cached_branch_nibble); + } + + // Finally update the `branch_path` and push the new branch. + self.branch_path = cached_path; + self.branch_stack.push(Self::new_from_cached_branch(cached_branch, ext_len as u8)); + + trace!( + target: TRACE_TARGET, + branch=?self.branch_stack.last(), + branch_path=?self.branch_path, + "Pushed cached branch", + ); + + Ok(()) + } + + /// Attempts to pop off the top branch of the `cached_branch_stack`, returning + /// [`PopCachedBranchOutcome::Popped`] on success. Returns other variants to indicate that the + /// stack is empty and what to do about it. + /// + /// This method only returns [`PopCachedBranchOutcome::CalculateLeaves`] if there is a cached + /// branch on top of the stack. + #[inline] + fn try_pop_cached_branch( + &mut self, + trie_cursor_state: &mut TrieCursorState, + uncalculated_lower_bound: &Option, + ) -> Result { + // If there is a branch on top of the stack we use that. + if let Some(cached) = self.cached_branch_stack.pop() { + return Ok(PopCachedBranchOutcome::Popped(cached)); + } + + // There is no cached branch on the stack. It's possible that another one exists + // farther on in the trie, but we perform some checks first to prevent unnecessary + // attempts to find it. + + // If the `uncalculated_lower_bound` is None it indicates that there can be no more + // leaf data, so similarly there be no more branches. + let Some(uncalculated_lower_bound) = uncalculated_lower_bound else { + return Ok(PopCachedBranchOutcome::Exhausted) + }; + + // If [`TrieCursorState::path`] returns None it means that the cursor has been + // exhausted, so there can be no more cached data. + let Some(trie_cursor_path) = trie_cursor_state.path() else { + return Ok(PopCachedBranchOutcome::Exhausted) + }; + + // If the trie cursor is seeked to a branch whose leaves have already been processed + // then we can't use it, instead we seek forward and try again. + if trie_cursor_path < uncalculated_lower_bound { + *trie_cursor_state = + TrieCursorState::new(self.trie_cursor.seek(*uncalculated_lower_bound)?); + + // Having just seeked forward we need to check if the cursor is now exhausted. + if matches!(trie_cursor_state, TrieCursorState::Exhausted) { + return Ok(PopCachedBranchOutcome::Exhausted) + }; + } + + // At this point we can be sure that the cursor is in an `Available` state. We know for + // sure it's not `Exhausted` because of the call to `path` above, and we know it's not + // `Taken` because we push all taken branches onto the `cached_branch_stack`, and the + // stack is empty. + // + // We will use this `Available` cached branch as our next branch. + let cached = trie_cursor_state.take(); + trace!(target: TRACE_TARGET, cached=?cached, "Pushed next trie node onto cached_branch_stack"); + + // If the calculated range is not caught up to the next cached branch it means there + // are portions of the trie prior to that branch which may need to be calculated; + // return the uncalculated range up to that branch to make that happen. + // + // If the next cached branch's path is all zeros then we can skip this catch-up step, + // because there cannot be any keys prior to that range. + let cached_path = &cached.0; + if uncalculated_lower_bound < cached_path && !PATH_ALL_ZEROS.starts_with(cached_path) { + let range = (*uncalculated_lower_bound, Some(*cached_path)); + trace!(target: TRACE_TARGET, ?range, "Returning key range to calculate in order to catch up to cached branch"); + + // Push the cached branch onto the stack so it's available once the leaf range is done + // being calculated. + self.cached_branch_stack.push(cached); + + return Ok(PopCachedBranchOutcome::CalculateLeaves(range)); + } + + Ok(PopCachedBranchOutcome::Popped(cached)) + } + + /// Accepts the current state of both hashed and trie cursors, and determines the next range of + /// hashed keys which need to be processed using [`Self::push_leaf`]. + /// + /// This method will use cached branch node data from the trie cursor to skip over all possible + /// ranges of keys, to reduce computation as much as possible. + /// + /// # Returns + /// + /// - `None`: No more data to process, finish computation + /// + /// - `Some(lower, None)`: Indicates to call `push_leaf` on all keys starting at `lower`, with + /// no upper bound. This method won't be called again after this. + /// + /// - `Some(lower, Some(upper))`: Indicates to call `push_leaf` on all keys starting at `lower`, + /// up to but excluding `upper`, and then call this method once done. + #[instrument(target = TRACE_TARGET, level = "trace", skip_all)] + fn next_uncached_key_range( + &mut self, + targets: &mut TargetsIter>, + trie_cursor_state: &mut TrieCursorState, + hashed_key_current_path: Option, + ) -> Result)>, StateProofError> { + // Pop any under-construction branches that are now complete. + // All trie data prior to the current cached branch, if any, has been computed. Any branches + // which were under-construction previously, and which are not on the same path as this + // cached branch, can be assumed to be completed; they will not have any further keys added + // to them. + if let Some(cached_path) = self.cached_branch_stack.last().map(|kv| kv.0) { + while !cached_path.starts_with(&self.branch_path) { + self.pop_branch(targets)?; + } + } + + // `uncalculated_lower_bound` tracks the lower bound of node paths which have yet to be + // visited, either via the hashed key cursor (`calculate_key_range`) or trie cursor (this + // method). If this is None then there are no further nodes which could exist. + // + // This starts off being based on the hashed cursor's current position, which is the + // next hashed key which hasn't been processed. If that is None then we start from zero. + let mut uncalculated_lower_bound = Some(hashed_key_current_path.unwrap_or_default()); + + loop { + // Pop the currently cached branch node. + // + // NOTE we pop off the `cached_branch_stack` because cloning the `BranchNodeCompact` + // means cloning an Arc, which incurs synchronization overhead. We have to be sure to + // push the cached branch back onto the stack once done. + let (cached_path, cached_branch) = match self + .try_pop_cached_branch(trie_cursor_state, &uncalculated_lower_bound)? + { + PopCachedBranchOutcome::Popped(cached) => cached, + PopCachedBranchOutcome::Exhausted => { + // If cached branches are exhausted it's possible that there is still an + // unbounded range of leaves to be processed. `uncalculated_lower_bound` is + // used to return that range. + trace!(target: TRACE_TARGET, ?uncalculated_lower_bound, "Exhausted cached trie nodes"); + return Ok(uncalculated_lower_bound.map(|lower| (lower, None))); + } + PopCachedBranchOutcome::CalculateLeaves(range) => { + return Ok(Some(range)); + } + }; + + trace!( + target: TRACE_TARGET, + branch_path = ?self.branch_path, + branch_state_mask = ?self.branch_stack.last().map(|b| b.state_mask), + ?cached_path, + cached_branch_state_mask = ?cached_branch.state_mask, + cached_branch_hash_mask = ?cached_branch.hash_mask, + "loop", + ); + + // Since we've popped all branches which don't start with cached_path, branch_path at + // this point must be equal to or shorter than cached_path. + debug_assert!( + self.branch_path.len() < cached_path.len() || self.branch_path == cached_path, + "branch_path {:?} is different-or-longer-than cached_path {cached_path:?}", + self.branch_path + ); + + // If the branch_path != cached_path it means the branch_stack is either empty, or the + // top branch is the parent of this cached branch. Either way we push a branch + // corresponding to the cached one onto the stack, so we can begin constructing it. + if self.branch_path != cached_path { + self.push_cached_branch(targets, cached_path, &cached_branch)?; + } + + // At this point the top of the branch stack is the same branch which was found in the + // cache. + let curr_branch = + self.branch_stack.last().expect("top of branch_stack corresponds to cached branch"); + + let cached_state_mask = cached_branch.state_mask.get(); + let curr_state_mask = curr_branch.state_mask.get(); + + // Determine all child nibbles which are set in the cached branch but not the + // under-construction branch. + let next_child_nibbles = curr_state_mask ^ cached_state_mask; + debug_assert_eq!( + cached_state_mask | next_child_nibbles, cached_state_mask, + "curr_branch has state_mask bits set which aren't set on cached_branch. curr_branch:{:?}", + curr_state_mask, + ); + + // If there are no further children to construct for this branch then pop it off both + // stacks and loop using the parent branch. + if next_child_nibbles == 0 { + trace!( + target: TRACE_TARGET, + path=?cached_path, + ?curr_branch, + ?cached_branch, + "No further children, popping branch", + ); + self.pop_branch(targets)?; + + // no need to pop from `cached_branch_stack`, the current cached branch is already + // popped (see note at the top of the loop). + + // The just-popped branch is completely processed; we know there can be no more keys + // with that prefix. Set the lower bound which can be returned from this method to + // be the next possible prefix, if any. + uncalculated_lower_bound = increment_and_strip_trailing_zeros(&cached_path); + + continue + } + + // Determine the next nibble of the branch which has not yet been constructed, and + // determine the child's full path. + let child_nibble = next_child_nibbles.trailing_zeros() as u8; + let child_path = self.child_path_at(child_nibble); + + // If the `hash_mask` bit is set for the next child it means the child's hash is cached + // in the `cached_branch`. We can use that instead of re-calculating the hash of the + // entire sub-trie. + // + // If the child needs to be retained for a proof then we should not use the cached + // hash, and instead continue on to calculate its node manually. + if cached_branch.hash_mask.is_bit_set(child_nibble) { + // Commit the last child. We do this here for two reasons: + // - `commit_last_child` will check if the last child needs to be retained. We need + // to check that before the subsequent `should_retain` call here to prevent + // `targets` from being moved beyond the last child before it is checked. + // - If we do end up using the cached hash value, then we will need to commit the + // last child before pushing a new one onto the stack anyway. + self.commit_last_child(targets)?; + + if !self.should_retain(targets, &child_path) { + // Pull this child's hash out of the cached branch node. To get the hash's index + // we first need to calculate the mask of which cached hashes have already been + // used by this branch (if any). The number of set bits in that mask will be the + // index of the next hash in the array to use. + let curr_hashed_used_mask = cached_branch.hash_mask.get() & curr_state_mask; + let hash_idx = curr_hashed_used_mask.count_ones() as usize; + let hash = cached_branch.hashes[hash_idx]; + + trace!( + target: TRACE_TARGET, + ?child_path, + ?hash_idx, + ?hash, + "Using cached hash for child", + ); + + self.child_stack.push(ProofTrieBranchChild::RlpNode(RlpNode::word_rlp(&hash))); + self.branch_stack + .last_mut() + .expect("already asserted there is a last branch") + .state_mask + .set_bit(child_nibble); + + // Update the `uncalculated_lower_bound` to indicate that the child whose bit + // was just set is completely processed. + uncalculated_lower_bound = increment_and_strip_trailing_zeros(&child_path); + + // Push the current cached branch back onto the stack before looping. + self.cached_branch_stack.push((cached_path, cached_branch)); + + continue + } + } + + // We now want to check if there is a cached branch node at this child. The cached + // branch node may be the node at this child directly, or this child may be an + // extension and the cached branch is the child of that extension. + + // All trie nodes prior to `child_path` will not be modified further, so we can seek the + // trie cursor to the next cached node at-or-after `child_path`. + if trie_cursor_state.path().is_some_and(|path| path < &child_path) { + trace!(target: TRACE_TARGET, ?child_path, "Seeking trie cursor to child path"); + *trie_cursor_state = TrieCursorState::new(self.trie_cursor.seek(child_path)?); + } + + // If the next cached branch node is a child of `child_path` then we can assume it is + // the cached branch for this child. We push it onto the `cached_branch_stack` and loop + // back to the top. + if let TrieCursorState::Available(next_cached_path, next_cached_branch) = + &trie_cursor_state && + next_cached_path.starts_with(&child_path) + { + // Push the current cached branch back on before pushing its child and then looping + self.cached_branch_stack.push((cached_path, cached_branch)); + + trace!( + target: TRACE_TARGET, + ?child_path, + ?next_cached_path, + ?next_cached_branch, + "Pushing cached branch for child", + ); + self.cached_branch_stack.push(trie_cursor_state.take()); + continue; + } + + // There is no cached data for the sub-trie at this child, we must recalculate the + // sub-trie root (this child) using the leaves. Return the range of keys based on the + // child path. + let child_path_upper = increment_and_strip_trailing_zeros(&child_path); + trace!( + target: TRACE_TARGET, + lower=?child_path, + upper=?child_path_upper, + "Returning sub-trie's key range to calculate", + ); + + // Push the current cached branch back onto the stack before returning. + self.cached_branch_stack.push((cached_path, cached_branch)); + + return Ok(Some((child_path, child_path_upper))); + } + } + /// Internal implementation of proof calculation. Assumes both cursors have already been reset. /// See docs on [`Self::proof`] for expected behavior. fn proof_inner( &mut self, value_encoder: &VE, - targets: impl IntoIterator, + targets: impl IntoIterator, ) -> Result, StateProofError> { trace!(target: TRACE_TARGET, "proof_inner: called"); // In debug builds, verify that targets are sorted #[cfg(debug_assertions)] let targets = { - let mut prev: Option = None; + let mut prev: Option = None; targets.into_iter().inspect(move |target| { if let Some(prev) = prev { - debug_assert!( - depth_first::cmp(&prev, target) != Ordering::Greater, - "targets must be sorted depth-first, instead {:?} > {:?}", - prev, - target - ); + debug_assert!(&prev <= target, "prev:{prev:?} target:{target:?}"); } prev = Some(*target); }) @@ -599,6 +1092,12 @@ where #[cfg(not(debug_assertions))] let targets = targets.into_iter(); + // Convert B256 targets into Nibbles. + let targets = targets.into_iter().map(|key| { + // SAFETY: key is a B256 and so is exactly 32-bytes. + unsafe { Nibbles::unpack_unchecked(key.as_slice()) } + }); + // Wrap targets into a `TargetsIter`. let mut targets = WindowIter::new(targets).peekable(); @@ -614,36 +1113,42 @@ where debug_assert!(self.branch_path.is_empty()); debug_assert!(self.child_stack.is_empty()); - let mut hashed_cursor_current = self.hashed_cursor.seek(B256::ZERO)?; + // Initialize the hashed cursor to None to indicate it hasn't been seeked yet. + let mut hashed_cursor_current: Option<(Nibbles, VE::DeferredEncoder)> = None; + + // Initialize the `trie_cursor_state` with the node closest to root. + let mut trie_cursor_state = TrieCursorState::new(self.trie_cursor.seek(Nibbles::new())?); + loop { - trace!( - target: TRACE_TARGET, - ?hashed_cursor_current, - branch_stack_len = ?self.branch_stack.len(), - branch_path = ?self.branch_path, - child_stack_len = ?self.child_stack.len(), - "proof_inner: loop", - ); - - // Sanity check before making any further changes: - // If there is a branch, there must be at least two children - debug_assert!(self.branch_stack.last().is_none_or(|_| self.child_stack.len() >= 2)); - - // Fetch the next leaf from the hashed cursor, converting the key to Nibbles and - // immediately creating the DeferredValueEncoder so that encoding of the leaf value can - // begin ASAP. - let Some((key, val)) = hashed_cursor_current.map(|(key_b256, val)| { - debug_assert_eq!(key_b256.len(), 32); - // SAFETY: key is a B256 and so is exactly 32-bytes. - let key = unsafe { Nibbles::unpack_unchecked(key_b256.as_slice()) }; - let val = value_encoder.deferred_encoder(key_b256, val); - (key, val) - }) else { - break + // Determine the range of keys of the overall trie which need to be re-computed. + let Some((lower_bound, upper_bound)) = self.next_uncached_key_range( + &mut targets, + &mut trie_cursor_state, + hashed_cursor_current.as_ref().map(|kv| kv.0), + )? + else { + // If `next_uncached_key_range` determines that there can be no more keys then + // complete the computation. + break; }; - self.add_leaf(&mut targets, key, val)?; - hashed_cursor_current = self.hashed_cursor.next()?; + // Calculate the trie for that range of keys + self.calculate_key_range( + value_encoder, + &mut targets, + &mut hashed_cursor_current, + lower_bound, + upper_bound, + )?; + + // Once outside `calculate_key_range`, `hashed_cursor_current` will be at the first key + // after the range. + // + // If the `hashed_cursor_current` is None then there are no more keys at all, meaning + // the trie couldn't possibly have more data and we should complete computation. + if hashed_cursor_current.is_none() { + break; + } } // Once there's no more leaves we can pop the remaining branches, if any. @@ -689,8 +1194,8 @@ where { /// Generate a proof for the given targets. /// - /// Given depth-first sorted targets, returns nodes whose paths are a prefix of any target. The - /// returned nodes will be sorted lexicographically by path. + /// Given lexicographically sorted targets, returns nodes whose paths are a prefix of any + /// target. The returned nodes will be sorted lexicographically by path. /// /// # Panics /// @@ -699,7 +1204,7 @@ where pub fn proof( &mut self, value_encoder: &VE, - targets: impl IntoIterator, + targets: impl IntoIterator, ) -> Result, StateProofError> { self.trie_cursor.reset(); self.hashed_cursor.reset(); @@ -722,8 +1227,8 @@ where /// Generate a proof for a storage trie at the given hashed address. /// - /// Given depth-first sorted targets, returns nodes whose paths are a prefix of any target. The - /// returned nodes will be sorted lexicographically by path. + /// Given lexicographically sorted targets, returns nodes whose paths are a prefix of any + /// target. The returned nodes will be sorted lexicographically by path. /// /// # Panics /// @@ -732,7 +1237,7 @@ where pub fn storage_proof( &mut self, hashed_address: B256, - targets: impl IntoIterator, + targets: impl IntoIterator, ) -> Result, StateProofError> { /// Static storage value encoder instance used by all storage proofs. static STORAGE_VALUE_ENCODER: StorageValueEncoder = StorageValueEncoder; @@ -795,20 +1300,101 @@ impl> Iterator for WindowIter { } } +/// Used to track the state of the trie cursor, allowing us to differentiate between a branch having +/// been taken (used as a cached branch) and the cursor having been exhausted. +#[derive(Debug)] +enum TrieCursorState { + /// Cursor is seeked to this path and the node has not been used yet. + Available(Nibbles, BranchNodeCompact), + /// Cursor is seeked to this path, but the node has been used. + Taken(Nibbles), + /// Cursor has been exhausted. + Exhausted, +} + +impl TrieCursorState { + /// Creates a [`Self`] based on an entry returned from the cursor itself. + fn new(entry: Option<(Nibbles, BranchNodeCompact)>) -> Self { + entry.map_or(Self::Exhausted, |(path, node)| Self::Available(path, node)) + } + + /// Returns the path the cursor is seeked to, or None if it's exhausted. + const fn path(&self) -> Option<&Nibbles> { + match self { + Self::Available(path, _) | Self::Taken(path) => Some(path), + Self::Exhausted => None, + } + } + + /// Takes the path and node from a [`Self::Available`]. Panics if not [`Self::Available`]. + fn take(&mut self) -> (Nibbles, BranchNodeCompact) { + let Self::Available(path, _) = self else { + panic!("take called on non-Available: {self:?}") + }; + + let path = *path; + let Self::Available(path, node) = core::mem::replace(self, Self::Taken(path)) else { + unreachable!("already checked that self is Self::Available"); + }; + + (path, node) + } +} + +/// Describes the state of the currently cached branch node (if any). +enum PopCachedBranchOutcome { + /// Cached branch has been popped from the `cached_branch_stack` and is ready to be used. + Popped((Nibbles, BranchNodeCompact)), + /// All cached branches have been exhausted. + Exhausted, + /// Need to calculate leaves from this range (exclusive upper) before the cached branch + /// (catch-up range). If None then + CalculateLeaves((Nibbles, Option)), +} + +/// Increments the nibbles and strips any trailing zeros. +/// +/// This function wraps `Nibbles::increment` and when it returns a value with trailing zeros, +/// it strips those zeros using bit manipulation on the underlying U256. +fn increment_and_strip_trailing_zeros(nibbles: &Nibbles) -> Option { + let mut result = nibbles.increment()?; + + // If result is empty, just return it + if result.is_empty() { + return Some(result); + } + + // Get access to the underlying U256 to detect trailing zeros + let uint_val = *result.as_mut_uint_unchecked(); + let non_zero_prefix_len = 64 - (uint_val.trailing_zeros() / 4); + result.truncate(non_zero_prefix_len); + + Some(result) +} + #[cfg(test)] mod tests { use super::*; use crate::{ - hashed_cursor::{mock::MockHashedCursorFactory, HashedCursorFactory}, + hashed_cursor::{ + mock::MockHashedCursorFactory, HashedCursorFactory, HashedCursorMetricsCache, + InstrumentedHashedCursor, + }, proof::Proof, - trie_cursor::{depth_first, mock::MockTrieCursorFactory, TrieCursorFactory}, + trie_cursor::{ + depth_first, mock::MockTrieCursorFactory, InstrumentedTrieCursor, TrieCursorFactory, + TrieCursorMetricsCache, + }, }; use alloy_primitives::map::{B256Map, B256Set}; use alloy_rlp::Decodable; use assert_matches::assert_matches; use itertools::Itertools; - use reth_trie_common::{HashedPostState, MultiProofTargets, TrieNode}; - use std::collections::BTreeMap; + use reth_primitives_traits::Account; + use reth_trie_common::{ + updates::{StorageTrieUpdates, TrieUpdates}, + HashedPostState, MultiProofTargets, TrieNode, + }; /// Target to use with the `tracing` crate. static TRACE_TARGET: &str = "trie::proof_v2::tests"; @@ -828,25 +1414,38 @@ mod tests { /// Creates a new test harness from a `HashedPostState`. /// /// The `HashedPostState` is used to populate the mock hashed cursor factory directly. - /// The trie cursor factory is empty by default, suitable for testing the leaf-only - /// proof calculator. + /// The trie cursor factory is initialized from `TrieUpdates` generated by `StateRoot`. fn new(post_state: HashedPostState) -> Self { - trace!(target: TRACE_TARGET, ?post_state, "Creating ProofTestHarness"); - - // Ensure that there's a storage trie dataset for every storage trie, even if empty. - let storage_trie_nodes: B256Map> = post_state - .storages + // Create empty trie cursor factory to serve as the initial state for StateRoot + // Ensure that there's a storage trie dataset for every account, to make + // `MockTrieCursorFactory` happy. + let storage_tries: B256Map<_> = post_state + .accounts .keys() .copied() - .map(|addr| (addr, Default::default())) + .map(|addr| (addr, StorageTrieUpdates::default())) .collect(); + let empty_trie_cursor_factory = MockTrieCursorFactory::from_trie_updates(TrieUpdates { + storage_tries: storage_tries.clone(), + ..Default::default() + }); + // Create mock hashed cursor factory from the post state let hashed_cursor_factory = MockHashedCursorFactory::from_hashed_post_state(post_state); - // Create empty trie cursor factory (leaf-only calculator doesn't need trie nodes) - let trie_cursor_factory = - MockTrieCursorFactory::new(BTreeMap::new(), storage_trie_nodes); + // Generate TrieUpdates using StateRoot + let (_root, mut trie_updates) = + crate::StateRoot::new(empty_trie_cursor_factory, hashed_cursor_factory.clone()) + .root_with_updates() + .expect("StateRoot should succeed"); + + // Continue using empty storage tries for each account, to keep `MockTrieCursorFactory` + // happy. + trie_updates.storage_tries = storage_tries; + + // Initialize trie cursor factory from the generated TrieUpdates + let trie_cursor_factory = MockTrieCursorFactory::from_trie_updates(trie_updates); Self { trie_cursor_factory, hashed_cursor_factory } } @@ -858,12 +1457,10 @@ mod tests { /// the results. fn assert_proof( &self, - targets: impl IntoIterator + Clone, + targets: impl IntoIterator, ) -> Result<(), StateProofError> { - // Convert B256 targets to Nibbles for proof_v2 - let targets_vec: Vec = targets.into_iter().collect(); - let nibbles_targets: Vec = - targets_vec.iter().map(|b256| Nibbles::unpack(b256.as_slice())).sorted().collect(); + let targets_vec = targets.into_iter().sorted().collect::>(); + // Convert B256 targets to MultiProofTargets for legacy implementation // For account-only proofs, each account maps to an empty storage set let legacy_targets = targets_vec @@ -875,17 +1472,29 @@ mod tests { let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?; let hashed_cursor = self.hashed_cursor_factory.hashed_account_cursor()?; + // Collect metrics for cursors + let mut trie_cursor_metrics = TrieCursorMetricsCache::default(); + let trie_cursor = InstrumentedTrieCursor::new(trie_cursor, &mut trie_cursor_metrics); + let mut hashed_cursor_metrics = HashedCursorMetricsCache::default(); + let hashed_cursor = + InstrumentedHashedCursor::new(hashed_cursor, &mut hashed_cursor_metrics); + // Call ProofCalculator::proof with account targets let value_encoder = SyncAccountValueEncoder::new( self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone(), ); let mut proof_calculator = ProofCalculator::new(trie_cursor, hashed_cursor); - let proof_v2_result = proof_calculator.proof(&value_encoder, nibbles_targets)?; + let proof_v2_result = proof_calculator.proof(&value_encoder, targets_vec.clone())?; + + // Output metrics + trace!(target: TRACE_TARGET, ?trie_cursor_metrics, "V2 trie cursor metrics"); + trace!(target: TRACE_TARGET, ?hashed_cursor_metrics, "V2 hashed cursor metrics"); // Call Proof::multiproof (legacy implementation) let proof_legacy_result = Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone()) + .with_branch_node_masks(true) .multiproof(legacy_targets)?; // Decode and sort legacy proof nodes @@ -897,10 +1506,12 @@ mod tests { let node = TrieNode::decode(&mut buf) .expect("legacy implementation should not produce malformed proof nodes"); - ProofTrieNode { - path: *path, - node, - masks: TrieMasks { + // The legacy proof calculator will calculate masks for the root node, even + // though we never store the root node so the masks for it aren't really valid. + let masks = if path.is_empty() { + TrieMasks::none() + } else { + TrieMasks { hash_mask: proof_legacy_result .branch_node_hash_masks .get(path) @@ -909,8 +1520,10 @@ mod tests { .branch_node_tree_masks .get(path) .copied(), - }, - } + } + }; + + ProofTrieNode { path: *path, node, masks } }) .sorted_by(|a, b| depth_first::cmp(&a.path, &b.path)) .collect::>(); @@ -927,7 +1540,7 @@ mod tests { } // Basic comparison: both should succeed and produce identical results - assert_eq!(proof_legacy_nodes, proof_v2_result); + pretty_assertions::assert_eq!(proof_legacy_nodes, proof_v2_result); Ok(()) } @@ -937,7 +1550,6 @@ mod tests { use super::*; use alloy_primitives::{map::B256Map, U256}; use proptest::prelude::*; - use reth_primitives_traits::Account; use reth_trie_common::HashedPostState; /// Generate a strategy for Account values @@ -953,21 +1565,14 @@ mod tests { /// Generate a strategy for `HashedPostState` with random accounts fn hashed_post_state_strategy() -> impl Strategy { - prop::collection::vec((any::<[u8; 32]>(), account_strategy()), 0..40).prop_map( + prop::collection::vec((any::<[u8; 32]>(), account_strategy()), 0..=100).prop_map( |accounts| { let account_map = accounts .into_iter() .map(|(addr_bytes, account)| (B256::from(addr_bytes), Some(account))) .collect::>(); - // All accounts have empty storages. - let storages = account_map - .keys() - .copied() - .map(|addr| (addr, Default::default())) - .collect::>(); - - HashedPostState { accounts: account_map, storages } + HashedPostState { accounts: account_map, ..Default::default() } }, ) } @@ -999,7 +1604,7 @@ mod tests { proptest! { #![proptest_config(ProptestConfig::with_cases(8000))] - + #[test] /// Tests that ProofCalculator produces valid proofs for randomly generated /// HashedPostState with proof targets. /// @@ -1009,11 +1614,12 @@ mod tests { /// - Creates a test harness with the generated state /// - Calls assert_proof with the generated targets /// - Verifies both ProofCalculator and legacy Proof produce equivalent results - #[test] fn proptest_proof_with_targets( (post_state, targets) in hashed_post_state_strategy() .prop_flat_map(|post_state| { - let account_keys: Vec = post_state.accounts.keys().copied().collect(); + let mut account_keys: Vec = post_state.accounts.keys().copied().collect(); + // Sort to ensure deterministic order when using PROPTEST_RNG_SEED + account_keys.sort_unstable(); let targets_strategy = proof_targets_strategy(account_keys); (Just(post_state), targets_strategy) }) @@ -1026,4 +1632,66 @@ mod tests { } } } + + #[test] + fn test_big_trie() { + use rand::prelude::*; + + reth_tracing::init_test_tracing(); + let mut rng = rand::rngs::SmallRng::seed_from_u64(1); + + let mut rand_b256 = || { + let mut buf: [u8; 32] = [0; 32]; + rng.fill_bytes(&mut buf); + B256::from_slice(&buf) + }; + + // Generate random HashedPostState. + let mut post_state = HashedPostState::default(); + for _ in 0..10240 { + let hashed_addr = rand_b256(); + let account = Account { bytecode_hash: Some(hashed_addr), ..Default::default() }; + post_state.accounts.insert(hashed_addr, Some(account)); + } + + // Collect targets; partially from real keys, partially random keys which probably won't + // exist. + let num_real_targets = post_state.accounts.len() * 5; + let mut targets = + post_state.accounts.keys().copied().sorted().take(num_real_targets).collect::>(); + for _ in 0..post_state.accounts.len() / 5 { + targets.push(rand_b256()); + } + targets.sort(); + + // Create test harness + let harness = ProofTestHarness::new(post_state); + + // Assert the proof + harness.assert_proof(targets).expect("Proof generation failed"); + } + + #[test] + fn test_increment_and_strip_trailing_zeros() { + let test_cases: Vec<(Nibbles, Option)> = vec![ + // Basic increment without trailing zeros + (Nibbles::from_nibbles([0x1, 0x2, 0x3]), Some(Nibbles::from_nibbles([0x1, 0x2, 0x4]))), + // Increment with trailing zeros - should be stripped + (Nibbles::from_nibbles([0x0, 0x0, 0xF]), Some(Nibbles::from_nibbles([0x0, 0x1]))), + (Nibbles::from_nibbles([0x0, 0xF, 0xF]), Some(Nibbles::from_nibbles([0x1]))), + // Overflow case + (Nibbles::from_nibbles([0xF, 0xF, 0xF]), None), + // Empty nibbles + (Nibbles::new(), None), + // Single nibble + (Nibbles::from_nibbles([0x5]), Some(Nibbles::from_nibbles([0x6]))), + // All Fs except last - results in trailing zeros after increment + (Nibbles::from_nibbles([0xE, 0xF, 0xF]), Some(Nibbles::from_nibbles([0xF]))), + ]; + + for (input, expected) in test_cases { + let result = increment_and_strip_trailing_zeros(&input); + assert_eq!(result, expected, "Failed for input: {:?}", input); + } + } } diff --git a/crates/trie/trie/src/proof_v2/node.rs b/crates/trie/trie/src/proof_v2/node.rs index 536665f19a..9300123fbe 100644 --- a/crates/trie/trie/src/proof_v2/node.rs +++ b/crates/trie/trie/src/proof_v2/node.rs @@ -25,7 +25,12 @@ pub(crate) enum ProofTrieBranchChild { child: RlpNode, }, /// A branch node whose children have already been flattened into [`RlpNode`]s. - Branch(BranchNode), + Branch { + /// The node itself, for use during RLP encoding. + node: BranchNode, + /// Bitmasks carried over from cached `BranchNodeCompact` values, if any. + masks: TrieMasks, + }, /// A node whose type is not known, as it has already been converted to an [`RlpNode`]. RlpNode(RlpNode), } @@ -64,7 +69,7 @@ impl ProofTrieBranchChild { ExtensionNodeRef::new(&short_key, child.as_slice()).encode(buf); Ok((RlpNode::from_rlp(buf), None)) } - Self::Branch(branch_node) => { + Self::Branch { node: branch_node, .. } => { branch_node.encode(buf); Ok((RlpNode::from_rlp(buf), Some(branch_node.stack))) } @@ -98,8 +103,7 @@ impl ProofTrieBranchChild { Self::Extension { short_key, child } => { (TrieNode::Extension(ExtensionNode { key: short_key, child }), TrieMasks::none()) } - // TODO store trie masks on branch - Self::Branch(branch_node) => (TrieNode::Branch(branch_node), TrieMasks::none()), + Self::Branch { node, masks } => (TrieNode::Branch(node), masks), Self::RlpNode(_) => panic!("Cannot call `into_proof_trie_node` on RlpNode"), }; @@ -111,7 +115,7 @@ impl ProofTrieBranchChild { pub(crate) fn short_key(&self) -> &Nibbles { match self { Self::Leaf { short_key, .. } | Self::Extension { short_key, .. } => short_key, - Self::Branch(_) | Self::RlpNode(_) => { + Self::Branch { .. } | Self::RlpNode(_) => { static EMPTY_NIBBLES: Nibbles = Nibbles::new(); &EMPTY_NIBBLES } @@ -136,7 +140,7 @@ impl ProofTrieBranchChild { Self::Leaf { short_key, .. } | Self::Extension { short_key, .. } => { *short_key = trim_nibbles_prefix(short_key, len); } - Self::Branch(_) | Self::RlpNode(_) => { + Self::Branch { .. } | Self::RlpNode(_) => { panic!("Cannot call `trim_short_key_prefix` on Branch or RlpNode") } } @@ -153,14 +157,8 @@ pub(crate) struct ProofTrieBranch { /// A mask tracking which child nibbles are set on the branch so far. There will be a single /// child on the stack for each set bit. pub(crate) state_mask: TrieMask, - /// A subset of `state_mask`. Each bit is set if the `state_mask` bit is set and: - /// - The child is a branch which is stored in the DB. - /// - The child is an extension whose child branch is stored in the DB. - #[expect(unused)] - pub(crate) tree_mask: TrieMask, - /// A subset of `state_mask`. Each bit is set if the hash for the child is cached in the DB. - #[expect(unused)] - pub(crate) hash_mask: TrieMask, + /// Bitmasks which are subsets of `state_mask`. + pub(crate) masks: TrieMasks, } /// Trims the first `len` nibbles from the head of the given `Nibbles`. diff --git a/crates/trie/trie/src/proof_v2/value.rs b/crates/trie/trie/src/proof_v2/value.rs index 9f5f97a271..b97e7579d4 100644 --- a/crates/trie/trie/src/proof_v2/value.rs +++ b/crates/trie/trie/src/proof_v2/value.rs @@ -7,7 +7,6 @@ use alloy_primitives::{B256, U256}; use alloy_rlp::Encodable; use reth_execution_errors::trie::StateProofError; use reth_primitives_traits::Account; -use reth_trie_common::Nibbles; use std::rc::Rc; /// A trait for deferred RLP-encoding of leaf values. @@ -124,7 +123,7 @@ where // Compute storage root by calling storage_proof with the root path as a target. // This returns just the root node of the storage trie. let storage_root = storage_proof_calculator - .storage_proof(self.hashed_address, [Nibbles::new()]) + .storage_proof(self.hashed_address, [B256::ZERO]) .map(|nodes| { // Encode the root node to RLP and hash it let root_node =