From ebd57f77bcdc3891c9d4960bed9af0fa88182e46 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 19 Jun 2025 15:13:12 +0200 Subject: [PATCH] perf(trie): `ParallelSparseTrie::reveal_node` (#16894) Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com> --- crates/trie/sparse/src/parallel_trie.rs | 488 +++++++++++++++++++++++- 1 file changed, 474 insertions(+), 14 deletions(-) diff --git a/crates/trie/sparse/src/parallel_trie.rs b/crates/trie/sparse/src/parallel_trie.rs index 6351999c5d..0e7a97efac 100644 --- a/crates/trie/sparse/src/parallel_trie.rs +++ b/crates/trie/sparse/src/parallel_trie.rs @@ -1,11 +1,15 @@ use crate::{blinded::BlindedProvider, SparseNode, SparseTrieUpdates, TrieMasks}; use alloc::{boxed::Box, vec::Vec}; -use alloy_primitives::{map::HashMap, B256}; +use alloy_primitives::{ + map::{Entry, HashMap}, + B256, +}; +use alloy_rlp::Decodable; use alloy_trie::TrieMask; -use reth_execution_errors::SparseTrieResult; +use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult}; use reth_trie_common::{ prefix_set::{PrefixSet, PrefixSetMut}, - Nibbles, TrieNode, + Nibbles, TrieNode, CHILD_INDEX_RANGE, }; use tracing::trace; @@ -37,6 +41,22 @@ impl Default for ParallelSparseTrie { } impl ParallelSparseTrie { + /// Returns mutable ref to the lower `SparseSubtrie` for the given path, or None if the path + /// belongs to the upper trie. + fn lower_subtrie_for_path(&mut self, path: &Nibbles) -> Option<&mut SparseSubtrie> { + match SparseSubtrieType::from_path(path) { + SparseSubtrieType::Upper => None, + SparseSubtrieType::Lower(idx) => { + if self.lower_subtries[idx].is_none() { + let upper_path = path.slice(..2); + self.lower_subtries[idx] = Some(SparseSubtrie::new(upper_path)); + } + + self.lower_subtries[idx].as_mut() + } + } + } + /// Creates a new revealed sparse trie from the given root node. /// /// # Returns @@ -58,7 +78,6 @@ impl ParallelSparseTrie { /// It handles different node types (leaf, extension, branch) by appropriately /// adding them to the trie structure and recursively revealing their children. /// - /// /// # Returns /// /// `Ok(())` if successful, or an error if node was not revealed. @@ -68,10 +87,50 @@ impl ParallelSparseTrie { node: TrieNode, masks: TrieMasks, ) -> SparseTrieResult<()> { - let _path = path; - let _node = node; - let _masks = masks; - todo!() + // TODO parallelize + if let Some(subtrie) = self.lower_subtrie_for_path(&path) { + return subtrie.reveal_node(path, &node, masks); + } + + // If there is no subtrie for the path it means the path is 2 or less nibbles, and so + // belongs to the upper trie. + self.upper_subtrie.reveal_node(path.clone(), &node, masks)?; + + // The previous upper_trie.reveal_node call will not have revealed any child nodes via + // reveal_node_or_hash if the child node would be found on a lower subtrie. We handle that + // here by manually checking the specific cases where this could happen, and calling + // reveal_node_or_hash for each. + match node { + TrieNode::Branch(branch) => { + // If a branch is at the second level of the trie then it will be in the upper trie, + // but all of its children will be in the lower trie. + if path.len() == 2 { + let mut stack_ptr = branch.as_ref().first_child_index(); + for idx in CHILD_INDEX_RANGE { + if branch.state_mask.is_bit_set(idx) { + let mut child_path = path.clone(); + child_path.push_unchecked(idx); + self.lower_subtrie_for_path(&child_path) + .expect("child_path must have a lower subtrie") + .reveal_node_or_hash(child_path, &branch.stack[stack_ptr])?; + stack_ptr += 1; + } + } + } + } + TrieNode::Extension(ext) => { + let mut child_path = path.clone(); + child_path.extend_from_slice_unchecked(&ext.key); + if child_path.len() > 2 { + self.lower_subtrie_for_path(&child_path) + .expect("child_path must have a lower subtrie") + .reveal_node_or_hash(child_path, &ext.child)?; + } + } + TrieNode::EmptyRoot | TrieNode::Leaf(_) => (), + } + + Ok(()) } /// Updates or inserts a leaf node at the specified key path with the provided RLP-encoded @@ -219,11 +278,213 @@ pub struct SparseSubtrie { } impl SparseSubtrie { - /// Creates a new sparse subtrie with the given root path. - pub fn new(path: Nibbles) -> Self { + fn new(path: Nibbles) -> Self { Self { path, ..Default::default() } } + /// Returns true if the current path and its child are both found in the same level. This + /// function assumes that if `current_path` is in a lower level then `child_path` is too. + fn is_child_same_level(current_path: &Nibbles, child_path: &Nibbles) -> bool { + let current_level = core::mem::discriminant(&SparseSubtrieType::from_path(current_path)); + let child_level = core::mem::discriminant(&SparseSubtrieType::from_path(child_path)); + current_level == child_level + } + + /// Internal implementation of the method of the same name on `ParallelSparseTrie`. + fn reveal_node( + &mut self, + path: Nibbles, + node: &TrieNode, + masks: TrieMasks, + ) -> SparseTrieResult<()> { + // If the node is already revealed and it's not a hash node, do nothing. + if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) { + return Ok(()) + } + + if let Some(tree_mask) = masks.tree_mask { + self.branch_node_tree_masks.insert(path.clone(), tree_mask); + } + if let Some(hash_mask) = masks.hash_mask { + self.branch_node_hash_masks.insert(path.clone(), hash_mask); + } + + match node { + TrieNode::EmptyRoot => { + // For an empty root, ensure that we are at the root path, and at the upper subtrie. + debug_assert!(path.is_empty()); + debug_assert!(self.path.is_empty()); + self.nodes.insert(path, SparseNode::Empty); + } + TrieNode::Branch(branch) => { + // For a branch node, iterate over all potential children + let mut stack_ptr = branch.as_ref().first_child_index(); + for idx in CHILD_INDEX_RANGE { + if branch.state_mask.is_bit_set(idx) { + let mut child_path = path.clone(); + child_path.push_unchecked(idx); + if Self::is_child_same_level(&path, &child_path) { + // Reveal each child node or hash it has, but only if the child is on + // the same level as the parent. + self.reveal_node_or_hash(child_path, &branch.stack[stack_ptr])?; + } + stack_ptr += 1; + } + } + // Update the branch node entry in the nodes map, handling cases where a blinded + // node is now replaced with a revealed node. + match self.nodes.entry(path) { + Entry::Occupied(mut entry) => match entry.get() { + // Replace a hash node with a fully revealed branch node. + SparseNode::Hash(hash) => { + entry.insert(SparseNode::Branch { + state_mask: branch.state_mask, + // Memoize the hash of a previously blinded node in a new branch + // node. + hash: Some(*hash), + store_in_db_trie: Some( + masks.hash_mask.is_some_and(|mask| !mask.is_empty()) || + masks.tree_mask.is_some_and(|mask| !mask.is_empty()), + ), + }); + } + // Branch node already exists, or an extension node was placed where a + // branch node was before. + SparseNode::Branch { .. } | SparseNode::Extension { .. } => {} + // All other node types can't be handled. + node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => { + return Err(SparseTrieErrorKind::Reveal { + path: entry.key().clone(), + node: Box::new(node.clone()), + } + .into()) + } + }, + Entry::Vacant(entry) => { + entry.insert(SparseNode::new_branch(branch.state_mask)); + } + } + } + TrieNode::Extension(ext) => match self.nodes.entry(path.clone()) { + Entry::Occupied(mut entry) => match entry.get() { + // Replace a hash node with a revealed extension node. + SparseNode::Hash(hash) => { + let mut child_path = entry.key().clone(); + child_path.extend_from_slice_unchecked(&ext.key); + entry.insert(SparseNode::Extension { + key: ext.key.clone(), + // Memoize the hash of a previously blinded node in a new extension + // node. + hash: Some(*hash), + store_in_db_trie: None, + }); + if Self::is_child_same_level(&path, &child_path) { + self.reveal_node_or_hash(child_path, &ext.child)?; + } + } + // Extension node already exists, or an extension node was placed where a branch + // node was before. + SparseNode::Extension { .. } | SparseNode::Branch { .. } => {} + // All other node types can't be handled. + node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => { + return Err(SparseTrieErrorKind::Reveal { + path: entry.key().clone(), + node: Box::new(node.clone()), + } + .into()) + } + }, + Entry::Vacant(entry) => { + let mut child_path = entry.key().clone(); + child_path.extend_from_slice_unchecked(&ext.key); + entry.insert(SparseNode::new_ext(ext.key.clone())); + if Self::is_child_same_level(&path, &child_path) { + self.reveal_node_or_hash(child_path, &ext.child)?; + } + } + }, + TrieNode::Leaf(leaf) => match self.nodes.entry(path) { + Entry::Occupied(mut entry) => match entry.get() { + // Replace a hash node with a revealed leaf node and store leaf node value. + SparseNode::Hash(hash) => { + let mut full = entry.key().clone(); + full.extend_from_slice_unchecked(&leaf.key); + self.values.insert(full, leaf.value.clone()); + entry.insert(SparseNode::Leaf { + key: leaf.key.clone(), + // Memoize the hash of a previously blinded node in a new leaf + // node. + hash: Some(*hash), + }); + } + // Leaf node already exists. + SparseNode::Leaf { .. } => {} + // All other node types can't be handled. + node @ (SparseNode::Empty | + SparseNode::Extension { .. } | + SparseNode::Branch { .. }) => { + return Err(SparseTrieErrorKind::Reveal { + path: entry.key().clone(), + node: Box::new(node.clone()), + } + .into()) + } + }, + Entry::Vacant(entry) => { + let mut full = entry.key().clone(); + full.extend_from_slice_unchecked(&leaf.key); + entry.insert(SparseNode::new_leaf(leaf.key.clone())); + self.values.insert(full, leaf.value.clone()); + } + }, + } + + Ok(()) + } + + /// Reveals either a node or its hash placeholder based on the provided child data. + /// + /// When traversing the trie, we often encounter references to child nodes that + /// are either directly embedded or represented by their hash. This method + /// handles both cases: + /// + /// 1. If the child data represents a hash (32+1=33 bytes), store it as a hash node + /// 2. Otherwise, decode the data as a [`TrieNode`] and recursively reveal it using + /// `reveal_node` + /// + /// # Returns + /// + /// Returns `Ok(())` if successful, or an error if the node cannot be revealed. + /// + /// # Error Handling + /// + /// Will error if there's a conflict between a new hash node and an existing one + /// at the same path + fn reveal_node_or_hash(&mut self, path: Nibbles, child: &[u8]) -> SparseTrieResult<()> { + if child.len() == B256::len_bytes() + 1 { + let hash = B256::from_slice(&child[1..]); + match self.nodes.entry(path) { + Entry::Occupied(entry) => match entry.get() { + // Hash node with a different hash can't be handled. + SparseNode::Hash(previous_hash) if previous_hash != &hash => { + return Err(SparseTrieErrorKind::Reveal { + path: entry.key().clone(), + node: Box::new(SparseNode::Hash(hash)), + } + .into()) + } + _ => {} + }, + Entry::Vacant(entry) => { + entry.insert(SparseNode::Hash(hash)); + } + } + return Ok(()) + } + + self.reveal_node(path, &TrieNode::decode(&mut &child[..])?, TrieMasks::none()) + } + /// Recalculates and updates the RLP hashes for the changed nodes in this subtrie. pub fn update_hashes(&mut self, prefix_set: &mut PrefixSet) -> SparseTrieResult<()> { trace!(target: "trie::parallel_sparse", path=?self.path, "Updating subtrie hashes"); @@ -271,13 +532,57 @@ fn path_subtrie_index_unchecked(path: &Nibbles) -> usize { #[cfg(test)] mod tests { - use alloy_trie::Nibbles; - use reth_trie_common::prefix_set::{PrefixSet, PrefixSetMut}; - use crate::{ parallel_trie::{path_subtrie_index_unchecked, SparseSubtrieType}, - ParallelSparseTrie, SparseSubtrie, + ParallelSparseTrie, SparseNode, SparseSubtrie, TrieMasks, }; + use alloy_primitives::B256; + use alloy_rlp::Encodable; + use alloy_trie::Nibbles; + use assert_matches::assert_matches; + use reth_primitives_traits::Account; + use reth_trie_common::{ + prefix_set::{PrefixSet, PrefixSetMut}, + BranchNode, ExtensionNode, LeafNode, RlpNode, TrieMask, TrieNode, EMPTY_ROOT_HASH, + }; + + // Test helpers + fn encode_account_value(nonce: u64) -> Vec { + let account = Account { nonce, ..Default::default() }; + let trie_account = account.into_trie_account(EMPTY_ROOT_HASH); + let mut buf = Vec::new(); + trie_account.encode(&mut buf); + buf + } + + fn create_leaf_node(key: &[u8], value_nonce: u64) -> TrieNode { + TrieNode::Leaf(LeafNode::new( + Nibbles::from_nibbles_unchecked(key), + encode_account_value(value_nonce), + )) + } + + fn create_extension_node(key: &[u8], child_hash: B256) -> TrieNode { + TrieNode::Extension(ExtensionNode::new( + Nibbles::from_nibbles_unchecked(key), + RlpNode::word_rlp(&child_hash), + )) + } + + fn create_branch_node_with_children( + children_indices: &[u8], + child_hashes: &[B256], + ) -> TrieNode { + let mut stack = Vec::new(); + let mut state_mask = 0u16; + + for (&idx, &hash) in children_indices.iter().zip(child_hashes.iter()) { + state_mask |= 1 << idx; + stack.push(RlpNode::word_rlp(&hash)); + } + + TrieNode::Branch(BranchNode::new(stack, TrieMask::new(state_mask))) + } #[test] fn test_get_changed_subtries_empty() { @@ -405,4 +710,159 @@ mod tests { SparseSubtrieType::Lower(255) ); } + + #[test] + fn reveal_node_leaves() { + let mut trie = ParallelSparseTrie::default(); + + // Reveal leaf in the upper trie + { + let path = Nibbles::from_nibbles([0x1, 0x2]); + let node = create_leaf_node(&[0x3, 0x4], 42); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(SparseNode::Leaf { key, hash: None }) + if key == &Nibbles::from_nibbles([0x3, 0x4]) + ); + + let full_path = Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4]); + assert_eq!(trie.upper_subtrie.values.get(&full_path), Some(&encode_account_value(42))); + } + + // Reveal leaf in a lower trie + { + let path = Nibbles::from_nibbles([0x1, 0x2, 0x3]); + let node = create_leaf_node(&[0x4, 0x5], 42); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + // Check that the lower subtrie was created + let idx = path_subtrie_index_unchecked(&path); + assert!(trie.lower_subtries[idx].is_some()); + + let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap(); + assert_matches!( + lower_subtrie.nodes.get(&path), + Some(SparseNode::Leaf { key, hash: None }) + if key == &Nibbles::from_nibbles([0x4, 0x5]) + ); + } + } + + #[test] + fn reveal_node_extension_all_upper() { + let mut trie = ParallelSparseTrie::default(); + let path = Nibbles::from_nibbles([0x1]); + let child_hash = B256::repeat_byte(0xab); + let node = create_extension_node(&[0x2], child_hash); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(SparseNode::Extension { key, hash: None, .. }) + if key == &Nibbles::from_nibbles([0x2]) + ); + + // Child path should be in upper trie + let child_path = Nibbles::from_nibbles([0x1, 0x2]); + assert_eq!(trie.upper_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash))); + } + + #[test] + fn reveal_node_extension_cross_level() { + let mut trie = ParallelSparseTrie::default(); + let path = Nibbles::from_nibbles([0x1, 0x2]); + let child_hash = B256::repeat_byte(0xcd); + let node = create_extension_node(&[0x3], child_hash); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + // Extension node should be in upper trie + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(SparseNode::Extension { key, hash: None, .. }) + if key == &Nibbles::from_nibbles([0x3]) + ); + + // Child path (0x1, 0x2, 0x3) should be in lower trie + let child_path = Nibbles::from_nibbles([0x1, 0x2, 0x3]); + let idx = path_subtrie_index_unchecked(&child_path); + assert!(trie.lower_subtries[idx].is_some()); + + let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap(); + assert_eq!(lower_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash))); + } + + #[test] + fn reveal_node_branch_all_upper() { + let mut trie = ParallelSparseTrie::default(); + let path = Nibbles::from_nibbles([0x1]); + let child_hashes = [B256::repeat_byte(0x11), B256::repeat_byte(0x22)]; + let node = create_branch_node_with_children(&[0x0, 0x5], &child_hashes); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + // Branch node should be in upper trie + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(SparseNode::Branch { state_mask, hash: None, .. }) + if *state_mask == 0b0000000000100001.into() + ); + + // Children should be in upper trie (paths of length 2) + let child_path_0 = Nibbles::from_nibbles([0x1, 0x0]); + let child_path_5 = Nibbles::from_nibbles([0x1, 0x5]); + assert_eq!( + trie.upper_subtrie.nodes.get(&child_path_0), + Some(&SparseNode::Hash(child_hashes[0])) + ); + assert_eq!( + trie.upper_subtrie.nodes.get(&child_path_5), + Some(&SparseNode::Hash(child_hashes[1])) + ); + } + + #[test] + fn reveal_node_branch_cross_level() { + let mut trie = ParallelSparseTrie::default(); + let path = Nibbles::from_nibbles([0x1, 0x2]); // Exactly 2 nibbles - boundary case + let child_hashes = + [B256::repeat_byte(0x33), B256::repeat_byte(0x44), B256::repeat_byte(0x55)]; + let node = create_branch_node_with_children(&[0x0, 0x7, 0xf], &child_hashes); + let masks = TrieMasks::none(); + + trie.reveal_node(path.clone(), node, masks).unwrap(); + + // Branch node should be in upper trie + assert_matches!( + trie.upper_subtrie.nodes.get(&path), + Some(SparseNode::Branch { state_mask, hash: None, .. }) + if *state_mask == 0b1000000010000001.into() + ); + + // All children should be in lower tries since they have paths of length 3 + let child_paths = [ + Nibbles::from_nibbles([0x1, 0x2, 0x0]), + Nibbles::from_nibbles([0x1, 0x2, 0x7]), + Nibbles::from_nibbles([0x1, 0x2, 0xf]), + ]; + + for (i, child_path) in child_paths.iter().enumerate() { + let idx = path_subtrie_index_unchecked(child_path); + let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap(); + assert_eq!( + lower_subtrie.nodes.get(child_path), + Some(&SparseNode::Hash(child_hashes[i])), + ); + } + } }