From 4936d467c91177df58a4b24f49fdb8064cc6273d Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Sat, 19 Nov 2022 03:57:29 +0200 Subject: [PATCH] test(sync): stage test suite (#204) * test(sync): stage test suite * cleanup txindex tests * nit * start revamping bodies testing * revamp body testing * add comments to suite tests * fmt * cleanup dup code * cleanup insert_headers helper fn * fix tests * linter * switch mutex to atomic * cleanup * revert * test: make unwind runner return value instead of channel * test: make execute runner return value instead of channel * Revert "test: make execute runner return value instead of channel" This reverts commit f8608654f2e4cf97f60ce6aa95c28009f71d5331. Co-authored-by: Georgios Konstantopoulos --- crates/db/src/kv/mod.rs | 2 +- crates/interfaces/src/test_utils/headers.rs | 72 ++- crates/net/headers-downloaders/src/linear.rs | 2 +- crates/stages/src/lib.rs | 3 + crates/stages/src/stages/bodies.rs | 578 +++++++------------ crates/stages/src/stages/headers.rs | 401 ++++++------- crates/stages/src/stages/tx_index.rs | 219 +++---- crates/stages/src/test_utils/macros.rs | 145 +++++ crates/stages/src/test_utils/mod.rs | 14 + crates/stages/src/test_utils/runner.rs | 82 +++ crates/stages/src/test_utils/stage_db.rs | 179 ++++++ crates/stages/src/util.rs | 186 ------ 12 files changed, 951 insertions(+), 932 deletions(-) create mode 100644 crates/stages/src/test_utils/macros.rs create mode 100644 crates/stages/src/test_utils/mod.rs create mode 100644 crates/stages/src/test_utils/runner.rs create mode 100644 crates/stages/src/test_utils/stage_db.rs diff --git a/crates/db/src/kv/mod.rs b/crates/db/src/kv/mod.rs index 9333989f61..d97876ae08 100644 --- a/crates/db/src/kv/mod.rs +++ b/crates/db/src/kv/mod.rs @@ -61,7 +61,7 @@ impl Env { inner: Environment::new() .set_max_dbs(TABLES.len()) .set_geometry(Geometry { - size: Some(0..0x100000), // TODO: reevaluate + size: Some(0..0x1000000), // TODO: reevaluate growth_step: Some(0x100000), // TODO: reevaluate shrink_threshold: None, page_size: Some(PageSize::Set(default_page_size())), diff --git a/crates/interfaces/src/test_utils/headers.rs b/crates/interfaces/src/test_utils/headers.rs index e0f73fbe54..087594c092 100644 --- a/crates/interfaces/src/test_utils/headers.rs +++ b/crates/interfaces/src/test_utils/headers.rs @@ -1,6 +1,6 @@ //! Testing support for headers related interfaces. use crate::{ - consensus::{self, Consensus, Error}, + consensus::{self, Consensus}, p2p::headers::{ client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream}, downloader::HeaderDownloader, @@ -9,20 +9,28 @@ use crate::{ }; use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512}; use reth_rpc_types::engine::ForkchoiceState; -use std::{collections::HashSet, sync::Arc, time::Duration}; +use std::{ + collections::HashSet, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; use tokio::sync::{broadcast, mpsc, watch}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; /// A test downloader which just returns the values that have been pushed to it. #[derive(Debug)] pub struct TestHeaderDownloader { - result: Result, DownloadError>, + client: Arc, + consensus: Arc, } impl TestHeaderDownloader { /// Instantiates the downloader with the mock responses - pub fn new(result: Result, DownloadError>) -> Self { - Self { result } + pub fn new(client: Arc, consensus: Arc) -> Self { + Self { client, consensus } } } @@ -36,11 +44,11 @@ impl HeaderDownloader for TestHeaderDownloader { } fn consensus(&self) -> &Self::Consensus { - unimplemented!() + &self.consensus } fn client(&self) -> &Self::Client { - unimplemented!() + &self.client } async fn download( @@ -48,7 +56,27 @@ impl HeaderDownloader for TestHeaderDownloader { _: &SealedHeader, _: &ForkchoiceState, ) -> Result, DownloadError> { - self.result.clone() + // call consensus stub first. fails if the flag is set + let empty = SealedHeader::default(); + self.consensus + .validate_header(&empty, &empty) + .map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?; + + let stream = self.client.stream_headers().await; + let stream = stream.timeout(Duration::from_secs(1)); + + match Box::pin(stream).try_next().await { + Ok(Some(res)) => { + let mut headers = res.headers.iter().map(|h| h.clone().seal()).collect::>(); + if !headers.is_empty() { + headers.sort_unstable_by_key(|h| h.number); + headers.remove(0); // remove head from response + headers.reverse(); + } + Ok(headers) + } + _ => Err(DownloadError::Timeout { request_id: 0 }), + } } } @@ -93,6 +121,12 @@ impl TestHeadersClient { pub fn send_header_response(&self, id: u64, headers: Vec
) { self.res_tx.send((id, headers).into()).expect("failed to send header response"); } + + /// Helper for pushing responses to the client + pub async fn send_header_response_delayed(&self, id: u64, headers: Vec
, secs: u64) { + tokio::time::sleep(Duration::from_secs(secs)).await; + self.send_header_response(id, headers); + } } #[async_trait::async_trait] @@ -106,6 +140,9 @@ impl HeadersClient for TestHeadersClient { } async fn stream_headers(&self) -> HeadersStream { + if !self.res_rx.is_empty() { + println!("WARNING: broadcast receiver already contains messages.") + } Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok())) } } @@ -116,7 +153,7 @@ pub struct TestConsensus { /// Watcher over the forkchoice state channel: (watch::Sender, watch::Receiver), /// Flag whether the header validation should purposefully fail - fail_validation: bool, + fail_validation: AtomicBool, } impl Default for TestConsensus { @@ -127,7 +164,7 @@ impl Default for TestConsensus { finalized_block_hash: H256::zero(), safe_block_hash: H256::zero(), }), - fail_validation: false, + fail_validation: AtomicBool::new(false), } } } @@ -143,9 +180,14 @@ impl TestConsensus { self.channel.0.send(state).expect("updating fork choice state failed"); } + /// Get the failed validation flag + pub fn fail_validation(&self) -> bool { + self.fail_validation.load(Ordering::SeqCst) + } + /// Update the validation flag - pub fn set_fail_validation(&mut self, val: bool) { - self.fail_validation = val; + pub fn set_fail_validation(&self, val: bool) { + self.fail_validation.store(val, Ordering::SeqCst) } } @@ -160,15 +202,15 @@ impl Consensus for TestConsensus { _header: &SealedHeader, _parent: &SealedHeader, ) -> Result<(), consensus::Error> { - if self.fail_validation { + if self.fail_validation() { Err(consensus::Error::BaseFeeMissing) } else { Ok(()) } } - fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), Error> { - if self.fail_validation { + fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), consensus::Error> { + if self.fail_validation() { Err(consensus::Error::BaseFeeMissing) } else { Ok(()) diff --git a/crates/net/headers-downloaders/src/linear.rs b/crates/net/headers-downloaders/src/linear.rs index b2a2d34ca0..1dc2dae6d9 100644 --- a/crates/net/headers-downloaders/src/linear.rs +++ b/crates/net/headers-downloaders/src/linear.rs @@ -215,7 +215,7 @@ mod tests { static CONSENSUS: Lazy> = Lazy::new(|| Arc::new(TestConsensus::default())); static CONSENSUS_FAIL: Lazy> = Lazy::new(|| { - let mut consensus = TestConsensus::default(); + let consensus = TestConsensus::default(); consensus.set_fail_validation(true); Arc::new(consensus) }); diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 697e2f525f..b44b183ce6 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -20,6 +20,9 @@ mod pipeline; mod stage; mod util; +#[cfg(test)] +mod test_utils; + /// Implementations of stages. pub mod stages; diff --git a/crates/stages/src/stages/bodies.rs b/crates/stages/src/stages/bodies.rs index 5c32ebb3af..76ecec26da 100644 --- a/crates/stages/src/stages/bodies.rs +++ b/crates/stages/src/stages/bodies.rs @@ -15,7 +15,7 @@ use reth_primitives::{ proofs::{EMPTY_LIST_HASH, EMPTY_ROOT}, BlockLocked, BlockNumber, SealedHeader, H256, }; -use std::fmt::Debug; +use std::{fmt::Debug, sync::Arc}; use tracing::warn; const BODIES: StageId = StageId("Bodies"); @@ -51,9 +51,9 @@ const BODIES: StageId = StageId("Bodies"); #[derive(Debug)] pub struct BodyStage { /// The body downloader. - pub downloader: D, + pub downloader: Arc, /// The consensus engine. - pub consensus: C, + pub consensus: Arc, /// The maximum amount of block bodies to process in one stage execution. /// /// Smaller batch sizes result in less memory usage, but more disk I/O. Larger batch sizes @@ -81,6 +81,9 @@ impl Stage for BodyStage BodyStage { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::StageTestRunner; - use assert_matches::assert_matches; - use reth_eth_wire::BlockBody; - use reth_interfaces::{ - consensus, - p2p::bodies::error::DownloadError, - test_utils::generators::{random_block, random_block_range}, + use crate::test_utils::{ + stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, + PREV_STAGE_ID, }; - use reth_primitives::{BlockNumber, H256}; + use assert_matches::assert_matches; + use reth_interfaces::{consensus, p2p::bodies::error::DownloadError}; use std::collections::HashMap; use test_utils::*; - /// Check that the execution is short-circuited if the database is empty. - #[tokio::test] - async fn empty_db() { - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true }) - ) - } - - /// Check that the execution is short-circuited if the target was already reached. - #[tokio::test] - async fn already_reached_target() { - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), 100)), - stage_progress: Some(100), - }); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress: 100, reached_tip: true, done: true }) - ) - } + stage_test_suite!(BodyTestRunner); /// Checks that the stage downloads at most `batch_size` blocks. #[tokio::test] async fn partial_body_download() { - // Generate blocks - let blocks = random_block_range(1..200, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 200); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size (max we sync per stage execution) to less than the number of blocks // the previous stage synced (10 vs 20) runner.set_batch_size(10); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we only synced around `batch_size` blocks even though the number of blocks // synced by the previous stage is higher @@ -299,34 +271,27 @@ mod tests { output, Ok(ExecOutput { stage_progress, reached_tip: true, done: false }) if stage_progress < 200 ); - runner - .validate_db_blocks(output.unwrap().stage_progress) - .expect("Written block data invalid"); + assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation"); } /// Same as [partial_body_download] except the `batch_size` is not hit. #[tokio::test] async fn full_body_download() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 20); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size to more than what the previous stage synced (40 vs 20) runner.set_batch_size(40); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we synced all blocks successfully, even though our `batch_size` allows us to // sync more (if there were more headers) @@ -335,31 +300,26 @@ mod tests { output, Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true }) ); - runner - .validate_db_blocks(output.unwrap().stage_progress) - .expect("Written block data invalid"); + assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation"); } /// Same as [full_body_download] except we have made progress before #[tokio::test] async fn sync_from_previous_progress() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 21); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); + + runner.set_batch_size(10); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we synced at least 10 blocks let first_run = rx.await.unwrap(); @@ -370,10 +330,11 @@ mod tests { let first_run_progress = first_run.unwrap().stage_progress; // Execute again on top of the previous run - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), stage_progress: Some(first_run_progress), - }); + }; + let rx = runner.execute(input); // Check that we synced more blocks let output = rx.await.unwrap(); @@ -381,175 +342,86 @@ mod tests { output, Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress > first_run_progress ); - runner - .validate_db_blocks(output.unwrap().stage_progress) - .expect("Written block data invalid"); + assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation"); } /// Checks that the stage asks to unwind if pre-validation of the block fails. #[tokio::test] async fn pre_validation_failure() { - // Generate blocks #1-19 - let blocks = random_block_range(1..20, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 20); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Fail validation - runner.set_fail_validation(true); - - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); + runner.consensus.set_fail_validation(true); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that the error bubbles up assert_matches!( rx.await.unwrap(), - Err(StageError::Validation { block: 1, error: consensus::Error::BaseFeeMissing }) + Err(StageError::Validation { error: consensus::Error::BaseFeeMissing, .. }) ); - } - - /// Checks that the stage unwinds correctly with no data. - #[tokio::test] - async fn unwind_empty_db() { - let unwind_to = 10; - let runner = BodyTestRunner::new(TestBodyDownloader::default); - let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress: 20, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to - ) - } - - /// Checks that the stage unwinds correctly with data. - #[tokio::test] - async fn unwind() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); - - // Set the batch size to more than what the previous stage synced (40 vs 20) - runner.set_batch_size(40); - - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - - // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); - - // Check that we synced all blocks successfully, even though our `batch_size` allows us to - // sync more (if there were more headers) - let output = rx.await.unwrap(); - assert_matches!( - output, - Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true }) - ); - let stage_progress = output.unwrap().stage_progress; - runner.validate_db_blocks(stage_progress).expect("Written block data invalid"); - - // Unwind all of it - let unwind_to = 1; - let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == 1 - ); - - let last_body = runner.last_body().expect("Could not read last body"); - let last_tx_id = last_body.base_tx_id + last_body.tx_amount; - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("Did not unwind block bodies correctly."); - runner - .db() - .check_no_entry_above::(last_tx_id, |key| key) - .expect("Did not unwind transactions correctly.") + assert!(runner.validate_execution(input, None).is_ok(), "execution validation"); } /// Checks that the stage unwinds correctly, even if a transaction in a block is missing. #[tokio::test] async fn unwind_missing_tx() { - // Generate blocks #1-20 - let blocks = random_block_range(1..21, GENESIS_HASH); - let bodies: HashMap> = - blocks.iter().map(body_by_hash).collect(); - let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone())); + let (stage_progress, previous_stage) = (1, 20); + + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); // Set the batch size to more than what the previous stage synced (40 vs 20) runner.set_batch_size(40); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner - .insert_headers(blocks.iter().map(|block| &block.header)) - .expect("Could not insert headers"); - // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that we synced all blocks successfully, even though our `batch_size` allows us to // sync more (if there were more headers) let output = rx.await.unwrap(); assert_matches!( output, - Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true }) + Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress == previous_stage ); let stage_progress = output.unwrap().stage_progress; runner.validate_db_blocks(stage_progress).expect("Written block data invalid"); // Delete a transaction - { - let mut db = runner.db().container(); - let mut tx_cursor = db - .get_mut() - .cursor_mut::() - .expect("Could not get transaction cursor"); - tx_cursor - .last() - .expect("Could not read database") - .expect("Could not read last transaction"); - tx_cursor.delete_current().expect("Could not delete last transaction"); - db.commit().expect("Could not commit database"); - } + runner + .db() + .commit(|tx| { + let mut tx_cursor = tx.cursor_mut::()?; + tx_cursor.last()?.expect("Could not read last transaction"); + tx_cursor.delete_current()?; + Ok(()) + }) + .expect("Could not delete a transaction"); // Unwind all of it let unwind_to = 1; - let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to }); + let input = UnwindInput { bad_block: None, stage_progress, unwind_to }; + let res = runner.unwind(input).await; assert_matches!( - rx.await.unwrap(), + res, Ok(UnwindOutput { stage_progress }) if stage_progress == 1 ); - let last_body = runner.last_body().expect("Could not read last body"); - let last_tx_id = last_body.base_tx_id + last_body.tx_amount; - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("Did not unwind block bodies correctly."); - runner - .db() - .check_no_entry_above::(last_tx_id, |key| key) - .expect("Did not unwind transactions correctly.") + assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); } /// Checks that the stage exits if the downloader times out @@ -557,54 +429,53 @@ mod tests { /// try again? #[tokio::test] async fn downloader_timeout() { - // Generate a header - let header = random_block(1, Some(GENESIS_HASH)).header; - let runner = BodyTestRunner::new(|| { - TestBodyDownloader::new(HashMap::from([( - header.hash(), - Err(DownloadError::Timeout { header_hash: header.hash() }), - )])) - }); + let (stage_progress, previous_stage) = (1, 2); - // Insert required state - runner.insert_genesis().expect("Could not insert genesis block"); - runner.insert_header(&header).expect("Could not insert header"); + // Set up test runner + let mut runner = BodyTestRunner::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let blocks = runner.seed_execution(input).expect("failed to seed execution"); + + // overwrite responses + let header = blocks.last().unwrap(); + runner.set_responses(HashMap::from([( + header.hash(), + Err(DownloadError::Timeout { header_hash: header.hash() }), + )])); // Run the stage - let rx = runner.execute(ExecInput { - previous_stage: Some((StageId("Headers"), 1)), - stage_progress: None, - }); + let rx = runner.execute(input); // Check that the error bubbles up assert_matches!(rx.await.unwrap(), Err(StageError::Internal(_))); + assert!(runner.validate_execution(input, None).is_ok(), "execution validation"); } mod test_utils { use crate::{ stages::bodies::BodyStage, - util::test_utils::{StageTestDB, StageTestRunner}, + test_utils::{ + ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError, + UnwindStageTestRunner, + }, + ExecInput, ExecOutput, UnwindInput, }; use assert_matches::assert_matches; - use async_trait::async_trait; use reth_eth_wire::BlockBody; use reth_interfaces::{ - db, - db::{ - models::{BlockNumHash, StoredBlockBody}, - tables, DbCursorRO, DbTx, DbTxMut, - }, + db::{models::StoredBlockBody, tables, DbCursorRO, DbTx, DbTxMut}, p2p::bodies::{ client::BodiesClient, downloader::{BodiesStream, BodyDownloader}, error::{BodiesClientError, DownloadError}, }, - test_utils::TestConsensus, + test_utils::{generators::random_block_range, TestConsensus}, }; - use reth_primitives::{ - BigEndianHash, BlockLocked, BlockNumber, Header, SealedHeader, H256, U256, - }; - use std::{collections::HashMap, ops::Deref, time::Duration}; + use reth_primitives::{BlockLocked, BlockNumber, Header, SealedHeader, H256}; + use std::{collections::HashMap, sync::Arc, time::Duration}; /// The block hash of the genesis block. pub(crate) const GENESIS_HASH: H256 = H256::zero(); @@ -623,43 +494,38 @@ mod tests { } /// A helper struct for running the [BodyStage]. - pub(crate) struct BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - downloader_builder: F, + pub(crate) struct BodyTestRunner { + pub(crate) consensus: Arc, + responses: HashMap>, db: StageTestDB, batch_size: u64, - fail_validation: bool, } - impl BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { - /// Build a new test runner. - pub(crate) fn new(downloader_builder: F) -> Self { - BodyTestRunner { - downloader_builder, + impl Default for BodyTestRunner { + fn default() -> Self { + Self { + consensus: Arc::new(TestConsensus::default()), + responses: HashMap::default(), db: StageTestDB::default(), - batch_size: 10, - fail_validation: false, + batch_size: 1000, } } + } + impl BodyTestRunner { pub(crate) fn set_batch_size(&mut self, batch_size: u64) { self.batch_size = batch_size; } - pub(crate) fn set_fail_validation(&mut self, fail_validation: bool) { - self.fail_validation = fail_validation; + pub(crate) fn set_responses( + &mut self, + responses: HashMap>, + ) { + self.responses = responses; } } - impl StageTestRunner for BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { + impl StageTestRunner for BodyTestRunner { type S = BodyStage; fn db(&self) -> &StageTestDB { @@ -667,115 +533,115 @@ mod tests { } fn stage(&self) -> Self::S { - let mut consensus = TestConsensus::default(); - consensus.set_fail_validation(self.fail_validation); - BodyStage { - downloader: (self.downloader_builder)(), - consensus, + downloader: Arc::new(TestBodyDownloader::new(self.responses.clone())), + consensus: self.consensus.clone(), batch_size: self.batch_size, } } } - impl BodyTestRunner - where - F: Fn() -> TestBodyDownloader, - { + #[async_trait::async_trait] + impl ExecuteStageTestRunner for BodyTestRunner { + type Seed = Vec; + + fn seed_execution(&mut self, input: ExecInput) -> Result { + let start = input.stage_progress.unwrap_or_default(); + let end = + input.previous_stage.as_ref().map(|(_, num)| *num + 1).unwrap_or_default(); + let blocks = random_block_range(start..end, GENESIS_HASH); + self.insert_genesis()?; + self.db.insert_headers(blocks.iter().map(|block| &block.header))?; + self.set_responses(blocks.iter().map(body_by_hash).collect()); + Ok(blocks) + } + + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError> { + let highest_block = match output.as_ref() { + Some(output) => output.stage_progress, + None => input.stage_progress.unwrap_or_default(), + }; + self.validate_db_blocks(highest_block) + } + } + + impl UnwindStageTestRunner for BodyTestRunner { + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { + self.db.check_no_entry_above::(input.unwind_to, |key| { + key.number() + })?; + if let Some(last_body) = self.last_body() { + let last_tx_id = last_body.base_tx_id + last_body.tx_amount; + self.db + .check_no_entry_above::(last_tx_id, |key| key)?; + } + Ok(()) + } + } + + impl BodyTestRunner { /// Insert the genesis block into the appropriate tables /// /// The genesis block always has no transactions and no ommers, and it always has the /// same hash. - pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> { - self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?; - let mut db = self.db.container(); - let tx = db.get_mut(); - tx.put::( - (0, GENESIS_HASH).into(), - StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] }, - )?; - db.commit()?; - - Ok(()) - } - - /// Insert header into tables - pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> { - self.insert_headers(std::iter::once(header)) - } - - /// Insert headers into tables - pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> - where - I: Iterator, - { - let headers = headers.collect::>(); - self.db - .map_put::(&headers, |h| (h.hash(), h.number))?; - self.db.map_put::(&headers, |h| { - (BlockNumHash((h.number, h.hash())), h.deref().clone().unseal()) - })?; - self.db.map_put::(&headers, |h| { - (h.number, h.hash()) - })?; - - self.db.transform_append::(&headers, |prev, h| { - let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default()); - ( - BlockNumHash((h.number, h.hash())), - H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(), + pub(crate) fn insert_genesis(&self) -> Result<(), TestRunnerError> { + let header = SealedHeader::new(Header::default(), GENESIS_HASH); + self.db.insert_headers(std::iter::once(&header))?; + self.db.commit(|tx| { + tx.put::( + (0, GENESIS_HASH).into(), + StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] }, ) })?; Ok(()) } + /// Retrieve the last body from the database pub(crate) fn last_body(&self) -> Option { - Some( - self.db() - .container() - .get() - .cursor::() - .ok()? - .last() - .ok()?? - .1, - ) + self.db + .query(|tx| Ok(tx.cursor::()?.last()?.map(|e| e.1))) + .ok() + .flatten() } /// Validate that the inserted block data is valid pub(crate) fn validate_db_blocks( &self, highest_block: BlockNumber, - ) -> Result<(), db::Error> { - let db = self.db.container(); - let tx = db.get(); + ) -> Result<(), TestRunnerError> { + self.db.query(|tx| { + let mut block_body_cursor = tx.cursor::()?; + let mut transaction_cursor = tx.cursor::()?; - let mut block_body_cursor = tx.cursor::()?; - let mut transaction_cursor = tx.cursor::()?; - - let mut entry = block_body_cursor.first()?; - let mut prev_max_tx_id = 0; - while let Some((key, body)) = entry { - assert!( - key.number() <= highest_block, - "We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}", - key.number(), highest_block - ); - - assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed."); - for num in 0..body.tx_amount { - let tx_id = body.base_tx_id + num; - assert_matches!( - transaction_cursor.seek_exact(tx_id), - Ok(Some(_)), - "A transaction is missing." + let mut entry = block_body_cursor.first()?; + let mut prev_max_tx_id = 0; + while let Some((key, body)) = entry { + assert!( + key.number() <= highest_block, + "We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}", + key.number(), highest_block ); - } - prev_max_tx_id = body.base_tx_id + body.tx_amount; - entry = block_body_cursor.next()?; - } + assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed."); + for num in 0..body.tx_amount { + let tx_id = body.base_tx_id + num; + assert_matches!( + transaction_cursor.seek_exact(tx_id), + Ok(Some(_)), + "A transaction is missing." + ); + } + prev_max_tx_id = body.base_tx_id + body.tx_amount; + entry = block_body_cursor.next()?; + } + + Ok(()) + })?; Ok(()) } } @@ -785,7 +651,7 @@ mod tests { #[derive(Debug)] pub(crate) struct NoopClient; - #[async_trait] + #[async_trait::async_trait] impl BodiesClient for NoopClient { async fn get_block_body(&self, _: H256) -> Result { panic!("Noop client should not be called") @@ -794,7 +660,7 @@ mod tests { // TODO(onbjerg): Move /// A [BodyDownloader] that is backed by an internal [HashMap] for testing. - #[derive(Debug, Default)] + #[derive(Debug, Default, Clone)] pub(crate) struct TestBodyDownloader { responses: HashMap>, } @@ -824,14 +690,12 @@ mod tests { { Box::pin(futures_util::stream::iter(hashes.into_iter().map( |(block_number, hash)| { - Ok(( - *block_number, - *hash, - self.responses - .get(hash) - .expect("Stage tried downloading a block we do not have.") - .clone()?, - )) + let result = self + .responses + .get(hash) + .expect("Stage tried downloading a block we do not have.") + .clone()?; + Ok((*block_number, *hash, result)) }, ))) } diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index a20fed002f..52abdf4884 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -58,8 +58,8 @@ impl Stage(tx, last_block_num).await?; + // TODO: add batch size // download the headers - // TODO: handle input.max_block let last_hash = tx .get::(last_block_num)? .ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?; @@ -190,214 +190,99 @@ impl HeaderStage { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::StageTestRunner; - use assert_matches::assert_matches; - use reth_interfaces::{ - consensus, - test_utils::{ - generators::{random_header, random_header_range}, - TestHeaderDownloader, - }, + use crate::test_utils::{ + stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, }; - use test_utils::HeadersTestRunner; + use assert_matches::assert_matches; + use test_runner::HeadersTestRunner; - const TEST_STAGE: StageId = StageId("Headers"); + stage_test_suite!(HeadersTestRunner); /// Check that the execution errors on empty database or /// prev progress missing from the database. #[tokio::test] - async fn execute_empty_db() { - let runner = HeadersTestRunner::default(); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. })) - ); - } - - /// Check that the execution exits on downloader timeout. - #[tokio::test] + // Validate that the execution does not fail on timeout async fn execute_timeout() { - let head = random_header(0, None); - let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err( - DownloadError::Timeout { request_id: 0 }, - ))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput::default()); + let mut runner = HeadersTestRunner::default(); + let input = ExecInput::default(); + runner.seed_execution(input).expect("failed to seed execution"); + let rx = runner.execute(input); runner.consensus.update_tip(H256::from_low_u64_be(1)); - assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done); + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 0 }) + ); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } /// Check that validation error is propagated during the execution. #[tokio::test] async fn execute_validation_error() { - let head = random_header(0, None); - let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err( - DownloadError::HeaderValidation { - hash: H256::zero(), - error: consensus::Error::BaseFeeMissing, - }, - ))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput::default()); - runner.consensus.update_tip(H256::from_low_u64_be(1)); - assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block, error: consensus::Error::BaseFeeMissing, }) if block == 0); - } - - /// Validate that all necessary tables are updated after the - /// header download on no previous progress. - #[tokio::test] - async fn execute_no_progress() { - let (start, end) = (0, 100); - let head = random_header(start, None); - let headers = random_header_range(start + 1..end, head.hash()); - - let result = headers.iter().rev().cloned().collect::>(); - let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput::default()); - let tip = headers.last().unwrap(); - runner.consensus.update_tip(tip.hash()); - - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number - ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok()); - } - - /// Validate that all necessary tables are updated after the - /// header download with some previous progress. - #[tokio::test] - async fn execute_prev_progress() { - let (start, end) = (10000, 10241); - let head = random_header(start, None); - let headers = random_header_range(start + 1..end, head.hash()); - - let result = headers.iter().rev().cloned().collect::>(); - let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput { - previous_stage: Some((TEST_STAGE, head.number)), - stage_progress: Some(head.number), - }); - let tip = headers.last().unwrap(); - runner.consensus.update_tip(tip.hash()); - - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number - ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok()); + let mut runner = HeadersTestRunner::default(); + runner.consensus.set_fail_validation(true); + let input = ExecInput::default(); + let seed = runner.seed_execution(input).expect("failed to seed execution"); + let rx = runner.execute(input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + let result = rx.await.unwrap(); + assert_matches!(result, Err(StageError::Validation { .. })); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } /// Execute the stage with linear downloader #[tokio::test] async fn execute_with_linear_downloader() { - let (start, end) = (1000, 1200); - let head = random_header(start, None); - let headers = random_header_range(start + 1..end, head.hash()); - - let runner = HeadersTestRunner::with_linear_downloader(); - runner.insert_header(&head).expect("failed to insert header"); - let rx = runner.execute(ExecInput { - previous_stage: Some((TEST_STAGE, head.number)), - stage_progress: Some(head.number), - }); + let mut runner = HeadersTestRunner::with_linear_downloader(); + let (stage_progress, previous_stage) = (1000, 1200); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let headers = runner.seed_execution(input).expect("failed to seed execution"); + let rx = runner.execute(input); + // skip `after_execution` hook for linear downloader let tip = headers.last().unwrap(); runner.consensus.update_tip(tip.hash()); - let mut download_result = headers.clone(); - download_result.insert(0, head); + let download_result = headers.clone(); runner .client .on_header_request(1, |id, _| { - runner.client.send_header_response( - id, - download_result.clone().into_iter().map(|h| h.unseal()).collect(), - ) + let response = download_result.iter().map(|h| h.clone().unseal()).collect(); + runner.client.send_header_response(id, response) }) .await; + let result = rx.await.unwrap(); assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number + result, + Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) if stage_progress == tip.number ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok()); + assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } - /// Check that unwind does not panic on empty database. - #[tokio::test] - async fn unwind_empty_db() { - let unwind_to = 100; - let runner = HeadersTestRunner::default(); - let rx = - runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to - ); - } - - /// Check that unwind can remove headers across gaps - #[tokio::test] - async fn unwind_db_gaps() { - let runner = HeadersTestRunner::default(); - let head = random_header(0, None); - let first_range = random_header_range(1..20, head.hash()); - let second_range = random_header_range(50..100, H256::zero()); - runner.insert_header(&head).expect("failed to insert header"); - runner - .insert_headers(first_range.iter().chain(second_range.iter())) - .expect("failed to insert headers"); - - let unwind_to = 15; - let rx = - runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to - ); - - runner - .db() - .check_no_entry_above::(unwind_to, |key| key) - .expect("failed to check cannonical headers"); - runner - .db() - .check_no_entry_above_by_value::(unwind_to, |val| val) - .expect("failed to check header numbers"); - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("failed to check headers"); - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("failed to check td"); - } - - mod test_utils { + mod test_runner { use crate::{ stages::headers::HeaderStage, - util::test_utils::{StageTestDB, StageTestRunner}, + test_utils::{ + ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError, + UnwindStageTestRunner, + }, + ExecInput, ExecOutput, UnwindInput, }; use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader}; use reth_interfaces::{ - db::{self, models::blocks::BlockNumHash, tables, DbTx}, + db::{models::blocks::BlockNumHash, tables, DbTx}, p2p::headers::downloader::HeaderDownloader, - test_utils::{TestConsensus, TestHeaderDownloader, TestHeadersClient}, + test_utils::{ + generators::{random_header, random_header_range}, + TestConsensus, TestHeaderDownloader, TestHeadersClient, + }, }; - use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256}; - use std::{ops::Deref, sync::Arc}; + use reth_primitives::{BlockNumber, SealedHeader, H256, U256}; + use std::sync::Arc; pub(crate) struct HeadersTestRunner { pub(crate) consensus: Arc, @@ -408,10 +293,12 @@ mod tests { impl Default for HeadersTestRunner { fn default() -> Self { + let client = Arc::new(TestHeadersClient::default()); + let consensus = Arc::new(TestConsensus::default()); Self { - client: Arc::new(TestHeadersClient::default()), - consensus: Arc::new(TestConsensus::default()), - downloader: Arc::new(TestHeaderDownloader::new(Ok(Vec::default()))), + client: client.clone(), + consensus: consensus.clone(), + downloader: Arc::new(TestHeaderDownloader::new(client, consensus)), db: StageTestDB::default(), } } @@ -433,6 +320,99 @@ mod tests { } } + #[async_trait::async_trait] + impl ExecuteStageTestRunner for HeadersTestRunner { + type Seed = Vec; + + fn seed_execution(&mut self, input: ExecInput) -> Result { + let start = input.stage_progress.unwrap_or_default(); + let head = random_header(start, None); + self.db.insert_headers(std::iter::once(&head))?; + + // use previous progress as seed size + let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1; + + if start + 1 >= end { + return Ok(Vec::default()) + } + + let mut headers = random_header_range(start + 1..end, head.hash()); + headers.insert(0, head); + Ok(headers) + } + + async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> { + let tip = if !headers.is_empty() { + headers.last().unwrap().hash() + } else { + H256::from_low_u64_be(rand::random()) + }; + self.consensus.update_tip(tip); + self.client + .send_header_response_delayed( + 0, + headers.into_iter().map(|h| h.unseal()).collect(), + 1, + ) + .await; + Ok(()) + } + + /// Validate stored headers + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError> { + let initial_stage_progress = input.stage_progress.unwrap_or_default(); + match output { + Some(output) if output.stage_progress > initial_stage_progress => { + self.db.query(|tx| { + for block_num in (initial_stage_progress..output.stage_progress).rev() { + // look up the header hash + let hash = tx + .get::(block_num)? + .expect("no header hash"); + let key: BlockNumHash = (block_num, hash).into(); + + // validate the header number + assert_eq!(tx.get::(hash)?, Some(block_num)); + + // validate the header + let header = tx.get::(key)?; + assert!(header.is_some()); + let header = header.unwrap().seal(); + assert_eq!(header.hash(), hash); + + // validate td consistency in the database + if header.number > initial_stage_progress { + let parent_td = tx.get::( + (header.number - 1, header.parent_hash).into(), + )?; + let td = tx.get::(key)?.unwrap(); + assert_eq!( + parent_td.map( + |td| U256::from_big_endian(&td) + header.difficulty + ), + Some(U256::from_big_endian(&td)) + ); + } + } + Ok(()) + })?; + } + _ => self.check_no_header_entry_above(initial_stage_progress)?, + }; + Ok(()) + } + } + + impl UnwindStageTestRunner for HeadersTestRunner { + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { + self.check_no_header_entry_above(input.unwind_to) + } + } + impl HeadersTestRunner> { pub(crate) fn with_linear_downloader() -> Self { let client = Arc::new(TestHeadersClient::default()); @@ -445,74 +425,15 @@ mod tests { } impl HeadersTestRunner { - pub(crate) fn with_downloader(downloader: D) -> Self { - HeadersTestRunner { - client: Arc::new(TestHeadersClient::default()), - consensus: Arc::new(TestConsensus::default()), - downloader: Arc::new(downloader), - db: StageTestDB::default(), - } - } - - /// Insert header into tables - pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> { - self.insert_headers(std::iter::once(header)) - } - - /// Insert headers into tables - pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> - where - I: Iterator, - { - let headers = headers.collect::>(); - self.db - .map_put::(&headers, |h| (h.hash(), h.number))?; - self.db.map_put::(&headers, |h| { - (BlockNumHash((h.number, h.hash())), h.deref().clone().unseal()) - })?; - self.db.map_put::(&headers, |h| { - (h.number, h.hash()) - })?; - - self.db.transform_append::(&headers, |prev, h| { - let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default()); - ( - BlockNumHash((h.number, h.hash())), - H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(), - ) - })?; - - Ok(()) - } - - /// Validate stored header against provided - pub(crate) fn validate_db_header( + pub(crate) fn check_no_header_entry_above( &self, - header: &SealedHeader, - ) -> Result<(), db::Error> { - let db = self.db.container(); - let tx = db.get(); - let key: BlockNumHash = (header.number, header.hash()).into(); - - let db_number = tx.get::(header.hash())?; - assert_eq!(db_number, Some(header.number)); - - let db_header = tx.get::(key)?; - assert_eq!(db_header, Some(header.clone().unseal())); - - let db_canonical_header = tx.get::(header.number)?; - assert_eq!(db_canonical_header, Some(header.hash())); - - if header.number != 0 { - let parent_key: BlockNumHash = (header.number - 1, header.parent_hash).into(); - let parent_td = tx.get::(parent_key)?; - let td = U256::from_big_endian(&tx.get::(key)?.unwrap()); - assert_eq!( - parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty), - Some(td) - ); - } - + block: BlockNumber, + ) -> Result<(), TestRunnerError> { + self.db + .check_no_entry_above_by_value::(block, |val| val)?; + self.db.check_no_entry_above::(block, |key| key)?; + self.db.check_no_entry_above::(block, |key| key.number())?; + self.db.check_no_entry_above::(block, |key| key.number())?; Ok(()) } } diff --git a/crates/stages/src/stages/tx_index.rs b/crates/stages/src/stages/tx_index.rs index 64ac970212..aec3610015 100644 --- a/crates/stages/src/stages/tx_index.rs +++ b/crates/stages/src/stages/tx_index.rs @@ -87,141 +87,18 @@ impl Stage for TxIndex { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::{StageTestDB, StageTestRunner}; + use crate::test_utils::{ + stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError, + UnwindStageTestRunner, + }; use assert_matches::assert_matches; - use reth_interfaces::{db::models::BlockNumHash, test_utils::generators::random_header_range}; + use reth_interfaces::{ + db::models::{BlockNumHash, StoredBlockBody}, + test_utils::generators::random_header_range, + }; use reth_primitives::H256; - const TEST_STAGE: StageId = StageId("PrevStage"); - - #[tokio::test] - async fn execute_empty_db() { - let runner = TxIndexTestRunner::default(); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. })) - ); - } - - #[tokio::test] - async fn execute_no_prev_tx_count() { - let runner = TxIndexTestRunner::default(); - let headers = random_header_range(0..10, H256::zero()); - runner - .db() - .map_put::(&headers, |h| (h.number, h.hash())) - .expect("failed to insert"); - - let (head, tail) = (headers.first().unwrap(), headers.last().unwrap()); - let input = ExecInput { - previous_stage: Some((TEST_STAGE, tail.number)), - stage_progress: Some(head.number), - }; - let rx = runner.execute(input); - assert_matches!( - rx.await.unwrap(), - Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CumulativeTxCount { .. })) - ); - } - - #[tokio::test] - async fn execute() { - let runner = TxIndexTestRunner::default(); - let (start, pivot, end) = (0, 100, 200); - let headers = random_header_range(start..end, H256::zero()); - runner - .db() - .map_put::(&headers, |h| (h.number, h.hash())) - .expect("failed to insert"); - runner - .db() - .transform_append::(&headers[..=pivot], |prev, h| { - ( - BlockNumHash((h.number, h.hash())), - prev.unwrap_or_default() + (rand::random::() as u64), - ) - }) - .expect("failed to insert"); - - let (pivot, tail) = (headers.get(pivot).unwrap(), headers.last().unwrap()); - let input = ExecInput { - previous_stage: Some((TEST_STAGE, tail.number)), - stage_progress: Some(pivot.number), - }; - let rx = runner.execute(input); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress, done, reached_tip }) - if done && reached_tip && stage_progress == tail.number - ); - } - - #[tokio::test] - async fn unwind_empty_db() { - let runner = TxIndexTestRunner::default(); - let rx = runner.unwind(UnwindInput::default()); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == 0 - ); - } - - #[tokio::test] - async fn unwind_no_input() { - let runner = TxIndexTestRunner::default(); - let headers = random_header_range(0..10, H256::zero()); - runner - .db() - .transform_append::(&headers, |prev, h| { - ( - BlockNumHash((h.number, h.hash())), - prev.unwrap_or_default() + (rand::random::() as u64), - ) - }) - .expect("failed to insert"); - - let rx = runner.unwind(UnwindInput::default()); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == 0 - ); - runner - .db() - .check_no_entry_above::(0, |h| h.number()) - .expect("failed to check tx count"); - } - - #[tokio::test] - async fn unwind_with_db_gaps() { - let runner = TxIndexTestRunner::default(); - let first_range = random_header_range(0..20, H256::zero()); - let second_range = random_header_range(50..100, H256::zero()); - runner - .db() - .transform_append::( - &first_range.iter().chain(second_range.iter()).collect::>(), - |prev, h| { - ( - BlockNumHash((h.number, h.hash())), - prev.unwrap_or_default() + (rand::random::() as u64), - ) - }, - ) - .expect("failed to insert"); - - let unwind_to = 10; - let input = UnwindInput { unwind_to, ..Default::default() }; - let rx = runner.unwind(input); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to - ); - runner - .db() - .check_no_entry_above::(unwind_to, |h| h.number()) - .expect("failed to check tx count"); - } + stage_test_suite!(TxIndexTestRunner); #[derive(Default)] pub(crate) struct TxIndexTestRunner { @@ -239,4 +116,82 @@ mod tests { TxIndex {} } } + + impl ExecuteStageTestRunner for TxIndexTestRunner { + type Seed = (); + + fn seed_execution(&mut self, input: ExecInput) -> Result { + let pivot = input.stage_progress.unwrap_or_default(); + let start = pivot.saturating_sub(100); + let mut end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(); + end += 2; // generate 2 additional headers to account for start header lookup + let headers = random_header_range(start..end, H256::zero()); + + let headers = + headers.into_iter().map(|h| (h, rand::random::())).collect::>(); + + self.db.map_put::(&headers, |(h, _)| { + (h.number, h.hash()) + })?; + self.db.map_put::(&headers, |(h, count)| { + ( + BlockNumHash((h.number, h.hash())), + StoredBlockBody { base_tx_id: 0, tx_amount: *count as u64, ommers: vec![] }, + ) + })?; + + let slice_up_to = + std::cmp::min(pivot.saturating_sub(start) as usize, headers.len() - 1); + self.db.transform_append::( + &headers[..=slice_up_to], + |prev, (h, count)| { + (BlockNumHash((h.number, h.hash())), prev.unwrap_or_default() + (*count as u64)) + }, + )?; + + Ok(()) + } + + fn validate_execution( + &self, + input: ExecInput, + _output: Option, + ) -> Result<(), TestRunnerError> { + self.db.query(|tx| { + let (start, end) = ( + input.stage_progress.unwrap_or_default(), + input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(), + ); + if start >= end { + return Ok(()) + } + + let start_hash = + tx.get::(start)?.expect("no canonical found"); + let mut tx_count_cursor = tx.cursor::()?; + let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?; + let mut count = tx_count_walker.next().unwrap()?.1; + let mut last_num = start; + while let Some(entry) = tx_count_walker.next() { + let (key, db_count) = entry?; + count += tx.get::(key)?.unwrap().tx_amount as u64; + assert_eq!(db_count, count); + last_num = key.number(); + } + assert_eq!(last_num, end); + + Ok(()) + })?; + Ok(()) + } + } + + impl UnwindStageTestRunner for TxIndexTestRunner { + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { + self.db.check_no_entry_above::(input.unwind_to, |h| { + h.number() + })?; + Ok(()) + } + } } diff --git a/crates/stages/src/test_utils/macros.rs b/crates/stages/src/test_utils/macros.rs new file mode 100644 index 0000000000..f92c96c7dc --- /dev/null +++ b/crates/stages/src/test_utils/macros.rs @@ -0,0 +1,145 @@ +macro_rules! stage_test_suite { + ($runner:ident) => { + /// Check that the execution is short-circuited if the database is empty. + #[tokio::test] + async fn execute_empty_db() { + // Set up the runner + let runner = $runner::default(); + + // Execute the stage with empty database + let input = crate::stage::ExecInput::default(); + + // Run stage execution + let result = runner.execute(input).await.unwrap(); + assert_matches!( + result, + Err(crate::error::StageError::DatabaseIntegrity(_)) + ); + + // Validate the stage execution + assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + /// Check that the execution is short-circuited if the target was already reached. + #[tokio::test] + async fn execute_already_reached_target() { + let stage_progress = 1000; + + // Set up the runner + let mut runner = $runner::default(); + let input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, stage_progress)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(input).expect("failed to seed"); + + // Run stage execution + let rx = runner.execute(input); + + // Run `after_execution` hook + runner.after_execution(seed).await.expect("failed to run after execution hook"); + + // Assert the successful result + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == stage_progress + ); + + // Validate the stage execution + assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + // Run the complete stage execution flow. + #[tokio::test] + async fn execute() { + let (previous_stage, stage_progress) = (500, 100); + + // Set up the runner + let mut runner = $runner::default(); + let input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(input).expect("failed to seed"); + let rx = runner.execute(input); + + // Run `after_execution` hook + runner.after_execution(seed).await.expect("failed to run after execution hook"); + + // Assert the successful result + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == previous_stage + ); + + // Validate the stage execution + assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation"); + } + + // Check that unwind does not panic on empty database. + #[tokio::test] + async fn unwind_empty_db() { + // Set up the runner + let runner = $runner::default(); + let input = crate::stage::UnwindInput::default(); + + // Run stage unwind + let rx = runner.unwind(input).await; + assert_matches!( + rx, + Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to + ); + + // Validate the stage unwind + assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); + } + + // Run complete execute and unwind flow. + #[tokio::test] + async fn unwind() { + let (previous_stage, stage_progress) = (500, 100); + + // Set up the runner + let mut runner = $runner::default(); + let execute_input = crate::stage::ExecInput { + previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + let seed = runner.seed_execution(execute_input).expect("failed to seed"); + + // Run stage execution + let rx = runner.execute(execute_input); + runner.after_execution(seed).await.expect("failed to run after execution hook"); + + // Assert the successful execution result + let result = rx.await.unwrap(); + assert_matches!( + result, + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == previous_stage + ); + assert!(runner.validate_execution(execute_input, result.ok()).is_ok(), "execution validation"); + + // Run stage unwind + let unwind_input = crate::stage::UnwindInput { + unwind_to: stage_progress, stage_progress, bad_block: None, + }; + let rx = runner.unwind(unwind_input).await; + + // Assert the successful unwind result + assert_matches!( + rx, + Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to + ); + + // Validate the stage unwind + assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation"); + } + }; +} + +pub(crate) use stage_test_suite; diff --git a/crates/stages/src/test_utils/mod.rs b/crates/stages/src/test_utils/mod.rs new file mode 100644 index 0000000000..5fae3794f5 --- /dev/null +++ b/crates/stages/src/test_utils/mod.rs @@ -0,0 +1,14 @@ +use crate::StageId; + +mod macros; +mod runner; +mod stage_db; + +pub(crate) use macros::*; +pub(crate) use runner::{ + ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner, +}; +pub(crate) use stage_db::StageTestDB; + +/// The previous test stage id mock used for testing +pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage"); diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs new file mode 100644 index 0000000000..761e930959 --- /dev/null +++ b/crates/stages/src/test_utils/runner.rs @@ -0,0 +1,82 @@ +use reth_db::{kv::Env, mdbx::WriteMap}; +use reth_interfaces::db::{self, DBContainer}; +use std::borrow::Borrow; +use tokio::sync::oneshot; + +use super::StageTestDB; +use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; + +#[derive(thiserror::Error, Debug)] +pub(crate) enum TestRunnerError { + #[error("Database error occured.")] + Database(#[from] db::Error), + #[error("Internal runner error occured.")] + Internal(#[from] Box), +} + +/// A generic test runner for stages. +#[async_trait::async_trait] +pub(crate) trait StageTestRunner { + type S: Stage> + 'static; + + /// Return a reference to the database. + fn db(&self) -> &StageTestDB; + + /// Return an instance of a Stage. + fn stage(&self) -> Self::S; +} + +#[async_trait::async_trait] +pub(crate) trait ExecuteStageTestRunner: StageTestRunner { + type Seed: Send + Sync; + + /// Seed database for stage execution + fn seed_execution(&mut self, input: ExecInput) -> Result; + + /// Validate stage execution + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError>; + + /// Run [Stage::execute] and return a receiver for the result. + fn execute(&self, input: ExecInput) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + let (db, mut stage) = (self.db().inner(), self.stage()); + tokio::spawn(async move { + let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); + let result = stage.execute(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send message") + }); + rx + } + + /// Run a hook after [Stage::execute]. Required for Headers & Bodies stages. + async fn after_execution(&self, _seed: Self::Seed) -> Result<(), TestRunnerError> { + Ok(()) + } +} + +#[async_trait::async_trait] +pub(crate) trait UnwindStageTestRunner: StageTestRunner { + /// Validate the unwind + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>; + + /// Run [Stage::unwind] and return a receiver for the result. + async fn unwind( + &self, + input: UnwindInput, + ) -> Result> { + let (tx, rx) = oneshot::channel(); + let (db, mut stage) = (self.db().inner(), self.stage()); + tokio::spawn(async move { + let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); + let result = stage.unwind(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send result"); + }); + Box::pin(rx).await.unwrap() + } +} diff --git a/crates/stages/src/test_utils/stage_db.rs b/crates/stages/src/test_utils/stage_db.rs new file mode 100644 index 0000000000..f37d4038b0 --- /dev/null +++ b/crates/stages/src/test_utils/stage_db.rs @@ -0,0 +1,179 @@ +use reth_db::{ + kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind}, + mdbx::{WriteMap, RW}, +}; +use reth_interfaces::db::{ + self, models::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table, +}; +use reth_primitives::{BigEndianHash, BlockNumber, SealedHeader, H256, U256}; +use std::{borrow::Borrow, sync::Arc}; + +/// The [StageTestDB] is used as an internal +/// database for testing stage implementation. +/// +/// ```rust +/// let db = StageTestDB::default(); +/// stage.execute(&mut db.container(), input); +/// ``` +pub(crate) struct StageTestDB { + db: Arc>, +} + +impl Default for StageTestDB { + /// Create a new instance of [StageTestDB] + fn default() -> Self { + Self { db: create_test_db::(EnvKind::RW) } + } +} + +impl StageTestDB { + /// Return a database wrapped in [DBContainer]. + fn container(&self) -> DBContainer<'_, Env> { + DBContainer::new(self.db.borrow()).expect("failed to create db container") + } + + /// Get a pointer to an internal database. + pub(crate) fn inner(&self) -> Arc> { + self.db.clone() + } + + /// Invoke a callback with transaction committing it afterwards + pub(crate) fn commit(&self, f: F) -> Result<(), db::Error> + where + F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>, + { + let mut db = self.container(); + let tx = db.get_mut(); + f(tx)?; + db.commit()?; + Ok(()) + } + + /// Invoke a callback with a read transaction + pub(crate) fn query(&self, f: F) -> Result + where + F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result, + { + f(self.container().get()) + } + + /// Map a collection of values and store them in the database. + /// This function commits the transaction before exiting. + /// + /// ```rust + /// let db = StageTestDB::default(); + /// db.map_put::(&items, |item| item)?; + /// ``` + pub(crate) fn map_put(&self, values: &[S], mut map: F) -> Result<(), db::Error> + where + T: Table, + S: Clone, + F: FnMut(&S) -> (T::Key, T::Value), + { + self.commit(|tx| { + values.iter().try_for_each(|src| { + let (k, v) = map(src); + tx.put::(k, v) + }) + }) + } + + /// Transform a collection of values using a callback and store + /// them in the database. The callback additionally accepts the + /// optional last element that was stored. + /// This function commits the transaction before exiting. + /// + /// ```rust + /// let db = StageTestDB::default(); + /// db.transform_append::(&items, |prev, item| prev.unwrap_or_default() + item)?; + /// ``` + pub(crate) fn transform_append( + &self, + values: &[S], + mut transform: F, + ) -> Result<(), db::Error> + where + T: Table, + ::Value: Clone, + S: Clone, + F: FnMut(&Option<::Value>, &S) -> (T::Key, T::Value), + { + self.commit(|tx| { + let mut cursor = tx.cursor_mut::()?; + let mut last = cursor.last()?.map(|(_, v)| v); + values.iter().try_for_each(|src| { + let (k, v) = transform(&last, src); + last = Some(v.clone()); + cursor.append(k, v) + }) + }) + } + + /// Check that there is no table entry above a given + /// block by [Table::Key] + pub(crate) fn check_no_entry_above( + &self, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + T: Table, + F: FnMut(T::Key) -> BlockNumber, + { + self.query(|tx| { + let mut cursor = tx.cursor::()?; + if let Some((key, _)) = cursor.last()? { + assert!(selector(key) <= block); + } + Ok(()) + }) + } + + /// Check that there is no table entry above a given + /// block by [Table::Value] + pub(crate) fn check_no_entry_above_by_value( + &self, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + T: Table, + F: FnMut(T::Value) -> BlockNumber, + { + self.query(|tx| { + let mut cursor = tx.cursor::()?; + if let Some((_, value)) = cursor.last()? { + assert!(selector(value) <= block); + } + Ok(()) + }) + } + + /// Insert ordered collection of [SealedHeader] into the corresponding tables + /// that are supposed to be populated by the headers stage. + pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> + where + I: Iterator, + { + self.commit(|tx| { + let headers = headers.collect::>(); + + let mut td = U256::from_big_endian( + &tx.cursor::()?.last()?.map(|(_, v)| v).unwrap_or_default(), + ); + + for header in headers { + let key: BlockNumHash = (header.number, header.hash()).into(); + + tx.put::(header.number, header.hash())?; + tx.put::(header.hash(), header.number)?; + tx.put::(key, header.clone().unseal())?; + + td += header.difficulty; + tx.put::(key, H256::from_uint(&td).as_bytes().to_vec())?; + } + + Ok(()) + }) + } +} diff --git a/crates/stages/src/util.rs b/crates/stages/src/util.rs index af221916d6..62ec6a9322 100644 --- a/crates/stages/src/util.rs +++ b/crates/stages/src/util.rs @@ -135,189 +135,3 @@ pub(crate) mod unwind { Ok(()) } } - -#[cfg(test)] -pub(crate) mod test_utils { - use reth_db::{ - kv::{test_utils::create_test_db, Env, EnvKind}, - mdbx::WriteMap, - }; - use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table}; - use reth_primitives::BlockNumber; - use std::{borrow::Borrow, sync::Arc}; - use tokio::sync::oneshot; - - use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; - - /// The [StageTestDB] is used as an internal - /// database for testing stage implementation. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// stage.execute(&mut db.container(), input); - /// ``` - pub(crate) struct StageTestDB { - db: Arc>, - } - - impl Default for StageTestDB { - /// Create a new instance of [StageTestDB] - fn default() -> Self { - Self { db: create_test_db::(EnvKind::RW) } - } - } - - impl StageTestDB { - /// Get a pointer to an internal database. - pub(crate) fn inner(&self) -> Arc> { - self.db.clone() - } - - /// Return a database wrapped in [DBContainer]. - pub(crate) fn container(&self) -> DBContainer<'_, Env> { - DBContainer::new(self.db.borrow()).expect("failed to create db container") - } - - /// Map a collection of values and store them in the database. - /// This function commits the transaction before exiting. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// db.map_put::(&items, |item| item)?; - /// ``` - pub(crate) fn map_put(&self, values: &[S], mut map: F) -> Result<(), Error> - where - T: Table, - S: Clone, - F: FnMut(&S) -> (T::Key, T::Value), - { - let mut db = self.container(); - let tx = db.get_mut(); - values.iter().try_for_each(|src| { - let (k, v) = map(src); - tx.put::(k, v) - })?; - db.commit()?; - Ok(()) - } - - /// Transform a collection of values using a callback and store - /// them in the database. The callback additionally accepts the - /// optional last element that was stored. - /// This function commits the transaction before exiting. - /// - /// ```rust - /// let db = StageTestDB::default(); - /// db.transform_append::(&items, |prev, item| prev.unwrap_or_default() + item)?; - /// ``` - pub(crate) fn transform_append( - &self, - values: &[S], - mut transform: F, - ) -> Result<(), Error> - where - T: Table, - ::Value: Clone, - S: Clone, - F: FnMut(&Option<::Value>, &S) -> (T::Key, T::Value), - { - let mut db = self.container(); - let tx = db.get_mut(); - let mut cursor = tx.cursor_mut::()?; - let mut last = cursor.last()?.map(|(_, v)| v); - values.iter().try_for_each(|src| { - let (k, v) = transform(&last, src); - last = Some(v.clone()); - cursor.append(k, v) - })?; - db.commit()?; - Ok(()) - } - - /// Check that there is no table entry above a given - /// block by [Table::Key] - pub(crate) fn check_no_entry_above( - &self, - block: BlockNumber, - mut selector: F, - ) -> Result<(), Error> - where - T: Table, - F: FnMut(T::Key) -> BlockNumber, - { - let db = self.container(); - let tx = db.get(); - - let mut cursor = tx.cursor::()?; - if let Some((key, _)) = cursor.last()? { - assert!(selector(key) <= block); - } - - Ok(()) - } - - /// Check that there is no table entry above a given - /// block by [Table::Value] - pub(crate) fn check_no_entry_above_by_value( - &self, - block: BlockNumber, - mut selector: F, - ) -> Result<(), Error> - where - T: Table, - F: FnMut(T::Value) -> BlockNumber, - { - let db = self.container(); - let tx = db.get(); - - let mut cursor = tx.cursor::()?; - if let Some((_, value)) = cursor.last()? { - assert!(selector(value) <= block); - } - - Ok(()) - } - } - - /// A generic test runner for stages. - #[async_trait::async_trait] - pub(crate) trait StageTestRunner { - type S: Stage> + 'static; - - /// Return a reference to the database. - fn db(&self) -> &StageTestDB; - - /// Return an instance of a Stage. - fn stage(&self) -> Self::S; - - /// Run [Stage::execute] and return a receiver for the result. - fn execute(&self, input: ExecInput) -> oneshot::Receiver> { - let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.db().inner(), self.stage()); - tokio::spawn(async move { - let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); - let result = stage.execute(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send message") - }); - rx - } - - /// Run [Stage::unwind] and return a receiver for the result. - fn unwind( - &self, - input: UnwindInput, - ) -> oneshot::Receiver>> - { - let (tx, rx) = oneshot::channel(); - let (db, mut stage) = (self.db().inner(), self.stage()); - tokio::spawn(async move { - let mut db = DBContainer::new(db.borrow()).expect("failed to create db container"); - let result = stage.unwind(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send result"); - }); - rx - } - } -}