diff --git a/crates/trie/sparse/src/parallel.rs b/crates/trie/sparse/src/parallel.rs index 765e8b074b..f7028fdaf8 100644 --- a/crates/trie/sparse/src/parallel.rs +++ b/crates/trie/sparse/src/parallel.rs @@ -1,10 +1,8 @@ #[cfg(feature = "trie-debug")] use crate::debug_recorder::{LeafUpdateRecord, ProofTrieNodeRecord, RecordedOp, TrieDebugRecorder}; use crate::{ - lower::LowerSparseSubtrie, - provider::{RevealedNode, TrieNodeProvider}, - LeafLookup, LeafLookupError, RlpNodeStackItem, SparseNode, SparseNodeState, SparseNodeType, - SparseTrie, SparseTrieUpdates, + lower::LowerSparseSubtrie, provider::TrieNodeProvider, LeafLookup, LeafLookupError, + RlpNodeStackItem, SparseNode, SparseNodeState, SparseNodeType, SparseTrie, SparseTrieUpdates, }; use alloc::{borrow::Cow, boxed::Box, vec, vec::Vec}; use alloy_primitives::{ @@ -20,10 +18,10 @@ use reth_primitives_traits::FastInstant as Instant; use reth_trie_common::{ prefix_set::{PrefixSet, PrefixSetMut}, BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, - ProofTrieNodeV2, RlpNode, TrieNode, TrieNodeV2, + ProofTrieNodeV2, RlpNode, TrieNodeV2, }; use smallvec::SmallVec; -use tracing::{debug, instrument, trace}; +use tracing::{instrument, trace}; /// The maximum length of a path, in nibbles, which belongs to the upper subtrie of a /// [`ParallelSparseTrie`]. All longer paths belong to a lower subtrie. @@ -67,7 +65,7 @@ pub struct ParallelismThresholds { /// ## Node Revealing /// /// The trie uses lazy loading to efficiently handle large state tries. Nodes can be: -/// - **Blind nodes**: Stored as hashes ([`SparseNode::Hash`]), representing unloaded trie parts +/// - **Blind nodes**: Stored as hashes on [`SparseNode::Branch::blinded_hashes`] /// - **Revealed nodes**: Fully loaded nodes (Branch, Extension, Leaf) with complete structure /// /// Note: An empty trie contains an `EmptyRoot` node at the root path, rather than no nodes at all. @@ -186,6 +184,16 @@ impl SparseTrie for ParallelSparseTrie { self.set_updates(retain_updates); + if let Some(masks) = masks { + let branch_path = if let TrieNodeV2::Branch(branch) = &root { + branch.key + } else { + Nibbles::default() + }; + + self.branch_node_masks.insert(branch_path, masks); + } + self.reveal_upper_node(Nibbles::default(), &root, masks) } @@ -249,6 +257,30 @@ impl SparseTrie for ParallelSparseTrie { let reachable_subtries = self.reachable_subtries(); + // For boundary nodes that are blinded in upper subtrie, unset the blinded bit and remember + // the hash to pass into `reveal_node`. + let hashes_from_upper = nodes + .iter() + .filter_map(|node| { + if node.path.len() == UPPER_TRIE_MAX_DEPTH && + reachable_subtries.get(path_subtrie_index_unchecked(&node.path)) && + let SparseNode::Branch { blinded_mask, blinded_hashes, .. } = self + .upper_subtrie + .nodes + .get_mut(&node.path.slice(0..UPPER_TRIE_MAX_DEPTH - 1)) + .unwrap() + { + let nibble = node.path.last().unwrap(); + blinded_mask.is_bit_set(nibble).then(|| { + blinded_mask.unset_bit(nibble); + (node.path, blinded_hashes[nibble as usize]) + }) + } else { + None + } + }) + .collect::>(); + if !self.is_reveal_parallelism_enabled(lower_nodes.len()) { for node in lower_nodes { let idx = path_subtrie_index_unchecked(&node.path); @@ -277,10 +309,12 @@ impl SparseTrie for ParallelSparseTrie { } self.lower_subtries[idx].reveal(&node.path); self.subtrie_heat.mark_modified(idx); - self.lower_subtries[idx] - .as_revealed_mut() - .expect("just revealed") - .reveal_node(node.path, &node.node, node.masks)?; + self.lower_subtries[idx].as_revealed_mut().expect("just revealed").reveal_node( + node.path, + &node.node, + node.masks, + hashes_from_upper.get(&node.path).copied(), + )?; } return Ok(()) } @@ -372,7 +406,12 @@ impl SparseTrie for ParallelSparseTrie { continue; } // Reveal each node in the subtrie, returning early on any errors - let res = subtrie.reveal_node(node.path, &node.node, node.masks); + let res = subtrie.reveal_node( + node.path, + &node.node, + node.masks, + hashes_from_upper.get(&node.path).copied(), + ); if res.is_err() { return (subtrie_idx, subtrie, res.map(|_| ())) } @@ -567,7 +606,7 @@ impl SparseTrie for ParallelSparseTrie { fn remove_leaf( &mut self, full_path: &Nibbles, - provider: P, + _provider: P, ) -> SparseTrieResult<()> { debug_assert_eq!( full_path.len(), @@ -623,8 +662,8 @@ impl SparseTrie for ParallelSparseTrie { match Self::find_next_to_leaf(&curr_path, curr_node, full_path) { FindNextToLeafOutcome::NotFound => return Ok(()), // leaf isn't in the trie - FindNextToLeafOutcome::BlindedNode(hash) => { - return Err(SparseTrieErrorKind::BlindedNode { path: curr_path, hash }.into()) + FindNextToLeafOutcome::BlindedNode { path, hash } => { + return Err(SparseTrieErrorKind::BlindedNode { path, hash }.into()) } FindNextToLeafOutcome::Found => { // this node is the target leaf @@ -662,7 +701,7 @@ impl SparseTrie for ParallelSparseTrie { ext_grandparent_path = Some(curr_path); ext_grandparent_node = Some(curr_node.clone()); } - SparseNode::Empty | SparseNode::Hash(_) | SparseNode::Leaf { .. } => { + SparseNode::Empty | SparseNode::Leaf { .. } => { unreachable!( "find_next_to_leaf only continues to a branch or extension" ) @@ -684,28 +723,28 @@ impl SparseTrie for ParallelSparseTrie { // Before mutating, check if branch collapse would require revealing a blinded node. // This ensures remove_leaf is atomic: if it errors, the trie is unchanged. - if let (Some(branch_path), Some(SparseNode::Branch { state_mask, .. })) = - (&branch_parent_path, &branch_parent_node) + if let ( + Some(branch_path), + Some(SparseNode::Branch { state_mask, blinded_mask, blinded_hashes, .. }), + ) = (&branch_parent_path, &branch_parent_node) { let mut check_mask = *state_mask; let child_nibble = leaf_path.get_unchecked(branch_path.len()); check_mask.unset_bit(child_nibble); if check_mask.count_bits() == 1 { - // Branch will collapse - check if remaining child needs revealing - let remaining_child_path = { - let mut p = *branch_path; - p.push_unchecked( - check_mask.first_set_bit_index().expect("state mask is not empty"), - ); - p - }; + let remaining_nibble = + check_mask.first_set_bit_index().expect("state mask is not empty"); - // Pre-validate the entire reveal chain (including extension grandchildren). - // This check mirrors the logic in `reveal_remaining_child_on_leaf_removal` with - // `recurse_into_extension: true` to ensure all nodes that would be revealed - // are accessible before any mutations occur. - self.pre_validate_reveal_chain(&remaining_child_path, &provider)?; + if blinded_mask.is_bit_set(remaining_nibble) { + let mut path = *branch_path; + path.push_unchecked(remaining_nibble); + return Err(SparseTrieErrorKind::BlindedNode { + path, + hash: blinded_hashes[remaining_nibble as usize], + } + .into()); + } } } @@ -734,7 +773,7 @@ impl SparseTrie for ParallelSparseTrie { SparseNode::Extension { state, .. } | SparseNode::Branch { state, .. } => { *state = SparseNodeState::Dirty } - SparseNode::Empty | SparseNode::Hash(_) | SparseNode::Leaf { .. } => { + SparseNode::Empty | SparseNode::Leaf { .. } => { unreachable!( "only branch and extension nodes can be marked dirty when removing a leaf" ) @@ -752,8 +791,10 @@ impl SparseTrie for ParallelSparseTrie { // If there is a parent branch node (very likely, unless the leaf is at the root) execute // any required changes for that node, relative to the removed leaf. - if let (Some(branch_path), &Some(SparseNode::Branch { mut state_mask, .. })) = - (&branch_parent_path, &branch_parent_node) + if let ( + Some(branch_path), + &Some(SparseNode::Branch { mut state_mask, blinded_mask, ref blinded_hashes, .. }), + ) = (&branch_parent_path, &branch_parent_node) { let child_nibble = leaf_path.get_unchecked(branch_path.len()); state_mask.unset_bit(child_nibble); @@ -761,13 +802,10 @@ impl SparseTrie for ParallelSparseTrie { let new_branch_node = if state_mask.count_bits() == 1 { // If only one child is left set in the branch node, we need to collapse it. Get // full path of the only child node left. - let remaining_child_path = { - let mut p = *branch_path; - p.push_unchecked( - state_mask.first_set_bit_index().expect("state mask is not empty"), - ); - p - }; + let remaining_child_nibble = + state_mask.first_set_bit_index().expect("state mask is not empty"); + let mut remaining_child_path = *branch_path; + remaining_child_path.push_unchecked(remaining_child_nibble); trace!( target: "trie::parallel_sparse", @@ -779,16 +817,24 @@ impl SparseTrie for ParallelSparseTrie { // If the remaining child node is not yet revealed then we have to reveal it here, // otherwise it's not possible to know how to collapse the branch. - let remaining_child_node = self.reveal_remaining_child_on_leaf_removal( - provider, - full_path, - &remaining_child_path, - )?; + if blinded_mask.is_bit_set(remaining_child_nibble) { + return Err(SparseTrieErrorKind::BlindedNode { + path: remaining_child_path, + hash: blinded_hashes[remaining_child_nibble as usize], + } + .into()); + } + + let remaining_child_node = self + .subtrie_for_path_mut(&remaining_child_path) + .nodes + .get(&remaining_child_path) + .unwrap(); let (new_branch_node, remove_child) = Self::branch_changes_on_leaf_removal( branch_path, &remaining_child_path, - &remaining_child_node, + remaining_child_node, ); if remove_child { @@ -809,7 +855,12 @@ impl SparseTrie for ParallelSparseTrie { } else { // If more than one child is left set in the branch, we just re-insert it with the // updated state_mask. - SparseNode::new_branch(state_mask) + SparseNode::Branch { + state_mask, + blinded_mask, + blinded_hashes: blinded_hashes.clone(), + state: SparseNodeState::Dirty, + } }; let branch_subtrie = self.subtrie_for_path_mut(branch_path); @@ -1050,8 +1101,8 @@ impl SparseTrie for ParallelSparseTrie { match Self::find_next_to_leaf(&curr_path, curr_node, full_path) { FindNextToLeafOutcome::NotFound => return Ok(LeafLookup::NonExistent), - FindNextToLeafOutcome::BlindedNode(hash) => { - return Err(LeafLookupError::BlindedNode { path: curr_path, hash }); + FindNextToLeafOutcome::BlindedNode { path, hash } => { + return Err(LeafLookupError::BlindedNode { path, hash }); } FindNextToLeafOutcome::Found => { panic!("target leaf {full_path:?} found at path {curr_path:?}, even though value wasn't in values hashmap"); @@ -1144,64 +1195,91 @@ impl SparseTrie for ParallelSparseTrie { continue; } - // Get children to visit from current node (immutable access) - let mut is_extension = false; - let children: SmallVec<[Nibbles; 16]> = { - let Some(subtrie) = self.subtrie_for_path(&path) else { continue }; - let Some(node) = subtrie.nodes.get(&path) else { continue }; + let Some(subtrie) = self.subtrie_for_path_mut_untracked(&path) else { continue }; + let Some(node) = subtrie.nodes.get_mut(&path) else { continue }; - match node { - SparseNode::Empty | SparseNode::Hash(_) | SparseNode::Leaf { .. } => { - SmallVec::new() - } - SparseNode::Extension { key, .. } => { + match node { + SparseNode::Empty | SparseNode::Leaf { .. } => {} + SparseNode::Extension { key, state, .. } => { + // For extension nodes at max depth, collapse both extension and its child + // branch to preserve invariant of all extension nodes children being revealed. + if depth == max_depth { + let Some(hash) = state.cached_hash() else { continue }; + subtrie.nodes.remove(&path); + + let parent_path = path.slice(0..path.len() - 1); + let SparseNode::Branch { blinded_mask, blinded_hashes, .. } = + subtrie.nodes.get_mut(&parent_path).unwrap() + else { + panic!("expected branch node at path {parent_path:?}"); + }; + + let nibble = path.last().unwrap(); + blinded_mask.set_bit(nibble); + blinded_hashes[nibble as usize] = hash; + + effective_pruned_roots.push(path); + } else { let mut child = path; child.extend(key); - is_extension = true; - SmallVec::from_slice(&[child]) - } - SparseNode::Branch { state_mask, .. } => { - let mut children = SmallVec::new(); - let mut mask = state_mask.get(); - while mask != 0 { - let nibble = mask.trailing_zeros() as u8; - mask &= mask - 1; - let mut child = path; - child.push_unchecked(nibble); - children.push(child); - } - children + stack.push((child, depth + 1)); } } - }; + SparseNode::Branch { state_mask, blinded_mask, blinded_hashes, .. } => { + // For branch nodes at max depth, collapse all children onto them, + if depth == max_depth { + let mut blinded_mask = *blinded_mask; + let mut blinded_hashes = blinded_hashes.clone(); + for nibble in state_mask.iter() { + if blinded_mask.is_bit_set(nibble) { + continue; + } + let mut child = path; + child.push_unchecked(nibble); - // Process children - either continue traversal or prune - for child in children { - if depth == max_depth { - let path_to_prune = if is_extension { - // If this is a child of extension node, we want to prune the extension node - // itself to preserve invariant of both extension and branch nodes being - // revealed. - path + let Entry::Occupied(entry) = self + .subtrie_for_path_mut_untracked(&child) + .unwrap() + .nodes + .entry(child) + else { + panic!("expected node at path {child:?}"); + }; + + let Some(hash) = entry.get().cached_hash() else { + continue; + }; + entry.remove(); + blinded_mask.set_bit(nibble); + blinded_hashes[nibble as usize] = hash; + effective_pruned_roots.push(child); + } + + let SparseNode::Branch { + blinded_mask: old_blinded_mask, + blinded_hashes: old_blinded_hashes, + .. + } = self + .subtrie_for_path_mut_untracked(&path) + .unwrap() + .nodes + .get_mut(&path) + .unwrap() + else { + unreachable!("expected branch node at path {path:?}"); + }; + *old_blinded_mask = blinded_mask; + *old_blinded_hashes = blinded_hashes; } else { - child - }; - // Check if child has a computed hash and replace inline - let hash = self - .subtrie_for_path(&path_to_prune) - .and_then(|s| s.nodes.get(&path_to_prune)) - .filter(|n| !n.is_hash()) - .and_then(|n| n.cached_hash()); - - if let Some(hash) = hash { - // Use untracked access to avoid marking subtrie as modified during pruning - if let Some(subtrie) = self.subtrie_for_path_mut_untracked(&path_to_prune) { - subtrie.nodes.insert(path_to_prune, SparseNode::Hash(hash)); - effective_pruned_roots.push(path_to_prune); + for nibble in state_mask.iter() { + if blinded_mask.is_bit_set(nibble) { + continue; + } + let mut child = path; + child.push_unchecked(nibble); + stack.push((child, depth + 1)); } } - } else { - stack.push((child, depth + 1)); } } } @@ -1230,9 +1308,9 @@ impl SparseTrie for ParallelSparseTrie { debug_assert!( { - let mut all_roots: Vec<_> = effective_pruned_roots.iter().collect(); + let mut all_roots: Vec<_> = effective_pruned_roots.clone(); all_roots.sort_unstable(); - all_roots.windows(2).all(|w| !w[1].starts_with(w[0])) + all_roots.windows(2).all(|w| !w[1].starts_with(&w[0])) }, "prune roots must be prefix-free" ); @@ -1560,7 +1638,6 @@ impl ParallelSparseTrie { // If empty node is found it means the subtrie doesn't have any nodes in it, let alone // the target leaf. SparseNode::Empty => FindNextToLeafOutcome::NotFound, - SparseNode::Hash(rlp_node) => FindNextToLeafOutcome::BlindedNode(*rlp_node), SparseNode::Leaf { key, .. } => { let mut found_full_path = *from_path; found_full_path.extend(key); @@ -1583,7 +1660,7 @@ impl ParallelSparseTrie { } FindNextToLeafOutcome::ContinueFrom(child_path) } - SparseNode::Branch { state_mask, .. } => { + SparseNode::Branch { state_mask, blinded_mask, blinded_hashes, .. } => { if leaf_full_path.len() == from_path.len() { return FindNextToLeafOutcome::NotFound } @@ -1596,6 +1673,13 @@ impl ParallelSparseTrie { let mut child_path = *from_path; child_path.push_unchecked(nibble); + if blinded_mask.is_bit_set(nibble) { + return FindNextToLeafOutcome::BlindedNode { + path: child_path, + hash: blinded_hashes[nibble as usize], + }; + } + FindNextToLeafOutcome::ContinueFrom(child_path) } } @@ -1692,7 +1776,7 @@ impl ParallelSparseTrie { // If we swap the branch node out either an extension or leaf, depending on // what its remaining child is. match remaining_child_node { - SparseNode::Empty | SparseNode::Hash(_) => { + SparseNode::Empty => { panic!("remaining child must have been revealed already") } // If the only child is a leaf node, we downgrade the branch node into a @@ -1740,7 +1824,7 @@ impl ParallelSparseTrie { // If the parent node is an extension node, we need to look at its child to see // if we need to merge it. match child { - SparseNode::Empty | SparseNode::Hash(_) => { + SparseNode::Empty => { panic!("child must be revealed") } // For a leaf node, we collapse the extension node into a leaf node, @@ -1765,112 +1849,6 @@ impl ParallelSparseTrie { } } - /// Pre-validates reveal chain accessibility before mutations. - /// - /// Walks the trie path checking that all nodes can be revealed. This is called before - /// any mutations to ensure the operation will succeed atomically. - /// - /// Returns `BlindedNode` error if any node in the chain cannot be revealed by the provider. - fn pre_validate_reveal_chain( - &self, - path: &Nibbles, - provider: &P, - ) -> SparseTrieResult<()> { - // Find the subtrie containing this path, or return Ok if path doesn't exist - let subtrie = match self.subtrie_for_path(path) { - Some(s) => s, - None => return Ok(()), - }; - - match subtrie.nodes.get(path) { - // Hash node: attempt to reveal from provider - Some(SparseNode::Hash(hash)) => match provider.trie_node(path)? { - Some(RevealedNode { node, .. }) => { - let decoded = TrieNode::decode(&mut &node[..])?; - // Extension nodes have children that also need validation - if let TrieNode::Extension(ext) = decoded { - let mut grandchild_path = *path; - grandchild_path.extend(&ext.key); - - return self.pre_validate_reveal_chain(&grandchild_path, provider); - } - - Ok(()) - } - // Provider cannot reveal this node - operation would fail - None => Err(SparseTrieErrorKind::BlindedNode { path: *path, hash: *hash }.into()), - }, - // Leaf, Extension, Branch, Empty, or missing: no further validation needed - _ => Ok(()), - } - } - - /// Called when a leaf is removed on a branch which has only one other remaining child. That - /// child must be revealed in order to properly collapse the branch. - /// - /// If `recurse_into_extension` is true, and the remaining child is an extension node, then its - /// child will be ensured to be revealed as well. - /// - /// ## Returns - /// - /// The node of the remaining child, whether it was already revealed or not. - fn reveal_remaining_child_on_leaf_removal( - &mut self, - provider: P, - full_path: &Nibbles, // only needed for logs - remaining_child_path: &Nibbles, - ) -> SparseTrieResult { - let remaining_child_subtrie = self.subtrie_for_path_mut(remaining_child_path); - - let (remaining_child_node, remaining_child_masks) = match remaining_child_subtrie - .nodes - .get(remaining_child_path) - .unwrap() - { - SparseNode::Hash(_) => { - debug!( - target: "trie::parallel_sparse", - child_path = ?remaining_child_path, - leaf_full_path = ?full_path, - "Node child not revealed in remove_leaf, falling back to db", - ); - if let Some(RevealedNode { node, tree_mask, hash_mask }) = - provider.trie_node(remaining_child_path)? - { - let decoded = TrieNodeV2::decode(&mut &node[..])?; - trace!( - target: "trie::parallel_sparse", - ?remaining_child_path, - ?decoded, - ?tree_mask, - ?hash_mask, - "Revealing remaining blinded branch child" - ); - let masks = BranchNodeMasks::from_optional(hash_mask, tree_mask); - remaining_child_subtrie.reveal_node(*remaining_child_path, &decoded, masks)?; - ( - remaining_child_subtrie.nodes.get(remaining_child_path).unwrap().clone(), - masks, - ) - } else { - return Err(SparseTrieErrorKind::NodeNotFoundInProvider { - path: *remaining_child_path, - } - .into()) - } - } - // The node is already revealed so we don't need to return its masks here, as they don't - // need to be inserted. - node => (node.clone(), None), - }; - - if let Some(masks) = remaining_child_masks { - self.branch_node_masks.insert(*remaining_child_path, masks); - } - - Ok(remaining_child_node) - } - /// Drains any [`SparseTrieUpdatesAction`]s from the given subtrie, and applies each action to /// the given `updates` set. If the given set is None then this is a no-op. #[instrument(level = "trace", target = "trie::parallel_sparse", skip_all)] @@ -2099,7 +2077,7 @@ impl ParallelSparseTrie { } // Exit early if the node was already revealed before. - if !self.upper_subtrie.reveal_node(path, node, masks)? { + if !self.upper_subtrie.reveal_node(path, node, masks, None)? { if let TrieNodeV2::Branch(branch) = node { if branch.key.is_empty() { return Ok(()); @@ -2144,9 +2122,20 @@ impl ParallelSparseTrie { for (stack_ptr, idx) in branch.state_mask.iter().enumerate() { let mut child_path = branch_path; child_path.push_unchecked(idx); - self.lower_subtrie_for_path_mut(&child_path) - .expect("child_path must have a lower subtrie") - .reveal_node_or_hash(child_path, &branch.stack[stack_ptr])?; + let child = &branch.stack[stack_ptr]; + + // Only reveal children that are not hashes. Hashes are stored on branch + // nodes directly. + if !child.is_hash() { + self.lower_subtrie_for_path_mut(&child_path) + .expect("child_path must have a lower subtrie") + .reveal_node( + child_path, + &TrieNodeV2::decode(&mut branch.stack[stack_ptr].as_ref())?, + None, + None, + )?; + } } } } @@ -2154,7 +2143,12 @@ impl ParallelSparseTrie { let mut child_path = path; child_path.extend(&ext.key); if let Some(subtrie) = self.lower_subtrie_for_path_mut(&child_path) { - subtrie.reveal_node_or_hash(child_path, &ext.child)?; + subtrie.reveal_node( + child_path, + &TrieNodeV2::decode(&mut ext.child.as_ref())?, + None, + None, + )?; } } TrieNodeV2::EmptyRoot | TrieNodeV2::Leaf(_) => (), @@ -2247,7 +2241,7 @@ impl ParallelSparseTrie { } current.extend(key); } - SparseNode::Hash(_) | SparseNode::Empty | SparseNode::Leaf { .. } => return false, + SparseNode::Empty | SparseNode::Leaf { .. } => return false, } } true @@ -2309,7 +2303,7 @@ impl ParallelSparseTrie { stack.push(next); } } - SparseNode::Hash(_) | SparseNode::Empty | SparseNode::Leaf { .. } => {} + SparseNode::Empty | SparseNode::Leaf { .. } => {} }; } @@ -2429,7 +2423,7 @@ enum FindNextToLeafOutcome { NotFound, /// `BlindedNode` indicates that the node is blinded with the contained hash and cannot be /// traversed. - BlindedNode(B256), + BlindedNode { path: Nibbles, hash: B256 }, } impl SparseSubtrie { @@ -2590,9 +2584,6 @@ impl SparseSubtrie { *node = SparseNode::new_leaf(path); Ok(LeafUpdateStep::complete_with_insertions(vec![current])) } - SparseNode::Hash(hash) => { - Err(SparseTrieErrorKind::BlindedNode { path: current, hash: *hash }.into()) - } SparseNode::Leaf { key: current_key, .. } => { current.extend(current_key); @@ -2671,7 +2662,7 @@ impl SparseSubtrie { Ok(LeafUpdateStep::continue_with(current)) } - SparseNode::Branch { state_mask, .. } => { + SparseNode::Branch { state_mask, blinded_mask, blinded_hashes, .. } => { let nibble = path.get_unchecked(current.len()); current.push_unchecked(nibble); if !state_mask.is_bit_set(nibble) { @@ -2681,6 +2672,14 @@ impl SparseSubtrie { return Ok(LeafUpdateStep::complete_with_insertions(vec![current])) } + if blinded_mask.is_bit_set(nibble) { + return Err(SparseTrieErrorKind::BlindedNode { + path: current, + hash: blinded_hashes[nibble as usize], + } + .into()); + } + // If the nibble is set, we can continue traversing the branch. Ok(LeafUpdateStep::continue_with(current)) } @@ -2697,25 +2696,9 @@ impl SparseSubtrie { rlp_node: Option, ) -> SparseTrieResult<()> { match self.nodes.entry(path) { - Entry::Occupied(mut entry) => { - match entry.get() { - SparseNode::Hash(hash) => { - // Replace a hash node with a fully revealed branch node. - entry.insert(SparseNode::Branch { - state_mask, - // Memoize the hash of a previously blinded node in a new branch - // node. - state: SparseNodeState::Cached { - rlp_node: RlpNode::word_rlp(hash), - store_in_db_trie: Some(masks.is_some_and(|m| { - !m.hash_mask.is_empty() || !m.tree_mask.is_empty() - })), - }, - }); - } - // Branch already revealed, do nothing - _ => return Ok(()), - } + Entry::Occupied(_) => { + // Branch already revealed, do nothing + return Ok(()); } Entry::Vacant(entry) => { let state = @@ -2728,7 +2711,27 @@ impl SparseSubtrie { }, None => SparseNodeState::Dirty, }; - entry.insert(SparseNode::Branch { state_mask, state }); + + let mut blinded_mask = TrieMask::default(); + let mut blinded_hashes = Box::new([B256::ZERO; 16]); + + for (stack_ptr, idx) in state_mask.iter().enumerate() { + let mut child_path = path; + child_path.push_unchecked(idx); + let child = &children[stack_ptr]; + + if let Some(hash) = child.as_hash() { + blinded_mask.set_bit(idx); + blinded_hashes[idx as usize] = hash; + } + } + + entry.insert(SparseNode::Branch { + state_mask, + state, + blinded_mask, + blinded_hashes, + }); } } @@ -2737,10 +2740,16 @@ impl SparseSubtrie { for (stack_ptr, idx) in state_mask.iter().enumerate() { let mut child_path = path; child_path.push_unchecked(idx); - if Self::is_child_same_level(&path, &child_path) { + let child = &children[stack_ptr]; + if !child.is_hash() && Self::is_child_same_level(&path, &child_path) { // Reveal each child node or hash it has, but only if the child is on // the same level as the parent. - self.reveal_node_or_hash(child_path, &children[stack_ptr])?; + self.reveal_node( + child_path, + &TrieNodeV2::decode(&mut child.as_ref())?, + None, + None, + )?; } } @@ -2748,42 +2757,46 @@ impl SparseSubtrie { } /// Internal implementation of the method of the same name on `ParallelSparseTrie`. + /// + /// This accepts `hash_from_upper` to handle cases when boundary nodes revealed in lower subtrie + /// but its blinded hash is known from the upper subtrie. fn reveal_node( &mut self, path: Nibbles, node: &TrieNodeV2, masks: Option, + hash_from_upper: Option, ) -> SparseTrieResult { debug_assert!(path.starts_with(&self.path)); - let mut skip_extension_reveal = false; - - // If the node is already revealed and it's not a hash node, do nothing. - if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) { - // Make sure that we reveal branch nodes properly even when extension was already - // revealed. - if let TrieNodeV2::Branch(branch) = node && - !branch.key.is_empty() - { - let mut branch_path = path; - branch_path.extend(&branch.key); - - // If branch does not belong to this subtrie, we can exit. - if !Self::is_child_same_level(&path, &branch_path) { - return Ok(false); - } - - // If branch is already revealed, we can exit. - if self.nodes.get(&branch_path).is_some_and(|node| !node.is_hash()) { - return Ok(false); - } - - skip_extension_reveal = true; - } else { - return Ok(false) - } + // If the node is already revealed, do nothing. + if self.nodes.contains_key(&path) { + return Ok(false); } + // If the hash is provided from the upper subtrie, use it. Otherwise, find the parent branch + // node, unset its blinded bit and use the hash. + let hash = if let Some(hash) = hash_from_upper { + Some(hash) + } else if path.len() != UPPER_TRIE_MAX_DEPTH && !path.is_empty() { + let Some(SparseNode::Branch { state_mask, blinded_mask, blinded_hashes, .. }) = + self.nodes.get_mut(&path.slice(0..path.len() - 1)) + else { + return Ok(false); + }; + let nibble = path.last().unwrap(); + if !state_mask.is_bit_set(nibble) { + return Ok(false); + } + + blinded_mask.is_bit_set(nibble).then(|| { + blinded_mask.unset_bit(nibble); + blinded_hashes[nibble as usize] + }) + } else { + None + }; + trace!( target: "trie::parallel_sparse", ?path, @@ -2806,38 +2819,29 @@ impl SparseSubtrie { branch.state_mask, &branch.stack, masks, - branch.branch_rlp_node.clone(), + hash.as_ref().map(RlpNode::word_rlp), )?; return Ok(true); } - if !skip_extension_reveal { - match self.nodes.entry(path) { - Entry::Occupied(mut entry) => match entry.get() { - SparseNode::Hash(hash) => { - // Replace a hash node with a revealed extension node. - entry.insert(SparseNode::Extension { - key: branch.key, - state: SparseNodeState::Cached { - // Memoize the hash of a previously blinded node in a new - // extension node. - rlp_node: RlpNode::word_rlp(hash), - // Inherit `store_in_db_trie` from the child branch - // node masks so that the memoized hash can be used - // without needing to fetch the child branch. - store_in_db_trie: Some(masks.is_some_and(|m| { - !m.hash_mask.is_empty() || !m.tree_mask.is_empty() - })), - }, - }); - } - _ => unreachable!("checked that node is either a hash or non-existent"), - }, - Entry::Vacant(entry) => { - entry.insert(SparseNode::new_ext(branch.key)); - } - } - } + self.nodes.insert( + path, + SparseNode::Extension { + key: branch.key, + state: hash + .as_ref() + .map(|hash| SparseNodeState::Cached { + rlp_node: RlpNode::word_rlp(hash), + // Inherit `store_in_db_trie` from the child branch + // node masks so that the memoized hash can be used + // without needing to fetch the child branch. + store_in_db_trie: Some(masks.is_some_and(|m| { + !m.hash_mask.is_empty() || !m.tree_mask.is_empty() + })), + }) + .unwrap_or(SparseNodeState::Dirty), + }, + ); let mut branch_path = path; branch_path.extend(&branch.key); @@ -2856,36 +2860,7 @@ impl SparseSubtrie { branch.branch_rlp_node.clone(), )?; } - TrieNodeV2::Extension(ext) => match self.nodes.entry(path) { - Entry::Occupied(mut entry) => match entry.get() { - // Replace a hash node with a revealed extension node. - SparseNode::Hash(hash) => { - let mut child_path = *entry.key(); - child_path.extend(&ext.key); - entry.insert(SparseNode::Extension { - key: ext.key, - state: SparseNodeState::Cached { - // Memoize the hash of a previously blinded node in the new - // extension node. - rlp_node: RlpNode::word_rlp(hash), - store_in_db_trie: None, - }, - }); - if Self::is_child_same_level(&path, &child_path) { - self.reveal_node_or_hash(child_path, &ext.child)?; - } - } - _ => unreachable!("checked that node is either a hash or non-existent"), - }, - Entry::Vacant(entry) => { - let mut child_path = *entry.key(); - child_path.extend(&ext.key); - entry.insert(SparseNode::new_ext(ext.key)); - if Self::is_child_same_level(&path, &child_path) { - self.reveal_node_or_hash(child_path, &ext.child)?; - } - } - }, + TrieNodeV2::Extension(_) => unreachable!(), TrieNodeV2::Leaf(leaf) => { // Skip the reachability check when path.len() == UPPER_TRIE_MAX_DEPTH because // at that boundary the leaf is in the lower subtrie but its parent branch is in @@ -2919,77 +2894,25 @@ impl SparseSubtrie { } } - match self.nodes.entry(path) { - Entry::Occupied(mut entry) => match entry.get() { - // Replace a hash node with a revealed leaf node and store leaf node value. - SparseNode::Hash(hash) => { - entry.insert(SparseNode::Leaf { - key: leaf.key, - state: SparseNodeState::Cached { - // Memoize the hash of a previously blinded node in the new leaf - // node. - rlp_node: RlpNode::word_rlp(hash), - store_in_db_trie: Some(false), - }, - }); - } - _ => unreachable!("checked that node is either a hash or non-existent"), + self.nodes.insert( + path, + SparseNode::Leaf { + key: leaf.key, + state: hash + .as_ref() + .map(|hash| SparseNodeState::Cached { + rlp_node: RlpNode::word_rlp(hash), + store_in_db_trie: Some(false), + }) + .unwrap_or(SparseNodeState::Dirty), }, - Entry::Vacant(entry) => { - entry.insert(SparseNode::new_leaf(leaf.key)); - } - } + ); } } Ok(true) } - /// Reveals either a node or its hash placeholder based on the provided child data. - /// - /// When traversing the trie, we often encounter references to child nodes that - /// are either directly embedded or represented by their hash. This method - /// handles both cases: - /// - /// 1. If the child data represents a hash (32+1=33 bytes), store it as a hash node - /// 2. Otherwise, decode the data as a [`TrieNode`] and recursively reveal it using - /// `reveal_node` - /// - /// # Returns - /// - /// Returns `Ok(())` if successful, or an error if the node cannot be revealed. - /// - /// # Error Handling - /// - /// Will error if there's a conflict between a new hash node and an existing one - /// at the same path - fn reveal_node_or_hash(&mut self, path: Nibbles, child: &[u8]) -> SparseTrieResult<()> { - if child.len() == B256::len_bytes() + 1 { - let hash = B256::from_slice(&child[1..]); - match self.nodes.entry(path) { - Entry::Occupied(entry) => match entry.get() { - // Hash node with a different hash can't be handled. - SparseNode::Hash(previous_hash) if previous_hash != &hash => { - return Err(SparseTrieErrorKind::Reveal { - path: *entry.key(), - node: Box::new(SparseNode::Hash(hash)), - } - .into()) - } - _ => {} - }, - Entry::Vacant(entry) => { - entry.insert(SparseNode::Hash(hash)); - } - } - return Ok(()) - } - - self.reveal_node(path, &TrieNodeV2::decode(&mut &child[..])?, None)?; - - Ok(()) - } - /// Recalculates and updates the RLP hashes for the changed nodes in this subtrie. /// /// The function starts from the subtrie root, traverses down to leaves, and then calculates @@ -3157,10 +3080,6 @@ impl SparseSubtrieInner { let (rlp_node, node_type) = match node { SparseNode::Empty => (RlpNode::word_rlp(&EMPTY_ROOT_HASH), SparseNodeType::Empty), - SparseNode::Hash(hash) => { - // Return pre-computed hash of a blinded node immediately - (RlpNode::word_rlp(hash), SparseNodeType::Hash) - } SparseNode::Leaf { key, state } => { let mut path = path; path.extend(key); @@ -3256,7 +3175,7 @@ impl SparseSubtrieInner { return } } - SparseNode::Branch { state_mask, state } => { + SparseNode::Branch { state_mask, state, blinded_mask, blinded_hashes } => { if let Some((rlp_node, store_in_db_trie)) = state .cached_rlp_node() .zip(state.store_in_db_trie()) @@ -3291,13 +3210,13 @@ impl SparseSubtrieInner { for bit in state_mask.iter().rev() { let mut child = path; child.push_unchecked(bit); - self.buffers.branch_child_buf.push(child); + + if !blinded_mask.is_bit_set(bit) { + self.buffers.branch_child_buf.push(child); + } } - self.buffers - .branch_value_stack_buf - .resize(self.buffers.branch_child_buf.len(), Default::default()); - let mut added_children = false; + self.buffers.branch_value_stack_buf.resize(state_mask.len(), Default::default()); let mut tree_mask = TrieMask::default(); let mut hash_mask = TrieMask::default(); @@ -3308,62 +3227,28 @@ impl SparseSubtrieInner { let mut path_masks = || *path_masks_storage.get_or_insert_with(|| branch_node_masks.get(&path)); - for (i, child_path) in self.buffers.branch_child_buf.iter().enumerate() { - if self.buffers.rlp_node_stack.last().is_some_and(|e| &e.path == child_path) { - let RlpNodeStackItem { - path: _, - rlp_node: child, - node_type: child_node_type, - } = self.buffers.rlp_node_stack.pop().unwrap(); + for (i, child_nibble) in state_mask.iter().enumerate().rev() { + let mut child_path = path; + child_path.push_unchecked(child_nibble); - // Update the masks only if we need to retain trie updates - if retain_updates { - // SAFETY: it's a child, so it's never empty - let last_child_nibble = child_path.last().unwrap(); + let (child, child_node_type) = if blinded_mask.is_bit_set(child_nibble) { + ( + RlpNode::word_rlp(&blinded_hashes[child_nibble as usize]), + SparseNodeType::Hash, + ) + } else if self + .buffers + .rlp_node_stack + .last() + .is_some_and(|e| e.path == child_path) + { + let RlpNodeStackItem { path: _, rlp_node, node_type } = + self.buffers.rlp_node_stack.pop().unwrap(); - // Determine whether we need to set trie mask bit. - let should_set_tree_mask_bit = if let Some(store_in_db_trie) = - child_node_type.store_in_db_trie() - { - // A branch or an extension node explicitly set the - // `store_in_db_trie` flag - store_in_db_trie - } else { - // A blinded node has the tree mask bit set - child_node_type.is_hash() && - path_masks().is_some_and(|masks| { - masks.tree_mask.is_bit_set(last_child_nibble) - }) - }; - if should_set_tree_mask_bit { - tree_mask.set_bit(last_child_nibble); - } - // Set the hash mask. If a child node is a revealed branch node OR - // is a blinded node that has its hash mask bit set according to the - // database, set the hash mask bit and save the hash. - let hash = child.as_hash().filter(|_| { - child_node_type.is_branch() || - (child_node_type.is_hash() && - path_masks().is_some_and(|masks| { - masks.hash_mask.is_bit_set(last_child_nibble) - })) - }); - if let Some(hash) = hash { - hash_mask.set_bit(last_child_nibble); - hashes.push(hash); - } - } - - // Insert children in the resulting buffer in a normal order, - // because initially we iterated in reverse. - // SAFETY: i < len and len is never 0 - let original_idx = self.buffers.branch_child_buf.len() - i - 1; - self.buffers.branch_value_stack_buf[original_idx] = child; - added_children = true; + (rlp_node, node_type) } else { // Need to defer processing until children are computed, on the next // invocation update the node's hash. - debug_assert!(!added_children); self.buffers.path_stack.push(RlpNodePathStackItem { path, is_in_prefix_set: Some(prefix_set_contains(&path)), @@ -3375,7 +3260,46 @@ impl SparseSubtrieInner { .map(|path| RlpNodePathStackItem { path, is_in_prefix_set: None }), ); return + }; + + // Update the masks only if we need to retain trie updates + if retain_updates { + // Determine whether we need to set trie mask bit. + let should_set_tree_mask_bit = + if let Some(store_in_db_trie) = child_node_type.store_in_db_trie() { + // A branch or an extension node explicitly set the + // `store_in_db_trie` flag + store_in_db_trie + } else { + // A blinded node has the tree mask bit set + child_node_type.is_hash() && + path_masks().is_some_and(|masks| { + masks.tree_mask.is_bit_set(child_nibble) + }) + }; + if should_set_tree_mask_bit { + tree_mask.set_bit(child_nibble); + } + // Set the hash mask. If a child node is a revealed branch node OR + // is a blinded node that has its hash mask bit set according to the + // database, set the hash mask bit and save the hash. + let hash = child.as_hash().filter(|_| { + child_node_type.is_branch() || + (child_node_type.is_hash() && + path_masks().is_some_and(|masks| { + masks.hash_mask.is_bit_set(child_nibble) + })) + }); + if let Some(hash) = hash { + hash_mask.set_bit(child_nibble); + hashes.push(hash); + } } + + // Insert children in the resulting buffer in a normal order, + // because initially we iterated in reverse. + // SAFETY: i < len and len is never 0 + self.buffers.branch_value_stack_buf[i] = child; } trace!( @@ -3684,22 +3608,22 @@ mod tests { }; use crate::{ parallel::ChangedSubtrie, - provider::{DefaultTrieNodeProvider, RevealedNode, TrieNodeProvider}, + provider::{DefaultTrieNodeProvider, NoRevealProvider}, trie::SparseNodeState, LeafLookup, LeafLookupError, SparseNode, SparseTrie, SparseTrieUpdates, }; use alloy_primitives::{ b256, hex, - map::{B256Set, DefaultHashBuilder, HashMap}, + map::{B256Set, HashMap}, B256, U256, }; use alloy_rlp::{Decodable, Encodable}; - use alloy_trie::{BranchNodeCompact, Nibbles}; + use alloy_trie::{proof::AddedRemovedKeys, BranchNodeCompact, Nibbles}; use assert_matches::assert_matches; use itertools::Itertools; use proptest::{prelude::*, sample::SizeRange}; use proptest_arbitrary_interop::arb; - use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind}; + use reth_execution_errors::SparseTrieErrorKind; use reth_primitives_traits::Account; use reth_provider::{test_utils::create_test_provider_factory, TrieWriter}; use reth_trie::{ @@ -3713,8 +3637,9 @@ mod tests { prefix_set::PrefixSetMut, proof::{ProofNodes, ProofRetainer}, updates::TrieUpdates, - BranchNodeMasks, BranchNodeMasksMap, BranchNodeV2, ExtensionNode, HashBuilder, LeafNode, - ProofTrieNodeV2, RlpNode, TrieMask, TrieNodeV2, EMPTY_ROOT_HASH, + BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, BranchNodeV2, ExtensionNode, + HashBuilder, LeafNode, ProofTrieNodeV2, RlpNode, TrieMask, TrieNode, TrieNodeV2, + EMPTY_ROOT_HASH, }; use reth_trie_db::DatabaseTrieCursorFactory; use std::collections::{BTreeMap, BTreeSet}; @@ -3737,34 +3662,6 @@ mod tests { nibbles } - /// Mock trie node provider for testing that allows pre-setting nodes at specific paths. - /// - /// This provider can be used in tests to simulate trie nodes that need to be revealed - /// during trie operations, particularly when collapsing branch nodes during leaf removal. - #[derive(Debug, Clone)] - struct MockTrieNodeProvider { - /// Mapping from path to revealed node data - nodes: HashMap, - } - - impl MockTrieNodeProvider { - /// Creates a new empty mock provider - fn new() -> Self { - Self { nodes: HashMap::default() } - } - - /// Adds a revealed node at the specified path - fn add_revealed_node(&mut self, path: Nibbles, node: RevealedNode) { - self.nodes.insert(path, node); - } - } - - impl TrieNodeProvider for MockTrieNodeProvider { - fn trie_node(&self, path: &Nibbles) -> Result, SparseTrieError> { - Ok(self.nodes.get(path).cloned()) - } - } - fn create_account(nonce: u64) -> Account { Account { nonce, ..Default::default() } } @@ -3974,14 +3871,8 @@ mod tests { )) } - fn create_extension_node(key: impl AsRef<[u8]>, child_hash: B256) -> TrieNodeV2 { - TrieNodeV2::Extension(ExtensionNode::new( - Nibbles::from_nibbles(key), - RlpNode::word_rlp(&child_hash), - )) - } - - fn create_branch_node_with_children( + fn create_branch_node( + key: Nibbles, children_indices: &[u8], child_hashes: impl IntoIterator, ) -> TrieNodeV2 { @@ -3993,7 +3884,20 @@ mod tests { stack.push(hash); } - TrieNodeV2::Branch(BranchNodeV2::new(Nibbles::default(), stack, state_mask, None)) + let branch_rlp_node = if key.is_empty() { + None + } else { + Some(RlpNode::from_rlp(&alloy_rlp::encode(BranchNodeRef::new(&stack, state_mask)))) + }; + + TrieNodeV2::Branch(BranchNodeV2::new(key, stack, state_mask, branch_rlp_node)) + } + + fn create_branch_node_with_children( + children_indices: &[u8], + child_hashes: impl IntoIterator, + ) -> TrieNodeV2 { + create_branch_node(Nibbles::default(), children_indices, child_hashes) } /// Calculate the state root by feeding the provided state to the hash builder and retaining the @@ -4011,7 +3915,9 @@ mod tests { let mut hash_builder = HashBuilder::default() .with_updates(true) - .with_proof_retainer(ProofRetainer::from_iter(proof_targets)); + .with_proof_retainer(ProofRetainer::from_iter(proof_targets).with_added_removed_keys( + Some(AddedRemovedKeys::default().with_assume_added(true)), + )); let mut prefix_set = PrefixSetMut::default(); prefix_set.extend_keys(state.clone().into_iter().map(|(nibbles, _)| nibbles)); @@ -4140,7 +4046,7 @@ mod tests { SparseNode::Leaf { key: sparse_key, .. }, ) => proof_key == sparse_key, // Empty and hash nodes are specific to the sparse trie, skip them - (_, SparseNode::Empty | SparseNode::Hash(_)) => continue, + (_, SparseNode::Empty) => continue, _ => false, }; assert!( @@ -4384,81 +4290,6 @@ mod tests { } } - #[test] - fn test_reveal_node_extension_all_upper() { - let path = Nibbles::new(); - let child_hash = B256::repeat_byte(0xab); - let node = create_extension_node([0x1], child_hash); - let masks = None; - let trie = ParallelSparseTrie::from_root(node, masks, true).unwrap(); - - assert_matches!( - trie.upper_subtrie.nodes.get(&path), - Some(SparseNode::Extension { key, state: SparseNodeState::Dirty, .. }) - if key == &Nibbles::from_nibbles([0x1]) - ); - - // Child path should be in upper trie - let child_path = Nibbles::from_nibbles([0x1]); - assert_eq!(trie.upper_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash))); - } - - #[test] - fn test_reveal_node_extension_cross_level() { - let path = Nibbles::new(); - let child_hash = B256::repeat_byte(0xcd); - let node = create_extension_node([0x1, 0x2, 0x3], child_hash); - let masks = None; - let trie = ParallelSparseTrie::from_root(node, masks, true).unwrap(); - - // Extension node should be in upper trie - assert_matches!( - trie.upper_subtrie.nodes.get(&path), - Some(SparseNode::Extension { key, state: SparseNodeState::Dirty, .. }) - if key == &Nibbles::from_nibbles([0x1, 0x2, 0x3]) - ); - - // Child path (0x1, 0x2, 0x3) should be in lower trie - let child_path = Nibbles::from_nibbles([0x1, 0x2, 0x3]); - let idx = path_subtrie_index_unchecked(&child_path); - assert!(trie.lower_subtries[idx].as_revealed_ref().is_some()); - - let lower_subtrie = trie.lower_subtries[idx].as_revealed_ref().unwrap(); - assert_eq!(lower_subtrie.path, child_path); - assert_eq!(lower_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash))); - } - - #[test] - fn test_reveal_node_extension_cross_level_boundary() { - // Set up root branch with nibble 0x1 so path [0x1] is reachable. - let root_branch = - create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]); - let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap(); - - let path = Nibbles::from_nibbles([0x1]); - let child_hash = B256::repeat_byte(0xcd); - let node = create_extension_node([0x2], child_hash); - let masks = None; - - trie.reveal_nodes(&mut [ProofTrieNodeV2 { path, node, masks }]).unwrap(); - - // Extension node should be in upper trie, hash is memoized from the previous Hash node - assert_matches!( - trie.upper_subtrie.nodes.get(&path), - Some(SparseNode::Extension { key, state: SparseNodeState::Cached { .. }, .. }) - if key == &Nibbles::from_nibbles([0x2]) - ); - - // Child path (0x1, 0x2) should be in lower trie - let child_path = Nibbles::from_nibbles([0x1, 0x2]); - let idx = path_subtrie_index_unchecked(&child_path); - assert!(trie.lower_subtries[idx].as_revealed_ref().is_some()); - - let lower_subtrie = trie.lower_subtries[idx].as_revealed_ref().unwrap(); - assert_eq!(lower_subtrie.path, child_path); - assert_eq!(lower_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash))); - } - #[test] fn test_reveal_node_branch_all_upper() { let path = Nibbles::new(); @@ -4471,23 +4302,19 @@ mod tests { let trie = ParallelSparseTrie::from_root(node, masks, true).unwrap(); // Branch node should be in upper trie - assert_matches!( - trie.upper_subtrie.nodes.get(&path), - Some(SparseNode::Branch { state_mask, state: SparseNodeState::Dirty, .. }) - if *state_mask == 0b0000000000100001.into() + assert_eq!( + trie.upper_subtrie.nodes.get(&path).unwrap(), + &SparseNode::new_branch( + 0b0000000000100001.into(), + &[(0, child_hashes[0].as_hash().unwrap()), (5, child_hashes[1].as_hash().unwrap())] + ) ); - // Children should be in upper trie (paths of length 2) + // Children should not be revealed yet let child_path_0 = Nibbles::from_nibbles([0x0]); let child_path_5 = Nibbles::from_nibbles([0x5]); - assert_eq!( - trie.upper_subtrie.nodes.get(&child_path_0), - Some(&SparseNode::Hash(child_hashes[0].as_hash().unwrap())) - ); - assert_eq!( - trie.upper_subtrie.nodes.get(&child_path_5), - Some(&SparseNode::Hash(child_hashes[1].as_hash().unwrap())) - ); + assert!(!trie.upper_subtrie.nodes.contains_key(&child_path_0)); + assert!(!trie.upper_subtrie.nodes.contains_key(&child_path_5)); } #[test] @@ -4509,10 +4336,20 @@ mod tests { trie.reveal_nodes(&mut [ProofTrieNodeV2 { path, node, masks }]).unwrap(); // Branch node should be in upper trie, hash is memoized from the previous Hash node - assert_matches!( - trie.upper_subtrie.nodes.get(&path), - Some(SparseNode::Branch { state_mask, state: SparseNodeState::Cached { .. }, .. }) - if *state_mask == 0b1000000010000001.into() + assert_eq!( + trie.upper_subtrie.nodes.get(&path).unwrap(), + &SparseNode::new_branch( + 0b1000000010000001.into(), + &[ + (0x0, child_hashes[0].as_hash().unwrap()), + (0x7, child_hashes[1].as_hash().unwrap()), + (0xf, child_hashes[2].as_hash().unwrap()) + ] + ) + .with_state(SparseNodeState::Cached { + rlp_node: RlpNode::word_rlp(&B256::repeat_byte(0xAA)), + store_in_db_trie: Some(false), + }) ); // All children should be in lower tries since they have paths of length 3 @@ -4522,13 +4359,41 @@ mod tests { Nibbles::from_nibbles([0x1, 0xf]), ]; + let mut children = child_paths + .iter() + .map(|path| ProofTrieNodeV2 { + path: *path, + node: create_leaf_node([0x0], 1), + masks: None, + }) + .collect::>(); + + trie.reveal_nodes(&mut children).unwrap(); + + // Branch node should still be in upper trie but without any blinded children + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(&SparseNode::Branch { + state_mask, + state: SparseNodeState::Cached { ref rlp_node, store_in_db_trie: Some(false) }, + blinded_mask, + .. + }) if state_mask == 0b1000000010000001.into() && blinded_mask.is_empty() && *rlp_node == RlpNode::word_rlp(&B256::repeat_byte(0xAA)) + ); + for (i, child_path) in child_paths.iter().enumerate() { let idx = path_subtrie_index_unchecked(child_path); let lower_subtrie = trie.lower_subtries[idx].as_revealed_ref().unwrap(); assert_eq!(&lower_subtrie.path, child_path); assert_eq!( lower_subtrie.nodes.get(child_path), - Some(&SparseNode::Hash(child_hashes[i].as_hash().unwrap())), + Some(&SparseNode::Leaf { + key: Nibbles::from_nibbles([0x0]), + state: SparseNodeState::Cached { + rlp_node: child_hashes[i].clone(), + store_in_db_trie: Some(false) + } + }) ); } } @@ -4577,7 +4442,7 @@ mod tests { // Insert leaf_3 via update_leaf. This modifies the branch at [0x0] to add child // 0x2 and creates a fresh leaf node with hash: None in the lower subtrie. - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; trie.update_leaf(leaf_3_full_path, encode_account_value(3), provider).unwrap(); // Calculate subtrie indexes @@ -4642,8 +4507,10 @@ mod tests { let leaf_3 = create_leaf_node(leaf_3_key.to_vec(), account_3.nonce); // Create bottom branch node + let extension_path = Nibbles::from_nibbles([0, 0, 0]); let branch_1_path = Nibbles::from_nibbles([0, 0, 0, 0]); - let branch_1 = create_branch_node_with_children( + let branch_1 = create_branch_node( + Nibbles::from_nibbles([0]), &[0, 1], vec![ RlpNode::from_rlp(&alloy_rlp::encode(&leaf_1)), @@ -4651,31 +4518,22 @@ mod tests { ], ); - // Create an extension node - let extension_path = Nibbles::from_nibbles([0, 0, 0]); - let extension_key = Nibbles::from_nibbles([0]); - let extension = create_extension_node( - extension_key.to_vec(), - RlpNode::from_rlp(&alloy_rlp::encode(&branch_1)).as_hash().unwrap(), - ); - // Create top branch node let branch_2_path = Nibbles::from_nibbles([0, 0]); let branch_2 = create_branch_node_with_children( &[0, 1], vec![ - RlpNode::from_rlp(&alloy_rlp::encode(&extension)), + RlpNode::from_rlp(&alloy_rlp::encode(&branch_1)), RlpNode::from_rlp(&alloy_rlp::encode(&leaf_3)), ], ); // Reveal nodes - subtrie.reveal_node(branch_2_path, &branch_2, None).unwrap(); - subtrie.reveal_node(leaf_1_path, &leaf_1, None).unwrap(); - subtrie.reveal_node(extension_path, &extension, None).unwrap(); - subtrie.reveal_node(branch_1_path, &branch_1, None).unwrap(); - subtrie.reveal_node(leaf_2_path, &leaf_2, None).unwrap(); - subtrie.reveal_node(leaf_3_path, &leaf_3, None).unwrap(); + subtrie.reveal_node(branch_2_path, &branch_2, None, None).unwrap(); + subtrie.reveal_node(extension_path, &branch_1, None, None).unwrap(); + subtrie.reveal_node(leaf_1_path, &leaf_1, None, None).unwrap(); + subtrie.reveal_node(leaf_2_path, &leaf_2, None, None).unwrap(); + subtrie.reveal_node(leaf_3_path, &leaf_3, None, None).unwrap(); // Run hash builder for two leaf nodes let (_, _, proof_nodes, _, _) = run_hash_builder( @@ -4686,14 +4544,7 @@ mod tests { ], NoopAccountTrieCursor::default(), Default::default(), - [ - branch_1_path, - extension_path, - branch_2_path, - leaf_1_full_path, - leaf_2_full_path, - leaf_3_full_path, - ], + [extension_path, branch_2_path, leaf_1_full_path, leaf_2_full_path, leaf_3_full_path], ); // Update hashes for the subtrie @@ -4757,14 +4608,14 @@ mod tests { let mut trie = new_test_trie( [ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(TrieMask::new(0b1001))), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(TrieMask::new(0b1001), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_ext(Nibbles::from_nibbles([0x2, 0x3])), ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3]), - SparseNode::new_branch(TrieMask::new(0b0101)), + SparseNode::new_branch(TrieMask::new(0b0101), &[]), ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3, 0x1]), @@ -4779,7 +4630,7 @@ mod tests { .into_iter(), ); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove the leaf with a full path of 0x537 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x5, 0x3, 0x7])); @@ -4819,7 +4670,7 @@ mod tests { // let mut trie = new_test_trie( [ - (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b0011))), + (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b0011), &[])), (Nibbles::from_nibbles([0x0]), SparseNode::new_leaf(leaf_key([0x1, 0x2], 63))), (Nibbles::from_nibbles([0x1]), SparseNode::new_leaf(leaf_key([0x3, 0x4], 63))), ] @@ -4833,7 +4684,7 @@ mod tests { .insert(Nibbles::default(), BranchNodeCompact::new(0b11, 0, 0, vec![], None)); } - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove the leaf with a full path of 0x012 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x0, 0x1, 0x2])); @@ -4879,14 +4730,14 @@ mod tests { let mut trie = new_test_trie( [ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(TrieMask::new(0b0011))), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(TrieMask::new(0b0011), &[])), (Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_leaf(leaf_key([0x1, 0x2], 62))), (Nibbles::from_nibbles([0x5, 0x1]), SparseNode::new_leaf(leaf_key([0x3, 0x4], 62))), ] .into_iter(), ); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove the leaf with a full path of 0x5012 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x5, 0x0, 0x1, 0x2])); @@ -4928,16 +4779,16 @@ mod tests { // let mut trie = new_test_trie( [ - (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b0101))), + (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b0101), &[])), (Nibbles::from_nibbles([0x0]), SparseNode::new_leaf(leaf_key([0x1, 0x2], 63))), - (Nibbles::from_nibbles([0x2]), SparseNode::new_branch(TrieMask::new(0b0011))), + (Nibbles::from_nibbles([0x2]), SparseNode::new_branch(TrieMask::new(0b0011), &[])), (Nibbles::from_nibbles([0x2, 0x0]), SparseNode::new_leaf(leaf_key([0x3, 0x4], 62))), (Nibbles::from_nibbles([0x2, 0x1]), SparseNode::new_leaf(leaf_key([0x5, 0x6], 62))), ] .into_iter(), ); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove the leaf with a full path of 0x2034 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x2, 0x0, 0x3, 0x4])); @@ -4990,7 +4841,7 @@ mod tests { (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x1, 0x2, 0x3]))), ( Nibbles::from_nibbles([0x1, 0x2, 0x3]), - SparseNode::new_branch(TrieMask::new(0b0011000)), + SparseNode::new_branch(TrieMask::new(0b0011000), &[]), ), ( Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x3]), @@ -5002,7 +4853,7 @@ mod tests { ), ( Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4, 0x5]), - SparseNode::new_branch(TrieMask::new(0b0011)), + SparseNode::new_branch(TrieMask::new(0b0011), &[]), ), ( Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4, 0x5, 0x0]), @@ -5016,7 +4867,7 @@ mod tests { .into_iter(), ); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Verify initial state - the lower subtrie's path should be 0x123 let lower_subtrie_root_path = Nibbles::from_nibbles([0x1, 0x2, 0x3]); @@ -5063,26 +4914,39 @@ mod tests { // let mut trie = new_test_trie( [ - (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b0011))), + ( + Nibbles::default(), + SparseNode::new_branch( + TrieMask::new(0b0011), + &[(0x1, B256::repeat_byte(0xab))], + ), + ), (Nibbles::from_nibbles([0x0]), SparseNode::new_leaf(leaf_key([0x1, 0x2], 63))), - (Nibbles::from_nibbles([0x1]), SparseNode::Hash(B256::repeat_byte(0xab))), ] .into_iter(), ); // Create a mock provider that will reveal the blinded leaf - let mut provider = MockTrieNodeProvider::new(); let revealed_leaf = create_leaf_node(leaf_key([0x3, 0x4], 63).to_vec(), 42); let mut encoded = Vec::new(); revealed_leaf.encode(&mut encoded); - provider.add_revealed_node( - Nibbles::from_nibbles([0x1]), - RevealedNode { node: encoded.into(), tree_mask: None, hash_mask: None }, - ); - // Remove the leaf with a full path of 0x012 + // Try removing the leaf with a full path of 0x012, this should fail because the leaf is + // blinded let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x0, 0x1, 0x2])); - trie.remove_leaf(&leaf_full_path, provider).unwrap(); + let Err(err) = trie.remove_leaf(&leaf_full_path, NoRevealProvider) else { + panic!("expected error"); + }; + assert_matches!(err.kind(), SparseTrieErrorKind::BlindedNode { path, hash } if *path == Nibbles::from_nibbles([0x1]) && *hash == B256::repeat_byte(0xab)); + + // Now reveal the leaf and try removing it again + trie.reveal_nodes(&mut [ProofTrieNodeV2 { + path: Nibbles::from_nibbles([0x1]), + node: revealed_leaf, + masks: None, + }]) + .unwrap(); + trie.remove_leaf(&leaf_full_path, NoRevealProvider).unwrap(); let upper_subtrie = &trie.upper_subtrie; @@ -5112,7 +4976,7 @@ mod tests { SparseNode::new_leaf(pad_nibbles_right(Nibbles::from_nibbles([0x1, 0x2, 0x3]))), ))); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove the leaf with a full key of 0x123 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x1, 0x2, 0x3])); @@ -5154,6 +5018,8 @@ mod tests { SparseNode::Branch { state_mask: TrieMask::new(0b0011), state: make_revealed(B256::repeat_byte(0x10)), + blinded_mask: Default::default(), + blinded_hashes: Default::default(), }, ), ( @@ -5168,6 +5034,8 @@ mod tests { SparseNode::Branch { state_mask: TrieMask::new(0b11100), state: make_revealed(B256::repeat_byte(0x30)), + blinded_mask: Default::default(), + blinded_hashes: Default::default(), }, ), ( @@ -5202,12 +5070,12 @@ mod tests { .into_iter(), ); - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Remove a leaf which does not exist; this should have no effect. trie.remove_leaf( &pad_nibbles_right(Nibbles::from_nibbles([0x0, 0x1, 0x2, 0x3, 0x4, 0xF])), - &provider, + provider, ) .unwrap(); for (path, node) in trie.all_nodes() { @@ -5216,7 +5084,7 @@ mod tests { // Remove the leaf at path 0x01234 let leaf_full_path = pad_nibbles_right(Nibbles::from_nibbles([0x0, 0x1, 0x2, 0x3, 0x4])); - trie.remove_leaf(&leaf_full_path, &provider).unwrap(); + trie.remove_leaf(&leaf_full_path, provider).unwrap(); let upper_subtrie = &trie.upper_subtrie; let lower_subtrie_10 = trie.lower_subtries[0x01].as_revealed_ref().unwrap(); @@ -5278,7 +5146,8 @@ mod tests { let leaf_2 = create_leaf_node(leaf_2_key.to_vec(), account_2.nonce); // Create branch node with children at indices 0 and 1 - let branch = create_branch_node_with_children( + let branch = create_branch_node( + extension_key, &[0, 1], vec![ RlpNode::from_rlp(&alloy_rlp::encode(&leaf_1)), @@ -5286,16 +5155,9 @@ mod tests { ], ); - // Create extension node pointing to branch - let extension = create_extension_node( - extension_key.to_vec(), - RlpNode::from_rlp(&alloy_rlp::encode(&branch)).as_hash().unwrap(), - ); - // Step 2: Reveal nodes in the trie - let mut trie = ParallelSparseTrie::from_root(extension, None, true).unwrap(); + let mut trie = ParallelSparseTrie::from_root(branch, None, true).unwrap(); trie.reveal_nodes(&mut [ - ProofTrieNodeV2 { path: branch_path, node: branch, masks: None }, ProofTrieNodeV2 { path: leaf_1_path, node: leaf_1, masks: None }, ProofTrieNodeV2 { path: leaf_2_path, node: leaf_2, masks: None }, ]) @@ -5604,14 +5466,14 @@ mod tests { .collect::>(), BTreeMap::from_iter([ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1101.into())), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1101.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_ext(Nibbles::from_nibbles([0x2, 0x3])) ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3]), - SparseNode::new_branch(0b1010.into()) + SparseNode::new_branch(0b1010.into(), &[]) ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3, 0x1]), @@ -5625,12 +5487,15 @@ mod tests { Nibbles::from_nibbles([0x5, 0x2]), SparseNode::new_leaf(leaf_key([0x0, 0x1, 0x3], 62)) ), - (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into())), + (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x3, 0x1]), SparseNode::new_leaf(leaf_key([0x0, 0x2], 61)) ), - (Nibbles::from_nibbles([0x5, 0x3, 0x3]), SparseNode::new_branch(0b0101.into())), + ( + Nibbles::from_nibbles([0x5, 0x3, 0x3]), + SparseNode::new_branch(0b0101.into(), &[]) + ), ( Nibbles::from_nibbles([0x5, 0x3, 0x3, 0x0]), SparseNode::new_leaf(leaf_key([0x2], 60)) @@ -5667,14 +5532,14 @@ mod tests { .collect::>(), BTreeMap::from_iter([ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into())), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_ext(Nibbles::from_nibbles([0x2, 0x3])) ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3]), - SparseNode::new_branch(0b1010.into()) + SparseNode::new_branch(0b1010.into(), &[]) ), ( Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3, 0x1]), @@ -5684,12 +5549,15 @@ mod tests { Nibbles::from_nibbles([0x5, 0x0, 0x2, 0x3, 0x3]), SparseNode::new_leaf(leaf_key([], 59)) ), - (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into())), + (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x3, 0x1]), SparseNode::new_leaf(leaf_key([0x0, 0x2], 61)) ), - (Nibbles::from_nibbles([0x5, 0x3, 0x3]), SparseNode::new_branch(0b0101.into())), + ( + Nibbles::from_nibbles([0x5, 0x3, 0x3]), + SparseNode::new_branch(0b0101.into(), &[]) + ), ( Nibbles::from_nibbles([0x5, 0x3, 0x3, 0x0]), SparseNode::new_leaf(leaf_key([0x2], 60)) @@ -5723,17 +5591,20 @@ mod tests { .collect::>(), BTreeMap::from_iter([ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into())), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_leaf(leaf_key([0x2, 0x3, 0x3], 62)) ), - (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into())), + (Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_branch(0b1010.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x3, 0x1]), SparseNode::new_leaf(leaf_key([0x0, 0x2], 61)) ), - (Nibbles::from_nibbles([0x5, 0x3, 0x3]), SparseNode::new_branch(0b0101.into())), + ( + Nibbles::from_nibbles([0x5, 0x3, 0x3]), + SparseNode::new_branch(0b0101.into(), &[]) + ), ( Nibbles::from_nibbles([0x5, 0x3, 0x3, 0x0]), SparseNode::new_leaf(leaf_key([0x2], 60)) @@ -5765,7 +5636,7 @@ mod tests { .collect::>(), BTreeMap::from_iter([ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into())), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_leaf(leaf_key([0x2, 0x3, 0x3], 62)) @@ -5774,7 +5645,10 @@ mod tests { Nibbles::from_nibbles([0x5, 0x3]), SparseNode::new_ext(Nibbles::from_nibbles([0x3])) ), - (Nibbles::from_nibbles([0x5, 0x3, 0x3]), SparseNode::new_branch(0b0101.into())), + ( + Nibbles::from_nibbles([0x5, 0x3, 0x3]), + SparseNode::new_branch(0b0101.into(), &[]) + ), ( Nibbles::from_nibbles([0x5, 0x3, 0x3, 0x0]), SparseNode::new_leaf(leaf_key([0x2], 60)) @@ -5804,7 +5678,7 @@ mod tests { .collect::>(), BTreeMap::from_iter([ (Nibbles::default(), SparseNode::new_ext(Nibbles::from_nibbles([0x5]))), - (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into())), + (Nibbles::from_nibbles([0x5]), SparseNode::new_branch(0b1001.into(), &[])), ( Nibbles::from_nibbles([0x5, 0x0]), SparseNode::new_leaf(leaf_key([0x2, 0x3, 0x3], 62)) @@ -6224,18 +6098,18 @@ mod tests { sparse.reveal_nodes(&mut revealed_nodes).unwrap(); // Check that the branch node exists with only two nibbles set - assert_eq!( + assert_matches!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(&SparseNode::Branch { state_mask: 0b101.into(), state: SparseNodeState::Dirty }) + Some(&SparseNode::Branch { state_mask, state: SparseNodeState::Dirty, .. }) if state_mask == TrieMask::new(0b101) ); // Insert the leaf for the second key sparse.update_leaf(key2(), value_encoded(), &provider).unwrap(); // Check that the branch node was updated and another nibble was set - assert_eq!( + assert_matches!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(&SparseNode::Branch { state_mask: 0b111.into(), state: SparseNodeState::Dirty }) + Some(&SparseNode::Branch { state_mask, state: SparseNodeState::Dirty, .. }) if state_mask == TrieMask::new(0b111) ); // Generate the proof for the third key and reveal it in the sparse trie @@ -6259,9 +6133,9 @@ mod tests { sparse.reveal_nodes(&mut revealed_nodes).unwrap(); // Check that nothing changed in the branch node - assert_eq!( + assert_matches!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(&SparseNode::Branch { state_mask: 0b111.into(), state: SparseNodeState::Dirty }) + Some(&SparseNode::Branch { state_mask, state: SparseNodeState::Dirty, .. }) if state_mask == TrieMask::new(0b111) ); // Generate the nodes for the full trie with all three key using the hash builder, and @@ -6345,9 +6219,9 @@ mod tests { sparse.reveal_nodes(&mut revealed_nodes).unwrap(); // Check that the branch node exists - assert_eq!( + assert_matches!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(&SparseNode::Branch { state_mask: 0b11.into(), state: SparseNodeState::Dirty }) + Some(&SparseNode::Branch { state_mask, state: SparseNodeState::Dirty, .. }) if state_mask == TrieMask::new(0b11) ); // Remove the leaf for the first key @@ -6415,26 +6289,22 @@ mod tests { [Nibbles::default()], ); + let mut nodes = Vec::new(); + + for (path, node) in hash_builder_proof_nodes.nodes_sorted() { + let hash_mask = branch_node_hash_masks.get(&path).copied(); + let tree_mask = branch_node_tree_masks.get(&path).copied(); + let masks = BranchNodeMasks::from_optional(hash_mask, tree_mask); + nodes.push((path, TrieNode::decode(&mut &node[..]).unwrap(), masks)); + } + + nodes.sort_unstable_by(|a, b| reth_trie_common::depth_first_cmp(&a.0, &b.0)); + + let nodes = ProofTrieNodeV2::from_sorted_trie_nodes(nodes); + let provider = DefaultTrieNodeProvider; - let masks = match ( - branch_node_hash_masks.get(&Nibbles::default()).copied(), - branch_node_tree_masks.get(&Nibbles::default()).copied(), - ) { - (Some(h), Some(t)) => Some(BranchNodeMasks { hash_mask: h, tree_mask: t }), - (Some(h), None) => { - Some(BranchNodeMasks { hash_mask: h, tree_mask: TrieMask::default() }) - } - (None, Some(t)) => { - Some(BranchNodeMasks { hash_mask: TrieMask::default(), tree_mask: t }) - } - (None, None) => None, - }; - let mut sparse = ParallelSparseTrie::from_root( - TrieNodeV2::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(), - masks, - false, - ) - .unwrap(); + let mut sparse = + ParallelSparseTrie::from_root(nodes[0].node.clone(), nodes[0].masks, false).unwrap(); // Check that the root extension node exists assert_matches!( @@ -6446,9 +6316,9 @@ mod tests { sparse.update_leaf(key3(), value_encoded(), &provider).unwrap(); // Check that the extension node was turned into a branch node - assert_matches!( + assert_eq!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(SparseNode::Branch { state_mask, state: SparseNodeState::Dirty }) if *state_mask == TrieMask::new(0b11) + Some(&SparseNode::new_branch(TrieMask::new(0b11), &[])) ); // Generate the proof for the first key and reveal it in the sparse trie @@ -6472,9 +6342,9 @@ mod tests { sparse.reveal_nodes(&mut revealed_nodes).unwrap(); // Check that the branch node wasn't overwritten by the extension node in the proof - assert_matches!( + assert_eq!( sparse.upper_subtrie.nodes.get(&Nibbles::default()), - Some(SparseNode::Branch { state_mask, state: SparseNodeState::Dirty }) if *state_mask == TrieMask::new(0b11) + Some(&SparseNode::new_branch(TrieMask::new(0b11), &[])) ); } @@ -7573,12 +7443,7 @@ mod tests { ( // Branch at 0x123, child 4 Nibbles::from_nibbles_unchecked([0x1, 0x2, 0x3]), - SparseNode::new_branch(TrieMask::new(0b10000)), - ), - ( - // Blinded node at 0x1234 - leaf_path, - SparseNode::Hash(blinded_hash), + SparseNode::new_branch(TrieMask::new(0b10000), &[(0x4, blinded_hash)]), ), ] .into_iter(), @@ -7602,8 +7467,10 @@ mod tests { [ // Root is a branch with child 0x1 (blinded) and 0x5 (revealed leaf) // So we set Bit 1 and Bit 5 in the state_mask - (Nibbles::default(), SparseNode::new_branch(TrieMask::new(0b100010))), - (path_to_blind, SparseNode::Hash(blinded_hash)), + ( + Nibbles::default(), + SparseNode::new_branch(TrieMask::new(0b100010), &[(0x1, blinded_hash)]), + ), ( Nibbles::from_nibbles_unchecked([0x5]), SparseNode::new_leaf(Nibbles::from_nibbles_unchecked([0x6, 0x7, 0x8])), @@ -7643,7 +7510,8 @@ mod tests { B256::from(hex!("9e514589a9c9210b783c19fa3f0b384bbfaefe98f10ea189a2bfc58c6bf000a1")), B256::from(hex!("85bdaabbcfa583cbd049650e41d3d19356bd833b3ed585cf225a3548557c7fa3")), ]; - let branch_0x3_node = create_branch_node_with_children( + let branch_0x3_node = create_branch_node( + Nibbles::from_nibbles([0x3]), &[0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf], branch_0x3_hashes.iter().map(RlpNode::word_rlp), ); @@ -7672,48 +7540,25 @@ mod tests { &[0x3, 0x7, 0xc], branch_0x31c_hashes.into_iter().map(|h| RlpNode::word_rlp(&h)), ); - let mut branch_0x31c_node_encoded = Vec::new(); - branch_0x31c_node.encode(&mut branch_0x31c_node_encoded); - - // Create a mock provider and preload 0x31c onto it, it will be revealed during remove_leaf. - let mut provider = MockTrieNodeProvider::new(); - provider.add_revealed_node( - Nibbles::from_nibbles([0x3, 0x1, 0xc]), - RevealedNode { - node: branch_0x31c_node_encoded.into(), - tree_mask: Some(0.into()), - hash_mask: Some(4096.into()), - }, - ); // Reveal the trie structure using ProofTrieNode - let mut proof_nodes = vec![ - ProofTrieNodeV2 { - path: Nibbles::from_nibbles([0x3]), - node: branch_0x3_node, - masks: Some(BranchNodeMasks { - tree_mask: TrieMask::new(26099), - hash_mask: TrieMask::new(65535), - }), - }, - ProofTrieNodeV2 { - path: Nibbles::from_nibbles([0x3, 0x1]), - node: branch_0x31_node, - masks: Some(BranchNodeMasks { - tree_mask: TrieMask::new(4096), - hash_mask: TrieMask::new(4096), - }), - }, - ]; + let mut proof_nodes = vec![ProofTrieNodeV2 { + path: Nibbles::from_nibbles([0x3, 0x1]), + node: branch_0x31_node, + masks: Some(BranchNodeMasks { + tree_mask: TrieMask::new(4096), + hash_mask: TrieMask::new(4096), + }), + }]; // Create a sparse trie and reveal nodes let mut trie = ParallelSparseTrie::default() .with_root( - TrieNodeV2::Extension(ExtensionNode { - key: Nibbles::from_nibbles([0x3]), - child: RlpNode::word_rlp(&B256::ZERO), + branch_0x3_node, + Some(BranchNodeMasks { + tree_mask: TrieMask::new(26099), + hash_mask: TrieMask::new(65535), }), - None, true, ) .expect("root revealed"); @@ -7721,10 +7566,23 @@ mod tests { trie.reveal_nodes(&mut proof_nodes).unwrap(); // Update the leaf in order to reveal it in the trie - trie.update_leaf(leaf_nibbles, leaf_value, &provider).unwrap(); + trie.update_leaf(leaf_nibbles, leaf_value, NoRevealProvider).unwrap(); - // Now delete the leaf - trie.remove_leaf(&leaf_nibbles, &provider).unwrap(); + // Now try deleting the leaf + let Err(err) = trie.remove_leaf(&leaf_nibbles, NoRevealProvider) else { + panic!("expected blinded node error"); + }; + assert_matches!(err.kind(), SparseTrieErrorKind::BlindedNode { path, hash } if path == &Nibbles::from_nibbles([0x3, 0x1, 0xc])); + + trie.reveal_nodes(&mut [ProofTrieNodeV2 { + path: Nibbles::from_nibbles([0x3, 0x1, 0xc]), + node: branch_0x31c_node, + masks: Some(BranchNodeMasks { tree_mask: 0.into(), hash_mask: 4096.into() }), + }]) + .unwrap(); + + // Now remove the leaf again, this should succeed + trie.remove_leaf(&leaf_nibbles, NoRevealProvider).unwrap(); // Compute the root to trigger updates let _ = trie.root(); @@ -7811,7 +7669,7 @@ mod tests { /// missing values stored in `upper_subtrie.inner.values`. #[test] fn test_get_leaf_value_upper_subtrie_via_update_leaf() { - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Create an empty trie with an empty root let mut trie = ParallelSparseTrie::default() @@ -7824,7 +7682,7 @@ mod tests { // Insert the leaf - since the trie is empty, the leaf node will be created // at the root level (depth 0), which is in the upper subtrie - trie.update_leaf(full_path, value.clone(), &provider).unwrap(); + trie.update_leaf(full_path, value.clone(), provider).unwrap(); // Verify the value is stored in upper_subtrie (where update_leaf puts it) assert!( @@ -7842,7 +7700,7 @@ mod tests { /// Test that `get_leaf_value` works for values in both upper and lower subtries. #[test] fn test_get_leaf_value_upper_and_lower_subtries() { - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Create an empty trie let mut trie = ParallelSparseTrie::default() @@ -7852,12 +7710,12 @@ mod tests { // Insert first leaf - will be at root level (upper subtrie) let path1 = pad_nibbles_right(Nibbles::from_nibbles([0x0, 0xA])); let value1 = encode_account_value(1); - trie.update_leaf(path1, value1.clone(), &provider).unwrap(); + trie.update_leaf(path1, value1.clone(), provider).unwrap(); // Insert second leaf with different prefix - creates a branch let path2 = pad_nibbles_right(Nibbles::from_nibbles([0x1, 0xB])); let value2 = encode_account_value(2); - trie.update_leaf(path2, value2.clone(), &provider).unwrap(); + trie.update_leaf(path2, value2.clone(), provider).unwrap(); // Both values should be retrievable assert_eq!(trie.get_leaf_value(&path1), Some(&value1)); @@ -7867,7 +7725,7 @@ mod tests { /// Test that `get_leaf_value` works for storage tries which are often very sparse. #[test] fn test_get_leaf_value_sparse_storage_trie() { - let provider = MockTrieNodeProvider::new(); + let provider = NoRevealProvider; // Simulate a storage trie with a single slot let mut trie = ParallelSparseTrie::default() @@ -7877,7 +7735,7 @@ mod tests { // Single storage slot - leaf will be at root (depth 0) let slot_path = pad_nibbles_right(Nibbles::from_nibbles([0x2, 0x9])); let slot_value = alloy_rlp::encode(U256::from(12345)); - trie.update_leaf(slot_path, slot_value.clone(), &provider).unwrap(); + trie.update_leaf(slot_path, slot_value.clone(), provider).unwrap(); // Value should be retrievable assert_eq!(trie.get_leaf_value(&slot_path), Some(&slot_value)); @@ -7983,8 +7841,8 @@ mod tests { ); if max_depth == 0 { - // Root + 4 hash stubs for children at [0], [1], [2], [3] - assert_eq!(nodes_after, 5, "root + 4 hash stubs after prune(0)"); + // Root with 4 blinded hashes for children at [0], [1], [2], [3] + assert_eq!(nodes_after, 1, "root"); } } } @@ -8097,8 +7955,8 @@ mod tests { } assert_eq!(root_before, trie.root(), "root hash should be preserved"); - // Root + extension + 2 hash stubs (for the two leaves' parent branches) - assert_eq!(trie.size_hint(), 4, "root + extension + hash stubs after prune(1)"); + // Root + branch + assert_eq!(trie.size_hint(), 2, "root + extension + hash stubs after prune(1)"); } #[test] diff --git a/crates/trie/sparse/src/trie.rs b/crates/trie/sparse/src/trie.rs index 95277e06b5..4fc1b5b349 100644 --- a/crates/trie/sparse/src/trie.rs +++ b/crates/trie/sparse/src/trie.rs @@ -5,7 +5,7 @@ use crate::{ use alloc::{borrow::Cow, boxed::Box, vec::Vec}; use alloy_primitives::{map::B256Map, B256}; use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult}; -use reth_trie_common::{BranchNodeMasks, Nibbles, RlpNode, TrieMask, TrieNode, TrieNodeV2}; +use reth_trie_common::{BranchNodeMasks, Nibbles, RlpNode, TrieMask, TrieNodeV2}; use tracing::instrument; /// A sparse trie that is either in a "blind" state (no nodes are revealed, root node hash is @@ -335,8 +335,6 @@ impl SparseNodeType { pub enum SparseNode { /// Empty trie node. Empty, - /// The hash of the node that was not revealed. - Hash(B256), /// Sparse leaf node with remaining key suffix. Leaf { /// Remaining key suffix for the leaf node. @@ -357,32 +355,39 @@ pub enum SparseNode { state_mask: TrieMask, /// Tracker for the node's state, e.g. cached `RlpNode` tracking. state: SparseNodeState, + /// The mask of the children that are blinded. + blinded_mask: TrieMask, + /// The hashes of the children that are blinded. + blinded_hashes: Box<[B256; 16]>, }, } impl SparseNode { - /// Create new sparse node from [`TrieNode`]. - pub fn from_node(node: TrieNode) -> Self { - match node { - TrieNode::EmptyRoot => Self::Empty, - TrieNode::Leaf(leaf) => Self::new_leaf(leaf.key), - TrieNode::Extension(ext) => Self::new_ext(ext.key), - TrieNode::Branch(branch) => Self::new_branch(branch.state_mask), - } - } + /// Create new [`SparseNode::Branch`] from state mask and blinded nodes. + #[cfg(test)] + pub fn new_branch(state_mask: TrieMask, blinded_children: &[(u8, B256)]) -> Self { + let mut blinded_mask = TrieMask::default(); + let mut blinded_hashes = Box::new([B256::ZERO; 16]); - /// Create new [`SparseNode::Branch`] from state mask. - pub const fn new_branch(state_mask: TrieMask) -> Self { - Self::Branch { state_mask, state: SparseNodeState::Dirty } + for (nibble, hash) in blinded_children { + blinded_mask.set_bit(*nibble); + blinded_hashes[*nibble as usize] = *hash; + } + Self::Branch { state_mask, state: SparseNodeState::Dirty, blinded_mask, blinded_hashes } } /// Create new [`SparseNode::Branch`] with two bits set. - pub const fn new_split_branch(bit_a: u8, bit_b: u8) -> Self { + pub fn new_split_branch(bit_a: u8, bit_b: u8) -> Self { let state_mask = TrieMask::new( // set bits for both children (1u16 << bit_a) | (1u16 << bit_b), ); - Self::Branch { state_mask, state: SparseNodeState::Dirty } + Self::Branch { + state_mask, + state: SparseNodeState::Dirty, + blinded_mask: TrieMask::default(), + blinded_hashes: Box::new([B256::ZERO; 16]), + } } /// Create new [`SparseNode::Extension`] from the key slice. @@ -395,16 +400,10 @@ impl SparseNode { Self::Leaf { key, state: SparseNodeState::Dirty } } - /// Returns `true` if the node is a hash node. - pub const fn is_hash(&self) -> bool { - matches!(self, Self::Hash(_)) - } - /// Returns the cached [`RlpNode`] of the node, if it's available. pub fn cached_rlp_node(&self) -> Option> { match &self { Self::Empty => None, - Self::Hash(hash) => Some(Cow::Owned(RlpNode::word_rlp(hash))), Self::Leaf { state, .. } | Self::Extension { state, .. } | Self::Branch { state, .. } => state.cached_rlp_node().map(Cow::Borrowed), @@ -415,7 +414,6 @@ impl SparseNode { pub fn cached_hash(&self) -> Option { match &self { Self::Empty => None, - Self::Hash(hash) => Some(*hash), Self::Leaf { state, .. } | Self::Extension { state, .. } | Self::Branch { state, .. } => state.cached_hash(), @@ -424,11 +422,11 @@ impl SparseNode { /// Sets the hash of the node for testing purposes. /// - /// For [`SparseNode::Empty`] and [`SparseNode::Hash`] nodes, this method panics. + /// For [`SparseNode::Empty`] nodes, this method panics. #[cfg(any(test, feature = "test-utils"))] pub fn set_state(&mut self, new_state: SparseNodeState) { match self { - Self::Empty | Self::Hash(_) => { + Self::Empty => { panic!("Cannot set hash for Empty or Hash nodes") } Self::Leaf { state, .. } | @@ -439,10 +437,17 @@ impl SparseNode { } } + /// Sets the state of the node and returns a new node with the same state. + #[cfg(any(test, feature = "test-utils"))] + pub fn with_state(mut self, state: SparseNodeState) -> Self { + self.set_state(state); + self + } + /// Returns the memory size of this node in bytes. pub const fn memory_size(&self) -> usize { match self { - Self::Empty | Self::Hash(_) | Self::Branch { .. } => core::mem::size_of::(), + Self::Empty | Self::Branch { .. } => core::mem::size_of::(), Self::Leaf { key, .. } | Self::Extension { key, .. } => { core::mem::size_of::() + key.len() }