start revamping bodies testing

This commit is contained in:
Roman Krasiuk
2022-11-16 18:25:35 +02:00
parent a252f21bdc
commit 8c0222a3cc
4 changed files with 277 additions and 240 deletions

View File

@@ -232,7 +232,7 @@ impl<D: BodyDownloader, C: Consensus> BodyStage<D, C> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::StageTestRunner;
use crate::util::test_utils::{ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID};
use assert_matches::assert_matches;
use reth_eth_wire::BlockBody;
use reth_interfaces::{
@@ -260,7 +260,7 @@ mod tests {
async fn already_reached_target() {
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), 100)),
previous_stage: Some((PREV_STAGE_ID, 100)),
stage_progress: Some(100),
});
assert_matches!(
@@ -289,10 +289,11 @@ mod tests {
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
};
let rx = runner.execute(input);
// Check that we only synced around `batch_size` blocks even though the number of blocks
// synced by the previous stage is higher
@@ -301,9 +302,7 @@ mod tests {
output,
Ok(ExecOutput { stage_progress, reached_tip: true, done: false }) if stage_progress < 200
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Same as [partial_body_download] except the `batch_size` is not hit.
@@ -325,10 +324,11 @@ mod tests {
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
};
let rx = runner.execute(input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
@@ -337,9 +337,7 @@ mod tests {
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Same as [full_body_download] except we have made progress before
@@ -359,7 +357,7 @@ mod tests {
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
@@ -372,10 +370,11 @@ mod tests {
let first_run_progress = first_run.unwrap().stage_progress;
// Execute again on top of the previous run
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: Some(first_run_progress),
});
};
let rx = runner.execute(input);
// Check that we synced more blocks
let output = rx.await.unwrap();
@@ -383,9 +382,7 @@ mod tests {
output,
Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress > first_run_progress
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Checks that the stage asks to unwind if pre-validation of the block fails.
@@ -408,7 +405,7 @@ mod tests {
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
@@ -424,11 +421,13 @@ mod tests {
async fn unwind_empty_db() {
let unwind_to = 10;
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress: 20, unwind_to });
let input = UnwindInput { bad_block: None, stage_progress: 20, unwind_to };
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
)
);
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
/// Checks that the stage unwinds correctly with data.
@@ -450,10 +449,11 @@ mod tests {
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
let execute_input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
};
let rx = runner.execute(execute_input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
@@ -462,27 +462,19 @@ mod tests {
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
);
let stage_progress = output.unwrap().stage_progress;
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
let output = output.unwrap();
assert!(runner.validate_execution(execute_input, Some(output.clone())).is_ok(), "execution validation");
// Unwind all of it
let unwind_to = 1;
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
let unwind_input = UnwindInput { bad_block: None, stage_progress: output.stage_progress, unwind_to };
let rx = runner.unwind(unwind_input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
);
let last_body = runner.last_body().expect("Could not read last body");
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
runner
.db()
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
.expect("Did not unwind block bodies correctly.");
runner
.db()
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
.expect("Did not unwind transactions correctly.")
assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation");
}
/// Checks that the stage unwinds correctly, even if a transaction in a block is missing.
@@ -505,7 +497,7 @@ mod tests {
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
previous_stage: Some((PREV_STAGE_ID, blocks.len() as BlockNumber)),
stage_progress: None,
});
@@ -520,38 +512,23 @@ mod tests {
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
// Delete a transaction
{
let mut db = runner.db().container();
let mut tx_cursor = db
.get_mut()
.cursor_mut::<tables::Transactions>()
.expect("Could not get transaction cursor");
tx_cursor
.last()
.expect("Could not read database")
.expect("Could not read last transaction");
tx_cursor.delete_current().expect("Could not delete last transaction");
db.commit().expect("Could not commit database");
}
runner.db().commit(|tx| {
let mut tx_cursor = tx.cursor_mut::<tables::Transactions>()?;
tx_cursor.last()?.expect("Could not read last transaction");
tx_cursor.delete_current()?;
Ok(())
}).expect("Could not delete a transaction");
// Unwind all of it
let unwind_to = 1;
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
let input = UnwindInput { bad_block: None, stage_progress, unwind_to };
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
);
let last_body = runner.last_body().expect("Could not read last body");
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
runner
.db()
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
.expect("Did not unwind block bodies correctly.");
runner
.db()
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
.expect("Did not unwind transactions correctly.")
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
/// Checks that the stage exits if the downloader times out
@@ -574,7 +551,7 @@ mod tests {
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), 1)),
previous_stage: Some((PREV_STAGE_ID, 1)),
stage_progress: None,
});
@@ -585,7 +562,10 @@ mod tests {
mod test_utils {
use crate::{
stages::bodies::BodyStage,
util::test_utils::{StageTestDB, StageTestRunner},
util::test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner, TestRunnerError,
},
ExecInput, UnwindInput, ExecOutput,
};
use assert_matches::assert_matches;
use async_trait::async_trait;
@@ -680,6 +660,61 @@ mod tests {
}
}
impl<F> ExecuteStageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
type Seed = ();
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), TestRunnerError> {
self.insert_genesis()?;
// TODO:
// self
// .insert_headers(blocks.iter().map(|block| &block.header))
// .expect("Could not insert headers");
Ok(())
}
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
if let Some(output) = output {
self.validate_db_blocks(output.stage_progress)?;
}
Ok(())
}
}
impl<F> UnwindStageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
fn seed_unwind(
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), TestRunnerError> {
unimplemented!()
}
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::BlockBodies, _>(input.unwind_to, |key| key.number())?;
if let Some(last_body) =self.last_body(){
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
self.db.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
}
Ok(())
}
}
impl<F> BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
@@ -690,13 +725,12 @@ mod tests {
/// same hash.
pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> {
self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?;
let mut db = self.db.container();
let tx = db.get_mut();
tx.put::<tables::BlockBodies>(
(0, GENESIS_HASH).into(),
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
)?;
db.commit()?;
self.db.commit(|tx| {
tx.put::<tables::BlockBodies>(
(0, GENESIS_HASH).into(),
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
)
})?;
Ok(())
}
@@ -707,6 +741,7 @@ mod tests {
}
/// Insert headers into tables
/// TODO: move to common inserter
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
where
I: Iterator<Item = &'a SealedHeader>,
@@ -733,9 +768,10 @@ mod tests {
}
pub(crate) fn last_body(&self) -> Option<StoredBlockBody> {
Some(
self.db.container().get().cursor::<tables::BlockBodies>().ok()?.last().ok()??.1,
)
self.db
.query(|tx| Ok(tx.cursor::<tables::BlockBodies>()?.last()?.map(|e| e.1)))
.ok()
.flatten()
}
/// Validate that the inserted block data is valid
@@ -743,35 +779,34 @@ mod tests {
&self,
highest_block: BlockNumber,
) -> Result<(), db::Error> {
let db = self.db.container();
let tx = db.get();
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
let mut entry = block_body_cursor.first()?;
let mut prev_max_tx_id = 0;
while let Some((key, body)) = entry {
assert!(
key.number() <= highest_block,
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
key.number(), highest_block
);
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
for num in 0..body.tx_amount {
let tx_id = body.base_tx_id + num;
assert_matches!(
transaction_cursor.seek_exact(tx_id),
Ok(Some(_)),
"A transaction is missing."
self.db.query(|tx| {
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
let mut entry = block_body_cursor.first()?;
let mut prev_max_tx_id = 0;
while let Some((key, body)) = entry {
assert!(
key.number() <= highest_block,
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
key.number(), highest_block
);
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
for num in 0..body.tx_amount {
let tx_id = body.base_tx_id + num;
assert_matches!(
transaction_cursor.seek_exact(tx_id),
Ok(Some(_)),
"A transaction is missing."
);
}
prev_max_tx_id = body.base_tx_id + body.tx_amount;
entry = block_body_cursor.next()?;
}
prev_max_tx_id = body.base_tx_id + body.tx_amount;
entry = block_body_cursor.next()?;
}
Ok(())
Ok(())
})
}
}

View File

@@ -198,6 +198,8 @@ mod tests {
stage_test_suite!(HeadersTestRunner);
// TODO: test consensus propagation error
/// Check that the execution errors on empty database or
/// prev progress missing from the database.
#[tokio::test]
@@ -208,12 +210,12 @@ mod tests {
runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
runner.consensus.update_tip(H256::from_low_u64_be(1));
let result = rx.await.unwrap();
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress: out_stage_progress })
if !done && !reached_tip && out_stage_progress == 0
result,
Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 0 })
);
assert!(runner.validate_execution(input).is_ok(), "validation failed");
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Execute the stage with linear downloader
@@ -225,11 +227,10 @@ mod tests {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
let headers = runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
// skip `after_execution` hook for linear downloader
let headers = runner.context.as_ref().unwrap();
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
@@ -242,21 +243,22 @@ mod tests {
})
.await;
let result = rx.await.unwrap();
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
result,
Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) if stage_progress == tip.number
);
assert!(runner.validate_execution(input).is_ok(), "validation failed");
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
mod test_runner {
use crate::{
stages::headers::HeaderStage,
util::test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, UnwindStageTestRunner,
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
},
ExecInput, UnwindInput,
ExecInput, ExecOutput, UnwindInput,
};
use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader};
use reth_interfaces::{
@@ -273,7 +275,6 @@ mod tests {
pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
pub(crate) consensus: Arc<TestConsensus>,
pub(crate) client: Arc<TestHeadersClient>,
pub(crate) context: Option<Vec<SealedHeader>>,
downloader: Arc<D>,
db: StageTestDB,
}
@@ -286,7 +287,6 @@ mod tests {
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(TestHeaderDownloader::new(client)),
db: StageTestDB::default(),
context: None,
}
}
}
@@ -309,38 +309,36 @@ mod tests {
#[async_trait::async_trait]
impl<D: HeaderDownloader + 'static> ExecuteStageTestRunner for HeadersTestRunner<D> {
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
type Seed = Vec<SealedHeader>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.stage_progress.unwrap_or_default();
let head = random_header(start, None);
self.insert_header(&head)?;
// use previous progress as seed size
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
if end > start + 1 {
let mut headers = random_header_range(start + 1..end, head.hash());
headers.insert(0, head);
self.context = Some(headers);
if start + 1 >= end {
return Ok(Vec::default())
}
Ok(())
let mut headers = random_header_range(start + 1..end, head.hash());
headers.insert(0, head);
Ok(headers)
}
async fn after_execution(
&self,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (tip, headers) = match self.context {
Some(ref headers) if headers.len() > 1 => {
(headers.last().unwrap().hash(), headers.clone())
}
_ => (H256::from_low_u64_be(rand::random()), Vec::default()),
async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
let tip = if !headers.is_empty() {
headers.last().unwrap().hash()
} else {
H256::from_low_u64_be(rand::random())
};
self.consensus.update_tip(tip);
self.client
.send_header_response_delayed(
0,
headers.iter().cloned().map(|h| h.unseal()).collect(),
headers.into_iter().map(|h| h.unseal()).collect(),
1,
)
.await;
@@ -350,11 +348,13 @@ mod tests {
fn validate_execution(
&self,
_input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(ref headers) = self.context {
// skip head and validate each
headers.iter().skip(1).try_for_each(|h| self.validate_db_header(&h))?;
}
_output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
// TODO: refine
// if let Some(ref headers) = self.context {
// // skip head and validate each
// headers.iter().skip(1).try_for_each(|h| self.validate_db_header(&h))?;
// }
Ok(())
}
}
@@ -364,17 +364,14 @@ mod tests {
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
) -> Result<(), TestRunnerError> {
let lowest_entry = input.unwind_to.saturating_sub(100);
let headers = random_header_range(lowest_entry..highest_entry, H256::zero());
self.insert_headers(headers.iter())?;
Ok(())
}
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
let unwind_to = input.unwind_to;
self.db.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(
unwind_to,
@@ -397,7 +394,7 @@ mod tests {
let downloader = Arc::new(
LinearDownloadBuilder::default().build(consensus.clone(), client.clone()),
);
Self { client, consensus, downloader, db: StageTestDB::default(), context: None }
Self { client, consensus, downloader, db: StageTestDB::default() }
}
}
@@ -438,30 +435,31 @@ mod tests {
&self,
header: &SealedHeader,
) -> Result<(), db::Error> {
let db = self.db.container();
let tx = db.get();
let key: BlockNumHash = (header.number, header.hash()).into();
self.db.query(|tx| {
let key: BlockNumHash = (header.number, header.hash()).into();
let db_number = tx.get::<tables::HeaderNumbers>(header.hash())?;
assert_eq!(db_number, Some(header.number));
let db_number = tx.get::<tables::HeaderNumbers>(header.hash())?;
assert_eq!(db_number, Some(header.number));
let db_header = tx.get::<tables::Headers>(key)?;
assert_eq!(db_header, Some(header.clone().unseal()));
let db_header = tx.get::<tables::Headers>(key)?;
assert_eq!(db_header, Some(header.clone().unseal()));
let db_canonical_header = tx.get::<tables::CanonicalHeaders>(header.number)?;
assert_eq!(db_canonical_header, Some(header.hash()));
let db_canonical_header = tx.get::<tables::CanonicalHeaders>(header.number)?;
assert_eq!(db_canonical_header, Some(header.hash()));
if header.number != 0 {
let parent_key: BlockNumHash = (header.number - 1, header.parent_hash).into();
let parent_td = tx.get::<tables::HeaderTD>(parent_key)?;
let td = U256::from_big_endian(&tx.get::<tables::HeaderTD>(key)?.unwrap());
assert_eq!(
parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty),
Some(td)
);
}
if header.number != 0 {
let parent_key: BlockNumHash =
(header.number - 1, header.parent_hash).into();
let parent_td = tx.get::<tables::HeaderTD>(parent_key)?;
let td = U256::from_big_endian(&tx.get::<tables::HeaderTD>(key)?.unwrap());
assert_eq!(
parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty),
Some(td)
);
}
Ok(())
Ok(())
})
}
}
}

