diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 9913dcfa87..e8be9ad9ba 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -6,7 +6,7 @@ use crate::{ StorageRootTargets, }; use alloy_primitives::{ - map::{B256Map, HashMap}, + map::{B256Map, B256Set, HashMap}, B256, }; use alloy_rlp::{BufMut, Encodable}; @@ -20,7 +20,7 @@ use reth_storage_errors::db::DatabaseError; use reth_trie::{ hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, node_iter::{TrieElement, TrieNodeIter}, - prefix_set::{PrefixSetMut, TriePrefixSetsMut}, + prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut}, proof::StorageProof, trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory}, updates::TrieUpdatesSorted, @@ -30,7 +30,7 @@ use reth_trie::{ }; use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use std::sync::Arc; +use std::sync::{mpsc::Receiver, Arc}; use tracing::debug; /// Parallel proof calculator. @@ -91,6 +91,60 @@ where Factory: DatabaseProviderFactory + StateCommitmentProvider + Clone + 'static, { + /// Spawns a storage proof on the storage proof task and returns a receiver for the result. + fn spawn_storage_proof( + &self, + hashed_address: B256, + prefix_set: PrefixSet, + target_slots: B256Set, + ) -> Receiver> { + let input = StorageProofInput::new( + hashed_address, + prefix_set, + target_slots, + self.collect_branch_node_masks, + ); + + let (sender, receiver) = std::sync::mpsc::channel(); + let _ = + self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender)); + receiver + } + + /// Generate a storage multiproof according to the specified targets and hashed address. + pub fn storage_proof( + self, + hashed_address: B256, + target_slots: B256Set, + ) -> Result { + let total_targets = target_slots.len(); + let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack)); + let prefix_set = prefix_set.freeze(); + + debug!( + target: "trie::parallel_proof", + total_targets, + ?hashed_address, + "Starting storage proof generation" + ); + + let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots); + let proof_result = receiver.recv().map_err(|_| { + ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other( + format!("channel closed for {hashed_address}"), + ))) + })?; + + debug!( + target: "trie::parallel_proof", + total_targets, + ?hashed_address, + "Storage proof generation completed" + ); + + proof_result + } + /// Generate a state multiproof according to specified targets. pub fn multiproof( self, @@ -137,17 +191,7 @@ where storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) { let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - - let input = StorageProofInput::new( - hashed_address, - prefix_set, - target_slots, - self.collect_branch_node_masks, - ); - let (sender, receiver) = std::sync::mpsc::channel(); - let _ = self - .storage_proof_task_handle - .queue_task(ProofTaskKind::StorageProof(input, sender)); + let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots); // store the receiver for that result with the hashed address so we can await this in // place when we iterate over the trie