From b2eb061fe2f9928097cbedf7a52663d7616c2abe Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Tue, 3 Mar 2026 10:57:13 +0100 Subject: [PATCH] chore(trie): remove `DatabaseTrieWitness` trait and add `MaskedTrieCursorFactory` (#22564) Co-authored-by: Amp Co-authored-by: Alexey Shekhirin --- .changelog/young-horses-play.md | 7 + .../src/providers/state/historical.rs | 30 +- .../provider/src/providers/state/latest.rs | 34 +- crates/trie/db/src/lib.rs | 2 - crates/trie/db/src/witness.rs | 48 -- crates/trie/db/tests/witness.rs | 74 +- crates/trie/trie/src/trie_cursor/masked.rs | 752 ++++++++++++++++++ crates/trie/trie/src/trie_cursor/mod.rs | 3 + crates/trie/trie/src/witness.rs | 14 - 9 files changed, 847 insertions(+), 117 deletions(-) create mode 100644 .changelog/young-horses-play.md delete mode 100644 crates/trie/db/src/witness.rs create mode 100644 crates/trie/trie/src/trie_cursor/masked.rs diff --git a/.changelog/young-horses-play.md b/.changelog/young-horses-play.md new file mode 100644 index 0000000000..7dfc940474 --- /dev/null +++ b/.changelog/young-horses-play.md @@ -0,0 +1,7 @@ +--- +reth-trie: major +reth-trie-db: major +reth-provider: minor +--- + +Added `MaskedTrieCursorFactory` and `MaskedTrieCursor` to handle prefix-set-based hash invalidation at the cursor layer, replacing the `DatabaseTrieWitness` trait abstraction. Removed `with_prefix_sets_mut` from `TrieWitness` and deleted `DatabaseTrieWitness` — callers should now wrap their cursor factory with `MaskedTrieCursorFactory` to apply prefix sets during witness/proof computation. diff --git a/crates/storage/provider/src/providers/state/historical.rs b/crates/storage/provider/src/providers/state/historical.rs index dd757122fe..2da49626c1 100644 --- a/crates/storage/provider/src/providers/state/historical.rs +++ b/crates/storage/provider/src/providers/state/historical.rs @@ -18,7 +18,9 @@ use reth_storage_api::{ }; use reth_storage_errors::provider::ProviderResult; use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, proof::{Proof, StorageProof}, + trie_cursor::{masked::MaskedTrieCursorFactory, InMemoryTrieCursorFactory}, updates::TrieUpdates, witness::TrieWitness, AccountProof, HashedPostState, HashedPostStateSorted, HashedStorage, KeccakKeyHasher, @@ -27,7 +29,7 @@ use reth_trie::{ }; use reth_trie_db::{ hashed_storage_from_reverts_with_provider, DatabaseProof, DatabaseStateRoot, - DatabaseStorageProof, DatabaseStorageRoot, DatabaseTrieWitness, + DatabaseStorageProof, DatabaseStorageRoot, }; use std::fmt::Debug; @@ -49,10 +51,6 @@ type DbProof<'a, TX, A> = Proof< reth_trie_db::DatabaseTrieCursorFactory<&'a TX, A>, reth_trie_db::DatabaseHashedCursorFactory<&'a TX>, >; -type DbTrieWitness<'a, TX, A> = TrieWitness< - reth_trie_db::DatabaseTrieCursorFactory<&'a TX, A>, - reth_trie_db::DatabaseHashedCursorFactory<&'a TX>, ->; /// Result of a history lookup for an account or storage slot. /// @@ -524,9 +522,25 @@ impl< reth_trie_db::with_adapter!(self.provider, |A| { let mut input = input; input.prepend(self.revert_state()?.into()); - >::overlay_witness(self.tx(), input, target) - .map_err(ProviderError::from) - .map(|hm| hm.into_values().collect()) + let nodes_sorted = input.nodes.into_sorted(); + let state_sorted = input.state.into_sorted(); + TrieWitness::new( + MaskedTrieCursorFactory::new( + InMemoryTrieCursorFactory::new( + reth_trie_db::DatabaseTrieCursorFactory::<_, A>::new(self.tx()), + &nodes_sorted, + ), + input.prefix_sets.freeze(), + ), + HashedPostStateCursorFactory::new( + reth_trie_db::DatabaseHashedCursorFactory::new(self.tx()), + &state_sorted, + ), + ) + .always_include_root_node() + .compute(target) + .map_err(ProviderError::from) + .map(|hm| hm.into_values().collect()) }) } } diff --git a/crates/storage/provider/src/providers/state/latest.rs b/crates/storage/provider/src/providers/state/latest.rs index 5f8caea547..ad66f2d6d9 100644 --- a/crates/storage/provider/src/providers/state/latest.rs +++ b/crates/storage/provider/src/providers/state/latest.rs @@ -9,16 +9,15 @@ use reth_storage_api::{ }; use reth_storage_errors::provider::{ProviderError, ProviderResult}; use reth_trie::{ + hashed_cursor::HashedPostStateCursorFactory, proof::{Proof, StorageProof}, + trie_cursor::{masked::MaskedTrieCursorFactory, InMemoryTrieCursorFactory}, updates::TrieUpdates, witness::TrieWitness, AccountProof, HashedPostState, HashedStorage, KeccakKeyHasher, MultiProof, MultiProofTargets, StateRoot, StorageMultiProof, StorageRoot, TrieInput, TrieInputSorted, }; -use reth_trie_db::{ - DatabaseProof, DatabaseStateRoot, DatabaseStorageProof, DatabaseStorageRoot, - DatabaseTrieWitness, -}; +use reth_trie_db::{DatabaseProof, DatabaseStateRoot, DatabaseStorageProof, DatabaseStorageRoot}; type DbStateRoot<'a, TX, A> = StateRoot< reth_trie_db::DatabaseTrieCursorFactory<&'a TX, A>, @@ -37,11 +36,6 @@ type DbProof<'a, TX, A> = Proof< reth_trie_db::DatabaseTrieCursorFactory<&'a TX, A>, reth_trie_db::DatabaseHashedCursorFactory<&'a TX>, >; -type DbTrieWitness<'a, TX, A> = TrieWitness< - reth_trie_db::DatabaseTrieCursorFactory<&'a TX, A>, - reth_trie_db::DatabaseHashedCursorFactory<&'a TX>, ->; - /// State provider over latest state that takes tx reference. /// /// Wraps a [`DBProvider`] to get access to database. @@ -226,9 +220,25 @@ impl StateProofProvider fn witness(&self, input: TrieInput, target: HashedPostState) -> ProviderResult> { reth_trie_db::with_adapter!(self.0, |A| { - Ok(>::overlay_witness(self.tx(), input, target)? - .into_values() - .collect()) + let nodes_sorted = input.nodes.into_sorted(); + let state_sorted = input.state.into_sorted(); + Ok(TrieWitness::new( + MaskedTrieCursorFactory::new( + InMemoryTrieCursorFactory::new( + reth_trie_db::DatabaseTrieCursorFactory::<_, A>::new(self.tx()), + &nodes_sorted, + ), + input.prefix_sets.freeze(), + ), + HashedPostStateCursorFactory::new( + reth_trie_db::DatabaseHashedCursorFactory::new(self.tx()), + &state_sorted, + ), + ) + .always_include_root_node() + .compute(target)? + .into_values() + .collect()) }) } } diff --git a/crates/trie/db/src/lib.rs b/crates/trie/db/src/lib.rs index 100a8bb4c5..31bede063f 100644 --- a/crates/trie/db/src/lib.rs +++ b/crates/trie/db/src/lib.rs @@ -10,7 +10,6 @@ mod proof; mod state; mod storage; mod trie_cursor; -mod witness; pub use hashed_cursor::{ DatabaseHashedAccountCursor, DatabaseHashedCursorFactory, DatabaseHashedStorageCursor, @@ -24,7 +23,6 @@ pub use trie_cursor::{ DatabaseAccountTrieCursor, DatabaseStorageTrieCursor, DatabaseTrieCursorFactory, LegacyKeyAdapter, PackedKeyAdapter, StorageTrieEntryLike, TrieKeyAdapter, TrieTableAdapter, }; -pub use witness::DatabaseTrieWitness; /// Dispatches a trie operation using the correct [`TrieKeyAdapter`] based on storage settings. /// diff --git a/crates/trie/db/src/witness.rs b/crates/trie/db/src/witness.rs deleted file mode 100644 index 2fb01ac850..0000000000 --- a/crates/trie/db/src/witness.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory, TrieTableAdapter}; -use alloy_primitives::{map::B256Map, Bytes}; -use reth_db_api::transaction::DbTx; -use reth_execution_errors::TrieWitnessError; -use reth_trie::{ - hashed_cursor::HashedPostStateCursorFactory, trie_cursor::InMemoryTrieCursorFactory, - witness::TrieWitness, HashedPostState, TrieInput, -}; - -/// Extends [`TrieWitness`] with operations specific for working with a database transaction. -pub trait DatabaseTrieWitness<'a, TX> { - /// Create a new [`TrieWitness`] from database transaction. - fn from_tx(tx: &'a TX) -> Self; - - /// Generates trie witness for target state based on [`TrieInput`]. - fn overlay_witness( - tx: &'a TX, - input: TrieInput, - target: HashedPostState, - ) -> Result, TrieWitnessError>; -} - -impl<'a, TX: DbTx, A: TrieTableAdapter> DatabaseTrieWitness<'a, TX> - for TrieWitness, DatabaseHashedCursorFactory<&'a TX>> -{ - fn from_tx(tx: &'a TX) -> Self { - Self::new(DatabaseTrieCursorFactory::<_, A>::new(tx), DatabaseHashedCursorFactory::new(tx)) - } - - fn overlay_witness( - tx: &'a TX, - input: TrieInput, - target: HashedPostState, - ) -> Result, TrieWitnessError> { - let nodes_sorted = input.nodes.into_sorted(); - let state_sorted = input.state.into_sorted(); - TrieWitness::new( - InMemoryTrieCursorFactory::new( - DatabaseTrieCursorFactory::<_, A>::new(tx), - &nodes_sorted, - ), - HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted), - ) - .with_prefix_sets_mut(input.prefix_sets) - .always_include_root_node() - .compute(target) - } -} diff --git a/crates/trie/db/tests/witness.rs b/crates/trie/db/tests/witness.rs index 176f2a6aab..1e630482f0 100644 --- a/crates/trie/db/tests/witness.rs +++ b/crates/trie/db/tests/witness.rs @@ -18,16 +18,12 @@ use reth_trie::{ }; use reth_trie_db::{ DatabaseHashedCursorFactory, DatabaseProof, DatabaseStateRoot, DatabaseTrieCursorFactory, - DatabaseTrieWitness, }; type DbStateRoot<'a, TX, A> = StateRoot, DatabaseHashedCursorFactory<&'a TX>>; type DbProof<'a, TX, A> = Proof, DatabaseHashedCursorFactory<&'a TX>>; -type DbTrieWitness<'a, TX, A> = - TrieWitness, DatabaseHashedCursorFactory<&'a TX>>; - #[test] fn includes_empty_node_preimage() { let factory = create_test_provider_factory(); @@ -40,12 +36,15 @@ fn includes_empty_node_preimage() { reth_trie_db::with_adapter!(provider, |A| { // witness includes empty state trie root node assert_eq!( - DbTrieWitness::<_, A>::from_tx(provider.tx_ref()) - .compute(HashedPostState { - accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), - storages: HashMap::default(), - }) - .unwrap(), + TrieWitness::new( + DatabaseTrieCursorFactory::<_, A>::new(provider.tx_ref()), + DatabaseHashedCursorFactory::new(provider.tx_ref()), + ) + .compute(HashedPostState { + accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), + storages: HashMap::default(), + }) + .unwrap(), HashMap::from_iter([(EMPTY_ROOT_HASH, Bytes::from([EMPTY_STRING_CODE]))]) ); @@ -61,15 +60,18 @@ fn includes_empty_node_preimage() { )])) .unwrap(); - let witness = DbTrieWitness::<_, A>::from_tx(provider.tx_ref()) - .compute(HashedPostState { - accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), - storages: HashMap::from_iter([( - hashed_address, - HashedStorage::from_iter(false, [(hashed_slot, U256::from(1))]), - )]), - }) - .unwrap(); + let witness = TrieWitness::new( + DatabaseTrieCursorFactory::<_, A>::new(provider.tx_ref()), + DatabaseHashedCursorFactory::new(provider.tx_ref()), + ) + .compute(HashedPostState { + accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), + storages: HashMap::from_iter([( + hashed_address, + HashedStorage::from_iter(false, [(hashed_slot, U256::from(1))]), + )]), + }) + .unwrap(); assert!(witness.contains_key(&state_root)); for node in multiproof.account_subtree.values() { assert_eq!(witness.get(&keccak256(node)), Some(node)); @@ -105,7 +107,11 @@ fn includes_nodes_for_destroyed_storage_nodes() { )])) .unwrap(); - let witness = DbTrieWitness::<_, A>::from_tx(provider.tx_ref()) + let witness = + TrieWitness::new( + DatabaseTrieCursorFactory::<_, A>::new(provider.tx_ref()), + DatabaseHashedCursorFactory::new(provider.tx_ref()), + ) .compute(HashedPostState { accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), storages: HashMap::from_iter([( @@ -155,19 +161,21 @@ fn correctly_decodes_branch_node_values() { )])) .unwrap(); - let witness = DbTrieWitness::<_, A>::from_tx(provider.tx_ref()) - .compute(HashedPostState { - accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), - storages: HashMap::from_iter([( - hashed_address, - HashedStorage::from_iter( - false, - [hashed_slot1, hashed_slot2] - .map(|hashed_slot| (hashed_slot, U256::from(2))), - ), - )]), - }) - .unwrap(); + let witness = TrieWitness::new( + DatabaseTrieCursorFactory::<_, A>::new(provider.tx_ref()), + DatabaseHashedCursorFactory::new(provider.tx_ref()), + ) + .compute(HashedPostState { + accounts: HashMap::from_iter([(hashed_address, Some(Account::default()))]), + storages: HashMap::from_iter([( + hashed_address, + HashedStorage::from_iter( + false, + [hashed_slot1, hashed_slot2].map(|hashed_slot| (hashed_slot, U256::from(2))), + ), + )]), + }) + .unwrap(); assert!(witness.contains_key(&state_root)); for node in multiproof.account_subtree.values() { assert_eq!(witness.get(&keccak256(node)), Some(node)); diff --git a/crates/trie/trie/src/trie_cursor/masked.rs b/crates/trie/trie/src/trie_cursor/masked.rs new file mode 100644 index 0000000000..cf8cfa2e41 --- /dev/null +++ b/crates/trie/trie/src/trie_cursor/masked.rs @@ -0,0 +1,752 @@ +use super::{TrieCursor, TrieCursorFactory, TrieStorageCursor}; +use alloy_primitives::{map::B256Map, B256}; +use reth_storage_errors::db::DatabaseError; +use reth_trie_common::{ + prefix_set::{PrefixSet, TriePrefixSets}, + BranchNodeCompact, Nibbles, +}; +use std::sync::Arc; + +/// A [`TrieCursorFactory`] wrapper that creates cursors which invalidate cached trie hash data +/// for children whose paths match the prefix sets in a [`TriePrefixSets`]. +/// +/// The `destroyed_accounts` field of the prefix sets is not used by the cursor — it is only +/// relevant during trie update finalization, not during cursor traversal. +#[derive(Debug, Clone)] +pub struct MaskedTrieCursorFactory { + /// Underlying trie cursor factory. + cursor_factory: CF, + /// Frozen prefix sets used for masking. + prefix_sets: TriePrefixSets, +} + +impl MaskedTrieCursorFactory { + /// Create a new factory from an inner cursor factory and frozen prefix sets. + pub const fn new(cursor_factory: CF, prefix_sets: TriePrefixSets) -> Self { + Self { cursor_factory, prefix_sets } + } +} + +impl TrieCursorFactory for MaskedTrieCursorFactory { + type AccountTrieCursor<'a> + = MaskedTrieCursor> + where + Self: 'a; + + type StorageTrieCursor<'a> + = MaskedTrieCursor> + where + Self: 'a; + + fn account_trie_cursor(&self) -> Result, DatabaseError> { + let cursor = self.cursor_factory.account_trie_cursor()?; + Ok(MaskedTrieCursor::new(cursor, self.prefix_sets.account_prefix_set.clone())) + } + + fn storage_trie_cursor( + &self, + hashed_address: B256, + ) -> Result, DatabaseError> { + let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?; + let prefix_set = + self.prefix_sets.storage_prefix_sets.get(&hashed_address).cloned().unwrap_or_default(); + Ok(MaskedTrieCursor::new_storage( + cursor, + prefix_set, + self.prefix_sets.storage_prefix_sets.clone(), + )) + } +} + +/// A [`TrieCursor`] wrapper that invalidates cached trie hash data for children whose paths match +/// a [`PrefixSet`]. +/// +/// For each node returned by the inner cursor, hash bits are unset for children whose paths match +/// the prefix set, and the corresponding hashes are removed from the node. If a node's `hash_mask` +/// and `tree_mask` are both empty after masking, the node is skipped entirely. +#[derive(Debug)] +pub struct MaskedTrieCursor { + /// The inner cursor. + cursor: C, + /// Prefix set used to determine which children's hashes to invalidate. + prefix_set: PrefixSet, + /// Storage prefix sets for swapping on `set_hashed_address`. + storage_prefix_sets: Option>, +} + +impl MaskedTrieCursor { + /// Create a new cursor wrapping `cursor`, masking hash bits for children whose paths match + /// `prefix_set`. + pub const fn new(cursor: C, prefix_set: PrefixSet) -> Self { + Self { cursor, prefix_set, storage_prefix_sets: None } + } + + /// Create a new storage cursor that can swap its prefix set on `set_hashed_address`. + pub const fn new_storage( + cursor: C, + prefix_set: PrefixSet, + storage_prefix_sets: B256Map, + ) -> Self { + Self { cursor, prefix_set, storage_prefix_sets: Some(storage_prefix_sets) } + } +} + +impl MaskedTrieCursor { + /// Mask hash bits on a node for children whose paths match the prefix set. + /// + /// Returns `true` if the node should be kept, `false` if it should be skipped (both + /// `hash_mask` and `tree_mask` are empty after masking). + fn mask_node(&mut self, key: &Nibbles, node: &mut BranchNodeCompact) -> bool { + if !self.prefix_set.contains(key) { + return true; + } + + // The subtree is modified — root hash is always invalid. + node.root_hash = None; + + let original_hash_mask = node.hash_mask; + if original_hash_mask.is_empty() { + return true; + } + + let mut new_hash_mask = original_hash_mask; + let mut child_path = *key; + let key_len = key.len(); + + for nibble in original_hash_mask.iter() { + child_path.truncate(key_len); + child_path.push(nibble); + + if self.prefix_set.contains(&child_path) { + new_hash_mask.unset_bit(nibble); + } + } + + if new_hash_mask != original_hash_mask { + // Remove hashes for unset bits in-place. + let hashes = Arc::make_mut(&mut node.hashes); + let mut write = 0; + for (read, nibble) in original_hash_mask.iter().enumerate() { + if new_hash_mask.is_bit_set(nibble) { + hashes[write] = hashes[read]; + write += 1; + } + } + hashes.truncate(write); + + node.hash_mask = new_hash_mask; + + if node.hash_mask.is_empty() && node.tree_mask.is_empty() { + return false; + } + } + + true + } + + /// Apply masking to entries, advancing past fully-masked nodes. + fn mask_entries( + &mut self, + mut entry: Option<(Nibbles, BranchNodeCompact)>, + ) -> Result, DatabaseError> { + while let Some((key, mut node)) = entry { + if self.mask_node(&key, &mut node) { + return Ok(Some((key, node))); + } + entry = self.cursor.next()?; + } + Ok(None) + } +} + +impl TrieCursor for MaskedTrieCursor { + fn seek_exact( + &mut self, + key: Nibbles, + ) -> Result, DatabaseError> { + if let Some((key, mut node)) = self.cursor.seek_exact(key)? { + if self.mask_node(&key, &mut node) { + Ok(Some((key, node))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } + + fn seek( + &mut self, + key: Nibbles, + ) -> Result, DatabaseError> { + let entry = self.cursor.seek(key)?; + self.mask_entries(entry) + } + + fn next(&mut self) -> Result, DatabaseError> { + let entry = self.cursor.next()?; + self.mask_entries(entry) + } + + fn current(&mut self) -> Result, DatabaseError> { + self.cursor.current() + } + + fn reset(&mut self) { + self.cursor.reset(); + } +} + +impl TrieStorageCursor for MaskedTrieCursor { + fn set_hashed_address(&mut self, hashed_address: B256) { + self.cursor.set_hashed_address(hashed_address); + if let Some(storage_prefix_sets) = &self.storage_prefix_sets { + self.prefix_set = storage_prefix_sets.get(&hashed_address).cloned().unwrap_or_default(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::trie_cursor::mock::MockTrieCursor; + use parking_lot::Mutex; + use reth_trie_common::prefix_set::PrefixSetMut; + use std::{collections::BTreeMap, sync::Arc}; + + fn make_cursor(nodes: Vec<(Nibbles, BranchNodeCompact)>) -> MockTrieCursor { + let map: BTreeMap = nodes.into_iter().collect(); + MockTrieCursor::new(Arc::new(map), Arc::new(Mutex::new(Vec::new()))) + } + + fn node(state_mask: u16) -> BranchNodeCompact { + BranchNodeCompact::new(state_mask, 0, 0, vec![], None) + } + + fn node_with_hashes(state_mask: u16, hash_mask: u16, hashes: Vec) -> BranchNodeCompact { + BranchNodeCompact::new(state_mask, 0, hash_mask, hashes, None) + } + + fn node_with_tree_mask( + state_mask: u16, + tree_mask: u16, + hash_mask: u16, + hashes: Vec, + ) -> BranchNodeCompact { + BranchNodeCompact::new(state_mask, tree_mask, hash_mask, hashes, None) + } + + fn hash(byte: u8) -> B256 { + B256::repeat_byte(byte) + } + + #[test] + fn test_seek_masks_matching_child_hashes() { + // Node at [0x1] with children 2 and 5 hashed. + // Prefix set marks child 2 as changed. + let nodes = vec![( + Nibbles::from_nibbles([0x1]), + node_with_hashes(0b0000_0000_0010_0100, 0b0000_0000_0010_0100, vec![hash(2), hash(5)]), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x2])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek(Nibbles::default()).unwrap(); + let (key, node) = result.unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x1])); + // Hash bit 2 should be unset, only bit 5 remains. + assert!(!node.hash_mask.is_bit_set(2)); + assert!(node.hash_mask.is_bit_set(5)); + assert_eq!(&*node.hashes, &[hash(5)]); + } + + #[test] + fn test_seek_skips_fully_masked_node() { + // Node at [0x1] with only child 3 hashed, tree_mask empty. + // Prefix set marks child 3 as changed → fully masked → skipped. + // Node at [0x2] is unaffected → returned. + let nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + node_with_hashes(0b0000_0000_0000_1000, 0b0000_0000_0000_1000, vec![hash(3)]), + ), + (Nibbles::from_nibbles([0x2]), node(0b0000_0000_0000_0001)), + ]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x3])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek(Nibbles::default()).unwrap(); + assert_eq!(result, Some((Nibbles::from_nibbles([0x2]), node(0b0000_0000_0000_0001)))); + } + + #[test] + fn test_node_with_tree_mask_not_skipped() { + // Node at [0x1] with child 3 hashed, tree_mask has bit 3 set. + // Prefix set marks child 3 → hash cleared, but tree_mask keeps the node alive. + let nodes = vec![( + Nibbles::from_nibbles([0x1]), + node_with_tree_mask( + 0b0000_0000_0000_1000, + 0b0000_0000_0000_1000, + 0b0000_0000_0000_1000, + vec![hash(3)], + ), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x3])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek(Nibbles::default()).unwrap(); + let (key, node) = result.unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x1])); + assert!(node.hash_mask.is_empty()); + assert!(node.tree_mask.is_bit_set(3)); + assert!(node.hashes.is_empty()); + } + + #[test] + fn test_seek_exact_masks_hash_bits() { + let nodes = vec![( + Nibbles::from_nibbles([0x1]), + node_with_tree_mask( + 0b0000_0000_0010_0100, + 0b0000_0000_0010_0100, + 0b0000_0000_0010_0100, + vec![hash(2), hash(5)], + ), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x5])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek_exact(Nibbles::from_nibbles([0x1])).unwrap(); + let (_, node) = result.unwrap(); + assert!(node.hash_mask.is_bit_set(2)); + assert!(!node.hash_mask.is_bit_set(5)); + assert_eq!(&*node.hashes, &[hash(2)]); + } + + #[test] + fn test_seek_exact_returns_none_for_fully_masked() { + let nodes = vec![( + Nibbles::from_nibbles([0x1]), + node_with_hashes(0b0000_0000_0000_0100, 0b0000_0000_0000_0100, vec![hash(2)]), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x2])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek_exact(Nibbles::from_nibbles([0x1])).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_next_masks_and_skips() { + // Three nodes: [0x1] unaffected, [0x2] fully masked, [0x3] unaffected. + let nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + node_with_hashes(0b0000_0000_0000_0010, 0b0000_0000_0000_0010, vec![hash(1)]), + ), + ( + Nibbles::from_nibbles([0x2]), + node_with_hashes(0b0000_0000_0001_0000, 0b0000_0000_0001_0000, vec![hash(4)]), + ), + ( + Nibbles::from_nibbles([0x3]), + node_with_hashes(0b0000_0000_0100_0000, 0b0000_0000_0100_0000, vec![hash(6)]), + ), + ]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x2, 0x4])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + // seek to [0x1], no match → returned unchanged. + let result = cursor.seek(Nibbles::from_nibbles([0x1])).unwrap(); + let (key, node) = result.unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x1])); + assert_eq!(&*node.hashes, &[hash(1)]); + + // next() should skip [0x2] (fully masked), returning [0x3]. + let result = cursor.next().unwrap(); + let (key, node) = result.unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x3])); + assert_eq!(&*node.hashes, &[hash(6)]); + } + + #[test] + fn test_no_match_returns_unchanged() { + let nodes = vec![( + Nibbles::from_nibbles([0x2]), + node_with_hashes(0b0000_0000_0000_0010, 0b0000_0000_0000_0010, vec![hash(1)]), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x3])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek(Nibbles::default()).unwrap(); + let (key, node) = result.unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x2])); + // Unchanged — prefix set doesn't match [0x2]. + assert!(node.hash_mask.is_bit_set(1)); + assert_eq!(&*node.hashes, &[hash(1)]); + } + + #[test] + fn test_empty_prefix_set_returns_all_unchanged() { + let h1 = hash(1); + let h2 = hash(2); + let nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + node_with_hashes(0b0000_0000_0000_0010, 0b0000_0000_0000_0010, vec![h1]), + ), + ( + Nibbles::from_nibbles([0x2]), + node_with_hashes(0b0000_0000_0000_0100, 0b0000_0000_0000_0100, vec![h2]), + ), + ]; + + let ps = PrefixSetMut::default(); + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let r1 = cursor.seek(Nibbles::default()).unwrap().unwrap(); + assert_eq!(r1.0, Nibbles::from_nibbles([0x1])); + assert_eq!(&*r1.1.hashes, &[h1]); + + let r2 = cursor.next().unwrap().unwrap(); + assert_eq!(r2.0, Nibbles::from_nibbles([0x2])); + assert_eq!(&*r2.1.hashes, &[h2]); + + assert_eq!(cursor.next().unwrap(), None); + } + + #[test] + fn test_root_hash_cleared_on_mask() { + let mut n = + node_with_hashes(0b0000_0000_0010_0100, 0b0000_0000_0010_0100, vec![hash(2), hash(5)]); + n.root_hash = Some(hash(0xFF)); + + let nodes = vec![(Nibbles::from_nibbles([0x1]), n)]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x2])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let (_, node) = cursor.seek(Nibbles::default()).unwrap().unwrap(); + assert_eq!(node.root_hash, None); + } + + #[test] + fn test_node_without_hashes_returned_unchanged() { + // Node with state_mask only (no hashes, no tree_mask) should pass through. + let nodes = vec![(Nibbles::from_nibbles([0x1]), node(0b0000_0000_0000_0011))]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x0])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let result = cursor.seek(Nibbles::default()).unwrap(); + assert_eq!(result, Some((Nibbles::from_nibbles([0x1]), node(0b0000_0000_0000_0011)))); + } + + #[test] + fn test_empty_cursor_returns_none() { + let nodes = vec![]; + let ps = PrefixSetMut::default(); + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + assert_eq!(cursor.seek(Nibbles::default()).unwrap(), None); + } + + #[test] + fn test_reset_delegates() { + let nodes = + vec![(Nibbles::from_nibbles([0x1]), node(1)), (Nibbles::from_nibbles([0x2]), node(2))]; + + let ps = PrefixSetMut::default(); + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let _ = cursor.seek(Nibbles::from_nibbles([0x1])).unwrap(); + assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1]))); + + cursor.reset(); + assert_eq!(cursor.current().unwrap(), None); + } + + #[test] + fn test_partial_mask_preserves_remaining_hashes() { + // Node at [0x1] with children 0, 3, 7 hashed. + // Prefix set marks children 0 and 7 as changed. + // Only hash for child 3 should remain. + let nodes = vec![( + Nibbles::from_nibbles([0x1]), + node_with_tree_mask( + 0b0000_0000_1000_1001, + 0b0000_0000_1000_1001, + 0b0000_0000_1000_1001, + vec![hash(0), hash(3), hash(7)], + ), + )]; + + let mut ps = PrefixSetMut::default(); + ps.insert(Nibbles::from_nibbles([0x1, 0x0])); + ps.insert(Nibbles::from_nibbles([0x1, 0x7])); + + let inner = make_cursor(nodes); + let mut cursor = MaskedTrieCursor::new(inner, ps.freeze()); + + let (key, node) = cursor.seek(Nibbles::default()).unwrap().unwrap(); + assert_eq!(key, Nibbles::from_nibbles([0x1])); + assert!(!node.hash_mask.is_bit_set(0)); + assert!(node.hash_mask.is_bit_set(3)); + assert!(!node.hash_mask.is_bit_set(7)); + assert_eq!(&*node.hashes, &[hash(3)]); + assert_eq!(node.root_hash, None); + } + + mod proptest_tests { + use crate::{ + hashed_cursor::{mock::MockHashedCursorFactory, HashedPostStateCursorFactory}, + proof::Proof, + trie_cursor::{ + masked::MaskedTrieCursorFactory, mock::MockTrieCursorFactory, + noop::NoopTrieCursorFactory, + }, + StateRoot, + }; + use alloy_primitives::{map::B256Set, B256, U256}; + use proptest::prelude::*; + use reth_primitives_traits::Account; + use reth_trie_common::{HashedPostState, HashedStorage, MultiProofTargets}; + + fn account_strategy() -> impl Strategy { + (any::(), any::(), any::<[u8; 32]>()).prop_map( + |(nonce, balance, code_hash)| Account { + nonce, + balance: U256::from(balance), + bytecode_hash: Some(B256::from(code_hash)), + }, + ) + } + + fn storage_value_strategy() -> impl Strategy { + any::().prop_filter("non-zero", |v| *v != 0).prop_map(U256::from) + } + + /// Generates a base dataset of 1000 storage slots for account `B256::ZERO`, + /// a 200-entry changeset partially overlapping with the base, and random + /// proof targets partially overlapping with both. + #[allow(clippy::type_complexity)] + fn test_input_strategy( + ) -> impl Strategy, Account, Vec<(B256, Option)>, Vec)> + { + ( + // 1000 base storage slots: unique keys with non-zero values + prop::collection::vec( + (any::<[u8; 32]>().prop_map(B256::from), storage_value_strategy()), + 1000, + ), + account_strategy(), + // 200 changeset entries: (key, Option) where None = removal + prop::collection::vec( + ( + any::<[u8; 32]>().prop_map(B256::from), + prop::option::of(storage_value_strategy()), + ), + 200, + ), + // Extra random keys for proof targets + prop::collection::vec(any::<[u8; 32]>().prop_map(B256::from), 50), + ) + .prop_flat_map( + |(base_slots, account, changeset_raw, extra_targets)| { + // Dedup base slots by key + let mut base_map = alloy_primitives::map::B256Map::default(); + for (k, v) in &base_slots { + base_map.insert(*k, *v); + } + let base_deduped: Vec<(B256, U256)> = + base_map.iter().map(|(&k, &v)| (k, v)).collect(); + let base_keys: Vec = base_deduped.iter().map(|(k, _)| *k).collect(); + + // Build changeset: 50% overlap with base keys, 50% new keys + let changeset_len = changeset_raw.len(); + let half = changeset_len / 2; + let base_keys_for_overlap = base_keys.clone(); + + // Use indices to select from base keys for overlap portion + let overlap_indices = + prop::collection::vec(0..base_keys_for_overlap.len().max(1), half); + + overlap_indices.prop_map(move |indices| { + let mut changeset: Vec<(B256, Option)> = Vec::new(); + + // First half: overlapping with base keys + for (i, (_, value)) in + indices.iter().zip(changeset_raw.iter()).take(half) + { + let key = if base_keys_for_overlap.is_empty() { + changeset_raw[*i].0 + } else { + base_keys_for_overlap[*i % base_keys_for_overlap.len()] + }; + changeset.push((key, *value)); + } + + // Second half: new keys from changeset_raw + for (key, value) in changeset_raw.iter().skip(half) { + changeset.push((*key, *value)); + } + + // Build proof targets: mix of base keys, changeset keys, and randoms + let changeset_keys: Vec = + changeset.iter().map(|(k, _)| *k).collect(); + let mut proof_slot_targets: Vec = Vec::new(); + + // ~40% from base + for k in base_keys.iter().take(40) { + proof_slot_targets.push(*k); + } + // ~30% from changeset + for k in changeset_keys.iter().take(30) { + proof_slot_targets.push(*k); + } + // ~30% random + for k in extra_targets.iter().take(30) { + proof_slot_targets.push(*k); + } + + (base_deduped.clone(), account, changeset, proof_slot_targets) + }) + }, + ) + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + #[test] + fn proptest_masked_cursor_multiproof_equivalence( + (base_slots, account, changeset, proof_slot_targets) in test_input_strategy() + ) { + reth_tracing::init_test_tracing(); + + let hashed_address = B256::ZERO; + + // Step 1: Create the base hashed post state with a single account + // and 1000 storage slots. + let base_state = HashedPostState { + accounts: std::iter::once((hashed_address, Some(account))).collect(), + storages: std::iter::once(( + hashed_address, + HashedStorage::from_iter(false, base_slots), + )) + .collect(), + }; + + // Step 2: Compute trie updates from state root over the full base state. + let base_hashed_cursor_factory = + MockHashedCursorFactory::from_hashed_post_state(base_state); + let (_, trie_updates) = StateRoot::new( + NoopTrieCursorFactory, + base_hashed_cursor_factory.clone(), + ) + .root_with_updates() + .expect("state root computation should succeed"); + + // Step 3: Create a MockTrieCursorFactory from those trie updates. + let mock_trie_cursor_factory = + MockTrieCursorFactory::from_trie_updates(trie_updates); + + // Step 4: Build the changeset post state. Removals use U256::ZERO. + let changeset_storage: Vec<(B256, U256)> = changeset + .iter() + .map(|(k, v)| (*k, v.unwrap_or(U256::ZERO))) + .collect(); + let changeset_state = HashedPostState { + accounts: std::iter::once((hashed_address, Some(account))).collect(), + storages: std::iter::once(( + hashed_address, + HashedStorage::from_iter(false, changeset_storage), + )) + .collect(), + }; + + // Step 5: Generate prefix sets from the changeset. + let prefix_sets_mut = changeset_state.construct_prefix_sets(); + + // Step 6: Build proof targets. + let slot_targets: B256Set = proof_slot_targets.into_iter().collect(); + let targets = + MultiProofTargets::from_iter([(hashed_address, slot_targets)]); + + // Step 7: Create the HashedPostStateCursorFactory overlaying changeset + // on the base. + let changeset_sorted = changeset_state.into_sorted(); + let overlay_cursor_factory = HashedPostStateCursorFactory::new( + base_hashed_cursor_factory, + &changeset_sorted, + ); + + // Step 8a: Approach A — prefix sets passed to Proof directly. + let proof_a = Proof::new( + mock_trie_cursor_factory.clone(), + overlay_cursor_factory.clone(), + ) + .with_prefix_sets_mut(prefix_sets_mut.clone()); + let multiproof_a = proof_a + .multiproof(targets.clone()) + .expect("multiproof A should succeed"); + + // Step 8b: Approach B — MaskedTrieCursorFactory, no prefix sets on Proof. + let masked_trie_cursor_factory = MaskedTrieCursorFactory::new( + mock_trie_cursor_factory, + prefix_sets_mut.freeze(), + ); + let proof_b = Proof::new( + masked_trie_cursor_factory, + overlay_cursor_factory, + ); + let multiproof_b = proof_b + .multiproof(targets) + .expect("multiproof B should succeed"); + + // Step 9: Compare results. + assert_eq!( + multiproof_a, multiproof_b, + "multiproof with prefix sets should equal multiproof with masked cursor" + ); + } + } + } +} diff --git a/crates/trie/trie/src/trie_cursor/mod.rs b/crates/trie/trie/src/trie_cursor/mod.rs index 4390bc0947..1ef2546761 100644 --- a/crates/trie/trie/src/trie_cursor/mod.rs +++ b/crates/trie/trie/src/trie_cursor/mod.rs @@ -11,6 +11,9 @@ pub mod subnode; /// Noop trie cursor implementations. pub mod noop; +/// Masked trie cursor wrapper that skips nodes matching a prefix set. +pub mod masked; + /// Depth-first trie iterator. pub mod depth_first; diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index 5d7050a7bc..92a3916a3c 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -1,6 +1,5 @@ use crate::{ hashed_cursor::{HashedCursor, HashedCursorFactory}, - prefix_set::TriePrefixSetsMut, proof::{Proof, ProofTrieNodeProviderFactory}, trie_cursor::TrieCursorFactory, }; @@ -33,8 +32,6 @@ pub struct TrieWitness { trie_cursor_factory: T, /// The factory for hashed cursors. hashed_cursor_factory: H, - /// A set of prefix sets that have changes. - prefix_sets: TriePrefixSetsMut, /// Flag indicating whether the root node should always be included (even if the target state /// is empty). This setting is useful if the caller wants to verify the witness against the /// parent state root. @@ -50,7 +47,6 @@ impl TrieWitness { Self { trie_cursor_factory, hashed_cursor_factory, - prefix_sets: TriePrefixSetsMut::default(), always_include_root_node: false, witness: HashMap::default(), } @@ -61,7 +57,6 @@ impl TrieWitness { TrieWitness { trie_cursor_factory, hashed_cursor_factory: self.hashed_cursor_factory, - prefix_sets: self.prefix_sets, always_include_root_node: self.always_include_root_node, witness: self.witness, } @@ -72,18 +67,11 @@ impl TrieWitness { TrieWitness { trie_cursor_factory: self.trie_cursor_factory, hashed_cursor_factory, - prefix_sets: self.prefix_sets, always_include_root_node: self.always_include_root_node, witness: self.witness, } } - /// Set the prefix sets. They have to be mutable in order to allow extension with proof target. - pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self { - self.prefix_sets = prefix_sets; - self - } - /// Set `always_include_root_node` to true. Root node will be included even in empty state. /// This setting is useful if the caller wants to verify the witness against the /// parent state root. @@ -115,10 +103,8 @@ where } else { self.get_proof_targets(&state)? }; - let prefix_sets = core::mem::take(&mut self.prefix_sets); let multiproof = Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone()) - .with_prefix_sets_mut(prefix_sets) .multiproof(proof_targets.clone())?; // No need to reconstruct the rest of the trie, we just need to include