View File

@@ -88,7 +88,7 @@ impl<DB: Database> Stage<DB> for TxIndex {
mod tests {
use super::*;
use crate::util::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner,
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
};
use assert_matches::assert_matches;
@@ -96,7 +96,7 @@ mod tests {
db::models::{BlockNumHash, StoredBlockBody},
test_utils::generators::random_header_range,
};
use reth_primitives::H256;
use reth_primitives::{SealedHeader, H256};
stage_test_suite!(TxIndexTestRunner);
@@ -118,10 +118,9 @@ mod tests {
}
impl ExecuteStageTestRunner for TxIndexTestRunner {
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
type Seed = ();
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let pivot = input.stage_progress.unwrap_or_default();
let start = pivot.saturating_sub(100);
let mut end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
@@ -156,31 +155,33 @@ mod tests {
fn validate_execution(
&self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let db = self.db.container();
let tx = db.get();
let (start, end) = (
input.stage_progress.unwrap_or_default(),
input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(),
);
if start >= end {
return Ok(())
}
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
self.db.query(|tx| {
let (start, end) = (
input.stage_progress.unwrap_or_default(),
input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(),
);
if start >= end {
return Ok(())
}
let start_hash =
tx.get::<tables::CanonicalHeaders>(start)?.expect("no canonical found");
let mut tx_count_cursor = tx.cursor::<tables::CumulativeTxCount>()?;
let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?;
let mut count = tx_count_walker.next().unwrap()?.1;
let mut last_num = start;
while let Some(entry) = tx_count_walker.next() {
let (key, db_count) = entry?;
count += tx.get::<tables::BlockBodies>(key)?.unwrap().tx_amount as u64;
assert_eq!(db_count, count);
last_num = key.number();
}
assert_eq!(last_num, end);
let start_hash =
tx.get::<tables::CanonicalHeaders>(start)?.expect("no canonical found");
let mut tx_count_cursor = tx.cursor::<tables::CumulativeTxCount>()?;
let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?;
let mut count = tx_count_walker.next().unwrap()?.1;
let mut last_num = start;
while let Some(entry) = tx_count_walker.next() {
let (key, db_count) = entry?;
count += tx.get::<tables::BlockBodies>(key)?.unwrap().tx_amount as u64;
assert_eq!(db_count, count);
last_num = key.number();
}
assert_eq!(last_num, end);
Ok(())
})?;
Ok(())
}
}
@@ -190,7 +191,7 @@ mod tests {
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
) -> Result<(), TestRunnerError> {
let headers = random_header_range(input.unwind_to..highest_entry, H256::zero());
self.db.transform_append::<tables::CumulativeTxCount, _, _>(&headers, |prev, h| {
(
@@ -201,10 +202,7 @@ mod tests {
Ok(())
}
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::CumulativeTxCount, _>(input.unwind_to, |h| {
h.number()
})?;

View File

@@ -142,7 +142,7 @@ pub(crate) mod test_utils {
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table};
use reth_interfaces::db::{self, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table};
use reth_primitives::BlockNumber;
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
@@ -171,20 +171,20 @@ pub(crate) mod test_utils {
}
impl StageTestDB {
/// Return a database wrapped in [DBContainer].
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
/// Return a database wrapped in [DBContainer].
pub(crate) fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Invoke a callback with transaction committing it afterwards
fn commit<F>(&self, f: F) -> Result<(), Error>
pub(crate) fn commit<F>(&self, f: F) -> Result<(), db::Error>
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), Error>,
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
{
let mut db = self.container();
let tx = db.get_mut();
@@ -193,10 +193,10 @@ pub(crate) mod test_utils {
Ok(())
}
/// Invoke a callback with transaction
fn query<F, R>(&self, f: F) -> Result<R, Error>
/// Invoke a callback with a read transaction
pub(crate) fn query<F, R>(&self, f: F) -> Result<R, db::Error>
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, Error>,
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
{
f(self.container().get())
}
@@ -208,7 +208,7 @@ pub(crate) mod test_utils {
/// let db = StageTestDB::default();
/// db.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), Error>
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), db::Error>
where
T: Table,
S: Clone,
@@ -235,7 +235,7 @@ pub(crate) mod test_utils {
&self,
values: &[S],
mut transform: F,
) -> Result<(), Error>
) -> Result<(), db::Error>
where
T: Table,
<T as Table>::Value: Clone,
@@ -259,7 +259,7 @@ pub(crate) mod test_utils {
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
@@ -279,7 +279,7 @@ pub(crate) mod test_utils {
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
@@ -294,6 +294,14 @@ pub(crate) mod test_utils {
}
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum TestRunnerError {
#[error("Database error occured.")]
Database(#[from] db::Error),
#[error("Internal runner error occured.")]
Internal(#[from] Box<dyn std::error::Error>),
}
/// A generic test runner for stages.
#[async_trait::async_trait]
pub(crate) trait StageTestRunner {
@@ -308,17 +316,17 @@ pub(crate) mod test_utils {
#[async_trait::async_trait]
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
type Seed: Send + Sync;
/// Seed database for stage execution
fn seed_execution(
&mut self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError>;
/// Validate stage execution
fn validate_execution(
&self,
input: ExecInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError>;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
@@ -334,7 +342,7 @@ pub(crate) mod test_utils {
}
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
async fn after_execution(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async fn after_execution(&self, seed: Self::Seed) -> Result<(), TestRunnerError> {
Ok(())
}
}
@@ -345,13 +353,10 @@ pub(crate) mod test_utils {
&mut self,
input: UnwindInput,
highest_entry: u64,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
) -> Result<(), TestRunnerError>;
/// Validate the unwind
fn validate_unwind(
&self,
input: UnwindInput,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>;
/// Run [Stage::unwind] and return a receiver for the result.
fn unwind(
@@ -373,37 +378,37 @@ pub(crate) mod test_utils {
macro_rules! stage_test_suite {
($runner:ident) => {
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
// Check that the execution errors on empty database or
// prev progress missing from the database.
async fn execute_empty_db() {
let runner = $runner::default();
let input = crate::stage::ExecInput::default();
let rx = runner.execute(input);
let result = runner.execute(input).await.unwrap();
assert_matches!(
rx.await.unwrap(),
result,
Err(crate::error::StageError::DatabaseIntegrity(_))
);
assert!(runner.validate_execution(input).is_ok(), "execution validation");
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
async fn execute_no_progress() {
async fn execute_already_reached_target() {
let stage_progress = 1000;
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, stage_progress)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed");
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution().await.expect("failed to run after execution hook");
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
rx.await.unwrap(),
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == stage_progress
);
assert!(runner.validate_execution(input).is_ok(), "execution validation");
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]
@@ -414,15 +419,16 @@ pub(crate) mod test_utils {
previous_stage: Some((crate::util::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed");
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
runner.after_execution().await.expect("failed to run after execution hook");
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(
rx.await.unwrap(),
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
assert!(runner.validate_execution(input).is_ok(), "execution validation");
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
#[tokio::test]