From b94a31f6d8c36be03a905cbfd3ba9bc201c2703f Mon Sep 17 00:00:00 2001 From: Federico Gimenez Date: Thu, 12 Dec 2024 14:24:47 +0100 Subject: [PATCH] feat(trie): replace TrieInput by its components in ParallelProof (#13346) --- crates/trie/parallel/src/proof.rs | 51 ++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 6dd4a9a013..cabd9c7e06 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -17,9 +17,10 @@ use reth_trie::{ prefix_set::{PrefixSetMut, TriePrefixSetsMut}, proof::StorageProof, trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory}, + updates::TrieUpdatesSorted, walker::TrieWalker, - HashBuilder, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof, TrieAccount, TrieInput, - TRIE_ACCOUNT_RLP_MAX_SIZE, + HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof, + TrieAccount, TRIE_ACCOUNT_RLP_MAX_SIZE, }; use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; @@ -34,8 +35,15 @@ use crate::metrics::ParallelStateRootMetrics; pub struct ParallelProof { /// Consistent view of the database. view: ConsistentDbView, - /// Trie input. - input: Arc, + /// The sorted collection of cached in-memory intermediate trie nodes that + /// can be reused for computation. + pub nodes_sorted: Arc, + /// The sorted in-memory overlay hashed state. + pub state_sorted: Arc, + /// The collection of prefix sets for the computation. Since the prefix sets _always_ + /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here, + /// if we have cached nodes for them. + pub prefix_sets: Arc, /// Flag indicating whether to include branch node hash masks in the proof. collect_branch_node_hash_masks: bool, /// Parallel state root metrics. @@ -45,10 +53,17 @@ pub struct ParallelProof { impl ParallelProof { /// Create new state proof generator. - pub fn new(view: ConsistentDbView, input: Arc) -> Self { + pub fn new( + view: ConsistentDbView, + nodes_sorted: Arc, + state_sorted: Arc, + prefix_sets: Arc, + ) -> Self { Self { view, - input, + nodes_sorted, + state_sorted, + prefix_sets, collect_branch_node_hash_masks: false, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), @@ -78,11 +93,8 @@ where ) -> Result { let mut tracker = ParallelTrieTracker::default(); - let trie_nodes_sorted = self.input.nodes.clone().into_sorted(); - let hashed_state_sorted = self.input.state.clone().into_sorted(); - // Extend prefix sets with targets - let mut prefix_sets = self.input.prefix_sets.clone(); + let mut prefix_sets = (*self.prefix_sets).clone(); prefix_sets.extend(TriePrefixSetsMut { account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)), storage_prefix_sets: targets @@ -112,8 +124,8 @@ where let view = self.view.clone(); let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - let trie_nodes_sorted = trie_nodes_sorted.clone(); - let hashed_state_sorted = hashed_state_sorted.clone(); + let trie_nodes_sorted = self.nodes_sorted.clone(); + let hashed_state_sorted = self.state_sorted.clone(); let (tx, rx) = std::sync::mpsc::sync_channel(1); @@ -149,11 +161,11 @@ where let provider_ro = self.view.provider_ro()?; let trie_cursor_factory = InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), - &trie_nodes_sorted, + &self.nodes_sorted, ); let hashed_cursor_factory = HashedPostStateCursorFactory::new( DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), - &hashed_state_sorted, + &self.state_sorted, ); // Create the walker. @@ -327,9 +339,14 @@ mod tests { let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref()); assert_eq!( - ParallelProof::new(consistent_view, Default::default()) - .multiproof(targets.clone()) - .unwrap(), + ParallelProof::new( + consistent_view, + Default::default(), + Default::default(), + Default::default() + ) + .multiproof(targets.clone()) + .unwrap(), Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap() ); }