feat(trie): add AsyncAccountValueEncoder for V2 proof computation (#21197)

Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Brian Picciano
2026-01-20 14:50:29 +01:00
committed by GitHub
parent ea3d4663ae
commit 346cc0da71
6 changed files with 217 additions and 37 deletions

View File

@@ -14,6 +14,7 @@ workspace = true
[dependencies]
# reth
reth-execution-errors.workspace = true
reth-primitives-traits.workspace = true
reth-provider.workspace = true
reth-storage-errors.workspace = true
reth-trie-common.workspace = true

View File

@@ -22,6 +22,9 @@ pub mod proof;
pub mod proof_task;
/// Async value encoder for V2 proofs.
pub(crate) mod value_encoder;
/// V2 multiproof targets and chunking.
pub mod targets_v2;

View File

@@ -41,7 +41,7 @@ use alloy_primitives::{
use alloy_rlp::{BufMut, Encodable};
use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
use dashmap::DashMap;
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, StateProofError};
use reth_provider::{DatabaseProviderROFactory, ProviderError, ProviderResult};
use reth_storage_errors::db::DatabaseError;
use reth_trie::{
@@ -305,17 +305,17 @@ impl ProofWorkerHandle {
self.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, proof_result_sender })
.map_err(|err| {
let error =
ProviderError::other(std::io::Error::other("storage workers unavailable"));
if let StorageWorkerJob::StorageProof { proof_result_sender, .. } = err.0 {
let _ = proof_result_sender.send(StorageProofResultMessage {
hashed_address,
result: Err(ParallelStateRootError::Provider(error.clone())),
result: Err(DatabaseError::Other(
"storage workers unavailable".to_string(),
)
.into()),
});
}
error
ProviderError::other(std::io::Error::other("storage workers unavailable"))
})
}
@@ -432,7 +432,7 @@ where
input: StorageProofInput,
trie_cursor_metrics: &mut TrieCursorMetricsCache,
hashed_cursor_metrics: &mut HashedCursorMetricsCache,
) -> Result<StorageProofResult, ParallelStateRootError> {
) -> Result<StorageProofResult, StateProofError> {
// Consume the input so we can move large collections (e.g. target slots) without cloning.
let StorageProofInput::Legacy {
hashed_address,
@@ -469,20 +469,13 @@ where
.with_added_removed_keys(added_removed_keys)
.with_trie_cursor_metrics(trie_cursor_metrics)
.with_hashed_cursor_metrics(hashed_cursor_metrics)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
.storage_multiproof(target_slots);
trie_cursor_metrics.record_span("trie_cursor");
hashed_cursor_metrics.record_span("hashed_cursor");
// Decode proof into DecodedStorageMultiProof
let decoded_result = raw_proof_result.and_then(|raw_proof| {
raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
ParallelStateRootError::Other(format!(
"Failed to decode storage proof for {}: {}",
hashed_address, e
))
})
})?;
let decoded_result =
raw_proof_result.and_then(|raw_proof| raw_proof.try_into().map_err(Into::into))?;
trace!(
target: "trie::proof_task",
@@ -502,7 +495,7 @@ where
<Provider as TrieCursorFactory>::StorageTrieCursor<'_>,
<Provider as HashedCursorFactory>::StorageCursor<'_>,
>,
) -> Result<StorageProofResult, ParallelStateRootError> {
) -> Result<StorageProofResult, StateProofError> {
let StorageProofInput::V2 { hashed_address, mut targets } = input else {
panic!("compute_v2_storage_proof only accepts StorageProofInput::V2")
};
@@ -717,12 +710,12 @@ pub struct StorageProofResultMessage {
/// The hashed address this storage proof belongs to
pub(crate) hashed_address: B256,
/// The storage proof calculation result
pub(crate) result: Result<StorageProofResult, ParallelStateRootError>,
pub(crate) result: Result<StorageProofResult, StateProofError>,
}
/// Internal message for storage workers.
#[derive(Debug)]
enum StorageWorkerJob {
pub(crate) enum StorageWorkerJob {
/// Storage proof computation request
StorageProof {
/// Storage proof input parameters

View File

@@ -0,0 +1,185 @@
use crate::proof_task::{
StorageProofInput, StorageProofResult, StorageProofResultMessage, StorageWorkerJob,
};
use alloy_primitives::{map::B256Map, B256};
use alloy_rlp::Encodable;
use core::cell::RefCell;
use crossbeam_channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
use dashmap::DashMap;
use reth_execution_errors::trie::StateProofError;
use reth_primitives_traits::Account;
use reth_storage_errors::db::DatabaseError;
use reth_trie::{
proof_v2::{DeferredValueEncoder, LeafValueEncoder, Target},
ProofTrieNode,
};
use std::{rc::Rc, sync::Arc};
/// Returned from [`AsyncAccountValueEncoder`], used to track an async storage root calculation.
pub(crate) enum AsyncAccountDeferredValueEncoder {
Dispatched {
hashed_address: B256,
account: Account,
proof_result_rx: Result<CrossbeamReceiver<StorageProofResultMessage>, DatabaseError>,
// None if results shouldn't be retained for this dispatched proof.
storage_proof_results: Option<Rc<RefCell<B256Map<Vec<ProofTrieNode>>>>>,
},
FromCache {
account: Account,
root: B256,
},
}
impl DeferredValueEncoder for AsyncAccountDeferredValueEncoder {
fn encode(self, buf: &mut Vec<u8>) -> Result<(), StateProofError> {
let (account, root) = match self {
Self::Dispatched {
hashed_address,
account,
proof_result_rx,
storage_proof_results,
} => {
let result = proof_result_rx?
.recv()
.map_err(|_| {
StateProofError::Database(DatabaseError::Other(format!(
"Storage proof channel closed for {hashed_address:?}",
)))
})?
.result?;
let StorageProofResult::V2 { root: Some(root), proof } = result else {
panic!("StorageProofResult is not V2 with root: {result:?}")
};
if let Some(storage_proof_results) = storage_proof_results.as_ref() {
storage_proof_results.borrow_mut().insert(hashed_address, proof);
}
(account, root)
}
Self::FromCache { account, root } => (account, root),
};
let account = account.into_trie_account(root);
account.encode(buf);
Ok(())
}
}
/// Implements the [`LeafValueEncoder`] trait for accounts using a [`CrossbeamSender`] to dispatch
/// and compute storage roots asynchronously. Can also accept a set of already dispatched account
/// storage proofs, for cases where it's possible to determine some necessary accounts ahead of
/// time.
pub(crate) struct AsyncAccountValueEncoder {
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
/// Storage proof jobs which were dispatched ahead of time.
dispatched: B256Map<CrossbeamReceiver<StorageProofResultMessage>>,
/// Storage roots which have already been computed. This can be used only if a storage proof
/// wasn't dispatched for an account, otherwise we must consume the proof result.
cached_storage_roots: Arc<DashMap<B256, B256>>,
/// Tracks storage proof results received from the storage workers. [`Rc`] + [`RefCell`] is
/// required because [`DeferredValueEncoder`] cannot have a lifetime.
storage_proof_results: Rc<RefCell<B256Map<Vec<ProofTrieNode>>>>,
}
impl AsyncAccountValueEncoder {
/// Initializes a [`Self`] using a `ProofWorkerHandle` which will be used to calculate storage
/// roots asynchronously.
#[expect(dead_code)]
pub(crate) fn new(
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
dispatched: B256Map<CrossbeamReceiver<StorageProofResultMessage>>,
cached_storage_roots: Arc<DashMap<B256, B256>>,
) -> Self {
Self {
storage_work_tx,
dispatched,
cached_storage_roots,
storage_proof_results: Default::default(),
}
}
/// Consume [`Self`] and return all collected storage proofs which had been dispatched.
///
/// # Panics
///
/// This method panics if any deferred encoders produced by [`Self::deferred_encoder`] have not
/// been dropped.
#[expect(dead_code)]
pub(crate) fn into_storage_proofs(
self,
) -> Result<B256Map<Vec<ProofTrieNode>>, StateProofError> {
let mut storage_proof_results = Rc::into_inner(self.storage_proof_results)
.expect("no deferred encoders are still allocated")
.into_inner();
// Any remaining dispatched proofs need to have their results collected
for (hashed_address, rx) in &self.dispatched {
let result = rx
.recv()
.map_err(|_| {
StateProofError::Database(DatabaseError::Other(format!(
"Storage proof channel closed for {hashed_address:?}",
)))
})?
.result?;
let StorageProofResult::V2 { proof, .. } = result else {
panic!("StorageProofResult is not V2: {result:?}")
};
storage_proof_results.insert(*hashed_address, proof);
}
Ok(storage_proof_results)
}
}
impl LeafValueEncoder for AsyncAccountValueEncoder {
type Value = Account;
type DeferredEncoder = AsyncAccountDeferredValueEncoder;
fn deferred_encoder(
&mut self,
hashed_address: B256,
account: Self::Value,
) -> Self::DeferredEncoder {
// If the proof job has already been dispatched for this account then it's not necessary to
// dispatch another.
if let Some(rx) = self.dispatched.remove(&hashed_address) {
return AsyncAccountDeferredValueEncoder::Dispatched {
hashed_address,
account,
proof_result_rx: Ok(rx),
storage_proof_results: Some(self.storage_proof_results.clone()),
}
}
// If the address didn't have a job dispatched for it then we can assume it has no targets,
// and we only need its root.
// If the root is already calculated then just use it directly
if let Some(root) = self.cached_storage_roots.get(&hashed_address) {
return AsyncAccountDeferredValueEncoder::FromCache { account, root: *root }
}
// Create a proof input which targets a bogus key, so that we calculate the root as a
// side-effect.
let input = StorageProofInput::new(hashed_address, vec![Target::new(B256::ZERO)]);
let (tx, rx) = crossbeam_channel::bounded(1);
let proof_result_rx = self
.storage_work_tx
.send(StorageWorkerJob::StorageProof { input, proof_result_sender: tx })
.map_err(|_| DatabaseError::Other("storage workers unavailable".to_string()))
.map(|_| rx);
AsyncAccountDeferredValueEncoder::Dispatched {
hashed_address,
account,
proof_result_rx,
storage_proof_results: None,
}
}
}

View File

@@ -651,7 +651,7 @@ where
)]
fn calculate_key_range<'a>(
&mut self,
value_encoder: &VE,
value_encoder: &mut VE,
targets: &mut TargetsCursor<'a>,
hashed_cursor_current: &mut Option<(Nibbles, VE::DeferredEncoder)>,
lower_bound: Nibbles,
@@ -660,7 +660,7 @@ where
// A helper closure for mapping entries returned from the `hashed_cursor`, converting the
// key to Nibbles and immediately creating the DeferredValueEncoder so that encoding of the
// leaf value can begin ASAP.
let map_hashed_cursor_entry = |(key_b256, val): (B256, _)| {
let mut map_hashed_cursor_entry = |(key_b256, val): (B256, _)| {
debug_assert_eq!(key_b256.len(), 32);
// SAFETY: key is a B256 and so is exactly 32-bytes.
let key = unsafe { Nibbles::unpack_unchecked(key_b256.as_slice()) };
@@ -679,7 +679,7 @@ where
let lower_key = B256::right_padding_from(&lower_bound.pack());
*hashed_cursor_current =
self.hashed_cursor.seek(lower_key)?.map(map_hashed_cursor_entry);
self.hashed_cursor.seek(lower_key)?.map(&mut map_hashed_cursor_entry);
}
// Loop over all keys in the range, calling `push_leaf` on each.
@@ -689,7 +689,7 @@ where
let (key, val) =
core::mem::take(hashed_cursor_current).expect("while-let checks for Some");
self.push_leaf(targets, key, val)?;
*hashed_cursor_current = self.hashed_cursor.next()?.map(map_hashed_cursor_entry);
*hashed_cursor_current = self.hashed_cursor.next()?.map(&mut map_hashed_cursor_entry);
}
trace!(target: TRACE_TARGET, "No further keys within range");
@@ -1125,7 +1125,7 @@ where
)]
fn proof_subtrie<'a>(
&mut self,
value_encoder: &VE,
value_encoder: &mut VE,
trie_cursor_state: &mut TrieCursorState,
hashed_cursor_current: &mut Option<(Nibbles, VE::DeferredEncoder)>,
sub_trie_targets: SubTrieTargets<'a>,
@@ -1254,7 +1254,7 @@ where
/// See docs on [`Self::proof`] for expected behavior.
fn proof_inner(
&mut self,
value_encoder: &VE,
value_encoder: &mut VE,
targets: &mut [Target],
) -> Result<Vec<ProofTrieNode>, StateProofError> {
// If there are no targets then nothing could be returned, return early.
@@ -1305,7 +1305,7 @@ where
#[instrument(target = TRACE_TARGET, level = "trace", skip_all)]
pub fn proof(
&mut self,
value_encoder: &VE,
value_encoder: &mut VE,
targets: &mut [Target],
) -> Result<Vec<ProofTrieNode>, StateProofError> {
self.trie_cursor.reset();
@@ -1341,9 +1341,6 @@ where
hashed_address: B256,
targets: &mut [Target],
) -> Result<Vec<ProofTrieNode>, StateProofError> {
/// Static storage value encoder instance used by all storage proofs.
static STORAGE_VALUE_ENCODER: StorageValueEncoder = StorageValueEncoder;
self.hashed_cursor.set_hashed_address(hashed_address);
// Shortcut: check if storage is empty
@@ -1360,8 +1357,9 @@ where
// been checked.
self.trie_cursor.set_hashed_address(hashed_address);
// Use the static StorageValueEncoder and pass it to proof_inner
self.proof_inner(&STORAGE_VALUE_ENCODER, targets)
// Create a mutable storage value encoder
let mut storage_value_encoder = StorageValueEncoder;
self.proof_inner(&mut storage_value_encoder, targets)
}
/// Computes the root hash from a set of proof nodes.
@@ -1639,13 +1637,13 @@ mod tests {
InstrumentedHashedCursor::new(hashed_cursor, &mut hashed_cursor_metrics);
// Call ProofCalculator::proof with account targets
let value_encoder = SyncAccountValueEncoder::new(
let mut value_encoder = SyncAccountValueEncoder::new(
self.trie_cursor_factory.clone(),
self.hashed_cursor_factory.clone(),
);
let mut proof_calculator = ProofCalculator::new(trie_cursor, hashed_cursor);
let proof_v2_result =
proof_calculator.proof(&value_encoder, &mut targets_vec.clone())?;
proof_calculator.proof(&mut value_encoder, &mut targets_vec.clone())?;
// Output metrics
trace!(target: TRACE_TARGET, ?trie_cursor_metrics, "V2 trie cursor metrics");

View File

@@ -44,7 +44,7 @@ pub trait LeafValueEncoder {
///
/// The returned deferred encoder will be called as late as possible in the algorithm to
/// maximize the time available for parallel computation (e.g., storage root calculation).
fn deferred_encoder(&self, key: B256, value: Self::Value) -> Self::DeferredEncoder;
fn deferred_encoder(&mut self, key: B256, value: Self::Value) -> Self::DeferredEncoder;
}
/// An encoder for storage slot values.
@@ -68,7 +68,7 @@ impl LeafValueEncoder for StorageValueEncoder {
type Value = U256;
type DeferredEncoder = StorageDeferredValueEncoder;
fn deferred_encoder(&self, _key: B256, value: Self::Value) -> Self::DeferredEncoder {
fn deferred_encoder(&mut self, _key: B256, value: Self::Value) -> Self::DeferredEncoder {
StorageDeferredValueEncoder(value)
}
}
@@ -157,7 +157,7 @@ where
type DeferredEncoder = SyncAccountDeferredValueEncoder<T, H>;
fn deferred_encoder(
&self,
&mut self,
hashed_address: B256,
account: Self::Value,
) -> Self::DeferredEncoder {