From 2d71243cf60af823a10f07f3a92b13cffaea82d5 Mon Sep 17 00:00:00 2001 From: YK Date: Thu, 29 Jan 2026 19:25:08 +0800 Subject: [PATCH] feat(trie): add update_leaves method to SparseTrieExt (#21525) Co-authored-by: Amp Co-authored-by: Georgios Konstantopoulos --- crates/trie/sparse-parallel/src/trie.rs | 1091 +++++++++++++++++++++-- crates/trie/sparse/src/provider.rs | 15 + crates/trie/sparse/src/traits.rs | 33 +- crates/trie/sparse/src/trie.rs | 35 +- 4 files changed, 1094 insertions(+), 80 deletions(-) diff --git a/crates/trie/sparse-parallel/src/trie.rs b/crates/trie/sparse-parallel/src/trie.rs index d6a1bfec93..5d7e62acb5 100644 --- a/crates/trie/sparse-parallel/src/trie.rs +++ b/crates/trie/sparse-parallel/src/trie.rs @@ -6,7 +6,7 @@ use alloy_primitives::{ }; use alloy_rlp::Decodable; use alloy_trie::{BranchNodeCompact, TrieMask, EMPTY_ROOT_HASH}; -use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult}; +use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, SparseTrieResult}; use reth_trie_common::{ prefix_set::{PrefixSet, PrefixSetMut}, BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, @@ -123,6 +123,9 @@ pub struct ParallelSparseTrie { update_actions_buffers: Vec>, /// Thresholds controlling when parallelism is enabled for different operations. parallelism_thresholds: ParallelismThresholds, + /// Tracks proof targets already requested via `update_leaves` to avoid duplicate callbacks + /// across retry calls. Key is (`leaf_path`, `min_depth`). + requested_proof_targets: alloy_primitives::map::HashSet<(Nibbles, u8)>, /// Metrics for the parallel sparse trie. #[cfg(feature = "metrics")] metrics: crate::metrics::ParallelSparseTrieMetrics, @@ -141,6 +144,7 @@ impl Default for ParallelSparseTrie { branch_node_masks: BranchNodeMasksMap::default(), update_actions_buffers: Vec::default(), parallelism_thresholds: Default::default(), + requested_proof_targets: Default::default(), #[cfg(feature = "metrics")] metrics: Default::default(), } @@ -301,15 +305,31 @@ impl SparseTrie for ParallelSparseTrie { value: Vec, provider: P, ) -> SparseTrieResult<()> { - self.prefix_set.insert(full_path); - let existing = self.upper_subtrie.inner.values.insert(full_path, value.clone()); - if existing.is_some() { - // upper trie structure unchanged, return immediately - return Ok(()) + // Check if the value already exists - if so, just update it (no structural changes needed) + if self.upper_subtrie.inner.values.contains_key(&full_path) { + self.prefix_set.insert(full_path); + self.upper_subtrie.inner.values.insert(full_path, value); + return Ok(()); + } + // Also check lower subtries for existing value + if let Some(subtrie) = self.lower_subtrie_for_path(&full_path) && + subtrie.inner.values.contains_key(&full_path) + { + self.prefix_set.insert(full_path); + self.lower_subtrie_for_path_mut(&full_path) + .expect("subtrie exists") + .inner + .values + .insert(full_path, value); + return Ok(()); } let retain_updates = self.updates_enabled(); + // Insert value into upper subtrie temporarily. We'll move it to the correct subtrie + // during traversal, or clean it up if we error. + self.upper_subtrie.inner.values.insert(full_path, value.clone()); + // Start at the root, traversing until we find either the node to update or a subtrie to // update. // @@ -320,6 +340,8 @@ impl SparseTrie for ParallelSparseTrie { // `new_nodes` to keep track of any nodes that were created during the traversal. let mut new_nodes = Vec::new(); let mut next = Some(Nibbles::default()); + // Track the original node that was modified (path, original_node) for rollback + let mut modified_original: Option<(Nibbles, SparseNode)> = None; // Traverse the upper subtrie to find the node to update or the subtrie to update. // @@ -328,11 +350,28 @@ impl SparseTrie for ParallelSparseTrie { while let Some(current) = next.filter(|next| SparseSubtrieType::path_len_is_upper(next.len())) { + // Save original node for potential rollback (only if not already saved) + if modified_original.is_none() && + let Some(node) = self.upper_subtrie.nodes.get(¤t) + { + modified_original = Some((current, node.clone())); + } + // Traverse the next node, keeping track of any changed nodes and the next step in the - // trie - match self.upper_subtrie.update_next_node(current, &full_path, retain_updates)? { + // trie. If traversal fails, clean up the value we inserted and propagate the error. + let step_result = + self.upper_subtrie.update_next_node(current, &full_path, retain_updates); + + if step_result.is_err() { + self.upper_subtrie.inner.values.remove(&full_path); + return step_result.map(|_| ()); + } + + match step_result? { LeafUpdateStep::Continue { next_node } => { next = Some(next_node); + // Clear modified_original since we haven't actually modified anything yet + modified_original = None; } LeafUpdateStep::Complete { inserted_nodes, reveal_path } => { new_nodes.extend(inserted_nodes); @@ -351,10 +390,30 @@ impl SparseTrie for ParallelSparseTrie { leaf_full_path = ?full_path, "Extension node child not revealed in update_leaf, falling back to db", ); - if let Some(RevealedNode { node, tree_mask, hash_mask }) = - provider.trie_node(&reveal_path)? + let revealed_node = match provider.trie_node(&reveal_path) { + Ok(node) => node, + Err(e) => { + self.rollback_insert( + &full_path, + &new_nodes, + modified_original.take(), + ); + return Err(e); + } + }; + if let Some(RevealedNode { node, tree_mask, hash_mask }) = revealed_node { - let decoded = TrieNode::decode(&mut &node[..])?; + let decoded = match TrieNode::decode(&mut &node[..]) { + Ok(d) => d, + Err(e) => { + self.rollback_insert( + &full_path, + &new_nodes, + modified_original.take(), + ); + return Err(e.into()); + } + }; trace!( target: "trie::parallel_sparse", ?reveal_path, @@ -364,9 +423,21 @@ impl SparseTrie for ParallelSparseTrie { "Revealing child (from upper)", ); let masks = BranchNodeMasks::from_optional(hash_mask, tree_mask); - subtrie.reveal_node(reveal_path, &decoded, masks)?; + if let Err(e) = subtrie.reveal_node(reveal_path, &decoded, masks) { + self.rollback_insert( + &full_path, + &new_nodes, + modified_original.take(), + ); + return Err(e); + } masks } else { + self.rollback_insert( + &full_path, + &new_nodes, + modified_original.take(), + ); return Err(SparseTrieErrorKind::NodeNotFoundInProvider { path: reveal_path, } @@ -448,13 +519,24 @@ impl SparseTrie for ParallelSparseTrie { // If we didn't update the target leaf, we need to call update_leaf on the subtrie // to ensure that the leaf is updated correctly. - if let Some((revealed_path, revealed_masks)) = - subtrie.update_leaf(full_path, value, provider, retain_updates)? - { - self.branch_node_masks.insert(revealed_path, revealed_masks); + match subtrie.update_leaf(full_path, value, provider, retain_updates) { + Ok(Some((revealed_path, revealed_masks))) => { + self.branch_node_masks.insert(revealed_path, revealed_masks); + } + Ok(None) => {} + Err(e) => { + // Clean up: remove the value from lower subtrie if it was inserted + if let Some(lower) = self.lower_subtrie_for_path_mut(&full_path) { + lower.inner.values.remove(&full_path); + } + return Err(e); + } } } + // Insert into prefix_set only after all operations succeed + self.prefix_set.insert(full_path); + Ok(()) } @@ -479,7 +561,7 @@ impl SparseTrie for ParallelSparseTrie { // and grandparent. let leaf_path; - let leaf_subtrie; + let leaf_subtrie_type; let mut branch_parent_path: Option = None; let mut branch_parent_node: Option = None; @@ -488,13 +570,18 @@ impl SparseTrie for ParallelSparseTrie { let mut ext_grandparent_node: Option = None; let mut curr_path = Nibbles::new(); // start traversal from root - let mut curr_subtrie = self.upper_subtrie.as_mut(); - let mut curr_subtrie_is_upper = true; + let mut curr_subtrie_type = SparseSubtrieType::Upper; // List of node paths which need to have their hashes reset let mut paths_to_reset_hashes = Vec::new(); loop { + let curr_subtrie = match curr_subtrie_type { + SparseSubtrieType::Upper => &mut self.upper_subtrie, + SparseSubtrieType::Lower(idx) => { + self.lower_subtries[idx].as_revealed_mut().expect("lower subtrie is revealed") + } + }; let curr_node = curr_subtrie.nodes.get_mut(&curr_path).unwrap(); match Self::find_next_to_leaf(&curr_path, curr_node, full_path) { @@ -505,7 +592,7 @@ impl SparseTrie for ParallelSparseTrie { FindNextToLeafOutcome::Found => { // this node is the target leaf leaf_path = curr_path; - leaf_subtrie = curr_subtrie; + leaf_subtrie_type = curr_subtrie_type; break; } FindNextToLeafOutcome::ContinueFrom(next_path) => { @@ -551,24 +638,53 @@ impl SparseTrie for ParallelSparseTrie { curr_path = next_path; - // If we were previously looking at the upper trie, and the new path is in the - // lower trie, we need to pull out a ref to the lower trie. - if curr_subtrie_is_upper && - let SparseSubtrieType::Lower(idx) = - SparseSubtrieType::from_path(&curr_path) + // Update subtrie type if we're crossing into the lower trie. + let next_subtrie_type = SparseSubtrieType::from_path(&curr_path); + if matches!(curr_subtrie_type, SparseSubtrieType::Upper) && + matches!(next_subtrie_type, SparseSubtrieType::Lower(_)) { - curr_subtrie = self.lower_subtries[idx] - .as_revealed_mut() - .expect("lower subtrie is revealed"); - curr_subtrie_is_upper = false; + curr_subtrie_type = next_subtrie_type; } } }; } + // Before mutating, check if branch collapse would require revealing a blinded node. + // This ensures remove_leaf is atomic: if it errors, the trie is unchanged. + if let (Some(branch_path), Some(SparseNode::Branch { state_mask, .. })) = + (&branch_parent_path, &branch_parent_node) + { + let mut check_mask = *state_mask; + let child_nibble = leaf_path.get_unchecked(branch_path.len()); + check_mask.unset_bit(child_nibble); + + if check_mask.count_bits() == 1 { + // Branch will collapse - check if remaining child needs revealing + let remaining_child_path = { + let mut p = *branch_path; + p.push_unchecked( + check_mask.first_set_bit_index().expect("state mask is not empty"), + ); + p + }; + + // Pre-validate the entire reveal chain (including extension grandchildren). + // This check mirrors the logic in `reveal_remaining_child_on_leaf_removal` with + // `recurse_into_extension: true` to ensure all nodes that would be revealed + // are accessible before any mutations occur. + self.pre_validate_reveal_chain(&remaining_child_path, &provider)?; + } + } + // We've traversed to the leaf and collected its ancestors as necessary. Remove the leaf // from its SparseSubtrie and reset the hashes of the nodes along the path. self.prefix_set.insert(*full_path); + let leaf_subtrie = match leaf_subtrie_type { + SparseSubtrieType::Upper => &mut self.upper_subtrie, + SparseSubtrieType::Lower(idx) => { + self.lower_subtries[idx].as_revealed_mut().expect("lower subtrie is revealed") + } + }; leaf_subtrie.inner.values.remove(full_path); for (subtrie_type, path) in paths_to_reset_hashes { let node = match subtrie_type { @@ -1062,6 +1178,77 @@ impl SparseTrieExt for ParallelSparseTrie { nodes_converted } + + fn update_leaves( + &mut self, + updates: &mut alloy_primitives::map::B256Map, + mut proof_required_fn: impl FnMut(Nibbles, u8), + ) -> SparseTrieResult<()> { + use reth_trie_sparse::{provider::NoRevealProvider, LeafUpdate}; + + // Collect keys upfront since we mutate `updates` during iteration. + // On success, entries are removed; on blinded node failure, they're re-inserted. + let keys: Vec = updates.keys().copied().collect(); + + for key in keys { + let full_path = Nibbles::unpack(key); + // Remove upfront - we'll re-insert if the operation fails due to blinded node. + let update = updates.remove(&key).unwrap(); + + match update { + LeafUpdate::Changed(value) => { + if value.is_empty() { + // Removal: remove_leaf with NoRevealProvider is atomic - returns a + // retriable error before any mutations (via pre_validate_reveal_chain). + match self.remove_leaf(&full_path, NoRevealProvider) { + Ok(()) => {} + Err(e) => { + if let Some(path) = Self::get_retriable_path(&e) { + let min_len = (path.len() as u8).min(64); + if self.requested_proof_targets.insert((full_path, min_len)) { + proof_required_fn(full_path, min_len); + } + updates.insert(key, LeafUpdate::Changed(value)); + } else { + return Err(e); + } + } + } + } else { + // Update/insert: update_leaf is atomic - cleans up on error. + if let Err(e) = self.update_leaf(full_path, value.clone(), NoRevealProvider) + { + if let Some(path) = Self::get_retriable_path(&e) { + let min_len = (path.len() as u8).min(64); + if self.requested_proof_targets.insert((full_path, min_len)) { + proof_required_fn(full_path, min_len); + } + updates.insert(key, LeafUpdate::Changed(value)); + } else { + return Err(e); + } + } + } + } + LeafUpdate::Touched => { + // Touched is read-only: check if path is accessible, request proof if blinded. + match self.find_leaf(&full_path, None) { + Err(LeafLookupError::BlindedNode { path, .. }) => { + let min_len = (path.len() as u8).min(64); + if self.requested_proof_targets.insert((full_path, min_len)) { + proof_required_fn(full_path, min_len); + } + updates.insert(key, LeafUpdate::Touched); + } + // Path is fully revealed (exists or proven non-existent), no action needed. + Ok(_) | Err(LeafLookupError::ValueMismatch { .. }) => {} + } + } + } + } + + Ok(()) + } } impl ParallelSparseTrie { @@ -1076,6 +1263,14 @@ impl ParallelSparseTrie { self.updates.is_some() } + /// Clears the set of already-requested proof targets. + /// + /// Call this when reusing the trie for a new payload to ensure proof callbacks + /// are emitted fresh. + pub fn clear_requested_proof_targets(&mut self) { + self.requested_proof_targets.clear(); + } + /// Returns true if parallelism should be enabled for revealing the given number of nodes. /// Will always return false in nostd builds. const fn is_reveal_parallelism_enabled(&self, num_nodes: usize) -> bool { @@ -1094,6 +1289,46 @@ impl ParallelSparseTrie { num_changed_keys >= self.parallelism_thresholds.min_updated_nodes } + /// Checks if an error is retriable (`BlindedNode` or `NodeNotFoundInProvider`) and extracts + /// the path if so. + /// + /// Both error types indicate that a node needs to be revealed before the operation can + /// succeed. `BlindedNode` occurs when traversing to a Hash node, while `NodeNotFoundInProvider` + /// occurs when `retain_updates` is enabled and an extension node's child needs revealing. + const fn get_retriable_path(e: &SparseTrieError) -> Option { + match e.kind() { + SparseTrieErrorKind::BlindedNode { path, .. } | + SparseTrieErrorKind::NodeNotFoundInProvider { path } => Some(*path), + _ => None, + } + } + + /// Rolls back a partial update by removing the value, removing any inserted nodes, + /// and restoring any modified original node. + /// This ensures `update_leaf` is atomic - either it succeeds completely or leaves the trie + /// unchanged. + fn rollback_insert( + &mut self, + full_path: &Nibbles, + inserted_nodes: &[Nibbles], + modified_original: Option<(Nibbles, SparseNode)>, + ) { + self.upper_subtrie.inner.values.remove(full_path); + for node_path in inserted_nodes { + // Try upper subtrie first - nodes may be there even if path length suggests lower + if self.upper_subtrie.nodes.remove(node_path).is_none() { + // Not in upper, try lower subtrie + if let Some(subtrie) = self.lower_subtrie_for_path_mut(node_path) { + subtrie.nodes.remove(node_path); + } + } + } + // Restore the original node that was modified + if let Some((path, original_node)) = modified_original { + self.upper_subtrie.nodes.insert(path, original_node); + } + } + /// Creates a new revealed sparse trie from the given root node. /// /// This function initializes the internal structures and then reveals the root. @@ -1396,6 +1631,50 @@ impl ParallelSparseTrie { } } + /// Pre-validates reveal chain accessibility before mutations. + /// + /// Walks the trie path checking that all nodes can be revealed. This is called before + /// any mutations to ensure the operation will succeed atomically. + /// + /// Returns `BlindedNode` error if any node in the chain cannot be revealed by the provider. + fn pre_validate_reveal_chain( + &self, + path: &Nibbles, + provider: &P, + ) -> SparseTrieResult<()> { + // Find the subtrie containing this path, or return Ok if path doesn't exist + let subtrie = match self.subtrie_for_path(path) { + Some(s) => s, + None => return Ok(()), + }; + + match subtrie.nodes.get(path) { + // Hash node: attempt to reveal from provider + Some(SparseNode::Hash(hash)) => match provider.trie_node(path)? { + Some(RevealedNode { node, .. }) => { + let decoded = TrieNode::decode(&mut &node[..])?; + // Extension nodes have children that also need validation + if let TrieNode::Extension(ext) = decoded { + let mut grandchild_path = *path; + grandchild_path.extend(&ext.key); + return self.pre_validate_reveal_chain(&grandchild_path, provider); + } + Ok(()) + } + // Provider cannot reveal this node - operation would fail + None => Err(SparseTrieErrorKind::BlindedNode { path: *path, hash: *hash }.into()), + }, + // Already-revealed extension: recursively validate its child + Some(SparseNode::Extension { key, .. }) => { + let mut child_path = *path; + child_path.extend(key); + self.pre_validate_reveal_chain(&child_path, provider) + } + // Leaf, Branch, Empty, or missing: no further validation needed + _ => Ok(()), + } + } + /// Called when a leaf is removed on a branch which has only one other remaining child. That /// child must be revealed in order to properly collapse the branch. /// @@ -7568,54 +7847,6 @@ mod tests { } } - #[test] - #[ignore = "profiling test - run manually"] - fn test_prune_profile() { - use std::time::Instant; - - let provider = DefaultTrieNodeProvider; - let large_value = large_account_value(); - - // Generate 65536 keys (16^4) for a large trie - let mut keys = Vec::with_capacity(65536); - for a in 0..16u8 { - for b in 0..16u8 { - for c in 0..16u8 { - for d in 0..16u8 { - keys.push(Nibbles::from_nibbles([a, b, c, d, 0x5, 0x6, 0x7, 0x8])); - } - } - } - } - - // Build base trie once - let mut base_trie = ParallelSparseTrie::default(); - for key in &keys { - base_trie.update_leaf(*key, large_value.clone(), &provider).unwrap(); - } - base_trie.root(); // ensure hashes computed - - // Pre-clone tries to exclude clone time from profiling - let iterations = 100; - let mut tries: Vec<_> = (0..iterations).map(|_| base_trie.clone()).collect(); - - // Measure only prune() - let mut total_pruned = 0; - let start = Instant::now(); - for trie in &mut tries { - total_pruned += trie.prune(2); - } - let elapsed = start.elapsed(); - - println!( - "Prune benchmark: {} iterations, total: {:?}, avg: {:?}, pruned/iter: {}", - iterations, - elapsed, - elapsed / iterations as u32, - total_pruned / iterations - ); - } - #[test] fn test_prune_max_depth_overflow() { // Verify that max_depth > 255 is not truncated (was u8, now usize) @@ -7680,4 +7911,710 @@ mod tests { // The trie should still be functional let _ = trie.root(); } + + // update_leaves tests + + #[test] + fn test_update_leaves_successful_update() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + // Create a leaf in the trie using a full-length key + let b256_key = B256::with_last_byte(42); + let key = Nibbles::unpack(b256_key); + let value = encode_account_value(1); + trie.update_leaf(key, value, &provider).unwrap(); + + // Create update map with a new value for the same key + let new_value = encode_account_value(2); + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(new_value)); + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Update should succeed: map empty, callback not invoked + assert!(updates.is_empty(), "Update map should be empty after successful update"); + assert!( + proof_targets.borrow().is_empty(), + "Callback should not be invoked for revealed paths" + ); + } + + #[test] + fn test_update_leaves_insert_new_leaf() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + let mut trie = ParallelSparseTrie::default(); + + // Insert a NEW leaf (key doesn't exist yet) via update_leaves + let b256_key = B256::with_last_byte(99); + let new_value = encode_account_value(42); + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(new_value.clone())); + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Insert should succeed: map empty, callback not invoked + assert!(updates.is_empty(), "Update map should be empty after successful insert"); + assert!( + proof_targets.borrow().is_empty(), + "Callback should not be invoked for new leaf insert" + ); + + // Verify the leaf was actually inserted + let full_path = Nibbles::unpack(b256_key); + assert_eq!( + trie.get_leaf_value(&full_path), + Some(&new_value), + "New leaf value should be retrievable" + ); + } + + #[test] + fn test_update_leaves_blinded_node() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a trie with a blinded node + // Use a small value that fits in RLP encoding + let small_value = alloy_rlp::encode_fixed_size(&U256::from(1)).to_vec(); + let leaf = LeafNode::new( + Nibbles::default(), // short key for RLP encoding + small_value, + ); + let branch = TrieNode::Branch(BranchNode::new( + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), // blinded child at 0 + RlpNode::from_raw_rlp(&alloy_rlp::encode(leaf.clone())).unwrap(), // revealed at 1 + ], + TrieMask::new(0b11), + )); + + let mut trie = ParallelSparseTrie::from_root( + branch.clone(), + Some(BranchNodeMasks { + hash_mask: TrieMask::new(0b01), + tree_mask: TrieMask::default(), + }), + false, + ) + .unwrap(); + + // Reveal only the branch and one child, leaving child 0 as a Hash node + trie.reveal_node( + Nibbles::default(), + branch, + Some(BranchNodeMasks { + hash_mask: TrieMask::default(), + tree_mask: TrieMask::new(0b01), + }), + ) + .unwrap(); + trie.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + + // The path 0x0... is blinded (Hash node) + // Create an update targeting the blinded path using a full B256 key + let b256_key = B256::ZERO; // starts with 0x0... + + let new_value = encode_account_value(42); + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(new_value)); + + let proof_targets = RefCell::new(Vec::new()); + let prefix_set_len_before = trie.prefix_set.len(); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Update should remain in map (blinded node) + assert!(!updates.is_empty(), "Update should remain in map when hitting blinded node"); + + // prefix_set should be unchanged after failed update + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged after failed update on blinded node" + ); + + // Callback should be invoked + let targets = proof_targets.borrow(); + assert!(!targets.is_empty(), "Callback should be invoked for blinded path"); + + // min_len should equal the blinded node's path length (1 nibble) + assert_eq!(targets[0].1, 1, "min_len should equal blinded node path length"); + } + + #[test] + fn test_update_leaves_removal() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + // Create two leaves so removal doesn't result in empty trie issues + // Use full-length keys + let b256_key1 = B256::with_last_byte(1); + let b256_key2 = B256::with_last_byte(2); + let key1 = Nibbles::unpack(b256_key1); + let key2 = Nibbles::unpack(b256_key2); + let value = encode_account_value(1); + trie.update_leaf(key1, value.clone(), &provider).unwrap(); + trie.update_leaf(key2, value, &provider).unwrap(); + + // Create an update to remove key1 (empty value = removal) + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key1, LeafUpdate::Changed(vec![])); // empty = removal + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Removal should succeed: map empty + assert!(updates.is_empty(), "Update map should be empty after successful removal"); + } + + #[test] + fn test_update_leaves_removal_blinded() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a trie with a blinded node + // Use a small value that fits in RLP encoding + let small_value = alloy_rlp::encode_fixed_size(&U256::from(1)).to_vec(); + let leaf = LeafNode::new( + Nibbles::default(), // short key for RLP encoding + small_value, + ); + let branch = TrieNode::Branch(BranchNode::new( + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), // blinded child at 0 + RlpNode::from_raw_rlp(&alloy_rlp::encode(leaf.clone())).unwrap(), // revealed at 1 + ], + TrieMask::new(0b11), + )); + + let mut trie = ParallelSparseTrie::from_root( + branch.clone(), + Some(BranchNodeMasks { + hash_mask: TrieMask::new(0b01), + tree_mask: TrieMask::default(), + }), + false, + ) + .unwrap(); + + trie.reveal_node( + Nibbles::default(), + branch, + Some(BranchNodeMasks { + hash_mask: TrieMask::default(), + tree_mask: TrieMask::new(0b01), + }), + ) + .unwrap(); + trie.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + + // Simulate having a known value behind the blinded node + let b256_key = B256::ZERO; // starts with 0x0... + let full_path = Nibbles::unpack(b256_key); + + // Insert the value into the trie's values map (simulating we know about it) + let old_value = encode_account_value(99); + trie.upper_subtrie.inner.values.insert(full_path, old_value.clone()); + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(vec![])); // empty = removal + + let proof_targets = RefCell::new(Vec::new()); + let prefix_set_len_before = trie.prefix_set.len(); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Callback should be invoked + assert!( + !proof_targets.borrow().is_empty(), + "Callback should be invoked when removal hits blinded node" + ); + + // Update should remain in map + assert!(!updates.is_empty(), "Update should remain in map when removal hits blinded node"); + + // Original value should be preserved (reverted) + assert_eq!( + trie.upper_subtrie.inner.values.get(&full_path), + Some(&old_value), + "Original value should be preserved after failed removal" + ); + + // prefix_set should be unchanged after failed removal + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged after failed removal on blinded node" + ); + } + + #[test] + fn test_update_leaves_removal_branch_collapse_blinded() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a branch node at root with two children: + // - Child at nibble 0: a blinded Hash node + // - Child at nibble 1: a revealed Leaf node + let small_value = alloy_rlp::encode_fixed_size(&U256::from(1)).to_vec(); + let leaf = LeafNode::new(Nibbles::default(), small_value); + let branch = TrieNode::Branch(BranchNode::new( + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), // blinded child at nibble 0 + RlpNode::from_raw_rlp(&alloy_rlp::encode(leaf.clone())).unwrap(), /* leaf at nibble 1 */ + ], + TrieMask::new(0b11), + )); + + let mut trie = ParallelSparseTrie::from_root( + branch.clone(), + Some(BranchNodeMasks { + hash_mask: TrieMask::new(0b01), // nibble 0 is hashed + tree_mask: TrieMask::default(), + }), + false, + ) + .unwrap(); + + // Reveal the branch and the leaf at nibble 1, leaving nibble 0 as Hash node + trie.reveal_node( + Nibbles::default(), + branch, + Some(BranchNodeMasks { + hash_mask: TrieMask::default(), + tree_mask: TrieMask::new(0b01), + }), + ) + .unwrap(); + trie.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + + // Insert the leaf's value into the values map for the revealed leaf + // Use B256 key that starts with nibble 1 (0x10 has first nibble = 1) + let b256_key = B256::with_last_byte(0x10); + let full_path = Nibbles::unpack(b256_key); + let leaf_value = encode_account_value(42); + trie.upper_subtrie.inner.values.insert(full_path, leaf_value.clone()); + + // Record state before update_leaves + let prefix_set_len_before = trie.prefix_set.len(); + let node_count_before = trie.upper_subtrie.nodes.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.nodes.len()) + .sum::(); + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(vec![])); // removal + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Assert: update remains in map (removal blocked by blinded sibling) + assert!( + !updates.is_empty(), + "Update should remain in map when removal would collapse branch with blinded sibling" + ); + + // Assert: callback was invoked for the blinded path + assert!( + !proof_targets.borrow().is_empty(), + "Callback should be invoked for blinded sibling path" + ); + + // Assert: prefix_set unchanged (atomic failure) + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged after atomic failure" + ); + + // Assert: node count unchanged + let node_count_after = trie.upper_subtrie.nodes.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.nodes.len()) + .sum::(); + assert_eq!( + node_count_before, node_count_after, + "Node count should be unchanged after atomic failure" + ); + + // Assert: the leaf value still exists (not removed) + assert_eq!( + trie.upper_subtrie.inner.values.get(&full_path), + Some(&leaf_value), + "Leaf value should still exist after failed removal" + ); + } + + #[test] + fn test_update_leaves_touched() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + let provider = DefaultTrieNodeProvider; + let mut trie = ParallelSparseTrie::default(); + + // Create a leaf in the trie using a full-length key + let b256_key = B256::with_last_byte(42); + let key = Nibbles::unpack(b256_key); + let value = encode_account_value(1); + trie.update_leaf(key, value, &provider).unwrap(); + + // Create a Touched update for the existing key + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Touched); + + let proof_targets = RefCell::new(Vec::new()); + let prefix_set_len_before = trie.prefix_set.len(); + + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Update should be removed (path is accessible) + assert!(updates.is_empty(), "Touched update should be removed for accessible path"); + + // No callback + assert!( + proof_targets.borrow().is_empty(), + "Callback should not be invoked for accessible path" + ); + + // prefix_set should be unchanged since Touched is read-only + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged for Touched update (read-only)" + ); + } + + #[test] + fn test_update_leaves_touched_nonexistent() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + let mut trie = ParallelSparseTrie::default(); + + // Create a Touched update for a key that doesn't exist + let b256_key = B256::with_last_byte(99); + let full_path = Nibbles::unpack(b256_key); + + let prefix_set_len_before = trie.prefix_set.len(); + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Touched); + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Update should be removed (path IS accessible - it's just empty) + assert!(updates.is_empty(), "Touched update should be removed for accessible (empty) path"); + + // No callback should be invoked (path is revealed, just empty) + assert!( + proof_targets.borrow().is_empty(), + "Callback should not be invoked for accessible path" + ); + + // prefix_set should NOT be modified (Touched is read-only) + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should not be modified by Touched update" + ); + + // No value should be inserted + assert!( + trie.get_leaf_value(&full_path).is_none(), + "No value should exist for non-existent key after Touched update" + ); + } + + #[test] + fn test_update_leaves_touched_blinded() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a trie with a blinded node + // Use a small value that fits in RLP encoding + let small_value = alloy_rlp::encode_fixed_size(&U256::from(1)).to_vec(); + let leaf = LeafNode::new( + Nibbles::default(), // short key for RLP encoding + small_value, + ); + let branch = TrieNode::Branch(BranchNode::new( + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), // blinded child at 0 + RlpNode::from_raw_rlp(&alloy_rlp::encode(leaf.clone())).unwrap(), // revealed at 1 + ], + TrieMask::new(0b11), + )); + + let mut trie = ParallelSparseTrie::from_root( + branch.clone(), + Some(BranchNodeMasks { + hash_mask: TrieMask::new(0b01), + tree_mask: TrieMask::default(), + }), + false, + ) + .unwrap(); + + trie.reveal_node( + Nibbles::default(), + branch, + Some(BranchNodeMasks { + hash_mask: TrieMask::default(), + tree_mask: TrieMask::new(0b01), + }), + ) + .unwrap(); + trie.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + + // Create a Touched update targeting the blinded path using full B256 key + let b256_key = B256::ZERO; // starts with 0x0... + + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Touched); + + let proof_targets = RefCell::new(Vec::new()); + let prefix_set_len_before = trie.prefix_set.len(); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // Callback should be invoked + assert!(!proof_targets.borrow().is_empty(), "Callback should be invoked for blinded path"); + + // Update should remain in map + assert!(!updates.is_empty(), "Touched update should remain in map for blinded path"); + + // prefix_set should be unchanged since Touched is read-only + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged for Touched update on blinded path" + ); + } + + #[test] + fn test_update_leaves_deduplication() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a trie with a blinded node + // Use a small value that fits in RLP encoding + let small_value = alloy_rlp::encode_fixed_size(&U256::from(1)).to_vec(); + let leaf = LeafNode::new( + Nibbles::default(), // short key for RLP encoding + small_value, + ); + let branch = TrieNode::Branch(BranchNode::new( + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), // blinded child at 0 + RlpNode::from_raw_rlp(&alloy_rlp::encode(leaf.clone())).unwrap(), // revealed at 1 + ], + TrieMask::new(0b11), + )); + + let mut trie = ParallelSparseTrie::from_root( + branch.clone(), + Some(BranchNodeMasks { + hash_mask: TrieMask::new(0b01), + tree_mask: TrieMask::default(), + }), + false, + ) + .unwrap(); + + trie.reveal_node( + Nibbles::default(), + branch, + Some(BranchNodeMasks { + hash_mask: TrieMask::default(), + tree_mask: TrieMask::new(0b01), + }), + ) + .unwrap(); + trie.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap(); + + // Create multiple updates that would all hit the same blinded node at path 0x0 + // Use full B256 keys that all start with 0x0 + let b256_key1 = B256::ZERO; + let b256_key2 = B256::with_last_byte(1); // still starts with 0x0 + let b256_key3 = B256::with_last_byte(2); // still starts with 0x0 + + let mut updates: B256Map = B256Map::default(); + let value = encode_account_value(42); + + updates.insert(b256_key1, LeafUpdate::Changed(value.clone())); + updates.insert(b256_key2, LeafUpdate::Changed(value.clone())); + updates.insert(b256_key3, LeafUpdate::Changed(value)); + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .unwrap(); + + // The callback should be invoked 3 times - once for each unique full_path + // The deduplication is by (full_path, min_len), not by blinded node + let targets = proof_targets.borrow(); + assert_eq!(targets.len(), 3, "Callback should be invoked for each unique key"); + + // All should have the same min_len (1) since they all hit blinded node at path 0x0 + for (_, min_len) in targets.iter() { + assert_eq!(*min_len, 1, "All should have min_len 1 from blinded node at 0x0"); + } + } + + #[test] + fn test_update_leaves_node_not_found_in_provider_atomicity() { + use alloy_primitives::map::B256Map; + use reth_trie_sparse::LeafUpdate; + use std::cell::RefCell; + + // Create a trie with retain_updates enabled (this triggers the code path that + // can return NodeNotFoundInProvider when an extension node's child needs revealing). + // + // Structure: Extension at root -> Hash node (blinded child) + // When we try to insert a new leaf that would split the extension, with retain_updates + // enabled, it tries to reveal the hash child via the provider. With NoRevealProvider, + // this returns NodeNotFoundInProvider. + + let child_hash = B256::repeat_byte(0xAB); + let extension = TrieNode::Extension(ExtensionNode::new( + Nibbles::from_nibbles([0x1, 0x2, 0x3]), + RlpNode::word_rlp(&child_hash), + )); + + // Create trie with retain_updates = true + let mut trie = + ParallelSparseTrie::from_root(extension, None, true).expect("from_root failed"); + + // Record state before update_leaves + let prefix_set_len_before = trie.prefix_set.len(); + let node_count_before = trie.upper_subtrie.nodes.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.nodes.len()) + .sum::(); + let value_count_before = trie.upper_subtrie.inner.values.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.inner.values.len()) + .sum::(); + + // Create an update that would cause an extension split. + // The key starts with 0x1 but diverges from 0x123... at the second nibble. + let b256_key = { + let mut k = B256::ZERO; + k.0[0] = 0x14; // nibbles: 1, 4 - matches first nibble, diverges at second + k + }; + + let new_value = encode_account_value(42); + let mut updates: B256Map = B256Map::default(); + updates.insert(b256_key, LeafUpdate::Changed(new_value)); + + let proof_targets = RefCell::new(Vec::new()); + trie.update_leaves(&mut updates, |path, min_len| { + proof_targets.borrow_mut().push((path, min_len)); + }) + .expect("update_leaves should succeed"); + + // Assert: update remains in map (NodeNotFoundInProvider is retriable) + assert!( + !updates.is_empty(), + "Update should remain in map when NodeNotFoundInProvider occurs" + ); + assert!( + updates.contains_key(&b256_key), + "The specific key should be re-inserted for retry" + ); + + // Assert: callback was invoked + let targets = proof_targets.borrow(); + assert!(!targets.is_empty(), "Callback should be invoked for NodeNotFoundInProvider"); + + // Assert: prefix_set unchanged (atomic - no partial state) + assert_eq!( + trie.prefix_set.len(), + prefix_set_len_before, + "prefix_set should be unchanged after atomic failure" + ); + + // Assert: node count unchanged (no structural changes persisted) + let node_count_after = trie.upper_subtrie.nodes.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.nodes.len()) + .sum::(); + assert_eq!( + node_count_before, node_count_after, + "Node count should be unchanged after atomic failure" + ); + + // Assert: value count unchanged (no values left dangling) + let value_count_after = trie.upper_subtrie.inner.values.len() + + trie.lower_subtries + .iter() + .filter_map(|s| s.as_revealed_ref()) + .map(|s| s.inner.values.len()) + .sum::(); + assert_eq!( + value_count_before, value_count_after, + "Value count should be unchanged after atomic failure (no dangling values)" + ); + } } diff --git a/crates/trie/sparse/src/provider.rs b/crates/trie/sparse/src/provider.rs index 405b3a8474..bfd44424cc 100644 --- a/crates/trie/sparse/src/provider.rs +++ b/crates/trie/sparse/src/provider.rs @@ -64,6 +64,21 @@ impl TrieNodeProvider for DefaultTrieNodeProvider { } } +/// A provider that never reveals nodes from the database. +/// +/// This is used by `update_leaves` to attempt trie operations without +/// performing any database lookups. When the trie encounters a blinded node +/// that would normally trigger a reveal, this provider returns `None`, +/// causing the operation to fail with a `BlindedNode` error. +#[derive(PartialEq, Eq, Clone, Copy, Default, Debug)] +pub struct NoRevealProvider; + +impl TrieNodeProvider for NoRevealProvider { + fn trie_node(&self, _path: &Nibbles) -> Result, SparseTrieError> { + Ok(None) + } +} + /// Right pad the path with 0s and return as [`B256`]. #[inline] pub fn pad_path_to_key(path: &Nibbles) -> B256 { diff --git a/crates/trie/sparse/src/traits.rs b/crates/trie/sparse/src/traits.rs index e235cead63..286b264c1e 100644 --- a/crates/trie/sparse/src/traits.rs +++ b/crates/trie/sparse/src/traits.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; use alloc::{borrow::Cow, vec, vec::Vec}; use alloy_primitives::{ - map::{HashMap, HashSet}, + map::{B256Map, HashMap, HashSet}, B256, }; use alloy_trie::BranchNodeCompact; @@ -13,6 +13,17 @@ use reth_trie_common::{BranchNodeMasks, Nibbles, ProofTrieNode, TrieNode}; use crate::provider::TrieNodeProvider; +/// Describes an update to a leaf in the sparse trie. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LeafUpdate { + /// The leaf value has been changed to the given RLP-encoded value. + /// Empty Vec indicates the leaf has been removed. + Changed(Vec), + /// The leaf value may have changed, but the new value is not yet known. + /// Used for optimistic prewarming when the actual value is unavailable. + Touched, +} + /// Trait defining common operations for revealed sparse trie implementations. /// /// This trait abstracts over different sparse trie implementations (serial vs parallel) @@ -260,6 +271,26 @@ pub trait SparseTrieExt: SparseTrie { /// /// The number of nodes converted to hash stubs. fn prune(&mut self, max_depth: usize) -> usize; + + /// Applies leaf updates to the sparse trie. + /// + /// When a [`LeafUpdate::Changed`] is successfully applied, it is removed from the + /// given [`B256Map`]. If it could not be applied due to blinded nodes, it remains + /// in the map and the callback is invoked with the required proof target. + /// + /// Once that proof is calculated and revealed via [`SparseTrie::reveal_nodes`], the same + /// `updates` map can be reused to retry the update. + /// + /// Proof targets are deduplicated by `(full_path, min_len)` across all calls to this method. + /// The callback will only be invoked once per unique target, even across retry loops. + /// A deeper blinded node (higher `min_len`) for the same path is considered a new target. + /// + /// [`LeafUpdate::Touched`] behaves identically except it does not modify the leaf value. + fn update_leaves( + &mut self, + updates: &mut B256Map, + proof_required_fn: impl FnMut(Nibbles, u8), + ) -> SparseTrieResult<()>; } /// Tracks modifications to the sparse trie structure. diff --git a/crates/trie/sparse/src/trie.rs b/crates/trie/sparse/src/trie.rs index 0ca4a20cb7..fe56d4b857 100644 --- a/crates/trie/sparse/src/trie.rs +++ b/crates/trie/sparse/src/trie.rs @@ -1,6 +1,7 @@ use crate::{ provider::{RevealedNode, TrieNodeProvider}, - LeafLookup, LeafLookupError, SparseTrie as SparseTrieTrait, SparseTrieUpdates, + LeafLookup, LeafLookupError, LeafUpdate, SparseTrie as SparseTrieTrait, SparseTrieExt, + SparseTrieUpdates, }; use alloc::{ borrow::Cow, @@ -12,7 +13,7 @@ use alloc::{ }; use alloy_primitives::{ hex, keccak256, - map::{Entry, HashMap, HashSet}, + map::{B256Map, Entry, HashMap, HashSet}, B256, }; use alloy_rlp::Decodable; @@ -287,6 +288,36 @@ impl RevealableSparseTrie { } } +impl RevealableSparseTrie { + /// Applies batch leaf updates to the sparse trie. + /// + /// For blind tries, all updates are kept in the map and proof targets are emitted + /// for every key (with `min_len = 0` since nothing is revealed). + /// + /// For revealed tries, delegates to the inner implementation which will: + /// - Apply updates where possible + /// - Keep blocked updates in the map + /// - Emit proof targets for blinded paths + pub fn update_leaves( + &mut self, + updates: &mut B256Map, + mut proof_required_fn: impl FnMut(Nibbles, u8), + ) -> SparseTrieResult<()> { + match self { + Self::Blind(_) => { + // Nothing is revealed - emit proof targets for all keys with min_len = 0 + for key in updates.keys() { + let full_path = Nibbles::unpack(*key); + proof_required_fn(full_path, 0); + } + // All updates remain in the map for retry after proofs are fetched + Ok(()) + } + Self::Revealed(trie) => trie.update_leaves(updates, proof_required_fn), + } + } +} + /// The representation of revealed sparse trie. /// /// The revealed sparse trie contains the actual trie structure with nodes, values, and