diff --git a/crates/stages/src/stages/bodies.rs b/crates/stages/src/stages/bodies.rs index b604e9df4c..e500417343 100644 --- a/crates/stages/src/stages/bodies.rs +++ b/crates/stages/src/stages/bodies.rs @@ -417,9 +417,9 @@ mod tests { // Unwind all of it let unwind_to = 1; let input = UnwindInput { bad_block: None, stage_progress, unwind_to }; - let rx = runner.unwind(input); + let res = runner.unwind(input).await; assert_matches!( - rx.await.unwrap(), + res, Ok(UnwindOutput { stage_progress }) if stage_progress == 1 ); diff --git a/crates/stages/src/test_utils/macros.rs b/crates/stages/src/test_utils/macros.rs index 0bb0ccdc8c..f92c96c7dc 100644 --- a/crates/stages/src/test_utils/macros.rs +++ b/crates/stages/src/test_utils/macros.rs @@ -88,9 +88,9 @@ macro_rules! stage_test_suite { let input = crate::stage::UnwindInput::default(); // Run stage unwind - let rx = runner.unwind(input); + let rx = runner.unwind(input).await; assert_matches!( - rx.await.unwrap(), + rx, Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to ); @@ -128,11 +128,11 @@ macro_rules! stage_test_suite { let unwind_input = crate::stage::UnwindInput { unwind_to: stage_progress, stage_progress, bad_block: None, }; - let rx = runner.unwind(unwind_input); + let rx = runner.unwind(unwind_input).await; // Assert the successful unwind result assert_matches!( - rx.await.unwrap(), + rx, Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to ); diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs index c81574dbd5..761e930959 100644 --- a/crates/stages/src/test_utils/runner.rs +++ b/crates/stages/src/test_utils/runner.rs @@ -59,15 +59,16 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner { } } +#[async_trait::async_trait] pub(crate) trait UnwindStageTestRunner: StageTestRunner { /// Validate the unwind fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>; /// Run [Stage::unwind] and return a receiver for the result. - fn unwind( + async fn unwind( &self, input: UnwindInput, - ) -> oneshot::Receiver>> { + ) -> Result> { let (tx, rx) = oneshot::channel(); let (db, mut stage) = (self.db().inner(), self.stage()); tokio::spawn(async move { @@ -76,6 +77,6 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner { db.commit().expect("failed to commit"); tx.send(result).expect("failed to send result"); }); - rx + Box::pin(rx).await.unwrap() } }