revamp body testing

This commit is contained in:
Roman Krasiuk
2022-11-17 15:16:43 +02:00
parent 8c0222a3cc
commit fac647c602
10 changed files with 597 additions and 690 deletions

View File

@@ -1,6 +1,6 @@
//! Testing support for headers related interfaces.
use crate::{
consensus::{self, Consensus, Error},
consensus::{self, Consensus},
p2p::headers::{
client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream},
downloader::HeaderDownloader,
@@ -9,7 +9,11 @@ use crate::{
};
use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512};
use reth_rpc_types::engine::ForkchoiceState;
use std::{collections::HashSet, sync::Arc, time::Duration};
use std::{
collections::HashSet,
sync::{Arc, Mutex, MutexGuard},
time::Duration,
};
use tokio::sync::{broadcast, mpsc, watch};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
@@ -17,12 +21,13 @@ use tokio_stream::{wrappers::BroadcastStream, StreamExt};
#[derive(Debug)]
pub struct TestHeaderDownloader {
client: Arc<TestHeadersClient>,
consensus: Arc<TestConsensus>,
}
impl TestHeaderDownloader {
/// Instantiates the downloader with the mock responses
pub fn new(client: Arc<TestHeadersClient>) -> Self {
Self { client }
pub fn new(client: Arc<TestHeadersClient>, consensus: Arc<TestConsensus>) -> Self {
Self { client, consensus }
}
}
@@ -36,7 +41,7 @@ impl HeaderDownloader for TestHeaderDownloader {
}
fn consensus(&self) -> &Self::Consensus {
unimplemented!()
&self.consensus
}
fn client(&self) -> &Self::Client {
@@ -48,6 +53,12 @@ impl HeaderDownloader for TestHeaderDownloader {
_: &SealedHeader,
_: &ForkchoiceState,
) -> Result<Vec<SealedHeader>, DownloadError> {
// call consensus stub first. fails if the flag is set
let empty = SealedHeader::default();
self.consensus
.validate_header(&empty, &empty)
.map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?;
let stream = self.client.stream_headers().await;
let stream = stream.timeout(Duration::from_secs(1));
@@ -139,7 +150,7 @@ pub struct TestConsensus {
/// Watcher over the forkchoice state
channel: (watch::Sender<ForkchoiceState>, watch::Receiver<ForkchoiceState>),
/// Flag whether the header validation should purposefully fail
fail_validation: bool,
fail_validation: Mutex<bool>,
}
impl Default for TestConsensus {
@@ -150,7 +161,7 @@ impl Default for TestConsensus {
finalized_block_hash: H256::zero(),
safe_block_hash: H256::zero(),
}),
fail_validation: false,
fail_validation: Mutex::new(false),
}
}
}
@@ -166,9 +177,14 @@ impl TestConsensus {
self.channel.0.send(state).expect("updating fork choice state failed");
}
/// Acquire lock on failed validation flag
pub fn fail_validation(&self) -> MutexGuard<'_, bool> {
self.fail_validation.lock().expect("failed to acquite consensus mutex")
}
/// Update the validation flag
pub fn set_fail_validation(&mut self, val: bool) {
self.fail_validation = val;
pub fn set_fail_validation(&self, val: bool) {
*self.fail_validation() = val;
}
}
@@ -183,15 +199,15 @@ impl Consensus for TestConsensus {
_header: &SealedHeader,
_parent: &SealedHeader,
) -> Result<(), consensus::Error> {
if self.fail_validation {
if *self.fail_validation() {
Err(consensus::Error::BaseFeeMissing)
} else {
Ok(())
}
}
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), Error> {
if self.fail_validation {
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), consensus::Error> {
if *self.fail_validation() {
Err(consensus::Error::BaseFeeMissing)
} else {
Ok(())

View File

@@ -20,6 +20,9 @@ mod pipeline;
mod stage;
mod util;
#[cfg(test)]
mod test_utils;
/// Implementations of stages.
pub mod stages;

View File

@@ -15,7 +15,7 @@ use reth_primitives::{
proofs::{EMPTY_LIST_HASH, EMPTY_ROOT},
BlockLocked, BlockNumber, SealedHeader, H256,
};
use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};
use tracing::warn;
const BODIES: StageId = StageId("Bodies");
@@ -51,9 +51,9 @@ const BODIES: StageId = StageId("Bodies");
#[derive(Debug)]
pub struct BodyStage<D: BodyDownloader, C: Consensus> {
/// The body downloader.
pub downloader: D,
pub downloader: Arc<D>,
/// The consensus engine.
pub consensus: C,
pub consensus: Arc<C>,
/// The maximum amount of block bodies to process in one stage execution.
///
/// Smaller batch sizes result in less memory usage, but more disk I/O. Larger batch sizes
@@ -232,67 +232,47 @@ impl<D: BodyDownloader, C: Consensus> BodyStage<D, C> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::{ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID};
use crate::test_utils::{ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, stage_test_suite};
use assert_matches::assert_matches;
use reth_eth_wire::BlockBody;
use reth_interfaces::{
consensus,
p2p::bodies::error::DownloadError,
test_utils::generators::{random_block, random_block_range},
};
use reth_primitives::{BlockNumber, H256};
use std::collections::HashMap;
use test_utils::*;
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
async fn empty_db() {
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true })
)
}
stage_test_suite!(BodyTestRunner);
/// Check that the execution is short-circuited if the target was already reached.
#[tokio::test]
async fn already_reached_target() {
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.execute(ExecInput {
previous_stage: Some((PREV_STAGE_ID, 100)),
stage_progress: Some(100),
});
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress: 100, reached_tip: true, done: true })
)
}
/// Check that the execution is short-circuited if the database is empty.
// #[tokio::test]
// TODO:
// async fn empty_db() {
// let runner = BodyTestRunner::new(TestBodyDownloader::default);
// let rx = runner.execute(ExecInput::default());
// assert_matches!(
// rx.await.unwrap(),
// Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true })
// )
// }
/// Checks that the stage downloads at most `batch_size` blocks.
#[tokio::test]
async fn partial_body_download() {
// Generate blocks
let blocks = random_block_range(1..200, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 200);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size (max we sync per stage execution) to less than the number of blocks
// the previous stage synced (10 vs 20)
runner.set_batch_size(10);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
};
let rx = runner.execute(input);
// Check that we only synced around `batch_size` blocks even though the number of blocks
@@ -308,26 +288,20 @@ mod tests {
/// Same as [partial_body_download] except the `batch_size` is not hit.
#[tokio::test]
async fn full_body_download() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 21);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
};
let rx = runner.execute(input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
@@ -343,23 +317,18 @@ mod tests {
/// Same as [full_body_download] except we have made progress before
#[tokio::test]
async fn sync_from_previous_progress() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 21);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we synced at least 10 blocks
let first_run = rx.await.unwrap();
@@ -371,7 +340,7 @@ mod tests {
// Execute again on top of the previous run
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(first_run_progress),
};
let rx = runner.execute(input);
@@ -388,118 +357,48 @@ mod tests {
/// Checks that the stage asks to unwind if pre-validation of the block fails.
#[tokio::test]
async fn pre_validation_failure() {
// Generate blocks #1-19
let blocks = random_block_range(1..20, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 20);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Fail validation
runner.set_fail_validation(true);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
runner.consensus.set_fail_validation(true);
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that the error bubbles up
assert_matches!(
rx.await.unwrap(),
Err(StageError::Validation { block: 1, error: consensus::Error::BaseFeeMissing })
Err(StageError::Validation { error: consensus::Error::BaseFeeMissing, .. })
);
}
/// Checks that the stage unwinds correctly with no data.
#[tokio::test]
async fn unwind_empty_db() {
let unwind_to = 10;
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let input = UnwindInput { bad_block: None, stage_progress: 20, unwind_to };
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
/// Checks that the stage unwinds correctly with data.
#[tokio::test]
async fn unwind() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let execute_input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
};
let rx = runner.execute(execute_input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
let output = rx.await.unwrap();
assert_matches!(
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
);
let output = output.unwrap();
assert!(runner.validate_execution(execute_input, Some(output.clone())).is_ok(), "execution validation");
// Unwind all of it
let unwind_to = 1;
let unwind_input = UnwindInput { bad_block: None, stage_progress: output.stage_progress, unwind_to };
let rx = runner.unwind(unwind_input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
);
assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation");
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
}
/// Checks that the stage unwinds correctly, even if a transaction in a block is missing.
#[tokio::test]
async fn unwind_missing_tx() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 21);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
@@ -536,42 +435,42 @@ mod tests {
/// try again?
#[tokio::test]
async fn downloader_timeout() {
// Generate a header
let header = random_block(1, Some(GENESIS_HASH)).header;
let runner = BodyTestRunner::new(|| {
TestBodyDownloader::new(HashMap::from([(
header.hash(),
Err(DownloadError::Timeout { header_hash: header.hash() }),
)]))
});
let (stage_progress, previous_stage) = (1, 3);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner.insert_header(&header).expect("Could not insert header");
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress), // TODO: None?
};
let blocks = runner.seed_execution(input).expect("failed to seed execution");
// overwrite responses
let header = blocks.last().unwrap();
runner.set_responses(HashMap::from([(
header.hash(),
Err(DownloadError::Timeout { header_hash: header.hash() }),
)]));
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((PREV_STAGE_ID, 1)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that the error bubbles up
assert_matches!(rx.await.unwrap(), Err(StageError::Internal(_)));
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
}
mod test_utils {
use crate::{
stages::bodies::BodyStage,
util::test_utils::{
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner, TestRunnerError,
},
ExecInput, UnwindInput, ExecOutput,
};
use assert_matches::assert_matches;
use async_trait::async_trait;
use reth_eth_wire::BlockBody;
use reth_interfaces::{
db,
db::{
models::{BlockNumHash, StoredBlockBody},
tables, DbCursorRO, DbTx, DbTxMut,
@@ -581,12 +480,12 @@ mod tests {
downloader::{BodiesStream, BodyDownloader},
error::{BodiesClientError, DownloadError},
},
test_utils::TestConsensus,
test_utils::{TestConsensus, generators::random_block_range},
};
use reth_primitives::{
BigEndianHash, BlockLocked, BlockNumber, Header, SealedHeader, H256, U256,
};
use std::{collections::HashMap, ops::Deref, time::Duration};
use std::{collections::HashMap, ops::Deref, time::Duration, sync::Arc};
/// The block hash of the genesis block.
pub(crate) const GENESIS_HASH: H256 = H256::zero();
@@ -605,43 +504,35 @@ mod tests {
}
/// A helper struct for running the [BodyStage].
pub(crate) struct BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
downloader_builder: F,
pub(crate) struct BodyTestRunner {
pub(crate) consensus: Arc<TestConsensus>,
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
db: StageTestDB,
batch_size: u64,
fail_validation: bool,
}
impl<F> BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
/// Build a new test runner.
pub(crate) fn new(downloader_builder: F) -> Self {
BodyTestRunner {
downloader_builder,
impl Default for BodyTestRunner {
fn default() -> Self {
Self {
consensus: Arc::new(TestConsensus::default()),
responses: HashMap::default(),
db: StageTestDB::default(),
batch_size: 10,
fail_validation: false,
}
}
}
impl BodyTestRunner {
pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
self.batch_size = batch_size;
}
pub(crate) fn set_fail_validation(&mut self, fail_validation: bool) {
self.fail_validation = fail_validation;
pub(crate) fn set_responses(&mut self, responses: HashMap<H256, Result<BlockBody, DownloadError>>) {
self.responses = responses;
}
}
impl<F> StageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
impl StageTestRunner for BodyTestRunner {
type S = BodyStage<TestBodyDownloader, TestConsensus>;
fn db(&self) -> &StageTestDB {
@@ -649,33 +540,29 @@ mod tests {
}
fn stage(&self) -> Self::S {
let mut consensus = TestConsensus::default();
consensus.set_fail_validation(self.fail_validation);
BodyStage {
downloader: (self.downloader_builder)(),
consensus,
downloader: Arc::new(TestBodyDownloader::new(self.responses.clone())),
consensus: self.consensus.clone(),
batch_size: self.batch_size,
}
}
}
impl<F> ExecuteStageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
type Seed = ();
#[async_trait::async_trait]
impl ExecuteStageTestRunner for BodyTestRunner {
type Seed = Vec<BlockLocked>;
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), TestRunnerError> {
) -> Result<Self::Seed, TestRunnerError> {
let start = input.stage_progress.unwrap_or_default();
let end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
let blocks = random_block_range(start..end, GENESIS_HASH);
self.insert_genesis()?;
// TODO:
// self
// .insert_headers(blocks.iter().map(|block| &block.header))
// .expect("Could not insert headers");
Ok(())
self.insert_headers(blocks.iter().map(|block| &block.header))?;
self.set_responses(blocks.iter().map(body_by_hash).collect());
Ok(blocks)
}
fn validate_execution(
@@ -683,31 +570,21 @@ mod tests {
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
if let Some(output) = output {
self.validate_db_blocks(output.stage_progress)?;
}
Ok(())
let highest_block = match output.as_ref() {
Some(output) => output.stage_progress,
None => input.stage_progress.unwrap_or_default(),
};
self.validate_db_blocks(highest_block)
}
}
impl<F> UnwindStageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), TestRunnerError> {
unimplemented!()
}
impl UnwindStageTestRunner for BodyTestRunner {
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::BlockBodies, _>(input.unwind_to, |key| key.number())?;
if let Some(last_body) =self.last_body(){
if let Some(last_body) = self.last_body() {
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
self.db.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
}
@@ -715,15 +592,12 @@ mod tests {
}
}
impl<F> BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
impl BodyTestRunner {
/// Insert the genesis block into the appropriate tables
///
/// The genesis block always has no transactions and no ommers, and it always has the
/// same hash.
pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> {
pub(crate) fn insert_genesis(&self) -> Result<(), TestRunnerError> {
self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?;
self.db.commit(|tx| {
tx.put::<tables::BlockBodies>(
@@ -736,13 +610,14 @@ mod tests {
}
/// Insert header into tables
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
self.insert_headers(std::iter::once(header))
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), TestRunnerError> {
self.insert_headers(std::iter::once(header))?;
Ok(())
}
/// Insert headers into tables
/// TODO: move to common inserter
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), TestRunnerError>
where
I: Iterator<Item = &'a SealedHeader>,
{
@@ -778,7 +653,7 @@ mod tests {
pub(crate) fn validate_db_blocks(
&self,
highest_block: BlockNumber,
) -> Result<(), db::Error> {
) -> Result<(), TestRunnerError> {
self.db.query(|tx| {
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
@@ -806,7 +681,8 @@ mod tests {
}
Ok(())
})
})?;
Ok(())
}
}
@@ -815,7 +691,7 @@ mod tests {
#[derive(Debug)]
pub(crate) struct NoopClient;
#[async_trait]
#[async_trait::async_trait]
impl BodiesClient for NoopClient {
async fn get_block_body(&self, _: H256) -> Result<BlockBody, BodiesClientError> {
panic!("Noop client should not be called")
@@ -824,7 +700,7 @@ mod tests {
// TODO(onbjerg): Move
/// A [BodyDownloader] that is backed by an internal [HashMap] for testing.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub(crate) struct TestBodyDownloader {
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
}
@@ -854,14 +730,11 @@ mod tests {
{
Box::pin(futures_util::stream::iter(hashes.into_iter().map(
|(block_number, hash)| {
Ok((
*block_number,
*hash,
self.responses
.get(hash)
.expect("Stage tried downloading a block we do not have.")
.clone()?,
))
let result = self.responses
.get(hash)
.expect("Stage tried downloading a block we do not have.")
.clone()?;
Ok((*block_number, *hash, result))
},
)))
}

View File

@@ -190,7 +190,7 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::{
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID,
};
use assert_matches::assert_matches;
@@ -198,8 +198,6 @@ mod tests {
stage_test_suite!(HeadersTestRunner);
// TODO: test consensus propagation error
/// Check that the execution errors on empty database or
/// prev progress missing from the database.
#[tokio::test]
@@ -218,6 +216,20 @@ mod tests {
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Check that validation error is propagated during the execution.
#[tokio::test]
async fn execute_validation_error() {
let mut runner = HeadersTestRunner::default();
runner.consensus.set_fail_validation(true);
let input = ExecInput::default();
let seed = runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(result, Err(StageError::Validation { .. }));
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Execute the stage with linear downloader
#[tokio::test]
async fn execute_with_linear_downloader() {
@@ -254,7 +266,7 @@ mod tests {
mod test_runner {
use crate::{
stages::headers::HeaderStage,
util::test_utils::{
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
},
@@ -262,14 +274,14 @@ mod tests {
};
use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader};
use reth_interfaces::{
db::{self, models::blocks::BlockNumHash, tables, DbTx},
db::{models::blocks::BlockNumHash, tables, DbTx},
p2p::headers::downloader::HeaderDownloader,
test_utils::{
generators::{random_header, random_header_range},
TestConsensus, TestHeaderDownloader, TestHeadersClient,
},
};
use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256};
use reth_primitives::{rpc::BigEndianHash, BlockNumber, SealedHeader, H256, U256};
use std::{ops::Deref, sync::Arc};
pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
@@ -282,10 +294,11 @@ mod tests {
impl Default for HeadersTestRunner<TestHeaderDownloader> {
fn default() -> Self {
let client = Arc::new(TestHeadersClient::default());
let consensus = Arc::new(TestConsensus::default());
Self {
client: client.clone(),
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(TestHeaderDownloader::new(client)),
consensus: consensus.clone(),
downloader: Arc::new(TestHeaderDownloader::new(client, consensus)),
db: StageTestDB::default(),
}
}
@@ -345,45 +358,58 @@ mod tests {
Ok(())
}
/// Validate stored headers
fn validate_execution(
&self,
_input: ExecInput,
_output: Option<ExecOutput>,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
// TODO: refine
// if let Some(ref headers) = self.context {
// // skip head and validate each
// headers.iter().skip(1).try_for_each(|h| self.validate_db_header(&h))?;
// }
let initial_stage_progress = input.stage_progress.unwrap_or_default();
match output {
Some(output) if output.stage_progress > initial_stage_progress => {
self.db.query(|tx| {
for block_num in (initial_stage_progress..output.stage_progress).rev() {
// look up the header hash
let hash = tx
.get::<tables::CanonicalHeaders>(block_num)?
.expect("no header hash");
let key: BlockNumHash = (block_num, hash).into();
// validate the header number
assert_eq!(tx.get::<tables::HeaderNumbers>(hash)?, Some(block_num));
// validate the header
let header = tx.get::<tables::Headers>(key)?;
assert!(header.is_some());
let header = header.unwrap().seal();
assert_eq!(header.hash(), hash);
// validate td consistency in the database
if header.number > initial_stage_progress {
let parent_td = tx.get::<tables::HeaderTD>(
(header.number - 1, header.parent_hash).into(),
)?;
let td = tx.get::<tables::HeaderTD>(key)?.unwrap();
assert_eq!(
parent_td.map(
|td| U256::from_big_endian(&td) + header.difficulty
),
Some(U256::from_big_endian(&td))
);
}
}
Ok(())
})?;
}
_ => self.check_no_header_entry_above(initial_stage_progress)?,
};
Ok(())
}
}
impl<D: HeaderDownloader + 'static> UnwindStageTestRunner for HeadersTestRunner<D> {
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), TestRunnerError> {
let lowest_entry = input.unwind_to.saturating_sub(100);
let headers = random_header_range(lowest_entry..highest_entry, H256::zero());
self.insert_headers(headers.iter())?;
Ok(())
}
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
let unwind_to = input.unwind_to;
self.db.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(
unwind_to,
|val| val,
)?;
self.db
.check_no_entry_above::<tables::CanonicalHeaders, _>(unwind_to, |key| key)?;
self.db
.check_no_entry_above::<tables::Headers, _>(unwind_to, |key| key.number())?;
self.db
.check_no_entry_above::<tables::HeaderTD, _>(unwind_to, |key| key.number())?;
Ok(())
self.check_no_header_entry_above(input.unwind_to)
}
}
@@ -400,12 +426,15 @@ mod tests {
impl<D: HeaderDownloader> HeadersTestRunner<D> {
/// Insert header into tables
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
pub(crate) fn insert_header(
&self,
header: &SealedHeader,
) -> Result<(), TestRunnerError> {
self.insert_headers(std::iter::once(header))
}
/// Insert headers into tables
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), TestRunnerError>
where
I: Iterator<Item = &'a SealedHeader>,
{
@@ -430,36 +459,16 @@ mod tests {
Ok(())
}
/// Validate stored header against provided
pub(crate) fn validate_db_header(
pub(crate) fn check_no_header_entry_above(
&self,
header: &SealedHeader,
) -> Result<(), db::Error> {
self.db.query(|tx| {
let key: BlockNumHash = (header.number, header.hash()).into();
let db_number = tx.get::<tables::HeaderNumbers>(header.hash())?;
assert_eq!(db_number, Some(header.number));
let db_header = tx.get::<tables::Headers>(key)?;
assert_eq!(db_header, Some(header.clone().unseal()));
let db_canonical_header = tx.get::<tables::CanonicalHeaders>(header.number)?;
assert_eq!(db_canonical_header, Some(header.hash()));
if header.number != 0 {
let parent_key: BlockNumHash =
(header.number - 1, header.parent_hash).into();
let parent_td = tx.get::<tables::HeaderTD>(parent_key)?;
let td = U256::from_big_endian(&tx.get::<tables::HeaderTD>(key)?.unwrap());
assert_eq!(
parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty),
Some(td)
);
}
Ok(())
})
block: BlockNumber,
) -> Result<(), TestRunnerError> {
self.db
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
self.db.check_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
self.db.check_no_entry_above::<tables::Headers, _>(block, |key| key.number())?;
self.db.check_no_entry_above::<tables::HeaderTD, _>(block, |key| key.number())?;
Ok(())
}
}
}

View File

@@ -87,7 +87,7 @@ impl<DB: Database> Stage<DB> for TxIndex {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::{
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
};
@@ -96,7 +96,7 @@ mod tests {
db::models::{BlockNumHash, StoredBlockBody},
test_utils::generators::random_header_range,
};
use reth_primitives::{SealedHeader, H256};
use reth_primitives::H256;
stage_test_suite!(TxIndexTestRunner);
@@ -155,7 +155,7 @@ mod tests {
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
_output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
self.db.query(|tx| {
let (start, end) = (
@@ -187,21 +187,6 @@ mod tests {
}
impl UnwindStageTestRunner for TxIndexTestRunner {
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), TestRunnerError> {
let headers = random_header_range(input.unwind_to..highest_entry, H256::zero());
self.db.transform_append::<tables::CumulativeTxCount, _, _>(&headers, |prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
})?;
Ok(())
}
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::CumulativeTxCount, _>(input.unwind_to, |h| {
h.number()

View File

@@ -0,0 +1,106 @@
// TODO: add comments
macro_rules! stage_test_suite {
($runner:ident) => {
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
async fn execute_empty_db() {
let runner = $runner::default();
let input = crate::stage::ExecInput::default();
let result = runner.execute(input).await.unwrap();
assert_matches!(
result,
Err(crate::error::StageError::DatabaseIntegrity(_))
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
/// Check that the execution is short-circuited if the target was already reached.
#[tokio::test]
async fn execute_already_reached_target() {
let stage_progress = 1000;
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, stage_progress)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == stage_progress
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn execute() {
let (previous_stage, stage_progress) = (1000, 100);
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
println!("RESULT >>> {:?}", result.unwrap_err().to_string());
// assert_matches!(
// result,
// Ok(ExecOutput { done, reached_tip, stage_progress })
// if done && reached_tip && stage_progress == previous_stage
// );
// assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
// Check that unwind does not panic on empty database.
async fn unwind_empty_db() {
let runner = $runner::default();
let input = crate::stage::UnwindInput::default();
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
#[tokio::test]
async fn unwind() {
let (previous_stage, stage_progress) = (1000, 100);
let mut runner = $runner::default();
// Run execute
let execute_input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(execute_input).expect("failed to seed");
let rx = runner.execute(execute_input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
assert!(runner.validate_execution(execute_input, result.ok()).is_ok(), "execution validation");
let unwind_input = crate::stage::UnwindInput {
unwind_to: stage_progress, stage_progress, bad_block: None,
};
let rx = runner.unwind(unwind_input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to
);
assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation");
}
};
}
pub(crate) use stage_test_suite;

View File

@@ -0,0 +1,14 @@
use crate::StageId;
mod macros;
mod runner;
mod stage_db;
pub(crate) use macros::*;
pub(crate) use runner::{
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner,
};
pub(crate) use stage_db::StageTestDB;
/// The previous test stage id mock used for testing
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");

View File

@@ -0,0 +1,81 @@
use reth_db::{kv::Env, mdbx::WriteMap};
use reth_interfaces::db::{self, DBContainer};
use std::borrow::Borrow;
use tokio::sync::oneshot;
use super::StageTestDB;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
#[derive(thiserror::Error, Debug)]
pub(crate) enum TestRunnerError {
#[error("Database error occured.")]
Database(#[from] db::Error),
#[error("Internal runner error occured.")]
Internal(#[from] Box<dyn std::error::Error>),
}
/// A generic test runner for stages.
#[async_trait::async_trait]
pub(crate) trait StageTestRunner {
type S: Stage<Env<WriteMap>> + 'static;
/// Return a reference to the database.
fn db(&self) -> &StageTestDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
}
#[async_trait::async_trait]
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
type Seed: Send + Sync;
/// Seed database for stage execution
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError>;
/// Validate stage execution
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError>;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
});
rx
}
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
async fn after_execution(&self, _seed: Self::Seed) -> Result<(), TestRunnerError> {
Ok(())
}
}
pub(crate) trait UnwindStageTestRunner: StageTestRunner {
/// Validate the unwind
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>;
/// Run [Stage::unwind] and return a receiver for the result.
fn unwind(
&self,
input: UnwindInput,
) -> oneshot::Receiver<Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");
});
rx
}
}

View File

@@ -0,0 +1,149 @@
use reth_db::{
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{self, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table};
use reth_primitives::BlockNumber;
use std::{borrow::Borrow, sync::Arc};
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
///
/// ```rust
/// let db = StageTestDB::default();
/// stage.execute(&mut db.container(), input);
/// ```
pub(crate) struct StageTestDB {
db: Arc<Env<WriteMap>>,
}
impl Default for StageTestDB {
/// Create a new instance of [StageTestDB]
fn default() -> Self {
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
}
}
impl StageTestDB {
/// Return a database wrapped in [DBContainer].
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
/// Invoke a callback with transaction committing it afterwards
pub(crate) fn commit<F>(&self, f: F) -> Result<(), db::Error>
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
{
let mut db = self.container();
let tx = db.get_mut();
f(tx)?;
db.commit()?;
Ok(())
}
/// Invoke a callback with a read transaction
pub(crate) fn query<F, R>(&self, f: F) -> Result<R, db::Error>
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
{
f(self.container().get())
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), db::Error>
where
T: Table,
S: Clone,
F: FnMut(&S) -> (T::Key, T::Value),
{
self.commit(|tx| {
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})
})
}
/// Transform a collection of values using a callback and store
/// them in the database. The callback additionally accepts the
/// optional last element that was stored.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
/// ```
pub(crate) fn transform_append<T, S, F>(
&self,
values: &[S],
mut transform: F,
) -> Result<(), db::Error>
where
T: Table,
<T as Table>::Value: Clone,
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
{
self.commit(|tx| {
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})
})
}
/// Check that there is no table entry above a given
/// block by [Table::Key]
pub(crate) fn check_no_entry_above<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((key, _)) = cursor.last()? {
assert!(selector(key) <= block);
}
Ok(())
})
}
/// Check that there is no table entry above a given
/// block by [Table::Value]
pub(crate) fn check_no_entry_above_by_value<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((_, value)) = cursor.last()? {
assert!(selector(value) <= block);
}
Ok(())
})
}
}

View File

@@ -135,332 +135,3 @@ pub(crate) mod unwind {
Ok(())
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use reth_db::{
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{self, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table};
use reth_primitives::BlockNumber;
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput};
/// The previous test stage id mock used for testing
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
///
/// ```rust
/// let db = StageTestDB::default();
/// stage.execute(&mut db.container(), input);
/// ```
pub(crate) struct StageTestDB {
db: Arc<Env<WriteMap>>,
}
impl Default for StageTestDB {
/// Create a new instance of [StageTestDB]
fn default() -> Self {
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
}
}
impl StageTestDB {
/// Return a database wrapped in [DBContainer].
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
/// Invoke a callback with transaction committing it afterwards
pub(crate) fn commit<F>(&self, f: F) -> Result<(), db::Error>
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
{
let mut db = self.container();
let tx = db.get_mut();
f(tx)?;
db.commit()?;
Ok(())
}
/// Invoke a callback with a read transaction
pub(crate) fn query<F, R>(&self, f: F) -> Result<R, db::Error>
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
{
f(self.container().get())
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), db::Error>
where
T: Table,
S: Clone,
F: FnMut(&S) -> (T::Key, T::Value),
{
self.commit(|tx| {
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})
})
}
/// Transform a collection of values using a callback and store
/// them in the database. The callback additionally accepts the
/// optional last element that was stored.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
/// ```
pub(crate) fn transform_append<T, S, F>(
&self,
values: &[S],
mut transform: F,
) -> Result<(), db::Error>
where
T: Table,
<T as Table>::Value: Clone,
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
{
self.commit(|tx| {
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})
})
}
/// Check that there is no table entry above a given
/// block by [Table::Key]
pub(crate) fn check_no_entry_above<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((key, _)) = cursor.last()? {
assert!(selector(key) <= block);
}
Ok(())
})
}
/// Check that there is no table entry above a given
/// block by [Table::Value]
pub(crate) fn check_no_entry_above_by_value<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((_, value)) = cursor.last()? {
assert!(selector(value) <= block);
}
Ok(())
})
}
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum TestRunnerError {
#[error("Database error occured.")]
Database(#[from] db::Error),
#[error("Internal runner error occured.")]
Internal(#[from] Box<dyn std::error::Error>),
}
/// A generic test runner for stages.
#[async_trait::async_trait]
pub(crate) trait StageTestRunner {
type S: Stage<Env<WriteMap>> + 'static;
/// Return a reference to the database.
fn db(&self) -> &StageTestDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
}
#[async_trait::async_trait]
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
type Seed: Send + Sync;
/// Seed database for stage execution
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError>;
/// Validate stage execution
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError>;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
});
rx
}
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
async fn after_execution(&self, seed: Self::Seed) -> Result<(), TestRunnerError> {
Ok(())
}
}
pub(crate) trait UnwindStageTestRunner: StageTestRunner {
/// Seed database for stage unwind
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), TestRunnerError>;
/// Validate the unwind
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>;
/// Run [Stage::unwind] and return a receiver for the result.
fn unwind(
&self,
input: UnwindInput,
) -> oneshot::Receiver<Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>>>
{
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");
});
rx
}
}
macro_rules! stage_test_suite {
($runner:ident) => {
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
async fn execute_empty_db() {
let runner = $runner::default();
let input = crate::stage::ExecInput::default();
let result = runner.execute(input).await.unwrap();
assert_matches!(
result,
Err(crate::error::StageError::DatabaseIntegrity(_))
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn execute_already_reached_target() {
let stage_progress = 1000;
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, stage_progress)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == stage_progress
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn execute() {
let (previous_stage, stage_progress) = (1000, 100);
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
// Check that unwind does not panic on empty database.
async fn unwind_empty_db() {
let runner = $runner::default();
let input = crate::stage::UnwindInput::default();
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
#[tokio::test]
async fn unwind() {
let (unwind_to, highest_entry) = (100, 200);
let mut runner = $runner::default();
let input = crate::stage::UnwindInput {
unwind_to, stage_progress: unwind_to, bad_block: None,
};
runner.seed_unwind(input, highest_entry).expect("failed to seed");
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
};
}
pub(crate) use stage_test_suite;
}