perf(trie): flatten sparse trie branch node masks to reduce overhead (#20664)

This commit is contained in:
YK
2025-12-30 11:38:24 +08:00
committed by GitHub
parent f7c77e72a7
commit 0f585f892e
2 changed files with 100 additions and 119 deletions

View File

@@ -9,8 +9,8 @@ use alloy_trie::{BranchNodeCompact, TrieMask, EMPTY_ROOT_HASH};
use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult};
use reth_trie_common::{
prefix_set::{PrefixSet, PrefixSetMut},
BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, TrieMasks,
TrieNode, CHILD_INDEX_RANGE,
BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles,
ProofTrieNode, RlpNode, TrieMasks, TrieNode, CHILD_INDEX_RANGE,
};
use reth_trie_sparse::{
provider::{RevealedNode, TrieNodeProvider},
@@ -112,10 +112,12 @@ pub struct ParallelSparseTrie {
prefix_set: PrefixSetMut,
/// Optional tracking of trie updates for later use.
updates: Option<SparseTrieUpdates>,
/// When a bit is set, the corresponding child subtree is stored in the database.
branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
/// When a bit is set, the corresponding child is stored as a hash in the database.
branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// Branch node masks containing `tree_mask` and `hash_mask` for each path.
/// - `tree_mask`: When a bit is set, the corresponding child subtree is stored in the
/// database.
/// - `hash_mask`: When a bit is set, the corresponding child is stored as a hash in the
/// database.
branch_node_masks: BranchNodeMasksMap,
/// Reusable buffer pool used for collecting [`SparseTrieUpdatesAction`]s during hash
/// computations.
update_actions_buffers: Vec<Vec<SparseTrieUpdatesAction>>,
@@ -136,8 +138,7 @@ impl Default for ParallelSparseTrie {
lower_subtries: [const { LowerSparseSubtrie::Blind(None) }; NUM_LOWER_SUBTRIES],
prefix_set: PrefixSetMut::default(),
updates: None,
branch_node_tree_masks: HashMap::default(),
branch_node_hash_masks: HashMap::default(),
branch_node_masks: BranchNodeMasksMap::default(),
update_actions_buffers: Vec::default(),
parallelism_thresholds: Default::default(),
#[cfg(feature = "metrics")]
@@ -187,11 +188,14 @@ impl SparseTrieInterface for ParallelSparseTrie {
// Update the top-level branch node masks. This is simple and can't be done in parallel.
for ProofTrieNode { path, masks, .. } in &nodes {
if let Some(tree_mask) = masks.tree_mask {
self.branch_node_tree_masks.insert(*path, tree_mask);
}
if let Some(hash_mask) = masks.hash_mask {
self.branch_node_hash_masks.insert(*path, hash_mask);
if masks.tree_mask.is_some() || masks.hash_mask.is_some() {
self.branch_node_masks.insert(
*path,
BranchNodeMasks {
tree_mask: masks.tree_mask.unwrap_or_default(),
hash_mask: masks.hash_mask.unwrap_or_default(),
},
);
}
}
@@ -719,8 +723,7 @@ impl SparseTrieInterface for ParallelSparseTrie {
changed_subtrie.subtrie.update_hashes(
&mut changed_subtrie.prefix_set,
&mut changed_subtrie.update_actions_buf,
&self.branch_node_tree_masks,
&self.branch_node_hash_masks,
&self.branch_node_masks,
);
}
@@ -736,8 +739,7 @@ impl SparseTrieInterface for ParallelSparseTrie {
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let branch_node_tree_masks = &self.branch_node_tree_masks;
let branch_node_hash_masks = &self.branch_node_hash_masks;
let branch_node_masks = &self.branch_node_masks;
let updated_subtries: Vec<_> = changed_subtries
.into_par_iter()
.map(|mut changed_subtrie| {
@@ -746,8 +748,7 @@ impl SparseTrieInterface for ParallelSparseTrie {
changed_subtrie.subtrie.update_hashes(
&mut changed_subtrie.prefix_set,
&mut changed_subtrie.update_actions_buf,
branch_node_tree_masks,
branch_node_hash_masks,
branch_node_masks,
);
#[cfg(feature = "metrics")]
self.metrics.subtrie_hash_update_latency.record(start.elapsed());
@@ -786,8 +787,7 @@ impl SparseTrieInterface for ParallelSparseTrie {
}
self.prefix_set.clear();
self.updates = None;
self.branch_node_tree_masks.clear();
self.branch_node_hash_masks.clear();
self.branch_node_masks.clear();
// `update_actions_buffers` doesn't need to be cleared; we want to reuse the Vecs it has
// buffered, and all of those are already inherently cleared when they get used.
}
@@ -870,9 +870,8 @@ impl SparseTrieInterface for ParallelSparseTrie {
subtrie.shrink_nodes_to(size_per_subtrie);
}
// shrink masks maps
self.branch_node_hash_masks.shrink_to(size);
self.branch_node_tree_masks.shrink_to(size);
// shrink masks map
self.branch_node_masks.shrink_to(size);
}
fn shrink_values_to(&mut self, size: usize) {
@@ -1377,8 +1376,7 @@ impl ParallelSparseTrie {
&mut update_actions_buf,
stack_item,
node,
&self.branch_node_tree_masks,
&self.branch_node_hash_masks,
&self.branch_node_masks,
);
}
@@ -2047,8 +2045,7 @@ impl SparseSubtrie {
/// - `update_actions`: A buffer which `SparseTrieUpdatesAction`s will be written to in the
/// event that any changes to the top-level updates are required. If None then update
/// retention is disabled.
/// - `branch_node_tree_masks`: The tree masks for branch nodes
/// - `branch_node_hash_masks`: The hash masks for branch nodes
/// - `branch_node_masks`: The tree and hash masks for branch nodes.
///
/// # Returns
///
@@ -2062,8 +2059,7 @@ impl SparseSubtrie {
&mut self,
prefix_set: &mut PrefixSet,
update_actions: &mut Option<Vec<SparseTrieUpdatesAction>>,
branch_node_tree_masks: &HashMap<Nibbles, TrieMask>,
branch_node_hash_masks: &HashMap<Nibbles, TrieMask>,
branch_node_masks: &BranchNodeMasksMap,
) -> RlpNode {
trace!(target: "trie::parallel_sparse", "Updating subtrie hashes");
@@ -2082,14 +2078,7 @@ impl SparseSubtrie {
.get_mut(&path)
.unwrap_or_else(|| panic!("node at path {path:?} does not exist"));
self.inner.rlp_node(
prefix_set,
update_actions,
stack_item,
node,
branch_node_tree_masks,
branch_node_hash_masks,
);
self.inner.rlp_node(prefix_set, update_actions, stack_item, node, branch_node_masks);
}
debug_assert_eq!(self.inner.buffers.rlp_node_stack.len(), 1);
@@ -2149,8 +2138,7 @@ impl SparseSubtrieInner {
/// retention is disabled.
/// - `stack_item`: The stack item to process
/// - `node`: The sparse node to process (will be mutated to update hash)
/// - `branch_node_tree_masks`: The tree masks for branch nodes
/// - `branch_node_hash_masks`: The hash masks for branch nodes
/// - `branch_node_masks`: The tree and hash masks for branch nodes.
///
/// # Side Effects
///
@@ -2168,8 +2156,7 @@ impl SparseSubtrieInner {
update_actions: &mut Option<Vec<SparseTrieUpdatesAction>>,
mut stack_item: RlpNodePathStackItem,
node: &mut SparseNode,
branch_node_tree_masks: &HashMap<Nibbles, TrieMask>,
branch_node_hash_masks: &HashMap<Nibbles, TrieMask>,
branch_node_masks: &BranchNodeMasksMap,
) {
let path = stack_item.path;
trace!(
@@ -2303,6 +2290,12 @@ impl SparseSubtrieInner {
let mut tree_mask = TrieMask::default();
let mut hash_mask = TrieMask::default();
let mut hashes = Vec::new();
// Lazy lookup for branch node masks - shared across loop iterations
let mut path_masks_storage = None;
let mut path_masks =
|| *path_masks_storage.get_or_insert_with(|| branch_node_masks.get(&path));
for (i, child_path) in self.buffers.branch_child_buf.iter().enumerate() {
if self.buffers.rlp_node_stack.last().is_some_and(|e| &e.path == child_path) {
let RlpNodeStackItem {
@@ -2326,9 +2319,9 @@ impl SparseSubtrieInner {
} else {
// A blinded node has the tree mask bit set
child_node_type.is_hash() &&
branch_node_tree_masks
.get(&path)
.is_some_and(|mask| mask.is_bit_set(last_child_nibble))
path_masks().is_some_and(|masks| {
masks.tree_mask.is_bit_set(last_child_nibble)
})
};
if should_set_tree_mask_bit {
tree_mask.set_bit(last_child_nibble);
@@ -2340,9 +2333,9 @@ impl SparseSubtrieInner {
let hash = child.as_hash().filter(|_| {
child_node_type.is_branch() ||
(child_node_type.is_hash() &&
branch_node_hash_masks.get(&path).is_some_and(
|mask| mask.is_bit_set(last_child_nibble),
))
path_masks().is_some_and(|masks| {
masks.hash_mask.is_bit_set(last_child_nibble)
}))
});
if let Some(hash) = hash {
hash_mask.set_bit(last_child_nibble);
@@ -2409,19 +2402,17 @@ impl SparseSubtrieInner {
);
update_actions
.push(SparseTrieUpdatesAction::InsertUpdated(path, branch_node));
} else if branch_node_tree_masks.get(&path).is_some_and(|mask| !mask.is_empty()) ||
branch_node_hash_masks.get(&path).is_some_and(|mask| !mask.is_empty())
{
// If new tree and hash masks are empty, but previously they weren't, we
// need to remove the node update and add the node itself to the list of
// removed nodes.
update_actions.push(SparseTrieUpdatesAction::InsertRemoved(path));
} else if branch_node_tree_masks.get(&path).is_none_or(|mask| mask.is_empty()) &&
branch_node_hash_masks.get(&path).is_none_or(|mask| mask.is_empty())
{
// If new tree and hash masks are empty, and they were previously empty
// as well, we need to remove the node update.
update_actions.push(SparseTrieUpdatesAction::RemoveUpdated(path));
} else {
// New tree and hash masks are empty - check previous state
let prev_had_masks = path_masks()
.is_some_and(|m| !m.tree_mask.is_empty() || !m.hash_mask.is_empty());
if prev_had_masks {
// Previously had masks, now empty - mark as removed
update_actions.push(SparseTrieUpdatesAction::InsertRemoved(path));
} else {
// Previously empty too - just remove the update
update_actions.push(SparseTrieUpdatesAction::RemoveUpdated(path));
}
}
store_in_db_trie
@@ -2667,8 +2658,8 @@ mod tests {
prefix_set::PrefixSetMut,
proof::{ProofNodes, ProofRetainer},
updates::TrieUpdates,
BranchNode, ExtensionNode, HashBuilder, LeafNode, ProofTrieNode, RlpNode, TrieMask,
TrieMasks, TrieNode, EMPTY_ROOT_HASH,
BranchNode, BranchNodeMasksMap, ExtensionNode, HashBuilder, LeafNode, ProofTrieNode,
RlpNode, TrieMask, TrieMasks, TrieNode, EMPTY_ROOT_HASH,
};
use reth_trie_db::DatabaseTrieCursorFactory;
use reth_trie_sparse::{
@@ -3608,8 +3599,7 @@ mod tests {
&mut PrefixSetMut::from([leaf_1_full_path, leaf_2_full_path, leaf_3_full_path])
.freeze(),
&mut None,
&HashMap::default(),
&HashMap::default(),
&BranchNodeMasksMap::default(),
);
// Compare hashes between hash builder and subtrie

View File

@@ -19,8 +19,9 @@ use alloy_rlp::Decodable;
use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult};
use reth_trie_common::{
prefix_set::{PrefixSet, PrefixSetMut},
BranchNodeCompact, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, ProofTrieNode,
RlpNode, TrieMask, TrieMasks, TrieNode, CHILD_INDEX_RANGE, EMPTY_ROOT_HASH,
BranchNodeCompact, BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef,
LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, TrieMask, TrieMasks, TrieNode, CHILD_INDEX_RANGE,
EMPTY_ROOT_HASH,
};
use smallvec::SmallVec;
use tracing::{debug, instrument, trace};
@@ -298,10 +299,12 @@ pub struct SerialSparseTrie {
/// Map from a path (nibbles) to its corresponding sparse trie node.
/// This contains all of the revealed nodes in trie.
nodes: HashMap<Nibbles, SparseNode>,
/// When a branch is set, the corresponding child subtree is stored in the database.
branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
/// When a bit is set, the corresponding child is stored as a hash in the database.
branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// Branch node masks containing `tree_mask` and `hash_mask` for each path.
/// - `tree_mask`: When a bit is set, the corresponding child subtree is stored in the
/// database.
/// - `hash_mask`: When a bit is set, the corresponding child is stored as a hash in the
/// database.
branch_node_masks: BranchNodeMasksMap,
/// Map from leaf key paths to their values.
/// All values are stored here instead of directly in leaf nodes.
values: HashMap<Nibbles, Vec<u8>>,
@@ -318,8 +321,7 @@ impl fmt::Debug for SerialSparseTrie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SerialSparseTrie")
.field("nodes", &self.nodes)
.field("branch_tree_masks", &self.branch_node_tree_masks)
.field("branch_hash_masks", &self.branch_node_hash_masks)
.field("branch_node_masks", &self.branch_node_masks)
.field("values", &self.values)
.field("prefix_set", &self.prefix_set)
.field("updates", &self.updates)
@@ -404,8 +406,7 @@ impl Default for SerialSparseTrie {
fn default() -> Self {
Self {
nodes: HashMap::from_iter([(Nibbles::default(), SparseNode::Empty)]),
branch_node_tree_masks: HashMap::default(),
branch_node_hash_masks: HashMap::default(),
branch_node_masks: BranchNodeMasksMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
updates: None,
@@ -456,11 +457,14 @@ impl SparseTrieInterface for SerialSparseTrie {
return Ok(())
}
if let Some(tree_mask) = masks.tree_mask {
self.branch_node_tree_masks.insert(path, tree_mask);
}
if let Some(hash_mask) = masks.hash_mask {
self.branch_node_hash_masks.insert(path, hash_mask);
if masks.tree_mask.is_some() || masks.hash_mask.is_some() {
self.branch_node_masks.insert(
path,
BranchNodeMasks {
tree_mask: masks.tree_mask.unwrap_or_default(),
hash_mask: masks.hash_mask.unwrap_or_default(),
},
);
}
match node {
@@ -959,8 +963,7 @@ impl SparseTrieInterface for SerialSparseTrie {
self.nodes.clear();
self.nodes.insert(Nibbles::default(), SparseNode::Empty);
self.branch_node_tree_masks.clear();
self.branch_node_hash_masks.clear();
self.branch_node_masks.clear();
self.values.clear();
self.prefix_set.clear();
self.updates = None;
@@ -1087,8 +1090,7 @@ impl SparseTrieInterface for SerialSparseTrie {
fn shrink_nodes_to(&mut self, size: usize) {
self.nodes.shrink_to(size);
self.branch_node_tree_masks.shrink_to(size);
self.branch_node_hash_masks.shrink_to(size);
self.branch_node_masks.shrink_to(size);
}
fn shrink_values_to(&mut self, size: usize) {
@@ -1624,6 +1626,13 @@ impl SerialSparseTrie {
let mut tree_mask = TrieMask::default();
let mut hash_mask = TrieMask::default();
let mut hashes = Vec::new();
// Lazy lookup for branch node masks - shared across loop iterations
let mut path_masks_storage = None;
let mut path_masks = || {
*path_masks_storage.get_or_insert_with(|| self.branch_node_masks.get(&path))
};
for (i, child_path) in buffers.branch_child_buf.iter().enumerate() {
if buffers.rlp_node_stack.last().is_some_and(|e| &e.path == child_path) {
let RlpNodeStackItem {
@@ -1647,9 +1656,9 @@ impl SerialSparseTrie {
} else {
// A blinded node has the tree mask bit set
child_node_type.is_hash() &&
self.branch_node_tree_masks.get(&path).is_some_and(
|mask| mask.is_bit_set(last_child_nibble),
)
path_masks().is_some_and(|masks| {
masks.tree_mask.is_bit_set(last_child_nibble)
})
};
if should_set_tree_mask_bit {
tree_mask.set_bit(last_child_nibble);
@@ -1661,11 +1670,9 @@ impl SerialSparseTrie {
let hash = child.as_hash().filter(|_| {
child_node_type.is_branch() ||
(child_node_type.is_hash() &&
self.branch_node_hash_masks
.get(&path)
.is_some_and(|mask| {
mask.is_bit_set(last_child_nibble)
}))
path_masks().is_some_and(|masks| {
masks.hash_mask.is_bit_set(last_child_nibble)
}))
});
if let Some(hash) = hash {
hash_mask.set_bit(last_child_nibble);
@@ -1729,30 +1736,16 @@ impl SerialSparseTrie {
hash.filter(|_| path.is_empty()),
);
updates.updated_nodes.insert(path, branch_node);
} else if self
.branch_node_tree_masks
.get(&path)
.is_some_and(|mask| !mask.is_empty()) ||
self.branch_node_hash_masks
.get(&path)
.is_some_and(|mask| !mask.is_empty())
{
// If new tree and hash masks are empty, but previously they weren't, we
// need to remove the node update and add the node itself to the list of
// removed nodes.
updates.updated_nodes.remove(&path);
updates.removed_nodes.insert(path);
} else if self
.branch_node_tree_masks
.get(&path)
.is_none_or(|mask| mask.is_empty()) &&
self.branch_node_hash_masks
.get(&path)
.is_none_or(|mask| mask.is_empty())
{
// If new tree and hash masks are empty, and they were previously empty
// as well, we need to remove the node update.
} else {
// New tree and hash masks are empty - check previous state
let prev_had_masks = path_masks().is_some_and(|m| {
!m.tree_mask.is_empty() || !m.hash_mask.is_empty()
});
updates.updated_nodes.remove(&path);
if prev_had_masks {
// Previously had masks, now empty - mark as removed
updates.removed_nodes.insert(path);
}
}
store_in_db_trie
@@ -2223,8 +2216,7 @@ mod find_leaf_tests {
let sparse = SerialSparseTrie {
nodes,
branch_node_tree_masks: Default::default(),
branch_node_hash_masks: Default::default(),
branch_node_masks: Default::default(),
/* The value is not in the values map, or else it would early return */
values: Default::default(),
prefix_set: Default::default(),
@@ -2266,8 +2258,7 @@ mod find_leaf_tests {
let sparse = SerialSparseTrie {
nodes,
branch_node_tree_masks: Default::default(),
branch_node_hash_masks: Default::default(),
branch_node_masks: Default::default(),
values,
prefix_set: Default::default(),
updates: None,