Compare commits

...

1 Commits

Author SHA1 Message Date
yongkangc
2773663d28 feat(trie): add update_leaves method to SparseTrie
Add batch update_leaves method that collects proof targets for blinded
nodes instead of fetching proofs one-at-a-time. This enables batch proof
fetching with a single database call.

Key changes:
- Move Target struct to reth-trie-common for reuse across crates
- Make update_leaf transactional by deferring all mutations to end
- Add LeafUpdate enum for batch operations
- Add update_leaves method to SparseTrie trait
- Add find_blinded_on_path helper for read-only traversal

The transactional fix ensures that if update_leaf hits a BlindedNode
error, no partial mutations are left in the trie, making retries safe.

Closes RETH-177

Amp-Thread-ID: https://ampcode.com/threads/T-019c03f0-b517-7209-a875-8eeb8c8a185f
2026-01-28 10:15:46 +00:00
7 changed files with 441 additions and 102 deletions

View File

@@ -39,6 +39,9 @@ pub use key::{KeccakKeyHasher, KeyHasher};
mod nibbles;
pub use nibbles::{Nibbles, StoredNibbles, StoredNibblesSubKey};
mod target;
pub use target::Target;
mod storage;
pub use storage::StorageTrieEntry;

View File

@@ -0,0 +1,68 @@
use alloy_primitives::B256;
use crate::Nibbles;
/// Target describes a proof target for trie operations.
///
/// For every proof target given, the proof calculator will return all nodes
/// whose path is a prefix of the target's `key`.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct Target {
/// The full key path (64 nibbles).
pub key: Nibbles,
/// Minimum path length required for the proof.
/// Only match trie nodes whose path is at least this long.
pub min_len: u8,
}
impl Target {
/// Returns a new [`Target`] which matches all trie nodes whose path is a prefix of this key.
pub fn new(key: B256) -> Self {
// SAFETY: key is a B256 and so is exactly 32-bytes.
let key = unsafe { Nibbles::unpack_unchecked(key.as_slice()) };
Self { key, min_len: 0 }
}
/// Creates a new target from nibbles with a specific min_len.
#[inline]
pub const fn from_nibbles(key: Nibbles, min_len: u8) -> Self {
Self { key, min_len }
}
/// Returns the key the target was initialized with.
pub fn key(&self) -> B256 {
B256::from_slice(&self.key.pack())
}
/// Only match trie nodes whose path is at least this long.
///
/// # Panics
///
/// This method panics if `min_len` is greater than 64.
pub fn with_min_len(mut self, min_len: u8) -> Self {
debug_assert!(min_len <= 64);
self.min_len = min_len;
self
}
/// Returns the sub-trie prefix for this target.
///
/// A target will only match nodes which share the target's prefix, where the target's prefix
/// is the first `min_len` nibbles of its key. E.g. a target with `key` 0xabcd and `min_len` 2
/// will only match nodes with prefix 0xab.
///
/// The sub-trie prefix is the target prefix with a nibble truncated, because a branch node
/// must be constructed at the parent level to know the node exists at that path.
#[inline]
pub fn sub_trie_prefix(&self) -> Nibbles {
let mut sub_trie_prefix = self.key;
sub_trie_prefix.truncate(self.min_len.saturating_sub(1) as usize);
sub_trie_prefix
}
}
impl From<B256> for Target {
fn from(key: B256) -> Self {
Self::new(key)
}
}

View File

@@ -14,6 +14,8 @@ pub use trie::*;
mod traits;
pub use traits::*;
pub use reth_trie_common::Target;
pub mod provider;
#[cfg(feature = "metrics")]

View File

