diff --git a/crates/interfaces/src/test_utils/headers.rs b/crates/interfaces/src/test_utils/headers.rs index 84ed8d79c9..4945627e37 100644 --- a/crates/interfaces/src/test_utils/headers.rs +++ b/crates/interfaces/src/test_utils/headers.rs @@ -1,6 +1,6 @@ //! Testing support for headers related interfaces. use crate::{ - consensus::{self, Consensus, Error}, + consensus::{self, Consensus}, p2p::headers::{ client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream}, downloader::HeaderDownloader, @@ -9,7 +9,11 @@ use crate::{ }; use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512}; use reth_rpc_types::engine::ForkchoiceState; -use std::{collections::HashSet, sync::Arc, time::Duration}; +use std::{ + collections::HashSet, + sync::{Arc, Mutex, MutexGuard}, + time::Duration, +}; use tokio::sync::{broadcast, mpsc, watch}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; @@ -17,12 +21,13 @@ use tokio_stream::{wrappers::BroadcastStream, StreamExt}; #[derive(Debug)] pub struct TestHeaderDownloader { client: Arc, + consensus: Arc, } impl TestHeaderDownloader { /// Instantiates the downloader with the mock responses - pub fn new(client: Arc) -> Self { - Self { client } + pub fn new(client: Arc, consensus: Arc) -> Self { + Self { client, consensus } } } @@ -36,7 +41,7 @@ impl HeaderDownloader for TestHeaderDownloader { } fn consensus(&self) -> &Self::Consensus { - unimplemented!() + &self.consensus } fn client(&self) -> &Self::Client { @@ -48,6 +53,12 @@ impl HeaderDownloader for TestHeaderDownloader { _: &SealedHeader, _: &ForkchoiceState, ) -> Result, DownloadError> { + // call consensus stub first. fails if the flag is set + let empty = SealedHeader::default(); + self.consensus + .validate_header(&empty, &empty) + .map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?; + let stream = self.client.stream_headers().await; let stream = stream.timeout(Duration::from_secs(1)); @@ -139,7 +150,7 @@ pub struct TestConsensus { /// Watcher over the forkchoice state channel: (watch::Sender, watch::Receiver), /// Flag whether the header validation should purposefully fail - fail_validation: bool, + fail_validation: Mutex, } impl Default for TestConsensus { @@ -150,7 +161,7 @@ impl Default for TestConsensus { finalized_block_hash: H256::zero(), safe_block_hash: H256::zero(), }), - fail_validation: false, + fail_validation: Mutex::new(false), } } } @@ -166,9 +177,14 @@ impl TestConsensus { self.channel.0.send(state).expect("updating fork choice state failed"); } + /// Acquire lock on failed validation flag + pub fn fail_validation(&self) -> MutexGuard<'_, bool> { + self.fail_validation.lock().expect("failed to acquite consensus mutex") + } + /// Update the validation flag - pub fn set_fail_validation(&mut self, val: bool) { - self.fail_validation = val; + pub fn set_fail_validation(&self, val: bool) { + *self.fail_validation() = val; } } @@ -183,15 +199,15 @@ impl Consensus for TestConsensus { _header: &SealedHeader, _parent: &SealedHeader, ) -> Result<(), consensus::Error> { - if self.fail_validation { + if *self.fail_validation() { Err(consensus::Error::BaseFeeMissing) } else { Ok(()) } } - fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), Error> { - if self.fail_validation { + fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), consensus::Error> { + if *self.fail_validation() { Err(consensus::Error::BaseFeeMissing) } else { Ok(()) diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 697e2f525f..b44b183ce6 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -20,6 +20,9 @@ mod pipeline; mod stage; mod util; +#[cfg(test)] +mod test_utils; + /// Implementations of stages. pub mod stages; diff --git a/crates/stages/src/stages/bodies.rs b/crates/stages/src/stages/bodies.rs index 498034e995..c349d0ecc4 100644 --- a/crates/stages/src/stages/bodies.rs +++ b/crates/stages/src/stages/bodies.rs @@ -15,7 +15,7 @@ use reth_primitives::{ proofs::{EMPTY_LIST_HASH, EMPTY_ROOT}, BlockLocked, BlockNumber, SealedHeader, H256, }; -use std::fmt::Debug; +use std::{fmt::Debug, sync::Arc}; use tracing::warn; const BODIES: StageId = StageId("Bodies"); @@ -51,9 +51,9 @@ const BODIES: StageId = StageId("Bodies"); #[derive(Debug)] pub struct BodyStage { /// The body downloader. - pub downloader: D, + pub downloader: Arc, /// The consensus engine. - pub consensus: C, + pub consensus: Arc, /// The maximum amount of block bodies to process in one stage execution. /// /// Smaller batch sizes result in less memory usage, but more disk I/O. Larger batch sizes @@ -232,67 +232,47 @@ impl BodyStage { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::{ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID}; + use crate::test_utils::{ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, stage_test_suite}; use assert_matches::assert_matches; - use reth_eth_wire::BlockBody; use reth_interfaces::{ consensus, p2p::bodies::error::DownloadError, - test_utils::generators::{random_block, random_block_range}, }; - use reth_primitives::{BlockNumber, H256}; use std::collections::HashMap; use test_utils::*; - /// Check that the execution is short-circuited if the database is empty. - #[tokio::test] - async fn empty_db() { - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true }) - ) - } + stage_test_suite!(BodyTestRunner); - /// Check that the execution is short-circuited if the target was already reached. - #[tokio::test] - async fn already_reached_target() { - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let rx = runner.execute(ExecInput { - previous_stage: Some((PREV_STAGE_ID, 100)), - stage_progress: Some(100), - }); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress: 100, reached_tip: true, done: true }) - ) - } + /// Check that the execution is short-circuited if the database is empty. + // #[tokio::test] + // TODO: + // async fn empty_db() { + // let runner = BodyTestRunner::new(TestBodyDownloader::default); + // let rx = runner.execute(ExecInput::default()); + // assert_matches!( + // rx.await.unwrap(), + // Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true }) + // ) + // } /// Checks that the stage downloads at most `batch_size` blocks. #[tokio::test] async fn partial_body_download() { - // Generate blocks - let blocks = random_block_range(1..200, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 200); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size (max we sync per stage execution) to less than the number of blocks // the previous stage synced (10 vs 20) runner.set_batch_size(10); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let input = ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }; let rx = runner.execute(input); // Check that we only synced around `batch_size` blocks even though the number of blocks @@ -308,26 +288,20 @@ mod tests { /// Same as [partial_body_download] except the `batch_size` is not hit. #[tokio::test] async fn full_body_download() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 21); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size to more than what the previous stage synced (40 vs 20) runner.set_batch_size(40); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let input = ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }; let rx = runner.execute(input); // Check that we synced all blocks successfully, even though our `batch_size` allows us to @@ -343,23 +317,18 @@ mod tests { /// Same as [full_body_download] except we have made progress before #[tokio::test] async fn sync_from_previous_progress() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 21); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we synced at least 10 blocks let first_run = rx.await.unwrap(); @@ -371,7 +340,7 @@ mod tests { // Execute again on top of the previous run let input = ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), + previous_stage: Some((PREV_STAGE_ID, previous_stage)), stage_progress: Some(first_run_progress), }; let rx = runner.execute(input); @@ -388,118 +357,48 @@ mod tests { /// Checks that the stage asks to unwind if pre-validation of the block fails. #[tokio::test] async fn pre_validation_failure() { - // Generate blocks #1-19 - let blocks = random_block_range(1..20, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 20); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Fail validation - runner.set_fail_validation(true); - - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); + runner.consensus.set_fail_validation(true); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that the error bubbles up assert_matches!( rx.await.unwrap(), - Err(StageError::Validation { block: 1, error: consensus::Error::BaseFeeMissing }) + Err(StageError::Validation { error: consensus::Error::BaseFeeMissing, .. }) ); - } - - /// Checks that the stage unwinds correctly with no data. - #[tokio::test] - async fn unwind_empty_db() { - let unwind_to = 10; - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let input = UnwindInput { bad_block: None, stage_progress: 20, unwind_to }; - let rx = runner.unwind(input); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to - ); - assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); - } - - /// Checks that the stage unwinds correctly with data. - #[tokio::test] - async fn unwind() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); - - // Set the batch size to more than what the previous stage synced (40 vs 20) - runner.set_batch_size(40); - - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - - // Run the stage - let execute_input = ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }; - let rx = runner.execute(execute_input); - - // Check that we synced all blocks successfully, even though our `batch_size` allows us to - // sync more (if there were more headers) - let output = rx.await.unwrap(); - assert_matches!( - output, - Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true }) - ); - let output = output.unwrap(); - assert!(runner.validate_execution(execute_input, Some(output.clone())).is_ok(), "execution validation"); - - // Unwind all of it - let unwind_to = 1; - let unwind_input = UnwindInput { bad_block: None, stage_progress: output.stage_progress, unwind_to }; - let rx = runner.unwind(unwind_input); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == 1 - ); - - assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation"); + assert!(runner.validate_execution(input, None).is_ok(), "execution validation"); } /// Checks that the stage unwinds correctly, even if a transaction in a block is missing. #[tokio::test] async fn unwind_missing_tx() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 21); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size to more than what the previous stage synced (40 vs 20) runner.set_batch_size(40); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we synced all blocks successfully, even though our `batch_size` allows us to // sync more (if there were more headers) @@ -536,42 +435,42 @@ mod tests { /// try again? #[tokio::test] async fn downloader_timeout() { - // Generate a header - let header = random_block(1, Some(GENESIS_HASH)).header; - let runner = BodyTestRunner::new(|| { - TestBodyDownloader::new(HashMap::from([( - header.hash(), - Err(DownloadError::Timeout { header_hash: header.hash() }), - )])) - }); + let (stage_progress, previous_stage) = (1, 3); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner.insert_header(&header).expect("Could not insert header"); + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), // TODO: None? + }; + let blocks = runner.seed_execution(input).expect("failed to seed execution"); + + // overwrite responses + let header = blocks.last().unwrap(); + runner.set_responses(HashMap::from([( + header.hash(), + Err(DownloadError::Timeout { header_hash: header.hash() }), + )])); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((PREV_STAGE_ID, 1)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that the error bubbles up assert_matches!(rx.await.unwrap(), Err(StageError::Internal(_))); + assert!(runner.validate_execution(input, None).is_ok(), "execution validation"); } mod test_utils { use crate::{ stages::bodies::BodyStage, - util::test_utils::{ + test_utils::{ ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner, TestRunnerError, }, ExecInput, UnwindInput, ExecOutput, }; use assert_matches::assert_matches; - use async_trait::async_trait; use reth_eth_wire::BlockBody; use reth_interfaces::{ - db, db::{ models::{BlockNumHash, StoredBlockBody}, tables, DbCursorRO, DbTx, DbTxMut, @@ -581,12 +480,12 @@ mod tests { downloader::{BodiesStream, BodyDownloader}, error::{BodiesClientError, DownloadError}, }, - test_utils::TestConsensus, + test_utils::{TestConsensus, generators::random_block_range}, }; use reth_primitives::{ BigEndianHash, BlockLocked, BlockNumber, Header, SealedHeader, H256, U256, }; - use std::{collections::HashMap, ops::Deref, time::Duration}; + use std::{collections::HashMap, ops::Deref, time::Duration, sync::Arc}; /// The block hash of the genesis block. pub(crate) const GENESIS_HASH: H256 = H256::zero(); @@ -605,43 +504,35 @@ mod tests { } /// A helper struct for running the [BodyStage]. - pub(crate) struct BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - downloader_builder: F, + pub(crate) struct BodyTestRunner { + pub(crate) consensus: Arc, + responses: HashMap>, db: StageTestDB, batch_size: u64, - fail_validation: bool, } - impl BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - /// Build a new test runner. - pub(crate) fn new(downloader_builder: F) -> Self { - BodyTestRunner { - downloader_builder, + impl Default for BodyTestRunner { + fn default() -> Self { + Self { + consensus: Arc::new(TestConsensus::default()), + responses: HashMap::default(), db: StageTestDB::default(), batch_size: 10, - fail_validation: false, } } + } + impl BodyTestRunner { pub(crate) fn set_batch_size(&mut self, batch_size: u64) { self.batch_size = batch_size; } - pub(crate) fn set_fail_validation(&mut self, fail_validation: bool) { - self.fail_validation = fail_validation; + pub(crate) fn set_responses(&mut self, responses: HashMap>) { + self.responses = responses; } } - impl StageTestRunner for BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { + impl StageTestRunner for BodyTestRunner { type S = BodyStage; fn db(&self) -> &StageTestDB { @@ -649,33 +540,29 @@ mod tests { } fn stage(&self) -> Self::S { - let mut consensus = TestConsensus::default(); - consensus.set_fail_validation(self.fail_validation); - BodyStage { - downloader: (self.downloader_builder)(), - consensus, + downloader: Arc::new(TestBodyDownloader::new(self.responses.clone())), + consensus: self.consensus.clone(), batch_size: self.batch_size, } } } - impl ExecuteStageTestRunner for BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - type Seed = (); + #[async_trait::async_trait] + impl ExecuteStageTestRunner for BodyTestRunner { + type Seed = Vec; fn seed_execution( &mut self, input: ExecInput, - ) -> Result<(), TestRunnerError> { + ) -> Result { + let start = input.stage_progress.unwrap_or_default(); + let end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(); + let blocks = random_block_range(start..end, GENESIS_HASH); self.insert_genesis()?; - // TODO: - // self - // .insert_headers(blocks.iter().map(|block| &block.header)) - // .expect("Could not insert headers"); - Ok(()) + self.insert_headers(blocks.iter().map(|block| &block.header))?; + self.set_responses(blocks.iter().map(body_by_hash).collect()); + Ok(blocks) } fn validate_execution( @@ -683,31 +570,21 @@ mod tests { input: ExecInput, output: Option, ) -> Result<(), TestRunnerError> { - if let Some(output) = output { - self.validate_db_blocks(output.stage_progress)?; - } - Ok(()) + let highest_block = match output.as_ref() { + Some(output) => output.stage_progress, + None => input.stage_progress.unwrap_or_default(), + }; + self.validate_db_blocks(highest_block) } } - impl UnwindStageTestRunner for BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - fn seed_unwind( - &mut self, - input: UnwindInput, - highest_entry: u64, - ) -> Result<(), TestRunnerError> { - unimplemented!() - } - + impl UnwindStageTestRunner for BodyTestRunner { fn validate_unwind( &self, input: UnwindInput, ) -> Result<(), TestRunnerError> { self.db.check_no_entry_above::(input.unwind_to, |key| key.number())?; - if let Some(last_body) =self.last_body(){ + if let Some(last_body) = self.last_body() { let last_tx_id = last_body.base_tx_id + last_body.tx_amount; self.db.check_no_entry_above::(last_tx_id, |key| key)?; } @@ -715,15 +592,12 @@ mod tests { } } - impl BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { + impl BodyTestRunner { /// Insert the genesis block into the appropriate tables /// /// The genesis block always has no transactions and no ommers, and it always has the /// same hash. - pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> { + pub(crate) fn insert_genesis(&self) -> Result<(), TestRunnerError> { self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?; self.db.commit(|tx| { tx.put::( @@ -736,13 +610,14 @@ mod tests { } /// Insert header into tables - pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> { - self.insert_headers(std::iter::once(header)) + pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), TestRunnerError> { + self.insert_headers(std::iter::once(header))?; + Ok(()) } /// Insert headers into tables /// TODO: move to common inserter - pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> + pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), TestRunnerError> where I: Iterator, { @@ -778,7 +653,7 @@ mod tests { pub(crate) fn validate_db_blocks( &self, highest_block: BlockNumber, - ) -> Result<(), db::Error> { + ) -> Result<(), TestRunnerError> { self.db.query(|tx| { let mut block_body_cursor = tx.cursor::()?; let mut transaction_cursor = tx.cursor::()?; @@ -806,7 +681,8 @@ mod tests { } Ok(()) - }) + })?; + Ok(()) } } @@ -815,7 +691,7 @@ mod tests { #[derive(Debug)] pub(crate) struct NoopClient; - #[async_trait] + #[async_trait::async_trait] impl BodiesClient for NoopClient { async fn get_block_body(&self, _: H256) -> Result { panic!("Noop client should not be called") @@ -824,7 +700,7 @@ mod tests { // TODO(onbjerg): Move /// A [BodyDownloader] that is backed by an internal [HashMap] for testing. - #[derive(Debug, Default)] + #[derive(Debug, Default, Clone)] pub(crate) struct TestBodyDownloader { responses: HashMap>, } @@ -854,14 +730,11 @@ mod tests { { Box::pin(futures_util::stream::iter(hashes.into_iter().map( |(block_number, hash)| { - Ok(( - *block_number, - *hash, - self.responses - .get(hash) - .expect("Stage tried downloading a block we do not have.") - .clone()?, - )) + let result = self.responses + .get(hash) + .expect("Stage tried downloading a block we do not have.") + .clone()?; + Ok((*block_number, *hash, result)) }, ))) } diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 0ea0c86dea..49039a4c0c 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -190,7 +190,7 @@ impl HeaderStage { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::{ + use crate::test_utils::{ stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, }; use assert_matches::assert_matches; @@ -198,8 +198,6 @@ mod tests { stage_test_suite!(HeadersTestRunner); - // TODO: test consensus propagation error - /// Check that the execution errors on empty database or /// prev progress missing from the database. #[tokio::test] @@ -218,6 +216,20 @@ mod tests { assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } + /// Check that validation error is propagated during the execution. + #[tokio::test] + async fn execute_validation_error() { + let mut runner = HeadersTestRunner::default(); + runner.consensus.set_fail_validation(true); + let input = ExecInput::default(); + let seed = runner.seed_execution(input).expect("failed to seed execution"); + let rx = runner.execute(input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + let result = rx.await.unwrap(); + assert_matches!(result, Err(StageError::Validation { .. })); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); + } + /// Execute the stage with linear downloader #[tokio::test] async fn execute_with_linear_downloader() { @@ -254,7 +266,7 @@ mod tests { mod test_runner { use crate::{ stages::headers::HeaderStage, - util::test_utils::{ + test_utils::{ ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError, UnwindStageTestRunner, }, @@ -262,14 +274,14 @@ mod tests { }; use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader}; use reth_interfaces::{ - db::{self, models::blocks::BlockNumHash, tables, DbTx}, + db::{models::blocks::BlockNumHash, tables, DbTx}, p2p::headers::downloader::HeaderDownloader, test_utils::{ generators::{random_header, random_header_range}, TestConsensus, TestHeaderDownloader, TestHeadersClient, }, }; - use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256}; + use reth_primitives::{rpc::BigEndianHash, BlockNumber, SealedHeader, H256, U256}; use std::{ops::Deref, sync::Arc}; pub(crate) struct HeadersTestRunner { @@ -282,10 +294,11 @@ mod tests { impl Default for HeadersTestRunner { fn default() -> Self { let client = Arc::new(TestHeadersClient::default()); + let consensus = Arc::new(TestConsensus::default()); Self { client: client.clone(), - consensus: Arc::new(TestConsensus::default()), - downloader: Arc::new(TestHeaderDownloader::new(client)), + consensus: consensus.clone(), + downloader: Arc::new(TestHeaderDownloader::new(client, consensus)), db: StageTestDB::default(), } } @@ -345,45 +358,58 @@ mod tests { Ok(()) } + /// Validate stored headers fn validate_execution( &self, - _input: ExecInput, - _output: Option, + input: ExecInput, + output: Option, ) -> Result<(), TestRunnerError> { - // TODO: refine - // if let Some(ref headers) = self.context { - // // skip head and validate each - // headers.iter().skip(1).try_for_each(|h| self.validate_db_header(&h))?; - // } + let initial_stage_progress = input.stage_progress.unwrap_or_default(); + match output { + Some(output) if output.stage_progress > initial_stage_progress => { + self.db.query(|tx| { + for block_num in (initial_stage_progress..output.stage_progress).rev() { + // look up the header hash + let hash = tx + .get::(block_num)? + .expect("no header hash"); + let key: BlockNumHash = (block_num, hash).into(); + + // validate the header number + assert_eq!(tx.get::(hash)?, Some(block_num)); + + // validate the header + let header = tx.get::(key)?; + assert!(header.is_some()); + let header = header.unwrap().seal(); + assert_eq!(header.hash(), hash); + + // validate td consistency in the database + if header.number > initial_stage_progress { + let parent_td = tx.get::( + (header.number - 1, header.parent_hash).into(), + )?; + let td = tx.get::(key)?.unwrap(); + assert_eq!( + parent_td.map( + |td| U256::from_big_endian(&td) + header.difficulty + ), + Some(U256::from_big_endian(&td)) + ); + } + } + Ok(()) + })?; + } + _ => self.check_no_header_entry_above(initial_stage_progress)?, + }; Ok(()) } } impl UnwindStageTestRunner for HeadersTestRunner { - fn seed_unwind( - &mut self, - input: UnwindInput, - highest_entry: u64, - ) -> Result<(), TestRunnerError> { - let lowest_entry = input.unwind_to.saturating_sub(100); - let headers = random_header_range(lowest_entry..highest_entry, H256::zero()); - self.insert_headers(headers.iter())?; - Ok(()) - } - fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { - let unwind_to = input.unwind_to; - self.db.check_no_entry_above_by_value::( - unwind_to, - |val| val, - )?; - self.db - .check_no_entry_above::(unwind_to, |key| key)?; - self.db - .check_no_entry_above::(unwind_to, |key| key.number())?; - self.db - .check_no_entry_above::(unwind_to, |key| key.number())?; - Ok(()) + self.check_no_header_entry_above(input.unwind_to) } } @@ -400,12 +426,15 @@ mod tests { impl HeadersTestRunner { /// Insert header into tables - pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> { + pub(crate) fn insert_header( + &self, + header: &SealedHeader, + ) -> Result<(), TestRunnerError> { self.insert_headers(std::iter::once(header)) } /// Insert headers into tables - pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> + pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), TestRunnerError> where I: Iterator, { @@ -430,36 +459,16 @@ mod tests { Ok(()) } - /// Validate stored header against provided - pub(crate) fn validate_db_header( + pub(crate) fn check_no_header_entry_above( &self, - header: &SealedHeader, - ) -> Result<(), db::Error> { - self.db.query(|tx| { - let key: BlockNumHash = (header.number, header.hash()).into(); - - let db_number = tx.get::(header.hash())?; - assert_eq!(db_number, Some(header.number)); - - let db_header = tx.get::(key)?; - assert_eq!(db_header, Some(header.clone().unseal())); - - let db_canonical_header = tx.get::(header.number)?; - assert_eq!(db_canonical_header, Some(header.hash())); - - if header.number != 0 { - let parent_key: BlockNumHash = - (header.number - 1, header.parent_hash).into(); - let parent_td = tx.get::(parent_key)?; - let td = U256::from_big_endian(&tx.get::(key)?.unwrap()); - assert_eq!( - parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty), - Some(td) - ); - } - - Ok(()) - }) + block: BlockNumber, + ) -> Result<(), TestRunnerError> { + self.db + .check_no_entry_above_by_value::(block, |val| val)?; + self.db.check_no_entry_above::(block, |key| key)?; + self.db.check_no_entry_above::(block, |key| key.number())?; + self.db.check_no_entry_above::(block, |key| key.number())?; + Ok(()) } } } diff --git a/crates/stages/src/stages/tx_index.rs b/crates/stages/src/stages/tx_index.rs index fd4ac04f2e..aec3610015 100644 --- a/crates/stages/src/stages/tx_index.rs +++ b/crates/stages/src/stages/tx_index.rs @@ -87,7 +87,7 @@ impl Stage for TxIndex { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::{ + use crate::test_utils::{ stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError, UnwindStageTestRunner, }; @@ -96,7 +96,7 @@ mod tests { db::models::{BlockNumHash, StoredBlockBody}, test_utils::generators::random_header_range, }; - use reth_primitives::{SealedHeader, H256}; + use reth_primitives::H256; stage_test_suite!(TxIndexTestRunner); @@ -155,7 +155,7 @@ mod tests { fn validate_execution( &self, input: ExecInput, - output: Option, + _output: Option, ) -> Result<(), TestRunnerError> { self.db.query(|tx| { let (start, end) = ( @@ -187,21 +187,6 @@ mod tests { } impl UnwindStageTestRunner for TxIndexTestRunner { - fn seed_unwind( - &mut self, - input: UnwindInput, - highest_entry: u64, - ) -> Result<(), TestRunnerError> { - let headers = random_header_range(input.unwind_to..highest_entry, H256::zero()); - self.db.transform_append::(&headers, |prev, h| { - ( - BlockNumHash((h.number, h.hash())), - prev.unwrap_or_default() + (rand::random::() as u64), - ) - })?; - Ok(()) - } - fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { self.db.check_no_entry_above::(input.unwind_to, |h| { h.number() diff --git a/crates/stages/src/test_utils/macros.rs b/crates/stages/src/test_utils/macros.rs new file mode 100644 index 0000000000..bc2c8410b0 --- /dev/null +++ b/crates/stages/src/test_utils/macros.rs @@ -0,0 +1,106 @@ +// TODO: add comments +macro_rules! stage_test_suite { + ($runner:ident) => { + /// Check that the execution is short-circuited if the database is empty. + #[tokio::test] + async fn execute_empty_db() { + let runner = $runner::default(); + let input = crate::stage::ExecInput::default(); + let result = runner.execute(input).await.unwrap(); + assert_matches!( + result, + Err(crate::error::StageError::DatabaseIntegrity(_)) + ); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + /// Check that the execution is short-circuited if the target was already reached. + #[tokio::test] + async fn execute_already_reached_target() { + let stage_progress = 1000; + let mut runner = $runner::default(); + let input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, stage_progress)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(input).expect("failed to seed"); + let rx = runner.execute(input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == stage_progress + ); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + #[tokio::test] + async fn execute() { + let (previous_stage, stage_progress) = (1000, 100); + let mut runner = $runner::default(); + let input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(input).expect("failed to seed"); + let rx = runner.execute(input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + let result = rx.await.unwrap(); + println!("RESULT >>> {:?}", result.unwrap_err().to_string()); + // assert_matches!( + // result, + // Ok(ExecOutput { done, reached_tip, stage_progress }) + // if done && reached_tip && stage_progress == previous_stage + // ); + // assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + #[tokio::test] + // Check that unwind does not panic on empty database. + async fn unwind_empty_db() { + let runner = $runner::default(); + let input = crate::stage::UnwindInput::default(); + let rx = runner.unwind(input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to + ); + assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); + } + + #[tokio::test] + async fn unwind() { + let (previous_stage, stage_progress) = (1000, 100); + let mut runner = $runner::default(); + + // Run execute + let execute_input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(execute_input).expect("failed to seed"); + let rx = runner.execute(execute_input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == previous_stage + ); + assert!(runner.validate_execution(execute_input, result.ok()).is_ok(), "execution validation"); + + let unwind_input = crate::stage::UnwindInput { + unwind_to: stage_progress, stage_progress, bad_block: None, + }; + let rx = runner.unwind(unwind_input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to + ); + assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation"); + } + }; +} + +pub(crate) use stage_test_suite; diff --git a/crates/stages/src/test_utils/mod.rs b/crates/stages/src/test_utils/mod.rs new file mode 100644 index 0000000000..5fae3794f5 --- /dev/null +++ b/crates/stages/src/test_utils/mod.rs @@ -0,0 +1,14 @@ +use crate::StageId; + +mod macros; +mod runner; +mod stage_db; + +pub(crate) use macros::*; +pub(crate) use runner::{ + ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner, +}; +pub(crate) use stage_db::StageTestDB; + +/// The previous test stage id mock used for testing +pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage"); diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs new file mode 100644 index 0000000000..c81574dbd5 --- /dev/null +++ b/crates/stages/src/test_utils/runner.rs @@ -0,0 +1,81 @@ +use reth_db::{kv::Env, mdbx::WriteMap}; +use reth_interfaces::db::{self, DBContainer}; +use std::borrow::Borrow; +use tokio::sync::oneshot; + +use super::StageTestDB; +use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; + +#[derive(thiserror::Error, Debug)] +pub(crate) enum TestRunnerError { + #[error("Database error occured.")] + Database(#[from] db::Error), + #[error("Internal runner error occured.")] + Internal(#[from] Box), +} + +/// A generic test runner for stages. +#[async_trait::async_trait] +pub(crate) trait StageTestRunner { + type S: Stage> + 'static; + + /// Return a reference to the database. + fn db(&self) -> &StageTestDB; + + /// Return an instance of a Stage. + fn stage(&self) -> Self::S; +} + +#[async_trait::async_trait] +pub(crate) trait ExecuteStageTestRunner: StageTestRunner { + type Seed: Send + Sync; + + /// Seed database for stage execution + fn seed_execution(&mut self, input: ExecInput) -> Result; + + /// Validate stage execution + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError>; + + /// Run [Stage::execute] and return a receiver for the result. + fn execute(&self, input: ExecInput) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + let (db, mut stage) = (self.db().inner(), self.stage()); + tokio::spawn(async move { + let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); + let result = stage.execute(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send message") + }); + rx + } + + /// Run a hook after [Stage::execute]. Required for Headers & Bodies stages. + async fn after_execution(&self, _seed: Self::Seed) -> Result<(), TestRunnerError> { + Ok(()) + } +} + +pub(crate) trait UnwindStageTestRunner: StageTestRunner { + /// Validate the unwind + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>; + + /// Run [Stage::unwind] and return a receiver for the result. + fn unwind( + &self, + input: UnwindInput, + ) -> oneshot::Receiver>> { + let (tx, rx) = oneshot::channel(); + let (db, mut stage) = (self.db().inner(), self.stage()); + tokio::spawn(async move { + let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); + let result = stage.unwind(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send result"); + }); + rx + } +} diff --git a/crates/stages/src/test_utils/stage_db.rs b/crates/stages/src/test_utils/stage_db.rs new file mode 100644 index 0000000000..f14b1cb4f2 --- /dev/null +++ b/crates/stages/src/test_utils/stage_db.rs @@ -0,0 +1,149 @@ +use reth_db::{ + kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind}, + mdbx::{WriteMap, RW}, +}; +use reth_interfaces::db::{self, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table}; +use reth_primitives::BlockNumber; +use std::{borrow::Borrow, sync::Arc}; + +/// The [StageTestDB] is used as an internal +/// database for testing stage implementation. +/// +/// ```rust +/// let db = StageTestDB::default(); +/// stage.execute(&mut db.container(), input); +/// ``` +pub(crate) struct StageTestDB { + db: Arc>, +} + +impl Default for StageTestDB { + /// Create a new instance of [StageTestDB] + fn default() -> Self { + Self { db: create_test_db::(EnvKind::RW) } + } +} + +impl StageTestDB { + /// Return a database wrapped in [DBContainer]. + fn container(&self) -> DBContainer<'_, Env> { + DBContainer::new(self.db.borrow()).expect("failed to create db container") + } + + /// Get a pointer to an internal database. + pub(crate) fn inner(&self) -> Arc> { + self.db.clone() + } + + /// Invoke a callback with transaction committing it afterwards + pub(crate) fn commit(&self, f: F) -> Result<(), db::Error> + where + F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>, + { + let mut db = self.container(); + let tx = db.get_mut(); + f(tx)?; + db.commit()?; + Ok(()) + } + + /// Invoke a callback with a read transaction + pub(crate) fn query(&self, f: F) -> Result + where + F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result, + { + f(self.container().get()) + } + + /// Map a collection of values and store them in the database. + /// This function commits the transaction before exiting. + /// + /// ```rust + /// let db = StageTestDB::default(); + /// db.map_put::(&items, |item| item)?; + /// ``` + pub(crate) fn map_put(&self, values: &[S], mut map: F) -> Result<(), db::Error> + where + T: Table, + S: Clone, + F: FnMut(&S) -> (T::Key, T::Value), + { + self.commit(|tx| { + values.iter().try_for_each(|src| { + let (k, v) = map(src); + tx.put::(k, v) + }) + }) + } + + /// Transform a collection of values using a callback and store + /// them in the database. The callback additionally accepts the + /// optional last element that was stored. + /// This function commits the transaction before exiting. + /// + /// ```rust + /// let db = StageTestDB::default(); + /// db.transform_append::(&items, |prev, item| prev.unwrap_or_default() + item)?; + /// ``` + pub(crate) fn transform_append( + &self, + values: &[S], + mut transform: F, + ) -> Result<(), db::Error> + where + T: Table, + ::Value: Clone, + S: Clone, + F: FnMut(&Option<::Value>, &S) -> (T::Key, T::Value), + { + self.commit(|tx| { + let mut cursor = tx.cursor_mut::()?; + let mut last = cursor.last()?.map(|(_, v)| v); + values.iter().try_for_each(|src| { + let (k, v) = transform(&last, src); + last = Some(v.clone()); + cursor.append(k, v) + }) + }) + } + + /// Check that there is no table entry above a given + /// block by [Table::Key] + pub(crate) fn check_no_entry_above( + &self, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + T: Table, + F: FnMut(T::Key) -> BlockNumber, + { + self.query(|tx| { + let mut cursor = tx.cursor::()?; + if let Some((key, _)) = cursor.last()? { + assert!(selector(key) <= block); + } + Ok(()) + }) + } + + /// Check that there is no table entry above a given + /// block by [Table::Value] + pub(crate) fn check_no_entry_above_by_value( + &self, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + T: Table, + F: FnMut(T::Value) -> BlockNumber, + { + self.query(|tx| { + let mut cursor = tx.cursor::()?; + if let Some((_, value)) = cursor.last()? { + assert!(selector(value) <= block); + } + Ok(()) + }) + } +} diff --git a/crates/stages/src/util.rs b/crates/stages/src/util.rs index 6c5497385d..62ec6a9322 100644 --- a/crates/stages/src/util.rs +++ b/crates/stages/src/util.rs @@ -135,332 +135,3 @@ pub(crate) mod unwind { Ok(()) } } - -#[cfg(test)] -pub(crate) mod test_utils { - use reth_db::{ - kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind}, - mdbx::{WriteMap, RW}, - }; - use reth_interfaces::db::{self, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table}; - use reth_primitives::BlockNumber; - use std::{borrow::Borrow, sync::Arc}; - use tokio::sync::oneshot; - - use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; - - /// The previous test stage id mock used for testing - pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage"); - - /// The [StageTestDB] is used as an internal - /// database for testing stage implementation. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// stage.execute(&mut db.container(), input); - /// ``` - pub(crate) struct StageTestDB { - db: Arc>, - } - - impl Default for StageTestDB { - /// Create a new instance of [StageTestDB] - fn default() -> Self { - Self { db: create_test_db::(EnvKind::RW) } - } - } - - impl StageTestDB { - /// Return a database wrapped in [DBContainer]. - fn container(&self) -> DBContainer<'_, Env> { - DBContainer::new(self.db.borrow()).expect("failed to create db container") - } - - /// Get a pointer to an internal database. - pub(crate) fn inner(&self) -> Arc> { - self.db.clone() - } - - /// Invoke a callback with transaction committing it afterwards - pub(crate) fn commit(&self, f: F) -> Result<(), db::Error> - where - F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>, - { - let mut db = self.container(); - let tx = db.get_mut(); - f(tx)?; - db.commit()?; - Ok(()) - } - - /// Invoke a callback with a read transaction - pub(crate) fn query(&self, f: F) -> Result - where - F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result, - { - f(self.container().get()) - } - - /// Map a collection of values and store them in the database. - /// This function commits the transaction before exiting. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// db.map_put::(&items, |item| item)?; - /// ``` - pub(crate) fn map_put(&self, values: &[S], mut map: F) -> Result<(), db::Error> - where - T: Table, - S: Clone, - F: FnMut(&S) -> (T::Key, T::Value), - { - self.commit(|tx| { - values.iter().try_for_each(|src| { - let (k, v) = map(src); - tx.put::(k, v) - }) - }) - } - - /// Transform a collection of values using a callback and store - /// them in the database. The callback additionally accepts the - /// optional last element that was stored. - /// This function commits the transaction before exiting. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// db.transform_append::(&items, |prev, item| prev.unwrap_or_default() + item)?; - /// ``` - pub(crate) fn transform_append( - &self, - values: &[S], - mut transform: F, - ) -> Result<(), db::Error> - where - T: Table, - ::Value: Clone, - S: Clone, - F: FnMut(&Option<::Value>, &S) -> (T::Key, T::Value), - { - self.commit(|tx| { - let mut cursor = tx.cursor_mut::()?; - let mut last = cursor.last()?.map(|(_, v)| v); - values.iter().try_for_each(|src| { - let (k, v) = transform(&last, src); - last = Some(v.clone()); - cursor.append(k, v) - }) - }) - } - - /// Check that there is no table entry above a given - /// block by [Table::Key] - pub(crate) fn check_no_entry_above( - &self, - block: BlockNumber, - mut selector: F, - ) -> Result<(), db::Error> - where - T: Table, - F: FnMut(T::Key) -> BlockNumber, - { - self.query(|tx| { - let mut cursor = tx.cursor::()?; - if let Some((key, _)) = cursor.last()? { - assert!(selector(key) <= block); - } - Ok(()) - }) - } - - /// Check that there is no table entry above a given - /// block by [Table::Value] - pub(crate) fn check_no_entry_above_by_value( - &self, - block: BlockNumber, - mut selector: F, - ) -> Result<(), db::Error> - where - T: Table, - F: FnMut(T::Value) -> BlockNumber, - { - self.query(|tx| { - let mut cursor = tx.cursor::()?; - if let Some((_, value)) = cursor.last()? { - assert!(selector(value) <= block); - } - Ok(()) - }) - } - } - - #[derive(thiserror::Error, Debug)] - pub(crate) enum TestRunnerError { - #[error("Database error occured.")] - Database(#[from] db::Error), - #[error("Internal runner error occured.")] - Internal(#[from] Box), - } - - /// A generic test runner for stages. - #[async_trait::async_trait] - pub(crate) trait StageTestRunner { - type S: Stage> + 'static; - - /// Return a reference to the database. - fn db(&self) -> &StageTestDB; - - /// Return an instance of a Stage. - fn stage(&self) -> Self::S; - } - - #[async_trait::async_trait] - pub(crate) trait ExecuteStageTestRunner: StageTestRunner { - type Seed: Send + Sync; - - /// Seed database for stage execution - fn seed_execution(&mut self, input: ExecInput) -> Result; - - /// Validate stage execution - fn validate_execution( - &self, - input: ExecInput, - output: Option, - ) -> Result<(), TestRunnerError>; - - /// Run [Stage::execute] and return a receiver for the result. - fn execute(&self, input: ExecInput) -> oneshot::Receiver> { - let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.db().inner(), self.stage()); - tokio::spawn(async move { - let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); - let result = stage.execute(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send message") - }); - rx - } - - /// Run a hook after [Stage::execute]. Required for Headers & Bodies stages. - async fn after_execution(&self, seed: Self::Seed) -> Result<(), TestRunnerError> { - Ok(()) - } - } - - pub(crate) trait UnwindStageTestRunner: StageTestRunner { - /// Seed database for stage unwind - fn seed_unwind( - &mut self, - input: UnwindInput, - highest_entry: u64, - ) -> Result<(), TestRunnerError>; - - /// Validate the unwind - fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>; - - /// Run [Stage::unwind] and return a receiver for the result. - fn unwind( - &self, - input: UnwindInput, - ) -> oneshot::Receiver>> - { - let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.db().inner(), self.stage()); - tokio::spawn(async move { - let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); - let result = stage.unwind(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send result"); - }); - rx - } - } - - macro_rules! stage_test_suite { - ($runner:ident) => { - /// Check that the execution is short-circuited if the database is empty. - #[tokio::test] - async fn execute_empty_db() { - let runner = $runner::default(); - let input = crate::stage::ExecInput::default(); - let result = runner.execute(input).await.unwrap(); - assert_matches!( - result, - Err(crate::error::StageError::DatabaseIntegrity(_)) - ); - assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); - } - - #[tokio::test] - async fn execute_already_reached_target() { - let stage_progress = 1000; - let mut runner = $runner::default(); - let input = crate::stage::ExecInput { - previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, stage_progress)), - stage_progress: Some(stage_progress), - }; - let seed = runner.seed_execution(input).expect("failed to seed"); - let rx = runner.execute(input); - runner.after_execution(seed).await.expect("failed to run after execution hook"); - let result = rx.await.unwrap(); - assert_matches!( - result, - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == stage_progress - ); - assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); - } - - #[tokio::test] - async fn execute() { - let (previous_stage, stage_progress) = (1000, 100); - let mut runner = $runner::default(); - let input = crate::stage::ExecInput { - previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, previous_stage)), - stage_progress: Some(stage_progress), - }; - let seed = runner.seed_execution(input).expect("failed to seed"); - let rx = runner.execute(input); - runner.after_execution(seed).await.expect("failed to run after execution hook"); - let result = rx.await.unwrap(); - assert_matches!( - result, - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == previous_stage - ); - assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); - } - - #[tokio::test] - // Check that unwind does not panic on empty database. - async fn unwind_empty_db() { - let runner = $runner::default(); - let input = crate::stage::UnwindInput::default(); - let rx = runner.unwind(input); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to - ); - assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); - } - - #[tokio::test] - async fn unwind() { - let (unwind_to, highest_entry) = (100, 200); - let mut runner = $runner::default(); - let input = crate::stage::UnwindInput { - unwind_to, stage_progress: unwind_to, bad_block: None, - }; - runner.seed_unwind(input, highest_entry).expect("failed to seed"); - let rx = runner.unwind(input); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to - ); - assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); - } - }; - } - - pub(crate) use stage_test_suite; -}