diff --git a/crates/trie/trie/src/forward_cursor.rs b/crates/trie/trie/src/forward_cursor.rs index 1f14a462b1..da71326cea 100644 --- a/crates/trie/trie/src/forward_cursor.rs +++ b/crates/trie/trie/src/forward_cursor.rs @@ -23,8 +23,8 @@ impl<'a, K, V> ForwardInMemoryCursor<'a, K, V> { impl<'a, K, V> ForwardInMemoryCursor<'a, K, V> where - K: PartialOrd + Copy, - V: Copy, + K: PartialOrd + Clone, + V: Clone, { /// Advances the cursor forward while `comparator` returns `true` or until the collection is /// exhausted. Returns the first entry for which `comparator` returns `false` or `None`. @@ -34,7 +34,7 @@ where self.index += 1; entry = self.entries.get(self.index); } - entry.copied() + entry.cloned() } /// Returns the first entry from the current cursor position that's greater or equal to the diff --git a/crates/trie/trie/src/trie_cursor/database_cursors.rs b/crates/trie/trie/src/trie_cursor/database_cursors.rs index 53a64a0b09..585ef46249 100644 --- a/crates/trie/trie/src/trie_cursor/database_cursors.rs +++ b/crates/trie/trie/src/trie_cursor/database_cursors.rs @@ -30,7 +30,7 @@ impl<'a, TX: DbTx> TrieCursorFactory for &'a TX { /// A cursor over the account trie. #[derive(Debug)] -pub struct DatabaseAccountTrieCursor(C); +pub struct DatabaseAccountTrieCursor(pub(crate) C); impl DatabaseAccountTrieCursor { /// Create a new account trie cursor. @@ -59,6 +59,11 @@ where Ok(self.0.seek(StoredNibbles(key))?.map(|value| (value.0 .0, value.1 .0))) } + /// Move the cursor to the next entry and return it. + fn next(&mut self) -> Result, DatabaseError> { + Ok(self.0.next()?.map(|value| (value.0 .0, value.1 .0))) + } + /// Retrieves the current key in the cursor. fn current(&mut self) -> Result, DatabaseError> { Ok(self.0.current()?.map(|(k, _)| k.0)) @@ -83,7 +88,7 @@ impl DatabaseStorageTrieCursor { impl TrieCursor for DatabaseStorageTrieCursor where - C: DbDupCursorRO + DbCursorRO + Send + Sync, + C: DbCursorRO + DbDupCursorRO + Send + Sync, { /// Seeks an exact match for the given key in the storage trie. fn seek_exact( @@ -108,6 +113,11 @@ where .map(|value| (value.nibbles.0, value.node))) } + /// Move the cursor to the next entry and return it. + fn next(&mut self) -> Result, DatabaseError> { + Ok(self.cursor.next_dup()?.map(|(_, v)| (v.nibbles.0, v.node))) + } + /// Retrieves the current value in the storage trie cursor. fn current(&mut self) -> Result, DatabaseError> { Ok(self.cursor.current()?.map(|(_, v)| v.nibbles.0)) diff --git a/crates/trie/trie/src/trie_cursor/in_memory.rs b/crates/trie/trie/src/trie_cursor/in_memory.rs index 983974da38..c74ee0eaf3 100644 --- a/crates/trie/trie/src/trie_cursor/in_memory.rs +++ b/crates/trie/trie/src/trie_cursor/in_memory.rs @@ -61,7 +61,7 @@ pub struct InMemoryAccountTrieCursor<'a, C> { last_key: Option, } -impl<'a, C> InMemoryAccountTrieCursor<'a, C> { +impl<'a, C: TrieCursor> InMemoryAccountTrieCursor<'a, C> { const fn new(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self { let in_memory_cursor = ForwardInMemoryCursor::new(&trie_updates.account_nodes); Self { @@ -71,25 +71,86 @@ impl<'a, C> InMemoryAccountTrieCursor<'a, C> { last_key: None, } } + + fn seek_inner( + &mut self, + key: Nibbles, + exact: bool, + ) -> Result, DatabaseError> { + let in_memory = self.in_memory_cursor.seek(&key); + if exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key) { + return Ok(in_memory) + } + + // Reposition the cursor to the first greater or equal node that wasn't removed. + let mut db_entry = self.cursor.seek(key.clone())?; + while db_entry.as_ref().map_or(false, |entry| self.removed_nodes.contains(&entry.0)) { + db_entry = self.cursor.next()?; + } + + // Compare two entries and return the lowest. + // If seek is exact, filter the entry for exact key match. + Ok(compare_trie_node_entries(in_memory, db_entry) + .filter(|(nibbles, _)| !exact || nibbles == &key)) + } + + fn next_inner( + &mut self, + last: Nibbles, + ) -> Result, DatabaseError> { + let in_memory = self.in_memory_cursor.first_after(&last); + + // Reposition the cursor to the first greater or equal node that wasn't removed. + let mut db_entry = self.cursor.seek(last.clone())?; + while db_entry + .as_ref() + .map_or(false, |entry| entry.0 < last || self.removed_nodes.contains(&entry.0)) + { + db_entry = self.cursor.next()?; + } + + // Compare two entries and return the lowest. + Ok(compare_trie_node_entries(in_memory, db_entry)) + } } impl<'a, C: TrieCursor> TrieCursor for InMemoryAccountTrieCursor<'a, C> { fn seek_exact( &mut self, - _key: Nibbles, + key: Nibbles, ) -> Result, DatabaseError> { - unimplemented!() + let entry = self.seek_inner(key, true)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone()); + Ok(entry) } fn seek( &mut self, - _key: Nibbles, + key: Nibbles, ) -> Result, DatabaseError> { - unimplemented!() + let entry = self.seek_inner(key, false)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone()); + Ok(entry) + } + + fn next(&mut self) -> Result, DatabaseError> { + let next = match &self.last_key { + Some(last) => { + let entry = self.next_inner(last.clone())?; + self.last_key = entry.as_ref().map(|entry| entry.0.clone()); + entry + } + // no previous entry was found + None => None, + }; + Ok(next) } fn current(&mut self) -> Result, DatabaseError> { - unimplemented!() + match &self.last_key { + Some(key) => Ok(Some(key.clone())), + None => self.cursor.current(), + } } } @@ -128,22 +189,172 @@ impl<'a, C> InMemoryStorageTrieCursor<'a, C> { } } +impl<'a, C: TrieCursor> InMemoryStorageTrieCursor<'a, C> { + fn seek_inner( + &mut self, + key: Nibbles, + exact: bool, + ) -> Result, DatabaseError> { + let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.seek(&key)); + if self.storage_trie_cleared || + (exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key)) + { + return Ok(in_memory) + } + + // Reposition the cursor to the first greater or equal node that wasn't removed. + let mut db_entry = self.cursor.seek(key.clone())?; + while db_entry.as_ref().map_or(false, |entry| { + self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0)) + }) { + db_entry = self.cursor.next()?; + } + + // Compare two entries and return the lowest. + // If seek is exact, filter the entry for exact key match. + Ok(compare_trie_node_entries(in_memory, db_entry) + .filter(|(nibbles, _)| !exact || nibbles == &key)) + } + + fn next_inner( + &mut self, + last: Nibbles, + ) -> Result, DatabaseError> { + let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.first_after(&last)); + + // Reposition the cursor to the first greater or equal node that wasn't removed. + let mut db_entry = self.cursor.seek(last.clone())?; + while db_entry.as_ref().map_or(false, |entry| { + entry.0 < last || self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0)) + }) { + db_entry = self.cursor.next()?; + } + + // Compare two entries and return the lowest. + Ok(compare_trie_node_entries(in_memory, db_entry)) + } +} + impl<'a, C: TrieCursor> TrieCursor for InMemoryStorageTrieCursor<'a, C> { fn seek_exact( &mut self, - _key: Nibbles, + key: Nibbles, ) -> Result, DatabaseError> { - unimplemented!() + let entry = self.seek_inner(key, true)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone()); + Ok(entry) } fn seek( &mut self, - _key: Nibbles, + key: Nibbles, ) -> Result, DatabaseError> { - unimplemented!() + let entry = self.seek_inner(key, false)?; + self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone()); + Ok(entry) + } + + fn next(&mut self) -> Result, DatabaseError> { + let next = match &self.last_key { + Some(last) => { + let entry = self.next_inner(last.clone())?; + self.last_key = entry.as_ref().map(|entry| entry.0.clone()); + entry + } + // no previous entry was found + None => None, + }; + Ok(next) } fn current(&mut self) -> Result, DatabaseError> { - unimplemented!() + match &self.last_key { + Some(key) => Ok(Some(key.clone())), + None => self.cursor.current(), + } + } +} + +/// Return the node with the lowest nibbles. +/// +/// Given the next in-memory and database entries, return the smallest of the two. +/// If the node keys are the same, the in-memory entry is given precedence. +fn compare_trie_node_entries( + mut in_memory_item: Option<(Nibbles, BranchNodeCompact)>, + mut db_item: Option<(Nibbles, BranchNodeCompact)>, +) -> Option<(Nibbles, BranchNodeCompact)> { + if let Some((in_memory_entry, db_entry)) = in_memory_item.as_ref().zip(db_item.as_ref()) { + // If both are not empty, return the smallest of the two + // In-memory is given precedence if keys are equal + if in_memory_entry.0 <= db_entry.0 { + in_memory_item.take() + } else { + db_item.take() + } + } else { + // Return either non-empty entry + db_item.or(in_memory_item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + prefix_set::{PrefixSetMut, TriePrefixSets}, + test_utils::state_root_prehashed, + StateRoot, + }; + use proptest::prelude::*; + use reth_db::{cursor::DbCursorRW, tables, transaction::DbTxMut}; + use reth_primitives::{Account, U256}; + use reth_provider::test_utils::create_test_provider_factory; + use std::collections::BTreeMap; + + proptest! { + #![proptest_config(ProptestConfig { + cases: 128, ..ProptestConfig::default() + })] + + #[test] + fn fuzz_in_memory_nodes(mut init_state: BTreeMap, mut updated_state: BTreeMap) { + let factory = create_test_provider_factory(); + let provider = factory.provider_rw().unwrap(); + let mut hashed_account_cursor = provider.tx_ref().cursor_write::().unwrap(); + + // Insert init state into database + for (hashed_address, balance) in init_state.clone() { + hashed_account_cursor.upsert(hashed_address, Account { balance, ..Default::default() }).unwrap(); + } + + // Compute initial root and updates + let (_, trie_updates) = StateRoot::from_tx(provider.tx_ref()) + .root_with_updates() + .unwrap(); + + // Insert state updates into database + let mut changes = PrefixSetMut::default(); + for (hashed_address, balance) in updated_state.clone() { + hashed_account_cursor.upsert(hashed_address, Account { balance, ..Default::default() }).unwrap(); + changes.insert(Nibbles::unpack(hashed_address)); + } + + // Compute root with in-memory trie nodes overlay + let (state_root, _) = StateRoot::from_tx(provider.tx_ref()) + .with_prefix_sets(TriePrefixSets { account_prefix_set: changes.freeze(), ..Default::default() }) + .with_trie_cursor_factory(InMemoryTrieCursorFactory::new(provider.tx_ref(), &trie_updates.into_sorted())) + .root_with_updates() + .unwrap(); + + // Verify the result + let mut state = BTreeMap::default(); + state.append(&mut init_state); + state.append(&mut updated_state); + let expected_root = state_root_prehashed( + state.iter().map(|(&key, &balance)| (key, (Account { balance, ..Default::default() }, std::iter::empty()))) + ); + assert_eq!(expected_root, state_root); + + } } } diff --git a/crates/trie/trie/src/trie_cursor/mod.rs b/crates/trie/trie/src/trie_cursor/mod.rs index e5160a5526..d297fa2bf1 100644 --- a/crates/trie/trie/src/trie_cursor/mod.rs +++ b/crates/trie/trie/src/trie_cursor/mod.rs @@ -50,6 +50,9 @@ pub trait TrieCursor: Send + Sync { fn seek(&mut self, key: Nibbles) -> Result, DatabaseError>; + /// Move the cursor to the next key. + fn next(&mut self) -> Result, DatabaseError>; + /// Get the current entry. fn current(&mut self) -> Result, DatabaseError>; } diff --git a/crates/trie/trie/src/trie_cursor/noop.rs b/crates/trie/trie/src/trie_cursor/noop.rs index e49c90613d..8db0cbb9d3 100644 --- a/crates/trie/trie/src/trie_cursor/noop.rs +++ b/crates/trie/trie/src/trie_cursor/noop.rs @@ -12,12 +12,12 @@ impl TrieCursorFactory for NoopTrieCursorFactory { type AccountTrieCursor = NoopAccountTrieCursor; type StorageTrieCursor = NoopStorageTrieCursor; - /// Generates a Noop account trie cursor. + /// Generates a noop account trie cursor. fn account_trie_cursor(&self) -> Result { Ok(NoopAccountTrieCursor::default()) } - /// Generates a Noop storage trie cursor. + /// Generates a noop storage trie cursor. fn storage_trie_cursor( &self, _hashed_address: B256, @@ -32,7 +32,6 @@ impl TrieCursorFactory for NoopTrieCursorFactory { pub struct NoopAccountTrieCursor; impl TrieCursor for NoopAccountTrieCursor { - /// Seeks an exact match within the account trie. fn seek_exact( &mut self, _key: Nibbles, @@ -40,7 +39,6 @@ impl TrieCursor for NoopAccountTrieCursor { Ok(None) } - /// Seeks within the account trie. fn seek( &mut self, _key: Nibbles, @@ -48,7 +46,10 @@ impl TrieCursor for NoopAccountTrieCursor { Ok(None) } - /// Retrieves the current cursor position within the account trie. + fn next(&mut self) -> Result, DatabaseError> { + Ok(None) + } + fn current(&mut self) -> Result, DatabaseError> { Ok(None) } @@ -60,7 +61,6 @@ impl TrieCursor for NoopAccountTrieCursor { pub struct NoopStorageTrieCursor; impl TrieCursor for NoopStorageTrieCursor { - /// Seeks an exact match in storage tries. fn seek_exact( &mut self, _key: Nibbles, @@ -68,7 +68,6 @@ impl TrieCursor for NoopStorageTrieCursor { Ok(None) } - /// Seeks a key in storage tries. fn seek( &mut self, _key: Nibbles, @@ -76,7 +75,10 @@ impl TrieCursor for NoopStorageTrieCursor { Ok(None) } - /// Retrieves the current cursor position within storage tries. + fn next(&mut self) -> Result, DatabaseError> { + Ok(None) + } + fn current(&mut self) -> Result, DatabaseError> { Ok(None) }