feat: support custom stages

This commit is contained in:
Arsenii Kulikov
2025-10-13 22:10:07 +04:00
parent 169a1fb97b
commit 2537cc40f9
11 changed files with 124 additions and 56 deletions

1
Cargo.lock generated
View File

@@ -8892,6 +8892,7 @@ dependencies = [
"reth-payload-builder-primitives",
"reth-payload-primitives",
"reth-provider",
"reth-stages-api",
"reth-tasks",
"reth-tokio-util",
"reth-transaction-pool",

View File

@@ -21,10 +21,10 @@ use reth_node_types::{BlockTy, NodeTypes};
use reth_payload_builder::PayloadBuilderHandle;
use reth_provider::{
providers::{BlockchainProvider, ProviderNodeTypes},
ProviderFactory,
DatabaseProviderFactory, ProviderFactory,
};
use reth_prune::PrunerWithFactory;
use reth_stages_api::{MetricEventsSender, Pipeline};
use reth_stages_api::{BoxedStage, MetricEventsSender, Pipeline};
use reth_tasks::TaskSpawner;
use std::{
pin::Pin,
@@ -84,6 +84,7 @@ where
tree_config: TreeConfig,
sync_metrics_tx: MetricEventsSender,
evm_config: C,
custom_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
) -> Self
where
V: EngineValidator<N::Payload>,
@@ -94,8 +95,12 @@ where
let downloader = BasicBlockDownloader::new(client, consensus.clone());
let persistence_handle =
PersistenceHandle::<EthPrimitives>::spawn_service(provider, pruner, sync_metrics_tx);
let persistence_handle = PersistenceHandle::<EthPrimitives>::spawn_service(
provider,
pruner,
sync_metrics_tx,
custom_stages,
);
let canonical_in_memory_state = blockchain_db.canonical_in_memory_state();
@@ -214,6 +219,7 @@ mod tests {
TreeConfig::default(),
sync_metrics_tx,
evm_config,
Default::default(),
);
}
}

View File

