diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index bbc544e49f..844f4558d1 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -107,7 +107,7 @@ where let decoded_storage_multiproof = match storage_receivers.remove(&hashed_address) { Some(receiver) => { // Try non-blocking receive first to check if proof is already available - + match receiver.try_recv() { Ok(Ok(proof)) => { // Immediate: proof was already ready @@ -118,9 +118,9 @@ where Err(crossbeam_channel::TryRecvError::Empty) => { // Blocked: need to wait for proof tracker.inc_storage_proof_blocked(); - match receiver.recv() { - Ok(Ok(proof)) => proof, - Ok(Err(e)) => return Err(e), + match receiver.recv() { + Ok(Ok(proof)) => proof, + Ok(Err(e)) => return Err(e), Err(_) => { return Err(storage_channel_closed_error(&hashed_address)) } @@ -570,4 +570,204 @@ mod tests { drop(proof_task_handle); rt.block_on(join_handle).unwrap().expect("The proof task should not return an error"); } + + /// Test parallel proof with mixed storage targets (some accounts have storage, some don't) + #[test] + fn parallel_proof_handles_mixed_storage_targets() { + let factory = create_test_provider_factory(); + let consistent_view = ConsistentDbView::new(factory.clone(), None); + + let mut rng = rand::rng(); + let state = (0..20) + .map(|i| { + let address = Address::random(); + let account = + Account { balance: U256::from(rng.random::()), ..Default::default() }; + + // Every other account has storage + let mut storage = HashMap::::default(); + if i % 2 == 0 { + for _ in 0..10 { + storage.insert( + B256::from(U256::from(rng.random::())), + U256::from(rng.random::()), + ); + } + } + (address, (account, storage)) + }) + .collect::>(); + + { + let provider_rw = factory.provider_rw().unwrap(); + provider_rw + .insert_account_for_hashing( + state.iter().map(|(address, (account, _))| (*address, Some(*account))), + ) + .unwrap(); + provider_rw + .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| { + ( + *address, + storage + .iter() + .map(|(slot, value)| StorageEntry { key: *slot, value: *value }), + ) + })) + .unwrap(); + provider_rw.commit().unwrap(); + } + + // Create targets with mixed storage (some empty, some with slots) + let mut targets = MultiProofTargets::default(); + for (address, (_, storage)) in &state { + let hashed_address = keccak256(*address); + let target_slots = if storage.is_empty() { + B256Set::default() // Empty storage + } else { + storage.iter().take(3).map(|(slot, _)| *slot).collect() + }; + targets.insert(hashed_address, target_slots); + } + + let provider_rw = factory.provider_rw().unwrap(); + let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref()); + let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref()); + + let rt = Runtime::new().unwrap(); + let task_ctx = + ProofTaskCtx::new(Default::default(), Default::default(), Default::default()); + let proof_task = ProofTaskManager::new( + rt.handle().clone(), + consistent_view.clone(), + task_ctx, + 2, // storage_worker_count + 1, // account_worker_count + 1, // max_concurrency + ) + .unwrap(); + let proof_task_handle = proof_task.handle(); + let join_handle = rt.spawn_blocking(move || proof_task.run()); + + let parallel_result = ParallelProof::new( + consistent_view, + Default::default(), + Default::default(), + Default::default(), + proof_task_handle.clone(), + ) + .decoded_multiproof(targets.clone()) + .unwrap(); + + let sequential_result_raw = + Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap(); + let sequential_result_decoded: DecodedMultiProof = + sequential_result_raw.try_into().unwrap(); + + assert_eq!(parallel_result, sequential_result_decoded); + + drop(proof_task_handle); + rt.block_on(join_handle).unwrap().expect("proof task should succeed"); + } + + /// Test parallel proof with varying storage sizes (validates ordering independence) + #[test] + fn parallel_proof_ordering_independence() { + let factory = create_test_provider_factory(); + let consistent_view = ConsistentDbView::new(factory.clone(), None); + + let mut rng = rand::rng(); + // Create state with varying storage sizes to ensure random completion order + let state = (0..15) + .map(|_| { + let address = Address::random(); + let account = + Account { balance: U256::from(rng.random::()), ..Default::default() }; + + // Random storage sizes (1-50 slots) to create different proof computation times + let storage_size = rng.random_range(1..50); + let storage: HashMap = (0..storage_size) + .map(|_| { + ( + B256::from(U256::from(rng.random::())), + U256::from(rng.random::()), + ) + }) + .collect(); + + (address, (account, storage)) + }) + .collect::>(); + + { + let provider_rw = factory.provider_rw().unwrap(); + provider_rw + .insert_account_for_hashing( + state.iter().map(|(address, (account, _))| (*address, Some(*account))), + ) + .unwrap(); + provider_rw + .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| { + ( + *address, + storage + .iter() + .map(|(slot, value)| StorageEntry { key: *slot, value: *value }), + ) + })) + .unwrap(); + provider_rw.commit().unwrap(); + } + + let mut targets = MultiProofTargets::default(); + for (address, (_, storage)) in &state { + let hashed_address = keccak256(*address); + let target_slots: B256Set = storage.keys().take(5).copied().collect(); + if !target_slots.is_empty() { + targets.insert(hashed_address, target_slots); + } + } + + let provider_rw = factory.provider_rw().unwrap(); + let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref()); + let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref()); + + let rt = Runtime::new().unwrap(); + let task_ctx = + ProofTaskCtx::new(Default::default(), Default::default(), Default::default()); + + // Use 3 workers to increase chance of out-of-order completion + let proof_task = ProofTaskManager::new( + rt.handle().clone(), + consistent_view.clone(), + task_ctx, + 3, // storage_worker_count + 1, // account_worker_count + 1, // max_concurrency + ) + .unwrap(); + let proof_task_handle = proof_task.handle(); + let join_handle = rt.spawn_blocking(move || proof_task.run()); + + let parallel_result = ParallelProof::new( + consistent_view, + Default::default(), + Default::default(), + Default::default(), + proof_task_handle.clone(), + ) + .decoded_multiproof(targets.clone()) + .unwrap(); + + let sequential_result_raw = + Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap(); + let sequential_result_decoded: DecodedMultiProof = + sequential_result_raw.try_into().unwrap(); + + // Results should be identical regardless of completion order + assert_eq!(parallel_result, sequential_result_decoded); + + drop(proof_task_handle); + rt.block_on(join_handle).unwrap().expect("proof task should succeed"); + } } diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index 66f9b94405..4def415159 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -1024,11 +1024,20 @@ mod tests { let rt = Runtime::new().unwrap(); let task_ctx = default_task_ctx(); - let num_workers = 2usize; - let manager = - ProofTaskManager::new(rt.handle().clone(), view, task_ctx, num_workers, 0, 4).unwrap(); + let storage_workers = 2usize; + let account_workers = 1usize; + let manager = ProofTaskManager::new( + rt.handle().clone(), + view, + task_ctx, + storage_workers, + account_workers, + 4, + ) + .unwrap(); - assert_eq!(calls.load(Ordering::SeqCst), num_workers); + let expected_total_workers = storage_workers + account_workers; + assert_eq!(calls.load(Ordering::SeqCst), expected_total_workers); let handle = manager.handle(); let join_handle = rt.spawn_blocking(move || manager.run()); @@ -1056,4 +1065,101 @@ mod tests { drop(handle); rt.block_on(join_handle).unwrap().unwrap(); } + + /// Tests that storage workers reuse the same database transaction across multiple proofs, + /// validating the core Phase 1a optimization that eliminates per-proof transaction overhead. + #[test] + fn storage_worker_reuses_transaction_across_multiple_proofs() { + let inner_factory = create_test_provider_factory(); + let calls = Arc::new(AtomicUsize::new(0)); + let counting_factory = CountingFactory::new(inner_factory, Arc::clone(&calls)); + let view = ConsistentDbView::new(counting_factory, None); + + let rt = Runtime::new().unwrap(); + let task_ctx = default_task_ctx(); + let storage_workers = 1usize; + let account_workers = 0usize; + let manager = ProofTaskManager::new( + rt.handle().clone(), + view, + task_ctx, + storage_workers, + account_workers, + 4, + ) + .unwrap(); + + // Expect 1 transaction: 1 for storage worker (0 account workers = no account workers) + let initial_calls = calls.load(Ordering::SeqCst); + assert_eq!(initial_calls, 1); + + let handle = manager.handle(); + let join_handle = rt.spawn_blocking(move || manager.run()); + + // Queue 10 storage proofs - all should use same transaction + let prefix_set = PrefixSetMut::default().freeze(); + let mut receivers = Vec::new(); + for _ in 0..10 { + let input = StorageProofInput::new( + B256::ZERO, + prefix_set.clone(), + Arc::new(B256Set::default()), + false, + None, + ); + let (sender, receiver) = crossbeam_channel::unbounded(); + handle.queue_task(ProofTaskKind::StorageProof(input, sender)).unwrap(); + receivers.push(receiver); + } + + for receiver in receivers { + let _ = receiver.recv().unwrap(); + } + + // Transaction count should still be 1 (worker reuses its transaction) + assert_eq!(calls.load(Ordering::SeqCst), initial_calls); + + drop(handle); + rt.block_on(join_handle).unwrap().unwrap(); + } + + /// Tests that the dual manager architecture handles heavy concurrent load without deadlocks, + /// validating unbounded channel backpressure behavior under stress. + #[test] + fn handles_backpressure_with_many_concurrent_storage_proofs() { + let inner_factory = create_test_provider_factory(); + let view = ConsistentDbView::new(inner_factory, None); + + let rt = Runtime::new().unwrap(); + let task_ctx = default_task_ctx(); + // 2 storage workers + 0 account workers = 2 total workers + let manager = ProofTaskManager::new(rt.handle().clone(), view, task_ctx, 2, 0, 4).unwrap(); + + let handle = manager.handle(); + let join_handle = rt.spawn_blocking(move || manager.run()); + + // Queue 50 storage proofs concurrently + let prefix_set = PrefixSetMut::default().freeze(); + let mut receivers = Vec::new(); + for _ in 0..50 { + let input = StorageProofInput::new( + B256::ZERO, + prefix_set.clone(), + Arc::new(B256Set::default()), + false, + None, + ); + let (sender, receiver) = crossbeam_channel::unbounded(); + handle.queue_task(ProofTaskKind::StorageProof(input, sender)).unwrap(); + receivers.push(receiver); + } + + // All tasks complete without deadlock + for receiver in receivers { + let _ = receiver.recv().unwrap(); + } + + drop(handle); + rt.block_on(join_handle).unwrap().unwrap(); + } }