chore(trie): Use Vec<Option<...>> in InMemoryTrieCursor (#18479)

This commit is contained in:
Brian Picciano
2025-09-19 15:24:46 +02:00
committed by GitHub
parent d6160de610
commit ebe1a8b014
5 changed files with 509 additions and 324 deletions

View File

@@ -107,15 +107,8 @@ impl TrieUpdates {
}
/// Converts trie updates into [`TrieUpdatesSorted`].
pub fn into_sorted(self) -> TrieUpdatesSorted {
let mut account_nodes = Vec::from_iter(self.account_nodes);
account_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let storage_tries = self
.storage_tries
.into_iter()
.map(|(hashed_address, updates)| (hashed_address, updates.into_sorted()))
.collect();
TrieUpdatesSorted { removed_nodes: self.removed_nodes, account_nodes, storage_tries }
pub fn into_sorted(mut self) -> TrieUpdatesSorted {
self.drain_into_sorted()
}
/// Converts trie updates into [`TrieUpdatesSorted`], but keeping the maps allocated by
@@ -126,7 +119,17 @@ impl TrieUpdates {
/// This allows us to reuse the allocated space. This allocates new space for the sorted
/// updates, like `into_sorted`.
pub fn drain_into_sorted(&mut self) -> TrieUpdatesSorted {
let mut account_nodes = self.account_nodes.drain().collect::<Vec<_>>();
let mut account_nodes = self
.account_nodes
.drain()
.map(|(path, node)| {
// Updated nodes take precedence over removed nodes.
self.removed_nodes.remove(&path);
(path, Some(node))
})
.collect::<Vec<_>>();
account_nodes.extend(self.removed_nodes.drain().map(|path| (path, None)));
account_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let storage_tries = self
@@ -134,12 +137,7 @@ impl TrieUpdates {
.drain()
.map(|(hashed_address, updates)| (hashed_address, updates.into_sorted()))
.collect();
TrieUpdatesSorted {
removed_nodes: self.removed_nodes.clone(),
account_nodes,
storage_tries,
}
TrieUpdatesSorted { account_nodes, storage_tries }
}
/// Converts trie updates into [`TrieUpdatesSortedRef`].
@@ -266,14 +264,21 @@ impl StorageTrieUpdates {
}
/// Convert storage trie updates into [`StorageTrieUpdatesSorted`].
pub fn into_sorted(self) -> StorageTrieUpdatesSorted {
let mut storage_nodes = Vec::from_iter(self.storage_nodes);
pub fn into_sorted(mut self) -> StorageTrieUpdatesSorted {
let mut storage_nodes = self
.storage_nodes
.into_iter()
.map(|(path, node)| {
// Updated nodes take precedence over removed nodes.
self.removed_nodes.remove(&path);
(path, Some(node))
})
.collect::<Vec<_>>();
storage_nodes.extend(self.removed_nodes.into_iter().map(|path| (path, None)));
storage_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0));
StorageTrieUpdatesSorted {
is_deleted: self.is_deleted,
removed_nodes: self.removed_nodes,
storage_nodes,
}
StorageTrieUpdatesSorted { is_deleted: self.is_deleted, storage_nodes }
}
/// Convert storage trie updates into [`StorageTrieUpdatesSortedRef`].
@@ -425,25 +430,19 @@ pub struct TrieUpdatesSortedRef<'a> {
#[derive(PartialEq, Eq, Clone, Default, Debug)]
#[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))]
pub struct TrieUpdatesSorted {
/// Sorted collection of updated state nodes with corresponding paths.
pub account_nodes: Vec<(Nibbles, BranchNodeCompact)>,
/// The set of removed state node keys.
pub removed_nodes: HashSet<Nibbles>,
/// Sorted collection of updated state nodes with corresponding paths. None indicates that a
/// node was removed.
pub account_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
/// Storage tries stored by hashed address of the account the trie belongs to.
pub storage_tries: B256Map<StorageTrieUpdatesSorted>,
}
impl TrieUpdatesSorted {
/// Returns reference to updated account nodes.
pub fn account_nodes_ref(&self) -> &[(Nibbles, BranchNodeCompact)] {
pub fn account_nodes_ref(&self) -> &[(Nibbles, Option<BranchNodeCompact>)] {
&self.account_nodes
}
/// Returns reference to removed account nodes.
pub const fn removed_nodes_ref(&self) -> &HashSet<Nibbles> {
&self.removed_nodes
}
/// Returns reference to updated storage tries.
pub const fn storage_tries_ref(&self) -> &B256Map<StorageTrieUpdatesSorted> {
&self.storage_tries
@@ -468,10 +467,9 @@ pub struct StorageTrieUpdatesSortedRef<'a> {
pub struct StorageTrieUpdatesSorted {
/// Flag indicating whether the trie has been deleted/wiped.
pub is_deleted: bool,
/// Sorted collection of updated storage nodes with corresponding paths.
pub storage_nodes: Vec<(Nibbles, BranchNodeCompact)>,
/// The set of removed storage node keys.
pub removed_nodes: HashSet<Nibbles>,
/// Sorted collection of updated storage nodes with corresponding paths. None indicates a node
/// is removed.
pub storage_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
}
impl StorageTrieUpdatesSorted {
@@ -481,14 +479,9 @@ impl StorageTrieUpdatesSorted {
}
/// Returns reference to updated storage nodes.
pub fn storage_nodes_ref(&self) -> &[(Nibbles, BranchNodeCompact)] {
pub fn storage_nodes_ref(&self) -> &[(Nibbles, Option<BranchNodeCompact>)] {
&self.storage_nodes
}
/// Returns reference to removed storage nodes.
pub const fn removed_nodes_ref(&self) -> &HashSet<Nibbles> {
&self.removed_nodes
}
}
/// Excludes empty nibbles from the given iterator.

View File

@@ -428,6 +428,7 @@ fn account_and_storage_trie() {
let (nibbles1a, node1a) = account_updates.first().unwrap();
assert_eq!(nibbles1a.to_vec(), vec![0xB]);
let node1a = node1a.as_ref().unwrap();
assert_eq!(node1a.state_mask, TrieMask::new(0b1011));
assert_eq!(node1a.tree_mask, TrieMask::new(0b0001));
assert_eq!(node1a.hash_mask, TrieMask::new(0b1001));
@@ -436,6 +437,7 @@ fn account_and_storage_trie() {
let (nibbles2a, node2a) = account_updates.last().unwrap();
assert_eq!(nibbles2a.to_vec(), vec![0xB, 0x0]);
let node2a = node2a.as_ref().unwrap();
assert_eq!(node2a.state_mask, TrieMask::new(0b10001));
assert_eq!(node2a.tree_mask, TrieMask::new(0b00000));
assert_eq!(node2a.hash_mask, TrieMask::new(0b10000));
@@ -471,6 +473,7 @@ fn account_and_storage_trie() {
let (nibbles1b, node1b) = account_updates.first().unwrap();
assert_eq!(nibbles1b.to_vec(), vec![0xB]);
let node1b = node1b.as_ref().unwrap();
assert_eq!(node1b.state_mask, TrieMask::new(0b1011));
assert_eq!(node1b.tree_mask, TrieMask::new(0b0001));
assert_eq!(node1b.hash_mask, TrieMask::new(0b1011));
@@ -481,6 +484,7 @@ fn account_and_storage_trie() {
let (nibbles2b, node2b) = account_updates.last().unwrap();
assert_eq!(nibbles2b.to_vec(), vec![0xB, 0x0]);
let node2b = node2b.as_ref().unwrap();
assert_eq!(node2a, node2b);
tx.commit().unwrap();
@@ -520,8 +524,9 @@ fn account_and_storage_trie() {
assert_eq!(trie_updates.account_nodes_ref().len(), 1);
let (nibbles1c, node1c) = trie_updates.account_nodes_ref().iter().next().unwrap();
assert_eq!(nibbles1c.to_vec(), vec![0xB]);
let entry = trie_updates.account_nodes_ref().iter().next().unwrap();
assert_eq!(entry.0.to_vec(), vec![0xB]);
let node1c = entry.1;
assert_eq!(node1c.state_mask, TrieMask::new(0b1011));
assert_eq!(node1c.tree_mask, TrieMask::new(0b0000));
@@ -578,8 +583,9 @@ fn account_and_storage_trie() {
assert_eq!(trie_updates.account_nodes_ref().len(), 1);
let (nibbles1d, node1d) = trie_updates.account_nodes_ref().iter().next().unwrap();
assert_eq!(nibbles1d.to_vec(), vec![0xB]);
let entry = trie_updates.account_nodes_ref().iter().next().unwrap();
assert_eq!(entry.0.to_vec(), vec![0xB]);
let node1d = entry.1;
assert_eq!(node1d.state_mask, TrieMask::new(0b1011));
assert_eq!(node1d.tree_mask, TrieMask::new(0b0000));

View File

@@ -7,7 +7,7 @@ use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner};
use reth_trie::{
hashed_cursor::{noop::NoopHashedStorageCursor, HashedPostStateStorageCursor},
node_iter::{TrieElement, TrieNodeIter},
trie_cursor::{noop::NoopStorageTrieCursor, InMemoryStorageTrieCursor},
trie_cursor::{noop::NoopStorageTrieCursor, InMemoryTrieCursor},
updates::StorageTrieUpdates,
walker::TrieWalker,
HashedStorage,
@@ -134,10 +134,9 @@ fn calculate_root_from_leaves_repeated(c: &mut Criterion) {
};
let walker = TrieWalker::<_>::storage_trie(
InMemoryStorageTrieCursor::new(
B256::ZERO,
NoopStorageTrieCursor::default(),
Some(&trie_updates_sorted),
InMemoryTrieCursor::new(
Some(NoopStorageTrieCursor::default()),
&trie_updates_sorted.storage_nodes,
),
prefix_set,
);

View File

@@ -1,9 +1,6 @@
use super::{TrieCursor, TrieCursorFactory};
use crate::{
forward_cursor::ForwardInMemoryCursor,
updates::{StorageTrieUpdatesSorted, TrieUpdatesSorted},
};
use alloy_primitives::{map::HashSet, B256};
use crate::{forward_cursor::ForwardInMemoryCursor, updates::TrieUpdatesSorted};
use alloy_primitives::B256;
use reth_storage_errors::db::DatabaseError;
use reth_trie_common::{BranchNodeCompact, Nibbles};
@@ -24,283 +21,472 @@ impl<'a, CF> InMemoryTrieCursorFactory<'a, CF> {
}
impl<'a, CF: TrieCursorFactory> TrieCursorFactory for InMemoryTrieCursorFactory<'a, CF> {
type AccountTrieCursor = InMemoryAccountTrieCursor<'a, CF::AccountTrieCursor>;
type StorageTrieCursor = InMemoryStorageTrieCursor<'a, CF::StorageTrieCursor>;
type AccountTrieCursor = InMemoryTrieCursor<'a, CF::AccountTrieCursor>;
type StorageTrieCursor = InMemoryTrieCursor<'a, CF::StorageTrieCursor>;
fn account_trie_cursor(&self) -> Result<Self::AccountTrieCursor, DatabaseError> {
let cursor = self.cursor_factory.account_trie_cursor()?;
Ok(InMemoryAccountTrieCursor::new(cursor, self.trie_updates))
Ok(InMemoryTrieCursor::new(Some(cursor), self.trie_updates.account_nodes_ref()))
}
fn storage_trie_cursor(
&self,
hashed_address: B256,
) -> Result<Self::StorageTrieCursor, DatabaseError> {
let cursor = self.cursor_factory.storage_trie_cursor(hashed_address)?;
Ok(InMemoryStorageTrieCursor::new(
hashed_address,
cursor,
self.trie_updates.storage_tries.get(&hashed_address),
))
}
}
// if the storage trie has no updates then we use this as the in-memory overlay.
static EMPTY_UPDATES: Vec<(Nibbles, Option<BranchNodeCompact>)> = Vec::new();
/// The cursor to iterate over account trie updates and corresponding database entries.
/// It will always give precedence to the data from the trie updates.
#[derive(Debug)]
pub struct InMemoryAccountTrieCursor<'a, C> {
/// The underlying cursor.
cursor: C,
/// Forward-only in-memory cursor over storage trie nodes.
in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, BranchNodeCompact>,
/// Collection of removed trie nodes.
removed_nodes: &'a HashSet<Nibbles>,
/// Last key returned by the cursor.
last_key: Option<Nibbles>,
}
let storage_trie_updates = self.trie_updates.storage_tries.get(&hashed_address);
let (storage_nodes, cleared) = storage_trie_updates
.map(|u| (u.storage_nodes_ref(), u.is_deleted()))
.unwrap_or((&EMPTY_UPDATES, false));
impl<'a, C: TrieCursor> InMemoryAccountTrieCursor<'a, C> {
/// Create new account trie cursor from underlying cursor and reference to
/// [`TrieUpdatesSorted`].
pub fn new(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
let in_memory_cursor = ForwardInMemoryCursor::new(&trie_updates.account_nodes);
Self {
cursor,
in_memory_cursor,
removed_nodes: &trie_updates.removed_nodes,
last_key: None,
}
}
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.seek(&key);
if in_memory.as_ref().is_some_and(|entry| entry.0 == key) {
return Ok(in_memory)
}
// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(key)?;
while db_entry.as_ref().is_some_and(|entry| self.removed_nodes.contains(&entry.0)) {
db_entry = self.cursor.next()?;
}
// Compare two entries and return the lowest.
// If seek is exact, filter the entry for exact key match.
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}
fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.first_after(&last);
// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(last)?;
while db_entry
.as_ref()
.is_some_and(|entry| entry.0 < last || self.removed_nodes.contains(&entry.0))
{
db_entry = self.cursor.next()?;
}
// Compare two entries and return the lowest.
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}
impl<C: TrieCursor> TrieCursor for InMemoryAccountTrieCursor<'_, C> {
fn seek_exact(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn seek(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(*last)?;
self.last_key = entry.as_ref().map(|entry| entry.0);
entry
}
// no previous entry was found
None => None,
};
Ok(next)
}
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
match &self.last_key {
Some(key) => Ok(Some(*key)),
None => self.cursor.current(),
}
}
}
/// The cursor to iterate over storage trie updates and corresponding database entries.
/// It will always give precedence to the data from the trie updates.
#[derive(Debug)]
#[expect(dead_code)]
pub struct InMemoryStorageTrieCursor<'a, C> {
/// The hashed address of the account that trie belongs to.
hashed_address: B256,
/// The underlying cursor.
cursor: C,
/// Forward-only in-memory cursor over storage trie nodes.
in_memory_cursor: Option<ForwardInMemoryCursor<'a, Nibbles, BranchNodeCompact>>,
/// Reference to the set of removed storage node keys.
removed_nodes: Option<&'a HashSet<Nibbles>>,
/// The flag indicating whether the storage trie was cleared.
storage_trie_cleared: bool,
/// Last key returned by the cursor.
last_key: Option<Nibbles>,
}
impl<'a, C> InMemoryStorageTrieCursor<'a, C> {
/// Create new storage trie cursor from underlying cursor and reference to
/// [`StorageTrieUpdatesSorted`].
pub fn new(
hashed_address: B256,
cursor: C,
updates: Option<&'a StorageTrieUpdatesSorted>,
) -> Self {
let in_memory_cursor = updates.map(|u| ForwardInMemoryCursor::new(&u.storage_nodes));
let removed_nodes = updates.map(|u| &u.removed_nodes);
let storage_trie_cleared = updates.is_some_and(|u| u.is_deleted);
Self {
hashed_address,
cursor,
in_memory_cursor,
removed_nodes,
storage_trie_cleared,
last_key: None,
}
}
}
impl<C: TrieCursor> InMemoryStorageTrieCursor<'_, C> {
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.seek(&key));
if self.storage_trie_cleared || in_memory.as_ref().is_some_and(|entry| entry.0 == key) {
return Ok(in_memory.filter(|(nibbles, _)| !exact || nibbles == &key))
}
// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(key)?;
while db_entry
.as_ref()
.is_some_and(|entry| self.removed_nodes.as_ref().is_some_and(|r| r.contains(&entry.0)))
{
db_entry = self.cursor.next()?;
}
// Compare two entries and return the lowest.
// If seek is exact, filter the entry for exact key match.
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}
fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.first_after(&last));
if self.storage_trie_cleared {
return Ok(in_memory)
}
// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(last)?;
while db_entry.as_ref().is_some_and(|entry| {
entry.0 < last || self.removed_nodes.as_ref().is_some_and(|r| r.contains(&entry.0))
}) {
db_entry = self.cursor.next()?;
}
// Compare two entries and return the lowest.
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}
impl<C: TrieCursor> TrieCursor for InMemoryStorageTrieCursor<'_, C> {
fn seek_exact(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn seek(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(*last)?;
self.last_key = entry.as_ref().map(|entry| entry.0);
entry
}
// no previous entry was found
None => None,
};
Ok(next)
}
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
match &self.last_key {
Some(key) => Ok(Some(*key)),
None => self.cursor.current(),
}
}
}
/// Return the node with the lowest nibbles.
///
/// Given the next in-memory and database entries, return the smallest of the two.
/// If the node keys are the same, the in-memory entry is given precedence.
fn compare_trie_node_entries(
mut in_memory_item: Option<(Nibbles, BranchNodeCompact)>,
mut db_item: Option<(Nibbles, BranchNodeCompact)>,
) -> Option<(Nibbles, BranchNodeCompact)> {
if let Some((in_memory_entry, db_entry)) = in_memory_item.as_ref().zip(db_item.as_ref()) {
// If both are not empty, return the smallest of the two
// In-memory is given precedence if keys are equal
if in_memory_entry.0 <= db_entry.0 {
in_memory_item.take()
let cursor = if cleared {
None
} else {
db_item.take()
}
} else {
// Return either non-empty entry
db_item.or(in_memory_item)
Some(self.cursor_factory.storage_trie_cursor(hashed_address)?)
};
Ok(InMemoryTrieCursor::new(cursor, storage_nodes))
}
}
/// A cursor to iterate over trie updates and corresponding database entries.
/// It will always give precedence to the data from the trie updates.
#[derive(Debug)]
pub struct InMemoryTrieCursor<'a, C> {
/// The underlying cursor. If None then it is assumed there is no DB data.
cursor: Option<C>,
/// Forward-only in-memory cursor over storage trie nodes.
in_memory_cursor: ForwardInMemoryCursor<'a, Nibbles, Option<BranchNodeCompact>>,
/// Last key returned by the cursor.
last_key: Option<Nibbles>,
}
impl<'a, C: TrieCursor> InMemoryTrieCursor<'a, C> {
/// Create new trie cursor which combines a DB cursor (None to assume empty DB) and a set of
/// in-memory trie nodes.
pub fn new(
cursor: Option<C>,
trie_updates: &'a [(Nibbles, Option<BranchNodeCompact>)],
) -> Self {
let in_memory_cursor = ForwardInMemoryCursor::new(trie_updates);
Self { cursor, in_memory_cursor, last_key: None }
}
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let mut mem_entry = self.in_memory_cursor.seek(&key);
let mut db_entry = self.cursor.as_mut().map(|c| c.seek(key)).transpose()?.flatten();
// exact matching is easy, if overlay has a value then return that (updated or removed), or
// if db has a value then return that.
if exact {
return Ok(match (mem_entry, db_entry) {
(Some((mem_key, entry_inner)), _) if mem_key == key => {
entry_inner.map(|node| (key, node))
}
(_, Some((db_key, node))) if db_key == key => Some((key, node)),
_ => None,
})
}
loop {
match (mem_entry, &db_entry) {
(Some((mem_key, None)), _)
if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key < db_key) =>
{
// If overlay has a removed node but DB cursor is exhausted or ahead of the
// in-memory cursor then move ahead in-memory, as there might be further
// non-removed overlay nodes.
mem_entry = self.in_memory_cursor.first_after(&mem_key);
}
(Some((mem_key, None)), Some((db_key, _))) if &mem_key == db_key => {
// If overlay has a removed node which is returned from DB then move both
// cursors ahead to the next key.
mem_entry = self.in_memory_cursor.first_after(&mem_key);
db_entry = self.cursor.as_mut().map(|c| c.next()).transpose()?.flatten();
}
(Some((mem_key, Some(node))), _)
if db_entry.as_ref().is_none_or(|(db_key, _)| &mem_key <= db_key) =>
{
// If overlay returns a node prior to the DB's node, or the DB is exhausted,
// then we return the overlay's node.
return Ok(Some((mem_key, node)))
}
// All other cases:
// - mem_key > db_key
// - overlay is exhausted
// Return the db_entry. If DB is also exhausted then this returns None.
_ => return Ok(db_entry),
}
}
}
fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let Some(key) = last.increment() else { return Ok(None) };
self.seek_inner(key, false)
}
}
impl<C: TrieCursor> TrieCursor for InMemoryTrieCursor<'_, C> {
fn seek_exact(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn seek(
&mut self,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| *nibbles);
Ok(entry)
}
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(*last)?;
self.last_key = entry.as_ref().map(|entry| entry.0);
entry
}
// no previous entry was found
None => None,
};
Ok(next)
}
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
match &self.last_key {
Some(key) => Ok(Some(*key)),
None => Ok(self.cursor.as_mut().map(|c| c.current()).transpose()?.flatten()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trie_cursor::mock::MockTrieCursor;
use parking_lot::Mutex;
use std::{collections::BTreeMap, sync::Arc};
#[derive(Debug)]
struct InMemoryTrieCursorTestCase {
db_nodes: Vec<(Nibbles, BranchNodeCompact)>,
in_memory_nodes: Vec<(Nibbles, Option<BranchNodeCompact>)>,
expected_results: Vec<(Nibbles, BranchNodeCompact)>,
}
fn execute_test(test_case: InMemoryTrieCursorTestCase) {
let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> =
test_case.db_nodes.into_iter().collect();
let db_nodes_arc = Arc::new(db_nodes_map);
let visited_keys = Arc::new(Mutex::new(Vec::new()));
let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &test_case.in_memory_nodes);
let mut results = Vec::new();
if let Some(first_expected) = test_case.expected_results.first() {
if let Ok(Some(entry)) = cursor.seek(first_expected.0) {
results.push(entry);
}
}
while let Ok(Some(entry)) = cursor.next() {
results.push(entry);
}
assert_eq!(
results, test_case.expected_results,
"Results mismatch.\nGot: {:?}\nExpected: {:?}",
results, test_case.expected_results
);
}
#[test]
fn test_empty_db_and_memory() {
let test_case = InMemoryTrieCursorTestCase {
db_nodes: vec![],
in_memory_nodes: vec![],
expected_results: vec![],
};
execute_test(test_case);
}
#[test]
fn test_only_db_nodes() {
let db_nodes = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
];
let test_case = InMemoryTrieCursorTestCase {
db_nodes: db_nodes.clone(),
in_memory_nodes: vec![],
expected_results: db_nodes,
};
execute_test(test_case);
}
#[test]
fn test_only_in_memory_nodes() {
let in_memory_nodes = vec![
(
Nibbles::from_nibbles([0x1]),
Some(BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
),
(
Nibbles::from_nibbles([0x2]),
Some(BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
),
(
Nibbles::from_nibbles([0x3]),
Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
),
];
let expected_results: Vec<(Nibbles, BranchNodeCompact)> = in_memory_nodes
.iter()
.filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone())))
.collect();
let test_case =
InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_in_memory_overwrites_db() {
let db_nodes = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
];
let in_memory_nodes = vec![
(
Nibbles::from_nibbles([0x1]),
Some(BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
),
(
Nibbles::from_nibbles([0x3]),
Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
),
];
let expected_results = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b1111, 0b1111, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
];
let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_in_memory_deletes_db_nodes() {
let db_nodes = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0011, 0b0010, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
];
let in_memory_nodes = vec![(Nibbles::from_nibbles([0x2]), None)];
let expected_results = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0011, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
];
let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_complex_interleaving() {
let db_nodes = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
(Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
(Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(0b0111, 0b0111, 0, vec![], None)),
];
let in_memory_nodes = vec![
(
Nibbles::from_nibbles([0x2]),
Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
),
(Nibbles::from_nibbles([0x3]), None),
(
Nibbles::from_nibbles([0x4]),
Some(BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
),
(
Nibbles::from_nibbles([0x6]),
Some(BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
),
(Nibbles::from_nibbles([0x7]), None),
(
Nibbles::from_nibbles([0x8]),
Some(BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
),
];
let expected_results = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
(Nibbles::from_nibbles([0x4]), BranchNodeCompact::new(0b0100, 0b0100, 0, vec![], None)),
(Nibbles::from_nibbles([0x5]), BranchNodeCompact::new(0b0101, 0b0101, 0, vec![], None)),
(Nibbles::from_nibbles([0x6]), BranchNodeCompact::new(0b0110, 0b0110, 0, vec![], None)),
(Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(0b1000, 0b1000, 0, vec![], None)),
];
let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_seek_exact() {
let db_nodes = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
(Nibbles::from_nibbles([0x3]), BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
];
let in_memory_nodes = vec![(
Nibbles::from_nibbles([0x2]),
Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
)];
let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
let db_nodes_arc = Arc::new(db_nodes_map);
let visited_keys = Arc::new(Mutex::new(Vec::new()));
let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
let result = cursor.seek_exact(Nibbles::from_nibbles([0x2])).unwrap();
assert_eq!(
result,
Some((
Nibbles::from_nibbles([0x2]),
BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)
))
);
let result = cursor.seek_exact(Nibbles::from_nibbles([0x3])).unwrap();
assert_eq!(
result,
Some((
Nibbles::from_nibbles([0x3]),
BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)
))
);
let result = cursor.seek_exact(Nibbles::from_nibbles([0x4])).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_multiple_consecutive_deletes() {
let db_nodes: Vec<(Nibbles, BranchNodeCompact)> = (1..=10)
.map(|i| {
(
Nibbles::from_nibbles([i]),
BranchNodeCompact::new(i as u16, i as u16, 0, vec![], None),
)
})
.collect();
let in_memory_nodes = vec![
(Nibbles::from_nibbles([0x3]), None),
(Nibbles::from_nibbles([0x4]), None),
(Nibbles::from_nibbles([0x5]), None),
(Nibbles::from_nibbles([0x6]), None),
];
let expected_results = vec![
(Nibbles::from_nibbles([0x1]), BranchNodeCompact::new(1, 1, 0, vec![], None)),
(Nibbles::from_nibbles([0x2]), BranchNodeCompact::new(2, 2, 0, vec![], None)),
(Nibbles::from_nibbles([0x7]), BranchNodeCompact::new(7, 7, 0, vec![], None)),
(Nibbles::from_nibbles([0x8]), BranchNodeCompact::new(8, 8, 0, vec![], None)),
(Nibbles::from_nibbles([0x9]), BranchNodeCompact::new(9, 9, 0, vec![], None)),
(Nibbles::from_nibbles([0xa]), BranchNodeCompact::new(10, 10, 0, vec![], None)),
];
let test_case = InMemoryTrieCursorTestCase { db_nodes, in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_empty_db_with_in_memory_deletes() {
let in_memory_nodes = vec![
(Nibbles::from_nibbles([0x1]), None),
(
Nibbles::from_nibbles([0x2]),
Some(BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None)),
),
(Nibbles::from_nibbles([0x3]), None),
];
let expected_results = vec![(
Nibbles::from_nibbles([0x2]),
BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
)];
let test_case =
InMemoryTrieCursorTestCase { db_nodes: vec![], in_memory_nodes, expected_results };
execute_test(test_case);
}
#[test]
fn test_current_key_tracking() {
let db_nodes = vec![(
Nibbles::from_nibbles([0x2]),
BranchNodeCompact::new(0b0010, 0b0010, 0, vec![], None),
)];
let in_memory_nodes = vec![
(
Nibbles::from_nibbles([0x1]),
Some(BranchNodeCompact::new(0b0001, 0b0001, 0, vec![], None)),
),
(
Nibbles::from_nibbles([0x3]),
Some(BranchNodeCompact::new(0b0011, 0b0011, 0, vec![], None)),
),
];
let db_nodes_map: BTreeMap<Nibbles, BranchNodeCompact> = db_nodes.into_iter().collect();
let db_nodes_arc = Arc::new(db_nodes_map);
let visited_keys = Arc::new(Mutex::new(Vec::new()));
let mock_cursor = MockTrieCursor::new(db_nodes_arc, visited_keys);
let mut cursor = InMemoryTrieCursor::new(Some(mock_cursor), &in_memory_nodes);
assert_eq!(cursor.current().unwrap(), None);
cursor.seek(Nibbles::from_nibbles([0x1])).unwrap();
assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x1])));
cursor.next().unwrap();
assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x2])));
cursor.next().unwrap();
assert_eq!(cursor.current().unwrap(), Some(Nibbles::from_nibbles([0x3])));
}
}

View File

@@ -93,7 +93,8 @@ pub struct MockTrieCursor {
}
impl MockTrieCursor {
fn new(
/// Creates a new mock trie cursor with the given trie nodes and key tracking.
pub fn new(
trie_nodes: Arc<BTreeMap<Nibbles, BranchNodeCompact>>,
visited_keys: Arc<Mutex<Vec<KeyVisit<Nibbles>>>>,
) -> Self {