@@ -7,9 +7,11 @@ use reth_ethereum_primitives::EthPrimitives;
use reth_primitives_traits::NodePrimitives;
use reth_provider::{
providers::ProviderNodeTypes, BlockExecutionWriter, BlockHashReader, ChainStateBlockWriter,
DBProvider, DatabaseProviderFactory, ProviderFactory,
DBProvider, DatabaseProviderFactory, ProviderFactory, StageCheckpointReader,
StageCheckpointWriter,
};
use reth_prune::{PrunerError, PrunerOutput, PrunerWithFactory};
use reth_stages::{BoxedStage, ExecInput, ExecOutput, StageError, UnwindInput, UnwindOutput};
use reth_stages_api::{MetricEvent, MetricEventsSender};
use std::{
sync::mpsc::{Receiver, SendError, Sender},
@@ -26,7 +28,7 @@ use tracing::{debug, error};
///
/// This should be spawned in its own thread with [`std::thread::spawn`], since this performs
/// blocking I/O operations in an endless loop.
#[derive(Debug)]
#[expect(missing_debug_implementations)]
pub struct PersistenceService<N>
where
N: ProviderNodeTypes,
@@ -41,6 +43,8 @@ where
metrics: PersistenceMetrics,
/// Sender for sync metrics - we only submit sync metrics for persisted blocks
sync_metrics_tx: MetricEventsSender,
/// Custom pipeline stages advanced on new blocks.
custom_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
}
impl<N> PersistenceService<N>
@@ -53,8 +57,16 @@ where
incoming: Receiver<PersistenceAction<N::Primitives>>,
pruner: PrunerWithFactory<ProviderFactory<N>>,
sync_metrics_tx: MetricEventsSender,
custom_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
) -> Self {
Self { provider, incoming, pruner, metrics: PersistenceMetrics::default(), sync_metrics_tx }
Self {
provider,
incoming,
pruner,
metrics: PersistenceMetrics::default(),
sync_metrics_tx,
custom_stages,
}
}
/// Prunes block data before the given block number according to the configured prune
@@ -67,12 +79,7 @@ where
self.metrics.prune_before_duration_seconds.record(start_time.elapsed());
result
}
}
impl<N> PersistenceService<N>
where
N: ProviderNodeTypes,
{
/// This is the main loop, that will listen to database events and perform the requested
/// database actions
pub fn run(mut self) -> Result<(), PersistenceError> {
@@ -122,7 +129,7 @@ where
}
fn on_remove_blocks_above(
&self,
&mut self,
new_tip_num: u64,
) -> Result<Option<BlockNumHash>, PersistenceError> {
debug!(target: "engine::persistence", ?new_tip_num, "Removing blocks");
@@ -130,6 +137,17 @@ where
let provider_rw = self.provider.database_provider_rw()?;
let new_tip_hash = provider_rw.block_hash(new_tip_num)?;
for stage in self.custom_stages.iter_mut().rev() {
if let Some(checkpoint) = provider_rw.get_stage_checkpoint(stage.id())? {
let UnwindOutput { checkpoint } = stage.unwind(
&provider_rw,
UnwindInput { checkpoint, unwind_to: new_tip_num, bad_block: None },
)?;
provider_rw.save_stage_checkpoint(stage.id(), checkpoint)?;
}
}
provider_rw.remove_block_and_execution_above(new_tip_num)?;
provider_rw.commit()?;
@@ -139,7 +157,7 @@ where
}
fn on_save_blocks(
&self,
&mut self,
blocks: Vec<ExecutedBlockWithTrieUpdates<N::Primitives>>,
) -> Result<Option<BlockNumHash>, PersistenceError> {
debug!(target: "engine::persistence", first=?blocks.first().map(|b| b.recovered_block.num_hash()), last=?blocks.last().map(|b| b.recovered_block.num_hash()), "Saving range of blocks");
@@ -149,10 +167,28 @@ where
number: block.recovered_block().header().number(),
});
if last_block_hash_num.is_some() {
if let Some(num_hash) = last_block_hash_num {
let provider_rw = self.provider.database_provider_rw()?;
provider_rw.save_blocks(blocks)?;
for stage in &mut self.custom_stages {
loop {
let checkpoint = provider_rw.get_stage_checkpoint(stage.id())?;
let ExecOutput { checkpoint, done } = stage.execute(
&provider_rw,
ExecInput { target: Some(num_hash.number), checkpoint },
)?;
provider_rw.save_stage_checkpoint(stage.id(), checkpoint)?;
if done {
break
}
}
}
provider_rw.commit()?;
}
self.metrics.save_blocks_duration_seconds.record(start_time.elapsed());
@@ -170,6 +206,10 @@ pub enum PersistenceError {
/// A provider error
#[error(transparent)]
ProviderError(#[from] ProviderError),
/// A stage error
#[error(transparent)]
StageError(#[from] StageError),
}
/// A signal to the persistence service that part of the tree state can be persisted.
@@ -213,6 +253,7 @@ impl<T: NodePrimitives> PersistenceHandle<T> {
provider_factory: ProviderFactory<N>,
pruner: PrunerWithFactory<ProviderFactory<N>>,
sync_metrics_tx: MetricEventsSender,
custom_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
) -> PersistenceHandle<N::Primitives>
where
N: ProviderNodeTypes,
@@ -224,8 +265,13 @@ impl<T: NodePrimitives> PersistenceHandle<T> {
let persistence_handle = PersistenceHandle::new(db_service_tx);
// spawn the persistence service
let db_service =
PersistenceService::new(provider_factory, db_service_rx, pruner, sync_metrics_tx);
let db_service = PersistenceService::new(
provider_factory,
db_service_rx,
pruner,
sync_metrics_tx,
custom_stages,
);
std::thread::Builder::new()
.name("Persistence Service".to_string())
.spawn(|| {
@@ -313,7 +359,12 @@ mod tests {
Pruner::new_with_factory(provider.clone(), vec![], 5, 0, None, finished_exex_height_rx);
let (sync_metrics_tx, _sync_metrics_rx) = unbounded_channel();
PersistenceHandle::<EthPrimitives>::spawn_service(provider, pruner, sync_metrics_tx)
PersistenceHandle::<EthPrimitives>::spawn_service(
provider,
pruner,
sync_metrics_tx,
Default::default(),
)
}
#[tokio::test]

View File

@@ -26,6 +26,7 @@ reth-tasks.workspace = true
reth-network-api.workspace = true
reth-node-types.workspace = true
reth-node-core.workspace = true
reth-stages-api.workspace = true
reth-tokio-util.workspace = true
alloy-rpc-types-engine.workspace = true

View File

@@ -11,7 +11,8 @@ use reth_network_api::FullNetwork;
use reth_node_core::node_config::NodeConfig;
use reth_node_types::{NodeTypes, NodeTypesWithDBAdapter, TxTy};
use reth_payload_builder::PayloadBuilderHandle;
use reth_provider::FullProvider;
use reth_provider::{DatabaseProviderFactory, FullProvider};
use reth_stages_api::BoxedStage;
use reth_tasks::TaskExecutor;
use reth_tokio_util::EventSender;
use reth_transaction_pool::{PoolTransaction, TransactionPool};
@@ -203,6 +204,13 @@ pub trait NodeAddOns<N: FullNodeComponents>: Send {
self,
ctx: AddOnsContext<'_, N>,
) -> impl Future<Output = eyre::Result<Self::Handle>> + Send;
/// Returns additional stages to be added to the pipeline.
fn extra_stages(
&self,
) -> Vec<BoxedStage<<N::Provider as DatabaseProviderFactory>::ProviderRW>> {
Vec::new()
}
}
impl<N: FullNodeComponents> NodeAddOns<N> for () {

View File

@@ -151,6 +151,7 @@ impl EngineNodeLauncher {
ctx.components().evm_config().clone(),
maybe_exex_manager_handle.clone().unwrap_or_else(ExExManagerHandle::empty),
ctx.era_import_source(),
add_ons.extra_stages(),
)?;
// The new engine writes directly to static files. This ensures that they're up to the tip.
@@ -232,6 +233,7 @@ impl EngineNodeLauncher {
engine_tree_config,
ctx.sync_metrics_tx(),
ctx.components().evm_config().clone(),
add_ons.extra_stages(),
);
info!(target: "reth::cli", "Consensus engine initialized");

View File

@@ -16,11 +16,11 @@ use reth_network_p2p::{
bodies::downloader::BodyDownloader, headers::downloader::HeaderDownloader, BlockClient,
};
use reth_node_api::HeaderTy;
use reth_provider::{providers::ProviderNodeTypes, ProviderFactory};
use reth_provider::{providers::ProviderNodeTypes, DatabaseProviderFactory, ProviderFactory};
use reth_stages::{
prelude::DefaultStages,
stages::{EraImportSource, ExecutionStage},
Pipeline, StageSet,
BoxedStage, Pipeline, StageId, StageSet,
};
use reth_static_file::StaticFileProducer;
use reth_tasks::TaskExecutor;
@@ -42,6 +42,7 @@ pub fn build_networked_pipeline<N, Client, Evm>(
evm_config: Evm,
exex_manager_handle: ExExManagerHandle<N::Primitives>,
era_import_source: Option<EraImportSource>,
extra_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
) -> eyre::Result<Pipeline<N>>
where
N: ProviderNodeTypes,
@@ -70,6 +71,7 @@ where
evm_config,
exex_manager_handle,
era_import_source,
extra_stages,
)?;
Ok(pipeline)
@@ -90,6 +92,7 @@ pub fn build_pipeline<N, H, B, Evm>(
evm_config: Evm,
exex_manager_handle: ExExManagerHandle<N::Primitives>,
era_import_source: Option<EraImportSource>,
extra_stages: Vec<BoxedStage<<ProviderFactory<N> as DatabaseProviderFactory>::ProviderRW>>,
) -> eyre::Result<Pipeline<N>>
where
N: ProviderNodeTypes,
@@ -129,7 +132,8 @@ where
stage_config.execution.into(),
stage_config.execution_external_clean_threshold(),
exex_manager_handle,
)),
))
.add_stages_before(extra_stages, StageId::Finish),
)
.build(provider_factory, static_file_producer);

