mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-04 12:05:12 -05:00
476 lines
19 KiB
Rust
476 lines
19 KiB
Rust
use crate::{
|
|
metrics::ParallelTrieMetrics,
|
|
proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
|
|
root::ParallelStateRootError,
|
|
stats::ParallelTrieTracker,
|
|
StorageRootTargets,
|
|
};
|
|
use alloy_primitives::{
|
|
map::{B256Map, B256Set, HashMap},
|
|
B256,
|
|
};
|
|
use alloy_rlp::{BufMut, Encodable};
|
|
use itertools::Itertools;
|
|
use reth_execution_errors::StorageRootError;
|
|
use reth_provider::{
|
|
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
|
|
ProviderError, StateCommitmentProvider,
|
|
};
|
|
use reth_storage_errors::db::DatabaseError;
|
|
use reth_trie::{
|
|
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
|
|
node_iter::{TrieElement, TrieNodeIter},
|
|
prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
|
|
proof::StorageProof,
|
|
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
|
|
updates::TrieUpdatesSorted,
|
|
walker::TrieWalker,
|
|
DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted, MultiProof,
|
|
MultiProofTargets, Nibbles, StorageMultiProof, TRIE_ACCOUNT_RLP_MAX_SIZE,
|
|
};
|
|
use reth_trie_common::proof::ProofRetainer;
|
|
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
|
|
use std::sync::{mpsc::Receiver, Arc};
|
|
use tracing::debug;
|
|
|
|
/// Parallel proof calculator.
|
|
///
|
|
/// This can collect proof for many targets in parallel, spawning a task for each hashed address
|
|
/// that has proof targets.
|
|
#[derive(Debug)]
|
|
pub struct ParallelProof<Factory: DatabaseProviderFactory> {
|
|
/// Consistent view of the database.
|
|
view: ConsistentDbView<Factory>,
|
|
/// The sorted collection of cached in-memory intermediate trie nodes that
|
|
/// can be reused for computation.
|
|
pub nodes_sorted: Arc<TrieUpdatesSorted>,
|
|
/// The sorted in-memory overlay hashed state.
|
|
pub state_sorted: Arc<HashedPostStateSorted>,
|
|
/// The collection of prefix sets for the computation. Since the prefix sets _always_
|
|
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
|
|
/// if we have cached nodes for them.
|
|
pub prefix_sets: Arc<TriePrefixSetsMut>,
|
|
/// Flag indicating whether to include branch node masks in the proof.
|
|
collect_branch_node_masks: bool,
|
|
/// Handle to the storage proof task.
|
|
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
|
|
#[cfg(feature = "metrics")]
|
|
metrics: ParallelTrieMetrics,
|
|
}
|
|
|
|
impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
|
|
/// Create new state proof generator.
|
|
pub fn new(
|
|
view: ConsistentDbView<Factory>,
|
|
nodes_sorted: Arc<TrieUpdatesSorted>,
|
|
state_sorted: Arc<HashedPostStateSorted>,
|
|
prefix_sets: Arc<TriePrefixSetsMut>,
|
|
storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
|
|
) -> Self {
|
|
Self {
|
|
view,
|
|
nodes_sorted,
|
|
state_sorted,
|
|
prefix_sets,
|
|
collect_branch_node_masks: false,
|
|
storage_proof_task_handle,
|
|
#[cfg(feature = "metrics")]
|
|
metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
|
|
}
|
|
}
|
|
|
|
/// Set the flag indicating whether to include branch node masks in the proof.
|
|
pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
|
|
self.collect_branch_node_masks = branch_node_masks;
|
|
self
|
|
}
|
|
}
|
|
|
|
impl<Factory> ParallelProof<Factory>
|
|
where
|
|
Factory:
|
|
DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
|
|
{
|
|
/// Spawns a storage proof on the storage proof task and returns a receiver for the result.
|
|
fn spawn_storage_proof(
|
|
&self,
|
|
hashed_address: B256,
|
|
prefix_set: PrefixSet,
|
|
target_slots: B256Set,
|
|
) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
|
|
let input = StorageProofInput::new(
|
|
hashed_address,
|
|
prefix_set,
|
|
target_slots,
|
|
self.collect_branch_node_masks,
|
|
);
|
|
|
|
let (sender, receiver) = std::sync::mpsc::channel();
|
|
let _ =
|
|
self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
|
|
receiver
|
|
}
|
|
|
|
/// Generate a storage multiproof according to the specified targets and hashed address.
|
|
pub fn storage_proof(
|
|
self,
|
|
hashed_address: B256,
|
|
target_slots: B256Set,
|
|
) -> Result<StorageMultiProof, ParallelStateRootError> {
|
|
let total_targets = target_slots.len();
|
|
let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
|
|
let prefix_set = prefix_set.freeze();
|
|
|
|
debug!(
|
|
target: "trie::parallel_proof",
|
|
total_targets,
|
|
?hashed_address,
|
|
"Starting storage proof generation"
|
|
);
|
|
|
|
let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
|
|
let proof_result = receiver.recv().map_err(|_| {
|
|
ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
|
|
format!("channel closed for {hashed_address}"),
|
|
)))
|
|
})?;
|
|
|
|
debug!(
|
|
target: "trie::parallel_proof",
|
|
total_targets,
|
|
?hashed_address,
|
|
"Storage proof generation completed"
|
|
);
|
|
|
|
proof_result
|
|
}
|
|
|
|
/// Generate a [`DecodedStorageMultiProof`] for the given proof by first calling
|
|
/// `storage_proof`, then decoding the proof nodes.
|
|
pub fn decoded_storage_proof(
|
|
self,
|
|
hashed_address: B256,
|
|
target_slots: B256Set,
|
|
) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
|
|
let proof = self.storage_proof(hashed_address, target_slots)?;
|
|
|
|
// Now decode the nodes of the proof
|
|
let proof = proof.try_into()?;
|
|
|
|
Ok(proof)
|
|
}
|
|
|
|
/// Generate a state multiproof according to specified targets.
|
|
pub fn multiproof(
|
|
self,
|
|
targets: MultiProofTargets,
|
|
) -> Result<MultiProof, ParallelStateRootError> {
|
|
let mut tracker = ParallelTrieTracker::default();
|
|
|
|
// Extend prefix sets with targets
|
|
let mut prefix_sets = (*self.prefix_sets).clone();
|
|
prefix_sets.extend(TriePrefixSetsMut {
|
|
account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
|
|
storage_prefix_sets: targets
|
|
.iter()
|
|
.filter(|&(_hashed_address, slots)| !slots.is_empty())
|
|
.map(|(hashed_address, slots)| {
|
|
(*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
|
|
})
|
|
.collect(),
|
|
destroyed_accounts: Default::default(),
|
|
});
|
|
let prefix_sets = prefix_sets.freeze();
|
|
|
|
let storage_root_targets = StorageRootTargets::new(
|
|
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
|
|
prefix_sets.storage_prefix_sets.clone(),
|
|
);
|
|
let storage_root_targets_len = storage_root_targets.len();
|
|
|
|
debug!(
|
|
target: "trie::parallel_proof",
|
|
total_targets = storage_root_targets_len,
|
|
"Starting parallel proof generation"
|
|
);
|
|
|
|
// Pre-calculate storage roots for accounts which were changed.
|
|
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
|
|
|
|
// stores the receiver for the storage proof outcome for the hashed addresses
|
|
// this way we can lazily await the outcome when we iterate over the map
|
|
let mut storage_proofs =
|
|
B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
|
|
|
|
for (hashed_address, prefix_set) in
|
|
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
|
|
{
|
|
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
|
|
let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
|
|
|
|
// store the receiver for that result with the hashed address so we can await this in
|
|
// place when we iterate over the trie
|
|
storage_proofs.insert(hashed_address, receiver);
|
|
}
|
|
|
|
let provider_ro = self.view.provider_ro()?;
|
|
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
|
|
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
|
|
&self.nodes_sorted,
|
|
);
|
|
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
|
|
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
|
|
&self.state_sorted,
|
|
);
|
|
|
|
// Create the walker.
|
|
let walker = TrieWalker::state_trie(
|
|
trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
|
|
prefix_sets.account_prefix_set,
|
|
)
|
|
.with_deletions_retained(true);
|
|
|
|
// Create a hash builder to rebuild the root node since it is not available in the database.
|
|
let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
|
|
let mut hash_builder = HashBuilder::default()
|
|
.with_proof_retainer(retainer)
|
|
.with_updates(self.collect_branch_node_masks);
|
|
|
|
// Initialize all storage multiproofs as empty.
|
|
// Storage multiproofs for non empty tries will be overwritten if necessary.
|
|
let mut storages: B256Map<_> =
|
|
targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
|
|
let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
|
|
let mut account_node_iter = TrieNodeIter::state_trie(
|
|
walker,
|
|
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
|
|
);
|
|
while let Some(account_node) =
|
|
account_node_iter.try_next().map_err(ProviderError::Database)?
|
|
{
|
|
match account_node {
|
|
TrieElement::Branch(node) => {
|
|
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
|
|
}
|
|
TrieElement::Leaf(hashed_address, account) => {
|
|
let storage_multiproof = match storage_proofs.remove(&hashed_address) {
|
|
Some(rx) => rx.recv().map_err(|_| {
|
|
ParallelStateRootError::StorageRoot(StorageRootError::Database(
|
|
DatabaseError::Other(format!(
|
|
"channel closed for {hashed_address}"
|
|
)),
|
|
))
|
|
})??,
|
|
// Since we do not store all intermediate nodes in the database, there might
|
|
// be a possibility of re-adding a non-modified leaf to the hash builder.
|
|
None => {
|
|
tracker.inc_missed_leaves();
|
|
StorageProof::new_hashed(
|
|
trie_cursor_factory.clone(),
|
|
hashed_cursor_factory.clone(),
|
|
hashed_address,
|
|
)
|
|
.with_prefix_set_mut(Default::default())
|
|
.storage_multiproof(
|
|
targets.get(&hashed_address).cloned().unwrap_or_default(),
|
|
)
|
|
.map_err(|e| {
|
|
ParallelStateRootError::StorageRoot(StorageRootError::Database(
|
|
DatabaseError::Other(e.to_string()),
|
|
))
|
|
})?
|
|
}
|
|
};
|
|
|
|
// Encode account
|
|
account_rlp.clear();
|
|
let account = account.into_trie_account(storage_multiproof.root);
|
|
account.encode(&mut account_rlp as &mut dyn BufMut);
|
|
|
|
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
|
|
|
|
// We might be adding leaves that are not necessarily our proof targets.
|
|
if targets.contains_key(&hashed_address) {
|
|
storages.insert(hashed_address, storage_multiproof);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
let _ = hash_builder.root();
|
|
|
|
let stats = tracker.finish();
|
|
#[cfg(feature = "metrics")]
|
|
self.metrics.record(stats);
|
|
|
|
let account_subtree = hash_builder.take_proof_nodes();
|
|
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
|
|
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
|
|
(
|
|
updated_branch_nodes
|
|
.iter()
|
|
.map(|(path, node)| (path.clone(), node.hash_mask))
|
|
.collect(),
|
|
updated_branch_nodes
|
|
.into_iter()
|
|
.map(|(path, node)| (path, node.tree_mask))
|
|
.collect(),
|
|
)
|
|
} else {
|
|
(HashMap::default(), HashMap::default())
|
|
};
|
|
|
|
debug!(
|
|
target: "trie::parallel_proof",
|
|
total_targets = storage_root_targets_len,
|
|
duration = ?stats.duration(),
|
|
branches_added = stats.branches_added(),
|
|
leaves_added = stats.leaves_added(),
|
|
missed_leaves = stats.missed_leaves(),
|
|
precomputed_storage_roots = stats.precomputed_storage_roots(),
|
|
"Calculated proof"
|
|
);
|
|
|
|
Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
|
|
}
|
|
|
|
/// Returns a [`DecodedMultiProof`] for the given proof.
|
|
///
|
|
/// Uses `multiproof` first to get the proof, and then decodes the nodes of the multiproof.
|
|
pub fn decoded_multiproof(
|
|
self,
|
|
targets: MultiProofTargets,
|
|
) -> Result<DecodedMultiProof, ParallelStateRootError> {
|
|
let multiproof = self.multiproof(targets)?;
|
|
|
|
// Now decode the nodes of the multiproof
|
|
let multiproof = multiproof.try_into()?;
|
|
|
|
Ok(multiproof)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
|
|
use alloy_primitives::{
|
|
keccak256,
|
|
map::{B256Set, DefaultHashBuilder},
|
|
Address, U256,
|
|
};
|
|
use rand::Rng;
|
|
use reth_primitives_traits::{Account, StorageEntry};
|
|
use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
|
|
use reth_trie::proof::Proof;
|
|
use tokio::runtime::Runtime;
|
|
|
|
#[test]
|
|
fn random_parallel_proof() {
|
|
let factory = create_test_provider_factory();
|
|
let consistent_view = ConsistentDbView::new(factory.clone(), None);
|
|
|
|
let mut rng = rand::rng();
|
|
let state = (0..100)
|
|
.map(|_| {
|
|
let address = Address::random();
|
|
let account =
|
|
Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
|
|
let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
|
|
let has_storage = rng.random_bool(0.7);
|
|
if has_storage {
|
|
for _ in 0..100 {
|
|
storage.insert(
|
|
B256::from(U256::from(rng.random::<u64>())),
|
|
U256::from(rng.random::<u64>()),
|
|
);
|
|
}
|
|
}
|
|
(address, (account, storage))
|
|
})
|
|
.collect::<HashMap<_, _, DefaultHashBuilder>>();
|
|
|
|
{
|
|
let provider_rw = factory.provider_rw().unwrap();
|
|
provider_rw
|
|
.insert_account_for_hashing(
|
|
state.iter().map(|(address, (account, _))| (*address, Some(*account))),
|
|
)
|
|
.unwrap();
|
|
provider_rw
|
|
.insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
|
|
(
|
|
*address,
|
|
storage
|
|
.iter()
|
|
.map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
|
|
)
|
|
}))
|
|
.unwrap();
|
|
provider_rw.commit().unwrap();
|
|
}
|
|
|
|
let mut targets = MultiProofTargets::default();
|
|
for (address, (_, storage)) in state.iter().take(10) {
|
|
let hashed_address = keccak256(*address);
|
|
let mut target_slots = B256Set::default();
|
|
|
|
for (slot, _) in storage.iter().take(5) {
|
|
target_slots.insert(*slot);
|
|
}
|
|
|
|
if !target_slots.is_empty() {
|
|
targets.insert(hashed_address, target_slots);
|
|
}
|
|
}
|
|
|
|
let provider_rw = factory.provider_rw().unwrap();
|
|
let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
|
|
let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
|
|
|
|
let rt = Runtime::new().unwrap();
|
|
|
|
let task_ctx =
|
|
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
|
|
let proof_task =
|
|
ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
|
|
let proof_task_handle = proof_task.handle();
|
|
|
|
// keep the join handle around to make sure it does not return any errors
|
|
// after we compute the state root
|
|
let join_handle = rt.spawn_blocking(move || proof_task.run());
|
|
|
|
let parallel_result = ParallelProof::new(
|
|
consistent_view,
|
|
Default::default(),
|
|
Default::default(),
|
|
Default::default(),
|
|
proof_task_handle.clone(),
|
|
)
|
|
.multiproof(targets.clone())
|
|
.unwrap();
|
|
|
|
let sequential_result =
|
|
Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
|
|
|
|
// to help narrow down what is wrong - first compare account subtries
|
|
assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
|
|
|
|
// then compare length of all storage subtries
|
|
assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
|
|
|
|
// then compare each storage subtrie
|
|
for (hashed_address, storage_proof) in ¶llel_result.storages {
|
|
let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
|
|
assert_eq!(storage_proof, sequential_storage_proof);
|
|
}
|
|
|
|
// then compare the entire thing for any mask differences
|
|
assert_eq!(parallel_result, sequential_result);
|
|
|
|
// drop the handle to terminate the task and then block on the proof task handle to make
|
|
// sure it does not return any errors
|
|
drop(proof_task_handle);
|
|
rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
|
|
}
|
|
}
|