From ebe1a8b014555495e3cce640561d0ec7cc30fb1e Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 19 Sep 2025 15:24:46 +0200 Subject: [PATCH] chore(trie): Use Vec> in InMemoryTrieCursor (#18479) --- crates/trie/common/src/updates.rs | 79 +- crates/trie/db/tests/trie.rs | 14 +- crates/trie/sparse/benches/root.rs | 9 +- crates/trie/trie/src/trie_cursor/in_memory.rs | 728 +++++++++++------- crates/trie/trie/src/trie_cursor/mock.rs | 3 +- 5 files changed, 509 insertions(+), 324 deletions(-) diff --git a/crates/trie/common/src/updates.rs b/crates/trie/common/src/updates.rs index a752fd06d7..5f32f388c0 100644 --- a/crates/trie/common/src/updates.rs +++ b/crates/trie/common/src/updates.rs @@ -107,15 +107,8 @@ impl TrieUpdates { } /// Converts trie updates into [`TrieUpdatesSorted`]. - pub fn into_sorted(self) -> TrieUpdatesSorted { - let mut account_nodes = Vec::from_iter(self.account_nodes); - account_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0)); - let storage_tries = self - .storage_tries - .into_iter() - .map(|(hashed_address, updates)| (hashed_address, updates.into_sorted())) - .collect(); - TrieUpdatesSorted { removed_nodes: self.removed_nodes, account_nodes, storage_tries } + pub fn into_sorted(mut self) -> TrieUpdatesSorted { + self.drain_into_sorted() } /// Converts trie updates into [`TrieUpdatesSorted`], but keeping the maps allocated by @@ -126,7 +119,17 @@ impl TrieUpdates { /// This allows us to reuse the allocated space. This allocates new space for the sorted /// updates, like `into_sorted`. pub fn drain_into_sorted(&mut self) -> TrieUpdatesSorted { - let mut account_nodes = self.account_nodes.drain().collect::>(); + let mut account_nodes = self + .account_nodes + .drain() + .map(|(path, node)| { + // Updated nodes take precedence over removed nodes. + self.removed_nodes.remove(&path); + (path, Some(node)) + }) + .collect::>(); + + account_nodes.extend(self.removed_nodes.drain().map(|path| (path, None))); account_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let storage_tries = self @@ -134,12 +137,7 @@ impl TrieUpdates { .drain() .map(|(hashed_address, updates)| (hashed_address, updates.into_sorted())) .collect(); - - TrieUpdatesSorted { - removed_nodes: self.removed_nodes.clone(), - account_nodes, - storage_tries, - } + TrieUpdatesSorted { account_nodes, storage_tries } } /// Converts trie updates into [`TrieUpdatesSortedRef`]. @@ -266,14 +264,21 @@ impl StorageTrieUpdates { } /// Convert storage trie updates into [`StorageTrieUpdatesSorted`]. - pub fn into_sorted(self) -> StorageTrieUpdatesSorted { - let mut storage_nodes = Vec::from_iter(self.storage_nodes); + pub fn into_sorted(mut self) -> StorageTrieUpdatesSorted { + let mut storage_nodes = self + .storage_nodes + .into_iter() + .map(|(path, node)| { + // Updated nodes take precedence over removed nodes. + self.removed_nodes.remove(&path); + (path, Some(node)) + }) + .collect::>(); + + storage_nodes.extend(self.removed_nodes.into_iter().map(|path| (path, None))); storage_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0)); - StorageTrieUpdatesSorted { - is_deleted: self.is_deleted, - removed_nodes: self.removed_nodes, - storage_nodes, - } + + StorageTrieUpdatesSorted { is_deleted: self.is_deleted, storage_nodes } } /// Convert storage trie updates into [`StorageTrieUpdatesSortedRef`]. @@ -425,25 +430,19 @@ pub struct TrieUpdatesSortedRef<'a> { #[derive(PartialEq, Eq, Clone, Default, Debug)] #[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))] pub struct TrieUpdatesSorted { - /// Sorted collection of updated state nodes with corresponding paths. - pub account_nodes: Vec<(Nibbles, BranchNodeCompact)>, - /// The set of removed state node keys. - pub removed_nodes: HashSet, + /// Sorted collection of updated state nodes with corresponding paths. None indicates that a + /// node was removed. + pub account_nodes: Vec<(Nibbles, Option)>, /// Storage tries stored by hashed address of the account the trie belongs to. pub storage_tries: B256Map, } impl TrieUpdatesSorted { /// Returns reference to updated account nodes. - pub fn account_nodes_ref(&self) -> &[(Nibbles, BranchNodeCompact)] { + pub fn account_nodes_ref(&self) -> &[(Nibbles, Option)] { &self.account_nodes } - /// Returns reference to removed account nodes. - pub const fn removed_nodes_ref(&self) -> &HashSet { - &self.removed_nodes - } - /// Returns reference to updated storage tries. pub const fn storage_tries_ref(&self) -> &B256Map { &self.storage_tries @@ -468,10 +467,9 @@ pub struct StorageTrieUpdatesSortedRef<'a> { pub struct StorageTrieUpdatesSorted { /// Flag indicating whether the trie has been deleted/wiped. pub is_deleted: bool, - /// Sorted collection of updated storage nodes with corresponding paths. - pub storage_nodes: Vec<(Nibbles, BranchNodeCompact)>, - /// The set of removed storage node keys. - pub removed_nodes: HashSet, + /// Sorted collection of updated storage nodes with corresponding paths. None indicates a node + /// is removed. + pub storage_nodes: Vec<(Nibbles, Option)>, } impl StorageTrieUpdatesSorted { @@ -481,14 +479,9 @@ impl StorageTrieUpdatesSorted { } /// Returns reference to updated storage nodes. - pub fn storage_nodes_ref(&self) -> &[(Nibbles, BranchNodeCompact)] { + pub fn storage_nodes_ref(&self) -> &[(Nibbles, Option)] { &self.storage_nodes } - - /// Returns reference to removed storage nodes. - pub const fn removed_nodes_ref(&self) -> &HashSet { - &self.removed_nodes - } } /// Excludes empty nibbles from the given iterator. diff --git a/crates/trie/db/tests/trie.rs b/crates/trie/db/tests/trie.rs index 6f2588f39e..e16c24c57f 100644 --- a/crates/trie/db/tests/trie.rs +++ b/crates/trie/db/tests/trie.rs @@ -428,6 +428,7 @@ fn account_and_storage_trie() { let (nibbles1a, node1a) = account_updates.first().unwrap(); assert_eq!(nibbles1a.to_vec(), vec![0xB]); + let node1a = node1a.as_ref().unwrap(); assert_eq!(node1a.state_mask, TrieMask::new(0b1011)); assert_eq!(node1a.tree_mask, TrieMask::new(0b0001)); assert_eq!(node1a.hash_mask, TrieMask::new(0b1001)); @@ -436,6 +437,7 @@ fn account_and_storage_trie() { let (nibbles2a, node2a) = account_updates.last().unwrap(); assert_eq!(nibbles2a.to_vec(), vec![0xB, 0x0]); + let node2a = node2a.as_ref().unwrap(); assert_eq!(node2a.state_mask, TrieMask::new(0b10001)); assert_eq!(node2a.tree_mask, TrieMask::new(0b00000)); assert_eq!(node2a.hash_mask, TrieMask::new(0b10000)); @@ -471,6 +473,7 @@ fn account_and_storage_trie() { let (nibbles1b, node1b) = account_updates.first().unwrap(); assert_eq!(nibbles1b.to_vec(), vec![0xB]); + let node1b = node1b.as_ref().unwrap(); assert_eq!(node1b.state_mask, TrieMask::new(0b1011)); assert_eq!(node1b.tree_mask, TrieMask::new(0b0001)); assert_eq!(node1b.hash_mask, TrieMask::new(0b1011)); @@ -481,6 +484,7 @@ fn account_and_storage_trie() { let (nibbles2b, node2b) = account_updates.last().unwrap(); assert_eq!(nibbles2b.to_vec(), vec![0xB, 0x0]); + let node2b = node2b.as_ref().unwrap(); assert_eq!(node2a, node2b); tx.commit().unwrap(); @@ -520,8 +524,9 @@ fn account_and_storage_trie() { assert_eq!(trie_updates.account_nodes_ref().len(), 1); - let (nibbles1c, node1c) = trie_updates.account_nodes_ref().iter().next().unwrap(); - assert_eq!(nibbles1c.to_vec(), vec![0xB]); + let entry = trie_updates.account_nodes_ref().iter().next().unwrap(); + assert_eq!(entry.0.to_vec(), vec![0xB]); + let node1c = entry.1; assert_eq!(node1c.state_mask, TrieMask::new(0b1011)); assert_eq!(node1c.tree_mask, TrieMask::new(0b0000)); @@ -578,8 +583,9 @@ fn account_and_storage_trie() { assert_eq!(trie_updates.account_nodes_ref().len(), 1); - let (nibbles1d, node1d) = trie_updates.account_nodes_ref().iter().next().unwrap(); - assert_eq!(nibbles1d.to_vec(), vec![0xB]); + let entry = trie_updates.account_nodes_ref().iter().next().unwrap(); + assert_eq!(entry.0.to_vec(), vec![0xB]); + let node1d = entry.1; assert_eq!(node1d.state_mask, TrieMask::new(0b1011)); assert_eq!(node1d.tree_mask, TrieMask::new(0b0000)); diff --git a/crates/trie/sparse/benches/root.rs b/crates/trie/sparse/benches/root.rs index 396776ecf5..9eaf54c2d0 100644 --- a/crates/trie/sparse/benches/root.rs +++ b/crates/trie/sparse/benches/root.rs @@ -7,7 +7,7 @@ use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner}; use reth_trie::{ hashed_cursor::{noop::NoopHashedStorageCursor, HashedPostStateStorageCursor}, node_iter::{TrieElement, TrieNodeIter}, - trie_cursor::{noop::NoopStorageTrieCursor, InMemoryStorageTrieCursor}, + trie_cursor::{noop::NoopStorageTrieCursor, InMemoryTrieCursor}, updates::StorageTrieUpdates, walker::TrieWalker, HashedStorage, @@ -134,10 +134,9 @@ fn calculate_root_from_leaves_repeated(c: &mut Criterion) { }; let walker = TrieWalker::<_>::storage_trie( - InMemoryStorageTrieCursor::new( - B256::ZERO, - NoopStorageTrieCursor::default(), - Some(&trie_updates_sorted), + InMemoryTrieCursor::new( + Some(NoopStorageTrieCursor::default()), + &trie_updates_sorted.storage_nodes, ), prefix_set, ); diff --git a/crates/trie/trie/src/trie_cursor/in_memory.rs b/crates/trie/trie/src/trie_cursor/in_memory.rs index 4925dc8a66..5a0223e180 100644 --- a/crates/trie/trie/src/trie_cursor/in_memory.rs +++ b/crates/trie/trie/src/trie_cursor/in_memory.rs @@ -1,9 +1,6 @@ use super::{TrieCursor, TrieCursorFactory}; -use crate::{ - forward_cursor::ForwardInMemoryCursor, - updates::{StorageTrieUpdatesSorted, TrieUpdatesSorted}, -}; -use alloy_primitives::{map::HashSet, B256}; +use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted}; +use alloy_primitives::B256; use reth_storage_errors::db::DatabaseError; use reth_trie_common::{BranchNodeCompact, Nibbles}; @@ -24,283 +21,472 @@ impl<'a, CF> InMemoryTrieCursorFactory<'a, CF> { } impl<'a, CF: TrieCursorFactory> TrieCursorFactory for InMemoryTrieCursorFactory<'a, CF> { - type AccountTrieCursor = InMemoryAccountTrieCursor<'a, CF::AccountTrieCursor>; - type StorageTrieCursor = InMemoryStorageTrieCursor<'a, CF::StorageTrieCursor>; + type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>; + type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>; fn account_trie_cursor(&self) -> Result { let cursor = self.cursor_factory.account_trie_cursor()?; - Ok(InMemoryAccountTrieCursor::new(cursor, self.trie_updates)) + Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.account_nodes_ref())) } fn storage_trie_cursor( &self, hashed_address: B256, ) -> Result { - let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?; - Ok(InMemoryStorageTrieCursor::new( - hashed_address, - cursor, - self.trie_updates.storage_tries.get(&hashed_address), - )) - } -} + // if the storage trie has no updates then we use this as the in-memory overlay. + static EMPTY_UPDATES: Vec<(Nibbles, Option)> = Vec::new(); -/// The cursor to iterate over account trie updates and corresponding database entries. -/// It will always give precedence to the data from the trie updates. -#[derive(Debug)] -pub struct InMemoryAccountTrieCursor<'a, C> { - /// The underlying cursor. - cursor: C, - /// Forward-only in-memory cursor over storage trie nodes. - in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, BranchNodeCompact>, - /// Collection of removed trie nodes. - removed_nodes: &'a HashSet, - /// Last key returned by the cursor. - last_key: Option, -} + let storage_trie_updates = self.trie_updates.storage_tries.get(&hashed_address); + let (storage_nodes, cleared) = storage_trie_updates + .map(|u| (u.storage_nodes_ref(), u.is_deleted())) + .unwrap_or((&EMPTY_UPDATES, false)); -impl<'a, C: TrieCursor> InMemoryAccountTrieCursor<'a, C> { - /// Create new account trie cursor from underlying cursor and reference to - /// [`TrieUpdatesSorted`]. - pub fn new(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self { - let in_memory_cursor = ForwardInMemoryCursor::new(&trie_updates.account_nodes); - Self { - cursor, - in_memory_cursor, - removed_nodes: &trie_updates.removed_nodes, - last_key: None, - } - } - - fn seek_inner( - &mut self, - key: Nibbles, - exact: bool, - ) -> Result, DatabaseError> { - let in_memory = self.in_memory_cursor.seek(&key); - if in_memory.as_ref().is_some_and(|entry| entry.0 == key) { - return Ok(in_memory) - } - - // Reposition the cursor to the first greater or equal node that wasn't removed. - let mut db_entry = self.cursor.seek(key)?; - while db_entry.as_ref().is_some_and(|entry| self.removed_nodes.contains(&entry.0)) { - db_entry = self.cursor.next()?; - } - - // Compare two entries and return the lowest. - // If seek is exact, filter the entry for exact key match. - Ok(compare_trie_node_entries(in_memory, db_entry) - .filter(|(nibbles, _)| !exact || nibbles == &key)) - } - - fn next_inner( - &mut self, - last: Nibbles, - ) -> Result, DatabaseError> { - let in_memory = self.in_memory_cursor.first_after(&last); - - // Reposition the cursor to the first greater or equal node that wasn't removed. - let mut db_entry = self.cursor.seek(last)?; - while db_entry - .as_ref() - .is_some_and(|entry| entry.0 < last || self.removed_nodes.contains(&entry.0)) - { - db_entry = self.cursor.next()?; - } - - // Compare two entries and return the lowest. - Ok(compare_trie_node_entries(in_memory, db_entry)) - } -} - -impl TrieCursor for InMemoryAccountTrieCursor<'_, C> { - fn seek_exact( - &mut self, - key: Nibbles, - ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, true)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); - Ok(entry) - } - - fn seek( - &mut self, - key: Nibbles, - ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, false)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); - Ok(entry) - } - - fn next(&mut self) -> Result, DatabaseError> { - let next = match &self.last_key { - Some(last) => { - let entry = self.next_inner(*last)?; - self.last_key = entry.as_ref().map(|entry| entry.0); - entry - } - // no previous entry was found - None => None, - }; - Ok(next) - } - - fn current(&mut self) -> Result, DatabaseError> { - match &self.last_key { - Some(key) => Ok(Some(*key)), - None => self.cursor.current(), - } - } -} - -/// The cursor to iterate over storage trie updates and corresponding database entries. -/// It will always give precedence to the data from the trie updates. -#[derive(Debug)] -#[expect(dead_code)] -pub struct InMemoryStorageTrieCursor<'a, C> { - /// The hashed address of the account that trie belongs to. - hashed_address: B256, - /// The underlying cursor. - cursor: C, - /// Forward-only in-memory cursor over storage trie nodes. - in_memory_cursor: Option>, - /// Reference to the set of removed storage node keys. - removed_nodes: Option<&'a HashSet>, - /// The flag indicating whether the storage trie was cleared. - storage_trie_cleared: bool, - /// Last key returned by the cursor. - last_key: Option, -} - -impl<'a, C> InMemoryStorageTrieCursor<'a, C> { - /// Create new storage trie cursor from underlying cursor and reference to - /// [`StorageTrieUpdatesSorted`]. - pub fn new( - hashed_address: B256, - cursor: C, - updates: Option<&'a StorageTrieUpdatesSorted>, - ) -> Self { - let in_memory_cursor = updates.map(|u| ForwardInMemoryCursor::new(&u.storage_nodes)); - let removed_nodes = updates.map(|u| &u.removed_nodes); - let storage_trie_cleared = updates.is_some_and(|u| u.is_deleted); - Self { - hashed_address, - cursor, - in_memory_cursor, - removed_nodes, - storage_trie_cleared, - last_key: None, - } - } -} - -impl InMemoryStorageTrieCursor<'_, C> { - fn seek_inner( - &mut self, - key: Nibbles, - exact: bool, - ) -> Result, DatabaseError> { - let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.seek(&key)); - if self.storage_trie_cleared || in_memory.as_ref().is_some_and(|entry| entry.0 == key) { - return Ok(in_memory.filter(|(nibbles, _)| !exact || nibbles == &key)) - } - - // Reposition the cursor to the first greater or equal node that wasn't removed. - let mut db_entry = self.cursor.seek(key)?; - while db_entry - .as_ref() - .is_some_and(|entry| self.removed_nodes.as_ref().is_some_and(|r| r.contains(&entry.0))) - { - db_entry = self.cursor.next()?; - } - - // Compare two entries and return the lowest. - // If seek is exact, filter the entry for exact key match. - Ok(compare_trie_node_entries(in_memory, db_entry) - .filter(|(nibbles, _)| !exact || nibbles == &key)) - } - - fn next_inner( - &mut self, - last: Nibbles, - ) -> Result, DatabaseError> { - let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.first_after(&last)); - if self.storage_trie_cleared { - return Ok(in_memory) - } - - // Reposition the cursor to the first greater or equal node that wasn't removed. - let mut db_entry = self.cursor.seek(last)?; - while db_entry.as_ref().is_some_and(|entry| { - entry.0 < last || self.removed_nodes.as_ref().is_some_and(|r| r.contains(&entry.0)) - }) { - db_entry = self.cursor.next()?; - } - - // Compare two entries and return the lowest. - Ok(compare_trie_node_entries(in_memory, db_entry)) - } -} - -impl TrieCursor for InMemoryStorageTrieCursor<'_, C> { - fn seek_exact( - &mut self, - key: Nibbles, - ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, true)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); - Ok(entry) - } - - fn seek( - &mut self, - key: Nibbles, - ) -> Result, DatabaseError> { - let entry = self.seek_inner(key, false)?; - self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); - Ok(entry) - } - - fn next(&mut self) -> Result, DatabaseError> { - let next = match &self.last_key { - Some(last) => { - let entry = self.next_inner(*last)?; - self.last_key = entry.as_ref().map(|entry| entry.0); - entry - } - // no previous entry was found - None => None, - }; - Ok(next) - } - - fn current(&mut self) -> Result, DatabaseError> { - match &self.last_key { - Some(key) => Ok(Some(*key)), - None => self.cursor.current(), - } - } -} - -/// Return the node with the lowest nibbles. -/// -/// Given the next in-memory and database entries, return the smallest of the two. -/// If the node keys are the same, the in-memory entry is given precedence. -fn compare_trie_node_entries( - mut in_memory_item: Option<(Nibbles, BranchNodeCompact)>, - mut db_item: Option<(Nibbles, BranchNodeCompact)>, -) -> Option<(Nibbles, BranchNodeCompact)> { - if let Some((in_memory_entry, db_entry)) = in_memory_item.as_ref().zip(db_item.as_ref()) { - // If both are not empty, return the smallest of the two - // In-memory is given precedence if keys are equal - if in_memory_entry.0 <= db_entry.0 { - in_memory_item.take() + let cursor = if cleared { + None } else { - db_item.take() - } - } else { - // Return either non-empty entry - db_item.or(in_memory_item) + Some(self.cursor_factory.storage_trie_cursor(hashed_address)?) + }; + + Ok(InMemoryTrieCursor::new(cursor, storage_nodes)) + } +} + +/// A cursor to iterate over trie updates and corresponding database entries. +/// It will always give precedence to the data from the trie updates. +#[derive(Debug)] +pub struct InMemoryTrieCursor<'a, C> { + /// The underlying cursor. If None then it is assumed there is no DB data. + cursor: Option, + /// Forward-only in-memory cursor over storage trie nodes. + in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option>, + /// Last key returned by the cursor. + last_key: Option, +} + +impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> { + /// Create new trie cursor which combines a DB cursor (None to assume empty DB) and a set of + /// in-memory trie nodes. + pub fn new( + cursor: Option, + trie_updates: &'a [(Nibbles, Option)], + ) -> Self { + let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates); + Self { cursor, in_memory_cursor, last_key: None } + } + + fn seek_inner( + &mut self, + key: Nibbles, + exact: bool, + ) -> Result, DatabaseError> { + let mut mem_entry = self.in_memory_cursor.seek(&key); + let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten(); + + // exact matching is easy, if overlay has a value then return that (updated or removed), or + // if db has a value then return that. + if exact { + return Ok(match (mem_entry, db_entry) { + (Some((mem_key, entry_inner)), _) if mem_key == key => { + entry_inner.map(|node| (key, node)) + } + (_, Some((db_key, node))) if db_key == key => Some((key, node)), + _ => None, + }) + } + + loop { + match (mem_entry, &db_entry) { + (Some((mem_key, None)), _) + if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) => + { + // If overlay has a removed node but DB cursor is exhausted or ahead of the + // in-memory cursor then move ahead in-memory, as there might be further + // non-removed overlay nodes. + mem_entry = self.in_memory_cursor.first_after(&mem_key); + } + (Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => { + // If overlay has a removed node which is returned from DB then move both + // cursors ahead to the next key. + mem_entry = self.in_memory_cursor.first_after(&mem_key); + db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten(); + } + (Some((mem_key, Some(node))), _) + if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) => + { + // If overlay returns a node prior to the DB's node, or the DB is exhausted, + // then we return the overlay's node. + return Ok(Some((mem_key, node))) + } + // All other cases: + // - mem_key > db_key + // - overlay is exhausted + // Return the db_entry. If DB is also exhausted then this returns None. + _ => return Ok(db_entry), + } + } + } + + fn next_inner( + &mut self, + last: Nibbles, + ) -> Result, DatabaseError> { + let Some(key) = last.increment() else { return Ok(None) }; + self.seek_inner(key, false) + } +} + +impl TrieCursor for InMemoryTrieCursor<'_, C> { + fn seek_exact( + &mut self, + key: Nibbles, + ) -> Result, DatabaseError> { + let entry = self.seek_inner(key, true)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); + Ok(entry) + } + + fn seek( + &mut self, + key: Nibbles, + ) -> Result, DatabaseError> { + let entry = self.seek_inner(key, false)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles); + Ok(entry) + } + + fn next(&mut self) -> Result, DatabaseError> { + let next = match &self.last_key { + Some(last) => { + let entry = self.next_inner(*last)?; + self.last_key = entry.as_ref().map(|entry| entry.0); + entry + } + // no previous entry was found + None => None, + }; + Ok(next) + } + + fn current(&mut self) -> Result, DatabaseError> { + match &self.last_key { + Some(key) => Ok(Some(*key)), + None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::trie_cursor::mock::MockTrieCursor; + use parking_lot::Mutex; + use std::{collections::BTreeMap, sync::Arc}; + + #[derive(Debug)] + struct InMemoryTrieCursorTestCase { + db_nodes: Vec<(Nibbles, BranchNodeCompact)>, + in_memory_nodes: Vec<(Nibbles, Option)>, + expected_results: Vec<(Nibbles, BranchNodeCompact)>, + } + + fn execute_test(test_case: InMemoryTrieCursorTestCase) { + let db_nodes_map: BTreeMap = + test_case.db_nodes.into_iter().collect(); + let db_nodes_arc = Arc::new(db_nodes_map); + let visited_keys = Arc::new(Mutex::new(Vec::new())); + let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys); + + let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes); + + let mut results = Vec::new(); + + if let Some(first_expected) = test_case.expected_results.first() { + if let Ok(Some(entry)) = cursor.seek(first_expected.0) { + results.push(entry); + } + } + + while let Ok(Some(entry)) = cursor.next() { + results.push(entry); + } + + assert_eq!( + results, test_case.expected_results, + "Results mismatch.\nGot: {:?}\nExpected: {:?}", + results, test_case.expected_results + ); + } + + #[test] + fn test_empty_db_and_memory() { + let test_case = InMemoryTrieCursorTestCase { + db_nodes: vec![], + in_memory_nodes: vec![], + expected_results: vec![], + }; + execute_test(test_case); + } + + #[test] + fn test_only_db_nodes() { + let db_nodes = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ]; + + let test_case = InMemoryTrieCursorTestCase { + db_nodes: db_nodes.clone(), + in_memory_nodes: vec![], + expected_results: db_nodes, + }; + execute_test(test_case); + } + + #[test] + fn test_only_in_memory_nodes() { + let in_memory_nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)), + ), + ( + Nibbles::from_nibbles([0x2]), + Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)), + ), + ( + Nibbles::from_nibbles([0x3]), + Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ), + ]; + + let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes + .iter() + .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone()))) + .collect(); + + let test_case = + InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_in_memory_overwrites_db() { + let db_nodes = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)), + ]; + + let in_memory_nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)), + ), + ( + Nibbles::from_nibbles([0x3]), + Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ), + ]; + + let expected_results = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ]; + + let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_in_memory_deletes_db_nodes() { + let db_nodes = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ]; + + let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)]; + + let expected_results = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ]; + + let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_complex_interleaving() { + let db_nodes = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)), + (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)), + ]; + + let in_memory_nodes = vec![ + ( + Nibbles::from_nibbles([0x2]), + Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)), + ), + (Nibbles::from_nibbles([0x3]), None), + ( + Nibbles::from_nibbles([0x4]), + Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)), + ), + ( + Nibbles::from_nibbles([0x6]), + Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)), + ), + (Nibbles::from_nibbles([0x7]), None), + ( + Nibbles::from_nibbles([0x8]), + Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)), + ), + ]; + + let expected_results = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)), + (Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)), + (Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)), + (Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)), + (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)), + ]; + + let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_seek_exact() { + let db_nodes = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)), + (Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ]; + + let in_memory_nodes = vec![( + Nibbles::from_nibbles([0x2]), + Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)), + )]; + + let db_nodes_map: BTreeMap = db_nodes.into_iter().collect(); + let db_nodes_arc = Arc::new(db_nodes_map); + let visited_keys = Arc::new(Mutex::new(Vec::new())); + let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys); + + let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes); + + let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap(); + assert_eq!( + result, + Some(( + Nibbles::from_nibbles([0x2]), + BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None) + )) + ); + + let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap(); + assert_eq!( + result, + Some(( + Nibbles::from_nibbles([0x3]), + BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None) + )) + ); + + let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_multiple_consecutive_deletes() { + let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10) + .map(|i| { + ( + Nibbles::from_nibbles([i]), + BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None), + ) + }) + .collect(); + + let in_memory_nodes = vec![ + (Nibbles::from_nibbles([0x3]), None), + (Nibbles::from_nibbles([0x4]), None), + (Nibbles::from_nibbles([0x5]), None), + (Nibbles::from_nibbles([0x6]), None), + ]; + + let expected_results = vec![ + (Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)), + (Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)), + (Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)), + (Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)), + (Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)), + (Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)), + ]; + + let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_empty_db_with_in_memory_deletes() { + let in_memory_nodes = vec![ + (Nibbles::from_nibbles([0x1]), None), + ( + Nibbles::from_nibbles([0x2]), + Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)), + ), + (Nibbles::from_nibbles([0x3]), None), + ]; + + let expected_results = vec![( + Nibbles::from_nibbles([0x2]), + BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None), + )]; + + let test_case = + InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results }; + execute_test(test_case); + } + + #[test] + fn test_current_key_tracking() { + let db_nodes = vec![( + Nibbles::from_nibbles([0x2]), + BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None), + )]; + + let in_memory_nodes = vec![ + ( + Nibbles::from_nibbles([0x1]), + Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)), + ), + ( + Nibbles::from_nibbles([0x3]), + Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)), + ), + ]; + + let db_nodes_map: BTreeMap = db_nodes.into_iter().collect(); + let db_nodes_arc = Arc::new(db_nodes_map); + let visited_keys = Arc::new(Mutex::new(Vec::new())); + let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys); + + let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes); + + assert_eq!(cursor.current().unwrap(), None); + + cursor.seek(Nibbles::from_nibbles([0x1])).unwrap(); + assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1]))); + + cursor.next().unwrap(); + assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2]))); + + cursor.next().unwrap(); + assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3]))); } } diff --git a/crates/trie/trie/src/trie_cursor/mock.rs b/crates/trie/trie/src/trie_cursor/mock.rs index feda1c72a8..4b0b7f699d 100644 --- a/crates/trie/trie/src/trie_cursor/mock.rs +++ b/crates/trie/trie/src/trie_cursor/mock.rs @@ -93,7 +93,8 @@ pub struct MockTrieCursor { } impl MockTrieCursor { - fn new( + /// Creates a new mock trie cursor with the given trie nodes and key tracking. + pub fn new( trie_nodes: Arc>, visited_keys: Arc>>>, ) -> Self {