From c7a57a703135582914cb77eeabdc3db179e0fbb2 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Sat, 18 Nov 2023 00:41:42 -0800 Subject: [PATCH] chore(pipeline): stage poll extension trait (#5484) --- Cargo.lock | 1 + bin/reth/src/stage/run.rs | 5 ++--- crates/stages/Cargo.toml | 1 + crates/stages/benches/criterion.rs | 7 ++++--- crates/stages/src/pipeline/mod.rs | 6 +++--- crates/stages/src/stage.rs | 14 ++++++++++++++ crates/stages/src/test_utils/runner.rs | 6 +++--- 7 files changed, 28 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 463dccf88c..89e2cd91e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6510,6 +6510,7 @@ dependencies = [ "aquamarine", "assert_matches", "async-trait", + "auto_impl", "criterion", "futures-util", "itertools 0.11.0", diff --git a/bin/reth/src/stage/run.rs b/bin/reth/src/stage/run.rs index 589bcbf7d0..6c82145190 100644 --- a/bin/reth/src/stage/run.rs +++ b/bin/reth/src/stage/run.rs @@ -12,7 +12,6 @@ use crate::{ version::SHORT_VERSION, }; use clap::Parser; -use futures::future::poll_fn; use reth_beacon_consensus::BeaconConsensus; use reth_config::Config; use reth_db::init_db; @@ -25,7 +24,7 @@ use reth_stages::{ IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, TransactionLookupStage, }, - ExecInput, Stage, UnwindInput, + ExecInput, Stage, StageExt, UnwindInput, }; use std::{any::Any, net::SocketAddr, path::PathBuf, sync::Arc}; use tracing::*; @@ -260,7 +259,7 @@ impl Command { }; loop { - poll_fn(|cx| exec_stage.poll_execute_ready(cx, input)).await?; + exec_stage.execute_ready(input).await?; let output = exec_stage.execute(&provider_rw, input)?; input.checkpoint = Some(output.checkpoint); diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 6da02ad00a..890bc135ed 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -50,6 +50,7 @@ aquamarine.workspace = true itertools.workspace = true rayon.workspace = true num-traits = "0.2.15" +auto_impl = "1" [dev-dependencies] # reth diff --git a/crates/stages/benches/criterion.rs b/crates/stages/benches/criterion.rs index ad210165cb..98979ca5a6 100644 --- a/crates/stages/benches/criterion.rs +++ b/crates/stages/benches/criterion.rs @@ -10,9 +10,9 @@ use reth_provider::ProviderFactory; use reth_stages::{ stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage}, test_utils::TestTransaction, - ExecInput, Stage, UnwindInput, + ExecInput, Stage, StageExt, UnwindInput, }; -use std::{future::poll_fn, path::PathBuf, sync::Arc}; +use std::{path::PathBuf, sync::Arc}; mod setup; use setup::StageRange; @@ -138,7 +138,8 @@ fn measure_stage_with_path( let mut stage = stage.clone(); let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - poll_fn(|cx| stage.poll_execute_ready(cx, input)) + stage + .execute_ready(input) .await .and_then(|_| stage.execute(&provider, input)) .unwrap(); diff --git a/crates/stages/src/pipeline/mod.rs b/crates/stages/src/pipeline/mod.rs index 718809abc6..06f487858d 100644 --- a/crates/stages/src/pipeline/mod.rs +++ b/crates/stages/src/pipeline/mod.rs @@ -1,6 +1,6 @@ use crate::{ error::*, BlockErrorKind, ExecInput, ExecOutput, MetricEvent, MetricEventsSender, Stage, - StageError, UnwindInput, + StageError, StageExt, UnwindInput, }; use futures_util::Future; use reth_db::database::Database; @@ -11,7 +11,7 @@ use reth_primitives::{ }; use reth_provider::{ProviderFactory, StageCheckpointReader, StageCheckpointWriter}; use reth_tokio_util::EventListeners; -use std::{future::poll_fn, pin::Pin, sync::Arc}; +use std::{pin::Pin, sync::Arc}; use tokio::sync::watch; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; @@ -370,7 +370,7 @@ where let exec_input = ExecInput { target, checkpoint: prev_checkpoint }; - if let Err(err) = poll_fn(|cx| stage.poll_execute_ready(cx, exec_input)).await { + if let Err(err) = stage.execute_ready(exec_input).await { self.listeners.notify(PipelineEvent::Error { stage_id }); match on_stage_error(&factory, stage_id, prev_checkpoint, err)? { Some(ctrl) => return Ok(ctrl), diff --git a/crates/stages/src/stage.rs b/crates/stages/src/stage.rs index 55a491a83c..1fc2b29c1d 100644 --- a/crates/stages/src/stage.rs +++ b/crates/stages/src/stage.rs @@ -7,6 +7,7 @@ use reth_primitives::{ use reth_provider::{BlockReader, DatabaseProviderRW, ProviderError, TransactionsProvider}; use std::{ cmp::{max, min}, + future::poll_fn, ops::{Range, RangeInclusive}, task::{Context, Poll}, }; @@ -189,6 +190,7 @@ pub struct UnwindOutput { /// Stages are executed as part of a pipeline where they are executed serially. /// /// Stages receive [`DatabaseProviderRW`]. +#[auto_impl::auto_impl(Box)] pub trait Stage: Send + Sync { /// Get the ID of the stage. /// @@ -243,3 +245,15 @@ pub trait Stage: Send + Sync { input: UnwindInput, ) -> Result; } + +/// [Stage] trait extension. +#[async_trait::async_trait] +pub trait StageExt: Stage { + /// Utility extension for the `Stage` trait that invokes `Stage::poll_execute_ready` + /// with [poll_fn] context. For more information see [Stage::poll_execute_ready]. + async fn execute_ready(&mut self, input: ExecInput) -> Result<(), StageError> { + poll_fn(|cx| self.poll_execute_ready(cx, input)).await + } +} + +impl> StageExt for S {} diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs index 96c44cacb4..0be375edcd 100644 --- a/crates/stages/src/test_utils/runner.rs +++ b/crates/stages/src/test_utils/runner.rs @@ -1,10 +1,10 @@ use super::TestTransaction; -use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; +use crate::{ExecInput, ExecOutput, Stage, StageError, StageExt, UnwindInput, UnwindOutput}; use reth_db::DatabaseEnv; use reth_interfaces::db::DatabaseError; use reth_primitives::MAINNET; use reth_provider::{ProviderError, ProviderFactory}; -use std::{borrow::Borrow, future::poll_fn, sync::Arc}; +use std::{borrow::Borrow, sync::Arc}; use tokio::sync::oneshot; #[derive(thiserror::Error, Debug)] @@ -49,7 +49,7 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner { tokio::spawn(async move { let factory = ProviderFactory::new(db.db(), MAINNET.clone()); - let result = poll_fn(|cx| stage.poll_execute_ready(cx, input)).await.and_then(|_| { + let result = stage.execute_ready(input).await.and_then(|_| { let provider_rw = factory.provider_rw().unwrap(); let result = stage.execute(&provider_rw, input); provider_rw.commit().expect("failed to commit");