From 747c0169a77076712574d32f91658811cdf5f089 Mon Sep 17 00:00:00 2001 From: YK Date: Wed, 28 Jan 2026 21:55:21 +0800 Subject: [PATCH] feat(trie): add prune method to SparseTrieInterface (#21427) Co-authored-by: Amp Co-authored-by: Georgios Konstantopoulos --- crates/trie/sparse-parallel/src/trie.rs | 578 +++++++++++++++++++++++- crates/trie/sparse/src/lib.rs | 6 + crates/trie/sparse/src/state.rs | 115 ++++- crates/trie/sparse/src/traits.rs | 30 ++ 4 files changed, 725 insertions(+), 4 deletions(-) diff --git a/crates/trie/sparse-parallel/src/trie.rs b/crates/trie/sparse-parallel/src/trie.rs index 7b55b66fd4..d6a1bfec93 100644 --- a/crates/trie/sparse-parallel/src/trie.rs +++ b/crates/trie/sparse-parallel/src/trie.rs @@ -15,7 +15,7 @@ use reth_trie_common::{ use reth_trie_sparse::{ provider::{RevealedNode, TrieNodeProvider}, LeafLookup, LeafLookupError, RlpNodeStackItem, SparseNode, SparseNodeType, SparseTrie, - SparseTrieUpdates, + SparseTrieExt, SparseTrieUpdates, }; use smallvec::SmallVec; use std::cmp::{Ord, Ordering, PartialOrd}; @@ -908,6 +908,162 @@ impl SparseTrie for ParallelSparseTrie { } } +impl SparseTrieExt for ParallelSparseTrie { + /// Returns the count of revealed (non-hash) nodes across all subtries. + fn revealed_node_count(&self) -> usize { + let upper_count = self.upper_subtrie.nodes.values().filter(|n| !n.is_hash()).count(); + + let lower_count: usize = self + .lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.nodes.values().filter(|n| !n.is_hash()).count()) + .sum(); + + upper_count + lower_count + } + + fn prune(&mut self, max_depth: usize) -> usize { + // DFS traversal to find nodes at max_depth that can be pruned. + // Collects "effective pruned roots" - children of nodes at max_depth with computed hashes. + // We replace nodes with Hash stubs inline during traversal. + let mut effective_pruned_roots = Vec::<(Nibbles, B256)>::new(); + let mut stack: SmallVec<[(Nibbles, usize); 32]> = SmallVec::new(); + stack.push((Nibbles::default(), 0)); + + // DFS traversal: pop path and depth, skip if subtrie or node not found. + while let Some((path, depth)) = stack.pop() { + // Get children to visit from current node (immutable access) + let children: SmallVec<[Nibbles; 16]> = { + let Some(subtrie) = self.subtrie_for_path(&path) else { continue }; + let Some(node) = subtrie.nodes.get(&path) else { continue }; + + match node { + SparseNode::Empty | SparseNode::Hash(_) | SparseNode::Leaf { .. } => { + SmallVec::new() + } + SparseNode::Extension { key, .. } => { + let mut child = path; + child.extend(key); + SmallVec::from_buf_and_len([child; 16], 1) + } + SparseNode::Branch { state_mask, .. } => { + let mut children = SmallVec::new(); + let mut mask = state_mask.get(); + while mask != 0 { + let nibble = mask.trailing_zeros() as u8; + mask &= mask - 1; + let mut child = path; + child.push_unchecked(nibble); + children.push(child); + } + children + } + } + }; + + // Process children - either continue traversal or prune + for child in children { + if depth == max_depth { + // Check if child has a computed hash and replace inline + let hash = self + .subtrie_for_path(&child) + .and_then(|s| s.nodes.get(&child)) + .filter(|n| !n.is_hash()) + .and_then(|n| n.hash()); + + if let Some(hash) = hash { + self.subtrie_for_path_mut(&child) + .nodes + .insert(child, SparseNode::Hash(hash)); + effective_pruned_roots.push((child, hash)); + } + } else { + stack.push((child, depth + 1)); + } + } + } + + if effective_pruned_roots.is_empty() { + return 0; + } + + let nodes_converted = effective_pruned_roots.len(); + + // Sort roots by subtrie type (upper first), then by path for efficient partitioning. + effective_pruned_roots.sort_unstable_by(|(path_a, _), (path_b, _)| { + let subtrie_type_a = SparseSubtrieType::from_path(path_a); + let subtrie_type_b = SparseSubtrieType::from_path(path_b); + subtrie_type_a.cmp(&subtrie_type_b).then(path_a.cmp(path_b)) + }); + + // Split off upper subtrie roots (they come first due to sorting) + let num_upper_roots = effective_pruned_roots + .iter() + .position(|(p, _)| !SparseSubtrieType::path_len_is_upper(p.len())) + .unwrap_or(effective_pruned_roots.len()); + + let roots_upper = &effective_pruned_roots[..num_upper_roots]; + let roots_lower = &effective_pruned_roots[num_upper_roots..]; + + debug_assert!( + { + let mut all_roots: Vec<_> = effective_pruned_roots.iter().map(|(p, _)| p).collect(); + all_roots.sort_unstable(); + all_roots.windows(2).all(|w| !w[1].starts_with(w[0])) + }, + "prune roots must be prefix-free" + ); + + // Upper prune roots that are prefixes of lower subtrie root paths cause the entire + // subtrie to be cleared (preserving allocations for reuse). + if !roots_upper.is_empty() { + for subtrie in &mut self.lower_subtries { + let should_clear = subtrie.as_revealed_ref().is_some_and(|s| { + let search_idx = roots_upper.partition_point(|(root, _)| root <= &s.path); + search_idx > 0 && s.path.starts_with(&roots_upper[search_idx - 1].0) + }); + if should_clear { + subtrie.clear(); + } + } + } + + // Upper subtrie: prune nodes and values + self.upper_subtrie.nodes.retain(|p, _| !is_strict_descendant_in(roots_upper, p)); + self.upper_subtrie.inner.values.retain(|p, _| { + !starts_with_pruned_in(roots_upper, p) && !starts_with_pruned_in(roots_lower, p) + }); + + // Process lower subtries using chunk_by to group roots by subtrie + for roots_group in roots_lower.chunk_by(|(path_a, _), (path_b, _)| { + SparseSubtrieType::from_path(path_a) == SparseSubtrieType::from_path(path_b) + }) { + let subtrie_idx = path_subtrie_index_unchecked(&roots_group[0].0); + + // Skip unrevealed/blinded subtries - nothing to prune + let Some(subtrie) = self.lower_subtries[subtrie_idx].as_revealed_mut() else { + continue; + }; + + // Retain only nodes/values not descended from any pruned root. + subtrie.nodes.retain(|p, _| !is_strict_descendant_in(roots_group, p)); + subtrie.inner.values.retain(|p, _| !starts_with_pruned_in(roots_group, p)); + } + + // Branch node masks pruning + self.branch_node_masks.retain(|p, _| { + if SparseSubtrieType::path_len_is_upper(p.len()) { + !starts_with_pruned_in(roots_upper, p) + } else { + !starts_with_pruned_in(roots_lower, p) && !starts_with_pruned_in(roots_upper, p) + } + }); + + nodes_converted + } +} + impl ParallelSparseTrie { /// Sets the thresholds that control when parallelism is used during operations. pub const fn with_parallelism_thresholds(mut self, thresholds: ParallelismThresholds) -> Self { @@ -2654,6 +2810,44 @@ fn path_subtrie_index_unchecked(path: &Nibbles) -> usize { path.get_byte_unchecked(0) as usize } +/// Checks if `path` is a strict descendant of any root in a sorted slice. +/// +/// Uses binary search to find the candidate root that could be an ancestor. +/// Returns `true` if `path` starts with a root and is longer (strict descendant). +fn is_strict_descendant_in(roots: &[(Nibbles, B256)], path: &Nibbles) -> bool { + if roots.is_empty() { + return false; + } + debug_assert!(roots.windows(2).all(|w| w[0].0 <= w[1].0), "roots must be sorted by path"); + let idx = roots.partition_point(|(root, _)| root <= path); + if idx > 0 { + let candidate = &roots[idx - 1].0; + if path.starts_with(candidate) && path.len() > candidate.len() { + return true; + } + } + false +} + +/// Checks if `path` starts with any root in a sorted slice (inclusive). +/// +/// Uses binary search to find the candidate root that could be a prefix. +/// Returns `true` if `path` starts with a root (including exact match). +fn starts_with_pruned_in(roots: &[(Nibbles, B256)], path: &Nibbles) -> bool { + if roots.is_empty() { + return false; + } + debug_assert!(roots.windows(2).all(|w| w[0].0 <= w[1].0), "roots must be sorted by path"); + let idx = roots.partition_point(|(root, _)| root <= path); + if idx > 0 { + let candidate = &roots[idx - 1].0; + if path.starts_with(candidate) { + return true; + } + } + false +} + /// Used by lower subtries to communicate updates to the top-level [`SparseTrieUpdates`] set. #[derive(Clone, Debug, Eq, PartialEq)] enum SparseTrieUpdatesAction { @@ -2704,7 +2898,8 @@ mod tests { use reth_trie_db::DatabaseTrieCursorFactory; use reth_trie_sparse::{ provider::{DefaultTrieNodeProvider, RevealedNode, TrieNodeProvider}, - LeafLookup, LeafLookupError, SerialSparseTrie, SparseNode, SparseTrie, SparseTrieUpdates, + LeafLookup, LeafLookupError, SerialSparseTrie, SparseNode, SparseTrie, SparseTrieExt, + SparseTrieUpdates, }; use std::collections::{BTreeMap, BTreeSet}; @@ -2749,6 +2944,17 @@ mod tests { Account { nonce, ..Default::default() } } + fn large_account_value() -> Vec { + let account = Account { + nonce: 0x123456789abcdef, + balance: U256::from(0x123456789abcdef0123456789abcdef_u128), + ..Default::default() + }; + let mut buf = Vec::new(); + account.into_trie_account(EMPTY_ROOT_HASH).encode(&mut buf); + buf + } + fn encode_account_value(nonce: u64) -> Vec { let account = Account { nonce, ..Default::default() }; let trie_account = account.into_trie_account(EMPTY_ROOT_HASH); @@ -7106,4 +7312,372 @@ mod tests { // Value should be retrievable assert_eq!(trie.get_leaf_value(&slot_path), Some(&slot_value)); } + + #[test] + fn test_prune_empty_suffix_key_regression() { + // Regression test: when a leaf has an empty suffix key (full path == node path), + // the value must be removed when that path becomes a pruned root. + // This catches the bug where is_strict_descendant fails to remove p == pruned_root. + + use reth_trie_sparse::provider::DefaultTrieNodeProvider; + + let provider = DefaultTrieNodeProvider; + let mut parallel = ParallelSparseTrie::default(); + + // Large value to ensure nodes have hashes (RLP >= 32 bytes) + let value = { + let account = Account { + nonce: 0x123456789abcdef, + balance: U256::from(0x123456789abcdef0123456789abcdef_u128), + ..Default::default() + }; + let mut buf = Vec::new(); + account.into_trie_account(EMPTY_ROOT_HASH).encode(&mut buf); + buf + }; + + // Create a trie with multiple leaves to force a branch at root + for i in 0..16u8 { + parallel + .update_leaf( + Nibbles::from_nibbles([i, 0x1, 0x2, 0x3, 0x4, 0x5]), + value.clone(), + &provider, + ) + .unwrap(); + } + + // Compute root to get hashes + let root_before = parallel.root(); + + // Prune at depth 0: the children of root become pruned roots + parallel.prune(0); + + let root_after = parallel.root(); + assert_eq!(root_before, root_after, "root hash must be preserved"); + + // Key assertion: values under pruned paths must be removed + // With the bug, values at pruned_root paths (not strict descendants) would remain + for i in 0..16u8 { + let path = Nibbles::from_nibbles([i, 0x1, 0x2, 0x3, 0x4, 0x5]); + assert!( + parallel.get_leaf_value(&path).is_none(), + "value at {:?} should be removed after prune", + path + ); + } + } + + #[test] + fn test_prune_at_various_depths() { + for max_depth in [0, 1, 2] { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + for i in 0..4u8 { + for j in 0..4u8 { + for k in 0..4u8 { + trie.update_leaf( + Nibbles::from_nibbles([i, j, k, 0x1, 0x2, 0x3]), + value.clone(), + &provider, + ) + .unwrap(); + } + } + } + + let root_before = trie.root(); + let nodes_before = trie.revealed_node_count(); + + trie.prune(max_depth); + + let root_after = trie.root(); + assert_eq!(root_before, root_after, "root hash should be preserved after prune"); + + let nodes_after = trie.revealed_node_count(); + assert!( + nodes_after < nodes_before, + "node count should decrease after prune at depth {max_depth}" + ); + + if max_depth == 0 { + assert_eq!(nodes_after, 1, "only root should be revealed after prune(0)"); + } + } + } + + #[test] + fn test_prune_empty_trie() { + let mut trie = ParallelSparseTrie::default(); + trie.prune(2); + let root = trie.root(); + assert_eq!(root, EMPTY_ROOT_HASH, "empty trie should have empty root hash"); + } + + #[test] + fn test_prune_preserves_root_hash() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + for i in 0..8u8 { + for j in 0..4u8 { + trie.update_leaf( + Nibbles::from_nibbles([i, j, 0x3, 0x4, 0x5, 0x6]), + value.clone(), + &provider, + ) + .unwrap(); + } + } + + let root_before = trie.root(); + trie.prune(1); + let root_after = trie.root(); + assert_eq!(root_before, root_after, "root hash must be preserved after prune"); + } + + #[test] + fn test_prune_single_leaf_trie() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + trie.update_leaf(Nibbles::from_nibbles([0x1, 0x2, 0x3, 0x4]), value, &provider).unwrap(); + + let root_before = trie.root(); + let nodes_before = trie.revealed_node_count(); + + trie.prune(0); + + let root_after = trie.root(); + assert_eq!(root_before, root_after, "root hash should be preserved"); + assert_eq!(trie.revealed_node_count(), nodes_before, "single leaf trie should not change"); + } + + #[test] + fn test_prune_deep_depth_no_effect() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + for i in 0..4u8 { + trie.update_leaf(Nibbles::from_nibbles([i, 0x2, 0x3, 0x4]), value.clone(), &provider) + .unwrap(); + } + + trie.root(); + let nodes_before = trie.revealed_node_count(); + + trie.prune(100); + + assert_eq!(nodes_before, trie.revealed_node_count(), "deep prune should have no effect"); + } + + #[test] + fn test_prune_extension_node_depth_semantics() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + trie.update_leaf(Nibbles::from_nibbles([0, 1, 2, 3, 0, 5, 6, 7]), value.clone(), &provider) + .unwrap(); + trie.update_leaf(Nibbles::from_nibbles([0, 1, 2, 3, 1, 5, 6, 7]), value, &provider) + .unwrap(); + + let root_before = trie.root(); + trie.prune(1); + + assert_eq!(root_before, trie.root(), "root hash should be preserved"); + assert_eq!(trie.revealed_node_count(), 2, "should have root + extension after prune(1)"); + } + + #[test] + fn test_prune_embedded_node_preserved() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let small_value = vec![0x80]; + trie.update_leaf(Nibbles::from_nibbles([0x0]), small_value.clone(), &provider).unwrap(); + trie.update_leaf(Nibbles::from_nibbles([0x1]), small_value, &provider).unwrap(); + + let root_before = trie.root(); + let nodes_before = trie.revealed_node_count(); + + trie.prune(0); + + assert_eq!(root_before, trie.root(), "root hash must be preserved"); + + if trie.revealed_node_count() == nodes_before { + assert!(trie.get_leaf_value(&Nibbles::from_nibbles([0x0])).is_some()); + assert!(trie.get_leaf_value(&Nibbles::from_nibbles([0x1])).is_some()); + } + } + + #[test] + fn test_prune_mixed_embedded_and_hashed() { + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let large_value = large_account_value(); + let small_value = vec![0x80]; + + for i in 0..8u8 { + let value = if i < 4 { large_value.clone() } else { small_value.clone() }; + trie.update_leaf(Nibbles::from_nibbles([i, 0x1, 0x2, 0x3]), value, &provider).unwrap(); + } + + let root_before = trie.root(); + trie.prune(0); + assert_eq!(root_before, trie.root(), "root hash must be preserved"); + } + + #[test] + fn test_prune_many_lower_subtries() { + let provider = DefaultTrieNodeProvider; + + let large_value = large_account_value(); + + let mut keys = Vec::new(); + for first in 0..16u8 { + for second in 0..16u8 { + keys.push(Nibbles::from_nibbles([first, second, 0x1, 0x2, 0x3, 0x4])); + } + } + + let mut trie = ParallelSparseTrie::default(); + + for key in &keys { + trie.update_leaf(*key, large_value.clone(), &provider).unwrap(); + } + + let root_before = trie.root(); + let pruned = trie.prune(1); + + assert!(pruned > 0, "should have pruned some nodes"); + assert_eq!(root_before, trie.root(), "root hash should be preserved"); + + for key in &keys { + assert!(trie.get_leaf_value(key).is_none(), "value should be pruned"); + } + } + + #[test] + #[ignore = "profiling test - run manually"] + fn test_prune_profile() { + use std::time::Instant; + + let provider = DefaultTrieNodeProvider; + let large_value = large_account_value(); + + // Generate 65536 keys (16^4) for a large trie + let mut keys = Vec::with_capacity(65536); + for a in 0..16u8 { + for b in 0..16u8 { + for c in 0..16u8 { + for d in 0..16u8 { + keys.push(Nibbles::from_nibbles([a, b, c, d, 0x5, 0x6, 0x7, 0x8])); + } + } + } + } + + // Build base trie once + let mut base_trie = ParallelSparseTrie::default(); + for key in &keys { + base_trie.update_leaf(*key, large_value.clone(), &provider).unwrap(); + } + base_trie.root(); // ensure hashes computed + + // Pre-clone tries to exclude clone time from profiling + let iterations = 100; + let mut tries: Vec<_> = (0..iterations).map(|_| base_trie.clone()).collect(); + + // Measure only prune() + let mut total_pruned = 0; + let start = Instant::now(); + for trie in &mut tries { + total_pruned += trie.prune(2); + } + let elapsed = start.elapsed(); + + println!( + "Prune benchmark: {} iterations, total: {:?}, avg: {:?}, pruned/iter: {}", + iterations, + elapsed, + elapsed / iterations as u32, + total_pruned / iterations + ); + } + + #[test] + fn test_prune_max_depth_overflow() { + // Verify that max_depth > 255 is not truncated (was u8, now usize) + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + for i in 0..4u8 { + trie.update_leaf(Nibbles::from_nibbles([i, 0x1, 0x2, 0x3]), value.clone(), &provider) + .unwrap(); + } + + trie.root(); + let nodes_before = trie.revealed_node_count(); + + // If depth were truncated to u8, 300 would become 44 and might prune something + trie.prune(300); + + assert_eq!( + nodes_before, + trie.revealed_node_count(), + "prune(300) should have no effect on a shallow trie" + ); + } + + #[test] + fn test_prune_fast_path_case2_update_after() { + // Test fast-path Case 2: upper prune root is prefix of lower subtrie. + // After pruning, we should be able to update leaves without panic. + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + let value = large_account_value(); + + // Create keys that span into lower subtries (path.len() >= UPPER_TRIE_MAX_DEPTH) + // UPPER_TRIE_MAX_DEPTH is typically 2, so paths of length 3+ go to lower subtries + for first in 0..4u8 { + for second in 0..4u8 { + trie.update_leaf( + Nibbles::from_nibbles([first, second, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]), + value.clone(), + &provider, + ) + .unwrap(); + } + } + + let root_before = trie.root(); + + // Prune at depth 0 - upper roots become prefixes of lower subtrie paths + trie.prune(0); + + let root_after = trie.root(); + assert_eq!(root_before, root_after, "root hash should be preserved"); + + // Now try to update a leaf - this should not panic even though lower subtries + // were replaced with Blind(None) + let new_path = Nibbles::from_nibbles([0x5, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]); + trie.update_leaf(new_path, value, &provider).unwrap(); + + // The trie should still be functional + let _ = trie.root(); + } } diff --git a/crates/trie/sparse/src/lib.rs b/crates/trie/sparse/src/lib.rs index 6b17597048..d63027fde1 100644 --- a/crates/trie/sparse/src/lib.rs +++ b/crates/trie/sparse/src/lib.rs @@ -5,6 +5,12 @@ extern crate alloc; +/// Default depth to prune sparse tries to for cross-payload caching. +pub const DEFAULT_SPARSE_TRIE_PRUNE_DEPTH: usize = 4; + +/// Default number of storage tries to preserve across payload validations. +pub const DEFAULT_MAX_PRESERVED_STORAGE_TRIES: usize = 100; + mod state; pub use state::*; diff --git a/crates/trie/sparse/src/state.rs b/crates/trie/sparse/src/state.rs index 84f57cde78..1b032b8dda 100644 --- a/crates/trie/sparse/src/state.rs +++ b/crates/trie/sparse/src/state.rs @@ -1,6 +1,6 @@ use crate::{ provider::{TrieNodeProvider, TrieNodeProviderFactory}, - traits::SparseTrie as SparseTrieTrait, + traits::{SparseTrie as SparseTrieTrait, SparseTrieExt}, RevealableSparseTrie, SerialSparseTrie, }; use alloc::{collections::VecDeque, vec::Vec}; @@ -972,6 +972,117 @@ where } } +impl SparseStateTrie +where + A: SparseTrieTrait + SparseTrieExt + Default, + S: SparseTrieTrait + SparseTrieExt + Default + Clone, +{ + /// Minimum number of storage tries before parallel pruning is enabled. + const PARALLEL_PRUNE_THRESHOLD: usize = 16; + + /// Returns true if parallelism should be enabled for pruning the given number of tries. + /// Will always return false in `no_std` builds. + const fn is_prune_parallelism_enabled(num_tries: usize) -> bool { + #[cfg(not(feature = "std"))] + return false; + + num_tries >= Self::PARALLEL_PRUNE_THRESHOLD + } + + /// Prunes the account trie and selected storage tries to reduce memory usage. + /// + /// Storage tries not in the top `max_storage_tries` by revealed node count are cleared + /// entirely. + /// + /// # Preconditions + /// + /// Node hashes must be computed via `root()` before calling this method. Otherwise, nodes + /// cannot be converted to hash stubs and pruning will have no effect. + /// + /// # Effects + /// + /// - Clears `revealed_account_paths` and `revealed_paths` for all storage tries + pub fn prune(&mut self, max_depth: usize, max_storage_tries: usize) { + if let Some(trie) = self.state.as_revealed_mut() { + trie.prune(max_depth); + } + self.revealed_account_paths.clear(); + + let mut storage_trie_counts: Vec<(B256, usize)> = self + .storage + .tries + .iter() + .map(|(hash, trie)| { + let count = match trie { + RevealableSparseTrie::Revealed(t) => t.revealed_node_count(), + RevealableSparseTrie::Blind(_) => 0, + }; + (*hash, count) + }) + .collect(); + + // Use O(n) selection instead of O(n log n) sort + let tries_to_keep: HashSet = if storage_trie_counts.len() <= max_storage_tries { + storage_trie_counts.iter().map(|(hash, _)| *hash).collect() + } else { + storage_trie_counts + .select_nth_unstable_by(max_storage_tries.saturating_sub(1), |a, b| b.1.cmp(&a.1)); + storage_trie_counts[..max_storage_tries].iter().map(|(hash, _)| *hash).collect() + }; + + // Collect keys to avoid borrow conflict + let tries_to_clear: Vec = self + .storage + .tries + .keys() + .filter(|hash| !tries_to_keep.contains(*hash)) + .copied() + .collect(); + + // Evict storage tries that exceeded limit, saving cleared allocations for reuse + for hash in tries_to_clear { + if let Some(trie) = self.storage.tries.remove(&hash) { + self.storage.cleared_tries.push(trie.clear()); + } + if let Some(mut paths) = self.storage.revealed_paths.remove(&hash) { + paths.clear(); + self.storage.cleared_revealed_paths.push(paths); + } + } + + // Prune storage tries that are kept + if Self::is_prune_parallelism_enabled(tries_to_keep.len()) { + #[cfg(feature = "std")] + { + use rayon::prelude::*; + + self.storage.tries.par_iter_mut().for_each(|(hash, trie)| { + if tries_to_keep.contains(hash) && + let Some(t) = trie.as_revealed_mut() + { + t.prune(max_depth); + } + }); + } + } else { + for hash in &tries_to_keep { + if let Some(trie) = + self.storage.tries.get_mut(hash).and_then(|t| t.as_revealed_mut()) + { + trie.prune(max_depth); + } + } + } + + // Clear revealed_paths for kept tries + for hash in &tries_to_keep { + if let Some(paths) = self.storage.revealed_paths.get_mut(hash) { + paths.clear(); + } + } + } +} + /// The fields of [`SparseStateTrie`] related to storage tries. This is kept separate from the rest /// of [`SparseStateTrie`] both to help enforce allocation re-use and to allow us to implement /// methods like `get_trie_and_revealed_paths` which return multiple mutable borrows. @@ -1260,7 +1371,7 @@ mod tests { use reth_trie::{updates::StorageTrieUpdates, HashBuilder, MultiProof, EMPTY_ROOT_HASH}; use reth_trie_common::{ proof::{ProofNodes, ProofRetainer}, - BranchNode, BranchNodeMasks, LeafNode, StorageMultiProof, TrieMask, + BranchNode, BranchNodeMasks, BranchNodeMasksMap, LeafNode, StorageMultiProof, TrieMask, }; #[test] diff --git a/crates/trie/sparse/src/traits.rs b/crates/trie/sparse/src/traits.rs index 15f474c6a2..e235cead63 100644 --- a/crates/trie/sparse/src/traits.rs +++ b/crates/trie/sparse/src/traits.rs @@ -232,6 +232,36 @@ pub trait SparseTrie: Sized + Debug + Send + Sync { fn shrink_values_to(&mut self, size: usize); } +/// Extension trait for sparse tries that support pruning. +/// +/// This trait provides the `prune` method for sparse trie implementations that support +/// converting nodes beyond a certain depth into hash stubs. This is useful for reducing +/// memory usage when caching tries across payload validations. +pub trait SparseTrieExt: SparseTrie { + /// Returns the number of revealed (non-Hash) nodes in the trie. + fn revealed_node_count(&self) -> usize; + + /// Replaces nodes beyond `max_depth` with hash stubs and removes their descendants. + /// + /// Depth counts nodes traversed (not nibbles), so extension nodes count as 1 depth + /// regardless of key length. `max_depth == 0` prunes all children of the root node. + /// + /// # Preconditions + /// + /// Must be called after `root()` to ensure all nodes have computed hashes. + /// Calling on a trie without computed hashes will result in no pruning. + /// + /// # Behavior + /// + /// - Embedded nodes (RLP < 32 bytes) are preserved since they have no hash + /// - Returns 0 if `max_depth` exceeds trie depth or trie is empty + /// + /// # Returns + /// + /// The number of nodes converted to hash stubs. + fn prune(&mut self, max_depth: usize) -> usize; +} + /// Tracks modifications to the sparse trie structure. /// /// Maintains references to both modified and pruned/removed branches, enabling