feat(sync): stage db helper (#240)

* feat(sync): stage db helper

* stagedb cont

* merge stage db & db container

* rename test stage db accessor methods

* clippy

* remove legacy test
This commit is contained in:
Roman Krasiuk
2022-11-25 17:12:13 +02:00
committed by GitHub
parent fb2861f112
commit 6e7928ab84
15 changed files with 299 additions and 379 deletions

View File

@@ -326,35 +326,3 @@ mod tests {
let _ = provider.latest();
}
}
#[cfg(test)]
// This ensures that we can use the GATs in the downstream staged exec pipeline.
mod gat_tests {
use super::*;
use reth_interfaces::db::{mock::DatabaseMock, DBContainer};
#[async_trait::async_trait]
trait Stage<DB: Database> {
async fn run(&mut self, db: &mut DBContainer<'_, DB>) -> ();
}
struct MyStage<'a, DB>(&'a DB);
#[async_trait::async_trait]
impl<'c, DB: Database> Stage<DB> for MyStage<'c, DB> {
async fn run(&mut self, db: &mut DBContainer<'_, DB>) -> () {
let _tx = db.commit().unwrap();
}
}
#[test]
#[should_panic] // no tokio runtime configured
fn can_spawn() {
let db = DatabaseMock::default();
tokio::spawn(async move {
let mut container = DBContainer::new(&db).unwrap();
let mut stage = MyStage(&db);
stage.run(&mut container).await;
});
}
}

View File

@@ -1,100 +0,0 @@
use crate::db::{Database, DatabaseGAT, DbTx, Error};
/// A container for any DB transaction that will open a new inner transaction when the current
/// one is committed.
// NOTE: This container is needed since `Transaction::commit` takes `mut self`, so methods in
// the pipeline that just take a reference will not be able to commit their transaction and let
// the pipeline continue. Is there a better way to do this?
pub struct DBContainer<'this, DB: Database> {
/// A handle to the DB.
pub(crate) db: &'this DB,
tx: Option<<DB as DatabaseGAT<'this>>::TXMut>,
}
impl<'this, DB> DBContainer<'this, DB>
where
DB: Database,
{
/// Create a new container with the given database handle.
///
/// A new inner transaction will be opened.
pub fn new(db: &'this DB) -> Result<Self, Error> {
Ok(Self { db, tx: Some(db.tx_mut()?) })
}
/// Commit the current inner transaction and open a new one.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [DBContainer::close] was called without following up with a call to [DBContainer::open].
pub fn commit(&mut self) -> Result<bool, Error> {
let success =
self.tx.take().expect("Tried committing a non-existent transaction").commit()?;
self.tx = Some(self.db.tx_mut()?);
Ok(success)
}
/// Get the inner transaction.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [DBContainer::close] was called without following up with a call to [DBContainer::open].
pub fn get(&self) -> &<DB as DatabaseGAT<'this>>::TXMut {
self.tx.as_ref().expect("Tried getting a reference to a non-existent transaction")
}
/// Get a mutable reference to the inner transaction.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [DBContainer::close] was called without following up with a call to [DBContainer::open].
pub fn get_mut(&mut self) -> &mut <DB as DatabaseGAT<'this>>::TXMut {
self.tx.as_mut().expect("Tried getting a mutable reference to a non-existent transaction")
}
/// Open a new inner transaction.
pub fn open(&mut self) -> Result<(), Error> {
self.tx = Some(self.db.tx_mut()?);
Ok(())
}
/// Close the current inner transaction.
pub fn close(&mut self) {
self.tx.take();
}
}
#[cfg(test)]
// This ensures that we can use the GATs in the downstream staged exec pipeline.
mod tests {
use super::*;
use crate::db::mock::DatabaseMock;
#[async_trait::async_trait]
trait Stage<DB: Database> {
async fn run(&mut self, db: &mut DBContainer<'_, DB>) -> ();
}
struct MyStage<'a, DB>(&'a DB);
#[async_trait::async_trait]
impl<'a, DB: Database> Stage<DB> for MyStage<'a, DB> {
async fn run(&mut self, db: &mut DBContainer<'_, DB>) -> () {
let _tx = db.commit().unwrap();
}
}
#[test]
#[should_panic] // no tokio runtime configured
fn can_spawn() {
let db = DatabaseMock::default();
tokio::spawn(async move {
let mut container = DBContainer::new(&db).unwrap();
let mut stage = MyStage(&db);
stage.run(&mut container).await;
});
}
}

View File

