test(sync): stage test suite

This commit is contained in:
Roman Krasiuk
2022-11-15 13:47:15 +02:00
parent 5ca2cab97f
commit 65ac844cb2
4 changed files with 417 additions and 256 deletions

View File

@@ -92,6 +92,12 @@ impl TestHeadersClient {
pub fn send_header_response(&self, id: u64, headers: Vec<Header>) {
self.res_tx.send((id, headers).into()).expect("failed to send header response");
}
/// Helper for pushing responses to the client
pub async fn send_header_response_delayed(&self, id: u64, headers: Vec<Header>, secs: u64) {
tokio::time::sleep(Duration::from_secs(secs)).await;
self.send_header_response(id, headers);
}
}
#[async_trait::async_trait]
@@ -105,6 +111,9 @@ impl HeadersClient for TestHeadersClient {
}
async fn stream_headers(&self) -> HeadersStream {
if !self.res_rx.is_empty() {
println!("WARNING: broadcast receiver already contains messages.")
}
Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok()))
}
}

View File

@@ -182,223 +182,160 @@ impl<D: Downloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::StageTestRunner;
use crate::util::test_utils::{
stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID,
};
use assert_matches::assert_matches;
use reth_interfaces::test_utils::{gen_random_header, gen_random_header_range};
use test_utils::{HeadersTestRunner, TestDownloader};
use test_runner::HeadersTestRunner;
const TEST_STAGE: StageId = StageId("Headers");
stage_test_suite!(HeadersTestRunner);
#[tokio::test]
// Check that the execution errors on empty database or
// prev progress missing from the database.
async fn execute_empty_db() {
let runner = HeadersTestRunner::default();
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CannonicalHeader { .. }))
);
}
#[tokio::test]
// Check that the execution exits on downloader timeout.
// Validate that the execution does not fail on timeout
async fn execute_timeout() {
let head = gen_random_header(0, None);
let runner =
HeadersTestRunner::with_downloader(TestDownloader::new(Err(DownloadError::Timeout {
request_id: 0,
})));
runner.insert_header(&head).expect("failed to insert header");
let mut runner = HeadersTestRunner::default();
let (stage_progress, previous_stage) = (0, 0);
runner
.seed_execution(ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
})
.expect("failed to seed execution");
let rx = runner.execute(ExecInput::default());
runner.consensus.update_tip(H256::from_low_u64_be(1));
assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done);
}
#[tokio::test]
// Check that validation error is propagated during the execution.
async fn execute_validation_error() {
let head = gen_random_header(0, None);
let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Err(
DownloadError::HeaderValidation { hash: H256::zero(), details: "".to_owned() },
)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput::default());
runner.consensus.update_tip(H256::from_low_u64_be(1));
assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block }) if block == 0);
}
#[tokio::test]
// Validate that all necessary tables are updated after the
// header download on no previous progress.
async fn execute_no_progress() {
let (start, end) = (0, 100);
let head = gen_random_header(start, None);
let headers = gen_random_header_range(start + 1..end, head.hash());
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput::default());
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
runner.after_execution().await.expect("failed to run after execution hook");
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
Ok(ExecOutput { done, reached_tip, stage_progress: out_stage_progress })
if !done && !reached_tip && out_stage_progress == stage_progress
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok());
assert!(runner.validate_execution().is_ok(), "validation failed");
}
#[tokio::test]
// Validate that all necessary tables are updated after the
// header download with some previous progress.
async fn execute_prev_progress() {
let (start, end) = (10000, 10241);
let head = gen_random_header(start, None);
let headers = gen_random_header_range(start + 1..end, head.hash());
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput {
previous_stage: Some((TEST_STAGE, head.number)),
stage_progress: Some(head.number),
});
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
let mut runner = HeadersTestRunner::default();
let (stage_progress, previous_stage) = (10000, 10241);
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
runner.after_execution().await.expect("failed to run after execution hook");
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
if done && reached_tip && stage_progress == stage_progress
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok());
assert!(runner.validate_execution().is_ok(), "validation failed");
}
#[tokio::test]
// Execute the stage with linear downloader
async fn execute_with_linear_downloader() {
let (start, end) = (1000, 1200);
let head = gen_random_header(start, None);
let headers = gen_random_header_range(start + 1..end, head.hash());
// TODO:
// #[tokio::test]
// // Execute the stage with linear downloader
// async fn execute_with_linear_downloader() {
// let mut runner = HeadersTestRunner::with_linear_downloader();
// let (stage_progress, previous_stage) = (1000, 1200);
// let input = ExecInput {
// previous_stage: Some((PREV_STAGE_ID, previous_stage)),
// stage_progress: Some(stage_progress),
// };
// runner.seed_execution(input).expect("failed to seed execution");
// let rx = runner.execute(input);
let runner = HeadersTestRunner::with_linear_downloader();
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput {
previous_stage: Some((TEST_STAGE, head.number)),
stage_progress: Some(head.number),
});
// // skip hook for linear downloader
// let headers = runner.context.as_ref().unwrap();
// let tip = headers.first().unwrap();
// runner.consensus.update_tip(tip.hash());
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
// // TODO:
// let mut download_result = headers.clone();
// download_result.insert(0, headers.last().unwrap().clone());
// runner
// .client
// .on_header_request(1, |id, _| {
// let response = download_result.clone().into_iter().map(|h| h.unseal()).collect();
// runner.client.send_header_response(id, response)
// })
// .await;
let mut download_result = headers.clone();
download_result.insert(0, head);
runner
.client
.on_header_request(1, |id, _| {
runner.client.send_header_response(
id,
download_result.clone().into_iter().map(|h| h.unseal()).collect(),
)
})
.await;
// assert_matches!(
// rx.await.unwrap(),
// Ok(ExecOutput { done, reached_tip, stage_progress })
// if done && reached_tip && stage_progress == tip.number
// );
// assert!(runner.validate_execution().is_ok(), "validation failed");
// }
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok());
}
// TODO:
// #[tokio::test]
// // Check that unwind can remove headers across gaps
// async fn unwind_db_gaps() {
// let runner = HeadersTestRunner::default();
// let head = gen_random_header(0, None);
// let first_range = gen_random_header_range(1..20, head.hash());
// let second_range = gen_random_header_range(50..100, H256::zero());
// runner.insert_header(&head).expect("failed to insert header");
// runner
// .insert_headers(first_range.iter().chain(second_range.iter()))
// .expect("failed to insert headers");
#[tokio::test]
// Check that unwind does not panic on empty database.
async fn unwind_empty_db() {
let unwind_to = 100;
let runner = HeadersTestRunner::default();
let rx =
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
);
}
// let unwind_to = 15;
// let input = UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to };
// let rx = runner.unwind(input);
// assert_matches!(
// rx.await.unwrap(),
// Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
// );
// assert!(runner.validate_unwind(input).is_ok(), "validation failed");
// }
#[tokio::test]
// Check that unwind can remove headers across gaps
async fn unwind_db_gaps() {
let runner = HeadersTestRunner::default();
let head = gen_random_header(0, None);
let first_range = gen_random_header_range(1..20, head.hash());
let second_range = gen_random_header_range(50..100, H256::zero());
runner.insert_header(&head).expect("failed to insert header");
runner
.insert_headers(first_range.iter().chain(second_range.iter()))
.expect("failed to insert headers");
let unwind_to = 15;
let rx =
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
);
runner
.db()
.check_no_entry_above::<tables::CanonicalHeaders, _>(unwind_to, |key| key)
.expect("failed to check cannonical headers");
runner
.db()
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(unwind_to, |val| val)
.expect("failed to check header numbers");
runner
.db()
.check_no_entry_above::<tables::Headers, _>(unwind_to, |key| key.number())
.expect("failed to check headers");
runner
.db()
.check_no_entry_above::<tables::HeaderTD, _>(unwind_to, |key| key.number())
.expect("failed to check td");
}
mod test_utils {
mod test_runner {
use crate::{
stages::headers::HeaderStage,
util::test_utils::{StageTestDB, StageTestRunner},
util::test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner,
},
ExecInput, UnwindInput,
};
use async_trait::async_trait;
use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader};
use reth_interfaces::{
consensus::ForkchoiceState,
db::{self, models::blocks::BlockNumHash, tables, DbTx},
p2p::headers::downloader::{DownloadError, Downloader},
test_utils::{TestConsensus, TestHeadersClient},
p2p::headers::{
client::HeadersClient,
downloader::{DownloadError, Downloader},
},
test_utils::{
gen_random_header, gen_random_header_range, TestConsensus, TestHeadersClient,
},
};
use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256};
use std::{ops::Deref, sync::Arc, time::Duration};
use tokio_stream::StreamExt;
pub(crate) struct HeadersTestRunner<D: Downloader> {
pub(crate) consensus: Arc<TestConsensus>,
pub(crate) client: Arc<TestHeadersClient>,
pub(crate) context: Option<Vec<SealedHeader>>,
downloader: Arc<D>,
db: StageTestDB,
}
impl Default for HeadersTestRunner<TestDownloader> {
fn default() -> Self {
let client = Arc::new(TestHeadersClient::default());
Self {
client: Arc::new(TestHeadersClient::default()),
client: client.clone(),
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(TestDownloader::new(Ok(Vec::default()))),
downloader: Arc::new(TestDownloader::new(client)),
db: StageTestDB::default(),
context: None,
}
}
}
@@ -419,26 +356,99 @@ mod tests {
}
}
#[async_trait::async_trait]
impl<D: Downloader + 'static> ExecuteStageTestRunner for HeadersTestRunner<D> {
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let start = input.stage_progress.unwrap_or_default();
let head = gen_random_header(start, None);
self.insert_header(&head)?;
// use previous progress as seed size
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
if end > start + 1 {
let mut headers = gen_random_header_range(start + 1..end, head.hash());
headers.reverse();
self.context = Some(headers);
}
Ok(())
}
async fn after_execution(
&self,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (tip, headers) = match self.context {
Some(ref headers) if headers.len() > 1 => {
// headers are in reverse
(headers.first().unwrap().hash(), headers.clone())
}
_ => (H256::from_low_u64_be(rand::random()), Vec::default()),
};
self.consensus.update_tip(tip);
if !headers.is_empty() {
self.client
.send_header_response_delayed(
0,
headers.iter().cloned().map(|h| h.unseal()).collect(),
1,
)
.await;
}
Ok(())
}
fn validate_execution(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(ref headers) = self.context {
headers.iter().try_for_each(|h| self.validate_db_header(&h))?;
}
Ok(())
}
}
impl<D: Downloader + 'static> UnwindStageTestRunner for HeadersTestRunner<D> {
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let lowest_entry = input.unwind_to.saturating_sub(100);
let headers = gen_random_header_range(lowest_entry..highest_entry, H256::zero());
self.insert_headers(headers.iter())?;
Ok(())
}
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let unwind_to = input.unwind_to;
self.db().check_no_entry_above_by_value::<tables::HeaderNumbers, _>(
unwind_to,
|val| val,
)?;
self.db()
.check_no_entry_above::<tables::CanonicalHeaders, _>(unwind_to, |key| key)?;
self.db()
.check_no_entry_above::<tables::Headers, _>(unwind_to, |key| key.number())?;
self.db()
.check_no_entry_above::<tables::HeaderTD, _>(unwind_to, |key| key.number())?;
Ok(())
}
}
impl HeadersTestRunner<LinearDownloader<TestConsensus, TestHeadersClient>> {
pub(crate) fn with_linear_downloader() -> Self {
let client = Arc::new(TestHeadersClient::default());
let consensus = Arc::new(TestConsensus::default());
let downloader =
Arc::new(LinearDownloadBuilder::new().build(consensus.clone(), client.clone()));
Self { client, consensus, downloader, db: StageTestDB::default() }
Self { client, consensus, downloader, db: StageTestDB::default(), context: None }
}
}
impl<D: Downloader> HeadersTestRunner<D> {
pub(crate) fn with_downloader(downloader: D) -> Self {
HeadersTestRunner {
client: Arc::new(TestHeadersClient::default()),
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(downloader),
db: StageTestDB::default(),
}
}
/// Insert header into tables
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
self.insert_headers(std::iter::once(header))
@@ -504,12 +514,12 @@ mod tests {
#[derive(Debug)]
pub(crate) struct TestDownloader {
result: Result<Vec<SealedHeader>, DownloadError>,
client: Arc<TestHeadersClient>,
}
impl TestDownloader {
pub(crate) fn new(result: Result<Vec<SealedHeader>, DownloadError>) -> Self {
Self { result }
pub(crate) fn new(client: Arc<TestHeadersClient>) -> Self {
Self { client }
}
}
@@ -527,7 +537,7 @@ mod tests {
}
fn client(&self) -> &Self::Client {
unimplemented!()
&self.client
}
async fn download(
@@ -535,7 +545,15 @@ mod tests {
_: &SealedHeader,
_: &ForkchoiceState,
) -> Result<Vec<SealedHeader>, DownloadError> {
self.result.clone()
let stream = self.client.stream_headers().await;
let stream = stream.timeout(Duration::from_secs(3));
match Box::pin(stream).try_next().await {
Ok(Some(res)) => {
Ok(res.headers.iter().map(|h| h.clone().seal()).collect::<Vec<_>>())
}
_ => Err(DownloadError::Timeout { request_id: 0 }),
}
}
}
}

View File

@@ -87,22 +87,15 @@ impl<DB: Database> Stage<DB> for TxIndex {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::{StageTestDB, StageTestRunner};
use crate::util::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner,
UnwindStageTestRunner, PREV_STAGE_ID,
};
use assert_matches::assert_matches;
use reth_interfaces::{db::models::BlockNumHash, test_utils::gen_random_header_range};
use reth_primitives::H256;
const TEST_STAGE: StageId = StageId("PrevStage");
#[tokio::test]
async fn execute_empty_db() {
let runner = TxIndexTestRunner::default();
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CannonicalHeader { .. }))
);
}
stage_test_suite!(TxIndexTestRunner);
#[tokio::test]
async fn execute_no_prev_tx_count() {
@@ -115,7 +108,7 @@ mod tests {
let (head, tail) = (headers.first().unwrap(), headers.last().unwrap());
let input = ExecInput {
previous_stage: Some((TEST_STAGE, tail.number)),
previous_stage: Some((PREV_STAGE_ID, tail.number)),
stage_progress: Some(head.number),
};
let rx = runner.execute(input);
@@ -125,48 +118,6 @@ mod tests {
);
}
#[tokio::test]
async fn execute() {
let runner = TxIndexTestRunner::default();
let (start, pivot, end) = (0, 100, 200);
let headers = gen_random_header_range(start..end, H256::zero());
runner
.db()
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))
.expect("failed to insert");
runner
.db()
.transform_append::<tables::CumulativeTxCount, _, _>(&headers[..=pivot], |prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
})
.expect("failed to insert");
let (pivot, tail) = (headers.get(pivot).unwrap(), headers.last().unwrap());
let input = ExecInput {
previous_stage: Some((TEST_STAGE, tail.number)),
stage_progress: Some(pivot.number),
};
let rx = runner.execute(input);
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress, done, reached_tip })
if done && reached_tip && stage_progress == tail.number
);
}
#[tokio::test]
async fn unwind_empty_db() {
let runner = TxIndexTestRunner::default();
let rx = runner.unwind(UnwindInput::default());
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 0
);
}
#[tokio::test]
async fn unwind_no_input() {
let runner = TxIndexTestRunner::default();
@@ -239,4 +190,65 @@ mod tests {
TxIndex {}
}
}
impl ExecuteStageTestRunner for TxIndexTestRunner {
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let pivot = input.stage_progress.unwrap_or_default();
let start = pivot.saturating_sub(100);
let end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
let headers = gen_random_header_range(start..end, H256::zero());
self.db()
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))?;
self.db().transform_append::<tables::CumulativeTxCount, _, _>(
&headers[..=(pivot as usize)],
|prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
},
)?;
Ok(())
}
fn validate_execution(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// TODO:
Ok(())
}
}
impl UnwindStageTestRunner for TxIndexTestRunner {
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// TODO: accept range
let headers = gen_random_header_range(input.unwind_to..highest_entry, H256::zero());
self.db().transform_append::<tables::CumulativeTxCount, _, _>(
&headers,
|prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
},
)?;
Ok(())
}
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.db()
.check_no_entry_above::<tables::CumulativeTxCount, _>(input.unwind_to, |h| {
h.number()
})?;
Ok(())
}
}
}

