diff --git a/crates/engine/tree/src/tree/trie_updates.rs b/crates/engine/tree/src/tree/trie_updates.rs index ba8f7fc16a..11bb360c15 100644 --- a/crates/engine/tree/src/tree/trie_updates.rs +++ b/crates/engine/tree/src/tree/trie_updates.rs @@ -17,16 +17,14 @@ struct EntryDiff { #[derive(Debug, Default)] struct TrieUpdatesDiff { + /// Account node differences. `None` in value means the node was removed. account_nodes: HashMap>>, - removed_nodes: HashMap>, storage_tries: HashMap, } impl TrieUpdatesDiff { fn has_differences(&self) -> bool { - !self.account_nodes.is_empty() || - !self.removed_nodes.is_empty() || - !self.storage_tries.is_empty() + !self.account_nodes.is_empty() || !self.storage_tries.is_empty() } pub(super) fn log_differences(mut self) { @@ -35,18 +33,6 @@ impl TrieUpdatesDiff { warn!(target: "engine::tree", ?path, ?task, ?regular, ?database, "Difference in account trie updates"); } - for ( - path, - EntryDiff { - task: task_removed, - regular: regular_removed, - database: database_not_exists, - }, - ) in &self.removed_nodes - { - warn!(target: "engine::tree", ?path, ?task_removed, ?regular_removed, ?database_not_exists, "Difference in removed account trie nodes"); - } - for (address, storage_diff) in self.storage_tries { storage_diff.log_differences(address); } @@ -100,15 +86,13 @@ impl StorageTrieUpdatesDiff { /// and logs the differences if there's any. pub(super) fn compare_trie_updates( trie_cursor_factory: impl TrieCursorFactory, - task: TrieUpdates, - regular: TrieUpdates, + mut task: TrieUpdates, + mut regular: TrieUpdates, ) -> Result<(), DatabaseError> { - let mut task = adjust_trie_updates(task); - let mut regular = adjust_trie_updates(regular); - let mut diff = TrieUpdatesDiff::default(); - // compare account nodes + // compare account nodes (both updated and removed are in account_nodes map) + // None = removed, Some(node) = updated let mut account_trie_cursor = trie_cursor_factory.account_trie_cursor()?; for key in task .account_nodes @@ -117,37 +101,13 @@ pub(super) fn compare_trie_updates( .copied() .collect::>() { - let (task, regular) = (task.account_nodes.remove(&key), regular.account_nodes.remove(&key)); + let task_entry = task.account_nodes.remove(&key).flatten(); + let regular_entry = regular.account_nodes.remove(&key).flatten(); let database = account_trie_cursor.seek_exact(key)?.map(|x| x.1); - if !branch_nodes_equal(task.as_ref(), regular.as_ref(), database.as_ref())? { - diff.account_nodes.insert(key, EntryDiff { task, regular, database }); - } - } - - // compare removed nodes - let mut account_trie_cursor = trie_cursor_factory.account_trie_cursor()?; - for key in task - .removed_nodes - .iter() - .chain(regular.removed_nodes.iter()) - .copied() - .collect::>() - { - let (task_removed, regular_removed) = - (task.removed_nodes.contains(&key), regular.removed_nodes.contains(&key)); - let database_not_exists = account_trie_cursor.seek_exact(key)?.is_none(); - // If the deletion is a no-op, meaning that the entry is not in the - // database, do not add it to the diff. - if task_removed != regular_removed && !database_not_exists { - diff.removed_nodes.insert( - key, - EntryDiff { - task: task_removed, - regular: regular_removed, - database: database_not_exists, - }, - ); + if !branch_nodes_equal(task_entry.as_ref(), regular_entry.as_ref(), database.as_ref())? { + diff.account_nodes + .insert(key, EntryDiff { task: task_entry, regular: regular_entry, database }); } } @@ -244,36 +204,6 @@ fn compare_storage_trie_updates( Ok(diff) } -/// Filters the removed nodes of both account trie updates and storage trie updates, so that they -/// don't include those nodes that were also updated. -fn adjust_trie_updates(trie_updates: TrieUpdates) -> TrieUpdates { - TrieUpdates { - removed_nodes: trie_updates - .removed_nodes - .into_iter() - .filter(|key| !trie_updates.account_nodes.contains_key(key)) - .collect(), - storage_tries: trie_updates - .storage_tries - .into_iter() - .map(|(address, updates)| { - ( - address, - StorageTrieUpdates { - removed_nodes: updates - .removed_nodes - .into_iter() - .filter(|key| !updates.storage_nodes.contains_key(key)) - .collect(), - ..updates - }, - ) - }) - .collect(), - ..trie_updates - } -} - /// Compares the branch nodes from state root task and regular state root calculation. /// /// If one of the branch nodes is [`None`], it means it's not updated and the other is compared to diff --git a/crates/trie/common/Cargo.toml b/crates/trie/common/Cargo.toml index 55dadbab1d..ad23085b3e 100644 --- a/crates/trie/common/Cargo.toml +++ b/crates/trie/common/Cargo.toml @@ -141,3 +141,7 @@ harness = false name = "hashed_state" harness = false required-features = ["rayon"] + +[[bench]] +name = "trie_updates" +harness = false diff --git a/crates/trie/common/benches/trie_updates.rs b/crates/trie/common/benches/trie_updates.rs new file mode 100644 index 0000000000..a01c211050 --- /dev/null +++ b/crates/trie/common/benches/trie_updates.rs @@ -0,0 +1,494 @@ +#![allow(missing_docs, unreachable_pub)] +//! Benchmark comparing two representations for trie updates: +//! 1. Current: `HashMap` + `HashSet` (separate) +//! 2. Consolidated: `HashMap>` (unified) +//! +//! The consolidation aims to: +//! - Reduce `HashMap` overhead (one map instead of two) +//! - Improve memory layout (less fragmentation) +//! - Reduce cache misses (related data stored together) + +use alloy_primitives::map::{HashMap, HashSet}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use prop::test_runner::TestRng; +use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner}; +use reth_trie_common::{BranchNodeCompact, Nibbles, TrieMask}; +use std::hint::black_box; + +/// Current representation: separate `HashMap` and `HashSet` +#[derive(Default, Clone)] +struct TrieUpdatesSeparate { + account_nodes: HashMap, + removed_nodes: HashSet, +} + +impl TrieUpdatesSeparate { + fn insert_node(&mut self, path: Nibbles, node: BranchNodeCompact) { + self.removed_nodes.remove(&path); + self.account_nodes.insert(path, node); + } + + fn remove_node(&mut self, path: Nibbles) { + self.account_nodes.remove(&path); + self.removed_nodes.insert(path); + } + + fn get_node(&self, path: &Nibbles) -> Option<&BranchNodeCompact> { + self.account_nodes.get(path) + } + + fn is_removed(&self, path: &Nibbles) -> bool { + self.removed_nodes.contains(path) + } + + fn is_empty(&self) -> bool { + self.account_nodes.is_empty() && self.removed_nodes.is_empty() + } + + fn len(&self) -> usize { + self.account_nodes.len() + self.removed_nodes.len() + } + + fn extend(&mut self, other: Self) { + // Removed nodes in other should remove existing account nodes + self.account_nodes.retain(|k, _| !other.removed_nodes.contains(k)); + self.account_nodes.extend(other.account_nodes); + self.removed_nodes.extend(other.removed_nodes); + } + + fn iter_all(&self) -> impl Iterator)> { + self.account_nodes + .iter() + .map(|(k, v)| (k, Some(v))) + .chain(self.removed_nodes.iter().map(|k| (k, None))) + } +} + +/// Consolidated representation: single `HashMap` with `Option` +#[derive(Default, Clone)] +struct TrieUpdatesConsolidated { + /// None = removed, Some = updated + nodes: HashMap>, +} + +impl TrieUpdatesConsolidated { + fn insert_node(&mut self, path: Nibbles, node: BranchNodeCompact) { + self.nodes.insert(path, Some(node)); + } + + fn remove_node(&mut self, path: Nibbles) { + self.nodes.insert(path, None); + } + + fn get_node(&self, path: &Nibbles) -> Option<&BranchNodeCompact> { + self.nodes.get(path).and_then(|opt| opt.as_ref()) + } + + fn is_removed(&self, path: &Nibbles) -> bool { + self.nodes.get(path).is_some_and(|opt| opt.is_none()) + } + + fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + fn len(&self) -> usize { + self.nodes.len() + } + + fn extend(&mut self, other: Self) { + // When extending, other's entries override ours + // If other marks a path as removed (None), it overrides our entry + // If other has a node (Some), it overrides our entry + self.nodes.extend(other.nodes); + } + + fn iter_all(&self) -> impl Iterator)> { + self.nodes.iter().map(|(k, v)| (k, v.as_ref())) + } +} + +fn print_sizes() { + println!("\n=== Type Sizes ==="); + println!("Nibbles: {} bytes", std::mem::size_of::()); + println!("BranchNodeCompact: {} bytes", std::mem::size_of::()); + println!( + "Option: {} bytes", + std::mem::size_of::>() + ); + println!( + "(Nibbles, BranchNodeCompact): {} bytes", + std::mem::size_of::<(Nibbles, BranchNodeCompact)>() + ); + println!( + "(Nibbles, Option): {} bytes", + std::mem::size_of::<(Nibbles, Option)>() + ); + println!(); +} + +fn generate_nibbles(runner: &mut TestRunner, count: usize) -> Vec { + use prop::collection::vec; + let strategy = vec(any_with::((1..=32usize).into()), count); + let mut nibbles = strategy.new_tree(runner).unwrap().current(); + nibbles.sort(); + nibbles.dedup(); + nibbles +} + +fn generate_branch_node() -> BranchNodeCompact { + // Create a valid BranchNodeCompact with matching hash_mask and hashes count + // hash_mask must have exactly the same number of bits set as hashes.len() + BranchNodeCompact::new( + TrieMask::new(0b1111_0000_1111_0000), // state_mask + TrieMask::new(0b0011_0000_0011_0000), // tree_mask + TrieMask::new(0), // hash_mask (0 bits = 0 hashes) + vec![], // hashes (empty) + None, // root_hash + ) +} + +fn bench_insert(c: &mut Criterion) { + print_sizes(); + let mut group = c.benchmark_group("trie_updates_insert"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + group.throughput(Throughput::Elements(size as u64)); + + group.bench_with_input(BenchmarkId::new("separate", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesSeparate::default(); + for path in &nibbles { + updates.insert_node(*path, node.clone()); + } + black_box(updates) + }); + }); + + group.bench_with_input(BenchmarkId::new("consolidated", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesConsolidated::default(); + for path in &nibbles { + updates.insert_node(*path, node.clone()); + } + black_box(updates) + }); + }); + } + + group.finish(); +} + +fn bench_mixed_operations(c: &mut Criterion) { + let mut group = c.benchmark_group("trie_updates_mixed"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + group.throughput(Throughput::Elements(size as u64)); + + // Mix of 70% inserts, 30% removes + group.bench_with_input(BenchmarkId::new("separate", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesSeparate::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + updates.insert_node(*path, node.clone()); + } else { + updates.remove_node(*path); + } + } + black_box(updates) + }); + }); + + group.bench_with_input(BenchmarkId::new("consolidated", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + updates.insert_node(*path, node.clone()); + } else { + updates.remove_node(*path); + } + } + black_box(updates) + }); + }); + } + + group.finish(); +} + +fn bench_lookup(c: &mut Criterion) { + let mut group = c.benchmark_group("trie_updates_lookup"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + // Pre-populate the structures + let mut separate = TrieUpdatesSeparate::default(); + let mut consolidated = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + separate.insert_node(*path, node.clone()); + consolidated.insert_node(*path, node.clone()); + } else { + separate.remove_node(*path); + consolidated.remove_node(*path); + } + } + + group.throughput(Throughput::Elements(size as u64)); + + group.bench_with_input(BenchmarkId::new("separate", size), &size, |b, _| { + b.iter(|| { + let mut found = 0usize; + for path in &nibbles { + if separate.get_node(path).is_some() { + found += 1; + } + if separate.is_removed(path) { + found += 1; + } + } + black_box(found) + }); + }); + + group.bench_with_input(BenchmarkId::new("consolidated", size), &size, |b, _| { + b.iter(|| { + let mut found = 0usize; + for path in &nibbles { + if consolidated.get_node(path).is_some() { + found += 1; + } + if consolidated.is_removed(path) { + found += 1; + } + } + black_box(found) + }); + }); + } + + group.finish(); +} + +fn bench_extend(c: &mut Criterion) { + let mut group = c.benchmark_group("trie_updates_extend"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles1 = generate_nibbles(&mut runner, size); + let nibbles2 = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + // Pre-populate first set + let mut separate1 = TrieUpdatesSeparate::default(); + let mut consolidated1 = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles1.iter().enumerate() { + if i % 10 < 7 { + separate1.insert_node(*path, node.clone()); + consolidated1.insert_node(*path, node.clone()); + } else { + separate1.remove_node(*path); + consolidated1.remove_node(*path); + } + } + + // Pre-populate second set + let mut separate2 = TrieUpdatesSeparate::default(); + let mut consolidated2 = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles2.iter().enumerate() { + if i % 10 < 7 { + separate2.insert_node(*path, node.clone()); + consolidated2.insert_node(*path, node.clone()); + } else { + separate2.remove_node(*path); + consolidated2.remove_node(*path); + } + } + + group.throughput(Throughput::Elements((size * 2) as u64)); + + group.bench_with_input(BenchmarkId::new("separate", size), &size, |b, _| { + b.iter_batched( + || (separate1.clone(), separate2.clone()), + |(mut s1, s2)| { + s1.extend(s2); + black_box(s1) + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_with_input(BenchmarkId::new("consolidated", size), &size, |b, _| { + b.iter_batched( + || (consolidated1.clone(), consolidated2.clone()), + |(mut c1, c2)| { + c1.extend(c2); + black_box(c1) + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +fn bench_iteration(c: &mut Criterion) { + let mut group = c.benchmark_group("trie_updates_iteration"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + // Pre-populate the structures + let mut separate = TrieUpdatesSeparate::default(); + let mut consolidated = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + separate.insert_node(*path, node.clone()); + consolidated.insert_node(*path, node.clone()); + } else { + separate.remove_node(*path); + consolidated.remove_node(*path); + } + } + + group.throughput(Throughput::Elements(size as u64)); + + group.bench_with_input(BenchmarkId::new("separate", size), &size, |b, _| { + b.iter(|| { + let mut count = 0usize; + for (_, node) in separate.iter_all() { + if node.is_some() { + count += 1; + } + } + black_box(count) + }); + }); + + group.bench_with_input(BenchmarkId::new("consolidated", size), &size, |b, _| { + b.iter(|| { + let mut count = 0usize; + for (_, node) in consolidated.iter_all() { + if node.is_some() { + count += 1; + } + } + black_box(count) + }); + }); + } + + group.finish(); +} + +fn bench_memory_size(c: &mut Criterion) { + let mut group = c.benchmark_group("trie_updates_memory"); + + for size in [100, 1_000, 10_000] { + let config = proptest::test_runner::Config::default(); + let rng = TestRng::deterministic_rng(config.rng_algorithm); + let mut runner = TestRunner::new_with_rng(config, rng); + let nibbles = generate_nibbles(&mut runner, size); + let node = generate_branch_node(); + + // Pre-populate with 70% inserts, 30% removes + let mut separate = TrieUpdatesSeparate::default(); + let mut consolidated = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + separate.insert_node(*path, node.clone()); + consolidated.insert_node(*path, node.clone()); + } else { + separate.remove_node(*path); + consolidated.remove_node(*path); + } + } + + // Calculate approximate memory usage + let separate_size = std::mem::size_of::() + + separate.account_nodes.capacity() * + (std::mem::size_of::() + std::mem::size_of::()) + + separate.removed_nodes.capacity() * std::mem::size_of::(); + + let consolidated_size = std::mem::size_of::() + + consolidated.nodes.capacity() * + (std::mem::size_of::() + + std::mem::size_of::>()); + + println!( + "Size {}: Separate={} bytes, Consolidated={} bytes, Savings={}%", + size, + separate_size, + consolidated_size, + 100 - (consolidated_size * 100 / separate_size.max(1)) + ); + + // Benchmark memory allocation overhead + group.bench_with_input(BenchmarkId::new("separate_alloc", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesSeparate::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + updates.insert_node(*path, node.clone()); + } else { + updates.remove_node(*path); + } + } + black_box((updates.len(), updates.is_empty())) + }); + }); + + group.bench_with_input(BenchmarkId::new("consolidated_alloc", size), &size, |b, _| { + b.iter(|| { + let mut updates = TrieUpdatesConsolidated::default(); + for (i, path) in nibbles.iter().enumerate() { + if i % 10 < 7 { + updates.insert_node(*path, node.clone()); + } else { + updates.remove_node(*path); + } + } + black_box((updates.len(), updates.is_empty())) + }); + }); + } + + group.finish(); +} + +criterion_group!( + trie_updates, + bench_insert, + bench_mixed_operations, + bench_lookup, + bench_extend, + bench_iteration, + bench_memory_size, +); +criterion_main!(trie_updates); diff --git a/crates/trie/common/src/updates.rs b/crates/trie/common/src/updates.rs index f1db882781..083d9b0279 100644 --- a/crates/trie/common/src/updates.rs +++ b/crates/trie/common/src/updates.rs @@ -9,15 +9,17 @@ use alloy_primitives::{ }; /// The aggregation of trie updates. +/// +/// Account nodes are stored as `Option` where: +/// - `Some(node)` indicates an updated node +/// - `None` indicates a removed node #[derive(PartialEq, Eq, Clone, Default, Debug)] #[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize, serde::Deserialize))] pub struct TrieUpdates { - /// Collection of updated intermediate account nodes indexed by full path. - #[cfg_attr(any(test, feature = "serde"), serde(with = "serde_nibbles_map"))] - pub account_nodes: HashMap, - /// Collection of removed intermediate account nodes indexed by full path. - #[cfg_attr(any(test, feature = "serde"), serde(with = "serde_nibbles_set"))] - pub removed_nodes: HashSet, + /// Collection of account node updates indexed by full path. + /// `Some(node)` = updated node, `None` = removed node. + #[cfg_attr(any(test, feature = "serde"), serde(with = "serde_nibbles_option_map"))] + pub account_nodes: HashMap>, /// Collection of updated storage tries indexed by the hashed address. pub storage_tries: B256Map, } @@ -25,21 +27,14 @@ pub struct TrieUpdates { impl TrieUpdates { /// Returns `true` if the updates are empty. pub fn is_empty(&self) -> bool { - self.account_nodes.is_empty() && - self.removed_nodes.is_empty() && - self.storage_tries.is_empty() + self.account_nodes.is_empty() && self.storage_tries.is_empty() } - /// Returns reference to updated account nodes. - pub const fn account_nodes_ref(&self) -> &HashMap { + /// Returns reference to account nodes (both updated and removed). + pub const fn account_nodes_ref(&self) -> &HashMap> { &self.account_nodes } - /// Returns a reference to removed account nodes. - pub const fn removed_nodes_ref(&self) -> &HashSet { - &self.removed_nodes - } - /// Returns a reference to updated storage tries. pub const fn storage_tries_ref(&self) -> &B256Map { &self.storage_tries @@ -47,9 +42,7 @@ impl TrieUpdates { /// Extends the trie updates. pub fn extend(&mut self, other: Self) { - self.extend_common(&other); - self.account_nodes.extend(exclude_empty_from_pair(other.account_nodes)); - self.removed_nodes.extend(exclude_empty(other.removed_nodes)); + self.account_nodes.extend(exclude_empty_from_option_pair(other.account_nodes)); for (hashed_address, storage_trie) in other.storage_tries { self.storage_tries.entry(hashed_address).or_default().extend(storage_trie); } @@ -59,20 +52,14 @@ impl TrieUpdates { /// /// Slightly less efficient than [`Self::extend`], but preferred to `extend(other.clone())`. pub fn extend_ref(&mut self, other: &Self) { - self.extend_common(other); - self.account_nodes.extend(exclude_empty_from_pair( + self.account_nodes.extend(exclude_empty_from_option_pair( other.account_nodes.iter().map(|(k, v)| (*k, v.clone())), )); - self.removed_nodes.extend(exclude_empty(other.removed_nodes.iter().copied())); for (hashed_address, storage_trie) in &other.storage_tries { self.storage_tries.entry(*hashed_address).or_default().extend_ref(storage_trie); } } - fn extend_common(&mut self, other: &Self) { - self.account_nodes.retain(|nibbles, _| !other.removed_nodes.contains(nibbles)); - } - /// Extend trie updates with sorted data, converting directly into the unsorted `HashMap` /// representation. This is more efficient than first converting to `TrieUpdates` and /// then extending, as it avoids creating intermediate `HashMap` allocations. @@ -84,21 +71,12 @@ impl TrieUpdates { let new_nodes_count = sorted.account_nodes.len(); self.account_nodes.reserve(new_nodes_count); - // Insert account nodes from sorted (only non-None entries) + // Insert account nodes from sorted for (nibbles, maybe_node) in &sorted.account_nodes { if nibbles.is_empty() { continue; } - match maybe_node { - Some(node) => { - self.removed_nodes.remove(nibbles); - self.account_nodes.insert(*nibbles, node.clone()); - } - None => { - self.account_nodes.remove(nibbles); - self.removed_nodes.insert(*nibbles); - } - } + self.account_nodes.insert(*nibbles, maybe_node.clone()); } // Extend storage tries @@ -131,12 +109,15 @@ impl TrieUpdates { removed_keys: HashSet, destroyed_accounts: B256Set, ) { - // Retrieve updated nodes from hash builder. - let (_, updated_nodes) = hash_builder.split(); - self.account_nodes.extend(exclude_empty_from_pair(updated_nodes)); + // Add deleted node paths first (None indicates removal). + // Updated nodes take precedence over removed nodes. + self.account_nodes.extend(exclude_empty(removed_keys).map(|k| (k, None))); - // Add deleted node paths. - self.removed_nodes.extend(exclude_empty(removed_keys)); + // Retrieve updated nodes from hash builder. + // Extend after removed_keys so updates take precedence. + let (_, updated_nodes) = hash_builder.split(); + self.account_nodes + .extend(exclude_empty_from_pair(updated_nodes).map(|(k, v)| (k, Some(v)))); // Add deleted storage tries for destroyed accounts. for destroyed in destroyed_accounts { @@ -145,35 +126,35 @@ impl TrieUpdates { } /// Converts trie updates into [`TrieUpdatesSorted`]. - pub fn into_sorted(mut self) -> TrieUpdatesSorted { - let mut account_nodes = self - .account_nodes - .drain() - .map(|(path, node)| { - // Updated nodes take precedence over removed nodes. - self.removed_nodes.remove(&path); - (path, Some(node)) - }) - .collect::>(); - - account_nodes.extend(self.removed_nodes.drain().map(|path| (path, None))); + pub fn into_sorted(self) -> TrieUpdatesSorted { + let mut account_nodes = self.account_nodes.into_iter().collect::>(); account_nodes.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let storage_tries = self .storage_tries - .drain() + .into_iter() .map(|(hashed_address, updates)| (hashed_address, updates.into_sorted())) .collect(); TrieUpdatesSorted { account_nodes, storage_tries } } /// Converts trie updates into [`TrieUpdatesSortedRef`]. - pub fn into_sorted_ref<'a>(&'a self) -> TrieUpdatesSortedRef<'a> { - let mut account_nodes = self.account_nodes.iter().collect::>(); + pub fn into_sorted_ref(&self) -> TrieUpdatesSortedRef<'_> { + let mut account_nodes = self + .account_nodes + .iter() + .filter_map(|(k, v)| v.as_ref().map(|node| (k, node))) + .collect::>(); account_nodes.sort_unstable_by(|a, b| a.0.cmp(b.0)); + let removed_nodes = self + .account_nodes + .iter() + .filter_map(|(k, v)| v.is_none().then_some(k)) + .collect::>(); + TrieUpdatesSortedRef { - removed_nodes: self.removed_nodes.iter().collect::>(), + removed_nodes, account_nodes, storage_tries: self .storage_tries @@ -186,7 +167,6 @@ impl TrieUpdates { /// Clears the nodes and storage trie maps in this `TrieUpdates`. pub fn clear(&mut self) { self.account_nodes.clear(); - self.removed_nodes.clear(); self.storage_tries.clear(); } } @@ -473,6 +453,91 @@ mod serde_nibbles_map { } } +/// Serializes and deserializes any [`HashMap`] that uses [`Nibbles`] as keys and `Option` as +/// values, by using the hex-encoded packed representation. +/// +/// This also sorts the map's keys before encoding and serializing. +#[cfg(any(test, feature = "serde"))] +mod serde_nibbles_option_map { + use crate::Nibbles; + use alloc::{ + string::{String, ToString}, + vec::Vec, + }; + use alloy_primitives::{hex, map::HashMap}; + use core::marker::PhantomData; + use serde::{ + de::{Error, MapAccess, Visitor}, + ser::SerializeMap, + Deserialize, Deserializer, Serialize, Serializer, + }; + + pub(super) fn serialize( + map: &HashMap>, + serializer: S, + ) -> Result + where + S: Serializer, + T: Serialize, + { + let mut map_serializer = serializer.serialize_map(Some(map.len()))?; + let mut nodes = Vec::from_iter(map); + nodes.sort_unstable_by_key(|node| node.0); + for (k, v) in nodes { + // pack, then hex encode the Nibbles + let packed = alloy_primitives::hex::encode(k.pack()); + map_serializer.serialize_entry(&packed, &v)?; + } + map_serializer.end() + } + + pub(super) fn deserialize<'de, D, T>( + deserializer: D, + ) -> Result>, D::Error> + where + D: Deserializer<'de>, + T: Deserialize<'de>, + { + struct NibblesOptionMapVisitor { + marker: PhantomData, + } + + impl<'de, T> Visitor<'de> for NibblesOptionMapVisitor + where + T: Deserialize<'de>, + { + type Value = HashMap>; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + formatter.write_str("a map with hex-encoded Nibbles keys and optional values") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut result = HashMap::with_capacity_and_hasher( + map.size_hint().unwrap_or(0), + Default::default(), + ); + + while let Some((key, value)) = map.next_entry::>()? { + let decoded_key = + hex::decode(&key).map_err(|err| Error::custom(err.to_string()))?; + + let nibbles = Nibbles::unpack(&decoded_key); + + result.insert(nibbles, value); + } + + Ok(result) + } + } + + deserializer.deserialize_map(NibblesOptionMapVisitor { marker: PhantomData }) + } +} + /// Sorted trie updates reference used for serializing trie to file. #[derive(PartialEq, Eq, Clone, Default, Debug)] #[cfg_attr(any(test, feature = "serde"), derive(serde::Serialize))] @@ -574,16 +639,7 @@ impl AsRef for TrieUpdatesSorted { impl From for TrieUpdates { fn from(sorted: TrieUpdatesSorted) -> Self { - let mut account_nodes = HashMap::default(); - let mut removed_nodes = HashSet::default(); - - for (nibbles, node) in sorted.account_nodes { - if let Some(node) = node { - account_nodes.insert(nibbles, node); - } else { - removed_nodes.insert(nibbles); - } - } + let account_nodes = sorted.account_nodes.into_iter().collect(); let storage_tries = sorted .storage_tries @@ -591,7 +647,7 @@ impl From for TrieUpdates { .map(|(address, storage)| (address, storage.into())) .collect(); - Self { account_nodes, removed_nodes, storage_tries } + Self { account_nodes, storage_tries } } } @@ -669,6 +725,14 @@ fn exclude_empty_from_pair( iter.into_iter().filter(|(n, _)| !n.is_empty()) } +/// Excludes empty nibbles from the given iterator of pairs where the nibbles are the key +/// and value is `Option`. +fn exclude_empty_from_option_pair( + iter: impl IntoIterator)>, +) -> impl Iterator)> { + iter.into_iter().filter(|(n, _)| !n.is_empty()) +} + impl From for StorageTrieUpdates { fn from(sorted: StorageTrieUpdatesSorted) -> Self { let mut storage_nodes = HashMap::default(); @@ -987,8 +1051,7 @@ pub mod serde_bincode_compat { /// ``` #[derive(Debug, Serialize, Deserialize)] pub struct TrieUpdates<'a> { - account_nodes: Cow<'a, HashMap>, - removed_nodes: Cow<'a, HashSet>, + account_nodes: Cow<'a, HashMap>>, storage_tries: B256Map>, } @@ -996,7 +1059,6 @@ pub mod serde_bincode_compat { fn from(value: &'a super::TrieUpdates) -> Self { Self { account_nodes: Cow::Borrowed(&value.account_nodes), - removed_nodes: Cow::Borrowed(&value.removed_nodes), storage_tries: value.storage_tries.iter().map(|(k, v)| (*k, v.into())).collect(), } } @@ -1006,7 +1068,6 @@ pub mod serde_bincode_compat { fn from(value: TrieUpdates<'a>) -> Self { Self { account_nodes: value.account_nodes.into_owned(), - removed_nodes: value.removed_nodes.into_owned(), storage_tries: value .storage_tries .into_iter() @@ -1122,16 +1183,18 @@ pub mod serde_bincode_compat { let decoded: Data = bincode::deserialize(&encoded).unwrap(); assert_eq!(decoded, data); + // Insert a removed node (None) data.trie_updates - .removed_nodes - .insert(Nibbles::from_nibbles_unchecked([0x0b, 0x0e, 0x0e, 0x0f])); + .account_nodes + .insert(Nibbles::from_nibbles_unchecked([0x0b, 0x0e, 0x0e, 0x0f]), None); let encoded = bincode::serialize(&data).unwrap(); let decoded: Data = bincode::deserialize(&encoded).unwrap(); assert_eq!(decoded, data); + // Insert an updated node (Some) data.trie_updates.account_nodes.insert( Nibbles::from_nibbles_unchecked([0x0d, 0x0e, 0x0a, 0x0d]), - BranchNodeCompact::default(), + Some(BranchNodeCompact::default()), ); let encoded = bincode::serialize(&data).unwrap(); let decoded: Data = bincode::deserialize(&encoded).unwrap(); @@ -1186,16 +1249,18 @@ mod serde_tests { let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap(); assert_eq!(updates_deserialized, default_updates); + // Insert a removed node (None) default_updates - .removed_nodes - .insert(Nibbles::from_nibbles_unchecked([0x0b, 0x0e, 0x0e, 0x0f])); + .account_nodes + .insert(Nibbles::from_nibbles_unchecked([0x0b, 0x0e, 0x0e, 0x0f]), None); let updates_serialized = serde_json::to_string(&default_updates).unwrap(); let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap(); assert_eq!(updates_deserialized, default_updates); + // Insert an updated node (Some) default_updates.account_nodes.insert( Nibbles::from_nibbles_unchecked([0x0d, 0x0e, 0x0a, 0x0d]), - BranchNodeCompact::default(), + Some(BranchNodeCompact::default()), ); let updates_serialized = serde_json::to_string(&default_updates).unwrap(); let updates_deserialized: TrieUpdates = serde_json::from_str(&updates_serialized).unwrap(); diff --git a/crates/trie/db/tests/trie.rs b/crates/trie/db/tests/trie.rs index 8f543a711d..c3cdc7e6ba 100644 --- a/crates/trie/db/tests/trie.rs +++ b/crates/trie/db/tests/trie.rs @@ -521,16 +521,17 @@ fn account_and_storage_trie() { .root_with_updates() .unwrap(); assert_eq!(root, computed_expected_root); - assert_eq!( - trie_updates.account_nodes_ref().len() + trie_updates.removed_nodes_ref().len(), - 1 - ); - + // account_nodes now contains both updated (Some) and removed (None) entries assert_eq!(trie_updates.account_nodes_ref().len(), 1); - let entry = trie_updates.account_nodes_ref().iter().next().unwrap(); + // Count only updated entries (Some) + let updated_count = + trie_updates.account_nodes_ref().iter().filter(|(_, v)| v.is_some()).count(); + assert_eq!(updated_count, 1); + + let entry = trie_updates.account_nodes_ref().iter().find(|(_, v)| v.is_some()).unwrap(); assert_eq!(entry.0.to_vec(), vec![0xB]); - let node1c = entry.1; + let node1c = entry.1.as_ref().unwrap(); assert_eq!(node1c.state_mask, TrieMask::new(0b1011)); assert_eq!(node1c.tree_mask, TrieMask::new(0b0000)); @@ -576,20 +577,21 @@ fn account_and_storage_trie() { .root_with_updates() .unwrap(); assert_eq!(root, computed_expected_root); - assert_eq!( - trie_updates.account_nodes_ref().len() + trie_updates.removed_nodes_ref().len(), - 1 - ); + // account_nodes contains both updated (Some) and removed (None) entries + assert_eq!(trie_updates.account_nodes_ref().len(), 1); assert!(!trie_updates .storage_tries_ref() .iter() - .any(|(_, u)| !u.storage_nodes_ref().is_empty() || !u.removed_nodes_ref().is_empty())); // no storage root update + .any(|(_, u)| !u.storage_nodes.is_empty() || !u.removed_nodes.is_empty())); // no storage root update - assert_eq!(trie_updates.account_nodes_ref().len(), 1); + // Count only updated entries (Some) + let updated_count = + trie_updates.account_nodes_ref().iter().filter(|(_, v)| v.is_some()).count(); + assert_eq!(updated_count, 1); - let entry = trie_updates.account_nodes_ref().iter().next().unwrap(); + let entry = trie_updates.account_nodes_ref().iter().find(|(_, v)| v.is_some()).unwrap(); assert_eq!(entry.0.to_vec(), vec![0xB]); - let node1d = entry.1; + let node1d = entry.1.as_ref().unwrap(); assert_eq!(node1d.state_mask, TrieMask::new(0b1011)); assert_eq!(node1d.tree_mask, TrieMask::new(0b0000)); @@ -629,11 +631,11 @@ fn account_trie_around_extension_node_with_dbtrie() { // read the account updates from the db let mut accounts_trie = tx.tx_ref().cursor_read::().unwrap(); let walker = accounts_trie.walk(None).unwrap(); - let account_updates = walker + let account_updates: HashMap> = walker .into_iter() .map(|item| { let (key, node) = item.unwrap(); - (key.0, node) + (key.0, Some(node)) }) .collect(); assert_trie_updates(&account_updates); @@ -688,7 +690,7 @@ fn storage_trie_around_extension_node() { StorageRoot::from_tx_hashed(tx.tx_ref(), hashed_address).root_with_updates().unwrap(); assert_eq!(expected_root, got); assert_eq!(expected_updates, updates); - assert_trie_updates(updates.storage_nodes_ref()); + assert_storage_trie_updates(updates.storage_nodes_ref()); } fn extension_node_storage_trie( @@ -745,14 +747,38 @@ fn extension_node_trie( hb.root() } -fn assert_trie_updates(account_updates: &HashMap) { - assert_eq!(account_updates.len(), 2); +fn assert_trie_updates(account_updates: &HashMap>) { + // Filter to only updated nodes (Some variants) + let updated_count = account_updates.values().filter(|v| v.is_some()).count(); + assert_eq!(updated_count, 2); - let node = account_updates.get(&Nibbles::from_nibbles_unchecked([0x3])).unwrap(); + let node = account_updates + .get(&Nibbles::from_nibbles_unchecked([0x3])) + .and_then(|v| v.as_ref()) + .unwrap(); let expected = BranchNodeCompact::new(0b0011, 0b0001, 0b0000, vec![], None); assert_eq!(node, &expected); - let node = account_updates.get(&Nibbles::from_nibbles_unchecked([0x3, 0x0, 0xA, 0xF])).unwrap(); + let node = account_updates + .get(&Nibbles::from_nibbles_unchecked([0x3, 0x0, 0xA, 0xF])) + .and_then(|v| v.as_ref()) + .unwrap(); + assert_eq!(node.state_mask, TrieMask::new(0b101100000)); + assert_eq!(node.tree_mask, TrieMask::new(0b000000000)); + assert_eq!(node.hash_mask, TrieMask::new(0b001000000)); + + assert_eq!(node.root_hash, None); + assert_eq!(node.hashes.len(), 1); +} + +fn assert_storage_trie_updates(storage_updates: &HashMap) { + assert_eq!(storage_updates.len(), 2); + + let node = storage_updates.get(&Nibbles::from_nibbles_unchecked([0x3])).unwrap(); + let expected = BranchNodeCompact::new(0b0011, 0b0001, 0b0000, vec![], None); + assert_eq!(node, &expected); + + let node = storage_updates.get(&Nibbles::from_nibbles_unchecked([0x3, 0x0, 0xA, 0xF])).unwrap(); assert_eq!(node.state_mask, TrieMask::new(0b101100000)); assert_eq!(node.tree_mask, TrieMask::new(0b000000000)); assert_eq!(node.hash_mask, TrieMask::new(0b001000000)); diff --git a/crates/trie/sparse-parallel/src/trie.rs b/crates/trie/sparse-parallel/src/trie.rs index 3ccc5aad1a..c8e10cf2fe 100644 --- a/crates/trie/sparse-parallel/src/trie.rs +++ b/crates/trie/sparse-parallel/src/trie.rs @@ -9,8 +9,8 @@ use alloy_trie::{BranchNodeCompact, TrieMask, EMPTY_ROOT_HASH}; use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult}; use reth_trie_common::{ prefix_set::{PrefixSet, PrefixSetMut}, - BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, TrieMasks, - TrieNode, CHILD_INDEX_RANGE, + BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, + ProofTrieNode, RlpNode, TrieMasks, TrieNode, CHILD_INDEX_RANGE, }; use reth_trie_sparse::{ provider::{RevealedNode, TrieNodeProvider}, @@ -112,10 +112,12 @@ pub struct ParallelSparseTrie { prefix_set: PrefixSetMut, /// Optional tracking of trie updates for later use. updates: Option, - /// When a bit is set, the corresponding child subtree is stored in the database. - branch_node_tree_masks: HashMap, - /// When a bit is set, the corresponding child is stored as a hash in the database. - branch_node_hash_masks: HashMap, + /// Branch node masks containing `tree_mask` and `hash_mask` for each path. + /// - `tree_mask`: When a bit is set, the corresponding child subtree is stored in the + /// database. + /// - `hash_mask`: When a bit is set, the corresponding child is stored as a hash in the + /// database. + branch_node_masks: BranchNodeMasksMap, /// Reusable buffer pool used for collecting [`SparseTrieUpdatesAction`]s during hash /// computations. update_actions_buffers: Vec>, @@ -136,8 +138,7 @@ impl Default for ParallelSparseTrie { lower_subtries: [const { LowerSparseSubtrie::Blind(None) }; NUM_LOWER_SUBTRIES], prefix_set: PrefixSetMut::default(), updates: None, - branch_node_tree_masks: HashMap::default(), - branch_node_hash_masks: HashMap::default(), + branch_node_masks: BranchNodeMasksMap::default(), update_actions_buffers: Vec::default(), parallelism_thresholds: Default::default(), #[cfg(feature = "metrics")] @@ -187,11 +188,14 @@ impl SparseTrieInterface for ParallelSparseTrie { // Update the top-level branch node masks. This is simple and can't be done in parallel. for ProofTrieNode { path, masks, .. } in &nodes { - if let Some(tree_mask) = masks.tree_mask { - self.branch_node_tree_masks.insert(*path, tree_mask); - } - if let Some(hash_mask) = masks.hash_mask { - self.branch_node_hash_masks.insert(*path, hash_mask); + if masks.tree_mask.is_some() || masks.hash_mask.is_some() { + self.branch_node_masks.insert( + *path, + BranchNodeMasks { + tree_mask: masks.tree_mask.unwrap_or_default(), + hash_mask: masks.hash_mask.unwrap_or_default(), + }, + ); } } @@ -719,8 +723,7 @@ impl SparseTrieInterface for ParallelSparseTrie { changed_subtrie.subtrie.update_hashes( &mut changed_subtrie.prefix_set, &mut changed_subtrie.update_actions_buf, - &self.branch_node_tree_masks, - &self.branch_node_hash_masks, + &self.branch_node_masks, ); } @@ -736,8 +739,7 @@ impl SparseTrieInterface for ParallelSparseTrie { { use rayon::iter::{IntoParallelIterator, ParallelIterator}; - let branch_node_tree_masks = &self.branch_node_tree_masks; - let branch_node_hash_masks = &self.branch_node_hash_masks; + let branch_node_masks = &self.branch_node_masks; let updated_subtries: Vec<_> = changed_subtries .into_par_iter() .map(|mut changed_subtrie| { @@ -746,8 +748,7 @@ impl SparseTrieInterface for ParallelSparseTrie { changed_subtrie.subtrie.update_hashes( &mut changed_subtrie.prefix_set, &mut changed_subtrie.update_actions_buf, - branch_node_tree_masks, - branch_node_hash_masks, + branch_node_masks, ); #[cfg(feature = "metrics")] self.metrics.subtrie_hash_update_latency.record(start.elapsed()); @@ -786,8 +787,7 @@ impl SparseTrieInterface for ParallelSparseTrie { } self.prefix_set.clear(); self.updates = None; - self.branch_node_tree_masks.clear(); - self.branch_node_hash_masks.clear(); + self.branch_node_masks.clear(); // `update_actions_buffers` doesn't need to be cleared; we want to reuse the Vecs it has // buffered, and all of those are already inherently cleared when they get used. } @@ -870,9 +870,8 @@ impl SparseTrieInterface for ParallelSparseTrie { subtrie.shrink_nodes_to(size_per_subtrie); } - // shrink masks maps - self.branch_node_hash_masks.shrink_to(size); - self.branch_node_tree_masks.shrink_to(size); + // shrink masks map + self.branch_node_masks.shrink_to(size); } fn shrink_values_to(&mut self, size: usize) { @@ -1377,8 +1376,7 @@ impl ParallelSparseTrie { &mut update_actions_buf, stack_item, node, - &self.branch_node_tree_masks, - &self.branch_node_hash_masks, + &self.branch_node_masks, ); } @@ -2047,8 +2045,7 @@ impl SparseSubtrie { /// - `update_actions`: A buffer which `SparseTrieUpdatesAction`s will be written to in the /// event that any changes to the top-level updates are required. If None then update /// retention is disabled. - /// - `branch_node_tree_masks`: The tree masks for branch nodes - /// - `branch_node_hash_masks`: The hash masks for branch nodes + /// - `branch_node_masks`: The tree and hash masks for branch nodes. /// /// # Returns /// @@ -2062,8 +2059,7 @@ impl SparseSubtrie { &mut self, prefix_set: &mut PrefixSet, update_actions: &mut Option>, - branch_node_tree_masks: &HashMap, - branch_node_hash_masks: &HashMap, + branch_node_masks: &BranchNodeMasksMap, ) -> RlpNode { trace!(target: "trie::parallel_sparse", "Updating subtrie hashes"); @@ -2082,14 +2078,7 @@ impl SparseSubtrie { .get_mut(&path) .unwrap_or_else(|| panic!("node at path {path:?} does not exist")); - self.inner.rlp_node( - prefix_set, - update_actions, - stack_item, - node, - branch_node_tree_masks, - branch_node_hash_masks, - ); + self.inner.rlp_node(prefix_set, update_actions, stack_item, node, branch_node_masks); } debug_assert_eq!(self.inner.buffers.rlp_node_stack.len(), 1); @@ -2149,8 +2138,7 @@ impl SparseSubtrieInner { /// retention is disabled. /// - `stack_item`: The stack item to process /// - `node`: The sparse node to process (will be mutated to update hash) - /// - `branch_node_tree_masks`: The tree masks for branch nodes - /// - `branch_node_hash_masks`: The hash masks for branch nodes + /// - `branch_node_masks`: The tree and hash masks for branch nodes. /// /// # Side Effects /// @@ -2168,8 +2156,7 @@ impl SparseSubtrieInner { update_actions: &mut Option>, mut stack_item: RlpNodePathStackItem, node: &mut SparseNode, - branch_node_tree_masks: &HashMap, - branch_node_hash_masks: &HashMap, + branch_node_masks: &BranchNodeMasksMap, ) { let path = stack_item.path; trace!( @@ -2303,6 +2290,12 @@ impl SparseSubtrieInner { let mut tree_mask = TrieMask::default(); let mut hash_mask = TrieMask::default(); let mut hashes = Vec::new(); + + // Lazy lookup for branch node masks - shared across loop iterations + let mut path_masks_storage = None; + let mut path_masks = + || *path_masks_storage.get_or_insert_with(|| branch_node_masks.get(&path)); + for (i, child_path) in self.buffers.branch_child_buf.iter().enumerate() { if self.buffers.rlp_node_stack.last().is_some_and(|e| &e.path == child_path) { let RlpNodeStackItem { @@ -2326,9 +2319,9 @@ impl SparseSubtrieInner { } else { // A blinded node has the tree mask bit set child_node_type.is_hash() && - branch_node_tree_masks - .get(&path) - .is_some_and(|mask| mask.is_bit_set(last_child_nibble)) + path_masks().is_some_and(|masks| { + masks.tree_mask.is_bit_set(last_child_nibble) + }) }; if should_set_tree_mask_bit { tree_mask.set_bit(last_child_nibble); @@ -2340,9 +2333,9 @@ impl SparseSubtrieInner { let hash = child.as_hash().filter(|_| { child_node_type.is_branch() || (child_node_type.is_hash() && - branch_node_hash_masks.get(&path).is_some_and( - |mask| mask.is_bit_set(last_child_nibble), - )) + path_masks().is_some_and(|masks| { + masks.hash_mask.is_bit_set(last_child_nibble) + })) }); if let Some(hash) = hash { hash_mask.set_bit(last_child_nibble); @@ -2409,19 +2402,17 @@ impl SparseSubtrieInner { ); update_actions .push(SparseTrieUpdatesAction::InsertUpdated(path, branch_node)); - } else if branch_node_tree_masks.get(&path).is_some_and(|mask| !mask.is_empty()) || - branch_node_hash_masks.get(&path).is_some_and(|mask| !mask.is_empty()) - { - // If new tree and hash masks are empty, but previously they weren't, we - // need to remove the node update and add the node itself to the list of - // removed nodes. - update_actions.push(SparseTrieUpdatesAction::InsertRemoved(path)); - } else if branch_node_tree_masks.get(&path).is_none_or(|mask| mask.is_empty()) && - branch_node_hash_masks.get(&path).is_none_or(|mask| mask.is_empty()) - { - // If new tree and hash masks are empty, and they were previously empty - // as well, we need to remove the node update. - update_actions.push(SparseTrieUpdatesAction::RemoveUpdated(path)); + } else { + // New tree and hash masks are empty - check previous state + let prev_had_masks = path_masks() + .is_some_and(|m| !m.tree_mask.is_empty() || !m.hash_mask.is_empty()); + if prev_had_masks { + // Previously had masks, now empty - mark as removed + update_actions.push(SparseTrieUpdatesAction::InsertRemoved(path)); + } else { + // Previously empty too - just remove the update + update_actions.push(SparseTrieUpdatesAction::RemoveUpdated(path)); + } } store_in_db_trie @@ -2667,8 +2658,8 @@ mod tests { prefix_set::PrefixSetMut, proof::{ProofNodes, ProofRetainer}, updates::TrieUpdates, - BranchNode, ExtensionNode, HashBuilder, LeafNode, ProofTrieNode, RlpNode, TrieMask, - TrieMasks, TrieNode, EMPTY_ROOT_HASH, + BranchNode, BranchNodeMasksMap, ExtensionNode, HashBuilder, LeafNode, ProofTrieNode, + RlpNode, TrieMask, TrieMasks, TrieNode, EMPTY_ROOT_HASH, }; use reth_trie_db::DatabaseTrieCursorFactory; use reth_trie_sparse::{ @@ -3608,8 +3599,7 @@ mod tests { &mut PrefixSetMut::from([leaf_1_full_path, leaf_2_full_path, leaf_3_full_path]) .freeze(), &mut None, - &HashMap::default(), - &HashMap::default(), + &BranchNodeMasksMap::default(), ); // Compare hashes between hash builder and subtrie diff --git a/crates/trie/sparse/src/state.rs b/crates/trie/sparse/src/state.rs index 1d68f19cd7..8c5fdd5c1d 100644 --- a/crates/trie/sparse/src/state.rs +++ b/crates/trie/sparse/src/state.rs @@ -613,11 +613,16 @@ where let revealed = self.revealed_trie_mut(provider_factory)?; let (root, updates) = (revealed.root(), revealed.take_updates()); - let updates = TrieUpdates { - account_nodes: updates.updated_nodes, - removed_nodes: updates.removed_nodes, - storage_tries, - }; + // Convert updated_nodes (HashMap) to + // account_nodes (HashMap>) + let mut account_nodes = updates + .updated_nodes + .into_iter() + .map(|(k, v)| (k, Some(v))) + .collect::>(); + // Add removed nodes as None entries + account_nodes.extend(updates.removed_nodes.into_iter().map(|k| (k, None))); + let updates = TrieUpdates { account_nodes, storage_tries }; Ok((root, updates)) } @@ -649,11 +654,16 @@ where let storage_tries = self.storage_trie_updates(); self.state.as_revealed_mut().map(|state| { let updates = state.take_updates(); - TrieUpdates { - account_nodes: updates.updated_nodes, - removed_nodes: updates.removed_nodes, - storage_tries, - } + // Convert updated_nodes (HashMap) to + // account_nodes (HashMap>) + let mut account_nodes = updates + .updated_nodes + .into_iter() + .map(|(k, v)| (k, Some(v))) + .collect::>(); + // Add removed nodes as None entries + account_nodes.extend(updates.removed_nodes.into_iter().map(|k| (k, None))); + TrieUpdates { account_nodes, storage_tries } }) } @@ -1337,7 +1347,6 @@ mod tests { removed_nodes: HashSet::default() } )]), - removed_nodes: HashSet::default() } ); } diff --git a/crates/trie/sparse/src/trie.rs b/crates/trie/sparse/src/trie.rs index acad15bc15..1e0db04bf5 100644 --- a/crates/trie/sparse/src/trie.rs +++ b/crates/trie/sparse/src/trie.rs @@ -19,8 +19,9 @@ use alloy_rlp::Decodable; use reth_execution_errors::{SparseTrieErrorKind, SparseTrieResult}; use reth_trie_common::{ prefix_set::{PrefixSet, PrefixSetMut}, - BranchNodeCompact, BranchNodeRef, ExtensionNodeRef, LeafNodeRef, Nibbles, ProofTrieNode, - RlpNode, TrieMask, TrieMasks, TrieNode, CHILD_INDEX_RANGE, EMPTY_ROOT_HASH, + BranchNodeCompact, BranchNodeMasks, BranchNodeMasksMap, BranchNodeRef, ExtensionNodeRef, + LeafNodeRef, Nibbles, ProofTrieNode, RlpNode, TrieMask, TrieMasks, TrieNode, CHILD_INDEX_RANGE, + EMPTY_ROOT_HASH, }; use smallvec::SmallVec; use tracing::{debug, instrument, trace}; @@ -298,10 +299,12 @@ pub struct SerialSparseTrie { /// Map from a path (nibbles) to its corresponding sparse trie node. /// This contains all of the revealed nodes in trie. nodes: HashMap, - /// When a branch is set, the corresponding child subtree is stored in the database. - branch_node_tree_masks: HashMap, - /// When a bit is set, the corresponding child is stored as a hash in the database. - branch_node_hash_masks: HashMap, + /// Branch node masks containing `tree_mask` and `hash_mask` for each path. + /// - `tree_mask`: When a bit is set, the corresponding child subtree is stored in the + /// database. + /// - `hash_mask`: When a bit is set, the corresponding child is stored as a hash in the + /// database. + branch_node_masks: BranchNodeMasksMap, /// Map from leaf key paths to their values. /// All values are stored here instead of directly in leaf nodes. values: HashMap>, @@ -318,8 +321,7 @@ impl fmt::Debug for SerialSparseTrie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SerialSparseTrie") .field("nodes", &self.nodes) - .field("branch_tree_masks", &self.branch_node_tree_masks) - .field("branch_hash_masks", &self.branch_node_hash_masks) + .field("branch_node_masks", &self.branch_node_masks) .field("values", &self.values) .field("prefix_set", &self.prefix_set) .field("updates", &self.updates) @@ -404,8 +406,7 @@ impl Default for SerialSparseTrie { fn default() -> Self { Self { nodes: HashMap::from_iter([(Nibbles::default(), SparseNode::Empty)]), - branch_node_tree_masks: HashMap::default(), - branch_node_hash_masks: HashMap::default(), + branch_node_masks: BranchNodeMasksMap::default(), values: HashMap::default(), prefix_set: PrefixSetMut::default(), updates: None, @@ -456,11 +457,14 @@ impl SparseTrieInterface for SerialSparseTrie { return Ok(()) } - if let Some(tree_mask) = masks.tree_mask { - self.branch_node_tree_masks.insert(path, tree_mask); - } - if let Some(hash_mask) = masks.hash_mask { - self.branch_node_hash_masks.insert(path, hash_mask); + if masks.tree_mask.is_some() || masks.hash_mask.is_some() { + self.branch_node_masks.insert( + path, + BranchNodeMasks { + tree_mask: masks.tree_mask.unwrap_or_default(), + hash_mask: masks.hash_mask.unwrap_or_default(), + }, + ); } match node { @@ -959,8 +963,7 @@ impl SparseTrieInterface for SerialSparseTrie { self.nodes.clear(); self.nodes.insert(Nibbles::default(), SparseNode::Empty); - self.branch_node_tree_masks.clear(); - self.branch_node_hash_masks.clear(); + self.branch_node_masks.clear(); self.values.clear(); self.prefix_set.clear(); self.updates = None; @@ -1087,8 +1090,7 @@ impl SparseTrieInterface for SerialSparseTrie { fn shrink_nodes_to(&mut self, size: usize) { self.nodes.shrink_to(size); - self.branch_node_tree_masks.shrink_to(size); - self.branch_node_hash_masks.shrink_to(size); + self.branch_node_masks.shrink_to(size); } fn shrink_values_to(&mut self, size: usize) { @@ -1624,6 +1626,13 @@ impl SerialSparseTrie { let mut tree_mask = TrieMask::default(); let mut hash_mask = TrieMask::default(); let mut hashes = Vec::new(); + + // Lazy lookup for branch node masks - shared across loop iterations + let mut path_masks_storage = None; + let mut path_masks = || { + *path_masks_storage.get_or_insert_with(|| self.branch_node_masks.get(&path)) + }; + for (i, child_path) in buffers.branch_child_buf.iter().enumerate() { if buffers.rlp_node_stack.last().is_some_and(|e| &e.path == child_path) { let RlpNodeStackItem { @@ -1647,9 +1656,9 @@ impl SerialSparseTrie { } else { // A blinded node has the tree mask bit set child_node_type.is_hash() && - self.branch_node_tree_masks.get(&path).is_some_and( - |mask| mask.is_bit_set(last_child_nibble), - ) + path_masks().is_some_and(|masks| { + masks.tree_mask.is_bit_set(last_child_nibble) + }) }; if should_set_tree_mask_bit { tree_mask.set_bit(last_child_nibble); @@ -1661,11 +1670,9 @@ impl SerialSparseTrie { let hash = child.as_hash().filter(|_| { child_node_type.is_branch() || (child_node_type.is_hash() && - self.branch_node_hash_masks - .get(&path) - .is_some_and(|mask| { - mask.is_bit_set(last_child_nibble) - })) + path_masks().is_some_and(|masks| { + masks.hash_mask.is_bit_set(last_child_nibble) + })) }); if let Some(hash) = hash { hash_mask.set_bit(last_child_nibble); @@ -1729,30 +1736,16 @@ impl SerialSparseTrie { hash.filter(|_| path.is_empty()), ); updates.updated_nodes.insert(path, branch_node); - } else if self - .branch_node_tree_masks - .get(&path) - .is_some_and(|mask| !mask.is_empty()) || - self.branch_node_hash_masks - .get(&path) - .is_some_and(|mask| !mask.is_empty()) - { - // If new tree and hash masks are empty, but previously they weren't, we - // need to remove the node update and add the node itself to the list of - // removed nodes. - updates.updated_nodes.remove(&path); - updates.removed_nodes.insert(path); - } else if self - .branch_node_tree_masks - .get(&path) - .is_none_or(|mask| mask.is_empty()) && - self.branch_node_hash_masks - .get(&path) - .is_none_or(|mask| mask.is_empty()) - { - // If new tree and hash masks are empty, and they were previously empty - // as well, we need to remove the node update. + } else { + // New tree and hash masks are empty - check previous state + let prev_had_masks = path_masks().is_some_and(|m| { + !m.tree_mask.is_empty() || !m.hash_mask.is_empty() + }); updates.updated_nodes.remove(&path); + if prev_had_masks { + // Previously had masks, now empty - mark as removed + updates.removed_nodes.insert(path); + } } store_in_db_trie @@ -2223,8 +2216,7 @@ mod find_leaf_tests { let sparse = SerialSparseTrie { nodes, - branch_node_tree_masks: Default::default(), - branch_node_hash_masks: Default::default(), + branch_node_masks: Default::default(), /* The value is not in the values map, or else it would early return */ values: Default::default(), prefix_set: Default::default(), @@ -2266,8 +2258,7 @@ mod find_leaf_tests { let sparse = SerialSparseTrie { nodes, - branch_node_tree_masks: Default::default(), - branch_node_hash_masks: Default::default(), + branch_node_masks: Default::default(), values, prefix_set: Default::default(), updates: None, @@ -2386,6 +2377,24 @@ mod tests { nibbles } + /// Extract only updated nodes (Some entries) from consolidated TrieUpdates account_nodes. + fn extract_updated_nodes( + updates: &TrieUpdates, + ) -> HashMap { + updates + .account_nodes + .iter() + .filter_map(|(k, v)| v.as_ref().map(|node| (*k, node.clone()))) + .collect() + } + + /// Extract only updated nodes from account_nodes HashMap, returning a BTreeMap. + fn extract_updated_nodes_btree( + account_nodes: alloy_primitives::map::HashMap>, + ) -> BTreeMap { + account_nodes.into_iter().filter_map(|(k, v)| v.map(|node| (k, node))).collect() + } + /// Calculate the state root by feeding the provided state to the hash builder and retaining the /// proofs for the provided targets. /// @@ -2533,7 +2542,7 @@ mod tests { let sparse_updates = sparse.take_updates(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); } @@ -2566,7 +2575,7 @@ mod tests { let sparse_updates = sparse.take_updates(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); } @@ -2597,7 +2606,7 @@ mod tests { let sparse_updates = sparse.take_updates(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); } @@ -2638,7 +2647,7 @@ mod tests { assert_eq!(sparse_root, hash_builder_root); pretty_assertions::assert_eq!( BTreeMap::from_iter(sparse_updates.updated_nodes), - BTreeMap::from_iter(hash_builder_updates.account_nodes) + BTreeMap::from_iter(extract_updated_nodes(&hash_builder_updates)) ); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); } @@ -2676,7 +2685,7 @@ mod tests { let sparse_updates = sparse.updates_ref(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) = @@ -2694,7 +2703,7 @@ mod tests { let sparse_updates = sparse.take_updates(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes); } @@ -3072,8 +3081,9 @@ mod tests { state.keys().copied(), ); - // Extract account nodes before moving hash_builder_updates - let hash_builder_account_nodes = hash_builder_updates.account_nodes.clone(); + // Extract updated account nodes before moving hash_builder_updates + let hash_builder_account_nodes = + extract_updated_nodes_btree(hash_builder_updates.account_nodes.clone()); // Write trie updates to the database let provider_rw = provider_factory.provider_rw().unwrap(); @@ -3085,7 +3095,7 @@ mod tests { // Assert that the sparse trie updates match the hash builder updates pretty_assertions::assert_eq!( BTreeMap::from_iter(sparse_updates.updated_nodes), - BTreeMap::from_iter(hash_builder_account_nodes) + hash_builder_account_nodes ); // Assert that the sparse trie nodes match the hash builder proof nodes assert_eq_sparse_trie_proof_nodes(&updated_sparse, hash_builder_proof_nodes); @@ -3117,8 +3127,9 @@ mod tests { state.keys().copied(), ); - // Extract account nodes before moving hash_builder_updates - let hash_builder_account_nodes = hash_builder_updates.account_nodes.clone(); + // Extract updated account nodes before moving hash_builder_updates + let hash_builder_account_nodes = + extract_updated_nodes_btree(hash_builder_updates.account_nodes.clone()); // Write trie updates to the database let provider_rw = provider_factory.provider_rw().unwrap(); @@ -3130,7 +3141,7 @@ mod tests { // Assert that the sparse trie updates match the hash builder updates pretty_assertions::assert_eq!( BTreeMap::from_iter(sparse_updates.updated_nodes), - BTreeMap::from_iter(hash_builder_account_nodes) + hash_builder_account_nodes ); // Assert that the sparse trie nodes match the hash builder proof nodes assert_eq_sparse_trie_proof_nodes(&updated_sparse, hash_builder_proof_nodes); @@ -3599,7 +3610,7 @@ mod tests { let sparse_updates = sparse.take_updates(); assert_eq!(sparse_root, hash_builder_root); - assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes); + assert_eq!(sparse_updates.updated_nodes, extract_updated_nodes(&hash_builder_updates)); } #[test] diff --git a/crates/trie/trie/src/node_iter.rs b/crates/trie/trie/src/node_iter.rs index 7d53bd4b6d..518bfdcc49 100644 --- a/crates/trie/trie/src/node_iter.rs +++ b/crates/trie/trie/src/node_iter.rs @@ -376,7 +376,12 @@ mod tests { let mut trie_updates = TrieUpdates::default(); trie_updates.finalize(hash_builder, Default::default(), Default::default()); - trie_updates.account_nodes + // Extract only updated nodes (Some), not removed nodes (None) + trie_updates + .account_nodes + .into_iter() + .filter_map(|(k, v)| v.map(|node| (k, node))) + .collect() } #[test] diff --git a/crates/trie/trie/src/trie.rs b/crates/trie/trie/src/trie.rs index 17cdd1f96c..a89074f0ed 100644 --- a/crates/trie/trie/src/trie.rs +++ b/crates/trie/trie/src/trie.rs @@ -369,9 +369,13 @@ impl StateRootContext { K: AsRef, { let (walker_stack, walker_deleted_keys) = account_node_iter.walker.split(); - self.trie_updates.removed_nodes.extend(walker_deleted_keys); + // Add removed nodes as None entries + self.trie_updates.account_nodes.extend(walker_deleted_keys.into_iter().map(|k| (k, None))); let (hash_builder, hash_builder_updates) = hash_builder.split(); - self.trie_updates.account_nodes.extend(hash_builder_updates); + // Add updated nodes as Some entries + self.trie_updates + .account_nodes + .extend(hash_builder_updates.into_iter().map(|(k, v)| (k, Some(v)))); let account_state = IntermediateRootState { hash_builder, walker_stack, last_hashed_key }; diff --git a/crates/trie/trie/src/trie_cursor/mock.rs b/crates/trie/trie/src/trie_cursor/mock.rs index 5f29a6734b..db26ab3be2 100644 --- a/crates/trie/trie/src/trie_cursor/mock.rs +++ b/crates/trie/trie/src/trie_cursor/mock.rs @@ -40,9 +40,12 @@ impl MockTrieCursorFactory { /// Creates a new mock trie cursor factory from `TrieUpdates`. pub fn from_trie_updates(updates: TrieUpdates) -> Self { - // Convert account nodes from HashMap to BTreeMap - let account_trie_nodes: BTreeMap = - updates.account_nodes.into_iter().collect(); + // Convert account nodes from HashMap to BTreeMap (only updated nodes, not removed) + let account_trie_nodes: BTreeMap = updates + .account_nodes + .into_iter() + .filter_map(|(k, v)| v.map(|node| (k, node))) + .collect(); // Convert storage tries let storage_tries: B256Map> = updates diff --git a/crates/trie/trie/src/verify.rs b/crates/trie/trie/src/verify.rs index 4299a66916..e72f6db909 100644 --- a/crates/trie/trie/src/verify.rs +++ b/crates/trie/trie/src/verify.rs @@ -119,7 +119,10 @@ impl Iterator for StateRootBranchNodesIter { // collect account updates and sort them in descending order, so that when we pop them // off the Vec they are popped in ascending order. - self.account_nodes.extend(updates.account_nodes); + // Only include updated nodes (Some), not removed nodes (None) + self.account_nodes.extend( + updates.account_nodes.into_iter().filter_map(|(k, v)| v.map(|node| (k, node))), + ); Self::sort_updates(&mut self.account_nodes); self.storage_tries = updates