perf(trie): ParallelSparseTrie::reveal_node (#16894)

Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com>
This commit is contained in:
Brian Picciano
2025-06-19 15:13:12 +02:00
committed by GitHub
parent aa725dd0cf
commit ebd57f77bc

View File

@@ -1,11 +1,15 @@
use crate::{blinded::BlindedProvider, SparseNode, SparseTrieUpdates, TrieMasks};
use alloc::{boxed::Box, vec::Vec};
use alloy_primitives::{map::HashMap, B256};
use alloy_primitives::{
map::{Entry, HashMap},
B256,
};
use alloy_rlp::Decodable;
use alloy_trie::TrieMask;
use reth_execution_errors::SparseTrieResult;
use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult};
use reth_trie_common::{
prefix_set::{PrefixSet, PrefixSetMut},
Nibbles, TrieNode,
Nibbles, TrieNode, CHILD_INDEX_RANGE,
};
use tracing::trace;
@@ -37,6 +41,22 @@ impl Default for ParallelSparseTrie {
}
impl ParallelSparseTrie {
/// Returns mutable ref to the lower `SparseSubtrie` for the given path, or None if the path
/// belongs to the upper trie.
fn lower_subtrie_for_path(&mut self, path: &Nibbles) -> Option<&mut SparseSubtrie> {
match SparseSubtrieType::from_path(path) {
SparseSubtrieType::Upper => None,
SparseSubtrieType::Lower(idx) => {
if self.lower_subtries[idx].is_none() {
let upper_path = path.slice(..2);
self.lower_subtries[idx] = Some(SparseSubtrie::new(upper_path));
}
self.lower_subtries[idx].as_mut()
}
}
}
/// Creates a new revealed sparse trie from the given root node.
///
/// # Returns
@@ -58,7 +78,6 @@ impl ParallelSparseTrie {
/// It handles different node types (leaf, extension, branch) by appropriately
/// adding them to the trie structure and recursively revealing their children.
///
///
/// # Returns
///
/// `Ok(())` if successful, or an error if node was not revealed.
@@ -68,10 +87,50 @@ impl ParallelSparseTrie {
node: TrieNode,
masks: TrieMasks,
) -> SparseTrieResult<()> {
let _path = path;
let _node = node;
let _masks = masks;
todo!()
// TODO parallelize
if let Some(subtrie) = self.lower_subtrie_for_path(&path) {
return subtrie.reveal_node(path, &node, masks);
}
// If there is no subtrie for the path it means the path is 2 or less nibbles, and so
// belongs to the upper trie.
self.upper_subtrie.reveal_node(path.clone(), &node, masks)?;
// The previous upper_trie.reveal_node call will not have revealed any child nodes via
// reveal_node_or_hash if the child node would be found on a lower subtrie. We handle that
// here by manually checking the specific cases where this could happen, and calling
// reveal_node_or_hash for each.
match node {
TrieNode::Branch(branch) => {
// If a branch is at the second level of the trie then it will be in the upper trie,
// but all of its children will be in the lower trie.
if path.len() == 2 {
let mut stack_ptr = branch.as_ref().first_child_index();
for idx in CHILD_INDEX_RANGE {
if branch.state_mask.is_bit_set(idx) {
let mut child_path = path.clone();
child_path.push_unchecked(idx);
self.lower_subtrie_for_path(&child_path)
.expect("child_path must have a lower subtrie")
.reveal_node_or_hash(child_path, &branch.stack[stack_ptr])?;
stack_ptr += 1;
}
}
}
}
TrieNode::Extension(ext) => {
let mut child_path = path.clone();
child_path.extend_from_slice_unchecked(&ext.key);
if child_path.len() > 2 {
self.lower_subtrie_for_path(&child_path)
.expect("child_path must have a lower subtrie")
.reveal_node_or_hash(child_path, &ext.child)?;
}
}
TrieNode::EmptyRoot | TrieNode::Leaf(_) => (),
}
Ok(())
}
/// Updates or inserts a leaf node at the specified key path with the provided RLP-encoded
@@ -219,11 +278,213 @@ pub struct SparseSubtrie {
}
impl SparseSubtrie {
/// Creates a new sparse subtrie with the given root path.
pub fn new(path: Nibbles) -> Self {
fn new(path: Nibbles) -> Self {
Self { path, ..Default::default() }
}
/// Returns true if the current path and its child are both found in the same level. This
/// function assumes that if `current_path` is in a lower level then `child_path` is too.
fn is_child_same_level(current_path: &Nibbles, child_path: &Nibbles) -> bool {
let current_level = core::mem::discriminant(&SparseSubtrieType::from_path(current_path));
let child_level = core::mem::discriminant(&SparseSubtrieType::from_path(child_path));
current_level == child_level
}
/// Internal implementation of the method of the same name on `ParallelSparseTrie`.
fn reveal_node(
&mut self,
path: Nibbles,
node: &TrieNode,
masks: TrieMasks,
) -> SparseTrieResult<()> {
// If the node is already revealed and it's not a hash node, do nothing.
if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) {
return Ok(())
}
if let Some(tree_mask) = masks.tree_mask {
self.branch_node_tree_masks.insert(path.clone(), tree_mask);
}
if let Some(hash_mask) = masks.hash_mask {
self.branch_node_hash_masks.insert(path.clone(), hash_mask);
}
match node {
TrieNode::EmptyRoot => {
// For an empty root, ensure that we are at the root path, and at the upper subtrie.
debug_assert!(path.is_empty());
debug_assert!(self.path.is_empty());
self.nodes.insert(path, SparseNode::Empty);
}
TrieNode::Branch(branch) => {
// For a branch node, iterate over all potential children
let mut stack_ptr = branch.as_ref().first_child_index();
for idx in CHILD_INDEX_RANGE {
if branch.state_mask.is_bit_set(idx) {
let mut child_path = path.clone();
child_path.push_unchecked(idx);
if Self::is_child_same_level(&path, &child_path) {
// Reveal each child node or hash it has, but only if the child is on
// the same level as the parent.
self.reveal_node_or_hash(child_path, &branch.stack[stack_ptr])?;
}
stack_ptr += 1;
}
}
// Update the branch node entry in the nodes map, handling cases where a blinded
// node is now replaced with a revealed node.
match self.nodes.entry(path) {
Entry::Occupied(mut entry) => match entry.get() {
// Replace a hash node with a fully revealed branch node.
SparseNode::Hash(hash) => {
entry.insert(SparseNode::Branch {
state_mask: branch.state_mask,
// Memoize the hash of a previously blinded node in a new branch
// node.
hash: Some(*hash),
store_in_db_trie: Some(
masks.hash_mask.is_some_and(|mask| !mask.is_empty()) ||
masks.tree_mask.is_some_and(|mask| !mask.is_empty()),
),
});
}
// Branch node already exists, or an extension node was placed where a
// branch node was before.
SparseNode::Branch { .. } | SparseNode::Extension { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: entry.key().clone(),
node: Box::new(node.clone()),
}
.into())
}
},
Entry::Vacant(entry) => {
entry.insert(SparseNode::new_branch(branch.state_mask));
}
}
}
TrieNode::Extension(ext) => match self.nodes.entry(path.clone()) {
Entry::Occupied(mut entry) => match entry.get() {
// Replace a hash node with a revealed extension node.
SparseNode::Hash(hash) => {
let mut child_path = entry.key().clone();
child_path.extend_from_slice_unchecked(&ext.key);
entry.insert(SparseNode::Extension {
key: ext.key.clone(),
// Memoize the hash of a previously blinded node in a new extension
// node.
hash: Some(*hash),
store_in_db_trie: None,
});
if Self::is_child_same_level(&path, &child_path) {
self.reveal_node_or_hash(child_path, &ext.child)?;
}
}
// Extension node already exists, or an extension node was placed where a branch
// node was before.
SparseNode::Extension { .. } | SparseNode::Branch { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: entry.key().clone(),
node: Box::new(node.clone()),
}
.into())
}
},
Entry::Vacant(entry) => {
let mut child_path = entry.key().clone();
child_path.extend_from_slice_unchecked(&ext.key);
entry.insert(SparseNode::new_ext(ext.key.clone()));
if Self::is_child_same_level(&path, &child_path) {
self.reveal_node_or_hash(child_path, &ext.child)?;
}
}
},
TrieNode::Leaf(leaf) => match self.nodes.entry(path) {
Entry::Occupied(mut entry) => match entry.get() {
// Replace a hash node with a revealed leaf node and store leaf node value.
SparseNode::Hash(hash) => {
let mut full = entry.key().clone();
full.extend_from_slice_unchecked(&leaf.key);
self.values.insert(full, leaf.value.clone());
entry.insert(SparseNode::Leaf {
key: leaf.key.clone(),
// Memoize the hash of a previously blinded node in a new leaf
// node.
hash: Some(*hash),
});
}
// Leaf node already exists.
SparseNode::Leaf { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty |
SparseNode::Extension { .. } |
SparseNode::Branch { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: entry.key().clone(),
node: Box::new(node.clone()),
}
.into())
}
},
Entry::Vacant(entry) => {
let mut full = entry.key().clone();
full.extend_from_slice_unchecked(&leaf.key);
entry.insert(SparseNode::new_leaf(leaf.key.clone()));
self.values.insert(full, leaf.value.clone());
}
},
}
Ok(())
}
/// Reveals either a node or its hash placeholder based on the provided child data.
///
/// When traversing the trie, we often encounter references to child nodes that
/// are either directly embedded or represented by their hash. This method
/// handles both cases:
///
/// 1. If the child data represents a hash (32+1=33 bytes), store it as a hash node
/// 2. Otherwise, decode the data as a [`TrieNode`] and recursively reveal it using
/// `reveal_node`
///
/// # Returns
///
/// Returns `Ok(())` if successful, or an error if the node cannot be revealed.
///
/// # Error Handling
///
/// Will error if there's a conflict between a new hash node and an existing one
/// at the same path
fn reveal_node_or_hash(&mut self, path: Nibbles, child: &[u8]) -> SparseTrieResult<()> {
if child.len() == B256::len_bytes() + 1 {
let hash = B256::from_slice(&child[1..]);
match self.nodes.entry(path) {
Entry::Occupied(entry) => match entry.get() {
// Hash node with a different hash can't be handled.
SparseNode::Hash(previous_hash) if previous_hash != &hash => {
return Err(SparseTrieErrorKind::Reveal {
path: entry.key().clone(),
node: Box::new(SparseNode::Hash(hash)),
}
.into())
}
_ => {}
},
Entry::Vacant(entry) => {
entry.insert(SparseNode::Hash(hash));
}
}
return Ok(())
}
self.reveal_node(path, &TrieNode::decode(&mut &child[..])?, TrieMasks::none())
}
/// Recalculates and updates the RLP hashes for the changed nodes in this subtrie.
pub fn update_hashes(&mut self, prefix_set: &mut PrefixSet) -> SparseTrieResult<()> {
trace!(target: "trie::parallel_sparse", path=?self.path, "Updating subtrie hashes");
@@ -271,13 +532,57 @@ fn path_subtrie_index_unchecked(path: &Nibbles) -> usize {
#[cfg(test)]
mod tests {
use alloy_trie::Nibbles;
use reth_trie_common::prefix_set::{PrefixSet, PrefixSetMut};
use crate::{
parallel_trie::{path_subtrie_index_unchecked, SparseSubtrieType},
ParallelSparseTrie, SparseSubtrie,
ParallelSparseTrie, SparseNode, SparseSubtrie, TrieMasks,
};
use alloy_primitives::B256;
use alloy_rlp::Encodable;
use alloy_trie::Nibbles;
use assert_matches::assert_matches;
use reth_primitives_traits::Account;
use reth_trie_common::{
prefix_set::{PrefixSet, PrefixSetMut},
BranchNode, ExtensionNode, LeafNode, RlpNode, TrieMask, TrieNode, EMPTY_ROOT_HASH,
};
// Test helpers
fn encode_account_value(nonce: u64) -> Vec<u8> {
let account = Account { nonce, ..Default::default() };
let trie_account = account.into_trie_account(EMPTY_ROOT_HASH);
let mut buf = Vec::new();
trie_account.encode(&mut buf);
buf
}
fn create_leaf_node(key: &[u8], value_nonce: u64) -> TrieNode {
TrieNode::Leaf(LeafNode::new(
Nibbles::from_nibbles_unchecked(key),
encode_account_value(value_nonce),
))
}
fn create_extension_node(key: &[u8], child_hash: B256) -> TrieNode {
TrieNode::Extension(ExtensionNode::new(
Nibbles::from_nibbles_unchecked(key),
RlpNode::word_rlp(&child_hash),
))
}
fn create_branch_node_with_children(
children_indices: &[u8],
child_hashes: &[B256],
) -> TrieNode {
let mut stack = Vec::new();
let mut state_mask = 0u16;
for (&idx, &hash) in children_indices.iter().zip(child_hashes.iter()) {
state_mask |= 1 << idx;
stack.push(RlpNode::word_rlp(&hash));
}
TrieNode::Branch(BranchNode::new(stack, TrieMask::new(state_mask)))
}
#[test]
fn test_get_changed_subtries_empty() {
@@ -405,4 +710,159 @@ mod tests {
SparseSubtrieType::Lower(255)
);
}
#[test]
fn reveal_node_leaves() {
let mut trie = ParallelSparseTrie::default();
// Reveal leaf in the upper trie
{
let path = Nibbles::from_nibbles([0x1, 0x2]);
let node = create_leaf_node(&[0x3, 0x4], 42);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Leaf { key, hash: None })
if key == &Nibbles::from_nibbles([0x3, 0x4])
);
let full_path = Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4]);
assert_eq!(trie.upper_subtrie.values.get(&full_path), Some(&encode_account_value(42)));
}
// Reveal leaf in a lower trie
{
let path = Nibbles::from_nibbles([0x1, 0x2, 0x3]);
let node = create_leaf_node(&[0x4, 0x5], 42);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
// Check that the lower subtrie was created
let idx = path_subtrie_index_unchecked(&path);
assert!(trie.lower_subtries[idx].is_some());
let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap();
assert_matches!(
lower_subtrie.nodes.get(&path),
Some(SparseNode::Leaf { key, hash: None })
if key == &Nibbles::from_nibbles([0x4, 0x5])
);
}
}
#[test]
fn reveal_node_extension_all_upper() {
let mut trie = ParallelSparseTrie::default();
let path = Nibbles::from_nibbles([0x1]);
let child_hash = B256::repeat_byte(0xab);
let node = create_extension_node(&[0x2], child_hash);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Extension { key, hash: None, .. })
if key == &Nibbles::from_nibbles([0x2])
);
// Child path should be in upper trie
let child_path = Nibbles::from_nibbles([0x1, 0x2]);
assert_eq!(trie.upper_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash)));
}
#[test]
fn reveal_node_extension_cross_level() {
let mut trie = ParallelSparseTrie::default();
let path = Nibbles::from_nibbles([0x1, 0x2]);
let child_hash = B256::repeat_byte(0xcd);
let node = create_extension_node(&[0x3], child_hash);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
// Extension node should be in upper trie
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Extension { key, hash: None, .. })
if key == &Nibbles::from_nibbles([0x3])
);
// Child path (0x1, 0x2, 0x3) should be in lower trie
let child_path = Nibbles::from_nibbles([0x1, 0x2, 0x3]);
let idx = path_subtrie_index_unchecked(&child_path);
assert!(trie.lower_subtries[idx].is_some());
let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap();
assert_eq!(lower_subtrie.nodes.get(&child_path), Some(&SparseNode::Hash(child_hash)));
}
#[test]
fn reveal_node_branch_all_upper() {
let mut trie = ParallelSparseTrie::default();
let path = Nibbles::from_nibbles([0x1]);
let child_hashes = [B256::repeat_byte(0x11), B256::repeat_byte(0x22)];
let node = create_branch_node_with_children(&[0x0, 0x5], &child_hashes);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
// Branch node should be in upper trie
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Branch { state_mask, hash: None, .. })
if *state_mask == 0b0000000000100001.into()
);
// Children should be in upper trie (paths of length 2)
let child_path_0 = Nibbles::from_nibbles([0x1, 0x0]);
let child_path_5 = Nibbles::from_nibbles([0x1, 0x5]);
assert_eq!(
trie.upper_subtrie.nodes.get(&child_path_0),
Some(&SparseNode::Hash(child_hashes[0]))
);
assert_eq!(
trie.upper_subtrie.nodes.get(&child_path_5),
Some(&SparseNode::Hash(child_hashes[1]))
);
}
#[test]
fn reveal_node_branch_cross_level() {
let mut trie = ParallelSparseTrie::default();
let path = Nibbles::from_nibbles([0x1, 0x2]); // Exactly 2 nibbles - boundary case
let child_hashes =
[B256::repeat_byte(0x33), B256::repeat_byte(0x44), B256::repeat_byte(0x55)];
let node = create_branch_node_with_children(&[0x0, 0x7, 0xf], &child_hashes);
let masks = TrieMasks::none();
trie.reveal_node(path.clone(), node, masks).unwrap();
// Branch node should be in upper trie
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Branch { state_mask, hash: None, .. })
if *state_mask == 0b1000000010000001.into()
);
// All children should be in lower tries since they have paths of length 3
let child_paths = [
Nibbles::from_nibbles([0x1, 0x2, 0x0]),
Nibbles::from_nibbles([0x1, 0x2, 0x7]),
Nibbles::from_nibbles([0x1, 0x2, 0xf]),
];
for (i, child_path) in child_paths.iter().enumerate() {
let idx = path_subtrie_index_unchecked(child_path);
let lower_subtrie = trie.lower_subtries[idx].as_ref().unwrap();
assert_eq!(
lower_subtrie.nodes.get(child_path),
Some(&SparseNode::Hash(child_hashes[i])),
);
}
}
}