diff --git a/crates/primitives/src/trie/nibbles.rs b/crates/primitives/src/trie/nibbles.rs index 313eefbaf3..65bbf2df5e 100644 --- a/crates/primitives/src/trie/nibbles.rs +++ b/crates/primitives/src/trie/nibbles.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; /// The nibbles are the keys for the AccountsTrie and the subkeys for the StorageTrie. #[main_codec] -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct StoredNibbles { /// The inner nibble bytes pub inner: Bytes, @@ -18,7 +18,7 @@ impl From> for StoredNibbles { } /// The representation of nibbles of the merkle trie stored in the database. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Deref)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Hash, Deref)] pub struct StoredNibblesSubKey(StoredNibbles); impl From> for StoredNibblesSubKey { diff --git a/crates/stages/src/stages/merkle.rs b/crates/stages/src/stages/merkle.rs index 6385d214ec..7e46a8bade 100644 --- a/crates/stages/src/stages/merkle.rs +++ b/crates/stages/src/stages/merkle.rs @@ -195,7 +195,7 @@ impl Stage for MerkleStage { match progress { StateRootProgress::Progress(state, updates) => { updates.flush(tx.deref_mut())?; - self.save_execution_checkpoint(tx, Some(state.into()))?; + self.save_execution_checkpoint(tx, Some((*state).into()))?; return Ok(ExecOutput { stage_progress: input.stage_progress(), done: false }) } StateRootProgress::Complete(root, updates) => { diff --git a/crates/trie/src/hash_builder.rs b/crates/trie/src/hash_builder.rs index 4e2801f492..d99fcf0796 100644 --- a/crates/trie/src/hash_builder.rs +++ b/crates/trie/src/hash_builder.rs @@ -8,7 +8,7 @@ use reth_primitives::{ trie::{BranchNodeCompact, HashBuilderState, HashBuilderValue, TrieMask}, H256, }; -use std::{collections::BTreeMap, fmt::Debug}; +use std::{collections::HashMap, fmt::Debug}; /// A component used to construct the root hash of the trie. The primary purpose of a Hash Builder /// is to build the Merkle proof that is essential for verifying the integrity and authenticity of @@ -45,7 +45,7 @@ pub struct HashBuilder { stored_in_database: bool, - updated_branch_nodes: Option>, + updated_branch_nodes: Option>, } impl From for HashBuilder { @@ -91,12 +91,12 @@ impl HashBuilder { /// Call [HashBuilder::split] to get the updates to branch nodes. pub fn set_updates(&mut self, retain_updates: bool) { if retain_updates { - self.updated_branch_nodes = Some(BTreeMap::default()); + self.updated_branch_nodes = Some(HashMap::default()); } } /// Splits the [HashBuilder] into a [HashBuilder] and hash builder updates. - pub fn split(mut self) -> (Self, BTreeMap) { + pub fn split(mut self) -> (Self, HashMap) { let updates = self.updated_branch_nodes.take(); (self, updates.unwrap_or_default()) } diff --git a/crates/trie/src/nibbles.rs b/crates/trie/src/nibbles.rs index b3699dce68..fd3660ecb7 100644 --- a/crates/trie/src/nibbles.rs +++ b/crates/trie/src/nibbles.rs @@ -20,6 +20,7 @@ use reth_rlp::RlpEncodableWrapper; RlpEncodableWrapper, PartialOrd, Ord, + Hash, Index, From, Deref, diff --git a/crates/trie/src/progress.rs b/crates/trie/src/progress.rs index 41fa61df1e..20175bf9b5 100644 --- a/crates/trie/src/progress.rs +++ b/crates/trie/src/progress.rs @@ -8,7 +8,7 @@ pub enum StateRootProgress { Complete(H256, TrieUpdates), /// The intermediate progress of state root computation. /// Contains the walker stack, the hash builder and the trie updates. - Progress(IntermediateStateRootState, TrieUpdates), + Progress(Box, TrieUpdates), } /// The intermediate state of the state root computation. diff --git a/crates/trie/src/trie.rs b/crates/trie/src/trie.rs index 118105298e..1b80091f2f 100644 --- a/crates/trie/src/trie.rs +++ b/crates/trie/src/trie.rs @@ -256,8 +256,8 @@ impl<'a, 'tx, TX: DbTx<'tx>> StateRoot<'a, TX> { ); let storage_root = if retain_updates { - let (root, mut updates) = storage_root_calculator.root_with_updates()?; - trie_updates.append(&mut updates); + let (root, updates) = storage_root_calculator.root_with_updates()?; + trie_updates.extend(updates.into_iter()); root } else { storage_root_calculator.root()? @@ -273,7 +273,7 @@ impl<'a, 'tx, TX: DbTx<'tx>> StateRoot<'a, TX> { let total_updates_len = trie_updates.len() + walker.updates_len() + hash_builder.updates_len(); if retain_updates && total_updates_len as u64 >= self.threshold { - let (walker_stack, mut walker_updates) = walker.split(); + let (walker_stack, walker_updates) = walker.split(); let (hash_builder, hash_builder_updates) = hash_builder.split(); let state = IntermediateStateRootState { @@ -283,10 +283,10 @@ impl<'a, 'tx, TX: DbTx<'tx>> StateRoot<'a, TX> { last_account_key: hashed_address, }; - trie_updates.append(&mut walker_updates); + trie_updates.extend(walker_updates.into_iter()); trie_updates.extend_with_account_updates(hash_builder_updates); - return Ok(StateRootProgress::Progress(state, trie_updates)) + return Ok(StateRootProgress::Progress(Box::new(state), trie_updates)) } // Move the next account entry @@ -296,10 +296,10 @@ impl<'a, 'tx, TX: DbTx<'tx>> StateRoot<'a, TX> { let root = hash_builder.root(); - let (_, mut walker_updates) = walker.split(); + let (_, walker_updates) = walker.split(); let (_, hash_builder_updates) = hash_builder.split(); - trie_updates.append(&mut walker_updates); + trie_updates.extend(walker_updates.into_iter()); trie_updates.extend_with_account_updates(hash_builder_updates); Ok(StateRootProgress::Complete(root, trie_updates)) @@ -406,10 +406,10 @@ impl<'a, 'tx, TX: DbTx<'tx>> StorageRoot<'a, TX> { let root = hash_builder.root(); let (_, hash_builder_updates) = hash_builder.split(); - let (_, mut walker_updates) = walker.split(); + let (_, walker_updates) = walker.split(); let mut trie_updates = TrieUpdates::default(); - trie_updates.append(&mut walker_updates); + trie_updates.extend(walker_updates.into_iter()); trie_updates.extend_with_storage_updates(self.hashed_address, hash_builder_updates); tracing::debug!(target: "trie::storage_root", ?root, hashed_address = ?self.hashed_address, "calculated storage root"); @@ -662,11 +662,11 @@ mod tests { let threshold = 10; let mut got = None; - let mut intermediate_state = None; + let mut intermediate_state: Option> = None; while got.is_none() { let calculator = StateRoot::new(tx.deref_mut()) .with_threshold(threshold) - .with_intermediate_state(intermediate_state.take()); + .with_intermediate_state(intermediate_state.take().map(|state| *state)); match calculator.root_with_progress().unwrap() { StateRootProgress::Progress(state, _updates) => intermediate_state = Some(state), StateRootProgress::Complete(root, _updates) => got = Some(root), @@ -845,13 +845,14 @@ mod tests { assert_eq!(root, computed_expected_root); // Check account trie - let account_updates = trie_updates + let mut account_updates = trie_updates .iter() .filter_map(|(k, v)| match (k, v) { (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), _ => None, }) .collect::>(); + account_updates.sort_unstable_by(|a, b| a.0.cmp(b.0)); assert_eq!(account_updates.len(), 2); let (nibbles1a, node1a) = account_updates.first().unwrap(); @@ -911,13 +912,14 @@ mod tests { .unwrap(); assert_eq!(root, expected_state_root); - let account_updates = trie_updates + let mut account_updates = trie_updates .iter() .filter_map(|entry| match entry { (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), _ => None, }) .collect::>(); + account_updates.sort_by(|a, b| a.0.cmp(b.0)); assert_eq!(account_updates.len(), 2); let (nibbles1b, node1b) = account_updates.first().unwrap(); @@ -1060,7 +1062,7 @@ mod tests { } _ => None, }) - .collect::>(); + .collect::>(); assert_trie_updates(&account_updates); } @@ -1080,7 +1082,7 @@ mod tests { // read the account updates from the db let mut accounts_trie = tx.cursor_read::().unwrap(); let walker = accounts_trie.walk(None).unwrap(); - let mut account_updates = BTreeMap::new(); + let mut account_updates = HashMap::new(); for item in walker { let (key, node) = item.unwrap(); account_updates.insert(key.inner[..].into(), node); @@ -1148,7 +1150,7 @@ mod tests { } _ => None, }) - .collect::>(); + .collect::>(); assert_eq!(expected_updates, storage_updates); assert_trie_updates(&storage_updates); @@ -1157,7 +1159,7 @@ mod tests { fn extension_node_storage_trie( tx: &mut Transaction<'_, Env>, hashed_address: H256, - ) -> (H256, BTreeMap) { + ) -> (H256, HashMap) { let value = U256::from(1); let mut hashed_storage = tx.cursor_write::().unwrap(); @@ -1208,7 +1210,7 @@ mod tests { hb.root() } - fn assert_trie_updates(account_updates: &BTreeMap) { + fn assert_trie_updates(account_updates: &HashMap) { assert_eq!(account_updates.len(), 2); let node = account_updates.get(&vec![0x3].into()).unwrap(); diff --git a/crates/trie/src/updates.rs b/crates/trie/src/updates.rs index afd635a26e..847fd4d19c 100644 --- a/crates/trie/src/updates.rs +++ b/crates/trie/src/updates.rs @@ -9,10 +9,10 @@ use reth_primitives::{ trie::{BranchNodeCompact, StorageTrieEntry, StoredNibbles, StoredNibblesSubKey}, H256, }; -use std::collections::BTreeMap; +use std::collections::{hash_map::IntoIter, HashMap}; /// The key of a trie node. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum TrieKey { /// A node in the account trie. AccountNode(StoredNibbles), @@ -41,12 +41,21 @@ impl TrieOp { /// The aggregation of trie updates. #[derive(Debug, Default, Clone, Deref)] pub struct TrieUpdates { - trie_operations: BTreeMap, + trie_operations: HashMap, } impl From<[(TrieKey, TrieOp); N]> for TrieUpdates { fn from(value: [(TrieKey, TrieOp); N]) -> Self { - Self { trie_operations: BTreeMap::from(value) } + Self { trie_operations: HashMap::from(value) } + } +} + +impl IntoIterator for TrieUpdates { + type Item = (TrieKey, TrieOp); + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.trie_operations.into_iter() } } @@ -63,18 +72,13 @@ impl TrieUpdates { } } - /// Append the updates to the current updates. - pub fn append(&mut self, other: &mut Self) { - self.trie_operations.append(&mut other.trie_operations); - } - /// Extend the updates with trie updates. pub fn extend(&mut self, updates: impl Iterator) { self.trie_operations.extend(updates); } /// Extend the updates with account trie updates. - pub fn extend_with_account_updates(&mut self, updates: BTreeMap) { + pub fn extend_with_account_updates(&mut self, updates: HashMap) { self.extend(updates.into_iter().map(|(nibbles, node)| { (TrieKey::AccountNode(nibbles.hex_data.into()), TrieOp::Update(node)) })); @@ -84,7 +88,7 @@ impl TrieUpdates { pub fn extend_with_storage_updates( &mut self, hashed_address: H256, - updates: BTreeMap, + updates: HashMap, ) { self.extend(updates.into_iter().map(|(nibbles, node)| { (TrieKey::StorageNode(hashed_address, nibbles.hex_data.into()), TrieOp::Update(node)) @@ -103,7 +107,9 @@ impl TrieUpdates { let mut account_trie_cursor = tx.cursor_write::()?; let mut storage_trie_cursor = tx.cursor_dup_write::()?; - for (key, operation) in self.trie_operations { + let mut trie_operations = Vec::from_iter(self.trie_operations.into_iter()); + trie_operations.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + for (key, operation) in trie_operations { match key { TrieKey::AccountNode(nibbles) => match operation { TrieOp::Delete => {