@@ -1,5 +1,4 @@
pub mod codecs;
mod container;
mod error;
pub mod mock;
pub mod models;
@@ -8,7 +7,6 @@ pub mod tables;
use std::marker::PhantomData;
pub use container::DBContainer;
pub use error::Error;
pub use table::*;

174
crates/stages/src/db.rs Normal file
View File

@@ -0,0 +1,174 @@
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
};
use reth_interfaces::db::{
models::{BlockNumHash, NumTransactions},
tables, Database, DatabaseGAT, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table,
};
use reth_primitives::{BlockHash, BlockNumber};
use crate::{DatabaseIntegrityError, StageError};
/// A container for any DB transaction that will open a new inner transaction when the current
/// one is committed.
// NOTE: This container is needed since `Transaction::commit` takes `mut self`, so methods in
// the pipeline that just take a reference will not be able to commit their transaction and let
// the pipeline continue. Is there a better way to do this?
pub struct StageDB<'this, DB: Database> {
/// A handle to the DB.
pub(crate) db: &'this DB,
tx: Option<<DB as DatabaseGAT<'this>>::TXMut>,
}
impl<'a, DB: Database> Debug for StageDB<'a, DB> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StageDB").finish()
}
}
impl<'a, DB: Database> Deref for StageDB<'a, DB> {
type Target = <DB as DatabaseGAT<'a>>::TXMut;
/// Dereference as the inner transaction.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [StageDB::close] was called without following up with a call to [StageDB::open].
fn deref(&self) -> &Self::Target {
self.tx.as_ref().expect("Tried getting a reference to a non-existent transaction")
}
}
impl<'a, DB: Database> DerefMut for StageDB<'a, DB> {
/// Dereference as a mutable reference to the inner transaction.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [StageDB::close] was called without following up with a call to [StageDB::open].
fn deref_mut(&mut self) -> &mut Self::Target {
self.tx.as_mut().expect("Tried getting a mutable reference to a non-existent transaction")
}
}
impl<'this, DB> StageDB<'this, DB>
where
DB: Database,
{
/// Create a new container with the given database handle.
///
/// A new inner transaction will be opened.
pub fn new(db: &'this DB) -> Result<Self, Error> {
Ok(Self { db, tx: Some(db.tx_mut()?) })
}
/// Commit the current inner transaction and open a new one.
///
/// # Panics
///
/// Panics if an inner transaction does not exist. This should never be the case unless
/// [StageDB::close] was called without following up with a call to [StageDB::open].
pub fn commit(&mut self) -> Result<bool, Error> {
let success =
self.tx.take().expect("Tried committing a non-existent transaction").commit()?;
self.tx = Some(self.db.tx_mut()?);
Ok(success)
}
/// Open a new inner transaction.
pub fn open(&mut self) -> Result<(), Error> {
self.tx = Some(self.db.tx_mut()?);
Ok(())
}
/// Close the current inner transaction.
pub fn close(&mut self) {
self.tx.take();
}
/// Query [tables::CanonicalHeaders] table for block hash by block number
pub(crate) fn get_block_hash(&self, number: BlockNumber) -> Result<BlockHash, StageError> {
let hash = self
.get::<tables::CanonicalHeaders>(number)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number })?;
Ok(hash)
}
/// Query for block hash by block number and return it as [BlockNumHash] key
pub(crate) fn get_block_numhash(
&self,
number: BlockNumber,
) -> Result<BlockNumHash, StageError> {
Ok((number, self.get_block_hash(number)?).into())
}
/// Query [tables::CumulativeTxCount] table for total transaction
/// count block by [BlockNumHash] key
pub(crate) fn get_tx_count(&self, key: BlockNumHash) -> Result<NumTransactions, StageError> {
let count = self.get::<tables::CumulativeTxCount>(key)?.ok_or(
DatabaseIntegrityError::CumulativeTxCount { number: key.number(), hash: key.hash() },
)?;
Ok(count)
}
/// Unwind table by some number key
#[inline]
pub(crate) fn unwind_table_by_num<T>(&self, num: u64) -> Result<(), Error>
where
DB: Database,
T: Table<Key = u64>,
{
self.unwind_table::<T, _>(num, |key| key)
}
/// Unwind table by composite block number hash key
#[inline]
pub(crate) fn unwind_table_by_num_hash<T>(&self, block: BlockNumber) -> Result<(), Error>
where
DB: Database,
T: Table<Key = BlockNumHash>,
{
self.unwind_table::<T, _>(block, |key| key.number())
}
/// Unwind the table to a provided block
pub(crate) fn unwind_table<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
where
DB: Database,
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
let mut cursor = self.cursor_mut::<T>()?;
let mut entry = cursor.last()?;
while let Some((key, _)) = entry {
if selector(key) <= block {
break
}
cursor.delete_current()?;
entry = cursor.prev()?;
}
Ok(())
}
/// Unwind a table forward by a [Walker] on another table
pub(crate) fn unwind_table_by_walker<T1, T2>(&self, start_at: T1::Key) -> Result<(), Error>
where
DB: Database,
T1: Table,
T2: Table<Key = T1::Value>,
{
let mut cursor = self.cursor_mut::<T1>()?;
let mut walker = cursor.walk(start_at)?;
while let Some((_, value)) = walker.next().transpose()? {
self.delete::<T2>(value, None)?;
}
Ok(())
}
}

