mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-19 03:04:27 -05:00
feat(trie): add prune method to SparseTrieInterface (#21427)
Co-authored-by: Amp <amp@ampcode.com> Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>
This commit is contained in:
@@ -15,7 +15,7 @@ use reth_trie_common::{
|
||||
use reth_trie_sparse::{
|
||||
provider::{RevealedNode, TrieNodeProvider},
|
||||
LeafLookup, LeafLookupError, RlpNodeStackItem, SparseNode, SparseNodeType, SparseTrie,
|
||||
SparseTrieUpdates,
|
||||
SparseTrieExt, SparseTrieUpdates,
|
||||
};
|
||||
use smallvec::SmallVec;
|
||||
use std::cmp::{Ord, Ordering, PartialOrd};
|
||||
@@ -908,6 +908,162 @@ impl SparseTrie for ParallelSparseTrie {
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseTrieExt for ParallelSparseTrie {
|
||||
/// Returns the count of revealed (non-hash) nodes across all subtries.
|
||||
fn revealed_node_count(&self) -> usize {
|
||||
let upper_count = self.upper_subtrie.nodes.values().filter(|n| !n.is_hash()).count();
|
||||
|
||||
let lower_count: usize = self
|
||||
.lower_subtries
|
||||
.iter()
|
||||
.filter_map(|s| s.as_revealed_ref())
|
||||
.map(|s| s.nodes.values().filter(|n| !n.is_hash()).count())
|
||||
.sum();
|
||||
|
||||
upper_count + lower_count
|
||||
}
|
||||
|
||||
fn prune(&mut self, max_depth: usize) -> usize {
|
||||
// DFS traversal to find nodes at max_depth that can be pruned.
|
||||
// Collects "effective pruned roots" - children of nodes at max_depth with computed hashes.
|
||||
// We replace nodes with Hash stubs inline during traversal.
|
||||
let mut effective_pruned_roots = Vec::<(Nibbles, B256)>::new();
|
||||
let mut stack: SmallVec<[(Nibbles, usize); 32]> = SmallVec::new();
|
||||
stack.push((Nibbles::default(), 0));
|
||||
|
||||
// DFS traversal: pop path and depth, skip if subtrie or node not found.
|
||||
while let Some((path, depth)) = stack.pop() {
|
||||
// Get children to visit from current node (immutable access)
|
||||
let children: SmallVec<[Nibbles; 16]> = {
|
||||
let Some(subtrie) = self.subtrie_for_path(&path) else { continue };
|
||||
let Some(node) = subtrie.nodes.get(&path) else { continue };
|
||||
|
||||
match node {
|
||||
SparseNode::Empty | SparseNode::Hash(_) | SparseNode::Leaf { .. } => {
|
||||
SmallVec::new()
|
||||
}
|
||||
SparseNode::Extension { key, .. } => {
|
||||
let mut child = path;
|
||||
child.extend(key);
|
||||
SmallVec::from_buf_and_len([child; 16], 1)
|
||||
}
|
||||
SparseNode::Branch { state_mask, .. } => {
|
||||
let mut children = SmallVec::new();
|
||||
let mut mask = state_mask.get();
|
||||
while mask != 0 {
|
||||
let nibble = mask.trailing_zeros() as u8;
|
||||
mask &= mask - 1;
|
||||
let mut child = path;
|
||||
child.push_unchecked(nibble);
|
||||
children.push(child);
|
||||
}
|
||||
children
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Process children - either continue traversal or prune
|
||||
for child in children {
|
||||
if depth == max_depth {
|
||||
// Check if child has a computed hash and replace inline
|
||||
let hash = self
|
||||
.subtrie_for_path(&child)
|
||||
.and_then(|s| s.nodes.get(&child))
|
||||
.filter(|n| !n.is_hash())
|
||||
.and_then(|n| n.hash());
|
||||
|
||||
if let Some(hash) = hash {
|
||||
self.subtrie_for_path_mut(&child)
|
||||
.nodes
|
||||
.insert(child, SparseNode::Hash(hash));
|
||||
effective_pruned_roots.push((child, hash));
|
||||
}
|
||||
} else {
|
||||
stack.push((child, depth + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if effective_pruned_roots.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let nodes_converted = effective_pruned_roots.len();
|
||||
|
||||
// Sort roots by subtrie type (upper first), then by path for efficient partitioning.
|
||||
effective_pruned_roots.sort_unstable_by(|(path_a, _), (path_b, _)| {
|
||||
let subtrie_type_a = SparseSubtrieType::from_path(path_a);
|
||||
let subtrie_type_b = SparseSubtrieType::from_path(path_b);
|
||||
subtrie_type_a.cmp(&subtrie_type_b).then(path_a.cmp(path_b))
|
||||
});
|
||||
|
||||
// Split off upper subtrie roots (they come first due to sorting)
|
||||
let num_upper_roots = effective_pruned_roots
|
||||
.iter()
|
||||
.position(|(p, _)| !SparseSubtrieType::path_len_is_upper(p.len()))
|
||||
.unwrap_or(effective_pruned_roots.len());
|
||||
|
||||
let roots_upper = &effective_pruned_roots[..num_upper_roots];
|
||||
let roots_lower = &effective_pruned_roots[num_upper_roots..];
|
||||
|
||||
debug_assert!(
|
||||
{
|
||||
let mut all_roots: Vec<_> = effective_pruned_roots.iter().map(|(p, _)| p).collect();
|
||||
all_roots.sort_unstable();
|
||||
all_roots.windows(2).all(|w| !w[1].starts_with(w[0]))
|
||||
},
|
||||
"prune roots must be prefix-free"
|
||||
);
|
||||
|
||||
// Upper prune roots that are prefixes of lower subtrie root paths cause the entire
|
||||
// subtrie to be cleared (preserving allocations for reuse).
|
||||
if !roots_upper.is_empty() {
|
||||
for subtrie in &mut self.lower_subtries {
|
||||
let should_clear = subtrie.as_revealed_ref().is_some_and(|s| {
|
||||
let search_idx = roots_upper.partition_point(|(root, _)| root <= &s.path);
|
||||
search_idx > 0 && s.path.starts_with(&roots_upper[search_idx - 1].0)
|
||||
});
|
||||
if should_clear {
|
||||
subtrie.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Upper subtrie: prune nodes and values
|
||||
self.upper_subtrie.nodes.retain(|p, _| !is_strict_descendant_in(roots_upper, p));
|
||||
self.upper_subtrie.inner.values.retain(|p, _| {
|
||||
!starts_with_pruned_in(roots_upper, p) && !starts_with_pruned_in(roots_lower, p)
|
||||
});
|
||||
|
||||
// Process lower subtries using chunk_by to group roots by subtrie
|
||||
for roots_group in roots_lower.chunk_by(|(path_a, _), (path_b, _)| {
|
||||
SparseSubtrieType::from_path(path_a) == SparseSubtrieType::from_path(path_b)
|
||||
}) {
|
||||
let subtrie_idx = path_subtrie_index_unchecked(&roots_group[0].0);
|
||||
|
||||
// Skip unrevealed/blinded subtries - nothing to prune
|
||||
let Some(subtrie) = self.lower_subtries[subtrie_idx].as_revealed_mut() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Retain only nodes/values not descended from any pruned root.
|
||||
subtrie.nodes.retain(|p, _| !is_strict_descendant_in(roots_group, p));
|
||||
subtrie.inner.values.retain(|p, _| !starts_with_pruned_in(roots_group, p));
|
||||
}
|
||||
|
||||
// Branch node masks pruning
|
||||
self.branch_node_masks.retain(|p, _| {
|
||||
if SparseSubtrieType::path_len_is_upper(p.len()) {
|
||||
!starts_with_pruned_in(roots_upper, p)
|
||||
} else {
|
||||
!starts_with_pruned_in(roots_lower, p) && !starts_with_pruned_in(roots_upper, p)
|
||||
}
|
||||
});
|
||||
|
||||
nodes_converted
|
||||
}
|
||||
}
|
||||
|
||||
impl ParallelSparseTrie {
|
||||
/// Sets the thresholds that control when parallelism is used during operations.
|
||||
pub const fn with_parallelism_thresholds(mut self, thresholds: ParallelismThresholds) -> Self {
|
||||
@@ -2654,6 +2810,44 @@ fn path_subtrie_index_unchecked(path: &Nibbles) -> usize {
|
||||
path.get_byte_unchecked(0) as usize
|
||||
}
|
||||
|
||||
/// Checks if `path` is a strict descendant of any root in a sorted slice.
|
||||
///
|
||||
/// Uses binary search to find the candidate root that could be an ancestor.
|
||||
/// Returns `true` if `path` starts with a root and is longer (strict descendant).
|
||||
fn is_strict_descendant_in(roots: &[(Nibbles, B256)], path: &Nibbles) -> bool {
|
||||
if roots.is_empty() {
|
||||
return false;
|
||||
}
|
||||
debug_assert!(roots.windows(2).all(|w| w[0].0 <= w[1].0), "roots must be sorted by path");
|
||||
let idx = roots.partition_point(|(root, _)| root <= path);
|
||||
if idx > 0 {
|
||||
let candidate = &roots[idx - 1].0;
|
||||
if path.starts_with(candidate) && path.len() > candidate.len() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Checks if `path` starts with any root in a sorted slice (inclusive).
|
||||
///
|
||||
/// Uses binary search to find the candidate root that could be a prefix.
|
||||
/// Returns `true` if `path` starts with a root (including exact match).
|
||||
fn starts_with_pruned_in(roots: &[(Nibbles, B256)], path: &Nibbles) -> bool {
|
||||
if roots.is_empty() {
|
||||
return false;
|
||||
}
|
||||
debug_assert!(roots.windows(2).all(|w| w[0].0 <= w[1].0), "roots must be sorted by path");
|
||||
let idx = roots.partition_point(|(root, _)| root <= path);
|
||||
if idx > 0 {
|
||||
let candidate = &roots[idx - 1].0;
|
||||
if path.starts_with(candidate) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Used by lower subtries to communicate updates to the top-level [`SparseTrieUpdates`] set.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
enum SparseTrieUpdatesAction {
|
||||
@@ -2704,7 +2898,8 @@ mod tests {
|
||||
use reth_trie_db::DatabaseTrieCursorFactory;
|
||||
use reth_trie_sparse::{
|
||||
provider::{DefaultTrieNodeProvider, RevealedNode, TrieNodeProvider},
|
||||
LeafLookup, LeafLookupError, SerialSparseTrie, SparseNode, SparseTrie, SparseTrieUpdates,
|
||||
LeafLookup, LeafLookupError, SerialSparseTrie, SparseNode, SparseTrie, SparseTrieExt,
|
||||
SparseTrieUpdates,
|
||||
};
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
@@ -2749,6 +2944,17 @@ mod tests {
|
||||
Account { nonce, ..Default::default() }
|
||||
}
|
||||
|
||||
fn large_account_value() -> Vec<u8> {
|
||||
let account = Account {
|
||||
nonce: 0x123456789abcdef,
|
||||
balance: U256::from(0x123456789abcdef0123456789abcdef_u128),
|
||||
..Default::default()
|
||||
};
|
||||
let mut buf = Vec::new();
|
||||
account.into_trie_account(EMPTY_ROOT_HASH).encode(&mut buf);
|
||||
buf
|
||||
}
|
||||
|
||||
fn encode_account_value(nonce: u64) -> Vec<u8> {
|
||||
let account = Account { nonce, ..Default::default() };
|
||||
let trie_account = account.into_trie_account(EMPTY_ROOT_HASH);
|
||||
@@ -7106,4 +7312,372 @@ mod tests {
|
||||
// Value should be retrievable
|
||||
assert_eq!(trie.get_leaf_value(&slot_path), Some(&slot_value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_empty_suffix_key_regression() {
|
||||
// Regression test: when a leaf has an empty suffix key (full path == node path),
|
||||
// the value must be removed when that path becomes a pruned root.
|
||||
// This catches the bug where is_strict_descendant fails to remove p == pruned_root.
|
||||
|
||||
use reth_trie_sparse::provider::DefaultTrieNodeProvider;
|
||||
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut parallel = ParallelSparseTrie::default();
|
||||
|
||||
// Large value to ensure nodes have hashes (RLP >= 32 bytes)
|
||||
let value = {
|
||||
let account = Account {
|
||||
nonce: 0x123456789abcdef,
|
||||
balance: U256::from(0x123456789abcdef0123456789abcdef_u128),
|
||||
..Default::default()
|
||||
};
|
||||
let mut buf = Vec::new();
|
||||
account.into_trie_account(EMPTY_ROOT_HASH).encode(&mut buf);
|
||||
buf
|
||||
};
|
||||
|
||||
// Create a trie with multiple leaves to force a branch at root
|
||||
for i in 0..16u8 {
|
||||
parallel
|
||||
.update_leaf(
|
||||
Nibbles::from_nibbles([i, 0x1, 0x2, 0x3, 0x4, 0x5]),
|
||||
value.clone(),
|
||||
&provider,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Compute root to get hashes
|
||||
let root_before = parallel.root();
|
||||
|
||||
// Prune at depth 0: the children of root become pruned roots
|
||||
parallel.prune(0);
|
||||
|
||||
let root_after = parallel.root();
|
||||
assert_eq!(root_before, root_after, "root hash must be preserved");
|
||||
|
||||
// Key assertion: values under pruned paths must be removed
|
||||
// With the bug, values at pruned_root paths (not strict descendants) would remain
|
||||
for i in 0..16u8 {
|
||||
let path = Nibbles::from_nibbles([i, 0x1, 0x2, 0x3, 0x4, 0x5]);
|
||||
assert!(
|
||||
parallel.get_leaf_value(&path).is_none(),
|
||||
"value at {:?} should be removed after prune",
|
||||
path
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_at_various_depths() {
|
||||
for max_depth in [0, 1, 2] {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
for i in 0..4u8 {
|
||||
for j in 0..4u8 {
|
||||
for k in 0..4u8 {
|
||||
trie.update_leaf(
|
||||
Nibbles::from_nibbles([i, j, k, 0x1, 0x2, 0x3]),
|
||||
value.clone(),
|
||||
&provider,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let root_before = trie.root();
|
||||
let nodes_before = trie.revealed_node_count();
|
||||
|
||||
trie.prune(max_depth);
|
||||
|
||||
let root_after = trie.root();
|
||||
assert_eq!(root_before, root_after, "root hash should be preserved after prune");
|
||||
|
||||
let nodes_after = trie.revealed_node_count();
|
||||
assert!(
|
||||
nodes_after < nodes_before,
|
||||
"node count should decrease after prune at depth {max_depth}"
|
||||
);
|
||||
|
||||
if max_depth == 0 {
|
||||
assert_eq!(nodes_after, 1, "only root should be revealed after prune(0)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_empty_trie() {
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
trie.prune(2);
|
||||
let root = trie.root();
|
||||
assert_eq!(root, EMPTY_ROOT_HASH, "empty trie should have empty root hash");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_preserves_root_hash() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
for i in 0..8u8 {
|
||||
for j in 0..4u8 {
|
||||
trie.update_leaf(
|
||||
Nibbles::from_nibbles([i, j, 0x3, 0x4, 0x5, 0x6]),
|
||||
value.clone(),
|
||||
&provider,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let root_before = trie.root();
|
||||
trie.prune(1);
|
||||
let root_after = trie.root();
|
||||
assert_eq!(root_before, root_after, "root hash must be preserved after prune");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_single_leaf_trie() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
trie.update_leaf(Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4]), value, &provider).unwrap();
|
||||
|
||||
let root_before = trie.root();
|
||||
let nodes_before = trie.revealed_node_count();
|
||||
|
||||
trie.prune(0);
|
||||
|
||||
let root_after = trie.root();
|
||||
assert_eq!(root_before, root_after, "root hash should be preserved");
|
||||
assert_eq!(trie.revealed_node_count(), nodes_before, "single leaf trie should not change");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_deep_depth_no_effect() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
for i in 0..4u8 {
|
||||
trie.update_leaf(Nibbles::from_nibbles([i, 0x2, 0x3, 0x4]), value.clone(), &provider)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
trie.root();
|
||||
let nodes_before = trie.revealed_node_count();
|
||||
|
||||
trie.prune(100);
|
||||
|
||||
assert_eq!(nodes_before, trie.revealed_node_count(), "deep prune should have no effect");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_extension_node_depth_semantics() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
trie.update_leaf(Nibbles::from_nibbles([0, 1, 2, 3, 0, 5, 6, 7]), value.clone(), &provider)
|
||||
.unwrap();
|
||||
trie.update_leaf(Nibbles::from_nibbles([0, 1, 2, 3, 1, 5, 6, 7]), value, &provider)
|
||||
.unwrap();
|
||||
|
||||
let root_before = trie.root();
|
||||
trie.prune(1);
|
||||
|
||||
assert_eq!(root_before, trie.root(), "root hash should be preserved");
|
||||
assert_eq!(trie.revealed_node_count(), 2, "should have root + extension after prune(1)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_embedded_node_preserved() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let small_value = vec![0x80];
|
||||
trie.update_leaf(Nibbles::from_nibbles([0x0]), small_value.clone(), &provider).unwrap();
|
||||
trie.update_leaf(Nibbles::from_nibbles([0x1]), small_value, &provider).unwrap();
|
||||
|
||||
let root_before = trie.root();
|
||||
let nodes_before = trie.revealed_node_count();
|
||||
|
||||
trie.prune(0);
|
||||
|
||||
assert_eq!(root_before, trie.root(), "root hash must be preserved");
|
||||
|
||||
if trie.revealed_node_count() == nodes_before {
|
||||
assert!(trie.get_leaf_value(&Nibbles::from_nibbles([0x0])).is_some());
|
||||
assert!(trie.get_leaf_value(&Nibbles::from_nibbles([0x1])).is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_mixed_embedded_and_hashed() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let large_value = large_account_value();
|
||||
let small_value = vec![0x80];
|
||||
|
||||
for i in 0..8u8 {
|
||||
let value = if i < 4 { large_value.clone() } else { small_value.clone() };
|
||||
trie.update_leaf(Nibbles::from_nibbles([i, 0x1, 0x2, 0x3]), value, &provider).unwrap();
|
||||
}
|
||||
|
||||
let root_before = trie.root();
|
||||
trie.prune(0);
|
||||
assert_eq!(root_before, trie.root(), "root hash must be preserved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_many_lower_subtries() {
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
|
||||
let large_value = large_account_value();
|
||||
|
||||
let mut keys = Vec::new();
|
||||
for first in 0..16u8 {
|
||||
for second in 0..16u8 {
|
||||
keys.push(Nibbles::from_nibbles([first, second, 0x1, 0x2, 0x3, 0x4]));
|
||||
}
|
||||
}
|
||||
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
for key in &keys {
|
||||
trie.update_leaf(*key, large_value.clone(), &provider).unwrap();
|
||||
}
|
||||
|
||||
let root_before = trie.root();
|
||||
let pruned = trie.prune(1);
|
||||
|
||||
assert!(pruned > 0, "should have pruned some nodes");
|
||||
assert_eq!(root_before, trie.root(), "root hash should be preserved");
|
||||
|
||||
for key in &keys {
|
||||
assert!(trie.get_leaf_value(key).is_none(), "value should be pruned");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "profiling test - run manually"]
|
||||
fn test_prune_profile() {
|
||||
use std::time::Instant;
|
||||
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let large_value = large_account_value();
|
||||
|
||||
// Generate 65536 keys (16^4) for a large trie
|
||||
let mut keys = Vec::with_capacity(65536);
|
||||
for a in 0..16u8 {
|
||||
for b in 0..16u8 {
|
||||
for c in 0..16u8 {
|
||||
for d in 0..16u8 {
|
||||
keys.push(Nibbles::from_nibbles([a, b, c, d, 0x5, 0x6, 0x7, 0x8]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build base trie once
|
||||
let mut base_trie = ParallelSparseTrie::default();
|
||||
for key in &keys {
|
||||
base_trie.update_leaf(*key, large_value.clone(), &provider).unwrap();
|
||||
}
|
||||
base_trie.root(); // ensure hashes computed
|
||||
|
||||
// Pre-clone tries to exclude clone time from profiling
|
||||
let iterations = 100;
|
||||
let mut tries: Vec<_> = (0..iterations).map(|_| base_trie.clone()).collect();
|
||||
|
||||
// Measure only prune()
|
||||
let mut total_pruned = 0;
|
||||
let start = Instant::now();
|
||||
for trie in &mut tries {
|
||||
total_pruned += trie.prune(2);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
println!(
|
||||
"Prune benchmark: {} iterations, total: {:?}, avg: {:?}, pruned/iter: {}",
|
||||
iterations,
|
||||
elapsed,
|
||||
elapsed / iterations as u32,
|
||||
total_pruned / iterations
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_max_depth_overflow() {
|
||||
// Verify that max_depth > 255 is not truncated (was u8, now usize)
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
for i in 0..4u8 {
|
||||
trie.update_leaf(Nibbles::from_nibbles([i, 0x1, 0x2, 0x3]), value.clone(), &provider)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
trie.root();
|
||||
let nodes_before = trie.revealed_node_count();
|
||||
|
||||
// If depth were truncated to u8, 300 would become 44 and might prune something
|
||||
trie.prune(300);
|
||||
|
||||
assert_eq!(
|
||||
nodes_before,
|
||||
trie.revealed_node_count(),
|
||||
"prune(300) should have no effect on a shallow trie"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_fast_path_case2_update_after() {
|
||||
// Test fast-path Case 2: upper prune root is prefix of lower subtrie.
|
||||
// After pruning, we should be able to update leaves without panic.
|
||||
let provider = DefaultTrieNodeProvider;
|
||||
let mut trie = ParallelSparseTrie::default();
|
||||
|
||||
let value = large_account_value();
|
||||
|
||||
// Create keys that span into lower subtries (path.len() >= UPPER_TRIE_MAX_DEPTH)
|
||||
// UPPER_TRIE_MAX_DEPTH is typically 2, so paths of length 3+ go to lower subtries
|
||||
for first in 0..4u8 {
|
||||
for second in 0..4u8 {
|
||||
trie.update_leaf(
|
||||
Nibbles::from_nibbles([first, second, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]),
|
||||
value.clone(),
|
||||
&provider,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let root_before = trie.root();
|
||||
|
||||
// Prune at depth 0 - upper roots become prefixes of lower subtrie paths
|
||||
trie.prune(0);
|
||||
|
||||
let root_after = trie.root();
|
||||
assert_eq!(root_before, root_after, "root hash should be preserved");
|
||||
|
||||
// Now try to update a leaf - this should not panic even though lower subtries
|
||||
// were replaced with Blind(None)
|
||||
let new_path = Nibbles::from_nibbles([0x5, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]);
|
||||
trie.update_leaf(new_path, value, &provider).unwrap();
|
||||
|
||||
// The trie should still be functional
|
||||
let _ = trie.root();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,12 @@
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Default depth to prune sparse tries to for cross-payload caching.
|
||||
pub const DEFAULT_SPARSE_TRIE_PRUNE_DEPTH: usize = 4;
|
||||
|
||||
/// Default number of storage tries to preserve across payload validations.
|
||||
pub const DEFAULT_MAX_PRESERVED_STORAGE_TRIES: usize = 100;
|
||||
|
||||
mod state;
|
||||
pub use state::*;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
provider::{TrieNodeProvider, TrieNodeProviderFactory},
|
||||
traits::SparseTrie as SparseTrieTrait,
|
||||
traits::{SparseTrie as SparseTrieTrait, SparseTrieExt},
|
||||
RevealableSparseTrie, SerialSparseTrie,
|
||||
};
|
||||
use alloc::{collections::VecDeque, vec::Vec};
|
||||
@@ -972,6 +972,117 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, S> SparseStateTrie<A, S>
|
||||
where
|
||||
A: SparseTrieTrait + SparseTrieExt + Default,
|
||||
S: SparseTrieTrait + SparseTrieExt + Default + Clone,
|
||||
{
|
||||
/// Minimum number of storage tries before parallel pruning is enabled.
|
||||
const PARALLEL_PRUNE_THRESHOLD: usize = 16;
|
||||
|
||||
/// Returns true if parallelism should be enabled for pruning the given number of tries.
|
||||
/// Will always return false in `no_std` builds.
|
||||
const fn is_prune_parallelism_enabled(num_tries: usize) -> bool {
|
||||
#[cfg(not(feature = "std"))]
|
||||
return false;
|
||||
|
||||
num_tries >= Self::PARALLEL_PRUNE_THRESHOLD
|
||||
}
|
||||
|
||||
/// Prunes the account trie and selected storage tries to reduce memory usage.
|
||||
///
|
||||
/// Storage tries not in the top `max_storage_tries` by revealed node count are cleared
|
||||
/// entirely.
|
||||
///
|
||||
/// # Preconditions
|
||||
///
|
||||
/// Node hashes must be computed via `root()` before calling this method. Otherwise, nodes
|
||||
/// cannot be converted to hash stubs and pruning will have no effect.
|
||||
///
|
||||
/// # Effects
|
||||
///
|
||||
/// - Clears `revealed_account_paths` and `revealed_paths` for all storage tries
|
||||
pub fn prune(&mut self, max_depth: usize, max_storage_tries: usize) {
|
||||
if let Some(trie) = self.state.as_revealed_mut() {
|
||||
trie.prune(max_depth);
|
||||
}
|
||||
self.revealed_account_paths.clear();
|
||||
|
||||
let mut storage_trie_counts: Vec<(B256, usize)> = self
|
||||
.storage
|
||||
.tries
|
||||
.iter()
|
||||
.map(|(hash, trie)| {
|
||||
let count = match trie {
|
||||
RevealableSparseTrie::Revealed(t) => t.revealed_node_count(),
|
||||
RevealableSparseTrie::Blind(_) => 0,
|
||||
};
|
||||
(*hash, count)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Use O(n) selection instead of O(n log n) sort
|
||||
let tries_to_keep: HashSet<B256> = if storage_trie_counts.len() <= max_storage_tries {
|
||||
storage_trie_counts.iter().map(|(hash, _)| *hash).collect()
|
||||
} else {
|
||||
storage_trie_counts
|
||||
.select_nth_unstable_by(max_storage_tries.saturating_sub(1), |a, b| b.1.cmp(&a.1));
|
||||
storage_trie_counts[..max_storage_tries].iter().map(|(hash, _)| *hash).collect()
|
||||
};
|
||||
|
||||
// Collect keys to avoid borrow conflict
|
||||
let tries_to_clear: Vec<B256> = self
|
||||
.storage
|
||||
.tries
|
||||
.keys()
|
||||
.filter(|hash| !tries_to_keep.contains(*hash))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Evict storage tries that exceeded limit, saving cleared allocations for reuse
|
||||
for hash in tries_to_clear {
|
||||
if let Some(trie) = self.storage.tries.remove(&hash) {
|
||||
self.storage.cleared_tries.push(trie.clear());
|
||||
}
|
||||
if let Some(mut paths) = self.storage.revealed_paths.remove(&hash) {
|
||||
paths.clear();
|
||||
self.storage.cleared_revealed_paths.push(paths);
|
||||
}
|
||||
}
|
||||
|
||||
// Prune storage tries that are kept
|
||||
if Self::is_prune_parallelism_enabled(tries_to_keep.len()) {
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
|
||||
self.storage.tries.par_iter_mut().for_each(|(hash, trie)| {
|
||||
if tries_to_keep.contains(hash) &&
|
||||
let Some(t) = trie.as_revealed_mut()
|
||||
{
|
||||
t.prune(max_depth);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for hash in &tries_to_keep {
|
||||
if let Some(trie) =
|
||||
self.storage.tries.get_mut(hash).and_then(|t| t.as_revealed_mut())
|
||||
{
|
||||
trie.prune(max_depth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear revealed_paths for kept tries
|
||||
for hash in &tries_to_keep {
|
||||
if let Some(paths) = self.storage.revealed_paths.get_mut(hash) {
|
||||
paths.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The fields of [`SparseStateTrie`] related to storage tries. This is kept separate from the rest
|
||||
/// of [`SparseStateTrie`] both to help enforce allocation re-use and to allow us to implement
|
||||
/// methods like `get_trie_and_revealed_paths` which return multiple mutable borrows.
|
||||
@@ -1260,7 +1371,7 @@ mod tests {
|
||||
use reth_trie::{updates::StorageTrieUpdates, HashBuilder, MultiProof, EMPTY_ROOT_HASH};
|
||||
use reth_trie_common::{
|
||||
proof::{ProofNodes, ProofRetainer},
|
||||
BranchNode, BranchNodeMasks, LeafNode, StorageMultiProof, TrieMask,
|
||||
BranchNode, BranchNodeMasks, BranchNodeMasksMap, LeafNode, StorageMultiProof, TrieMask,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -232,6 +232,36 @@ pub trait SparseTrie: Sized + Debug + Send + Sync {
|
||||
fn shrink_values_to(&mut self, size: usize);
|
||||
}
|
||||
|
||||
/// Extension trait for sparse tries that support pruning.
|
||||
///
|
||||
/// This trait provides the `prune` method for sparse trie implementations that support
|
||||
/// converting nodes beyond a certain depth into hash stubs. This is useful for reducing
|
||||
/// memory usage when caching tries across payload validations.
|
||||
pub trait SparseTrieExt: SparseTrie {
|
||||
/// Returns the number of revealed (non-Hash) nodes in the trie.
|
||||
fn revealed_node_count(&self) -> usize;
|
||||
|
||||
/// Replaces nodes beyond `max_depth` with hash stubs and removes their descendants.
|
||||
///
|
||||
/// Depth counts nodes traversed (not nibbles), so extension nodes count as 1 depth
|
||||
/// regardless of key length. `max_depth == 0` prunes all children of the root node.
|
||||
///
|
||||
/// # Preconditions
|
||||
///
|
||||
/// Must be called after `root()` to ensure all nodes have computed hashes.
|
||||
/// Calling on a trie without computed hashes will result in no pruning.
|
||||
///
|
||||
/// # Behavior
|
||||
///
|
||||
/// - Embedded nodes (RLP < 32 bytes) are preserved since they have no hash
|
||||
/// - Returns 0 if `max_depth` exceeds trie depth or trie is empty
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The number of nodes converted to hash stubs.
|
||||
fn prune(&mut self, max_depth: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Tracks modifications to the sparse trie structure.
|
||||
///
|
||||
/// Maintains references to both modified and pruned/removed branches, enabling
|
||||
|
||||
Reference in New Issue
Block a user