From 6ff2510ad972ebb74e3c12101d618da074e23e64 Mon Sep 17 00:00:00 2001 From: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:28:12 +0000 Subject: [PATCH] perf(engine): cache proof targets in proof sequencer for state root task (#13310) --- crates/engine/tree/src/tree/root.rs | 213 ++++++++++++++++++-------- crates/trie/common/src/proofs.rs | 20 ++- crates/trie/db/src/proof.rs | 13 +- crates/trie/db/tests/witness.rs | 19 ++- crates/trie/sparse/src/state.rs | 8 +- crates/trie/trie/src/proof/blinded.rs | 2 +- crates/trie/trie/src/proof/mod.rs | 13 +- crates/trie/trie/src/witness.rs | 6 +- 8 files changed, 206 insertions(+), 88 deletions(-) diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 723d70ac58..5fc042f6c8 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -1,6 +1,6 @@ //! State root task related functionality. -use alloy_primitives::map::{HashMap, HashSet}; +use alloy_primitives::map::HashSet; use rayon::iter::{ParallelBridge, ParallelIterator}; use reth_evm::system_calls::OnStateHook; use reth_execution_errors::StateProofError; @@ -75,14 +75,7 @@ pub enum StateRootMessage { /// New state update from transaction execution StateUpdate(EvmState), /// Proof calculation completed for a specific state update - ProofCalculated { - /// The calculated proof - proof: MultiProof, - /// The state update that was used to calculate the proof - state_update: HashedPostState, - /// The index of this proof in the sequence of state updates - sequence_number: u64, - }, + ProofCalculated(Box), /// Error during proof calculation ProofCalculationError(StateProofError), /// State root calculation completed @@ -98,6 +91,19 @@ pub enum StateRootMessage { FinishedStateUpdates, } +/// Message about completion of proof calculation for a specific state update +#[derive(Debug)] +pub struct ProofCalculated { + /// The state update that was used to calculate the proof + state_update: HashedPostState, + /// The proof targets + targets: MultiProofTargets, + /// The calculated proof + proof: MultiProof, + /// The index of this proof in the sequence of state updates + sequence_number: u64, +} + /// Handle to track proof calculation ordering #[derive(Debug, Default)] pub(crate) struct ProofSequencer { @@ -106,7 +112,7 @@ pub(crate) struct ProofSequencer { /// The next sequence number expected to be delivered. next_to_deliver: u64, /// Buffer for out-of-order proofs and corresponding state updates - pending_proofs: BTreeMap, + pending_proofs: BTreeMap, } impl ProofSequencer { @@ -127,11 +133,12 @@ impl ProofSequencer { pub(crate) fn add_proof( &mut self, sequence: u64, - proof: MultiProof, state_update: HashedPostState, - ) -> Vec<(MultiProof, HashedPostState)> { + targets: MultiProofTargets, + proof: MultiProof, + ) -> Vec<(HashedPostState, MultiProofTargets, MultiProof)> { if sequence >= self.next_to_deliver { - self.pending_proofs.insert(sequence, (proof, state_update)); + self.pending_proofs.insert(sequence, (state_update, targets, proof)); } // return early if we don't have the next expected proof @@ -143,8 +150,8 @@ impl ProofSequencer { let mut current_sequence = self.next_to_deliver; // keep collecting proofs and state updates as long as we have consecutive sequence numbers - while let Some((proof, state_update)) = self.pending_proofs.remove(¤t_sequence) { - consecutive_proofs.push((proof, state_update)); + while let Some(pending) = self.pending_proofs.remove(¤t_sequence) { + consecutive_proofs.push(pending); current_sequence += 1; // if we don't have the next number, stop collecting @@ -319,9 +326,7 @@ where let hashed_state_update = evm_state_to_hashed_post_state(update); let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets); - for (address, slots) in &proof_targets { - fetched_proof_targets.entry(*address).or_default().extend(slots) - } + fetched_proof_targets.extend_ref(&proof_targets); // Dispatch proof gathering for this state update scope.spawn(move |_| { @@ -338,15 +343,18 @@ where provider.tx_ref(), // TODO(alexey): this clone can be expensive, we should avoid it input.as_ref().clone(), - proof_targets, + proof_targets.clone(), ); match result { Ok(proof) => { - let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated { - proof, - state_update: hashed_state_update, - sequence_number: proof_sequence_number, - }); + let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( + Box::new(ProofCalculated { + state_update: hashed_state_update, + targets: proof_targets, + proof, + sequence_number: proof_sequence_number, + }), + )); } Err(e) => { let _ = @@ -360,18 +368,21 @@ where fn on_proof( &mut self, sequence_number: u64, - proof: MultiProof, state_update: HashedPostState, - ) -> Option<(MultiProof, HashedPostState)> { - let ready_proofs = self.proof_sequencer.add_proof(sequence_number, proof, state_update); + targets: MultiProofTargets, + proof: MultiProof, + ) -> Option<(HashedPostState, MultiProofTargets, MultiProof)> { + let ready_proofs = + self.proof_sequencer.add_proof(sequence_number, state_update, targets, proof); if ready_proofs.is_empty() { None } else { // Merge all ready proofs and state updates - ready_proofs.into_iter().reduce(|mut acc, (proof, state_update)| { - acc.0.extend(proof); - acc.1.extend(state_update); + ready_proofs.into_iter().reduce(|mut acc, (state_update, targets, proof)| { + acc.0.extend(state_update); + acc.1.extend(targets); + acc.2.extend(proof); acc }) } @@ -382,6 +393,7 @@ where &mut self, scope: &rayon::Scope<'env>, state: HashedPostState, + targets: MultiProofTargets, multiproof: MultiProof, ) { let Some(trie) = self.sparse_trie.take() else { return }; @@ -394,7 +406,7 @@ where ); // TODO(alexey): store proof targets in `ProofSequecner` to avoid recomputing them - let targets = get_proof_targets(&state, &HashMap::default()); + let targets = get_proof_targets(&state, &targets); let tx = self.tx.clone(); scope.spawn(move |_| { @@ -417,6 +429,7 @@ where fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult { let mut current_state_update = HashedPostState::default(); + let mut current_proof_targets = MultiProofTargets::default(); let mut current_multiproof = MultiProof::default(); let mut updates_received = 0; let mut proofs_processed = 0; @@ -447,27 +460,36 @@ where StateRootMessage::FinishedStateUpdates => { updates_finished = true; } - StateRootMessage::ProofCalculated { proof, state_update, sequence_number } => { + StateRootMessage::ProofCalculated(proof_calculated) => { proofs_processed += 1; trace!( target: "engine::root", - sequence = sequence_number, + sequence = proof_calculated.sequence_number, total_proofs = proofs_processed, "Processing calculated proof" ); - trace!(target: "engine::root", ?proof, "Proof calculated"); + trace!(target: "engine::root", proof = ?proof_calculated.proof, "Proof calculated"); - if let Some((combined_proof, combined_state_update)) = - self.on_proof(sequence_number, proof, state_update) - { + if let Some(( + combined_state_update, + combined_proof_targets, + combined_proof, + )) = self.on_proof( + proof_calculated.sequence_number, + proof_calculated.state_update, + proof_calculated.targets, + proof_calculated.proof, + ) { if self.sparse_trie.is_none() { - current_multiproof.extend(combined_proof); current_state_update.extend(combined_state_update); + current_proof_targets.extend(combined_proof_targets); + current_multiproof.extend(combined_proof); } else { self.spawn_root_calculation( scope, combined_state_update, + combined_proof_targets, combined_proof, ); } @@ -509,6 +531,7 @@ where self.spawn_root_calculation( scope, std::mem::take(&mut current_state_update), + std::mem::take(&mut current_proof_targets), std::mem::take(&mut current_multiproof), ); } else if all_proofs_received && no_pending && updates_finished { @@ -564,7 +587,7 @@ fn get_proof_targets( state_update: &HashedPostState, fetched_proof_targets: &MultiProofTargets, ) -> MultiProofTargets { - let mut targets = HashMap::default(); + let mut targets = MultiProofTargets::default(); // first collect all new accounts (not previously fetched) for &hashed_address in state_update.accounts.keys() { @@ -830,11 +853,21 @@ mod tests { let proof2 = MultiProof::default(); sequencer.next_sequence = 2; - let ready = sequencer.add_proof(0, proof1, HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proof1, + ); assert_eq!(ready.len(), 1); assert!(!sequencer.has_pending()); - let ready = sequencer.add_proof(1, proof2, HashedPostState::default()); + let ready = sequencer.add_proof( + 1, + HashedPostState::default(), + MultiProofTargets::default(), + proof2, + ); assert_eq!(ready.len(), 1); assert!(!sequencer.has_pending()); } @@ -847,15 +880,30 @@ mod tests { let proof3 = MultiProof::default(); sequencer.next_sequence = 3; - let ready = sequencer.add_proof(2, proof3, HashedPostState::default()); + let ready = sequencer.add_proof( + 2, + HashedPostState::default(), + MultiProofTargets::default(), + proof3, + ); assert_eq!(ready.len(), 0); assert!(sequencer.has_pending()); - let ready = sequencer.add_proof(0, proof1, HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proof1, + ); assert_eq!(ready.len(), 1); assert!(sequencer.has_pending()); - let ready = sequencer.add_proof(1, proof2, HashedPostState::default()); + let ready = sequencer.add_proof( + 1, + HashedPostState::default(), + MultiProofTargets::default(), + proof2, + ); assert_eq!(ready.len(), 2); assert!(!sequencer.has_pending()); } @@ -867,10 +915,20 @@ mod tests { let proof3 = MultiProof::default(); sequencer.next_sequence = 3; - let ready = sequencer.add_proof(0, proof1, HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proof1, + ); assert_eq!(ready.len(), 1); - let ready = sequencer.add_proof(2, proof3, HashedPostState::default()); + let ready = sequencer.add_proof( + 2, + HashedPostState::default(), + MultiProofTargets::default(), + proof3, + ); assert_eq!(ready.len(), 0); assert!(sequencer.has_pending()); } @@ -881,10 +939,20 @@ mod tests { let proof1 = MultiProof::default(); let proof2 = MultiProof::default(); - let ready = sequencer.add_proof(0, proof1, HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proof1, + ); assert_eq!(ready.len(), 1); - let ready = sequencer.add_proof(0, proof2, HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proof2, + ); assert_eq!(ready.len(), 0); assert!(!sequencer.has_pending()); } @@ -895,12 +963,37 @@ mod tests { let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect(); sequencer.next_sequence = 5; - sequencer.add_proof(4, proofs[4].clone(), HashedPostState::default()); - sequencer.add_proof(2, proofs[2].clone(), HashedPostState::default()); - sequencer.add_proof(1, proofs[1].clone(), HashedPostState::default()); - sequencer.add_proof(3, proofs[3].clone(), HashedPostState::default()); + sequencer.add_proof( + 4, + HashedPostState::default(), + MultiProofTargets::default(), + proofs[4].clone(), + ); + sequencer.add_proof( + 2, + HashedPostState::default(), + MultiProofTargets::default(), + proofs[2].clone(), + ); + sequencer.add_proof( + 1, + HashedPostState::default(), + MultiProofTargets::default(), + proofs[1].clone(), + ); + sequencer.add_proof( + 3, + HashedPostState::default(), + MultiProofTargets::default(), + proofs[3].clone(), + ); - let ready = sequencer.add_proof(0, proofs[0].clone(), HashedPostState::default()); + let ready = sequencer.add_proof( + 0, + HashedPostState::default(), + MultiProofTargets::default(), + proofs[0].clone(), + ); assert_eq!(ready.len(), 5); assert!(!sequencer.has_pending()); } @@ -926,7 +1019,7 @@ mod tests { #[test] fn test_get_proof_targets_new_account_targets() { let state = create_get_proof_targets_state(); - let fetched = HashMap::default(); + let fetched = MultiProofTargets::default(); let targets = get_proof_targets(&state, &fetched); @@ -940,7 +1033,7 @@ mod tests { #[test] fn test_get_proof_targets_new_storage_targets() { let state = create_get_proof_targets_state(); - let fetched = HashMap::default(); + let fetched = MultiProofTargets::default(); let targets = get_proof_targets(&state, &fetched); @@ -958,7 +1051,7 @@ mod tests { #[test] fn test_get_proof_targets_filter_already_fetched_accounts() { let state = create_get_proof_targets_state(); - let mut fetched = HashMap::default(); + let mut fetched = MultiProofTargets::default(); // select an account that has no storage updates let fetched_addr = state @@ -981,7 +1074,7 @@ mod tests { #[test] fn test_get_proof_targets_filter_already_fetched_storage() { let state = create_get_proof_targets_state(); - let mut fetched = HashMap::default(); + let mut fetched = MultiProofTargets::default(); // mark one storage slot as already fetched let (addr, storage) = state.storages.iter().next().unwrap(); @@ -1001,7 +1094,7 @@ mod tests { #[test] fn test_get_proof_targets_empty_state() { let state = HashedPostState::default(); - let fetched = HashMap::default(); + let fetched = MultiProofTargets::default(); let targets = get_proof_targets(&state, &fetched); @@ -1011,7 +1104,7 @@ mod tests { #[test] fn test_get_proof_targets_mixed_fetched_state() { let mut state = HashedPostState::default(); - let mut fetched = HashMap::default(); + let mut fetched = MultiProofTargets::default(); let addr1 = B256::random(); let addr2 = B256::random(); @@ -1040,7 +1133,7 @@ mod tests { #[test] fn test_get_proof_targets_unmodified_account_with_storage() { let mut state = HashedPostState::default(); - let fetched = HashMap::default(); + let fetched = MultiProofTargets::default(); let addr = B256::random(); let slot1 = B256::random(); diff --git a/crates/trie/common/src/proofs.rs b/crates/trie/common/src/proofs.rs index eb3626d90e..8455d1e8ac 100644 --- a/crates/trie/common/src/proofs.rs +++ b/crates/trie/common/src/proofs.rs @@ -13,11 +13,29 @@ use alloy_trie::{ proof::{verify_proof, ProofNodes, ProofVerificationError}, TrieMask, EMPTY_ROOT_HASH, }; +use derive_more::derive::{Deref, DerefMut, From, Into, IntoIterator}; use itertools::Itertools; use reth_primitives_traits::Account; /// Proof targets map. -pub type MultiProofTargets = B256HashMap; +#[derive(Debug, Default, Clone, Deref, DerefMut, From, Into, IntoIterator)] +pub struct MultiProofTargets(B256HashMap); + +impl MultiProofTargets { + /// Extends the proof targets map with another one. + pub fn extend(&mut self, other: Self) { + for (address, slots) in other.0 { + self.0.entry(address).or_default().extend(slots); + } + } + + /// Extends the proof targets map with another one by reference. + pub fn extend_ref(&mut self, other: &Self) { + for (address, slots) in &other.0 { + self.0.entry(*address).or_default().extend(slots); + } + } +} /// The state multiproof of target accounts and multiproofs of their storage tries. /// Multiproof is effectively a state subtrie that only contains the nodes diff --git a/crates/trie/db/src/proof.rs b/crates/trie/db/src/proof.rs index d7263a9436..137e661b05 100644 --- a/crates/trie/db/src/proof.rs +++ b/crates/trie/db/src/proof.rs @@ -1,16 +1,13 @@ use crate::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use alloy_primitives::{ - keccak256, - map::{B256HashMap, B256HashSet, HashMap}, - Address, B256, -}; +use alloy_primitives::{keccak256, map::HashMap, Address, B256}; use reth_db_api::transaction::DbTx; use reth_execution_errors::StateProofError; use reth_trie::{ hashed_cursor::HashedPostStateCursorFactory, proof::{Proof, StorageProof}, trie_cursor::InMemoryTrieCursorFactory, - AccountProof, HashedPostStateSorted, HashedStorage, MultiProof, StorageMultiProof, TrieInput, + AccountProof, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, + StorageMultiProof, TrieInput, }; /// Extends [`Proof`] with operations specific for working with a database transaction. @@ -30,7 +27,7 @@ pub trait DatabaseProof<'a, TX> { fn overlay_multiproof( tx: &'a TX, input: TrieInput, - targets: B256HashMap, + targets: MultiProofTargets, ) -> Result; } @@ -66,7 +63,7 @@ impl<'a, TX: DbTx> DatabaseProof<'a, TX> fn overlay_multiproof( tx: &'a TX, input: TrieInput, - targets: B256HashMap, + targets: MultiProofTargets, ) -> Result { let nodes_sorted = input.nodes.into_sorted(); let state_sorted = input.state.into_sorted(); diff --git a/crates/trie/db/tests/witness.rs b/crates/trie/db/tests/witness.rs index 385f6269f3..c9732bef49 100644 --- a/crates/trie/db/tests/witness.rs +++ b/crates/trie/db/tests/witness.rs @@ -39,7 +39,9 @@ fn includes_empty_node_preimage() { let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap(); let multiproof = Proof::from_tx(provider.tx_ref()) - .multiproof(HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))])) + .multiproof( + HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]).into(), + ) .unwrap(); let witness = TrieWitness::from_tx(provider.tx_ref()) @@ -77,7 +79,9 @@ fn includes_nodes_for_destroyed_storage_nodes() { let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap(); let multiproof = Proof::from_tx(provider.tx_ref()) - .multiproof(HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))])) + .multiproof( + HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]).into(), + ) .unwrap(); let witness = @@ -122,10 +126,13 @@ fn correctly_decodes_branch_node_values() { let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap(); let multiproof = Proof::from_tx(provider.tx_ref()) - .multiproof(HashMap::from_iter([( - hashed_address, - HashSet::from_iter([hashed_slot1, hashed_slot2]), - )])) + .multiproof( + HashMap::from_iter([( + hashed_address, + HashSet::from_iter([hashed_slot1, hashed_slot2]), + )]) + .into(), + ) .unwrap(); let witness = TrieWitness::from_tx(provider.tx_ref()) diff --git a/crates/trie/sparse/src/state.rs b/crates/trie/sparse/src/state.rs index 1dad2a1378..5112e64a22 100644 --- a/crates/trie/sparse/src/state.rs +++ b/crates/trie/sparse/src/state.rs @@ -15,7 +15,8 @@ use reth_primitives_traits::Account; use reth_tracing::tracing::trace; use reth_trie_common::{ updates::{StorageTrieUpdates, TrieUpdates}, - MultiProof, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH, TRIE_ACCOUNT_RLP_MAX_SIZE, + MultiProof, MultiProofTargets, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH, + TRIE_ACCOUNT_RLP_MAX_SIZE, }; use std::{fmt, iter::Peekable}; @@ -206,7 +207,7 @@ impl SparseStateTrie { /// NOTE: This method does not extensively validate the proof. pub fn reveal_multiproof( &mut self, - targets: B256HashMap, + targets: MultiProofTargets, multiproof: MultiProof, ) -> SparseStateTrieResult<()> { let account_subtree = multiproof.account_subtree.into_nodes_sorted(); @@ -559,7 +560,8 @@ mod tests { HashMap::from_iter([ (address_1, HashSet::from_iter([slot_1, slot_2])), (address_2, HashSet::from_iter([slot_1, slot_2])), - ]), + ]) + .into(), MultiProof { account_subtree: proof_nodes, branch_node_hash_masks: HashMap::from_iter([( diff --git a/crates/trie/trie/src/proof/blinded.rs b/crates/trie/trie/src/proof/blinded.rs index 33a1a43b57..eb713c05aa 100644 --- a/crates/trie/trie/src/proof/blinded.rs +++ b/crates/trie/trie/src/proof/blinded.rs @@ -91,7 +91,7 @@ where let proof = Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone()) .with_prefix_sets_mut(self.prefix_sets.as_ref().clone()) - .multiproof(targets) + .multiproof(targets.into()) .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?; Ok(proof.account_subtree.into_inner().remove(path)) diff --git a/crates/trie/trie/src/proof/mod.rs b/crates/trie/trie/src/proof/mod.rs index 1414d4a344..fe31f7baa9 100644 --- a/crates/trie/trie/src/proof/mod.rs +++ b/crates/trie/trie/src/proof/mod.rs @@ -14,7 +14,8 @@ use alloy_primitives::{ use alloy_rlp::{BufMut, Encodable}; use reth_execution_errors::trie::StateProofError; use reth_trie_common::{ - proof::ProofRetainer, AccountProof, MultiProof, StorageMultiProof, TrieAccount, + proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof, + TrieAccount, }; mod blinded; @@ -93,17 +94,17 @@ where slots: &[B256], ) -> Result { Ok(self - .multiproof(HashMap::from_iter([( - keccak256(address), - slots.iter().map(keccak256).collect(), - )]))? + .multiproof( + HashMap::from_iter([(keccak256(address), slots.iter().map(keccak256).collect())]) + .into(), + )? .account_proof(address, slots)?) } /// Generate a state multiproof according to specified targets. pub fn multiproof( mut self, - mut targets: B256HashMap, + mut targets: MultiProofTargets, ) -> Result { let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?; let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?; diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index f48815d9ba..e297921468 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -15,7 +15,7 @@ use reth_execution_errors::{ SparseStateTrieError, SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind, StateProofError, TrieWitnessError, }; -use reth_trie_common::Nibbles; +use reth_trie_common::{MultiProofTargets, Nibbles}; use reth_trie_sparse::{ blinded::{BlindedProvider, BlindedProviderFactory}, SparseStateTrie, @@ -171,8 +171,8 @@ where fn get_proof_targets( &self, state: &HashedPostState, - ) -> Result, StateProofError> { - let mut proof_targets = B256HashMap::default(); + ) -> Result { + let mut proof_targets = MultiProofTargets::default(); for hashed_address in state.accounts.keys() { proof_targets.insert(*hashed_address, B256HashSet::default()); }