diff --git a/crates/interfaces/src/test_utils/headers.rs b/crates/interfaces/src/test_utils/headers.rs index ddd9c0d138..50bdcf270c 100644 --- a/crates/interfaces/src/test_utils/headers.rs +++ b/crates/interfaces/src/test_utils/headers.rs @@ -92,6 +92,12 @@ impl TestHeadersClient { pub fn send_header_response(&self, id: u64, headers: Vec
) { self.res_tx.send((id, headers).into()).expect("failed to send header response"); } + + /// Helper for pushing responses to the client + pub async fn send_header_response_delayed(&self, id: u64, headers: Vec
, secs: u64) { + tokio::time::sleep(Duration::from_secs(secs)).await; + self.send_header_response(id, headers); + } } #[async_trait::async_trait] @@ -105,6 +111,9 @@ impl HeadersClient for TestHeadersClient { } async fn stream_headers(&self) -> HeadersStream { + if !self.res_rx.is_empty() { + println!("WARNING: broadcast receiver already contains messages.") + } Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok())) } } diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index f24c81547c..e9e0d09ef0 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -182,223 +182,160 @@ impl HeaderStage { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::StageTestRunner; + use crate::util::test_utils::{ + stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, + }; use assert_matches::assert_matches; use reth_interfaces::test_utils::{gen_random_header, gen_random_header_range}; - use test_utils::{HeadersTestRunner, TestDownloader}; + use test_runner::HeadersTestRunner; - const TEST_STAGE: StageId = StageId("Headers"); + stage_test_suite!(HeadersTestRunner); #[tokio::test] - // Check that the execution errors on empty database or - // prev progress missing from the database. - async fn execute_empty_db() { - let runner = HeadersTestRunner::default(); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CannonicalHeader { .. })) - ); - } - - #[tokio::test] - // Check that the execution exits on downloader timeout. + // Validate that the execution does not fail on timeout async fn execute_timeout() { - let head = gen_random_header(0, None); - let runner = - HeadersTestRunner::with_downloader(TestDownloader::new(Err(DownloadError::Timeout { - request_id: 0, - }))); - runner.insert_header(&head).expect("failed to insert header"); + let mut runner = HeadersTestRunner::default(); + let (stage_progress, previous_stage) = (0, 0); + runner + .seed_execution(ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }) + .expect("failed to seed execution"); let rx = runner.execute(ExecInput::default()); - runner.consensus.update_tip(H256::from_low_u64_be(1)); - assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done); - } - - #[tokio::test] - // Check that validation error is propagated during the execution. - async fn execute_validation_error() { - let head = gen_random_header(0, None); - let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Err( - DownloadError::HeaderValidation { hash: H256::zero(), details: "".to_owned() }, - ))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput::default()); - runner.consensus.update_tip(H256::from_low_u64_be(1)); - assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block }) if block == 0); - } - - #[tokio::test] - // Validate that all necessary tables are updated after the - // header download on no previous progress. - async fn execute_no_progress() { - let (start, end) = (0, 100); - let head = gen_random_header(start, None); - let headers = gen_random_header_range(start + 1..end, head.hash()); - - let result = headers.iter().rev().cloned().collect::>(); - let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput::default()); - let tip = headers.last().unwrap(); - runner.consensus.update_tip(tip.hash()); - + runner.after_execution().await.expect("failed to run after execution hook"); assert_matches!( rx.await.unwrap(), - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number + Ok(ExecOutput { done, reached_tip, stage_progress: out_stage_progress }) + if !done && !reached_tip && out_stage_progress == stage_progress ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); + assert!(runner.validate_execution().is_ok(), "validation failed"); } #[tokio::test] // Validate that all necessary tables are updated after the // header download with some previous progress. async fn execute_prev_progress() { - let (start, end) = (10000, 10241); - let head = gen_random_header(start, None); - let headers = gen_random_header_range(start + 1..end, head.hash()); - - let result = headers.iter().rev().cloned().collect::>(); - let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result))); - runner.insert_header(&head).expect("failed to insert header"); - - let rx = runner.execute(ExecInput { - previous_stage: Some((TEST_STAGE, head.number)), - stage_progress: Some(head.number), - }); - let tip = headers.last().unwrap(); - runner.consensus.update_tip(tip.hash()); - + let mut runner = HeadersTestRunner::default(); + let (stage_progress, previous_stage) = (10000, 10241); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed execution"); + let rx = runner.execute(input); + runner.after_execution().await.expect("failed to run after execution hook"); assert_matches!( rx.await.unwrap(), Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number + if done && reached_tip && stage_progress == stage_progress ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); + assert!(runner.validate_execution().is_ok(), "validation failed"); } - #[tokio::test] - // Execute the stage with linear downloader - async fn execute_with_linear_downloader() { - let (start, end) = (1000, 1200); - let head = gen_random_header(start, None); - let headers = gen_random_header_range(start + 1..end, head.hash()); + // TODO: + // #[tokio::test] + // // Execute the stage with linear downloader + // async fn execute_with_linear_downloader() { + // let mut runner = HeadersTestRunner::with_linear_downloader(); + // let (stage_progress, previous_stage) = (1000, 1200); + // let input = ExecInput { + // previous_stage: Some((PREV_STAGE_ID, previous_stage)), + // stage_progress: Some(stage_progress), + // }; + // runner.seed_execution(input).expect("failed to seed execution"); + // let rx = runner.execute(input); - let runner = HeadersTestRunner::with_linear_downloader(); - runner.insert_header(&head).expect("failed to insert header"); - let rx = runner.execute(ExecInput { - previous_stage: Some((TEST_STAGE, head.number)), - stage_progress: Some(head.number), - }); + // // skip hook for linear downloader + // let headers = runner.context.as_ref().unwrap(); + // let tip = headers.first().unwrap(); + // runner.consensus.update_tip(tip.hash()); - let tip = headers.last().unwrap(); - runner.consensus.update_tip(tip.hash()); + // // TODO: + // let mut download_result = headers.clone(); + // download_result.insert(0, headers.last().unwrap().clone()); + // runner + // .client + // .on_header_request(1, |id, _| { + // let response = download_result.clone().into_iter().map(|h| h.unseal()).collect(); + // runner.client.send_header_response(id, response) + // }) + // .await; - let mut download_result = headers.clone(); - download_result.insert(0, head); - runner - .client - .on_header_request(1, |id, _| { - runner.client.send_header_response( - id, - download_result.clone().into_iter().map(|h| h.unseal()).collect(), - ) - }) - .await; + // assert_matches!( + // rx.await.unwrap(), + // Ok(ExecOutput { done, reached_tip, stage_progress }) + // if done && reached_tip && stage_progress == tip.number + // ); + // assert!(runner.validate_execution().is_ok(), "validation failed"); + // } - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { done, reached_tip, stage_progress }) - if done && reached_tip && stage_progress == tip.number - ); - assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); - } + // TODO: + // #[tokio::test] + // // Check that unwind can remove headers across gaps + // async fn unwind_db_gaps() { + // let runner = HeadersTestRunner::default(); + // let head = gen_random_header(0, None); + // let first_range = gen_random_header_range(1..20, head.hash()); + // let second_range = gen_random_header_range(50..100, H256::zero()); + // runner.insert_header(&head).expect("failed to insert header"); + // runner + // .insert_headers(first_range.iter().chain(second_range.iter())) + // .expect("failed to insert headers"); - #[tokio::test] - // Check that unwind does not panic on empty database. - async fn unwind_empty_db() { - let unwind_to = 100; - let runner = HeadersTestRunner::default(); - let rx = - runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to - ); - } + // let unwind_to = 15; + // let input = UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }; + // let rx = runner.unwind(input); + // assert_matches!( + // rx.await.unwrap(), + // Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to + // ); + // assert!(runner.validate_unwind(input).is_ok(), "validation failed"); + // } - #[tokio::test] - // Check that unwind can remove headers across gaps - async fn unwind_db_gaps() { - let runner = HeadersTestRunner::default(); - let head = gen_random_header(0, None); - let first_range = gen_random_header_range(1..20, head.hash()); - let second_range = gen_random_header_range(50..100, H256::zero()); - runner.insert_header(&head).expect("failed to insert header"); - runner - .insert_headers(first_range.iter().chain(second_range.iter())) - .expect("failed to insert headers"); - - let unwind_to = 15; - let rx = - runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to - ); - - runner - .db() - .check_no_entry_above::(unwind_to, |key| key) - .expect("failed to check cannonical headers"); - runner - .db() - .check_no_entry_above_by_value::(unwind_to, |val| val) - .expect("failed to check header numbers"); - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("failed to check headers"); - runner - .db() - .check_no_entry_above::(unwind_to, |key| key.number()) - .expect("failed to check td"); - } - - mod test_utils { + mod test_runner { use crate::{ stages::headers::HeaderStage, - util::test_utils::{StageTestDB, StageTestRunner}, + util::test_utils::{ + ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner, + }, + ExecInput, UnwindInput, }; use async_trait::async_trait; use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader}; use reth_interfaces::{ consensus::ForkchoiceState, db::{self, models::blocks::BlockNumHash, tables, DbTx}, - p2p::headers::downloader::{DownloadError, Downloader}, - test_utils::{TestConsensus, TestHeadersClient}, + p2p::headers::{ + client::HeadersClient, + downloader::{DownloadError, Downloader}, + }, + test_utils::{ + gen_random_header, gen_random_header_range, TestConsensus, TestHeadersClient, + }, }; use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256}; use std::{ops::Deref, sync::Arc, time::Duration}; + use tokio_stream::StreamExt; pub(crate) struct HeadersTestRunner { pub(crate) consensus: Arc, pub(crate) client: Arc, + pub(crate) context: Option>, downloader: Arc, db: StageTestDB, } impl Default for HeadersTestRunner { fn default() -> Self { + let client = Arc::new(TestHeadersClient::default()); Self { - client: Arc::new(TestHeadersClient::default()), + client: client.clone(), consensus: Arc::new(TestConsensus::default()), - downloader: Arc::new(TestDownloader::new(Ok(Vec::default()))), + downloader: Arc::new(TestDownloader::new(client)), db: StageTestDB::default(), + context: None, } } } @@ -419,26 +356,99 @@ mod tests { } } + #[async_trait::async_trait] + impl ExecuteStageTestRunner for HeadersTestRunner { + fn seed_execution( + &mut self, + input: ExecInput, + ) -> Result<(), Box> { + let start = input.stage_progress.unwrap_or_default(); + let head = gen_random_header(start, None); + self.insert_header(&head)?; + + // use previous progress as seed size + let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1; + if end > start + 1 { + let mut headers = gen_random_header_range(start + 1..end, head.hash()); + headers.reverse(); + self.context = Some(headers); + } + Ok(()) + } + + async fn after_execution( + &self, + ) -> Result<(), Box> { + let (tip, headers) = match self.context { + Some(ref headers) if headers.len() > 1 => { + // headers are in reverse + (headers.first().unwrap().hash(), headers.clone()) + } + _ => (H256::from_low_u64_be(rand::random()), Vec::default()), + }; + self.consensus.update_tip(tip); + if !headers.is_empty() { + self.client + .send_header_response_delayed( + 0, + headers.iter().cloned().map(|h| h.unseal()).collect(), + 1, + ) + .await; + } + Ok(()) + } + + fn validate_execution(&self) -> Result<(), Box> { + if let Some(ref headers) = self.context { + headers.iter().try_for_each(|h| self.validate_db_header(&h))?; + } + Ok(()) + } + } + + impl UnwindStageTestRunner for HeadersTestRunner { + fn seed_unwind( + &mut self, + input: UnwindInput, + highest_entry: u64, + ) -> Result<(), Box> { + let lowest_entry = input.unwind_to.saturating_sub(100); + let headers = gen_random_header_range(lowest_entry..highest_entry, H256::zero()); + self.insert_headers(headers.iter())?; + Ok(()) + } + + fn validate_unwind( + &self, + input: UnwindInput, + ) -> Result<(), Box> { + let unwind_to = input.unwind_to; + self.db().check_no_entry_above_by_value::( + unwind_to, + |val| val, + )?; + self.db() + .check_no_entry_above::(unwind_to, |key| key)?; + self.db() + .check_no_entry_above::(unwind_to, |key| key.number())?; + self.db() + .check_no_entry_above::(unwind_to, |key| key.number())?; + Ok(()) + } + } + impl HeadersTestRunner> { pub(crate) fn with_linear_downloader() -> Self { let client = Arc::new(TestHeadersClient::default()); let consensus = Arc::new(TestConsensus::default()); let downloader = Arc::new(LinearDownloadBuilder::new().build(consensus.clone(), client.clone())); - Self { client, consensus, downloader, db: StageTestDB::default() } + Self { client, consensus, downloader, db: StageTestDB::default(), context: None } } } impl HeadersTestRunner { - pub(crate) fn with_downloader(downloader: D) -> Self { - HeadersTestRunner { - client: Arc::new(TestHeadersClient::default()), - consensus: Arc::new(TestConsensus::default()), - downloader: Arc::new(downloader), - db: StageTestDB::default(), - } - } - /// Insert header into tables pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> { self.insert_headers(std::iter::once(header)) @@ -504,12 +514,12 @@ mod tests { #[derive(Debug)] pub(crate) struct TestDownloader { - result: Result, DownloadError>, + client: Arc, } impl TestDownloader { - pub(crate) fn new(result: Result, DownloadError>) -> Self { - Self { result } + pub(crate) fn new(client: Arc) -> Self { + Self { client } } } @@ -527,7 +537,7 @@ mod tests { } fn client(&self) -> &Self::Client { - unimplemented!() + &self.client } async fn download( @@ -535,7 +545,15 @@ mod tests { _: &SealedHeader, _: &ForkchoiceState, ) -> Result, DownloadError> { - self.result.clone() + let stream = self.client.stream_headers().await; + let stream = stream.timeout(Duration::from_secs(3)); + + match Box::pin(stream).try_next().await { + Ok(Some(res)) => { + Ok(res.headers.iter().map(|h| h.clone().seal()).collect::>()) + } + _ => Err(DownloadError::Timeout { request_id: 0 }), + } } } } diff --git a/crates/stages/src/stages/tx_index.rs b/crates/stages/src/stages/tx_index.rs index a2d986ba71..05b45cb383 100644 --- a/crates/stages/src/stages/tx_index.rs +++ b/crates/stages/src/stages/tx_index.rs @@ -87,22 +87,15 @@ impl Stage for TxIndex { #[cfg(test)] mod tests { use super::*; - use crate::util::test_utils::{StageTestDB, StageTestRunner}; + use crate::util::test_utils::{ + stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, + UnwindStageTestRunner, PREV_STAGE_ID, + }; use assert_matches::assert_matches; use reth_interfaces::{db::models::BlockNumHash, test_utils::gen_random_header_range}; use reth_primitives::H256; - const TEST_STAGE: StageId = StageId("PrevStage"); - - #[tokio::test] - async fn execute_empty_db() { - let runner = TxIndexTestRunner::default(); - let rx = runner.execute(ExecInput::default()); - assert_matches!( - rx.await.unwrap(), - Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CannonicalHeader { .. })) - ); - } + stage_test_suite!(TxIndexTestRunner); #[tokio::test] async fn execute_no_prev_tx_count() { @@ -115,7 +108,7 @@ mod tests { let (head, tail) = (headers.first().unwrap(), headers.last().unwrap()); let input = ExecInput { - previous_stage: Some((TEST_STAGE, tail.number)), + previous_stage: Some((PREV_STAGE_ID, tail.number)), stage_progress: Some(head.number), }; let rx = runner.execute(input); @@ -125,48 +118,6 @@ mod tests { ); } - #[tokio::test] - async fn execute() { - let runner = TxIndexTestRunner::default(); - let (start, pivot, end) = (0, 100, 200); - let headers = gen_random_header_range(start..end, H256::zero()); - runner - .db() - .map_put::(&headers, |h| (h.number, h.hash())) - .expect("failed to insert"); - runner - .db() - .transform_append::(&headers[..=pivot], |prev, h| { - ( - BlockNumHash((h.number, h.hash())), - prev.unwrap_or_default() + (rand::random::() as u64), - ) - }) - .expect("failed to insert"); - - let (pivot, tail) = (headers.get(pivot).unwrap(), headers.last().unwrap()); - let input = ExecInput { - previous_stage: Some((TEST_STAGE, tail.number)), - stage_progress: Some(pivot.number), - }; - let rx = runner.execute(input); - assert_matches!( - rx.await.unwrap(), - Ok(ExecOutput { stage_progress, done, reached_tip }) - if done && reached_tip && stage_progress == tail.number - ); - } - - #[tokio::test] - async fn unwind_empty_db() { - let runner = TxIndexTestRunner::default(); - let rx = runner.unwind(UnwindInput::default()); - assert_matches!( - rx.await.unwrap(), - Ok(UnwindOutput { stage_progress }) if stage_progress == 0 - ); - } - #[tokio::test] async fn unwind_no_input() { let runner = TxIndexTestRunner::default(); @@ -239,4 +190,65 @@ mod tests { TxIndex {} } } + + impl ExecuteStageTestRunner for TxIndexTestRunner { + fn seed_execution( + &mut self, + input: ExecInput, + ) -> Result<(), Box> { + let pivot = input.stage_progress.unwrap_or_default(); + let start = pivot.saturating_sub(100); + let end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(); + let headers = gen_random_header_range(start..end, H256::zero()); + self.db() + .map_put::(&headers, |h| (h.number, h.hash()))?; + self.db().transform_append::( + &headers[..=(pivot as usize)], + |prev, h| { + ( + BlockNumHash((h.number, h.hash())), + prev.unwrap_or_default() + (rand::random::() as u64), + ) + }, + )?; + Ok(()) + } + + fn validate_execution(&self) -> Result<(), Box> { + // TODO: + Ok(()) + } + } + + impl UnwindStageTestRunner for TxIndexTestRunner { + fn seed_unwind( + &mut self, + input: UnwindInput, + highest_entry: u64, + ) -> Result<(), Box> { + // TODO: accept range + let headers = gen_random_header_range(input.unwind_to..highest_entry, H256::zero()); + self.db().transform_append::( + &headers, + |prev, h| { + ( + BlockNumHash((h.number, h.hash())), + prev.unwrap_or_default() + (rand::random::() as u64), + ) + }, + )?; + Ok(()) + } + + fn validate_unwind( + &self, + input: UnwindInput, + ) -> Result<(), Box> { + self.db() + .check_no_entry_above::(input.unwind_to, |h| { + h.number() + })?; + Ok(()) + } + } } diff --git a/crates/stages/src/util.rs b/crates/stages/src/util.rs index af221916d6..a9817964fe 100644 --- a/crates/stages/src/util.rs +++ b/crates/stages/src/util.rs @@ -139,15 +139,18 @@ pub(crate) mod unwind { #[cfg(test)] pub(crate) mod test_utils { use reth_db::{ - kv::{test_utils::create_test_db, Env, EnvKind}, - mdbx::WriteMap, + kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind}, + mdbx::{WriteMap, RW}, }; use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table}; use reth_primitives::BlockNumber; use std::{borrow::Borrow, sync::Arc}; use tokio::sync::oneshot; - use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; + use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; + + /// The previous test stage id mock used for testing + pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage"); /// The [StageTestDB] is used as an internal /// database for testing stage implementation. @@ -178,6 +181,30 @@ pub(crate) mod test_utils { DBContainer::new(self.db.borrow()).expect("failed to create db container") } + fn commit(&self, f: F) -> Result<(), Error> + where + F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), Error>, + { + let mut db = self.container(); + let tx = db.get_mut(); + f(tx)?; + db.commit()?; + Ok(()) + } + + /// Put a single value into the table + pub(crate) fn put(&self, k: T::Key, v: T::Value) -> Result<(), Error> { + self.commit(|tx| tx.put::(k, v)) + } + + /// Delete a single value from the table + pub(crate) fn delete(&self, k: T::Key) -> Result<(), Error> { + self.commit(|tx| { + tx.delete::(k, None)?; + Ok(()) + }) + } + /// Map a collection of values and store them in the database. /// This function commits the transaction before exiting. /// @@ -191,14 +218,12 @@ pub(crate) mod test_utils { S: Clone, F: FnMut(&S) -> (T::Key, T::Value), { - let mut db = self.container(); - let tx = db.get_mut(); - values.iter().try_for_each(|src| { - let (k, v) = map(src); - tx.put::(k, v) - })?; - db.commit()?; - Ok(()) + self.commit(|tx| { + values.iter().try_for_each(|src| { + let (k, v) = map(src); + tx.put::(k, v) + }) + }) } /// Transform a collection of values using a callback and store @@ -221,17 +246,15 @@ pub(crate) mod test_utils { S: Clone, F: FnMut(&Option<::Value>, &S) -> (T::Key, T::Value), { - let mut db = self.container(); - let tx = db.get_mut(); - let mut cursor = tx.cursor_mut::()?; - let mut last = cursor.last()?.map(|(_, v)| v); - values.iter().try_for_each(|src| { - let (k, v) = transform(&last, src); - last = Some(v.clone()); - cursor.append(k, v) - })?; - db.commit()?; - Ok(()) + self.commit(|tx| { + let mut cursor = tx.cursor_mut::()?; + let mut last = cursor.last()?.map(|(_, v)| v); + values.iter().try_for_each(|src| { + let (k, v) = transform(&last, src); + last = Some(v.clone()); + cursor.append(k, v) + }) + }) } /// Check that there is no table entry above a given @@ -289,6 +312,18 @@ pub(crate) mod test_utils { /// Return an instance of a Stage. fn stage(&self) -> Self::S; + } + + #[async_trait::async_trait] + pub(crate) trait ExecuteStageTestRunner: StageTestRunner { + /// Seed database for stage execution + fn seed_execution( + &mut self, + input: ExecInput, + ) -> Result<(), Box>; + + /// Validate stage execution + fn validate_execution(&self) -> Result<(), Box>; /// Run [Stage::execute] and return a receiver for the result. fn execute(&self, input: ExecInput) -> oneshot::Receiver> { @@ -303,6 +338,26 @@ pub(crate) mod test_utils { rx } + /// Run a hook after [Stage::execute]. Required for Headers & Bodies stages. + async fn after_execution(&self) -> Result<(), Box> { + Ok(()) + } + } + + pub(crate) trait UnwindStageTestRunner: StageTestRunner { + /// Seed database for stage unwind + fn seed_unwind( + &mut self, + input: UnwindInput, + highest_entry: u64, + ) -> Result<(), Box>; + + /// Validate the unwind + fn validate_unwind( + &self, + input: UnwindInput, + ) -> Result<(), Box>; + /// Run [Stage::unwind] and return a receiver for the result. fn unwind( &self, @@ -320,4 +375,71 @@ pub(crate) mod test_utils { rx } } + + macro_rules! stage_test_suite { + ($runner:ident) => { + #[tokio::test] + // Check that the execution errors on empty database or + // prev progress missing from the database. + async fn execute_empty_db() { + let runner = $runner::default(); + let rx = runner.execute(crate::stage::ExecInput::default()); + assert_matches!( + rx.await.unwrap(), + Err(crate::error::StageError::DatabaseIntegrity(_)) + ); + assert!(runner.validate_execution().is_ok(), "execution validation"); + } + + #[tokio::test] + async fn execute() { + let (previous_stage, stage_progress) = (1000, 100); + let mut runner = $runner::default(); + let input = crate::stage::ExecInput { + previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; + runner.seed_execution(input).expect("failed to seed"); + let rx = runner.execute(input); + runner.after_execution().await.expect("failed to run after execution hook"); + assert_matches!( + rx.await.unwrap(), + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == previous_stage + ); + assert!(runner.validate_execution().is_ok(), "execution validation"); + } + + #[tokio::test] + // Check that unwind does not panic on empty database. + async fn unwind_empty_db() { + let runner = $runner::default(); + let input = crate::stage::UnwindInput::default(); + let rx = runner.unwind(input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to + ); + assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); + } + + #[tokio::test] + async fn unwind() { + let (unwind_to, highest_entry) = (100, 200); + let mut runner = $runner::default(); + let input = crate::stage::UnwindInput { + unwind_to, stage_progress: unwind_to, bad_block: None, + }; + runner.seed_unwind(input, highest_entry).expect("failed to seed"); + let rx = runner.unwind(input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to + ); + assert!(runner.validate_unwind(input).is_ok(), "unwind validation"); + } + }; + } + + pub(crate) use stage_test_suite; }