diff --git a/crates/engine/tree/src/engine.rs b/crates/engine/tree/src/engine.rs index c2d0c17546..5bbc37a486 100644 --- a/crates/engine/tree/src/engine.rs +++ b/crates/engine/tree/src/engine.rs @@ -6,6 +6,7 @@ use crate::{ download::{BlockDownloader, DownloadAction, DownloadOutcome}, }; use alloy_primitives::B256; +use crossbeam_channel::Sender; use futures::{Stream, StreamExt}; use reth_chain_state::ExecutedBlock; use reth_engine_primitives::{BeaconEngineMessage, ConsensusEngineEvent}; @@ -15,7 +16,6 @@ use reth_primitives_traits::{Block, NodePrimitives, SealedBlock}; use std::{ collections::HashSet, fmt::Display, - sync::mpsc::Sender, task::{ready, Context, Poll}, }; use tokio::sync::mpsc::UnboundedReceiver; diff --git a/crates/engine/tree/src/persistence.rs b/crates/engine/tree/src/persistence.rs index a57ea8bed9..5dbaefcd29 100644 --- a/crates/engine/tree/src/persistence.rs +++ b/crates/engine/tree/src/persistence.rs @@ -1,5 +1,6 @@ use crate::metrics::PersistenceMetrics; use alloy_eips::BlockNumHash; +use crossbeam_channel::Sender as CrossbeamSender; use reth_chain_state::ExecutedBlock; use reth_errors::ProviderError; use reth_ethereum_primitives::EthPrimitives; @@ -15,7 +16,6 @@ use std::{ time::Instant, }; use thiserror::Error; -use tokio::sync::oneshot; use tracing::{debug, error}; /// Writes parts of reth's in memory tree state to the database and static files. @@ -183,13 +183,13 @@ pub enum PersistenceAction { /// /// First, header, transaction, and receipt-related data should be written to static files. /// Then the execution history-related data will be written to the database. - SaveBlocks(Vec>, oneshot::Sender>), + SaveBlocks(Vec>, CrossbeamSender>), /// Removes block data above the given block number from the database. /// /// This will first update checkpoints from the database, then remove actual block data from /// static files. - RemoveBlocksAbove(u64, oneshot::Sender>), + RemoveBlocksAbove(u64, CrossbeamSender>), /// Update the persisted finalized block on disk SaveFinalizedBlock(u64), @@ -261,7 +261,7 @@ impl PersistenceHandle { pub fn save_blocks( &self, blocks: Vec>, - tx: oneshot::Sender>, + tx: CrossbeamSender>, ) -> Result<(), SendError>> { self.send_action(PersistenceAction::SaveBlocks(blocks, tx)) } @@ -290,7 +290,7 @@ impl PersistenceHandle { pub fn remove_blocks_above( &self, block_num: u64, - tx: oneshot::Sender>, + tx: CrossbeamSender>, ) -> Result<(), SendError>> { self.send_action(PersistenceAction::RemoveBlocksAbove(block_num, tx)) } @@ -319,22 +319,22 @@ mod tests { PersistenceHandle::::spawn_service(provider, pruner, sync_metrics_tx) } - #[tokio::test] - async fn test_save_blocks_empty() { + #[test] + fn test_save_blocks_empty() { reth_tracing::init_test_tracing(); let persistence_handle = default_persistence_handle(); let blocks = vec![]; - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); persistence_handle.save_blocks(blocks, tx).unwrap(); - let hash = rx.await.unwrap(); + let hash = rx.recv().unwrap(); assert_eq!(hash, None); } - #[tokio::test] - async fn test_save_blocks_single_block() { + #[test] + fn test_save_blocks_single_block() { reth_tracing::init_test_tracing(); let persistence_handle = default_persistence_handle(); let block_number = 0; @@ -344,37 +344,35 @@ mod tests { let block_hash = executed.recovered_block().hash(); let blocks = vec![executed]; - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); persistence_handle.save_blocks(blocks, tx).unwrap(); - let BlockNumHash { hash: actual_hash, number: _ } = - tokio::time::timeout(std::time::Duration::from_secs(10), rx) - .await - .expect("test timed out") - .expect("channel closed unexpectedly") - .expect("no hash returned"); + let BlockNumHash { hash: actual_hash, number: _ } = rx + .recv_timeout(std::time::Duration::from_secs(10)) + .expect("test timed out") + .expect("no hash returned"); assert_eq!(block_hash, actual_hash); } - #[tokio::test] - async fn test_save_blocks_multiple_blocks() { + #[test] + fn test_save_blocks_multiple_blocks() { reth_tracing::init_test_tracing(); let persistence_handle = default_persistence_handle(); let mut test_block_builder = TestBlockBuilder::eth(); let blocks = test_block_builder.get_executed_blocks(0..5).collect::>(); let last_hash = blocks.last().unwrap().recovered_block().hash(); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); persistence_handle.save_blocks(blocks, tx).unwrap(); - let BlockNumHash { hash: actual_hash, number: _ } = rx.await.unwrap().unwrap(); + let BlockNumHash { hash: actual_hash, number: _ } = rx.recv().unwrap().unwrap(); assert_eq!(last_hash, actual_hash); } - #[tokio::test] - async fn test_save_blocks_multiple_calls() { + #[test] + fn test_save_blocks_multiple_calls() { reth_tracing::init_test_tracing(); let persistence_handle = default_persistence_handle(); @@ -383,11 +381,11 @@ mod tests { for range in ranges { let blocks = test_block_builder.get_executed_blocks(range).collect::>(); let last_hash = blocks.last().unwrap().recovered_block().hash(); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); persistence_handle.save_blocks(blocks, tx).unwrap(); - let BlockNumHash { hash: actual_hash, number: _ } = rx.await.unwrap().unwrap(); + let BlockNumHash { hash: actual_hash, number: _ } = rx.recv().unwrap().unwrap(); assert_eq!(last_hash, actual_hash); } } diff --git a/crates/engine/tree/src/tree/error.rs b/crates/engine/tree/src/tree/error.rs index 8589bc59d3..8899542af6 100644 --- a/crates/engine/tree/src/tree/error.rs +++ b/crates/engine/tree/src/tree/error.rs @@ -6,15 +6,13 @@ use reth_errors::{BlockExecutionError, BlockValidationError, ProviderError}; use reth_evm::execute::InternalBlockExecutionError; use reth_payload_primitives::NewPayloadError; use reth_primitives_traits::{Block, BlockBody, SealedBlock}; -use tokio::sync::oneshot::error::TryRecvError; -/// This is an error that can come from advancing persistence. Either this can be a -/// [`TryRecvError`], or this can be a [`ProviderError`] +/// This is an error that can come from advancing persistence. #[derive(Debug, thiserror::Error)] pub enum AdvancePersistenceError { - /// An error that can be from failing to receive a value from persistence - #[error(transparent)] - RecvError(#[from] TryRecvError), + /// The persistence channel was closed unexpectedly + #[error("persistence channel closed")] + ChannelClosed, /// A provider error #[error(transparent)] Provider(#[from] ProviderError), diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index cb2a715130..5ba703ff8d 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -37,18 +37,12 @@ use reth_revm::database::StateProviderDatabase; use reth_stages_api::ControlFlow; use revm::state::EvmState; use state::TreeState; -use std::{ - fmt::Debug, - ops, - sync::{ - mpsc::{Receiver, RecvError, RecvTimeoutError, Sender}, - Arc, - }, - time::Instant, -}; +use std::{fmt::Debug, ops, sync::Arc, time::Instant}; + +use crossbeam_channel::{Receiver, Sender}; use tokio::sync::{ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - oneshot::{self, error::TryRecvError}, + oneshot, }; use tracing::*; @@ -338,7 +332,7 @@ where engine_kind: EngineApiKind, evm_config: C, ) -> Self { - let (incoming_tx, incoming) = std::sync::mpsc::channel(); + let (incoming_tx, incoming) = crossbeam_channel::unbounded(); Self { provider, @@ -423,8 +417,8 @@ where /// This will block the current thread and process incoming messages. pub fn run(mut self) { loop { - match self.try_recv_engine_message() { - Ok(Some(msg)) => { + match self.wait_for_event() { + LoopEvent::EngineMessage(msg) => { debug!(target: "engine::tree", %msg, "received new engine message"); match self.on_engine_message(msg) { Ok(ops::ControlFlow::Break(())) => return, @@ -435,15 +429,22 @@ where } } } - Ok(None) => { - debug!(target: "engine::tree", "received no engine message for some time, while waiting for persistence task to complete"); + LoopEvent::PersistenceComplete { result, start_time } => { + if let Err(err) = self.on_persistence_complete(result, start_time) { + error!(target: "engine::tree", %err, "Persistence complete handling failed"); + return + } } - Err(_err) => { - error!(target: "engine::tree", "Engine channel disconnected"); + LoopEvent::Disconnected => { + error!(target: "engine::tree", "Channel disconnected"); return } } + // Always check if we need to trigger new persistence after any event: + // - After engine messages: new blocks may have been inserted that exceed the + // persistence threshold + // - After persistence completion: we can now persist more blocks if needed if let Err(err) = self.advance_persistence() { error!(target: "engine::tree", %err, "Advancing persistence failed"); return @@ -451,6 +452,47 @@ where } } + /// Blocks until the next event is ready: either an incoming engine message or a persistence + /// completion (if one is in progress). + /// + /// Uses biased selection to prioritize persistence completion to update in-memory state and + /// unblock further writes. + fn wait_for_event(&mut self) -> LoopEvent { + // Take ownership of persistence rx if present + let maybe_persistence = self.persistence_state.rx.take(); + + if let Some((persistence_rx, start_time, action)) = maybe_persistence { + // Biased select prioritizes persistence completion to update in memory state and + // unblock further writes + crossbeam_channel::select_biased! { + recv(persistence_rx) -> result => { + // Don't put it back - consumed (oneshot-like behavior) + match result { + Ok(value) => LoopEvent::PersistenceComplete { + result: value, + start_time, + }, + Err(_) => LoopEvent::Disconnected, + } + }, + recv(self.incoming) -> msg => { + // Put the persistence rx back - we didn't consume it + self.persistence_state.rx = Some((persistence_rx, start_time, action)); + match msg { + Ok(m) => LoopEvent::EngineMessage(m), + Err(_) => LoopEvent::Disconnected, + } + }, + } + } else { + // No persistence in progress - just wait on incoming + match self.incoming.recv() { + Ok(m) => LoopEvent::EngineMessage(m), + Err(_) => LoopEvent::Disconnected, + } + } + } + /// Invoked when previously requested blocks were downloaded. /// /// If the block count exceeds the configured batch size we're allowed to execute at once, this @@ -1191,39 +1233,13 @@ where .with_event(TreeEvent::Download(DownloadRequest::single_block(target)))) } - /// Attempts to receive the next engine request. - /// - /// If there's currently no persistence action in progress, this will block until a new request - /// is received. If there's a persistence action in progress, this will try to receive the - /// next request with a timeout to not block indefinitely and return `Ok(None)` if no request is - /// received in time. - /// - /// Returns an error if the engine channel is disconnected. - #[expect(clippy::type_complexity)] - fn try_recv_engine_message( - &self, - ) -> Result, N::Block>>, RecvError> { - if self.persistence_state.in_progress() { - // try to receive the next request with a timeout to not block indefinitely - match self.incoming.recv_timeout(std::time::Duration::from_millis(500)) { - Ok(msg) => Ok(Some(msg)), - Err(err) => match err { - RecvTimeoutError::Timeout => Ok(None), - RecvTimeoutError::Disconnected => Err(RecvError), - }, - } - } else { - self.incoming.recv().map(Some) - } - } - /// Helper method to remove blocks and set the persistence state. This ensures we keep track of /// the current persistence action while we're removing blocks. fn remove_blocks(&mut self, new_tip_num: u64) { debug!(target: "engine::tree", ?new_tip_num, last_persisted_block_number=?self.persistence_state.last_persisted_block.number, "Removing blocks using persistence task"); if new_tip_num < self.persistence_state.last_persisted_block.number { debug!(target: "engine::tree", ?new_tip_num, "Starting remove blocks job"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); let _ = self.persistence.remove_blocks_above(new_tip_num, tx); self.persistence_state.start_remove(new_tip_num, rx); } @@ -1245,35 +1261,17 @@ where .expect("Checked non-empty persisting blocks"); debug!(target: "engine::tree", count=blocks_to_persist.len(), blocks = ?blocks_to_persist.iter().map(|block| block.recovered_block().num_hash()).collect::>(), "Persisting blocks"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = crossbeam_channel::bounded(1); let _ = self.persistence.save_blocks(blocks_to_persist, tx); self.persistence_state.start_save(highest_num_hash, rx); } - /// Attempts to advance the persistence state. + /// Triggers new persistence actions if no persistence task is currently in progress. /// - /// If we're currently awaiting a response this will try to receive the response (non-blocking) - /// or send a new persistence action if necessary. + /// This checks if we need to remove blocks (disk reorg) or save new blocks to disk. + /// Persistence completion is handled separately via the `wait_for_event` method. fn advance_persistence(&mut self) -> Result<(), AdvancePersistenceError> { - if self.persistence_state.in_progress() { - let (mut rx, start_time, current_action) = self - .persistence_state - .rx - .take() - .expect("if a persistence task is in progress Receiver must be Some"); - // Check if persistence has complete - match rx.try_recv() { - Ok(last_persisted_hash_num) => { - self.on_persistence_complete(last_persisted_hash_num, start_time)?; - } - Err(TryRecvError::Closed) => return Err(TryRecvError::Closed.into()), - Err(TryRecvError::Empty) => { - self.persistence_state.rx = Some((rx, start_time, current_action)) - } - } - } - if !self.persistence_state.in_progress() { if let Some(new_tip_num) = self.find_disk_reorg()? { self.remove_blocks(new_tip_num) @@ -1306,7 +1304,7 @@ where loop { // Wait for any in-progress persistence to complete (blocking) if let Some((rx, start_time, _action)) = self.persistence_state.rx.take() { - let result = rx.blocking_recv().map_err(|_| TryRecvError::Closed)?; + let result = rx.recv().map_err(|_| AdvancePersistenceError::ChannelClosed)?; self.on_persistence_complete(result, start_time)?; } @@ -1322,6 +1320,31 @@ where } } + /// Tries to poll for a completed persistence task (non-blocking). + /// + /// Returns `true` if a persistence task was completed, `false` otherwise. + #[cfg(test)] + pub fn try_poll_persistence(&mut self) -> Result { + let Some((rx, start_time, action)) = self.persistence_state.rx.take() else { + return Ok(false); + }; + + match rx.try_recv() { + Ok(result) => { + self.on_persistence_complete(result, start_time)?; + Ok(true) + } + Err(crossbeam_channel::TryRecvError::Empty) => { + // Not ready yet, put it back + self.persistence_state.rx = Some((rx, start_time, action)); + Ok(false) + } + Err(crossbeam_channel::TryRecvError::Disconnected) => { + Err(AdvancePersistenceError::ChannelClosed) + } + } + } + /// Handles a completed persistence task. fn on_persistence_complete( &mut self, @@ -2848,6 +2871,26 @@ where } } +/// Events received in the main engine loop. +#[derive(Debug)] +enum LoopEvent +where + N: NodePrimitives, + T: PayloadTypes, +{ + /// An engine API message was received. + EngineMessage(FromEngine, N::Block>), + /// A persistence task completed. + PersistenceComplete { + /// The result of the persistence operation. + result: Option, + /// When the persistence operation started. + start_time: Instant, + }, + /// A channel was disconnected. + Disconnected, +} + /// Block inclusion can be valid, accepted, or invalid. Invalid blocks are returned as an error /// variant. /// diff --git a/crates/engine/tree/src/tree/persistence_state.rs b/crates/engine/tree/src/tree/persistence_state.rs index 82a8078447..847904c0dd 100644 --- a/crates/engine/tree/src/tree/persistence_state.rs +++ b/crates/engine/tree/src/tree/persistence_state.rs @@ -22,12 +22,12 @@ use alloy_eips::BlockNumHash; use alloy_primitives::B256; +use crossbeam_channel::Receiver as CrossbeamReceiver; use std::time::Instant; -use tokio::sync::oneshot; use tracing::trace; /// The state of the persistence task. -#[derive(Default, Debug)] +#[derive(Debug)] pub struct PersistenceState { /// Hash and number of the last block persisted. /// @@ -36,7 +36,7 @@ pub struct PersistenceState { /// Receiver end of channel where the result of the persistence task will be /// sent when done. A None value means there's no persistence task in progress. pub(crate) rx: - Option<(oneshot::Receiver>, Instant, CurrentPersistenceAction)>, + Option<(CrossbeamReceiver>, Instant, CurrentPersistenceAction)>, } impl PersistenceState { @@ -50,7 +50,7 @@ impl PersistenceState { pub(crate) fn start_remove( &mut self, new_tip_num: u64, - rx: oneshot::Receiver>, + rx: CrossbeamReceiver>, ) { self.rx = Some((rx, Instant::now(), CurrentPersistenceAction::RemovingBlocks { new_tip_num })); @@ -60,7 +60,7 @@ impl PersistenceState { pub(crate) fn start_save( &mut self, highest: BlockNumHash, - rx: oneshot::Receiver>, + rx: CrossbeamReceiver>, ) { self.rx = Some((rx, Instant::now(), CurrentPersistenceAction::SavingBlocks { highest })); } diff --git a/crates/engine/tree/src/tree/tests.rs b/crates/engine/tree/src/tree/tests.rs index b1725bc36c..c04d50414e 100644 --- a/crates/engine/tree/src/tree/tests.rs +++ b/crates/engine/tree/src/tree/tests.rs @@ -31,7 +31,7 @@ use std::{ collections::BTreeMap, str::FromStr, sync::{ - mpsc::{channel, Receiver, Sender}, + mpsc::{Receiver, Sender}, Arc, }, }; @@ -97,6 +97,7 @@ struct TestChannel { impl TestChannel { /// Creates a new test channel fn spawn_channel() -> (Sender, Receiver, TestChannelHandle) { + use std::sync::mpsc::channel; let (original_tx, original_rx) = channel(); let (wrapped_tx, wrapped_rx) = channel(); let (release_tx, release_rx) = channel(); @@ -143,7 +144,9 @@ struct TestHarness { BasicEngineValidator, MockEvmConfig, >, - to_tree_tx: Sender, Block>>, + to_tree_tx: crossbeam_channel::Sender< + FromEngine, Block>, + >, from_tree_rx: UnboundedReceiver, blocks: Vec, action_rx: Receiver, @@ -153,6 +156,7 @@ struct TestHarness { impl TestHarness { fn new(chain_spec: Arc) -> Self { + use std::sync::mpsc::channel; let (action_tx, action_rx) = channel(); Self::with_persistence_channel(chain_spec, action_tx, action_rx) } @@ -205,7 +209,7 @@ impl TestHarness { engine_api_tree_state, canonical_in_memory_state, persistence_handle, - PersistenceState::default(), + PersistenceState { last_persisted_block: BlockNumHash::default(), rx: None }, payload_builder, // always assume enough parallelism for tests TreeConfig::default().with_legacy_state_root(false).with_has_enough_parallelism(true), @@ -399,10 +403,8 @@ impl ValidatorTestHarness { /// Configure `PersistenceState` for specific persistence scenarios fn start_persistence_operation(&mut self, action: CurrentPersistenceAction) { - use tokio::sync::oneshot; - // Create a dummy receiver for testing - it will never receive a value - let (_tx, rx) = oneshot::channel(); + let (_tx, rx) = crossbeam_channel::bounded(1); match action { CurrentPersistenceAction::SavingBlocks { highest } => { @@ -498,11 +500,17 @@ fn test_tree_persist_block_batch() { test_harness.to_tree_tx.send(FromEngine::DownloadedBlocks(blocks)).unwrap(); // process the message - let msg = test_harness.tree.try_recv_engine_message().unwrap().unwrap(); + let msg = match test_harness.tree.wait_for_event() { + super::LoopEvent::EngineMessage(msg) => msg, + other => panic!("unexpected event: {other:?}"), + }; let _ = test_harness.tree.on_engine_message(msg).unwrap(); // we now should receive the other batch - let msg = test_harness.tree.try_recv_engine_message().unwrap().unwrap(); + let msg = match test_harness.tree.wait_for_event() { + super::LoopEvent::EngineMessage(msg) => msg, + other => panic!("unexpected event: {other:?}"), + }; match msg { FromEngine::DownloadedBlocks(blocks) => { assert_eq!(blocks.len(), tree_config.max_execute_block_batch_size()); @@ -753,8 +761,8 @@ async fn test_tree_state_on_new_head_reorg() { }) ); - // after advancing persistence, we should be at `None` for the next action - test_harness.tree.advance_persistence().unwrap(); + // after polling persistence completion, we should be at `None` for the next action + test_harness.tree.try_poll_persistence().unwrap(); let current_action = test_harness.tree.persistence_state.current_action().cloned(); assert_eq!(current_action, None);