diff --git a/crates/trie/parallel/Cargo.toml b/crates/trie/parallel/Cargo.toml index 9fb882b44a..d64f2dfb51 100644 --- a/crates/trie/parallel/Cargo.toml +++ b/crates/trie/parallel/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] # reth reth-execution-errors.workspace = true +reth-primitives-traits.workspace = true reth-provider.workspace = true reth-storage-errors.workspace = true reth-trie-common.workspace = true diff --git a/crates/trie/parallel/src/lib.rs b/crates/trie/parallel/src/lib.rs index ba88ab690d..cba9d9440e 100644 --- a/crates/trie/parallel/src/lib.rs +++ b/crates/trie/parallel/src/lib.rs @@ -22,6 +22,9 @@ pub mod proof; pub mod proof_task; +/// Async value encoder for V2 proofs. +pub(crate) mod value_encoder; + /// V2 multiproof targets and chunking. pub mod targets_v2; diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index c6c0d89555..1d492e2775 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -41,7 +41,7 @@ use alloy_primitives::{ use alloy_rlp::{BufMut, Encodable}; use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; use dashmap::DashMap; -use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind}; +use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, StateProofError}; use reth_provider::{DatabaseProviderROFactory, ProviderError, ProviderResult}; use reth_storage_errors::db::DatabaseError; use reth_trie::{ @@ -305,17 +305,17 @@ impl ProofWorkerHandle { self.storage_work_tx .send(StorageWorkerJob::StorageProof { input, proof_result_sender }) .map_err(|err| { - let error = - ProviderError::other(std::io::Error::other("storage workers unavailable")); - if let StorageWorkerJob::StorageProof { proof_result_sender, .. } = err.0 { let _ = proof_result_sender.send(StorageProofResultMessage { hashed_address, - result: Err(ParallelStateRootError::Provider(error.clone())), + result: Err(DatabaseError::Other( + "storage workers unavailable".to_string(), + ) + .into()), }); } - error + ProviderError::other(std::io::Error::other("storage workers unavailable")) }) } @@ -432,7 +432,7 @@ where input: StorageProofInput, trie_cursor_metrics: &mut TrieCursorMetricsCache, hashed_cursor_metrics: &mut HashedCursorMetricsCache, - ) -> Result { + ) -> Result { // Consume the input so we can move large collections (e.g. target slots) without cloning. let StorageProofInput::Legacy { hashed_address, @@ -469,20 +469,13 @@ where .with_added_removed_keys(added_removed_keys) .with_trie_cursor_metrics(trie_cursor_metrics) .with_hashed_cursor_metrics(hashed_cursor_metrics) - .storage_multiproof(target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string())); + .storage_multiproof(target_slots); trie_cursor_metrics.record_span("trie_cursor"); hashed_cursor_metrics.record_span("hashed_cursor"); // Decode proof into DecodedStorageMultiProof - let decoded_result = raw_proof_result.and_then(|raw_proof| { - raw_proof.try_into().map_err(|e: alloy_rlp::Error| { - ParallelStateRootError::Other(format!( - "Failed to decode storage proof for {}: {}", - hashed_address, e - )) - }) - })?; + let decoded_result = + raw_proof_result.and_then(|raw_proof| raw_proof.try_into().map_err(Into::into))?; trace!( target: "trie::proof_task", @@ -502,7 +495,7 @@ where ::StorageTrieCursor<'_>, ::StorageCursor<'_>, >, - ) -> Result { + ) -> Result { let StorageProofInput::V2 { hashed_address, mut targets } = input else { panic!("compute_v2_storage_proof only accepts StorageProofInput::V2") }; @@ -717,12 +710,12 @@ pub struct StorageProofResultMessage { /// The hashed address this storage proof belongs to pub(crate) hashed_address: B256, /// The storage proof calculation result - pub(crate) result: Result, + pub(crate) result: Result, } /// Internal message for storage workers. #[derive(Debug)] -enum StorageWorkerJob { +pub(crate) enum StorageWorkerJob { /// Storage proof computation request StorageProof { /// Storage proof input parameters diff --git a/crates/trie/parallel/src/value_encoder.rs b/crates/trie/parallel/src/value_encoder.rs new file mode 100644 index 0000000000..13c611922d --- /dev/null +++ b/crates/trie/parallel/src/value_encoder.rs @@ -0,0 +1,185 @@ +use crate::proof_task::{ + StorageProofInput, StorageProofResult, StorageProofResultMessage, StorageWorkerJob, +}; +use alloy_primitives::{map::B256Map, B256}; +use alloy_rlp::Encodable; +use core::cell::RefCell; +use crossbeam_channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; +use dashmap::DashMap; +use reth_execution_errors::trie::StateProofError; +use reth_primitives_traits::Account; +use reth_storage_errors::db::DatabaseError; +use reth_trie::{ + proof_v2::{DeferredValueEncoder, LeafValueEncoder, Target}, + ProofTrieNode, +}; +use std::{rc::Rc, sync::Arc}; + +/// Returned from [`AsyncAccountValueEncoder`], used to track an async storage root calculation. +pub(crate) enum AsyncAccountDeferredValueEncoder { + Dispatched { + hashed_address: B256, + account: Account, + proof_result_rx: Result, DatabaseError>, + // None if results shouldn't be retained for this dispatched proof. + storage_proof_results: Option>>>>, + }, + FromCache { + account: Account, + root: B256, + }, +} + +impl DeferredValueEncoder for AsyncAccountDeferredValueEncoder { + fn encode(self, buf: &mut Vec) -> Result<(), StateProofError> { + let (account, root) = match self { + Self::Dispatched { + hashed_address, + account, + proof_result_rx, + storage_proof_results, + } => { + let result = proof_result_rx? + .recv() + .map_err(|_| { + StateProofError::Database(DatabaseError::Other(format!( + "Storage proof channel closed for {hashed_address:?}", + ))) + })? + .result?; + + let StorageProofResult::V2 { root: Some(root), proof } = result else { + panic!("StorageProofResult is not V2 with root: {result:?}") + }; + + if let Some(storage_proof_results) = storage_proof_results.as_ref() { + storage_proof_results.borrow_mut().insert(hashed_address, proof); + } + + (account, root) + } + Self::FromCache { account, root } => (account, root), + }; + + let account = account.into_trie_account(root); + account.encode(buf); + Ok(()) + } +} + +/// Implements the [`LeafValueEncoder`] trait for accounts using a [`CrossbeamSender`] to dispatch +/// and compute storage roots asynchronously. Can also accept a set of already dispatched account +/// storage proofs, for cases where it's possible to determine some necessary accounts ahead of +/// time. +pub(crate) struct AsyncAccountValueEncoder { + storage_work_tx: CrossbeamSender, + /// Storage proof jobs which were dispatched ahead of time. + dispatched: B256Map>, + /// Storage roots which have already been computed. This can be used only if a storage proof + /// wasn't dispatched for an account, otherwise we must consume the proof result. + cached_storage_roots: Arc>, + /// Tracks storage proof results received from the storage workers. [`Rc`] + [`RefCell`] is + /// required because [`DeferredValueEncoder`] cannot have a lifetime. + storage_proof_results: Rc>>>, +} + +impl AsyncAccountValueEncoder { + /// Initializes a [`Self`] using a `ProofWorkerHandle` which will be used to calculate storage + /// roots asynchronously. + #[expect(dead_code)] + pub(crate) fn new( + storage_work_tx: CrossbeamSender, + dispatched: B256Map>, + cached_storage_roots: Arc>, + ) -> Self { + Self { + storage_work_tx, + dispatched, + cached_storage_roots, + storage_proof_results: Default::default(), + } + } + + /// Consume [`Self`] and return all collected storage proofs which had been dispatched. + /// + /// # Panics + /// + /// This method panics if any deferred encoders produced by [`Self::deferred_encoder`] have not + /// been dropped. + #[expect(dead_code)] + pub(crate) fn into_storage_proofs( + self, + ) -> Result>, StateProofError> { + let mut storage_proof_results = Rc::into_inner(self.storage_proof_results) + .expect("no deferred encoders are still allocated") + .into_inner(); + + // Any remaining dispatched proofs need to have their results collected + for (hashed_address, rx) in &self.dispatched { + let result = rx + .recv() + .map_err(|_| { + StateProofError::Database(DatabaseError::Other(format!( + "Storage proof channel closed for {hashed_address:?}", + ))) + })? + .result?; + + let StorageProofResult::V2 { proof, .. } = result else { + panic!("StorageProofResult is not V2: {result:?}") + }; + + storage_proof_results.insert(*hashed_address, proof); + } + + Ok(storage_proof_results) + } +} + +impl LeafValueEncoder for AsyncAccountValueEncoder { + type Value = Account; + type DeferredEncoder = AsyncAccountDeferredValueEncoder; + + fn deferred_encoder( + &mut self, + hashed_address: B256, + account: Self::Value, + ) -> Self::DeferredEncoder { + // If the proof job has already been dispatched for this account then it's not necessary to + // dispatch another. + if let Some(rx) = self.dispatched.remove(&hashed_address) { + return AsyncAccountDeferredValueEncoder::Dispatched { + hashed_address, + account, + proof_result_rx: Ok(rx), + storage_proof_results: Some(self.storage_proof_results.clone()), + } + } + + // If the address didn't have a job dispatched for it then we can assume it has no targets, + // and we only need its root. + + // If the root is already calculated then just use it directly + if let Some(root) = self.cached_storage_roots.get(&hashed_address) { + return AsyncAccountDeferredValueEncoder::FromCache { account, root: *root } + } + + // Create a proof input which targets a bogus key, so that we calculate the root as a + // side-effect. + let input = StorageProofInput::new(hashed_address, vec![Target::new(B256::ZERO)]); + let (tx, rx) = crossbeam_channel::bounded(1); + + let proof_result_rx = self + .storage_work_tx + .send(StorageWorkerJob::StorageProof { input, proof_result_sender: tx }) + .map_err(|_| DatabaseError::Other("storage workers unavailable".to_string())) + .map(|_| rx); + + AsyncAccountDeferredValueEncoder::Dispatched { + hashed_address, + account, + proof_result_rx, + storage_proof_results: None, + } + } +} diff --git a/crates/trie/trie/src/proof_v2/mod.rs b/crates/trie/trie/src/proof_v2/mod.rs index f421ba7bb8..8861def8a5 100644 --- a/crates/trie/trie/src/proof_v2/mod.rs +++ b/crates/trie/trie/src/proof_v2/mod.rs @@ -651,7 +651,7 @@ where )] fn calculate_key_range<'a>( &mut self, - value_encoder: &VE, + value_encoder: &mut VE, targets: &mut TargetsCursor<'a>, hashed_cursor_current: &mut Option<(Nibbles, VE::DeferredEncoder)>, lower_bound: Nibbles, @@ -660,7 +660,7 @@ where // A helper closure for mapping entries returned from the `hashed_cursor`, converting the // key to Nibbles and immediately creating the DeferredValueEncoder so that encoding of the // leaf value can begin ASAP. - let map_hashed_cursor_entry = |(key_b256, val): (B256, _)| { + let mut map_hashed_cursor_entry = |(key_b256, val): (B256, _)| { debug_assert_eq!(key_b256.len(), 32); // SAFETY: key is a B256 and so is exactly 32-bytes. let key = unsafe { Nibbles::unpack_unchecked(key_b256.as_slice()) }; @@ -679,7 +679,7 @@ where let lower_key = B256::right_padding_from(&lower_bound.pack()); *hashed_cursor_current = - self.hashed_cursor.seek(lower_key)?.map(map_hashed_cursor_entry); + self.hashed_cursor.seek(lower_key)?.map(&mut map_hashed_cursor_entry); } // Loop over all keys in the range, calling `push_leaf` on each. @@ -689,7 +689,7 @@ where let (key, val) = core::mem::take(hashed_cursor_current).expect("while-let checks for Some"); self.push_leaf(targets, key, val)?; - *hashed_cursor_current = self.hashed_cursor.next()?.map(map_hashed_cursor_entry); + *hashed_cursor_current = self.hashed_cursor.next()?.map(&mut map_hashed_cursor_entry); } trace!(target: TRACE_TARGET, "No further keys within range"); @@ -1125,7 +1125,7 @@ where )] fn proof_subtrie<'a>( &mut self, - value_encoder: &VE, + value_encoder: &mut VE, trie_cursor_state: &mut TrieCursorState, hashed_cursor_current: &mut Option<(Nibbles, VE::DeferredEncoder)>, sub_trie_targets: SubTrieTargets<'a>, @@ -1254,7 +1254,7 @@ where /// See docs on [`Self::proof`] for expected behavior. fn proof_inner( &mut self, - value_encoder: &VE, + value_encoder: &mut VE, targets: &mut [Target], ) -> Result, StateProofError> { // If there are no targets then nothing could be returned, return early. @@ -1305,7 +1305,7 @@ where #[instrument(target = TRACE_TARGET, level = "trace", skip_all)] pub fn proof( &mut self, - value_encoder: &VE, + value_encoder: &mut VE, targets: &mut [Target], ) -> Result, StateProofError> { self.trie_cursor.reset(); @@ -1341,9 +1341,6 @@ where hashed_address: B256, targets: &mut [Target], ) -> Result, StateProofError> { - /// Static storage value encoder instance used by all storage proofs. - static STORAGE_VALUE_ENCODER: StorageValueEncoder = StorageValueEncoder; - self.hashed_cursor.set_hashed_address(hashed_address); // Shortcut: check if storage is empty @@ -1360,8 +1357,9 @@ where // been checked. self.trie_cursor.set_hashed_address(hashed_address); - // Use the static StorageValueEncoder and pass it to proof_inner - self.proof_inner(&STORAGE_VALUE_ENCODER, targets) + // Create a mutable storage value encoder + let mut storage_value_encoder = StorageValueEncoder; + self.proof_inner(&mut storage_value_encoder, targets) } /// Computes the root hash from a set of proof nodes. @@ -1639,13 +1637,13 @@ mod tests { InstrumentedHashedCursor::new(hashed_cursor, &mut hashed_cursor_metrics); // Call ProofCalculator::proof with account targets - let value_encoder = SyncAccountValueEncoder::new( + let mut value_encoder = SyncAccountValueEncoder::new( self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone(), ); let mut proof_calculator = ProofCalculator::new(trie_cursor, hashed_cursor); let proof_v2_result = - proof_calculator.proof(&value_encoder, &mut targets_vec.clone())?; + proof_calculator.proof(&mut value_encoder, &mut targets_vec.clone())?; // Output metrics trace!(target: TRACE_TARGET, ?trie_cursor_metrics, "V2 trie cursor metrics"); diff --git a/crates/trie/trie/src/proof_v2/value.rs b/crates/trie/trie/src/proof_v2/value.rs index dd330d9a87..2b7b085119 100644 --- a/crates/trie/trie/src/proof_v2/value.rs +++ b/crates/trie/trie/src/proof_v2/value.rs @@ -44,7 +44,7 @@ pub trait LeafValueEncoder { /// /// The returned deferred encoder will be called as late as possible in the algorithm to /// maximize the time available for parallel computation (e.g., storage root calculation). - fn deferred_encoder(&self, key: B256, value: Self::Value) -> Self::DeferredEncoder; + fn deferred_encoder(&mut self, key: B256, value: Self::Value) -> Self::DeferredEncoder; } /// An encoder for storage slot values. @@ -68,7 +68,7 @@ impl LeafValueEncoder for StorageValueEncoder { type Value = U256; type DeferredEncoder = StorageDeferredValueEncoder; - fn deferred_encoder(&self, _key: B256, value: Self::Value) -> Self::DeferredEncoder { + fn deferred_encoder(&mut self, _key: B256, value: Self::Value) -> Self::DeferredEncoder { StorageDeferredValueEncoder(value) } } @@ -157,7 +157,7 @@ where type DeferredEncoder = SyncAccountDeferredValueEncoder; fn deferred_encoder( - &self, + &mut self, hashed_address: B256, account: Self::Value, ) -> Self::DeferredEncoder {