View File

@@ -35,7 +35,7 @@ use reth_errors::{ProviderResult, RethResult};
pub use set::*;
/// A container for a queued stage.
pub(crate) type BoxedStage<DB> = Box<dyn Stage<DB>>;
pub type BoxedStage<Provider> = Box<dyn Stage<Provider>>;
/// The future that returns the owned pipeline and the result of the pipeline run. See
/// [`Pipeline::run_as_fut`].

View File

@@ -172,6 +172,24 @@ impl<Provider> StageSetBuilder<Provider> {
self
}
/// Adds given [`Stage`]s before the stage with the given [`StageId`].
///
/// If the stage was already in the group, it is removed from its previous place.
///
/// # Panics
///
/// Panics if the dependency stage is not in this set.
pub fn add_stages_before<S: Stage<Provider> + 'static>(
mut self,
stages: Vec<S>,
before: StageId,
) -> Self {
for stage in stages {
self = self.add_before(stage, before);
}
self
}
/// Adds the given [`Stage`] after the stage with the given [`StageId`].
///
/// If the stage was already in the group, it is removed from its previous place.

View File

@@ -135,30 +135,6 @@ where
}
}
impl<P, H, B, E> DefaultStages<P, H, B, E>
where
E: ConfigureEvm,
H: HeaderDownloader,
B: BodyDownloader,
{
/// Appends the default offline stages and default finish stage to the given builder.
pub fn add_offline_stages<Provider>(
default_offline: StageSetBuilder<Provider>,
evm_config: E,
consensus: Arc<dyn FullConsensus<E::Primitives, Error = ConsensusError>>,
stages_config: StageConfig,
prune_modes: PruneModes,
) -> StageSetBuilder<Provider>
where
OfflineStages<E>: StageSet<Provider>,
{
StageSetBuilder::default()
.add_set(default_offline)
.add_set(OfflineStages::new(evm_config, consensus, stages_config, prune_modes))
.add_stage(FinishStage)
}
}
impl<P, H, B, E, Provider> StageSet<Provider> for DefaultStages<P, H, B, E>
where
P: HeaderSyncGapProvider + 'static,
@@ -169,13 +145,11 @@ where
OfflineStages<E>: StageSet<Provider>,
{
fn builder(self) -> StageSetBuilder<Provider> {
Self::add_offline_stages(
self.online.builder(),
self.evm_config,
self.consensus,
self.stages_config.clone(),
self.prune_modes,
)
let Self { online, evm_config, consensus, stages_config, prune_modes } = self;
StageSetBuilder::default()
.add_set(online.builder())
.add_set(OfflineStages::new(evm_config, consensus, stages_config, prune_modes))
.add_stage(FinishStage)
}
}

View File

@@ -25,7 +25,10 @@ use reth_stages_api::{
};
use reth_static_file_types::StaticFileSegment;
use reth_storage_errors::provider::ProviderError;
use std::task::{ready, Context, Poll};
use std::{
fmt::Debug,
task::{ready, Context, Poll},
};
use tokio::sync::watch;
use tracing::*;
@@ -186,7 +189,7 @@ where
impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
where
Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory + Debug,
P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
<Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,