diff --git a/crates/stages/src/stages/senders.rs b/crates/stages/src/stages/senders.rs index 2e29a61865..a81e96a317 100644 --- a/crates/stages/src/stages/senders.rs +++ b/crates/stages/src/stages/senders.rs @@ -23,6 +23,9 @@ const SENDERS: StageId = StageId("Senders"); pub struct SendersStage { /// The size of the chunk for parallel sender recovery pub batch_size: usize, + /// The size of inserted items after which the control + /// flow will be returned to the pipeline for commit + pub commit_threshold: u64, } #[derive(Error, Debug)] @@ -60,7 +63,8 @@ impl Stage for SendersStage { let start_tx_index = db.get_first_tx_id(stage_progress + 1)?; // Look up the end index for transaction range (inclusive) - let max_block_num = input.previous_stage_progress(); + let previous_stage_progress = input.previous_stage_progress(); + let max_block_num = previous_stage_progress.min(stage_progress + self.commit_threshold); let end_tx_index = match db.get_latest_tx_id(max_block_num) { Ok(id) => id, // No transactions in the database @@ -101,7 +105,8 @@ impl Stage for SendersStage { recovered.into_iter().try_for_each(|(id, sender)| senders_cursor.append(id, sender))?; } - Ok(ExecOutput { stage_progress: max_block_num, done: true, reached_tip: true }) + let done = max_block_num >= previous_stage_progress; + Ok(ExecOutput { stage_progress: max_block_num, done, reached_tip: done }) } /// Unwind the stage. @@ -119,20 +124,71 @@ impl Stage for SendersStage { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use reth_interfaces::test_utils::generators::random_block_range; use reth_primitives::{BlockLocked, BlockNumber, H256}; use super::*; use crate::test_utils::{ stage_test_suite, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB, - UnwindStageTestRunner, + UnwindStageTestRunner, PREV_STAGE_ID, }; stage_test_suite!(SendersTestRunner); - #[derive(Default)] + #[tokio::test] + async fn execute_intermediate_commit() { + let threshold = 50; + let mut runner = SendersTestRunner::default(); + runner.set_threshold(threshold); + let (stage_progress, previous_stage) = (1000, 1100); // input exceeds threshold + let first_input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + + // Seed only once with full input range + runner.seed_execution(first_input).expect("failed to seed execution"); + + // Execute first time + let result = runner.execute(first_input).await.unwrap(); + let expected_progress = stage_progress + threshold; + assert_matches!( + result, + Ok(ExecOutput { done: false, reached_tip: false, stage_progress }) + if stage_progress == expected_progress + ); + + // Execute second time + let second_input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(expected_progress), + }; + let result = runner.execute(second_input).await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) + if stage_progress == previous_stage + ); + + assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed"); + } + struct SendersTestRunner { db: TestStageDB, + threshold: u64, + } + + impl Default for SendersTestRunner { + fn default() -> Self { + Self { threshold: 1000, db: TestStageDB::default() } + } + } + + impl SendersTestRunner { + fn set_threshold(&mut self, threshold: u64) { + self.threshold = threshold; + } } impl StageTestRunner for SendersTestRunner { @@ -143,7 +199,7 @@ mod tests { } fn stage(&self) -> Self::S { - SendersStage { batch_size: 100 } + SendersStage { batch_size: 100, commit_threshold: self.threshold } } }