View File

@@ -14,6 +14,7 @@
//!
//! - `stage.progress{stage}`: The block number each stage has currently reached.
mod db;
mod error;
mod id;
mod pipeline;

View File

@@ -1,11 +1,12 @@
use crate::{
error::*, util::opt::MaybeSender, ExecInput, ExecOutput, Stage, StageError, StageId,
UnwindInput,
db::StageDB, error::*, util::opt::MaybeSender, ExecInput, ExecOutput, Stage, StageError,
StageId, UnwindInput,
};
use reth_interfaces::db::{DBContainer, Database, DbTx};
use reth_interfaces::db::{Database, DbTx};
use reth_primitives::BlockNumber;
use std::{
fmt::{Debug, Formatter},
ops::Deref,
sync::Arc,
};
use tokio::sync::mpsc::Sender;
@@ -221,14 +222,14 @@ impl<DB: Database> Pipeline<DB> {
};
// Unwind stages in reverse order of priority (i.e. higher priority = first)
let mut db = DBContainer::new(db)?;
let mut db = StageDB::new(db)?;
for (_, QueuedStage { stage, .. }) in unwind_pipeline.iter_mut() {
let stage_id = stage.id();
let span = info_span!("Unwinding", stage = %stage_id);
let _enter = span.enter();
let mut stage_progress = stage_id.get_progress(db.get())?.unwrap_or_default();
let mut stage_progress = stage_id.get_progress(db.deref())?.unwrap_or_default();
if stage_progress < to {
debug!(from = %stage_progress, %to, "Unwind point too far for stage");
self.events_sender.send(PipelineEvent::Skipped { stage_id }).await?;
@@ -244,7 +245,7 @@ impl<DB: Database> Pipeline<DB> {
match output {
Ok(unwind_output) => {
stage_progress = unwind_output.stage_progress;
stage_id.save_progress(db.get(), stage_progress)?;
stage_id.save_progress(db.deref(), stage_progress)?;
self.events_sender
.send(PipelineEvent::Unwound { stage_id, result: unwind_output })
@@ -293,9 +294,9 @@ impl<DB: Database> QueuedStage<DB> {
}
loop {
let mut db = DBContainer::new(db)?;
let mut db = StageDB::new(db)?;
let prev_progress = stage_id.get_progress(db.get())?;
let prev_progress = stage_id.get_progress(db.deref())?;
let stage_reached_max_block = prev_progress
.zip(state.max_block)
@@ -321,7 +322,7 @@ impl<DB: Database> QueuedStage<DB> {
{
Ok(out @ ExecOutput { stage_progress, done, reached_tip }) => {
debug!(stage = %stage_id, %stage_progress, %done, "Stage made progress");
stage_id.save_progress(db.get(), stage_progress)?;
stage_id.save_progress(db.deref(), stage_progress)?;
state
.events_sender
@@ -762,7 +763,7 @@ mod tests {
async fn execute(
&mut self,
_: &mut DBContainer<'_, DB>,
_: &mut StageDB<'_, DB>,
_input: ExecInput,
) -> Result<ExecOutput, StageError> {
self.exec_outputs
@@ -772,7 +773,7 @@ mod tests {
async fn unwind(
&mut self,
_: &mut DBContainer<'_, DB>,
_: &mut StageDB<'_, DB>,
_input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
self.unwind_outputs

View File

@@ -1,6 +1,6 @@
use crate::{error::StageError, id::StageId};
use crate::{db::StageDB, error::StageError, id::StageId};
use async_trait::async_trait;
use reth_interfaces::db::{DBContainer, Database};
use reth_interfaces::db::Database;
use reth_primitives::BlockNumber;
/// Stage execution input, see [Stage::execute].
@@ -58,8 +58,8 @@ pub struct UnwindOutput {
///
/// Stages are executed as part of a pipeline where they are executed serially.
///
/// Stages receive a [`DBContainer`] which manages the lifecycle of a transaction, such
/// as when to commit / reopen a new one etc.
/// Stages receive [`StageDB`] which manages the lifecycle of a transaction,
/// such as when to commit / reopen a new one etc.
#[async_trait]
pub trait Stage<DB: Database>: Send + Sync {
/// Get the ID of the stage.
@@ -70,14 +70,14 @@ pub trait Stage<DB: Database>: Send + Sync {
/// Execute the stage.
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError>;
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>>;
}

View File

@@ -1,13 +1,13 @@
use crate::{
DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput,
UnwindOutput,
db::StageDB, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId,
UnwindInput, UnwindOutput,
};
use futures_util::TryStreamExt;
use reth_interfaces::{
consensus::Consensus,
db::{
models::StoredBlockBody, tables, DBContainer, Database, DatabaseGAT, DbCursorRO,
DbCursorRW, DbTx, DbTxMut,
models::StoredBlockBody, tables, Database, DatabaseGAT, DbCursorRO, DbCursorRW, DbTx,
DbTxMut,
},
p2p::bodies::downloader::BodyDownloader,
};
@@ -72,11 +72,9 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
/// header, limited by the stage's batch size.
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
let previous_stage_progress = input.previous_stage_progress();
if previous_stage_progress == 0 {
warn!("The body stage seems to be running first, no work can be completed.");
@@ -95,18 +93,18 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
return Ok(ExecOutput { stage_progress: target, reached_tip: true, done: true })
}
let bodies_to_download = self.bodies_to_download::<DB>(tx, starting_block, target)?;
let bodies_to_download = self.bodies_to_download::<DB>(db, starting_block, target)?;
// Cursors used to write bodies and transactions
let mut bodies_cursor = tx.cursor_mut::<tables::BlockBodies>()?;
let mut tx_cursor = tx.cursor_mut::<tables::Transactions>()?;
let mut bodies_cursor = db.cursor_mut::<tables::BlockBodies>()?;
let mut tx_cursor = db.cursor_mut::<tables::Transactions>()?;
let mut base_tx_id = bodies_cursor
.last()?
.map(|(_, body)| body.base_tx_id + body.tx_amount)
.ok_or(DatabaseIntegrityError::BlockBody { number: starting_block })?;
// Cursor used to look up headers for block pre-validation
let mut header_cursor = tx.cursor::<tables::Headers>()?;
let mut header_cursor = db.cursor::<tables::Headers>()?;
// NOTE(onbjerg): The stream needs to live here otherwise it will just create a new iterator
// on every iteration of the while loop -_-
@@ -167,12 +165,11 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
let tx = db.get_mut();
let mut block_body_cursor = tx.cursor_mut::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor_mut::<tables::Transactions>()?;
let mut block_body_cursor = db.cursor_mut::<tables::BlockBodies>()?;
let mut transaction_cursor = db.cursor_mut::<tables::Transactions>()?;
let mut entry = block_body_cursor.last()?;
while let Some((key, body)) = entry {
@@ -457,7 +454,7 @@ mod tests {
use crate::{
stages::bodies::BodyStage,
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
},
ExecInput, ExecOutput, UnwindInput,
@@ -496,7 +493,7 @@ mod tests {
pub(crate) struct BodyTestRunner {
pub(crate) consensus: Arc<TestConsensus>,
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
db: StageTestDB,
db: TestStageDB,
batch_size: u64,
}
@@ -505,7 +502,7 @@ mod tests {
Self {
consensus: Arc::new(TestConsensus::default()),
responses: HashMap::default(),
db: StageTestDB::default(),
db: TestStageDB::default(),
batch_size: 1000,
}
}
@@ -527,7 +524,7 @@ mod tests {
impl StageTestRunner for BodyTestRunner {
type S = BodyStage<TestBodyDownloader, TestConsensus>;
fn db(&self) -> &StageTestDB {
fn db(&self) -> &TestStageDB {
&self.db
}

View File

@@ -1,14 +1,10 @@
use crate::{
util::unwind::{unwind_table_by_num, unwind_table_by_num_hash, unwind_table_by_walker},
DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput,
UnwindOutput,
db::StageDB, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId,
UnwindInput, UnwindOutput,
};
use reth_interfaces::{
consensus::{Consensus, ForkchoiceState},
db::{
models::blocks::BlockNumHash, tables, DBContainer, Database, DatabaseGAT, DbCursorRO,
DbCursorRW, DbTx, DbTxMut,
},
db::{models::blocks::BlockNumHash, tables, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut},
p2p::headers::{client::HeadersClient, downloader::HeaderDownloader, error::DownloadError},
};
use reth_primitives::{rpc::BigEndianHash, BlockNumber, SealedHeader, H256, U256};
@@ -51,20 +47,17 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient> Stage<DB
/// starting from the tip
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
let last_block_num = input.stage_progress.unwrap_or_default();
self.update_head::<DB>(tx, last_block_num).await?;
self.update_head::<DB>(db, last_block_num).await?;
// TODO: add batch size
// download the headers
let last_hash = tx
.get::<tables::CanonicalHeaders>(last_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?;
let last_hash = db.get_block_hash(last_block_num)?;
let last_header =
tx.get::<tables::Headers>((last_block_num, last_hash).into())?.ok_or({
db.get::<tables::Headers>((last_block_num, last_hash).into())?.ok_or({
DatabaseIntegrityError::Header { number: last_block_num, hash: last_hash }
})?;
let head = SealedHeader::new(last_header, last_hash);
@@ -96,27 +89,23 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient> Stage<DB
_ => unreachable!(),
},
};
let stage_progress = self.write_headers::<DB>(tx, headers).await?.unwrap_or(last_block_num);
let stage_progress = self.write_headers::<DB>(db, headers).await?.unwrap_or(last_block_num);
Ok(ExecOutput { stage_progress, reached_tip: true, done: true })
}
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
// TODO: handle bad block
let tx = db.get_mut();
unwind_table_by_walker::<DB, tables::CanonicalHeaders, tables::HeaderNumbers>(
tx,
db.unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
input.unwind_to + 1,
)?;
unwind_table_by_num::<DB, tables::CanonicalHeaders>(tx, input.unwind_to)?;
unwind_table_by_num_hash::<DB, tables::Headers>(tx, input.unwind_to)?;
unwind_table_by_num_hash::<DB, tables::HeaderTD>(tx, input.unwind_to)?;
db.unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
db.unwind_table_by_num_hash::<tables::Headers>(input.unwind_to)?;
db.unwind_table_by_num_hash::<tables::HeaderTD>(input.unwind_to)?;
Ok(UnwindOutput { stage_progress: input.unwind_to })
}
}
@@ -124,14 +113,12 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient> Stage<DB
impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
async fn update_head<DB: Database>(
&self,
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
db: &StageDB<'_, DB>,
height: BlockNumber,
) -> Result<(), StageError> {
let hash = tx
.get::<tables::CanonicalHeaders>(height)?
.ok_or(DatabaseIntegrityError::CanonicalHeader { number: height })?;
let td: Vec<u8> = tx.get::<tables::HeaderTD>((height, hash).into())?.unwrap(); // TODO:
self.client.update_status(height, hash, H256::from_slice(&td).into_uint());
let block_key = db.get_block_numhash(height)?;
let td: Vec<u8> = db.get::<tables::HeaderTD>(block_key)?.unwrap(); // TODO:
self.client.update_status(height, block_key.hash(), H256::from_slice(&td).into_uint());
Ok(())
}
@@ -149,12 +136,12 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
/// Write downloaded headers to the database
async fn write_headers<DB: Database>(
&self,
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
db: &StageDB<'_, DB>,
headers: Vec<SealedHeader>,
) -> Result<Option<BlockNumber>, StageError> {
let mut cursor_header = tx.cursor_mut::<tables::Headers>()?;
let mut cursor_canonical = tx.cursor_mut::<tables::CanonicalHeaders>()?;
let mut cursor_td = tx.cursor_mut::<tables::HeaderTD>()?;
let mut cursor_header = db.cursor_mut::<tables::Headers>()?;
let mut cursor_canonical = db.cursor_mut::<tables::CanonicalHeaders>()?;
let mut cursor_td = db.cursor_mut::<tables::HeaderTD>()?;
let mut td = U256::from_big_endian(&cursor_td.last()?.map(|(_, v)| v).unwrap());
let mut latest = None;
@@ -174,7 +161,7 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
// TODO: investigate default write flags
// NOTE: HeaderNumbers are not sorted and can't be inserted with cursor.
tx.put::<tables::HeaderNumbers>(block_hash, header.number)?;
db.put::<tables::HeaderNumbers>(block_hash, header.number)?;
cursor_header.append(key, header)?;
cursor_canonical.append(key.number(), key.hash())?;
cursor_td.append(key, H256::from_uint(&td).as_bytes().to_vec())?;
@@ -264,7 +251,7 @@ mod tests {
use crate::{
stages::headers::HeaderStage,
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
},
ExecInput, ExecOutput, UnwindInput,
@@ -285,7 +272,7 @@ mod tests {
pub(crate) consensus: Arc<TestConsensus>,
pub(crate) client: Arc<TestHeadersClient>,
downloader: Arc<D>,
db: StageTestDB,
db: TestStageDB,
}
impl Default for HeadersTestRunner<TestHeaderDownloader> {
@@ -296,7 +283,7 @@ mod tests {
client: client.clone(),
consensus: consensus.clone(),
downloader: Arc::new(TestHeaderDownloader::new(client, consensus, 1000)),
db: StageTestDB::default(),
db: TestStageDB::default(),
}
}
}
@@ -304,7 +291,7 @@ mod tests {
impl<D: HeaderDownloader + 'static> StageTestRunner for HeadersTestRunner<D> {
type S = HeaderStage<Arc<D>, TestConsensus, TestHeadersClient>;
fn db(&self) -> &StageTestDB {
fn db(&self) -> &TestStageDB {
&self.db
}
@@ -412,7 +399,7 @@ mod tests {
let downloader = Arc::new(
LinearDownloadBuilder::default().build(consensus.clone(), client.clone()),
);
Self { client, consensus, downloader, db: StageTestDB::default() }
Self { client, consensus, downloader, db: TestStageDB::default() }
}
}

View File

@@ -1,12 +1,9 @@
use crate::{
util::unwind::unwind_table_by_num, DatabaseIntegrityError, ExecInput, ExecOutput, Stage,
StageError, StageId, UnwindInput, UnwindOutput,
db::StageDB, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput,
};
use itertools::Itertools;
use rayon::prelude::*;
use reth_interfaces::db::{
self, tables, DBContainer, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut,
};
use reth_interfaces::db::{self, tables, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut};
use reth_primitives::TxNumber;
use std::fmt::Debug;
use thiserror::Error;
@@ -48,40 +45,22 @@ impl<DB: Database> Stage<DB> for SendersStage {
/// the [`TxSenders`][reth_interfaces::db::tables::TxSenders] table.
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
// Look up the start index for transaction range
let last_block_num = input.stage_progress.unwrap_or_default();
let last_block_hash = tx
.get::<tables::CanonicalHeaders>(last_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?;
let start_tx_index = tx
.get::<tables::CumulativeTxCount>((last_block_num, last_block_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: last_block_num,
hash: last_block_hash,
})?;
let last_block = db.get_block_numhash(input.stage_progress.unwrap_or_default())?;
let start_tx_index = db.get_tx_count(last_block)?;
// Look up the end index for transaction range (exclusive)
let max_block_num = input.previous_stage_progress();
let max_block_hash = tx
.get::<tables::CanonicalHeaders>(max_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: max_block_num })?;
let end_tx_index = tx
.get::<tables::CumulativeTxCount>((max_block_num, max_block_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: last_block_num,
hash: last_block_hash,
})?;
let max_block = db.get_block_numhash(input.previous_stage_progress())?;
let end_tx_index = db.get_tx_count(max_block)?;
// Acquire the cursor for inserting elements
let mut senders_cursor = tx.cursor_mut::<tables::TxSenders>()?;
let mut senders_cursor = db.cursor_mut::<tables::TxSenders>()?;
// Acquire the cursor over the transactions
let mut tx_cursor = tx.cursor::<tables::Transactions>()?;
let mut tx_cursor = db.cursor::<tables::Transactions>()?;
// Walk the transactions from start to end index (exclusive)
let entries = tx_cursor
.walk(start_tx_index)?
@@ -105,28 +84,20 @@ impl<DB: Database> Stage<DB> for SendersStage {
recovered.into_iter().try_for_each(|(id, sender)| senders_cursor.append(id, sender))?;
}
Ok(ExecOutput { stage_progress: max_block_num, done: true, reached_tip: true })
Ok(ExecOutput { stage_progress: max_block.number(), done: true, reached_tip: true })
}
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
let tx = db.get_mut();
// Look up the hash of the unwind block
if let Some(unwind_hash) = tx.get::<tables::CanonicalHeaders>(input.unwind_to)? {
if let Some(unwind_hash) = db.get::<tables::CanonicalHeaders>(input.unwind_to)? {
// Look up the cumulative tx count at unwind block
let latest_tx = tx
.get::<tables::CumulativeTxCount>((input.unwind_to, unwind_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: input.unwind_to,
hash: unwind_hash,
})?;
unwind_table_by_num::<DB, tables::TxSenders>(tx, latest_tx - 1)?;
let latest_tx = db.get_tx_count((input.unwind_to, unwind_hash).into())?;
db.unwind_table_by_num::<tables::TxSenders>(latest_tx - 1)?;
}
Ok(UnwindOutput { stage_progress: input.unwind_to })
@@ -142,7 +113,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
stage_test_suite, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
};
@@ -150,13 +121,13 @@ mod tests {
#[derive(Default)]
struct SendersTestRunner {
db: StageTestDB,
db: TestStageDB,
}
impl StageTestRunner for SendersTestRunner {
type S = SendersStage;
fn db(&self) -> &StageTestDB {
fn db(&self) -> &TestStageDB {
&self.db
}

View File

@@ -1,8 +1,8 @@
use crate::{
util::unwind::unwind_table_by_num_hash, DatabaseIntegrityError, ExecInput, ExecOutput, Stage,
StageError, StageId, UnwindInput, UnwindOutput,
db::StageDB, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId,
UnwindInput, UnwindOutput,
};
use reth_interfaces::db::{tables, DBContainer, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut};
use reth_interfaces::db::{tables, Database, DbCursorRO, DbCursorRW, DbTxMut};
use std::fmt::Debug;
const TX_INDEX: StageId = StageId("TxIndex");
@@ -28,36 +28,30 @@ impl<DB: Database> Stage<DB> for TxIndex {
/// Execute the stage
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
// The progress of this stage during last iteration
let last_block = input.stage_progress.unwrap_or_default();
let last_hash = tx
.get::<tables::CanonicalHeaders>(last_block)?
.ok_or(DatabaseIntegrityError::CanonicalHeader { number: last_block })?;
let last_block = db.get_block_numhash(input.stage_progress.unwrap_or_default())?;
// The start block for this iteration
let start_block = last_block + 1;
let start_hash = tx
.get::<tables::CanonicalHeaders>(start_block)?
.ok_or(DatabaseIntegrityError::CanonicalHeader { number: start_block })?;
let start_block = db.get_block_numhash(last_block.number() + 1)?;
// The maximum block that this stage should insert to
let max_block = input.previous_stage_progress();
// Get the cursor over the table
let mut cursor = tx.cursor_mut::<tables::CumulativeTxCount>()?;
let mut cursor = db.cursor_mut::<tables::CumulativeTxCount>()?;
// Find the last count that was inserted during previous iteration
let (_, mut count) = cursor.seek_exact((last_block, last_hash).into())?.ok_or(
DatabaseIntegrityError::CumulativeTxCount { number: last_block, hash: last_hash },
)?;
let (_, mut count) =
cursor.seek_exact(last_block)?.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: last_block.number(),
hash: last_block.hash(),
})?;
// Get the cursor over block bodies
let mut body_cursor = tx.cursor_mut::<tables::BlockBodies>()?;
let walker = body_cursor.walk((start_block, start_hash).into())?;
let mut body_cursor = db.cursor_mut::<tables::BlockBodies>()?;
let walker = body_cursor.walk(start_block)?;
// Walk the block body entries up to maximum block (including)
let entries = walker
@@ -76,10 +70,10 @@ impl<DB: Database> Stage<DB> for TxIndex {
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
db: &mut StageDB<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
unwind_table_by_num_hash::<DB, tables::CumulativeTxCount>(db.get_mut(), input.unwind_to)?;
db.unwind_table_by_num_hash::<tables::CumulativeTxCount>(input.unwind_to)?;
Ok(UnwindOutput { stage_progress: input.unwind_to })
}
}
@@ -88,11 +82,14 @@ impl<DB: Database> Stage<DB> for TxIndex {
mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
stage_test_suite, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
};
use reth_interfaces::{
db::models::{BlockNumHash, StoredBlockBody},
db::{
models::{BlockNumHash, StoredBlockBody},
DbTx,
},
test_utils::generators::random_header_range,
};
use reth_primitives::H256;
@@ -101,13 +98,13 @@ mod tests {
#[derive(Default)]
pub(crate) struct TxIndexTestRunner {
db: StageTestDB,
db: TestStageDB,
}
impl StageTestRunner for TxIndexTestRunner {
type S = TxIndex;
fn db(&self) -> &StageTestDB {
fn db(&self) -> &TestStageDB {
&self.db
}

View File

@@ -2,13 +2,13 @@ use crate::StageId;
mod macros;
mod runner;
mod stage_db;
mod test_db;
pub(crate) use macros::*;
pub(crate) use runner::{
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner,
};
pub(crate) use stage_db::StageTestDB;
pub(crate) use test_db::TestStageDB;
/// The previous test stage id mock used for testing
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");

View File

@@ -1,15 +1,14 @@
use reth_db::{kv::Env, mdbx::WriteMap};
use reth_interfaces::db::{self, DBContainer};
use std::borrow::Borrow;
use tokio::sync::oneshot;
use super::StageTestDB;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use super::TestStageDB;
use crate::{db::StageDB, ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
#[derive(thiserror::Error, Debug)]
pub(crate) enum TestRunnerError {
#[error("Database error occured.")]
Database(#[from] db::Error),
Database(#[from] reth_interfaces::db::Error),
#[error("Internal runner error occured.")]
Internal(#[from] Box<dyn std::error::Error>),
}
@@ -20,7 +19,7 @@ pub(crate) trait StageTestRunner {
type S: Stage<Env<WriteMap>> + 'static;
/// Return a reference to the database.
fn db(&self) -> &StageTestDB;
fn db(&self) -> &TestStageDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
@@ -43,9 +42,9 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
let (db, mut stage) = (self.db().inner_raw(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let mut db = StageDB::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
@@ -70,9 +69,9 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner {
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
let (db, mut stage) = (self.db().inner_raw(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let mut db = StageDB::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");

View File

@@ -3,11 +3,13 @@ use reth_db::{
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{
self, models::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table,
self, models::BlockNumHash, tables, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table,
};
use reth_primitives::{BigEndianHash, BlockNumber, SealedHeader, H256, U256};
use std::{borrow::Borrow, sync::Arc};
use crate::db::StageDB;
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
///
@@ -15,25 +17,25 @@ use std::{borrow::Borrow, sync::Arc};
/// let db = StageTestDB::default();
/// stage.execute(&mut db.container(), input);
/// ```
pub(crate) struct StageTestDB {
pub(crate) struct TestStageDB {
db: Arc<Env<WriteMap>>,
}
impl Default for StageTestDB {
impl Default for TestStageDB {
/// Create a new instance of [StageTestDB]
fn default() -> Self {
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
}
}
impl StageTestDB {
/// Return a database wrapped in [DBContainer].
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
impl TestStageDB {
/// Return a database wrapped in [StageDB].
fn inner(&self) -> StageDB<'_, Env<WriteMap>> {
StageDB::new(self.db.borrow()).expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
pub(crate) fn inner_raw(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
@@ -42,9 +44,8 @@ impl StageTestDB {
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
{
let mut db = self.container();
let tx = db.get_mut();
f(tx)?;
let mut db = self.inner();
f(&mut db)?;
db.commit()?;
Ok(())
}
@@ -54,7 +55,7 @@ impl StageTestDB {
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
{
f(self.container().get())
f(&self.inner())
}
/// Check if the table is empty

View File

@@ -61,77 +61,3 @@ pub(crate) mod opt {
}
}
}
pub(crate) mod unwind {
use reth_interfaces::db::{
models::BlockNumHash, Database, DatabaseGAT, DbCursorRO, DbCursorRW, DbTxMut, Error, Table,
};
use reth_primitives::BlockNumber;
/// Unwind table by some number key
#[inline]
pub(crate) fn unwind_table_by_num<DB, T>(
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
num: u64,
) -> Result<(), Error>
where
DB: Database,
T: Table<Key = u64>,
{
unwind_table::<DB, T, _>(tx, num, |key| key)
}
/// Unwind table by composite block number hash key
#[inline]
pub(crate) fn unwind_table_by_num_hash<DB, T>(
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
block: BlockNumber,
) -> Result<(), Error>
where
DB: Database,
T: Table<Key = BlockNumHash>,
{
unwind_table::<DB, T, _>(tx, block, |key| key.number())
}
/// Unwind the table to a provided block
pub(crate) fn unwind_table<DB, T, F>(
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
where
DB: Database,
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
let mut cursor = tx.cursor_mut::<T>()?;
let mut entry = cursor.last()?;
while let Some((key, _)) = entry {
if selector(key) <= block {
break
}
cursor.delete_current()?;
entry = cursor.prev()?;
}
Ok(())
}
/// Unwind a table forward by a [Walker] on another table
pub(crate) fn unwind_table_by_walker<DB, T1, T2>(
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
start_at: T1::Key,
) -> Result<(), Error>
where
DB: Database,
T1: Table,
T2: Table<Key = T1::Value>,
{
let mut cursor = tx.cursor_mut::<T1>()?;
let mut walker = cursor.walk(start_at)?;
while let Some((_, value)) = walker.next().transpose()? {
tx.delete::<T2>(value, None)?;
}
Ok(())
}
}