@@ -64,6 +64,22 @@ impl TrieNodeProvider for DefaultTrieNodeProvider {
}
}
/// A trie node provider that always returns `Ok(None)`.
///
/// This is used by [`update_leaves`](crate::SparseTrie::update_leaves) to force
/// [`BlindedNode`](reth_execution_errors::SparseTrieErrorKind::BlindedNode) errors
/// when the trie traversal hits a hash node, rather than attempting to fetch from a database.
/// This enables a "short-circuit" pattern where blinded nodes are detected and collected
/// for batch proof fetching instead of being resolved one-at-a-time.
#[derive(PartialEq, Eq, Clone, Copy, Default, Debug)]
pub struct ShortCircuitTrieNodeProvider;
impl TrieNodeProvider for ShortCircuitTrieNodeProvider {
fn trie_node(&self, _path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
Ok(None)
}
}
/// Right pad the path with 0s and return as [`B256`].
#[inline]
pub fn pad_path_to_key(path: &Nibbles) -> B256 {

View File

@@ -9,7 +9,7 @@ use alloy_primitives::{
};
use alloy_trie::BranchNodeCompact;
use reth_execution_errors::SparseTrieResult;
use reth_trie_common::{BranchNodeMasks, Nibbles, ProofTrieNode, TrieNode};
use reth_trie_common::{BranchNodeMasks, Nibbles, ProofTrieNode, Target, TrieNode};
use crate::provider::TrieNodeProvider;
@@ -223,6 +223,35 @@ pub trait SparseTrie: Sized + Debug + Send + Sync {
/// This is useful for reusing the trie without needing to reallocate memory.
fn clear(&mut self);
/// Batch update multiple leaves, collecting proof targets for any blinded nodes.
///
/// This method attempts to apply all leaf updates in the provided map. For any update
/// that would hit a blinded node during traversal, it calls `on_blinded` with a
/// [`Target`] instead of failing. Successfully applied updates are removed from
/// the input map.
///
/// This enables batch proof fetching: collect all blinded node targets in one pass,
/// fetch proofs in a single database call, reveal them, then retry.
///
/// # Arguments
///
/// * `updates` - Map of key paths to leaf updates. Modified in place - successful updates are
/// removed.
/// * `on_blinded` - Callback invoked for each blinded node encountered. Receives a
/// [`Target`] that can be used to fetch the required proof.
///
/// # Returns
///
/// `Ok(())` after processing all updates. Blinded nodes don't cause errors;
/// they're reported via the callback.
fn update_leaves<F>(
&mut self,
updates: &mut HashMap<Nibbles, LeafUpdate>,
on_blinded: F,
) -> SparseTrieResult<()>
where
F: FnMut(Target);
/// Shrink the capacity of the sparse trie's node storage to the given size.
/// This will reduce memory usage if the current capacity is higher than the given size.
fn shrink_nodes_to(&mut self, size: usize);
@@ -277,3 +306,38 @@ pub enum LeafLookup {
/// Leaf does not exist (exclusion proof found).
NonExistent,
}
/// Represents an update operation on a leaf node.
///
/// Used with [`SparseTrie::update_leaves`] to specify whether a leaf
/// should be inserted/updated or deleted.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LeafUpdate {
/// Insert a new leaf or update an existing leaf with the given RLP-encoded value.
Upsert(Vec<u8>),
/// Delete the leaf at this path.
Deleted,
}
impl LeafUpdate {
/// Returns `true` if this is an upsert operation.
#[inline]
pub const fn is_upsert(&self) -> bool {
matches!(self, Self::Upsert(_))
}
/// Returns `true` if this is a delete operation.
#[inline]
pub const fn is_deleted(&self) -> bool {
matches!(self, Self::Deleted)
}
/// Returns the value if this is an upsert, or `None` if it's a delete.
#[inline]
pub const fn value(&self) -> Option<&Vec<u8>> {
match self {
Self::Upsert(v) => Some(v),
Self::Deleted => None,
}
}
}

View File

@@ -1,6 +1,6 @@
use crate::{
provider::{RevealedNode, TrieNodeProvider},
LeafLookup, LeafLookupError, SparseTrie as SparseTrieTrait, SparseTrieUpdates,
LeafLookup, LeafLookupError, LeafUpdate, SparseTrie as SparseTrieTrait, SparseTrieUpdates,
};
use alloc::{
borrow::Cow,
@@ -20,7 +20,7 @@ use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult};
use reth_trie_common::{
prefix_set::{PrefixSet, PrefixSetMut},
BranchNodeCompact, BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef,
LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, TrieMask, TrieNode, CHILD_INDEX_RANGE,
LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, Target, TrieMask, TrieNode, CHILD_INDEX_RANGE,
EMPTY_ROOT_HASH,
};
use smallvec::SmallVec;
@@ -609,25 +609,32 @@ impl SparseTrieTrait for SerialSparseTrie {
value: Vec<u8>,
provider: P,
) -> SparseTrieResult<()> {
self.prefix_set.insert(full_path);
let existing = self.values.insert(full_path, value);
if existing.is_some() {
// trie structure unchanged, return immediately
// Check if value already exists - if so, update in place and return early
if let Some(existing) = self.values.get_mut(&full_path) {
*existing = value;
self.prefix_set.insert(full_path);
return Ok(())
}
enum NodeOp {
Insert(Nibbles, SparseNode),
Update(Nibbles, SparseNode),
}
let mut ops: SmallVec<[NodeOp; 4]> = SmallVec::new();
let mut current = Nibbles::default();
while let Some(node) = self.nodes.get_mut(&current) {
while let Some(node) = self.nodes.get(&current) {
match node {
SparseNode::Empty => {
*node = SparseNode::new_leaf(full_path);
ops.push(NodeOp::Update(current, SparseNode::new_leaf(full_path)));
break
}
&mut SparseNode::Hash(hash) => {
&SparseNode::Hash(hash) => {
return Err(SparseTrieErrorKind::BlindedNode { path: current, hash }.into())
}
SparseNode::Leaf { key: current_key, .. } => {
current.extend(current_key);
let current_key = *current_key;
current.extend(&current_key);
// this leaf is being updated
if current == full_path {
@@ -637,37 +644,39 @@ impl SparseTrieTrait for SerialSparseTrie {
// find the common prefix
let common = current.common_prefix_length(&full_path);
// update existing node
// update existing node to extension
let new_ext_key = current.slice(current.len() - current_key.len()..common);
*node = SparseNode::new_ext(new_ext_key);
ops.push(NodeOp::Update(
current.slice(..current.len() - current_key.len()),
SparseNode::new_ext(new_ext_key),
));
// create a branch node and corresponding leaves
self.nodes.reserve(3);
self.nodes.insert(
ops.push(NodeOp::Insert(
current.slice(..common),
SparseNode::new_split_branch(
current.get_unchecked(common),
full_path.get_unchecked(common),
),
);
self.nodes.insert(
));
ops.push(NodeOp::Insert(
full_path.slice(..=common),
SparseNode::new_leaf(full_path.slice(common + 1..)),
);
self.nodes.insert(
));
ops.push(NodeOp::Insert(
current.slice(..=common),
SparseNode::new_leaf(current.slice(common + 1..)),
);
));
break;
}
SparseNode::Extension { key, .. } => {
current.extend(key);
let key = *key;
current.extend(&key);
if !full_path.starts_with(&current) {
// find the common prefix
let common = current.common_prefix_length(&full_path);
*key = current.slice(current.len() - key.len()..common);
// If branch node updates retention is enabled, we need to query the
// extension node child to later set the hash mask for a parent branch node
@@ -700,23 +709,35 @@ impl SparseTrieTrait for SerialSparseTrie {
}
}
// update existing extension node
let new_ext_key = current.slice(current.len() - key.len()..common);
ops.push(NodeOp::Update(
current.slice(..current.len() - key.len()),
SparseNode::new_ext(new_ext_key),
));
// create state mask for new branch node
// NOTE: this might overwrite the current extension node
self.nodes.reserve(3);
let branch = SparseNode::new_split_branch(
current.get_unchecked(common),
full_path.get_unchecked(common),
);
self.nodes.insert(current.slice(..common), branch);
ops.push(NodeOp::Insert(
current.slice(..common),
SparseNode::new_split_branch(
current.get_unchecked(common),
full_path.get_unchecked(common),
),
));
// create new leaf
let new_leaf = SparseNode::new_leaf(full_path.slice(common + 1..));
self.nodes.insert(full_path.slice(..=common), new_leaf);
ops.push(NodeOp::Insert(
full_path.slice(..=common),
SparseNode::new_leaf(full_path.slice(common + 1..)),
));
// recreate extension to previous child if needed
let key = current.slice(common + 1..);
if !key.is_empty() {
self.nodes.insert(current.slice(..=common), SparseNode::new_ext(key));
let remaining_key = current.slice(common + 1..);
if !remaining_key.is_empty() {
ops.push(NodeOp::Insert(
current.slice(..=common),
SparseNode::new_ext(remaining_key),
));
}
break;
@@ -726,15 +747,90 @@ impl SparseTrieTrait for SerialSparseTrie {
let nibble = full_path.get_unchecked(current.len());
current.push_unchecked(nibble);
if !state_mask.is_bit_set(nibble) {
state_mask.set_bit(nibble);
let new_leaf = SparseNode::new_leaf(full_path.slice(current.len()..));
self.nodes.insert(current, new_leaf);
let mut new_state_mask = *state_mask;
new_state_mask.set_bit(nibble);
ops.push(NodeOp::Update(
current.slice(..current.len() - 1),
SparseNode::Branch {
state_mask: new_state_mask,
hash: None,
store_in_db_trie: None,
},
));
ops.push(NodeOp::Insert(
current,
SparseNode::new_leaf(full_path.slice(current.len()..)),
));
break;
}
}
};
}
// Apply all staged operations
for op in ops {
match op {
NodeOp::Insert(path, node) | NodeOp::Update(path, node) => {
self.nodes.insert(path, node);
}
}
}
self.values.insert(full_path, value);
self.prefix_set.insert(full_path);
Ok(())
}
fn update_leaves<F>(
&mut self,
updates: &mut HashMap<Nibbles, LeafUpdate>,
mut on_blinded: F,
) -> SparseTrieResult<()>
where
F: FnMut(Target),
{
use crate::provider::ShortCircuitTrieNodeProvider;
// Sort keys for cache-friendly traversal (lexicographic order)
let mut keys: Vec<_> = updates.keys().copied().collect();
keys.sort_unstable();
for key in keys {
// Check if path is blinded before attempting update
if let Some((blinded_path, _hash)) = self.find_blinded_on_path(&key) {
// Report blinded node - min_len is one level deeper than blinded path
let min_len = (blinded_path.len() + 1) as u8;
on_blinded(Target::from_nibbles(key, min_len));
continue;
}
// Path is revealed, apply the update
let update = updates.get(&key).unwrap();
let result = match update {
LeafUpdate::Upsert(value) => {
self.update_leaf(key, value.clone(), ShortCircuitTrieNodeProvider)
}
LeafUpdate::Deleted => self.remove_leaf(&key, ShortCircuitTrieNodeProvider),
};
match result {
Ok(()) => {
// Successfully applied - remove from updates map
updates.remove(&key);
}
Err(err) => {
// Check if it's a BlindedNode error (shouldn't happen after our check)
if let SparseTrieErrorKind::BlindedNode { path, .. } = err.kind() {
let min_len = (path.len() + 1) as u8;
on_blinded(Target::from_nibbles(key, min_len));
continue;
}
// Other errors propagate
return Err(err);
}
}
}
Ok(())
}
@@ -1775,6 +1871,44 @@ impl SerialSparseTrie {
debug_assert_eq!(buffers.rlp_node_stack.len(), 1);
buffers.rlp_node_stack.pop().unwrap().rlp_node
}
/// Traverses the trie read-only to find the first blinded hash node on the path.
/// Returns `Some((blinded_path, hash))` if found, `None` if path is fully revealed.
fn find_blinded_on_path(&self, full_path: &Nibbles) -> Option<(Nibbles, B256)> {
let mut current = Nibbles::default();
while let Some(node) = self.nodes.get(&current) {
match node {
SparseNode::Empty => return None,
&SparseNode::Hash(hash) => return Some((current, hash)),
SparseNode::Leaf { key, .. } => {
current.extend(key);
// Leaf found - path is revealed (even if it's a different leaf)
return None;
}
SparseNode::Extension { key, .. } => {
current.extend(key);
// If the path diverges, check if the child at current is blinded
if !full_path.starts_with(&current) {
// Check the extension's child
if let Some(&SparseNode::Hash(hash)) = self.nodes.get(&current) {
return Some((current, hash));
}
return None;
}
}
SparseNode::Branch { state_mask, .. } => {
let nibble = full_path.get_unchecked(current.len());
current.push_unchecked(nibble);
if !state_mask.is_bit_set(nibble) {
// Child doesn't exist - path is revealed (leaf will be inserted)
return None;
}
// Continue traversing to the child
}
}
}
None
}
}
/// Enum representing sparse trie node type.
@@ -3723,4 +3857,119 @@ Root -> Extension { key: Nibbles(0x5), hash: None, store_in_db_trie: None }
assert_eq!(alternate_printed, expected);
}
#[test]
fn test_update_leaves_all_revealed() {
let mut trie = SerialSparseTrie::default();
let path1 = Nibbles::unpack(B256::repeat_byte(0x11));
let path2 = Nibbles::unpack(B256::repeat_byte(0x22));
trie.update_leaf(path1, vec![1, 2, 3], DefaultTrieNodeProvider).unwrap();
trie.update_leaf(path2, vec![4, 5, 6], DefaultTrieNodeProvider).unwrap();
let mut updates = HashMap::default();
updates.insert(path1, LeafUpdate::Upsert(vec![10, 20, 30]));
updates.insert(path2, LeafUpdate::Upsert(vec![40, 50, 60]));
let path3 = Nibbles::unpack(B256::repeat_byte(0x33));
updates.insert(path3, LeafUpdate::Upsert(vec![7, 8, 9]));
let mut blinded_targets = Vec::new();
trie.update_leaves(&mut updates, |target| blinded_targets.push(target)).unwrap();
assert!(blinded_targets.is_empty());
assert!(updates.is_empty());
assert_eq!(trie.get_leaf_value(&path1), Some(&vec![10, 20, 30]));
assert_eq!(trie.get_leaf_value(&path2), Some(&vec![40, 50, 60]));
assert_eq!(trie.get_leaf_value(&path3), Some(&vec![7, 8, 9]));
}
#[test]
fn test_update_leaves_with_blinded_nodes() {
let mut trie = SerialSparseTrie::default();
let hash = B256::repeat_byte(0xaa);
trie.nodes.clear();
trie.nodes.insert(Nibbles::default(), SparseNode::Hash(hash));
let path = Nibbles::unpack(B256::repeat_byte(0x11));
let mut updates = HashMap::default();
updates.insert(path, LeafUpdate::Upsert(vec![1, 2, 3]));
let mut blinded_targets = Vec::new();
trie.update_leaves(&mut updates, |target| blinded_targets.push(target)).unwrap();
assert_eq!(blinded_targets.len(), 1);
assert_eq!(blinded_targets[0].key, path);
assert_eq!(blinded_targets[0].min_len, 1);
assert!(updates.contains_key(&path));
}
#[test]
fn test_update_leaves_mixed() {
let mut trie = SerialSparseTrie::default();
let revealed_path = Nibbles::unpack(B256::repeat_byte(0x11));
let other_path = Nibbles::unpack(B256::repeat_byte(0x55));
trie.update_leaf(revealed_path, vec![1], DefaultTrieNodeProvider).unwrap();
trie.update_leaf(other_path, vec![2], DefaultTrieNodeProvider).unwrap();
let blinded_prefix = Nibbles::from_nibbles([0x2]);
let hash = B256::repeat_byte(0xbb);
trie.nodes.insert(blinded_prefix, SparseNode::Hash(hash));
if let Some(SparseNode::Branch { state_mask, .. }) = trie.nodes.get_mut(&Nibbles::default())
{
state_mask.set_bit(2);
}
let blinded_path = Nibbles::unpack(B256::repeat_byte(0x22));
let mut updates = HashMap::default();
updates.insert(revealed_path, LeafUpdate::Upsert(vec![10]));
updates.insert(blinded_path, LeafUpdate::Upsert(vec![20]));
let mut blinded_targets = Vec::new();
trie.update_leaves(&mut updates, |target| blinded_targets.push(target)).unwrap();
assert!(!updates.contains_key(&revealed_path));
assert_eq!(trie.get_leaf_value(&revealed_path), Some(&vec![10]));
assert!(updates.contains_key(&blinded_path));
assert_eq!(blinded_targets.len(), 1);
}
#[test]
fn test_update_leaves_delete() {
let mut trie = SerialSparseTrie::default();
let path1 = Nibbles::unpack(B256::repeat_byte(0x11));
let path2 = Nibbles::unpack(B256::repeat_byte(0x22));
trie.update_leaf(path1, vec![1], DefaultTrieNodeProvider).unwrap();
trie.update_leaf(path2, vec![2], DefaultTrieNodeProvider).unwrap();
let mut updates = HashMap::default();
updates.insert(path1, LeafUpdate::Deleted);
let mut blinded_targets = Vec::new();
trie.update_leaves(&mut updates, |target| blinded_targets.push(target)).unwrap();
assert!(blinded_targets.is_empty());
assert!(updates.is_empty());
assert_eq!(trie.get_leaf_value(&path1), None);
assert_eq!(trie.get_leaf_value(&path2), Some(&vec![2]));
}
#[test]
fn test_update_leaf_existing_value_marks_dirty() {
let mut trie = SerialSparseTrie::default();
let path = Nibbles::unpack(B256::repeat_byte(0x11));
trie.update_leaf(path, vec![1, 2, 3], DefaultTrieNodeProvider).unwrap();
trie.prefix_set.clear();
trie.update_leaf(path, vec![4, 5, 6], DefaultTrieNodeProvider).unwrap();
let mut frozen = trie.prefix_set.clone().freeze();
assert!(frozen.contains(&path));
}
}

View File

@@ -1,71 +1,7 @@
use crate::proof_v2::increment_and_strip_trailing_zeros;
use alloy_primitives::B256;
use reth_trie_common::Nibbles;
/// Target describes a proof target. For every proof target given, the
/// [`crate::proof_v2::ProofCalculator`] will calculate and return all nodes whose path is a prefix
/// of the target's `key`.
#[derive(Debug, Copy, Clone)]
pub struct Target {
pub(crate) key: Nibbles,
pub(crate) min_len: u8,
}
impl Target {
/// Returns a new [`Target`] which matches all trie nodes whose path is a prefix of this key.
pub fn new(key: B256) -> Self {
// SAFETY: key is a B256 and so is exactly 32-bytes.
let key = unsafe { Nibbles::unpack_unchecked(key.as_slice()) };
Self { key, min_len: 0 }
}
/// Returns the key the target was initialized with.
pub fn key(&self) -> B256 {
B256::from_slice(&self.key.pack())
}
/// Only match trie nodes whose path is at least this long.
///
/// # Panics
///
/// This method panics if `min_len` is greater than 64.
pub fn with_min_len(mut self, min_len: u8) -> Self {
debug_assert!(min_len <= 64);
self.min_len = min_len;
self
}
// A helper function for getting the largest prefix of the sub-trie which contains a particular
// target, based on its `min_len`.
//
// A target will only match nodes which share the target's prefix, where the target's prefix is
// the first `min_len` nibbles of its key. E.g. a target with `key` 0xabcd and `min_len` 2 will
// only match nodes with prefix 0xab.
//
// In general the target will only match within the sub-trie whose prefix is identical to the
// target's. However there is an exception:
//
// Given a trie with a node at 0xabc, there must be a branch at 0xab. A target with prefix 0xabc
// needs to match that node, but the branch at 0xab must be constructed order to know the node
// is at that path. Therefore the sub-trie prefix is the target prefix with a nibble truncated.
//
// For a target with an empty prefix (`min_len` of 0) we still use an empty sub-trie prefix;
// this will still construct the branch at the root node (if there is one). Targets with
// `min_len` of both 0 and 1 will therefore construct the root node, but only those with
// `min_len` of 0 will retain it.
#[inline]
fn sub_trie_prefix(&self) -> Nibbles {
let mut sub_trie_prefix = self.key;
sub_trie_prefix.truncate(self.min_len.saturating_sub(1) as usize);
sub_trie_prefix
}
}
impl From<B256> for Target {
fn from(key: B256) -> Self {
Self::new(key)
}
}
pub use reth_trie_common::Target;
// A helper function which returns the first path following a sub-trie in lexicographical order.
#[inline]
@@ -153,6 +89,7 @@ pub(crate) fn iter_sub_trie_targets<'a>(
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::B256;
#[test]
fn test_iter_sub_trie_targets() {