chore(pipeline): stage poll extension trait (#5484)

This commit is contained in:
Roman Krasiuk
2023-11-18 00:41:42 -08:00
committed by GitHub
parent f29e04dadc
commit c7a57a7031
7 changed files with 28 additions and 12 deletions

1
Cargo.lock generated
View File

@@ -6510,6 +6510,7 @@ dependencies = [
"aquamarine",
"assert_matches",
"async-trait",
"auto_impl",
"criterion",
"futures-util",
"itertools 0.11.0",

View File

@@ -12,7 +12,6 @@ use crate::{
version::SHORT_VERSION,
};
use clap::Parser;
use futures::future::poll_fn;
use reth_beacon_consensus::BeaconConsensus;
use reth_config::Config;
use reth_db::init_db;
@@ -25,7 +24,7 @@ use reth_stages::{
IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage,
StorageHashingStage, TransactionLookupStage,
},
ExecInput, Stage, UnwindInput,
ExecInput, Stage, StageExt, UnwindInput,
};
use std::{any::Any, net::SocketAddr, path::PathBuf, sync::Arc};
use tracing::*;
@@ -260,7 +259,7 @@ impl Command {
};
loop {
poll_fn(|cx| exec_stage.poll_execute_ready(cx, input)).await?;
exec_stage.execute_ready(input).await?;
let output = exec_stage.execute(&provider_rw, input)?;
input.checkpoint = Some(output.checkpoint);

View File

@@ -50,6 +50,7 @@ aquamarine.workspace = true
itertools.workspace = true
rayon.workspace = true
num-traits = "0.2.15"
auto_impl = "1"
[dev-dependencies]
# reth

View File

@@ -10,9 +10,9 @@ use reth_provider::ProviderFactory;
use reth_stages::{
stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage},
test_utils::TestTransaction,
ExecInput, Stage, UnwindInput,
ExecInput, Stage, StageExt, UnwindInput,
};
use std::{future::poll_fn, path::PathBuf, sync::Arc};
use std::{path::PathBuf, sync::Arc};
mod setup;
use setup::StageRange;
@@ -138,7 +138,8 @@ fn measure_stage_with_path<F, S>(
let mut stage = stage.clone();
let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
poll_fn(|cx| stage.poll_execute_ready(cx, input))
stage
.execute_ready(input)
.await
.and_then(|_| stage.execute(&provider, input))
.unwrap();

View File

@@ -1,6 +1,6 @@
use crate::{
error::*, BlockErrorKind, ExecInput, ExecOutput, MetricEvent, MetricEventsSender, Stage,
StageError, UnwindInput,
StageError, StageExt, UnwindInput,
};
use futures_util::Future;
use reth_db::database::Database;
@@ -11,7 +11,7 @@ use reth_primitives::{
};
use reth_provider::{ProviderFactory, StageCheckpointReader, StageCheckpointWriter};
use reth_tokio_util::EventListeners;
use std::{future::poll_fn, pin::Pin, sync::Arc};
use std::{pin::Pin, sync::Arc};
use tokio::sync::watch;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::*;
@@ -370,7 +370,7 @@ where
let exec_input = ExecInput { target, checkpoint: prev_checkpoint };
if let Err(err) = poll_fn(|cx| stage.poll_execute_ready(cx, exec_input)).await {
if let Err(err) = stage.execute_ready(exec_input).await {
self.listeners.notify(PipelineEvent::Error { stage_id });
match on_stage_error(&factory, stage_id, prev_checkpoint, err)? {
Some(ctrl) => return Ok(ctrl),

View File

@@ -7,6 +7,7 @@ use reth_primitives::{
use reth_provider::{BlockReader, DatabaseProviderRW, ProviderError, TransactionsProvider};
use std::{
cmp::{max, min},
future::poll_fn,
ops::{Range, RangeInclusive},
task::{Context, Poll},
};
@@ -189,6 +190,7 @@ pub struct UnwindOutput {
/// Stages are executed as part of a pipeline where they are executed serially.
///
/// Stages receive [`DatabaseProviderRW`].
#[auto_impl::auto_impl(Box)]
pub trait Stage<DB: Database>: Send + Sync {
/// Get the ID of the stage.
///
@@ -243,3 +245,15 @@ pub trait Stage<DB: Database>: Send + Sync {
input: UnwindInput,
) -> Result<UnwindOutput, StageError>;
}
/// [Stage] trait extension.
#[async_trait::async_trait]
pub trait StageExt<DB: Database>: Stage<DB> {
/// Utility extension for the `Stage` trait that invokes `Stage::poll_execute_ready`
/// with [poll_fn] context. For more information see [Stage::poll_execute_ready].
async fn execute_ready(&mut self, input: ExecInput) -> Result<(), StageError> {
poll_fn(|cx| self.poll_execute_ready(cx, input)).await
}
}
impl<DB: Database, S: Stage<DB>> StageExt<DB> for S {}

View File

@@ -1,10 +1,10 @@
use super::TestTransaction;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use crate::{ExecInput, ExecOutput, Stage, StageError, StageExt, UnwindInput, UnwindOutput};
use reth_db::DatabaseEnv;
use reth_interfaces::db::DatabaseError;
use reth_primitives::MAINNET;
use reth_provider::{ProviderError, ProviderFactory};
use std::{borrow::Borrow, future::poll_fn, sync::Arc};
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
#[derive(thiserror::Error, Debug)]
@@ -49,7 +49,7 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
tokio::spawn(async move {
let factory = ProviderFactory::new(db.db(), MAINNET.clone());
let result = poll_fn(|cx| stage.poll_execute_ready(cx, input)).await.and_then(|_| {
let result = stage.execute_ready(input).await.and_then(|_| {
let provider_rw = factory.provider_rw().unwrap();
let result = stage.execute(&provider_rw, input);
provider_rw.commit().expect("failed to commit");