View File

@@ -139,15 +139,18 @@ pub(crate) mod unwind {
#[cfg(test)]
pub(crate) mod test_utils {
use reth_db::{
kv::{test_utils::create_test_db, Env, EnvKind},
mdbx::WriteMap,
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table};
use reth_primitives::BlockNumber;
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput};
/// The previous test stage id mock used for testing
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
@@ -178,6 +181,30 @@ pub(crate) mod test_utils {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
fn commit<F>(&self, f: F) -> Result<(), Error>
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), Error>,
{
let mut db = self.container();
let tx = db.get_mut();
f(tx)?;
db.commit()?;
Ok(())
}
/// Put a single value into the table
pub(crate) fn put<T: Table>(&self, k: T::Key, v: T::Value) -> Result<(), Error> {
self.commit(|tx| tx.put::<T>(k, v))
}
/// Delete a single value from the table
pub(crate) fn delete<T: Table>(&self, k: T::Key) -> Result<(), Error> {
self.commit(|tx| {
tx.delete::<T>(k, None)?;
Ok(())
})
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
@@ -191,14 +218,12 @@ pub(crate) mod test_utils {
S: Clone,
F: FnMut(&S) -> (T::Key, T::Value),
{
let mut db = self.container();
let tx = db.get_mut();
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})?;
db.commit()?;
Ok(())
self.commit(|tx| {
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})
})
}
/// Transform a collection of values using a callback and store
@@ -221,17 +246,15 @@ pub(crate) mod test_utils {
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
{
let mut db = self.container();
let tx = db.get_mut();
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})?;
db.commit()?;
Ok(())
self.commit(|tx| {
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})
})
}
/// Check that there is no table entry above a given
@@ -289,6 +312,18 @@ pub(crate) mod test_utils {
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
}
#[async_trait::async_trait]
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
/// Seed database for stage execution
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
/// Validate stage execution
fn validate_execution(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
@@ -303,6 +338,26 @@ pub(crate) mod test_utils {
rx
}
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
async fn after_execution(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
pub(crate) trait UnwindStageTestRunner: StageTestRunner {
/// Seed database for stage unwind
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
/// Validate the unwind
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
/// Run [Stage::unwind] and return a receiver for the result.
fn unwind(
&self,
@@ -320,4 +375,71 @@ pub(crate) mod test_utils {
rx
}
}
macro_rules! stage_test_suite {
($runner:ident) => {
#[tokio::test]
// Check that the execution errors on empty database or
// prev progress missing from the database.
async fn execute_empty_db() {
let runner = $runner::default();
let rx = runner.execute(crate::stage::ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Err(crate::error::StageError::DatabaseIntegrity(_))
);
assert!(runner.validate_execution().is_ok(), "execution validation");
}
#[tokio::test]
async fn execute() {
let (previous_stage, stage_progress) = (1000, 100);
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution().await.expect("failed to run after execution hook");
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
assert!(runner.validate_execution().is_ok(), "execution validation");
}
#[tokio::test]
// Check that unwind does not panic on empty database.
async fn unwind_empty_db() {
let runner = $runner::default();
let input = crate::stage::UnwindInput::default();
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
#[tokio::test]
async fn unwind() {
let (unwind_to, highest_entry) = (100, 200);
let mut runner = $runner::default();
let input = crate::stage::UnwindInput {
unwind_to, stage_progress: unwind_to, bad_block: None,
};
runner.seed_unwind(input, highest_entry).expect("failed to seed");
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
};
}
pub(crate) use stage_test_suite;
}