diff --git a/bin/reth/src/drop_stage.rs b/bin/reth/src/drop_stage.rs index c3a77ec3e2..9ef004a64d 100644 --- a/bin/reth/src/drop_stage.rs +++ b/bin/reth/src/drop_stage.rs @@ -98,6 +98,7 @@ impl Command { tx.clear::()?; tx.put::(MERKLE_EXECUTION.0.to_string(), 0)?; tx.put::(MERKLE_UNWIND.0.to_string(), 0)?; + tx.delete::(MERKLE_EXECUTION.0.into(), None)?; Ok::<_, eyre::Error>(()) })??; } diff --git a/crates/primitives/src/checkpoints.rs b/crates/primitives/src/checkpoints.rs index 62a90bfa0a..dd0f8e93c5 100644 --- a/crates/primitives/src/checkpoints.rs +++ b/crates/primitives/src/checkpoints.rs @@ -1,18 +1,70 @@ -use crate::{Address, H256}; +use crate::{ + trie::{HashBuilderState, StoredSubNode}, + Address, H256, +}; +use bytes::Buf; use reth_codecs::{main_codec, Compact}; -/// Saves the progress of MerkleStage -#[main_codec] -#[derive(Default, Debug, Copy, Clone, PartialEq)] -pub struct ProofCheckpoint { - /// The next hashed account to insert into the trie. - pub hashed_address: Option, - /// The next storage entry to insert into the trie. - pub storage_key: Option, - /// Current intermediate root for `AccountsTrie`. - pub account_root: Option, - /// Current intermediate storage root from an account. - pub storage_root: Option, +/// Saves the progress of Merkle stage. +#[derive(Default, Debug, Clone, PartialEq)] +pub struct MerkleCheckpoint { + // TODO: target block? + /// The last hashed account key processed. + pub last_account_key: H256, + /// The last walker key processed. + pub last_walker_key: Vec, + /// Previously recorded walker stack. + pub walker_stack: Vec, + /// The hash builder state. + pub state: HashBuilderState, +} + +impl Compact for MerkleCheckpoint { + fn to_compact(self, buf: &mut B) -> usize + where + B: bytes::BufMut + AsMut<[u8]>, + { + let mut len = 0; + + buf.put_slice(self.last_account_key.as_slice()); + len += self.last_account_key.len(); + + buf.put_u16(self.last_walker_key.len() as u16); + buf.put_slice(&self.last_walker_key[..]); + len += 2 + self.last_walker_key.len(); + + buf.put_u16(self.walker_stack.len() as u16); + len += 2; + for item in self.walker_stack.into_iter() { + len += item.to_compact(buf); + } + + len += self.state.to_compact(buf); + len + } + + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) + where + Self: Sized, + { + let last_account_key = H256::from_slice(&buf[..32]); + buf.advance(32); + + let last_walker_key_len = buf.get_u16() as usize; + let last_walker_key = Vec::from(&buf[..last_walker_key_len]); + buf.advance(last_walker_key_len); + + let walker_stack_len = buf.get_u16() as usize; + let mut walker_stack = Vec::with_capacity(walker_stack_len); + for _ in 0..walker_stack_len { + let (item, rest) = StoredSubNode::from_compact(buf, 0); + walker_stack.push(item); + buf = rest; + } + + let (state, buf) = HashBuilderState::from_compact(buf, 0); + (MerkleCheckpoint { last_account_key, last_walker_key, walker_stack, state }, buf) + } } /// Saves the progress of AccountHashing diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index 99b7c8d361..2ee5fa0681 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -48,7 +48,7 @@ pub use chain::{ AllGenesisFormats, Chain, ChainInfo, ChainSpec, ChainSpecBuilder, ForkCondition, GOERLI, MAINNET, SEPOLIA, }; -pub use checkpoints::{AccountHashingCheckpoint, ProofCheckpoint, StorageHashingCheckpoint}; +pub use checkpoints::{AccountHashingCheckpoint, MerkleCheckpoint, StorageHashingCheckpoint}; pub use constants::{ EMPTY_OMMER_ROOT, GOERLI_GENESIS, KECCAK_EMPTY, MAINNET_GENESIS, SEPOLIA_GENESIS, }; diff --git a/crates/primitives/src/trie/branch_node.rs b/crates/primitives/src/trie/branch_node.rs index 00f171cd14..b6c30caf17 100644 --- a/crates/primitives/src/trie/branch_node.rs +++ b/crates/primitives/src/trie/branch_node.rs @@ -1,5 +1,6 @@ use super::TrieMask; use crate::H256; +use bytes::Buf; use reth_codecs::Compact; use serde::{Deserialize, Serialize}; @@ -88,7 +89,7 @@ impl Compact for BranchNodeCompact { buf_size } - fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) where Self: Sized, { @@ -98,9 +99,9 @@ impl Compact for BranchNodeCompact { assert_eq!(buf.len() % hash_len, 6); // Consume the masks. - let (state_mask, buf) = TrieMask::from_compact(buf, len); - let (tree_mask, buf) = TrieMask::from_compact(buf, len); - let (hash_mask, buf) = TrieMask::from_compact(buf, len); + let (state_mask, buf) = TrieMask::from_compact(buf, 0); + let (tree_mask, buf) = TrieMask::from_compact(buf, 0); + let (hash_mask, buf) = TrieMask::from_compact(buf, 0); let mut buf = buf; let mut num_hashes = buf.len() / hash_len; @@ -108,18 +109,16 @@ impl Compact for BranchNodeCompact { // Check if the root hash is present if hash_mask.count_ones() as usize + 1 == num_hashes { - let (hash, remaining) = H256::from_compact(buf, hash_len); - root_hash = Some(hash); - buf = remaining; + root_hash = Some(H256::from_slice(&buf[..hash_len])); + buf.advance(hash_len); num_hashes -= 1; } // Consume all remaining hashes. let mut hashes = Vec::::with_capacity(num_hashes); for _ in 0..num_hashes { - let (hash, remaining) = H256::from_compact(buf, hash_len); - hashes.push(hash); - buf = remaining; + hashes.push(H256::from_slice(&buf[..hash_len])); + buf.advance(hash_len); } (Self::new(state_mask, tree_mask, hash_mask, hashes, root_hash), buf) diff --git a/crates/primitives/src/trie/hash_builder.rs b/crates/primitives/src/trie/hash_builder.rs new file mode 100644 index 0000000000..6944b2a788 --- /dev/null +++ b/crates/primitives/src/trie/hash_builder.rs @@ -0,0 +1,219 @@ +use super::TrieMask; +use crate::H256; +use bytes::Buf; +use reth_codecs::{derive_arbitrary, Compact}; +use serde::{Deserialize, Serialize}; + +/// The hash builder state for storing in the database. +/// Check the `reth-trie` crate for more info on hash builder. +#[derive_arbitrary(compact)] +#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] +pub struct HashBuilderState { + /// The current key. + pub key: Vec, + /// The builder stack. + pub stack: Vec>, + /// The current node value. + pub value: HashBuilderValue, + + /// Group masks. + pub groups: Vec, + /// Tree masks. + pub tree_masks: Vec, + /// Hash masks. + pub hash_masks: Vec, + + /// Flag indicating if the current node is stored in the database. + pub stored_in_database: bool, +} + +impl Compact for HashBuilderState { + fn to_compact(self, buf: &mut B) -> usize + where + B: bytes::BufMut + AsMut<[u8]>, + { + let mut len = 0; + + len += self.key.to_compact(buf); + + buf.put_u16(self.stack.len() as u16); + len += 2; + for item in self.stack.iter() { + buf.put_u16(item.len() as u16); + buf.put_slice(&item[..]); + len += 2 + item.len(); + } + + len += self.value.to_compact(buf); + + buf.put_u16(self.groups.len() as u16); + len += 2; + for item in self.groups.iter() { + len += item.to_compact(buf); + } + + buf.put_u16(self.tree_masks.len() as u16); + len += 2; + for item in self.tree_masks.iter() { + len += item.to_compact(buf); + } + + buf.put_u16(self.hash_masks.len() as u16); + len += 2; + for item in self.hash_masks.iter() { + len += item.to_compact(buf); + } + + buf.put_u8(self.stored_in_database as u8); + len += 1; + len + } + + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) + where + Self: Sized, + { + let (key, mut buf) = Vec::from_compact(buf, 0); + + let stack_len = buf.get_u16() as usize; + let mut stack = Vec::with_capacity(stack_len); + for _ in 0..stack_len { + let item_len = buf.get_u16() as usize; + stack.push(Vec::from(&buf[..item_len])); + buf.advance(item_len); + } + + let (value, mut buf) = HashBuilderValue::from_compact(buf, 0); + + let groups_len = buf.get_u16() as usize; + let mut groups = Vec::with_capacity(groups_len); + for _ in 0..groups_len { + let (item, rest) = TrieMask::from_compact(buf, 0); + groups.push(item); + buf = rest; + } + + let tree_masks_len = buf.get_u16() as usize; + let mut tree_masks = Vec::with_capacity(tree_masks_len); + for _ in 0..tree_masks_len { + let (item, rest) = TrieMask::from_compact(buf, 0); + tree_masks.push(item); + buf = rest; + } + + let hash_masks_len = buf.get_u16() as usize; + let mut hash_masks = Vec::with_capacity(hash_masks_len); + for _ in 0..hash_masks_len { + let (item, rest) = TrieMask::from_compact(buf, 0); + hash_masks.push(item); + buf = rest; + } + + let stored_in_database = buf.get_u8() != 0; + (Self { key, stack, value, groups, tree_masks, hash_masks, stored_in_database }, buf) + } +} + +/// The current value of the hash builder. +#[derive_arbitrary(compact)] +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub enum HashBuilderValue { + /// Value of the leaf node. + Hash(H256), + /// Hash of adjacent nodes. + Bytes(Vec), +} + +impl Compact for HashBuilderValue { + fn to_compact(self, buf: &mut B) -> usize + where + B: bytes::BufMut + AsMut<[u8]>, + { + match self { + Self::Hash(hash) => { + buf.put_u8(0); + 1 + hash.to_compact(buf) + } + Self::Bytes(bytes) => { + buf.put_u8(1); + 1 + bytes.to_compact(buf) + } + } + } + + fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) + where + Self: Sized, + { + match buf[0] { + 0 => { + let (hash, buf) = H256::from_compact(&buf[1..], 32); + (Self::Hash(hash), buf) + } + 1 => { + let (bytes, buf) = Vec::from_compact(&buf[1..], 0); + (Self::Bytes(bytes), buf) + } + _ => panic!("Invalid hash builder value"), + } + } +} + +impl std::fmt::Debug for HashBuilderValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bytes(bytes) => write!(f, "Bytes({:?})", hex::encode(bytes)), + Self::Hash(hash) => write!(f, "Hash({:?})", hash), + } + } +} + +impl From> for HashBuilderValue { + fn from(value: Vec) -> Self { + Self::Bytes(value) + } +} + +impl From<&[u8]> for HashBuilderValue { + fn from(value: &[u8]) -> Self { + Self::Bytes(value.to_vec()) + } +} + +impl From for HashBuilderValue { + fn from(value: H256) -> Self { + Self::Hash(value) + } +} + +impl Default for HashBuilderValue { + fn default() -> Self { + Self::Bytes(vec![]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + #[test] + fn hash_builder_state_regression() { + let mut state = HashBuilderState::default(); + state.stack.push(vec![]); + let mut buf = vec![]; + let len = state.clone().to_compact(&mut buf); + let (decoded, _) = HashBuilderState::from_compact(&buf, len); + assert_eq!(state, decoded); + } + + proptest! { + #[test] + fn hash_builder_state_roundtrip(state: HashBuilderState) { + let mut buf = vec![]; + let len = state.clone().to_compact(&mut buf); + let (decoded, _) = HashBuilderState::from_compact(&buf, len); + assert_eq!(state, decoded); + } + } +} diff --git a/crates/primitives/src/trie/mask.rs b/crates/primitives/src/trie/mask.rs index 19b6864500..d54f239ad0 100644 --- a/crates/primitives/src/trie/mask.rs +++ b/crates/primitives/src/trie/mask.rs @@ -1,5 +1,6 @@ +use bytes::Buf; use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Deref, From, Not}; -use reth_codecs::Compact; +use reth_codecs::{derive_arbitrary, Compact}; use serde::{Deserialize, Serialize}; /// A struct representing a mask of 16 bits, used for Ethereum trie operations. @@ -26,6 +27,7 @@ use serde::{Deserialize, Serialize}; BitOrAssign, Not, )] +#[derive_arbitrary(compact)] pub struct TrieMask(u16); impl TrieMask { @@ -66,14 +68,15 @@ impl Compact for TrieMask { where B: bytes::BufMut + AsMut<[u8]>, { - buf.put_slice(self.to_be_bytes().as_slice()); + buf.put_u16(self.0); 2 } - fn from_compact(buf: &[u8], _len: usize) -> (Self, &[u8]) + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) where Self: Sized, { - (Self(u16::from_be_bytes(buf[..2].try_into().unwrap())), &buf[2..]) + let mask = buf.get_u16(); + (Self(mask), buf) } } diff --git a/crates/primitives/src/trie/mod.rs b/crates/primitives/src/trie/mod.rs index 771bb94a81..34a92c8824 100644 --- a/crates/primitives/src/trie/mod.rs +++ b/crates/primitives/src/trie/mod.rs @@ -1,13 +1,17 @@ //! Collection of trie related types. mod branch_node; +mod hash_builder; mod mask; mod nibbles; mod storage; +mod subnode; pub use self::{ branch_node::BranchNodeCompact, + hash_builder::{HashBuilderState, HashBuilderValue}, mask::TrieMask, nibbles::{StoredNibbles, StoredNibblesSubKey}, storage::StorageTrieEntry, + subnode::StoredSubNode, }; diff --git a/crates/primitives/src/trie/subnode.rs b/crates/primitives/src/trie/subnode.rs new file mode 100644 index 0000000000..5766e4a829 --- /dev/null +++ b/crates/primitives/src/trie/subnode.rs @@ -0,0 +1,98 @@ +use super::BranchNodeCompact; +use bytes::Buf; +use reth_codecs::Compact; + +/// Walker sub node for storing intermediate state root calculation state in the database. +/// See [crate::MerkleCheckpoint]. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct StoredSubNode { + /// The key of the current node. + pub key: Vec, + /// The index of the next child to visit. + pub nibble: Option, + /// The node itself. + pub node: Option, +} + +impl Compact for StoredSubNode { + fn to_compact(self, buf: &mut B) -> usize + where + B: bytes::BufMut + AsMut<[u8]>, + { + let mut len = 0; + + buf.put_u16(self.key.len() as u16); + buf.put_slice(&self.key[..]); + len += 2 + self.key.len(); + + if let Some(nibble) = self.nibble { + buf.put_u8(1); + buf.put_u8(nibble); + len += 2; + } else { + buf.put_u8(0); + len += 1; + } + + if let Some(node) = self.node { + buf.put_u8(1); + len += 1; + len += node.to_compact(buf); + } else { + len += 1; + buf.put_u8(0); + } + + len + } + + fn from_compact(mut buf: &[u8], _len: usize) -> (Self, &[u8]) + where + Self: Sized, + { + let key_len = buf.get_u16() as usize; + let key = Vec::from(&buf[..key_len]); + buf.advance(key_len); + + let nibbles_exists = buf.get_u8() != 0; + let nibble = if nibbles_exists { Some(buf.get_u8()) } else { None }; + + let node_exsists = buf.get_u8() != 0; + let node = if node_exsists { + let (node, rest) = BranchNodeCompact::from_compact(buf, 0); + buf = rest; + Some(node) + } else { + None + }; + + (StoredSubNode { key, nibble, node }, buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{trie::TrieMask, H256}; + + #[test] + fn subnode_roundtrip() { + let subnode = StoredSubNode { + key: vec![], + nibble: None, + node: Some(BranchNodeCompact { + state_mask: TrieMask::new(1), + tree_mask: TrieMask::new(0), + hash_mask: TrieMask::new(1), + hashes: vec![H256::zero()], + root_hash: None, + }), + }; + + let mut encoded = vec![]; + subnode.clone().to_compact(&mut encoded); + let (decoded, _) = StoredSubNode::from_compact(&encoded[..], 0); + + assert_eq!(subnode, decoded); + } +} diff --git a/crates/stages/benches/setup/mod.rs b/crates/stages/benches/setup/mod.rs index 9c04e9d42e..650bccf33c 100644 --- a/crates/stages/benches/setup/mod.rs +++ b/crates/stages/benches/setup/mod.rs @@ -121,7 +121,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { tx.insert_accounts_and_storages(start_state.clone()).unwrap(); // make first block after genesis have valid state root - let root = StateRoot::new(tx.inner().deref()).root(None).unwrap(); + let (root, updates) = StateRoot::new(tx.inner().deref()).root_with_updates().unwrap(); let second_block = blocks.get_mut(1).unwrap(); let cloned_second = second_block.clone(); let mut updated_header = cloned_second.header.unseal(); @@ -131,6 +131,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { let offset = transitions.len() as u64; tx.insert_transitions(transitions, None).unwrap(); + tx.commit(|tx| updates.flush(tx)).unwrap(); let (transitions, final_state) = random_transition_range(blocks.iter().skip(2), start_state, n_changes, key_range); @@ -142,7 +143,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf { // make last block have valid state root let root = { let mut tx_mut = tx.inner(); - let root = StateRoot::new(tx.inner().deref()).root(None).unwrap(); + let root = StateRoot::new(tx.inner().deref()).root().unwrap(); tx_mut.commit().unwrap(); root }; diff --git a/crates/stages/src/stages/merkle.rs b/crates/stages/src/stages/merkle.rs index 575a055300..6385d214ec 100644 --- a/crates/stages/src/stages/merkle.rs +++ b/crates/stages/src/stages/merkle.rs @@ -1,9 +1,14 @@ use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; -use reth_db::{database::Database, tables, transaction::DbTxMut}; +use reth_codecs::Compact; +use reth_db::{ + database::Database, + tables, + transaction::{DbTx, DbTxMut}, +}; use reth_interfaces::consensus; -use reth_primitives::{BlockNumber, H256}; +use reth_primitives::{hex, BlockNumber, MerkleCheckpoint, H256}; use reth_provider::Transaction; -use reth_trie::StateRoot; +use reth_trie::{IntermediateStateRootState, StateRoot, StateRootProgress}; use std::{fmt::Debug, ops::DerefMut}; use tracing::*; @@ -82,6 +87,42 @@ impl MerkleStage { }) } } + + /// Gets the hashing progress + pub fn get_execution_checkpoint( + &self, + tx: &Transaction<'_, DB>, + ) -> Result, StageError> { + let buf = + tx.get::(MERKLE_EXECUTION.0.into())?.unwrap_or_default(); + + if buf.is_empty() { + return Ok(None) + } + + let (checkpoint, _) = MerkleCheckpoint::from_compact(&buf, buf.len()); + Ok(Some(checkpoint)) + } + + /// Saves the hashing progress + pub fn save_execution_checkpoint( + &mut self, + tx: &Transaction<'_, DB>, + checkpoint: Option, + ) -> Result<(), StageError> { + let mut buf = vec![]; + if let Some(checkpoint) = checkpoint { + debug!( + target: "sync::stages::merkle::exec", + last_account_key = ?checkpoint.last_account_key, + last_walker_key = ?hex::encode(&checkpoint.last_walker_key), + "Saving inner merkle checkpoint" + ); + checkpoint.to_compact(&mut buf); + } + tx.put::(MERKLE_EXECUTION.0.into(), buf)?; + Ok(()) + } } #[async_trait::async_trait] @@ -121,21 +162,58 @@ impl Stage for MerkleStage { let block_root = tx.get_header(current_blook)?.state_root; + let checkpoint = self.get_execution_checkpoint(tx)?; + let trie_root = if range.is_empty() { block_root } else if to_block - from_block > threshold || from_block == 1 { // if there are more blocks than threshold it is faster to rebuild the trie - debug!(target: "sync::stages::merkle::exec", current = ?current_blook, target = ?to_block, "Rebuilding trie"); - tx.clear::()?; - tx.clear::()?; - StateRoot::new(tx.deref_mut()).root(None).map_err(|e| StageError::Fatal(Box::new(e)))? + if let Some(checkpoint) = &checkpoint { + debug!( + target: "sync::stages::merkle::exec", + current = ?current_blook, + target = ?to_block, + last_account_key = ?checkpoint.last_account_key, + last_walker_key = ?hex::encode(&checkpoint.last_walker_key), + "Continuing inner merkle checkpoint" + ); + } else { + debug!( + target: "sync::stages::merkle::exec", + current = ?current_blook, + target = ?to_block, + "Rebuilding trie" + ); + tx.clear::()?; + tx.clear::()?; + } + + let progress = StateRoot::new(tx.deref_mut()) + .with_intermediate_state(checkpoint.map(IntermediateStateRootState::from)) + .root_with_progress() + .map_err(|e| StageError::Fatal(Box::new(e)))?; + match progress { + StateRootProgress::Progress(state, updates) => { + updates.flush(tx.deref_mut())?; + self.save_execution_checkpoint(tx, Some(state.into()))?; + return Ok(ExecOutput { stage_progress: input.stage_progress(), done: false }) + } + StateRootProgress::Complete(root, updates) => { + updates.flush(tx.deref_mut())?; + root + } + } } else { - debug!(target: "sync::stages::merkle::exec", current = ?current_blook, target = - ?to_block, "Updating trie"); // Iterate over - StateRoot::incremental_root(tx.deref_mut(), range, None) - .map_err(|e| StageError::Fatal(Box::new(e)))? + debug!(target: "sync::stages::merkle::exec", current = ?current_blook, target = ?to_block, "Updating trie"); + let (root, updates) = StateRoot::incremental_root_with_updates(tx.deref_mut(), range) + .map_err(|e| StageError::Fatal(Box::new(e)))?; + updates.flush(tx.deref_mut())?; + root }; + // Reset the checkpoint + self.save_execution_checkpoint(tx, None)?; + self.validate_state_root(trie_root, block_root, to_block)?; info!(target: "sync::stages::merkle::exec", "Stage finished"); @@ -162,10 +240,16 @@ impl Stage for MerkleStage { // Unwind trie only if there are transitions if !range.is_empty() { - let block_root = StateRoot::incremental_root(tx.deref_mut(), range, None) - .map_err(|e| StageError::Fatal(Box::new(e)))?; + let (block_root, updates) = + StateRoot::incremental_root_with_updates(tx.deref_mut(), range) + .map_err(|e| StageError::Fatal(Box::new(e)))?; + + // Validate the calulated state root let target_root = tx.get_header(input.unwind_to)?.state_root; self.validate_state_root(block_root, target_root, input.unwind_to)?; + + // Validation passed, apply unwind changes to the database. + updates.flush(tx.deref_mut())?; } else { info!(target: "sync::stages::merkle::unwind", "Nothing to unwind"); } diff --git a/crates/storage/codecs/src/lib.rs b/crates/storage/codecs/src/lib.rs index b0c36f3ad5..b3d54a0fa5 100644 --- a/crates/storage/codecs/src/lib.rs +++ b/crates/storage/codecs/src/lib.rs @@ -78,7 +78,7 @@ macro_rules! impl_uint_compact { }; } -impl_uint_compact!(u64, u128); +impl_uint_compact!(u8, u64, u128); impl Compact for Vec where diff --git a/crates/storage/db/src/tables/codecs/compact.rs b/crates/storage/db/src/tables/codecs/compact.rs index b34ada2835..de9eeb447b 100644 --- a/crates/storage/db/src/tables/codecs/compact.rs +++ b/crates/storage/db/src/tables/codecs/compact.rs @@ -45,8 +45,7 @@ impl_compression_for_compact!( StoredBlockBodyIndices, StoredBlockOmmers, StoredBlockWithdrawals, - Bytecode, - ProofCheckpoint + Bytecode ); impl_compression_for_compact!(AccountBeforeTx, TransactionSignedNoHash); impl_compression_for_compact!(CompactU256); diff --git a/crates/storage/provider/src/transaction.rs b/crates/storage/provider/src/transaction.rs index a331975fdc..7047ecbc3b 100644 --- a/crates/storage/provider/src/transaction.rs +++ b/crates/storage/provider/src/transaction.rs @@ -581,7 +581,8 @@ where // merkle tree { - let state_root = StateRoot::incremental_root(self.deref_mut(), range.clone(), None)?; + let (state_root, trie_updates) = + StateRoot::incremental_root_with_updates(self.deref_mut(), range.clone())?; if state_root != expected_state_root { return Err(TransactionError::StateTrieRootMismatch { got: state_root, @@ -590,6 +591,7 @@ where block_hash: end_block_hash, }) } + trie_updates.flush(self.deref_mut())?; } Ok(()) } @@ -983,7 +985,8 @@ where self.unwind_storage_history_indices(storage_range)?; // merkle tree - let new_state_root = StateRoot::incremental_root(self.deref(), range.clone(), None)?; + let (new_state_root, trie_updates) = + StateRoot::incremental_root_with_updates(self.deref(), range.clone())?; let parent_number = range.start().saturating_sub(1); let parent_state_root = self.get_header(parent_number)?.state_root; @@ -999,6 +1002,7 @@ where block_hash: parent_hash, }) } + trie_updates.flush(self.deref())?; } // get blocks let blocks = self.get_take_block_range::(chain_spec, range.clone())?; diff --git a/crates/trie/src/cursor/account_cursor.rs b/crates/trie/src/cursor/account_cursor.rs index aeb2a90458..82639d0756 100644 --- a/crates/trie/src/cursor/account_cursor.rs +++ b/crates/trie/src/cursor/account_cursor.rs @@ -1,8 +1,6 @@ use super::TrieCursor; -use reth_db::{ - cursor::{DbCursorRO, DbCursorRW}, - tables, Error, -}; +use crate::updates::TrieKey; +use reth_db::{cursor::DbCursorRO, tables, Error}; use reth_primitives::trie::{BranchNodeCompact, StoredNibbles}; /// A cursor over the account trie. @@ -17,7 +15,7 @@ impl AccountTrieCursor { impl<'a, C> TrieCursor for AccountTrieCursor where - C: DbCursorRO<'a, tables::AccountsTrie> + DbCursorRW<'a, tables::AccountsTrie>, + C: DbCursorRO<'a, tables::AccountsTrie>, { fn seek_exact( &mut self, @@ -30,12 +28,8 @@ where Ok(self.0.seek(key)?.map(|value| (value.0.inner.to_vec(), value.1))) } - fn upsert(&mut self, key: StoredNibbles, value: BranchNodeCompact) -> Result<(), Error> { - self.0.upsert(key, value) - } - - fn delete_current(&mut self) -> Result<(), Error> { - self.0.delete_current() + fn current(&mut self) -> Result, Error> { + Ok(self.0.current()?.map(|(k, _)| TrieKey::AccountNode(k))) } } diff --git a/crates/trie/src/cursor/storage_cursor.rs b/crates/trie/src/cursor/storage_cursor.rs index 03e9190f5b..4cabb8d9d3 100644 --- a/crates/trie/src/cursor/storage_cursor.rs +++ b/crates/trie/src/cursor/storage_cursor.rs @@ -1,10 +1,11 @@ use super::TrieCursor; +use crate::updates::TrieKey; use reth_db::{ - cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW}, + cursor::{DbCursorRO, DbDupCursorRO}, tables, Error, }; use reth_primitives::{ - trie::{BranchNodeCompact, StorageTrieEntry, StoredNibblesSubKey}, + trie::{BranchNodeCompact, StoredNibblesSubKey}, H256, }; @@ -24,10 +25,7 @@ impl StorageTrieCursor { impl<'a, C> TrieCursor for StorageTrieCursor where - C: DbDupCursorRO<'a, tables::StoragesTrie> - + DbDupCursorRW<'a, tables::StoragesTrie> - + DbCursorRO<'a, tables::StoragesTrie> - + DbCursorRW<'a, tables::StoragesTrie>, + C: DbDupCursorRO<'a, tables::StoragesTrie> + DbCursorRO<'a, tables::StoragesTrie>, { fn seek_exact( &mut self, @@ -50,28 +48,18 @@ where .map(|value| (value.nibbles.inner.to_vec(), value.node))) } - fn upsert(&mut self, key: StoredNibblesSubKey, value: BranchNodeCompact) -> Result<(), Error> { - if let Some(entry) = self.cursor.seek_by_key_subkey(self.hashed_address, key.clone())? { - // "seek exact" - if entry.nibbles == key { - self.cursor.delete_current()?; - } - } - - self.cursor.upsert(self.hashed_address, StorageTrieEntry { nibbles: key, node: value })?; - Ok(()) - } - - fn delete_current(&mut self) -> Result<(), Error> { - self.cursor.delete_current() + fn current(&mut self) -> Result, Error> { + Ok(self.cursor.current()?.map(|(k, v)| TrieKey::StorageNode(k, v.nibbles))) } } #[cfg(test)] mod tests { use super::*; - use reth_db::{mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut}; - use reth_primitives::trie::BranchNodeCompact; + use reth_db::{ + cursor::DbCursorRW, mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut, + }; + use reth_primitives::trie::{BranchNodeCompact, StorageTrieEntry}; use reth_provider::Transaction; // tests that upsert and seek match on the storagetrie cursor @@ -79,14 +67,20 @@ mod tests { fn test_storage_cursor_abstraction() { let db = create_test_rw_db(); let tx = Transaction::new(db.as_ref()).unwrap(); - let cursor = tx.cursor_dup_write::().unwrap(); - - let mut cursor = StorageTrieCursor::new(cursor, H256::random()); + let mut cursor = tx.cursor_dup_write::().unwrap(); + let hashed_address = H256::random(); let key = vec![0x2, 0x3]; let value = BranchNodeCompact::new(1, 1, 1, vec![H256::random()], None); - cursor.upsert(key.clone().into(), value.clone()).unwrap(); - assert_eq!(cursor.seek(key.into()).unwrap().unwrap().1, value); + cursor + .upsert( + hashed_address, + StorageTrieEntry { nibbles: key.clone().into(), node: value.clone() }, + ) + .unwrap(); + + let mut cursor = StorageTrieCursor::new(cursor, hashed_address); + assert_eq!(cursor.seek(key.clone().into()).unwrap().unwrap().1, value); } } diff --git a/crates/trie/src/cursor/subnode.rs b/crates/trie/src/cursor/subnode.rs index cc4f292ae6..31a404a677 100644 --- a/crates/trie/src/cursor/subnode.rs +++ b/crates/trie/src/cursor/subnode.rs @@ -1,5 +1,8 @@ use crate::{nodes::CHILD_INDEX_RANGE, Nibbles}; -use reth_primitives::{trie::BranchNodeCompact, H256}; +use reth_primitives::{ + trie::{BranchNodeCompact, StoredSubNode}, + H256, +}; /// Cursor for iterating over a subtrie. #[derive(Clone)] @@ -31,6 +34,23 @@ impl std::fmt::Debug for CursorSubNode { } } +impl From for CursorSubNode { + fn from(value: StoredSubNode) -> Self { + let nibble = match value.nibble { + Some(n) => n as i8, + None => -1, + }; + Self { key: Nibbles::from(value.key), nibble, node: value.node } + } +} + +impl From for StoredSubNode { + fn from(value: CursorSubNode) -> Self { + let nibble = if value.nibble >= 0 { Some(value.nibble as u8) } else { None }; + Self { key: value.key.hex_data, nibble, node: value.node } + } +} + impl CursorSubNode { /// Creates a new `CursorSubNode` from a key and an optional node. pub fn new(key: Nibbles, node: Option) -> Self { diff --git a/crates/trie/src/cursor/trie_cursor.rs b/crates/trie/src/cursor/trie_cursor.rs index 97f2aa9b51..4e275a2fc9 100644 --- a/crates/trie/src/cursor/trie_cursor.rs +++ b/crates/trie/src/cursor/trie_cursor.rs @@ -1,3 +1,4 @@ +use crate::updates::TrieKey; use reth_db::{table::Key, Error}; use reth_primitives::trie::BranchNodeCompact; @@ -9,9 +10,6 @@ pub trait TrieCursor { /// Move the cursor to the key and return a value matching of greater than the key. fn seek(&mut self, key: K) -> Result, BranchNodeCompact)>, Error>; - /// Upsert the key/value pair. - fn upsert(&mut self, key: K, value: BranchNodeCompact) -> Result<(), Error>; - - /// Delete the key/value pair at the current cursor position. - fn delete_current(&mut self) -> Result<(), Error>; + /// Get the current entry. + fn current(&mut self) -> Result, Error>; } diff --git a/crates/trie/src/hash_builder/mod.rs b/crates/trie/src/hash_builder.rs similarity index 90% rename from crates/trie/src/hash_builder/mod.rs rename to crates/trie/src/hash_builder.rs index 4de0ab240d..ab615cc7c2 100644 --- a/crates/trie/src/hash_builder/mod.rs +++ b/crates/trie/src/hash_builder.rs @@ -5,13 +5,10 @@ use crate::{ use reth_primitives::{ keccak256, proofs::EMPTY_ROOT, - trie::{BranchNodeCompact, TrieMask}, + trie::{BranchNodeCompact, HashBuilderState, HashBuilderValue, TrieMask}, H256, }; -use std::{fmt::Debug, sync::mpsc}; - -mod value; -use value::HashBuilderValue; +use std::{collections::BTreeMap, fmt::Debug, sync::mpsc}; /// A type alias for a sender of branch nodes. /// Branch nodes are sent by the Hash Builder to be stored in the database. @@ -40,7 +37,7 @@ pub type BranchNodeSender = mpsc::Sender<(Nibbles, BranchNodeCompact)>; /// up, combining the hashes of child nodes and ultimately generating the root hash. The root hash /// can then be used to verify the integrity and authenticity of the trie's data by constructing and /// verifying Merkle proofs. -#[derive(Clone, Debug, Default)] +#[derive(Debug, Default)] pub struct HashBuilder { key: Nibbles, stack: Vec>, @@ -52,19 +49,66 @@ pub struct HashBuilder { stored_in_database: bool, - branch_node_sender: Option, + updated_branch_nodes: Option>, +} + +impl From for HashBuilder { + fn from(state: HashBuilderState) -> Self { + Self { + key: Nibbles::from(state.key), + stack: state.stack, + value: state.value, + groups: state.groups, + tree_masks: state.tree_masks, + hash_masks: state.hash_masks, + stored_in_database: state.stored_in_database, + updated_branch_nodes: None, + } + } +} + +impl From for HashBuilderState { + fn from(state: HashBuilder) -> Self { + Self { + key: state.key.hex_data, + stack: state.stack, + value: state.value, + groups: state.groups, + tree_masks: state.tree_masks, + hash_masks: state.hash_masks, + stored_in_database: state.stored_in_database, + } + } } impl HashBuilder { - /// Creates a new instance of the Hash Builder. - pub fn new(store_tx: Option) -> Self { - Self { branch_node_sender: store_tx, ..Default::default() } + /// Enables the Hash Builder to store updated branch nodes. + /// + /// Call [HashBuilder::split] to get the updates to branch nodes. + pub fn with_updates(mut self, retain_updates: bool) -> Self { + self.set_updates(retain_updates); + self } - /// Set a branch node sender on the Hash Builder instance. - pub fn with_branch_node_sender(mut self, tx: BranchNodeSender) -> Self { - self.branch_node_sender = Some(tx); - self + /// Enables the Hash Builder to store updated branch nodes. + /// + /// Call [HashBuilder::split] to get the updates to branch nodes. + pub fn set_updates(&mut self, retain_updates: bool) { + if retain_updates { + self.updated_branch_nodes = Some(BTreeMap::default()); + } + } + + /// Splits the [HashBuilder] into a [HashBuilder] and hash builder updates. + pub fn split(mut self) -> (Self, BTreeMap) { + let updates = self.updated_branch_nodes.take(); + (self, updates.unwrap_or_default()) + } + + /// The number of total updates accrued. + /// Returns `0` if [Self::with_updates] was not called. + pub fn updates_len(&self) -> usize { + self.updated_branch_nodes.as_ref().map(|u| u.len()).unwrap_or(0) } /// Print the current stack of the Hash Builder. @@ -326,8 +370,8 @@ impl HashBuilder { // other side of the HashBuilder tracing::debug!(target: "trie::hash_builder", node = ?n, "intermediate node"); let common_prefix = current.slice(0, len); - if let Some(tx) = &self.branch_node_sender { - let _ = tx.send((common_prefix, n)); + if let Some(nodes) = self.updated_branch_nodes.as_mut() { + nodes.insert(common_prefix, n); } } } @@ -429,8 +473,7 @@ mod tests { #[test] fn test_generates_branch_node() { - let (sender, recv) = mpsc::channel(); - let mut hb = HashBuilder::new(Some(sender)); + let mut hb = HashBuilder::default().with_updates(true); // We have 1 branch node update to be stored at 0x01, indicated by the first nibble. // That branch root node has 2 branch node children present at 0x1 and 0x2. @@ -477,11 +520,9 @@ mod tests { hb.add_leaf(nibbles, val.as_ref()); }); let root = hb.root(); - drop(hb); - let updates = recv.iter().collect::>(); + let (_, updates) = hb.split(); - let updates = updates.iter().cloned().collect::>(); let update = updates.get(&Nibbles::from(hex!("01").as_slice())).unwrap(); assert_eq!(update.state_mask, TrieMask::new(0b1111)); // 1st nibble: 0, 1, 2, 3 assert_eq!(update.tree_mask, TrieMask::new(0)); diff --git a/crates/trie/src/hash_builder/value.rs b/crates/trie/src/hash_builder/value.rs deleted file mode 100644 index 71acfdf133..0000000000 --- a/crates/trie/src/hash_builder/value.rs +++ /dev/null @@ -1,40 +0,0 @@ -use reth_primitives::H256; - -#[derive(Clone)] -pub(crate) enum HashBuilderValue { - Bytes(Vec), - Hash(H256), -} - -impl std::fmt::Debug for HashBuilderValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Bytes(bytes) => write!(f, "Bytes({:?})", hex::encode(bytes)), - Self::Hash(hash) => write!(f, "Hash({:?})", hash), - } - } -} - -impl From> for HashBuilderValue { - fn from(value: Vec) -> Self { - Self::Bytes(value) - } -} - -impl From<&[u8]> for HashBuilderValue { - fn from(value: &[u8]) -> Self { - Self::Bytes(value.to_vec()) - } -} - -impl From for HashBuilderValue { - fn from(value: H256) -> Self { - Self::Hash(value) - } -} - -impl Default for HashBuilderValue { - fn default() -> Self { - Self::Bytes(vec![]) - } -} diff --git a/crates/trie/src/lib.rs b/crates/trie/src/lib.rs index 3380444f6c..2a5e570f60 100644 --- a/crates/trie/src/lib.rs +++ b/crates/trie/src/lib.rs @@ -36,7 +36,14 @@ pub use errors::{StateRootError, StorageRootError}; /// The implementation of the Merkle Patricia Trie. mod trie; -pub use trie::{BranchNodeUpdate, BranchNodeUpdateSender, StateRoot, StorageRoot}; +pub use trie::{StateRoot, StorageRoot}; + +/// Buffer for trie updates. +pub mod updates; + +/// Utilities for state root checkpoint progress. +mod progress; +pub use progress::{IntermediateStateRootState, StateRootProgress}; /// Collection of trie-related test utilities. #[cfg(any(test, feature = "test-utils"))] diff --git a/crates/trie/src/prefix_set/loader.rs b/crates/trie/src/prefix_set/loader.rs index af368a617c..33a8569aa3 100644 --- a/crates/trie/src/prefix_set/loader.rs +++ b/crates/trie/src/prefix_set/loader.rs @@ -11,7 +11,7 @@ use reth_db::{ use reth_primitives::{keccak256, BlockNumber, StorageEntry, H256}; use std::{collections::HashMap, ops::RangeInclusive}; -/// A wrapper around a database transaction that loads prefix sets within a given transition range. +/// A wrapper around a database transaction that loads prefix sets within a given block range. #[derive(Deref)] pub struct PrefixSetLoader<'a, TX>(&'a TX); @@ -26,7 +26,7 @@ impl<'a, 'b, TX> PrefixSetLoader<'a, TX> where TX: DbTx<'b>, { - /// Load all account and storage changes for the given transition id range. + /// Load all account and storage changes for the given block range. pub fn load( self, range: RangeInclusive, diff --git a/crates/trie/src/progress.rs b/crates/trie/src/progress.rs new file mode 100644 index 0000000000..5fadeeeda2 --- /dev/null +++ b/crates/trie/src/progress.rs @@ -0,0 +1,47 @@ +use crate::{cursor::CursorSubNode, hash_builder::HashBuilder, updates::TrieUpdates, Nibbles}; +use reth_primitives::{trie::StoredSubNode, MerkleCheckpoint, H256}; + +/// The progress of the state root computation. +#[derive(Debug)] +pub enum StateRootProgress { + /// The complete state root computation with updates and computed root. + Complete(H256, TrieUpdates), + /// The intermediate progress of state root computation. + /// Contains the walker stack, the hash builder and the trie updates. + Progress(IntermediateStateRootState, TrieUpdates), +} + +/// The intermediate state of the state root computation. +#[derive(Debug)] +pub struct IntermediateStateRootState { + /// Previously constructed hash builder. + pub hash_builder: HashBuilder, + /// Previously recorded walker stack. + pub walker_stack: Vec, + /// The last hashed account key processed. + pub last_account_key: H256, + /// The last walker key processed. + pub last_walker_key: Nibbles, +} + +impl From for MerkleCheckpoint { + fn from(value: IntermediateStateRootState) -> Self { + Self { + last_account_key: value.last_account_key, + last_walker_key: value.last_walker_key.hex_data, + walker_stack: value.walker_stack.into_iter().map(StoredSubNode::from).collect(), + state: value.hash_builder.into(), + } + } +} + +impl From for IntermediateStateRootState { + fn from(value: MerkleCheckpoint) -> Self { + Self { + hash_builder: HashBuilder::from(value.state), + walker_stack: value.walker_stack.into_iter().map(CursorSubNode::from).collect(), + last_account_key: value.last_account_key, + last_walker_key: Nibbles::from(value.last_walker_key), + } + } +} diff --git a/crates/trie/src/trie.rs b/crates/trie/src/trie.rs index ea868938f9..db62337d12 100644 --- a/crates/trie/src/trie.rs +++ b/crates/trie/src/trie.rs @@ -4,34 +4,19 @@ use crate::{ hash_builder::HashBuilder, nibbles::Nibbles, prefix_set::{PrefixSet, PrefixSetLoader}, + progress::{IntermediateStateRootState, StateRootProgress}, + updates::{TrieKey, TrieOp, TrieUpdates}, walker::TrieWalker, StateRootError, StorageRootError, }; use reth_db::{ - cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW}, + cursor::{DbCursorRO, DbDupCursorRO}, tables, - transaction::{DbTx, DbTxMut}, -}; -use reth_primitives::{ - keccak256, - proofs::EMPTY_ROOT, - trie::{BranchNodeCompact, StorageTrieEntry, StoredNibblesSubKey}, - Address, BlockNumber, StorageEntry, H256, + transaction::DbTx, }; +use reth_primitives::{keccak256, proofs::EMPTY_ROOT, Address, BlockNumber, StorageEntry, H256}; use reth_rlp::Encodable; -use std::{collections::HashMap, ops::RangeInclusive, sync::mpsc}; - -/// The branch node update sender -pub type BranchNodeUpdateSender = mpsc::Sender; - -/// The branch node message to update the database. -#[derive(Debug, Clone)] -pub enum BranchNodeUpdate { - /// The account trie branch node. - Account(Nibbles, BranchNodeCompact), - /// The storage trie branch node with the hashed key of the account. - Storage(H256, Nibbles, BranchNodeCompact), -} +use std::{collections::HashMap, ops::RangeInclusive}; /// StateRoot is used to compute the root node of a state trie. pub struct StateRoot<'a, TX> { @@ -42,6 +27,10 @@ pub struct StateRoot<'a, TX> { /// A map containing storage changes with the hashed address as key and a set of storage key /// prefixes as the value. pub changed_storage_prefixes: HashMap, + /// Previous intermediate state. + previous_state: Option, + /// The number of updates after which the intermediate progress should be returned. + threshold: u64, } impl<'a, TX> StateRoot<'a, TX> { @@ -51,6 +40,8 @@ impl<'a, TX> StateRoot<'a, TX> { tx, changed_account_prefixes: PrefixSet::default(), changed_storage_prefixes: HashMap::default(), + previous_state: None, + threshold: 100_000, } } @@ -65,30 +56,100 @@ impl<'a, TX> StateRoot<'a, TX> { self.changed_storage_prefixes = prefixes; self } + + /// Set the threshold. + pub fn with_threshold(mut self, threshold: u64) -> Self { + self.threshold = threshold; + self + } + + /// Set the threshold to maximum value so that itermediate progress is not returned. + pub fn with_no_threshold(mut self) -> Self { + self.threshold = u64::MAX; + self + } + + /// Set the previously recorded intermediate state. + pub fn with_intermediate_state(mut self, state: Option) -> Self { + self.previous_state = state; + self + } } -impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StateRoot<'a, TX> { - /// Given a transition id range, identifies all the accounts and storage keys that - /// have changed. Calculates the new state root using existing unchanged intermediate nodes and - /// updating the nodes that are present in the prefix set. +impl<'a, 'tx, TX: DbTx<'tx>> StateRoot<'a, TX> { + /// Given a block number range, identifies all the accounts and storage keys that + /// have changed. /// /// # Returns /// - /// The updated state root hash. + /// An instance of state root calculator with account and storage prefixes loaded. + pub fn incremental_root_calculator( + tx: &'a TX, + range: RangeInclusive, + ) -> Result { + let (account_prefixes, storage_prefixes) = PrefixSetLoader::new(tx).load(range)?; + Ok(Self::new(tx) + .with_changed_account_prefixes(account_prefixes) + .with_changed_storage_prefixes(storage_prefixes)) + } + + /// Computes the state root of the trie with the changed account and storage prefixes and + /// existing trie nodes. + /// + /// # Returns + /// + /// The updated state root. pub fn incremental_root( tx: &'a TX, range: RangeInclusive, - branch_node_sender: Option, ) -> Result { tracing::debug!(target: "loader", "incremental state root"); - let (account_prefixes, storage_prefixes) = PrefixSetLoader::new(tx).load(range)?; - let this = Self::new(tx) - .with_changed_account_prefixes(account_prefixes) - .with_changed_storage_prefixes(storage_prefixes); + Self::incremental_root_calculator(tx, range)?.root() + } - let root = this.root(branch_node_sender)?; + /// Computes the state root of the trie with the changed account and storage prefixes and + /// existing trie nodes collecting updates in the process. + /// + /// Ignores the threshold. + /// + /// # Returns + /// + /// The updated state root and the trie updates. + pub fn incremental_root_with_updates( + tx: &'a TX, + range: RangeInclusive, + ) -> Result<(H256, TrieUpdates), StateRootError> { + tracing::debug!(target: "loader", "incremental state root"); + Self::incremental_root_calculator(tx, range)?.root_with_updates() + } - Ok(root) + /// Computes the state root of the trie with the changed account and storage prefixes and + /// existing trie nodes collecting updates in the process. + /// + /// # Returns + /// + /// The intermediate progress of state root computation. + pub fn incremental_root_with_progress( + tx: &'a TX, + range: RangeInclusive, + ) -> Result { + tracing::debug!(target: "loader", "incremental state root with progress"); + Self::incremental_root_calculator(tx, range)?.root_with_progress() + } + + /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the + /// nodes into the hash builder. Collects the updates in the process. + /// + /// Ignores the threshold. + /// + /// # Returns + /// + /// The intermediate progress of state root computation and the trie updates. + pub fn root_with_updates(self) -> Result<(H256, TrieUpdates), StateRootError> { + match self.with_no_threshold().calculate(true)? { + StateRootProgress::Complete(root, updates) => Ok((root, updates)), + StateRootProgress::Progress(..) => unreachable!(), // unreachable threshold + } } /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the @@ -97,43 +158,78 @@ impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StateRoot<'a, TX> { /// # Returns /// /// The state root hash. - pub fn root( - &self, - branch_node_sender: Option, - ) -> Result { - tracing::debug!(target: "loader", "calculating state root"); + pub fn root(self) -> Result { + match self.calculate(false)? { + StateRootProgress::Complete(root, _) => Ok(root), + StateRootProgress::Progress(..) => unreachable!(), // update retenion is disabled + } + } - let (sender, maybe_receiver) = match branch_node_sender { - Some(sender) => (sender, None), - None => { - let (sender, recv) = mpsc::channel(); - (sender, Some(recv)) - } - }; + /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the + /// nodes into the hash builder. Collects the updates in the process. + /// + /// # Returns + /// + /// The intermediate progress of state root computation. + pub fn root_with_progress(self) -> Result { + self.calculate(true) + } + + fn calculate(self, retain_updates: bool) -> Result { + tracing::debug!(target: "loader", "calculating state root"); + let mut trie_updates = TrieUpdates::default(); let mut hashed_account_cursor = self.tx.cursor_read::()?; let mut trie_cursor = - AccountTrieCursor::new(self.tx.cursor_write::()?); - let mut walker = TrieWalker::new(&mut trie_cursor, self.changed_account_prefixes.clone()); + AccountTrieCursor::new(self.tx.cursor_read::()?); - let (account_branch_node_tx, account_branch_node_rx) = mpsc::channel(); - let mut hash_builder = - HashBuilder::default().with_branch_node_sender(account_branch_node_tx); - - while let Some(key) = walker.key() { - if walker.can_skip_current_node { - let value = walker.hash().unwrap(); - let is_in_db_trie = walker.children_are_in_trie(); - hash_builder.add_branch(key, value, is_in_db_trie); - } - - let seek_key = match walker.next_unprocessed_key() { - Some(key) => key, - None => break, // no more keys + let (mut walker, mut hash_builder, mut last_account_key, mut last_walker_key) = + match self.previous_state { + Some(state) => ( + TrieWalker::from_stack( + &mut trie_cursor, + state.walker_stack, + self.changed_account_prefixes, + ), + state.hash_builder, + Some(state.last_account_key), + Some(state.last_walker_key), + ), + None => ( + TrieWalker::new(&mut trie_cursor, self.changed_account_prefixes), + HashBuilder::default(), + None, + None, + ), + }; + + walker.set_updates(retain_updates); + hash_builder.set_updates(retain_updates); + + while let Some(key) = last_walker_key.take().or_else(|| walker.key()) { + // Take the last account key to make sure we take it into consideration only once. + let (next_key, mut next_account_entry) = match last_account_key.take() { + // Seek the last processed entry and take the next after. + Some(account_key) => { + hashed_account_cursor.seek(account_key)?; + (walker.key(), hashed_account_cursor.next()?) + } + None => { + if walker.can_skip_current_node { + let value = walker.hash().unwrap(); + let is_in_db_trie = walker.children_are_in_trie(); + hash_builder.add_branch(key.clone(), value, is_in_db_trie); + } + + let seek_key = match walker.next_unprocessed_key() { + Some(key) => key, + None => break, // no more keys + }; + + (walker.advance()?, hashed_account_cursor.seek(seek_key)?) + } }; - let next_key = walker.advance()?; - let mut next_account_entry = hashed_account_cursor.seek(seek_key)?; while let Some((hashed_address, account)) = next_account_entry { let account_nibbles = Nibbles::unpack(hashed_address); @@ -151,14 +247,21 @@ impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StateRoot<'a, TX> { // progress. // TODO: We can consider introducing the TrieProgress::Progress/Complete // abstraction inside StorageRoot, but let's give it a try as-is for now. - let storage_root = StorageRoot::new_hashed(self.tx, hashed_address) + let storage_root_calculator = StorageRoot::new_hashed(self.tx, hashed_address) .with_changed_prefixes( self.changed_storage_prefixes .get(&hashed_address) .cloned() .unwrap_or_default(), - ) - .root(Some(sender.clone()))?; + ); + + let storage_root = if retain_updates { + let (root, mut updates) = storage_root_calculator.root_with_updates()?; + trie_updates.append(&mut updates); + root + } else { + storage_root_calculator.root()? + }; let account = EthAccount::from(account).with_storage_root(storage_root); let mut account_rlp = Vec::with_capacity(account.length()); @@ -166,50 +269,40 @@ impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StateRoot<'a, TX> { hash_builder.add_leaf(account_nibbles, &account_rlp); + // Decide if we need to return intermediate progress. + let total_updates_len = + trie_updates.len() + walker.updates_len() + hash_builder.updates_len(); + if retain_updates && total_updates_len as u64 >= self.threshold { + let (walker_stack, mut walker_updates) = walker.split(); + let (hash_builder, hash_builder_updates) = hash_builder.split(); + + let state = IntermediateStateRootState { + hash_builder, + walker_stack, + last_walker_key: key, + last_account_key: hashed_address, + }; + + trie_updates.append(&mut walker_updates); + trie_updates.extend_with_account_updates(hash_builder_updates); + + return Ok(StateRootProgress::Progress(state, trie_updates)) + } + + // Move the next account entry next_account_entry = hashed_account_cursor.next()?; } } let root = hash_builder.root(); - drop(hash_builder); - for (nibbles, branch_node) in account_branch_node_rx.iter() { - let _ = sender.send(BranchNodeUpdate::Account(nibbles, branch_node)); - } - drop(sender); + let (_, mut walker_updates) = walker.split(); + let (_, hash_builder_updates) = hash_builder.split(); - if let Some(receiver) = maybe_receiver { - let mut account_cursor = self.tx.cursor_write::()?; - let mut storage_cursor = self.tx.cursor_dup_write::()?; + trie_updates.append(&mut walker_updates); + trie_updates.extend_with_account_updates(hash_builder_updates); - for update in receiver.iter() { - match update { - BranchNodeUpdate::Account(nibbles, branch_node) => { - if !nibbles.is_empty() { - account_cursor.upsert(nibbles.hex_data.into(), branch_node)?; - } - } - BranchNodeUpdate::Storage(hashed_address, nibbles, node) => { - if !nibbles.is_empty() { - let key: StoredNibblesSubKey = nibbles.hex_data.into(); - if let Some(entry) = - storage_cursor.seek_by_key_subkey(hashed_address, key.clone())? - { - // "seek exact" - if entry.nibbles == key { - storage_cursor.delete_current()?; - } - } - - storage_cursor - .upsert(hashed_address, StorageTrieEntry { nibbles: key, node })?; - } - } - } - } - } - - Ok(root) + Ok(StateRootProgress::Complete(root, trie_updates)) } } @@ -241,34 +334,48 @@ impl<'a, TX> StorageRoot<'a, TX> { } } -impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StorageRoot<'a, TX> { +impl<'a, 'tx, TX: DbTx<'tx>> StorageRoot<'a, TX> { /// Walks the hashed storage table entries for a given address and calculates the storage root. - pub fn root( - &self, - branch_node_update_sender: Option, - ) -> Result { + /// + /// # Returns + /// + /// The storage root and storage trie updates for a given address. + pub fn root_with_updates(&self) -> Result<(H256, TrieUpdates), StorageRootError> { + self.calculate(true) + } + + /// Walks the hashed storage table entries for a given address and calculates the storage root. + /// + /// # Returns + /// + /// The storage root. + pub fn root(&self) -> Result { + let (root, _) = self.calculate(false)?; + Ok(root) + } + + fn calculate(&self, retain_updates: bool) -> Result<(H256, TrieUpdates), StorageRootError> { tracing::debug!(target: "trie::storage_root", hashed_address = ?self.hashed_address, "calculating storage root"); let mut hashed_storage_cursor = self.tx.cursor_dup_read::()?; let mut trie_cursor = StorageTrieCursor::new( - self.tx.cursor_dup_write::()?, + self.tx.cursor_dup_read::()?, self.hashed_address, ); // do not add a branch node on empty storage if hashed_storage_cursor.seek_exact(self.hashed_address)?.is_none() { - if trie_cursor.cursor.seek_exact(self.hashed_address)?.is_some() { - trie_cursor.cursor.delete_current_duplicates()?; - } - return Ok(EMPTY_ROOT) + return Ok(( + EMPTY_ROOT, + TrieUpdates::from([(TrieKey::StorageTrie(self.hashed_address), TrieOp::Delete)]), + )) } - let mut walker = TrieWalker::new(&mut trie_cursor, self.changed_prefixes.clone()); + let mut walker = TrieWalker::new(&mut trie_cursor, self.changed_prefixes.clone()) + .with_updates(retain_updates); - let (storage_branch_node_tx, storage_branch_node_rx) = mpsc::channel(); - let mut hash_builder = - HashBuilder::default().with_branch_node_sender(storage_branch_node_tx); + let mut hash_builder = HashBuilder::default().with_updates(retain_updates); while let Some(key) = walker.key() { if walker.can_skip_current_node { @@ -297,20 +404,16 @@ impl<'a, 'tx, TX: DbTx<'tx> + DbTxMut<'tx>> StorageRoot<'a, TX> { } let root = hash_builder.root(); - drop(hash_builder); - if let Some(sender) = branch_node_update_sender { - for (nibbles, branch_node) in storage_branch_node_rx.iter() { - let _ = sender.send(BranchNodeUpdate::Storage( - self.hashed_address, - nibbles, - branch_node, - )); - } - } + let (_, hash_builder_updates) = hash_builder.split(); + let (_, mut walker_updates) = walker.split(); + + let mut trie_updates = TrieUpdates::default(); + trie_updates.append(&mut walker_updates); + trie_updates.extend_with_storage_updates(self.hashed_address, hash_builder_updates); tracing::debug!(target: "trie::storage_root", ?root, hashed_address = ?self.hashed_address, "calculated storage root"); - Ok(root) + Ok((root, trie_updates)) } } @@ -328,8 +431,11 @@ mod tests { transaction::DbTxMut, }; use reth_primitives::{ - hex_literal::hex, keccak256, proofs::KeccakHasher, trie::TrieMask, Account, Address, H256, - U256, + hex_literal::hex, + keccak256, + proofs::KeccakHasher, + trie::{BranchNodeCompact, TrieMask}, + Account, Address, H256, U256, }; use reth_provider::Transaction; use std::{ @@ -376,9 +482,8 @@ mod tests { } // Generate the intermediate nodes on the receiving end of the channel - let (branch_node_tx, branch_node_rx) = mpsc::channel(); - let _ = - StorageRoot::new_hashed(tx.deref(), hashed_address).root(Some(branch_node_tx)).unwrap(); + let (_, trie_updates) = + StorageRoot::new_hashed(tx.deref(), hashed_address).root_with_updates().unwrap(); // 1. Some state transition happens, update the hashed storage to the new value let modified_key = H256::from_str(modified).unwrap(); @@ -393,31 +498,17 @@ mod tests { // 2. Calculate full merkle root let loader = StorageRoot::new_hashed(tx.deref(), hashed_address); - let modified_root = loader.root(None).unwrap(); + let modified_root = loader.root().unwrap(); // Update the intermediate roots table so that we can run the incremental verification - let mut trie_cursor = tx.cursor_dup_write::().unwrap(); - let updates = branch_node_rx.iter().collect::>(); - for update in updates { - match update { - BranchNodeUpdate::Storage(_, nibbles, node) => { - trie_cursor - .upsert( - hashed_address, - StorageTrieEntry { nibbles: nibbles.hex_data.into(), node }, - ) - .unwrap(); - } - _ => unreachable!(), - } - } + trie_updates.flush(tx.deref()).unwrap(); // 3. Calculate the incremental root let mut storage_changes = PrefixSet::default(); storage_changes.insert(Nibbles::unpack(modified_key)); let loader = StorageRoot::new_hashed(tx.deref_mut(), hashed_address) .with_changed_prefixes(storage_changes); - let incremental_root = loader.root(None).unwrap(); + let incremental_root = loader.root().unwrap(); assert_eq!(modified_root, incremental_root); } @@ -441,26 +532,23 @@ mod tests { #[test] fn arbitrary_storage_root() { proptest!(ProptestConfig::with_cases(10), |(item: (Address, std::collections::BTreeMap))| { - tokio::runtime::Runtime::new().unwrap().block_on(async { - let (address, storage) = item; + let (address, storage) = item; - let hashed_address = keccak256(address); - let db = create_test_rw_db(); - let mut tx = Transaction::new(db.as_ref()).unwrap(); - for (key, value) in &storage { - tx.put::( - hashed_address, - StorageEntry { key: keccak256(key), value: *value }, - ) - .unwrap(); - } - tx.commit().unwrap(); - - let got = StorageRoot::new(tx.deref_mut(), address).root(None).unwrap(); - let expected = storage_root(storage.into_iter()); - assert_eq!(expected, got); - }); + let hashed_address = keccak256(address); + let db = create_test_rw_db(); + let mut tx = Transaction::new(db.as_ref()).unwrap(); + for (key, value) in &storage { + tx.put::( + hashed_address, + StorageEntry { key: keccak256(key), value: *value }, + ) + .unwrap(); + } + tx.commit().unwrap(); + let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap(); + let expected = storage_root(storage.into_iter()); + assert_eq!(expected, got); }); } @@ -516,7 +604,7 @@ mod tests { insert_account(&mut *tx, address, account, &Default::default()); tx.commit().unwrap(); - let got = StorageRoot::new(tx.deref_mut(), address).root(None).unwrap(); + let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap(); assert_eq!(got, EMPTY_ROOT); } @@ -542,7 +630,7 @@ mod tests { insert_account(&mut *tx, address, account, &storage); tx.commit().unwrap(); - let got = StorageRoot::new(tx.deref_mut(), address).root(None).unwrap(); + let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap(); assert_eq!(storage_root(storage.into_iter()), got); } @@ -553,23 +641,42 @@ mod tests { fn arbitrary_state_root() { proptest!( ProptestConfig::with_cases(10), | (state: State) | { - // set the bytecodehash for the accounts so that storage root is computed - // this is needed because proptest will generate accs with empty bytecodehash - // but non-empty storage, which is obviously invalid - let state = state - .into_iter() - .map(|(addr, (mut acc, storage))| { - if !storage.is_empty() { - acc.bytecode_hash = Some(H256::random()); - } - (addr, (acc, storage)) - }) - .collect::>(); test_state_root_with_state(state); } ); } + #[test] + fn arbitrary_state_root_with_progress() { + proptest!( + ProptestConfig::with_cases(10), | (state: State) | { + let db = create_test_rw_db(); + let mut tx = Transaction::new(db.as_ref()).unwrap(); + + for (address, (account, storage)) in &state { + insert_account(&mut *tx, *address, *account, storage) + } + tx.commit().unwrap(); + let expected = state_root(state.into_iter()); + + let threshold = 10; + let mut got = None; + + let mut intermediate_state = None; + while got.is_none() { + let calculator = StateRoot::new(tx.deref_mut()) + .with_threshold(threshold) + .with_intermediate_state(intermediate_state.take()); + match calculator.root_with_progress().unwrap() { + StateRootProgress::Progress(state, _updates) => intermediate_state = Some(state), + StateRootProgress::Complete(root, _updates) => got = Some(root), + }; + } + assert_eq!(expected, got.unwrap()); + } + ); + } + fn test_state_root_with_state(state: State) { let db = create_test_rw_db(); let mut tx = Transaction::new(db.as_ref()).unwrap(); @@ -580,7 +687,7 @@ mod tests { tx.commit().unwrap(); let expected = state_root(state.into_iter()); - let got = StateRoot::new(tx.deref_mut()).root(None).unwrap(); + let got = StateRoot::new(tx.deref_mut()).root().unwrap(); assert_eq!(expected, got); } @@ -620,7 +727,7 @@ mod tests { } tx.commit().unwrap(); - let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root(None).unwrap(); + let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root().unwrap(); let expected_root = storage_root_prehashed(storage.into_iter()); assert_eq!(expected_root, account3_storage_root); } @@ -685,7 +792,7 @@ mod tests { } hashed_storage_cursor.upsert(key3, StorageEntry { key: hashed_slot, value }).unwrap(); } - let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root(None).unwrap(); + let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root().unwrap(); hash_builder.add_leaf( Nibbles::unpack(key3), &encode_account(account3, Some(account3_storage_root)), @@ -734,36 +841,29 @@ mod tests { assert_eq!(hash_builder.root(), computed_expected_root); // Check state root calculation from scratch - let (branch_node_tx, branch_node_rx) = mpsc::channel(); - let loader = StateRoot::new(tx.deref()); - assert_eq!(loader.root(Some(branch_node_tx)).unwrap(), computed_expected_root); + let (root, trie_updates) = StateRoot::new(tx.deref()).root_with_updates().unwrap(); + assert_eq!(root, computed_expected_root); // Check account trie - drop(loader); - let updates = branch_node_rx.iter().collect::>(); - - let account_updates = updates + let account_updates = trie_updates .iter() - .filter_map(|u| { - if let BranchNodeUpdate::Account(nibbles, node) = u { - Some((nibbles, node)) - } else { - None - } + .filter_map(|(k, v)| match (k, v) { + (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), + _ => None, }) .collect::>(); assert_eq!(account_updates.len(), 2); - let (nibbles1a, node1a) = account_updates.last().unwrap(); - assert_eq!(**nibbles1a, Nibbles::from(&[0xB])); + let (nibbles1a, node1a) = account_updates.first().unwrap(); + assert_eq!(nibbles1a.inner[..], [0xB]); assert_eq!(node1a.state_mask, TrieMask::new(0b1011)); assert_eq!(node1a.tree_mask, TrieMask::new(0b0001)); assert_eq!(node1a.hash_mask, TrieMask::new(0b1001)); assert_eq!(node1a.root_hash, None); assert_eq!(node1a.hashes.len(), 2); - let (nibbles2a, node2a) = account_updates.first().unwrap(); - assert_eq!(**nibbles2a, Nibbles::from(&[0xB, 0x0])); + let (nibbles2a, node2a) = account_updates.last().unwrap(); + assert_eq!(nibbles2a.inner[..], [0xB, 0x0]); assert_eq!(node2a.state_mask, TrieMask::new(0b10001)); assert_eq!(node2a.tree_mask, TrieMask::new(0b00000)); assert_eq!(node2a.hash_mask, TrieMask::new(0b10000)); @@ -771,20 +871,17 @@ mod tests { assert_eq!(node2a.hashes.len(), 1); // Check storage trie - let storage_updates = updates + let storage_updates = trie_updates .iter() - .filter_map(|u| { - if let BranchNodeUpdate::Storage(_, nibbles, node) = u { - Some((nibbles, node)) - } else { - None - } + .filter_map(|entry| match entry { + (TrieKey::StorageNode(_, nibbles), TrieOp::Update(node)) => Some((nibbles, node)), + _ => None, }) .collect::>(); assert_eq!(storage_updates.len(), 1); let (nibbles3, node3) = storage_updates.first().unwrap(); - assert!(nibbles3.is_empty()); + assert!(nibbles3.inner.is_empty()); assert_eq!(node3.state_mask, TrieMask::new(0b1010)); assert_eq!(node3.tree_mask, TrieMask::new(0b0000)); assert_eq!(node3.hash_mask, TrieMask::new(0b0010)); @@ -808,27 +905,23 @@ mod tests { H256::from_str("8e263cd4eefb0c3cbbb14e5541a66a755cad25bcfab1e10dd9d706263e811b28") .unwrap(); - let (branch_node_tx, branch_node_rx) = mpsc::channel(); - let loader = StateRoot::new(tx.deref()).with_changed_account_prefixes(prefix_set); - assert_eq!(loader.root(Some(branch_node_tx)).unwrap(), expected_state_root); + let (root, trie_updates) = StateRoot::new(tx.deref()) + .with_changed_account_prefixes(prefix_set) + .root_with_updates() + .unwrap(); + assert_eq!(root, expected_state_root); - drop(loader); - let updates = branch_node_rx.iter().collect::>(); - - let account_updates = updates + let account_updates = trie_updates .iter() - .filter_map(|u| { - if let BranchNodeUpdate::Account(nibbles, node) = u { - Some((nibbles, node)) - } else { - None - } + .filter_map(|entry| match entry { + (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), + _ => None, }) .collect::>(); assert_eq!(account_updates.len(), 2); - let (nibbles1b, node1b) = account_updates.last().unwrap(); - assert_eq!(**nibbles1b, Nibbles::from(&[0xB])); + let (nibbles1b, node1b) = account_updates.first().unwrap(); + assert_eq!(nibbles1b.inner[..], [0xB]); assert_eq!(node1b.state_mask, TrieMask::new(0b1011)); assert_eq!(node1b.tree_mask, TrieMask::new(0b0001)); assert_eq!(node1b.hash_mask, TrieMask::new(0b1011)); @@ -837,8 +930,8 @@ mod tests { assert_eq!(node1a.hashes[0], node1b.hashes[0]); assert_eq!(node1a.hashes[1], node1b.hashes[2]); - let (nibbles2b, node2b) = account_updates.first().unwrap(); - assert_eq!(**nibbles2b, Nibbles::from(&[0xB, 0x0])); + let (nibbles2b, node2b) = account_updates.last().unwrap(); + assert_eq!(nibbles2b.inner[..], [0xB, 0x0]); assert_eq!(node2a, node2b); tx.commit().unwrap(); @@ -861,29 +954,25 @@ mod tests { (key6, encode_account(account6, None)), ]); - let (branch_node_tx, branch_node_rx) = mpsc::channel(); - let loader = - StateRoot::new(tx.deref_mut()).with_changed_account_prefixes(account_prefix_set); - assert_eq!(loader.root(Some(branch_node_tx)).unwrap(), computed_expected_root); - drop(loader); + let (root, trie_updates) = StateRoot::new(tx.deref()) + .with_changed_account_prefixes(account_prefix_set) + .root_with_updates() + .unwrap(); + assert_eq!(root, computed_expected_root); + assert_eq!(trie_updates.len(), 7); + assert_eq!(trie_updates.iter().filter(|(_, op)| op.is_update()).count(), 2); - let updates = branch_node_rx.iter().collect::>(); - assert_eq!(updates.len(), 2); - - let account_updates = updates + let account_updates = trie_updates .iter() - .filter_map(|u| { - if let BranchNodeUpdate::Account(nibbles, node) = u { - Some((nibbles, node)) - } else { - None - } + .filter_map(|entry| match entry { + (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), + _ => None, }) .collect::>(); assert_eq!(account_updates.len(), 1); let (nibbles1c, node1c) = account_updates.first().unwrap(); - assert_eq!(**nibbles1c, Nibbles::from(&[0xB])); + assert_eq!(nibbles1c.inner[..], [0xB]); assert_eq!(node1c.state_mask, TrieMask::new(0b1011)); assert_eq!(node1c.tree_mask, TrieMask::new(0b0000)); @@ -920,29 +1009,25 @@ mod tests { (key6, encode_account(account6, None)), ]); - let (branch_node_tx, branch_node_rx) = mpsc::channel(); - let loader = - StateRoot::new(tx.deref_mut()).with_changed_account_prefixes(account_prefix_set); - assert_eq!(loader.root(Some(branch_node_tx)).unwrap(), computed_expected_root); - drop(loader); + let (root, trie_updates) = StateRoot::new(tx.deref_mut()) + .with_changed_account_prefixes(account_prefix_set) + .root_with_updates() + .unwrap(); + assert_eq!(root, computed_expected_root); + assert_eq!(trie_updates.len(), 6); + assert_eq!(trie_updates.iter().filter(|(_, op)| op.is_update()).count(), 1); // no storage root update - let updates = branch_node_rx.iter().collect::>(); - assert_eq!(updates.len(), 1); // no storage root update - - let account_updates = updates + let account_updates = trie_updates .iter() - .filter_map(|u| { - if let BranchNodeUpdate::Account(nibbles, node) = u { - Some((nibbles, node)) - } else { - None - } + .filter_map(|entry| match entry { + (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => Some((nibbles, node)), + _ => None, }) .collect::>(); assert_eq!(account_updates.len(), 1); let (nibbles1d, node1d) = account_updates.first().unwrap(); - assert_eq!(**nibbles1d, Nibbles::from(&[0xB])); + assert_eq!(nibbles1d.inner[..], [0xB]); assert_eq!(node1d.state_mask, TrieMask::new(0b1011)); assert_eq!(node1d.tree_mask, TrieMask::new(0b0000)); @@ -963,23 +1048,17 @@ mod tests { let expected = extension_node_trie(&mut tx); - let (sender, recv) = mpsc::channel(); - let loader = StateRoot::new(tx.deref_mut()); - let got = loader.root(Some(sender)).unwrap(); + let (got, updates) = StateRoot::new(tx.deref_mut()).root_with_updates().unwrap(); assert_eq!(expected, got); // Check account trie - drop(loader); - let updates = recv.iter().collect::>(); - let account_updates = updates - .into_iter() - .filter_map(|u| { - if let BranchNodeUpdate::Account(nibbles, node) = u { - Some((nibbles, node)) - } else { - None + .iter() + .filter_map(|entry| match entry { + (TrieKey::AccountNode(nibbles), TrieOp::Update(node)) => { + Some((nibbles.inner[..].into(), node.clone())) } + _ => None, }) .collect::>(); @@ -994,11 +1073,9 @@ mod tests { let expected = extension_node_trie(&mut tx); - let loader = StateRoot::new(tx.deref_mut()); - let got = loader.root(None).unwrap(); + let (got, updates) = StateRoot::new(tx.deref_mut()).root_with_updates().unwrap(); assert_eq!(expected, got); - - drop(loader); + updates.flush(tx.deref_mut()).unwrap(); // read the account updates from the db let mut accounts_trie = tx.cursor_read::().unwrap(); @@ -1006,7 +1083,7 @@ mod tests { let mut account_updates = BTreeMap::new(); for item in walker { let (key, node) = item.unwrap(); - account_updates.insert(Nibbles::from(key.inner.0.as_ref()), node); + account_updates.insert(key.inner[..].into(), node); } assert_trie_updates(&account_updates); @@ -1024,37 +1101,26 @@ mod tests { let mut state = BTreeMap::default(); for accounts in account_changes { - let mut account_trie = tx.cursor_write::().unwrap(); - let should_generate_changeset = !state.is_empty(); let mut changes = PrefixSet::default(); for (hashed_address, balance) in accounts.clone() { - hashed_account_cursor.upsert(hashed_address, Account { balance, ..Default::default() }).unwrap(); + hashed_account_cursor.upsert(hashed_address, Account { balance,..Default::default() }).unwrap(); if should_generate_changeset { changes.insert(Nibbles::unpack(hashed_address)); } } - let (branch_node_rx, branch_node_tx) = mpsc::channel(); - let account_storage_root = StateRoot::new(tx.deref_mut()).with_changed_account_prefixes(changes).root(Some(branch_node_rx)).unwrap(); + let (state_root, trie_updates) = StateRoot::new(tx.deref_mut()) + .with_changed_account_prefixes(changes) + .root_with_updates() + .unwrap(); state.append(&mut accounts.clone()); let expected_root = state_root_prehashed( - state.clone().into_iter().map(|(key, balance)| (key, (Account { balance, ..Default::default() }, std::iter::empty()))), + state.clone().into_iter().map(|(key, balance)| (key, (Account { balance, ..Default::default() }, std::iter::empty()))) ); - assert_eq!(expected_root, account_storage_root); - - let updates = branch_node_tx.iter().collect::>(); - for update in updates { - match update { - BranchNodeUpdate::Account(nibbles, node) => { - if !nibbles.is_empty() { - account_trie.upsert(nibbles.hex_data.into(), node).unwrap(); - } - } - BranchNodeUpdate::Storage(..) => unreachable!(), - }; - } + assert_eq!(expected_root, state_root); + trie_updates.flush(tx.deref_mut()).unwrap(); } }); } @@ -1069,23 +1135,18 @@ mod tests { let (expected_root, expected_updates) = extension_node_storage_trie(&mut tx, hashed_address); - let (sender, recv) = mpsc::channel(); - let loader = StorageRoot::new_hashed(tx.deref_mut(), hashed_address); - let got = loader.root(Some(sender)).unwrap(); + let (got, updates) = + StorageRoot::new_hashed(tx.deref_mut(), hashed_address).root_with_updates().unwrap(); assert_eq!(expected_root, got); // Check account trie - drop(loader); - let updates = recv.iter().collect::>(); - let storage_updates = updates - .into_iter() - .filter_map(|u| { - if let BranchNodeUpdate::Storage(_, nibbles, node) = u { - Some((nibbles, node)) - } else { - None + .iter() + .filter_map(|entry| match entry { + (TrieKey::StorageNode(_, nibbles), TrieOp::Update(node)) => { + Some((nibbles.inner[..].into(), node.clone())) } + _ => None, }) .collect::>(); assert_eq!(expected_updates, storage_updates); @@ -1101,8 +1162,7 @@ mod tests { let mut hashed_storage = tx.cursor_write::().unwrap(); - let (sender, receiver) = mpsc::channel(); - let mut hb = HashBuilder::new(Some(sender)); + let mut hb = HashBuilder::default().with_updates(true); for key in [ hex!("30af561000000000000000000000000000000000000000000000000000000000"), @@ -1115,12 +1175,9 @@ mod tests { hashed_storage.upsert(hashed_address, StorageEntry { key: H256(key), value }).unwrap(); hb.add_leaf(Nibbles::unpack(key), &reth_rlp::encode_fixed_size(&value)); } + let root = hb.root(); - - drop(hb); - let updates = receiver.iter().collect::>(); - let updates = updates.iter().cloned().collect(); - + let (_, updates) = hb.split(); (root, updates) } @@ -1134,7 +1191,7 @@ mod tests { let val = encode_account(a, None); let mut hashed_accounts = tx.cursor_write::().unwrap(); - let mut hb = HashBuilder::new(None); + let mut hb = HashBuilder::default(); for key in [ hex!("30af561000000000000000000000000000000000000000000000000000000000"), @@ -1154,11 +1211,11 @@ mod tests { fn assert_trie_updates(account_updates: &BTreeMap) { assert_eq!(account_updates.len(), 2); - let node = account_updates.get(&Nibbles::from(vec![0x3])).unwrap(); + let node = account_updates.get(&vec![0x3].into()).unwrap(); let expected = BranchNodeCompact::new(0b0011, 0b0001, 0b0000, vec![], None); assert_eq!(node, &expected); - let node = account_updates.get(&Nibbles::from(vec![0x3, 0x0, 0xA, 0xF])).unwrap(); + let node = account_updates.get(&vec![0x3, 0x0, 0xA, 0xF].into()).unwrap(); assert_eq!(node.state_mask, TrieMask::new(0b101100000)); assert_eq!(node.tree_mask, TrieMask::new(0b000000000)); assert_eq!(node.hash_mask, TrieMask::new(0b001000000)); diff --git a/crates/trie/src/updates.rs b/crates/trie/src/updates.rs new file mode 100644 index 0000000000..afd635a26e --- /dev/null +++ b/crates/trie/src/updates.rs @@ -0,0 +1,151 @@ +use crate::Nibbles; +use derive_more::Deref; +use reth_db::{ + cursor::{DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW}, + tables, + transaction::{DbTx, DbTxMut}, +}; +use reth_primitives::{ + trie::{BranchNodeCompact, StorageTrieEntry, StoredNibbles, StoredNibblesSubKey}, + H256, +}; +use std::collections::BTreeMap; + +/// The key of a trie node. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum TrieKey { + /// A node in the account trie. + AccountNode(StoredNibbles), + /// A node in the storage trie. + StorageNode(H256, StoredNibblesSubKey), + /// Storage trie of an account. + StorageTrie(H256), +} + +/// The operation to perform on the trie. +#[derive(Debug, Clone)] +pub enum TrieOp { + /// Delete the node entry. + Delete, + /// Update the node entry with the provided value. + Update(BranchNodeCompact), +} + +impl TrieOp { + /// Returns `true` if the operation is an update. + pub fn is_update(&self) -> bool { + matches!(self, TrieOp::Update(..)) + } +} + +/// The aggregation of trie updates. +#[derive(Debug, Default, Clone, Deref)] +pub struct TrieUpdates { + trie_operations: BTreeMap, +} + +impl From<[(TrieKey, TrieOp); N]> for TrieUpdates { + fn from(value: [(TrieKey, TrieOp); N]) -> Self { + Self { trie_operations: BTreeMap::from(value) } + } +} + +impl TrieUpdates { + /// Schedule a delete operation on a trie key. + /// + /// # Panics + /// + /// If the key already exists and the operation is an update. + pub fn schedule_delete(&mut self, key: TrieKey) { + let existing = self.trie_operations.insert(key, TrieOp::Delete); + if let Some(op) = existing { + assert!(!op.is_update(), "Tried to delete a node that was already updated"); + } + } + + /// Append the updates to the current updates. + pub fn append(&mut self, other: &mut Self) { + self.trie_operations.append(&mut other.trie_operations); + } + + /// Extend the updates with trie updates. + pub fn extend(&mut self, updates: impl Iterator) { + self.trie_operations.extend(updates); + } + + /// Extend the updates with account trie updates. + pub fn extend_with_account_updates(&mut self, updates: BTreeMap) { + self.extend(updates.into_iter().map(|(nibbles, node)| { + (TrieKey::AccountNode(nibbles.hex_data.into()), TrieOp::Update(node)) + })); + } + + /// Extend the updates with storage trie updates. + pub fn extend_with_storage_updates( + &mut self, + hashed_address: H256, + updates: BTreeMap, + ) { + self.extend(updates.into_iter().map(|(nibbles, node)| { + (TrieKey::StorageNode(hashed_address, nibbles.hex_data.into()), TrieOp::Update(node)) + })); + } + + /// Flush updates all aggregated updates to the database. + pub fn flush<'a, 'tx, TX>(self, tx: &'a TX) -> Result<(), reth_db::Error> + where + TX: DbTx<'tx> + DbTxMut<'tx>, + { + if self.trie_operations.is_empty() { + return Ok(()) + } + + let mut account_trie_cursor = tx.cursor_write::()?; + let mut storage_trie_cursor = tx.cursor_dup_write::()?; + + for (key, operation) in self.trie_operations { + match key { + TrieKey::AccountNode(nibbles) => match operation { + TrieOp::Delete => { + if account_trie_cursor.seek_exact(nibbles)?.is_some() { + account_trie_cursor.delete_current()?; + } + } + TrieOp::Update(node) => { + if !nibbles.inner.is_empty() { + account_trie_cursor.upsert(nibbles, node)?; + } + } + }, + TrieKey::StorageTrie(hashed_address) => match operation { + TrieOp::Delete => { + if storage_trie_cursor.seek_exact(hashed_address)?.is_some() { + storage_trie_cursor.delete_current_duplicates()?; + } + } + TrieOp::Update(..) => unreachable!("Cannot update full storage trie."), + }, + TrieKey::StorageNode(hashed_address, nibbles) => { + if !nibbles.inner.is_empty() { + // Delete the old entry if it exists. + if storage_trie_cursor + .seek_by_key_subkey(hashed_address, nibbles.clone())? + .filter(|e| e.nibbles == nibbles) + .is_some() + { + storage_trie_cursor.delete_current()?; + } + + // The operation is an update, insert new entry. + if let TrieOp::Update(node) = operation { + storage_trie_cursor + .upsert(hashed_address, StorageTrieEntry { nibbles, node })?; + } + } + } + }; + } + + Ok(()) + } +} diff --git a/crates/trie/src/walker.rs b/crates/trie/src/walker.rs index 8a4dcf0fd5..df9bce92fc 100644 --- a/crates/trie/src/walker.rs +++ b/crates/trie/src/walker.rs @@ -1,6 +1,7 @@ use crate::{ cursor::{CursorSubNode, TrieCursor}, prefix_set::PrefixSet, + updates::TrieUpdates, Nibbles, }; use reth_db::{table::Key, Error}; @@ -20,6 +21,8 @@ pub struct TrieWalker<'a, K, C> { pub can_skip_current_node: bool, /// A `PrefixSet` representing the changes to be applied to the trie. pub changes: PrefixSet, + /// The trie updates to be applied to the trie. + trie_updates: Option, __phantom: PhantomData, } @@ -30,8 +33,9 @@ impl<'a, K: Key + From>, C: TrieCursor> TrieWalker<'a, K, C> { let mut this = Self { cursor, changes, - can_skip_current_node: false, stack: vec![CursorSubNode::default()], + can_skip_current_node: false, + trie_updates: None, __phantom: PhantomData::default(), }; @@ -45,6 +49,39 @@ impl<'a, K: Key + From>, C: TrieCursor> TrieWalker<'a, K, C> { this } + /// Constructs a new TrieWalker from existing stack and a cursor. + pub fn from_stack(cursor: &'a mut C, stack: Vec, changes: PrefixSet) -> Self { + let mut this = Self { + cursor, + changes, + stack, + can_skip_current_node: false, + trie_updates: None, + __phantom: PhantomData::default(), + }; + this.update_skip_node(); + this + } + + /// Sets the flag whether the trie updates should be stored. + pub fn with_updates(mut self, retain_updates: bool) -> Self { + self.set_updates(retain_updates); + self + } + + /// Sets the flag whether the trie updates should be stored. + pub fn set_updates(&mut self, retain_updates: bool) { + if retain_updates { + self.trie_updates = Some(TrieUpdates::default()); + } + } + + /// Split the walker into stack and trie updates. + pub fn split(mut self) -> (Vec, TrieUpdates) { + let trie_updates = self.trie_updates.take(); + (self.stack, trie_updates.unwrap_or_default()) + } + /// Prints the current stack of trie nodes. pub fn print_stack(&self) { println!("====================== STACK ======================"); @@ -54,6 +91,11 @@ impl<'a, K: Key + From>, C: TrieCursor> TrieWalker<'a, K, C> { println!("====================== END STACK ======================\n"); } + /// The current length of the trie updates. + pub fn updates_len(&self) -> usize { + self.trie_updates.as_ref().map(|u| u.len()).unwrap_or(0) + } + /// Advances the walker to the next trie node and updates the skip node flag. /// /// # Returns @@ -121,7 +163,9 @@ impl<'a, K: Key + From>, C: TrieCursor> TrieWalker<'a, K, C> { // Delete the current node if it's included in the prefix set or it doesn't contain the root // hash. if !self.can_skip_current_node || nibble != -1 { - self.cursor.delete_current()?; + if let Some((updates, key)) = self.trie_updates.as_mut().zip(self.cursor.current()?) { + updates.schedule_delete(key); + } } Ok(()) @@ -209,7 +253,10 @@ impl<'a, K: Key + From>, C: TrieCursor> TrieWalker<'a, K, C> { mod tests { use super::*; use crate::cursor::{AccountTrieCursor, StorageTrieCursor}; - use reth_db::{mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut}; + use reth_db::{ + cursor::DbCursorRW, mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut, + }; + use reth_primitives::trie::StorageTrieEntry; use reth_provider::Transaction; #[test] @@ -237,26 +284,32 @@ mod tests { let db = create_test_rw_db(); let tx = Transaction::new(db.as_ref()).unwrap(); - let account_trie = - AccountTrieCursor::new(tx.cursor_write::().unwrap()); - test_cursor(account_trie, &inputs, &expected); + let mut account_cursor = tx.cursor_write::().unwrap(); + for (k, v) in &inputs { + account_cursor.upsert(k.clone().into(), v.clone()).unwrap(); + } + let account_trie = AccountTrieCursor::new(account_cursor); + test_cursor(account_trie, &expected); - let storage_trie = StorageTrieCursor::new( - tx.cursor_dup_write::().unwrap(), - H256::random(), - ); - test_cursor(storage_trie, &inputs, &expected); + let hashed_address = H256::random(); + let mut storage_cursor = tx.cursor_dup_write::().unwrap(); + for (k, v) in &inputs { + storage_cursor + .upsert( + hashed_address, + StorageTrieEntry { nibbles: k.clone().into(), node: v.clone() }, + ) + .unwrap(); + } + let storage_trie = StorageTrieCursor::new(storage_cursor, hashed_address); + test_cursor(storage_trie, &expected); } - fn test_cursor(mut trie: T, inputs: &[(Vec, BranchNodeCompact)], expected: &[Vec]) + fn test_cursor(mut trie: T, expected: &[Vec]) where K: Key + From>, T: TrieCursor, { - for (k, v) in inputs { - trie.upsert(k.clone().into(), v.clone()).unwrap(); - } - let mut walker = TrieWalker::new(&mut trie, Default::default()); assert!(walker.key().unwrap().is_empty()); @@ -275,11 +328,7 @@ mod tests { fn cursor_rootnode_with_changesets() { let db = create_test_rw_db(); let tx = Transaction::new(db.as_ref()).unwrap(); - - let mut trie = StorageTrieCursor::new( - tx.cursor_dup_write::().unwrap(), - H256::random(), - ); + let mut cursor = tx.cursor_dup_write::().unwrap(); let nodes = vec![ ( @@ -306,10 +355,13 @@ mod tests { ), ]; + let hashed_address = H256::random(); for (k, v) in nodes { - trie.upsert(k.into(), v).unwrap(); + cursor.upsert(hashed_address, StorageTrieEntry { nibbles: k.into(), node: v }).unwrap(); } + let mut trie = StorageTrieCursor::new(cursor, hashed_address); + // No changes let mut cursor = TrieWalker::new(&mut trie, Default::default()); assert_eq!(cursor.key(), Some(Nibbles::from(vec![]))); // root