fix: strengthen batch invariants and prevent blinded node starvation

- Change debug_assert to assert for multi_added_removed_keys Arc equality
  check in BatchedStorageProof::merge, ensuring incorrect proofs are
  caught in release builds, not just debug

- Change BatchedAccountProof::merge to try_merge returning Result, properly
  handling incompatible caches by processing as separate batches instead
  of panicking

- Add MAX_DEFERRED_BLINDED_NODES (16) limit to prevent starvation of
  blinded node requests under high proof load - stops batching early when
  limit reached

- Pre-allocate deferred_blinded_nodes vectors with capacity

- Remove unnecessary clone of storage_work_tx by taking reference
This commit is contained in:
yongkangc
2025-12-09 10:27:40 +00:00
parent 1a5d9a3ad3
commit 9bd5a3ecba

View File

@@ -83,6 +83,11 @@ type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
/// Maximum number of storage proof jobs to batch together per account.
const STORAGE_PROOF_BATCH_LIMIT: usize = 32;
/// Maximum number of blinded node requests to defer during storage proof batching.
/// When this limit is reached, batching stops early to process deferred nodes,
/// preventing starvation of blinded node requests under high proof load.
const MAX_DEFERRED_BLINDED_NODES: usize = 16;
/// Holds batched storage proof jobs for the same account.
///
/// When multiple storage proof requests arrive for the same account, they can be merged
@@ -117,13 +122,16 @@ impl BatchedStorageProof {
/// Merges another storage proof job into this batch.
///
/// # Panics (debug builds only)
/// # Panics
/// Panics if `input.multi_added_removed_keys` does not point to the same Arc as the batch's.
/// This is a critical invariant for proof correctness.
fn merge(&mut self, input: StorageProofInput, sender: ProofResultContext) {
// Validate that all batched jobs share the same multi_added_removed_keys Arc.
// This is a critical invariant: if jobs have different keys, the merged proof
// would be computed with only the first job's keys, producing incorrect results.
debug_assert!(
// Using assert! (not debug_assert!) because incorrect proofs could cause consensus
// failures.
assert!(
match (&self.multi_added_removed_keys, &input.multi_added_removed_keys) {
(Some(a), Some(b)) => Arc::ptr_eq(a, b),
(None, None) => true,
@@ -235,31 +243,24 @@ impl BatchedAccountProof {
}
}
/// Merges another account multiproof job into this batch.
/// Attempts to merge another account multiproof job into this batch.
///
/// # Panics (debug builds only)
/// Panics if `input.multi_added_removed_keys` or `input.missed_leaves_storage_roots`
/// do not point to the same Arc as the batch's.
fn merge(&mut self, input: AccountMultiproofInput) {
// Validate that all batched jobs share the same multi_added_removed_keys Arc.
// This is a critical invariant: if jobs have different keys, the merged proof
// would be computed with only the first job's keys, producing incorrect results.
debug_assert!(
match (&self.multi_added_removed_keys, &input.multi_added_removed_keys) {
/// Returns the job back if caches are incompatible so the caller can process it separately.
fn try_merge(&mut self, input: AccountMultiproofInput) -> Result<(), AccountMultiproofInput> {
// Require all jobs to share the same caches; otherwise merging would produce
// incorrect proofs by reusing the wrong retained keys or missed-leaf storage roots.
let multi_added_removed_keys_mismatch =
!match (&self.multi_added_removed_keys, &input.multi_added_removed_keys) {
(Some(a), Some(b)) => Arc::ptr_eq(a, b),
(None, None) => true,
_ => false,
},
"All batched account proof jobs must share the same multi_added_removed_keys Arc"
);
};
// Validate that all batched jobs share the same missed_leaves_storage_roots cache.
// This is critical because workers may add entries to this cache during proof computation,
// and all receivers expect to see those entries in their shared cache.
debug_assert!(
Arc::ptr_eq(&self.missed_leaves_storage_roots, &input.missed_leaves_storage_roots),
"All batched account proof jobs must share the same missed_leaves_storage_roots Arc"
);
if multi_added_removed_keys_mismatch ||
!Arc::ptr_eq(&self.missed_leaves_storage_roots, &input.missed_leaves_storage_roots)
{
return Err(input);
}
// Merge targets.
self.targets.extend(input.targets);
@@ -287,6 +288,8 @@ impl BatchedAccountProof {
// Collect the sender.
self.senders.push(input.proof_result_sender);
Ok(())
}
/// Converts this batch into a single `AccountMultiproofInput` for computation.
@@ -1032,8 +1035,9 @@ where
available_workers.fetch_add(1, Ordering::Relaxed);
// Deferred blinded node jobs to process after batched storage proofs.
// Pre-allocate with capacity to avoid reallocations during batching.
let mut deferred_blinded_nodes: Vec<(B256, Nibbles, Sender<TrieNodeProviderResult>)> =
Vec::new();
Vec::with_capacity(MAX_DEFERRED_BLINDED_NODES);
while let Ok(job) = work_rx.recv() {
// Mark worker as busy.
@@ -1077,6 +1081,11 @@ where
}) => {
// Defer blinded node jobs to process after batched proofs.
deferred_blinded_nodes.push((account, path, result_sender));
// Stop batching if too many blinded nodes are deferred to prevent
// starvation.
if deferred_blinded_nodes.len() >= MAX_DEFERRED_BLINDED_NODES {
break;
}
}
Err(_) => break,
}
@@ -1419,7 +1428,9 @@ where
available_workers.fetch_add(1, Ordering::Relaxed);
// Deferred blinded node jobs to process after batched account proofs.
let mut deferred_blinded_nodes: Vec<(Nibbles, Sender<TrieNodeProviderResult>)> = Vec::new();
// Pre-allocate with capacity to avoid reallocations during batching.
let mut deferred_blinded_nodes: Vec<(Nibbles, Sender<TrieNodeProviderResult>)> =
Vec::with_capacity(MAX_DEFERRED_BLINDED_NODES);
while let Ok(job) = work_rx.recv() {
// Mark worker as busy.
@@ -1427,48 +1438,77 @@ where
match job {
AccountWorkerJob::AccountMultiproof { input } => {
// Start batching: accumulate account multiproof jobs.
let mut batch = BatchedAccountProof::new(*input);
let mut total_jobs = 1usize;
// Start batching: accumulate account multiproof jobs. If we encounter an
// incompatible job (different caches), process it as a separate batch.
let mut next_account_job: Option<Box<AccountMultiproofInput>> = Some(input);
// Drain additional jobs from the queue.
while total_jobs < ACCOUNT_PROOF_BATCH_LIMIT {
match work_rx.try_recv() {
Ok(AccountWorkerJob::AccountMultiproof { input: next_input }) => {
total_jobs += 1;
batch.merge(*next_input);
while let Some(account_job) = next_account_job.take() {
let mut batch = BatchedAccountProof::new(*account_job);
let mut pending_incompatible: Option<Box<AccountMultiproofInput>> = None;
// Drain additional jobs from the queue.
while batch.senders.len() < ACCOUNT_PROOF_BATCH_LIMIT {
match work_rx.try_recv() {
Ok(AccountWorkerJob::AccountMultiproof { input: next_input }) => {
match batch.try_merge(*next_input) {
Ok(()) => {}
Err(incompatible) => {
trace!(
target: "trie::proof_task",
worker_id,
"Account multiproof batch split due to incompatible caches"
);
pending_incompatible = Some(Box::new(incompatible));
break;
}
}
}
Ok(AccountWorkerJob::BlindedAccountNode {
path,
result_sender,
}) => {
// Defer blinded node jobs to process after batched proofs.
deferred_blinded_nodes.push((path, result_sender));
// Stop batching if too many blinded nodes are deferred to
// prevent starvation.
if deferred_blinded_nodes.len() >= MAX_DEFERRED_BLINDED_NODES {
break;
}
}
Err(_) => break,
}
Ok(AccountWorkerJob::BlindedAccountNode { path, result_sender }) => {
// Defer blinded node jobs to process after batched proofs.
deferred_blinded_nodes.push((path, result_sender));
}
Err(_) => break,
}
let batch_size = batch.senders.len();
batch_metrics.record_batch_size(batch_size);
let (merged_input, senders) = batch.into_input();
trace!(
target: "trie::proof_task",
worker_id,
batch_size,
targets_len = merged_input.targets.len(),
"Processing batched account multiproof"
);
Self::process_batched_account_multiproof(
worker_id,
&proof_tx,
&storage_work_tx,
merged_input,
senders,
&mut account_proofs_processed,
&mut cursor_metrics_cache,
);
// If we encountered an incompatible job, process it as its own batch
// before handling any deferred blinded node requests.
if let Some(incompatible_job) = pending_incompatible {
next_account_job = Some(incompatible_job);
}
}
let batch_size = batch.senders.len();
batch_metrics.record_batch_size(batch_size);
let (merged_input, senders) = batch.into_input();
trace!(
target: "trie::proof_task",
worker_id,
batch_size,
targets_len = merged_input.targets.len(),
"Processing batched account multiproof"
);
Self::process_batched_account_multiproof(
worker_id,
&proof_tx,
storage_work_tx.clone(),
merged_input,
senders,
&mut account_proofs_processed,
&mut cursor_metrics_cache,
);
// Process any deferred blinded node jobs.
for (path, result_sender) in std::mem::take(&mut deferred_blinded_nodes) {
Self::process_blinded_node(
@@ -1520,7 +1560,7 @@ where
fn process_batched_account_multiproof<Provider>(
worker_id: usize,
proof_tx: &ProofTaskTx<Provider>,
storage_work_tx: CrossbeamSender<StorageWorkerJob>,
storage_work_tx: &CrossbeamSender<StorageWorkerJob>,
input: AccountMultiproofInput,
senders: Vec<ProofResultContext>,
account_proofs_processed: &mut u64,
@@ -1563,7 +1603,7 @@ where
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
let storage_proof_receivers = match dispatch_storage_proofs(
&storage_work_tx,
storage_work_tx,
&targets,
&mut storage_prefix_sets,
collect_branch_node_masks,