From 68e4ff1f7d9f5b40bc02ba7433077dcba0456783 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Wed, 11 Feb 2026 04:45:09 +0100 Subject: [PATCH] feat: global runtime (#21934) --- .github/scripts/check_wasm.sh | 1 + Cargo.lock | 3 + bin/reth/src/lib.rs | 3 +- bin/reth/src/ress.rs | 4 +- crates/cli/cli/src/lib.rs | 3 +- crates/cli/commands/src/common.rs | 8 +- crates/cli/commands/src/db/mod.rs | 22 +- crates/cli/commands/src/db/repair_trie.rs | 2 +- crates/cli/commands/src/prune.rs | 2 +- .../cli/commands/src/stage/dump/execution.rs | 2 + .../src/stage/dump/hashing_account.rs | 2 + .../src/stage/dump/hashing_storage.rs | 2 + crates/cli/commands/src/stage/dump/merkle.rs | 2 + crates/cli/runner/src/lib.rs | 152 ++- crates/e2e-test-utils/src/lib.rs | 5 +- crates/e2e-test-utils/src/setup_builder.rs | 10 +- crates/e2e-test-utils/src/setup_import.rs | 14 +- crates/e2e-test-utils/src/testsuite/setup.rs | 7 +- .../tests/e2e-testsuite/main.rs | 17 +- crates/e2e-test-utils/tests/rocksdb/main.rs | 14 +- crates/engine/service/src/service.rs | 1 + crates/engine/tree/Cargo.toml | 1 + crates/engine/tree/benches/state_root_task.rs | 5 +- crates/engine/tree/src/backfill.rs | 2 +- .../src/tree/payload_processor/executor.rs | 47 - .../tree/src/tree/payload_processor/mod.rs | 21 +- .../src/tree/payload_processor/multiproof.rs | 18 +- .../src/tree/payload_processor/prewarm.rs | 10 +- .../src/tree/payload_processor/sparse_trie.rs | 4 +- .../engine/tree/src/tree/payload_validator.rs | 11 +- crates/engine/tree/src/tree/tests.rs | 2 + crates/ethereum/cli/src/interface.rs | 4 +- crates/ethereum/node/Cargo.toml | 1 + crates/ethereum/node/src/node.rs | 25 +- crates/ethereum/node/tests/e2e/blobs.rs | 17 +- .../ethereum/node/tests/e2e/custom_genesis.rs | 4 +- crates/ethereum/node/tests/e2e/dev.rs | 12 +- crates/ethereum/node/tests/e2e/eth.rs | 23 +- .../node/tests/e2e/invalid_payload.rs | 6 +- crates/ethereum/node/tests/e2e/p2p.rs | 10 +- crates/ethereum/node/tests/e2e/pool.rs | 29 +- crates/ethereum/node/tests/e2e/prestate.rs | 7 +- crates/ethereum/node/tests/e2e/rpc.rs | 15 +- .../ethereum/node/tests/e2e/selfdestruct.rs | 8 +- crates/ethereum/node/tests/e2e/simulate.rs | 2 +- crates/ethereum/node/tests/it/builder.rs | 14 +- crates/ethereum/node/tests/it/testing.rs | 6 +- crates/ethereum/reth/Cargo.toml | 1 + crates/exex/test-utils/src/lib.rs | 15 +- crates/net/downloaders/Cargo.toml | 1 + crates/net/downloaders/src/bodies/task.rs | 2 +- crates/net/downloaders/src/headers/task.rs | 2 +- crates/net/network/Cargo.toml | 1 + crates/net/network/src/session/mod.rs | 2 +- crates/node/builder/Cargo.toml | 3 +- crates/node/builder/src/builder/mod.rs | 4 +- crates/node/builder/src/builder/states.rs | 6 +- crates/node/builder/src/components/payload.rs | 5 +- crates/node/builder/src/components/pool.rs | 2 +- crates/node/builder/src/launch/common.rs | 18 +- crates/node/builder/src/launch/debug.rs | 15 +- crates/node/builder/src/launch/engine.rs | 4 +- crates/node/builder/src/launch/exex.rs | 6 +- crates/node/builder/src/rpc.rs | 3 +- crates/node/metrics/src/server.rs | 12 +- crates/payload/basic/src/lib.rs | 6 +- crates/ress/provider/src/lib.rs | 2 +- crates/rpc/rpc-engine-api/src/engine_api.rs | 4 +- .../rpc-eth-api/src/helpers/blocking_task.rs | 8 +- crates/rpc/rpc-eth-types/src/cache/mod.rs | 71 +- crates/rpc/rpc/src/debug.rs | 2 +- crates/rpc/rpc/src/eth/builder.rs | 2 +- crates/rpc/rpc/src/eth/core.rs | 2 +- crates/rpc/rpc/src/eth/filter.rs | 4 +- crates/rpc/rpc/src/eth/pubsub.rs | 2 +- crates/rpc/rpc/src/reth.rs | 8 +- crates/rpc/rpc/src/validation.rs | 6 +- crates/stages/api/Cargo.toml | 2 + crates/stages/api/src/stage.rs | 1 + crates/stages/stages/Cargo.toml | 1 + .../stages/stages/src/test_utils/test_db.rs | 2 + crates/storage/db-common/Cargo.toml | 1 + crates/storage/db-common/src/init.rs | 1 + crates/storage/provider/Cargo.toml | 3 +- crates/storage/provider/src/lib.rs | 3 - .../src/providers/database/builder.rs | 75 +- .../provider/src/providers/database/mod.rs | 17 +- .../src/providers/database/provider.rs | 20 +- .../src/providers/rocksdb/provider.rs | 8 +- .../src/providers/static_file/manager.rs | 5 +- .../provider/src/storage_threadpool.rs | 23 - crates/storage/provider/src/test_utils/mod.rs | 1 + crates/tasks/Cargo.toml | 3 +- crates/tasks/src/lib.rs | 649 ++---------- crates/tasks/src/runtime.rs | 928 ++++++++++++++++++ crates/transaction-pool/Cargo.toml | 1 + crates/transaction-pool/src/lib.rs | 9 +- crates/transaction-pool/src/maintain.rs | 18 +- crates/transaction-pool/src/validate/eth.rs | 4 +- crates/trie/parallel/Cargo.toml | 3 +- crates/trie/parallel/benches/root.rs | 6 +- crates/trie/parallel/src/proof.rs | 7 +- crates/trie/parallel/src/proof_task.rs | 40 +- crates/trie/parallel/src/root.rs | 44 +- .../sdk/examples/standalone-components.mdx | 6 +- .../beacon-api-sidecar-fetcher/src/main.rs | 2 +- examples/beacon-api-sse/src/main.rs | 2 +- examples/custom-dev-node/src/main.rs | 6 +- examples/custom-engine-types/src/main.rs | 6 +- examples/custom-evm/src/main.rs | 6 +- examples/custom-inspector/src/main.rs | 2 +- examples/custom-node-components/src/main.rs | 2 +- examples/custom-payload-builder/src/main.rs | 2 +- examples/custom-rlpx-subprotocol/src/main.rs | 2 +- examples/db-access/src/main.rs | 8 +- examples/full-contract-state/src/main.rs | 8 +- examples/precompile-cache/src/main.rs | 6 +- examples/rpc-db/src/main.rs | 4 +- examples/txpool-tracing/src/main.rs | 2 +- 119 files changed, 1612 insertions(+), 1126 deletions(-) delete mode 100644 crates/engine/tree/src/tree/payload_processor/executor.rs delete mode 100644 crates/storage/provider/src/storage_threadpool.rs create mode 100644 crates/tasks/src/runtime.rs diff --git a/.github/scripts/check_wasm.sh b/.github/scripts/check_wasm.sh index 3472ac9e38..20376ea363 100755 --- a/.github/scripts/check_wasm.sh +++ b/.github/scripts/check_wasm.sh @@ -63,6 +63,7 @@ exclude_crates=( reth-provider # tokio reth-prune # tokio reth-prune-static-files # reth-provider + reth-tasks # tokio rt-multi-thread reth-stages-api # reth-provider, reth-prune reth-static-file # tokio reth-transaction-pool # c-kzg diff --git a/Cargo.lock b/Cargo.lock index 4c3b0443c7..a64f94d174 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8023,6 +8023,7 @@ dependencies = [ "reth-provider", "reth-stages-types", "reth-static-file-types", + "reth-tasks", "reth-trie", "reth-trie-db", "serde", @@ -10266,6 +10267,7 @@ dependencies = [ "reth-stages-types", "reth-static-file", "reth-static-file-types", + "reth-tasks", "reth-testing-utils", "reth-tokio-util", "thiserror 2.0.18", @@ -10671,6 +10673,7 @@ dependencies = [ "reth-primitives-traits", "reth-provider", "reth-storage-errors", + "reth-tasks", "reth-trie", "reth-trie-common", "reth-trie-db", diff --git a/bin/reth/src/lib.rs b/bin/reth/src/lib.rs index adaa87cec9..e7bc42a28e 100644 --- a/bin/reth/src/lib.rs +++ b/bin/reth/src/lib.rs @@ -15,7 +15,6 @@ //! - `asm-keccak`: Replaces the default, pure-Rust implementation of Keccak256 with one implemented //! in assembly; see [the `keccak-asm` crate](https://github.com/DaniPopes/keccak-asm) for more //! details and supported targets. -//! - `min-debug-logs`: Disables all logs below `debug` level. //! //! ### Allocator Features //! @@ -211,7 +210,7 @@ pub mod ress; // re-export for convenience #[doc(inline)] -pub use reth_cli_runner::{tokio_runtime, CliContext, CliRunner}; +pub use reth_cli_runner::{CliContext, CliRunner}; // for rendering diagrams use aquamarine as _; diff --git a/bin/reth/src/ress.rs b/bin/reth/src/ress.rs index 88d3e2aa69..c2cfa74c2a 100644 --- a/bin/reth/src/ress.rs +++ b/bin/reth/src/ress.rs @@ -30,7 +30,7 @@ where let pending_state = PendingState::default(); // Spawn maintenance task for pending state. - task_executor.spawn(maintain_pending_state( + task_executor.spawn_task(maintain_pending_state( engine_events, provider.clone(), pending_state.clone(), @@ -58,7 +58,7 @@ where ); info!(target: "reth::cli", "Ress subprotocol support enabled"); - task_executor.spawn(async move { + task_executor.spawn_task(async move { while let Some(event) = rx.recv().await { trace!(target: "reth::ress", ?event, "Received ress event"); } diff --git a/crates/cli/cli/src/lib.rs b/crates/cli/cli/src/lib.rs index 43dde328f1..8ef61c130d 100644 --- a/crates/cli/cli/src/lib.rs +++ b/crates/cli/cli/src/lib.rs @@ -66,7 +66,8 @@ pub trait RethCli: Sized { F: FnOnce(Self, CliRunner) -> R, { let cli = Self::parse_args()?; - let runner = CliRunner::try_default_runtime()?; + let runner = CliRunner::try_default_runtime() + .map_err(|e| Error::raw(clap::error::ErrorKind::Io, e))?; Ok(cli.with_runner(f, runner)) } diff --git a/crates/cli/commands/src/common.rs b/crates/cli/commands/src/common.rs index d2ae5d228a..44dedb9775 100644 --- a/crates/cli/commands/src/common.rs +++ b/crates/cli/commands/src/common.rs @@ -127,10 +127,14 @@ impl EnvironmentArgs { /// Initializes environment according to [`AccessRights`] and returns an instance of /// [`Environment`]. + /// + /// Internally builds a [`reth_tasks::Runtime`] attached to the current tokio handle for + /// parallel storage I/O. pub fn init(&self, access: AccessRights) -> eyre::Result> where C: ChainSpecParser, { + let runtime = reth_tasks::Runtime::with_existing_handle(tokio::runtime::Handle::current())?; let data_dir = self.datadir.clone().resolve_datadir(self.chain.chain()); let db_path = data_dir.db(); let sf_path = data_dir.static_files(); @@ -186,7 +190,7 @@ impl EnvironmentArgs { .build()?; let provider_factory = - self.create_provider_factory(&config, db, sfp, rocksdb_provider, access)?; + self.create_provider_factory(&config, db, sfp, rocksdb_provider, access, runtime)?; if access.is_read_write() { debug!(target: "reth::cli", chain=%self.chain.chain(), genesis=?self.chain.genesis_hash(), "Initializing genesis"); init_genesis_with_settings(&provider_factory, self.storage_settings())?; @@ -207,6 +211,7 @@ impl EnvironmentArgs { static_file_provider: StaticFileProvider, rocksdb_provider: RocksDBProvider, access: AccessRights, + runtime: reth_tasks::Runtime, ) -> eyre::Result>> where C: ChainSpecParser, @@ -217,6 +222,7 @@ impl EnvironmentArgs { self.chain.clone(), static_file_provider, rocksdb_provider, + runtime, )? .with_prune_modes(prune_modes.clone()); diff --git a/crates/cli/commands/src/db/mod.rs b/crates/cli/commands/src/db/mod.rs index 2508911e12..be21c05022 100644 --- a/crates/cli/commands/src/db/mod.rs +++ b/crates/cli/commands/src/db/mod.rs @@ -70,23 +70,23 @@ pub enum Subcommands { State(state::Command), } -/// Initializes a provider factory with specified access rights, and then execute with the provided -/// command -macro_rules! db_exec { - ($env:expr, $tool:ident, $N:ident, $access_rights:expr, $command:block) => { - let Environment { provider_factory, .. } = $env.init::<$N>($access_rights)?; - - let $tool = DbTool::new(provider_factory)?; - $command; - }; -} - impl> Command { /// Execute `db` command pub async fn execute>( self, ctx: CliContext, ) -> eyre::Result<()> { + /// Initializes a provider factory with specified access rights, and then executes the + /// provided command. + macro_rules! db_exec { + ($env:expr, $tool:ident, $N:ident, $access_rights:expr, $command:block) => { + let Environment { provider_factory, .. } = $env.init::<$N>($access_rights)?; + + let $tool = DbTool::new(provider_factory)?; + $command; + }; + } + let data_dir = self.env.datadir.clone().resolve_datadir(self.env.chain.chain()); let db_path = data_dir.db(); let static_files_path = data_dir.static_files(); diff --git a/crates/cli/commands/src/db/repair_trie.rs b/crates/cli/commands/src/db/repair_trie.rs index 6d4ed52ea6..8ea42d2fa2 100644 --- a/crates/cli/commands/src/db/repair_trie.rs +++ b/crates/cli/commands/src/db/repair_trie.rs @@ -64,7 +64,7 @@ impl Command { let executor = task_executor.clone(); let pprof_dump_dir = data_dir.pprof_dumps(); - let handle = task_executor.spawn_critical("metrics server", async move { + let handle = task_executor.spawn_critical_task("metrics server", async move { let config = MetricServerConfig::new( listen_addr, VersionInfo { diff --git a/crates/cli/commands/src/prune.rs b/crates/cli/commands/src/prune.rs index 9d1c9c62de..b40f1b71be 100644 --- a/crates/cli/commands/src/prune.rs +++ b/crates/cli/commands/src/prune.rs @@ -76,7 +76,7 @@ impl> PruneComma // Set up cancellation token for graceful shutdown on Ctrl+C let cancellation = CancellationToken::new(); let cancellation_clone = cancellation.clone(); - ctx.task_executor.spawn_critical("prune-ctrl-c", async move { + ctx.task_executor.spawn_critical_task("prune-ctrl-c", async move { tokio::signal::ctrl_c().await.expect("failed to listen for ctrl-c"); cancellation_clone.cancel(); }); diff --git a/crates/cli/commands/src/stage/dump/execution.rs b/crates/cli/commands/src/stage/dump/execution.rs index 9bb9da1d5b..912b99b8be 100644 --- a/crates/cli/commands/src/stage/dump/execution.rs +++ b/crates/cli/commands/src/stage/dump/execution.rs @@ -37,12 +37,14 @@ where unwind_and_copy(db_tool, from, tip_block_number, &output_db, evm_config.clone())?; if should_run { + let runtime = reth_tasks::Runtime::with_existing_handle(tokio::runtime::Handle::current())?; dry_run( ProviderFactory::::new( output_db, db_tool.chain(), StaticFileProvider::read_write(output_datadir.static_files())?, RocksDBProvider::builder(output_datadir.rocksdb()).build()?, + runtime, )?, to, from, diff --git a/crates/cli/commands/src/stage/dump/hashing_account.rs b/crates/cli/commands/src/stage/dump/hashing_account.rs index cc3ffac38d..42c23646cb 100644 --- a/crates/cli/commands/src/stage/dump/hashing_account.rs +++ b/crates/cli/commands/src/stage/dump/hashing_account.rs @@ -33,12 +33,14 @@ pub(crate) async fn dump_hashing_account_stage::new( output_db, db_tool.chain(), StaticFileProvider::read_write(output_datadir.static_files())?, RocksDBProvider::builder(output_datadir.rocksdb()).build()?, + runtime, )?, to, from, diff --git a/crates/cli/commands/src/stage/dump/hashing_storage.rs b/crates/cli/commands/src/stage/dump/hashing_storage.rs index 5d19bb1d26..538aa21dcc 100644 --- a/crates/cli/commands/src/stage/dump/hashing_storage.rs +++ b/crates/cli/commands/src/stage/dump/hashing_storage.rs @@ -23,12 +23,14 @@ pub(crate) async fn dump_hashing_storage_stage::new( output_db, db_tool.chain(), StaticFileProvider::read_write(output_datadir.static_files())?, RocksDBProvider::builder(output_datadir.rocksdb()).build()?, + runtime, )?, to, from, diff --git a/crates/cli/commands/src/stage/dump/merkle.rs b/crates/cli/commands/src/stage/dump/merkle.rs index 4376134637..0932eeaf8b 100644 --- a/crates/cli/commands/src/stage/dump/merkle.rs +++ b/crates/cli/commands/src/stage/dump/merkle.rs @@ -57,12 +57,14 @@ where unwind_and_copy(db_tool, (from, to), tip_block_number, &output_db, evm_config, consensus)?; if should_run { + let runtime = reth_tasks::Runtime::with_existing_handle(tokio::runtime::Handle::current())?; dry_run( ProviderFactory::::new( output_db, db_tool.chain(), StaticFileProvider::read_write(output_datadir.static_files())?, RocksDBProvider::builder(output_datadir.rocksdb()).build()?, + runtime, )?, to, from, diff --git a/crates/cli/runner/src/lib.rs b/crates/cli/runner/src/lib.rs index fdb04b1e2d..5a802d117a 100644 --- a/crates/cli/runner/src/lib.rs +++ b/crates/cli/runner/src/lib.rs @@ -10,8 +10,9 @@ //! Entrypoint for running commands. -use reth_tasks::{TaskExecutor, TaskManager}; +use reth_tasks::{PanickedTaskError, TaskExecutor}; use std::{future::Future, pin::pin, sync::mpsc, time::Duration}; +use tokio::task::JoinHandle; use tracing::{debug, error, trace}; /// Executes CLI commands. @@ -20,21 +21,18 @@ use tracing::{debug, error, trace}; #[derive(Debug)] pub struct CliRunner { config: CliRunnerConfig, - tokio_runtime: tokio::runtime::Runtime, + runtime: reth_tasks::Runtime, } impl CliRunner { - /// Attempts to create a new [`CliRunner`] using the default tokio - /// [`Runtime`](tokio::runtime::Runtime). + /// Attempts to create a new [`CliRunner`] using the default + /// [`Runtime`](reth_tasks::Runtime). /// - /// The default tokio runtime is multi-threaded, with both I/O and time drivers enabled. - pub fn try_default_runtime() -> Result { - Ok(Self { config: CliRunnerConfig::default(), tokio_runtime: tokio_runtime()? }) - } - - /// Create a new [`CliRunner`] from a provided tokio [`Runtime`](tokio::runtime::Runtime). - pub const fn from_runtime(tokio_runtime: tokio::runtime::Runtime) -> Self { - Self { config: CliRunnerConfig::new(), tokio_runtime } + /// The default runtime is multi-threaded, with both I/O and time drivers enabled. + pub fn try_default_runtime() -> Result { + let runtime = + reth_tasks::RuntimeBuilder::new(reth_tasks::RuntimeConfig::default()).build()?; + Ok(Self { config: CliRunnerConfig::default(), runtime }) } /// Sets the [`CliRunnerConfig`] for this runner. @@ -48,7 +46,7 @@ impl CliRunner { where F: Future, { - self.tokio_runtime.block_on(fut) + self.runtime.handle().block_on(fut) } /// Executes the given _async_ command on the tokio runtime until the command future resolves or @@ -64,12 +62,11 @@ impl CliRunner { F: Future>, E: Send + Sync + From + From + 'static, { - let AsyncCliRunner { context, mut task_manager, tokio_runtime } = - AsyncCliRunner::new(self.tokio_runtime); + let (context, task_manager_handle) = cli_context(&self.runtime); // Executes the command until it finished or ctrl-c was fired - let command_res = tokio_runtime.block_on(run_to_completion_or_panic( - &mut task_manager, + let command_res = self.runtime.handle().block_on(run_to_completion_or_panic( + task_manager_handle, run_until_ctrl_c(command(context)), )); @@ -77,13 +74,13 @@ impl CliRunner { error!(target: "reth::cli", "shutting down due to error"); } else { debug!(target: "reth::cli", "shutting down gracefully"); - // after the command has finished or exit signal was received we shutdown the task - // manager which fires the shutdown signal to all tasks spawned via the task + // after the command has finished or exit signal was received we shutdown the + // runtime which fires the shutdown signal to all tasks spawned via the task // executor and awaiting on tasks spawned with graceful shutdown - task_manager.graceful_shutdown_with_timeout(self.config.graceful_shutdown_timeout); + self.runtime.graceful_shutdown_with_timeout(self.config.graceful_shutdown_timeout); } - tokio_shutdown(tokio_runtime, true); + runtime_shutdown(self.runtime, true); command_res } @@ -99,17 +96,16 @@ impl CliRunner { F: Future> + Send + 'static, E: Send + Sync + From + From + 'static, { - let AsyncCliRunner { context, mut task_manager, tokio_runtime } = - AsyncCliRunner::new(self.tokio_runtime); + let (context, task_manager_handle) = cli_context(&self.runtime); // Spawn the command on the blocking thread pool - let handle = tokio_runtime.handle().clone(); - let command_handle = - tokio_runtime.handle().spawn_blocking(move || handle.block_on(command(context))); + let handle = self.runtime.handle().clone(); + let handle2 = handle.clone(); + let command_handle = handle.spawn_blocking(move || handle2.block_on(command(context))); // Wait for the command to complete or ctrl-c - let command_res = tokio_runtime.block_on(run_to_completion_or_panic( - &mut task_manager, + let command_res = self.runtime.handle().block_on(run_to_completion_or_panic( + task_manager_handle, run_until_ctrl_c( async move { command_handle.await.expect("Failed to join blocking task") }, ), @@ -119,10 +115,10 @@ impl CliRunner { error!(target: "reth::cli", "shutting down due to error"); } else { debug!(target: "reth::cli", "shutting down gracefully"); - task_manager.graceful_shutdown_with_timeout(self.config.graceful_shutdown_timeout); + self.runtime.graceful_shutdown_with_timeout(self.config.graceful_shutdown_timeout); } - tokio_shutdown(tokio_runtime, true); + runtime_shutdown(self.runtime, true); command_res } @@ -133,48 +129,40 @@ impl CliRunner { F: Future>, E: Send + Sync + From + 'static, { - self.tokio_runtime.block_on(run_until_ctrl_c(fut))?; + self.runtime.handle().block_on(run_until_ctrl_c(fut))?; Ok(()) } /// Executes a regular future as a spawned blocking task until completion or until external /// signal received. /// - /// See [`Runtime::spawn_blocking`](tokio::runtime::Runtime::spawn_blocking) . + /// See [`Runtime::spawn_blocking`](tokio::runtime::Runtime::spawn_blocking). pub fn run_blocking_until_ctrl_c(self, fut: F) -> Result<(), E> where F: Future> + Send + 'static, E: Send + Sync + From + 'static, { - let tokio_runtime = self.tokio_runtime; - let handle = tokio_runtime.handle().clone(); - let fut = tokio_runtime.handle().spawn_blocking(move || handle.block_on(fut)); - tokio_runtime + let handle = self.runtime.handle().clone(); + let handle2 = handle.clone(); + let fut = handle.spawn_blocking(move || handle2.block_on(fut)); + self.runtime + .handle() .block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))?; - tokio_shutdown(tokio_runtime, false); + runtime_shutdown(self.runtime, false); Ok(()) } } -/// [`CliRunner`] configuration when executing commands asynchronously -struct AsyncCliRunner { - context: CliContext, - task_manager: TaskManager, - tokio_runtime: tokio::runtime::Runtime, -} - -// === impl AsyncCliRunner === - -impl AsyncCliRunner { - /// Given a tokio [`Runtime`](tokio::runtime::Runtime), creates additional context required to - /// execute commands asynchronously. - fn new(tokio_runtime: tokio::runtime::Runtime) -> Self { - let task_manager = TaskManager::new(tokio_runtime.handle().clone()); - let task_executor = task_manager.executor(); - Self { context: CliContext { task_executor }, task_manager, tokio_runtime } - } +/// Extracts the task manager handle from the runtime and creates the [`CliContext`]. +fn cli_context( + runtime: &reth_tasks::Runtime, +) -> (CliContext, JoinHandle>) { + let handle = + runtime.take_task_manager_handle().expect("Runtime must contain a TaskManager handle"); + let context = CliContext { task_executor: runtime.clone() }; + (context, handle) } /// Additional context provided by the [`CliRunner`] when executing commands @@ -216,37 +204,25 @@ impl CliRunnerConfig { } } -/// Creates a new default tokio multi-thread [Runtime](tokio::runtime::Runtime) with all features -/// enabled -pub fn tokio_runtime() -> Result { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - // Keep the threads alive for at least the block time (12 seconds) plus buffer. - // This prevents the costly process of spawning new threads on every - // new block, and instead reuses the existing threads. - .thread_keep_alive(Duration::from_secs(15)) - .thread_name("tokio-rt") - .build() -} - /// Runs the given future to completion or until a critical task panicked. /// /// Returns the error if a task panicked, or the given future returned an error. -async fn run_to_completion_or_panic(tasks: &mut TaskManager, fut: F) -> Result<(), E> +async fn run_to_completion_or_panic( + task_manager_handle: JoinHandle>, + fut: F, +) -> Result<(), E> where F: Future>, E: Send + Sync + From + 'static, { - { - let fut = pin!(fut); - tokio::select! { - task_manager_result = tasks => { - if let Err(panicked_error) = task_manager_result { - return Err(panicked_error.into()); - } - }, - res = fut => res?, - } + let fut = pin!(fut); + tokio::select! { + task_manager_result = task_manager_handle => { + if let Ok(Err(panicked_error)) = task_manager_result { + return Err(panicked_error.into()); + } + }, + res = fut => res?, } Ok(()) } @@ -296,17 +272,17 @@ where Ok(()) } -/// Shut down the given Tokio runtime, and wait for it if `wait` is set. +/// Default timeout for waiting on the tokio runtime to shut down. +const DEFAULT_RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +/// Shut down the given [`Runtime`](reth_tasks::Runtime), and wait for it if `wait` is set. /// -/// `drop(tokio_runtime)` would block the current thread until its pools -/// (including blocking pool) are shutdown. Since we want to exit as soon as possible, drop -/// it on a separate thread and wait for up to 5 seconds for this operation to -/// complete. -fn tokio_shutdown(rt: tokio::runtime::Runtime, wait: bool) { - // Shutdown the runtime on a separate thread +/// Dropping the runtime on the current thread could block due to tokio pool teardown. +/// Instead, we drop it on a separate thread and optionally wait for completion. +fn runtime_shutdown(rt: reth_tasks::Runtime, wait: bool) { let (tx, rx) = mpsc::channel(); std::thread::Builder::new() - .name("tokio-shutdown".to_string()) + .name("rt-shutdown".to_string()) .spawn(move || { drop(rt); let _ = tx.send(()); @@ -314,8 +290,8 @@ fn tokio_shutdown(rt: tokio::runtime::Runtime, wait: bool) { .unwrap(); if wait { - let _ = rx.recv_timeout(Duration::from_secs(5)).inspect_err(|err| { - debug!(target: "reth::cli", %err, "tokio runtime shutdown timed out"); + let _ = rx.recv_timeout(DEFAULT_RUNTIME_SHUTDOWN_TIMEOUT).inspect_err(|err| { + tracing::warn!(target: "reth::cli", %err, "runtime shutdown timed out"); }); } } diff --git a/crates/e2e-test-utils/src/lib.rs b/crates/e2e-test-utils/src/lib.rs index aadf101eb7..d0fb281198 100644 --- a/crates/e2e-test-utils/src/lib.rs +++ b/crates/e2e-test-utils/src/lib.rs @@ -11,7 +11,6 @@ use reth_node_builder::{ PayloadTypes, }; use reth_provider::providers::{BlockchainProvider, NodeTypesForProvider}; -use reth_tasks::TaskManager; use std::sync::Arc; use wallet::Wallet; @@ -50,7 +49,7 @@ pub async fn setup( chain_spec: Arc, is_dev: bool, attributes_generator: impl Fn(u64) -> <::Payload as PayloadTypes>::PayloadBuilderAttributes + Send + Sync + Copy + 'static, -) -> eyre::Result<(Vec>, TaskManager, Wallet)> +) -> eyre::Result<(Vec>, Wallet)> where N: NodeBuilderHelper, { @@ -69,7 +68,6 @@ pub async fn setup_engine( attributes_generator: impl Fn(u64) -> <::Payload as PayloadTypes>::PayloadBuilderAttributes + Send + Sync + Copy + 'static, ) -> eyre::Result<( Vec>>>, - TaskManager, Wallet, )> where @@ -96,7 +94,6 @@ pub async fn setup_engine_with_connection( connect_nodes: bool, ) -> eyre::Result<( Vec>>>, - TaskManager, Wallet, )> where diff --git a/crates/e2e-test-utils/src/setup_builder.rs b/crates/e2e-test-utils/src/setup_builder.rs index 2bbfb66a0c..efc01c8856 100644 --- a/crates/e2e-test-utils/src/setup_builder.rs +++ b/crates/e2e-test-utils/src/setup_builder.rs @@ -14,7 +14,7 @@ use reth_node_core::args::{DiscoveryArgs, NetworkArgs, RpcServerArgs}; use reth_primitives_traits::AlloyBlockHeader; use reth_provider::providers::BlockchainProvider; use reth_rpc_server_types::RpcModuleSelection; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::sync::Arc; use tracing::{span, Instrument, Level}; @@ -110,11 +110,9 @@ where self, ) -> eyre::Result<( Vec>>>, - TaskManager, Wallet, )> { - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; let network_config = NetworkArgs { discovery: DiscoveryArgs { disable_discovery: true, ..DiscoveryArgs::default() }, @@ -153,7 +151,7 @@ where let span = span!(Level::INFO, "node", idx); let node = N::default(); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .with_types_and_provider::>() .with_components(node.components_builder()) .with_add_ons(node.add_ons()) @@ -197,7 +195,7 @@ where } } - Ok((nodes, tasks, Wallet::default().with_chain_id(self.chain_spec.chain().into()))) + Ok((nodes, Wallet::default().with_chain_id(self.chain_spec.chain().into()))) } } diff --git a/crates/e2e-test-utils/src/setup_import.rs b/crates/e2e-test-utils/src/setup_import.rs index 00d321afb3..b853b2ff1e 100644 --- a/crates/e2e-test-utils/src/setup_import.rs +++ b/crates/e2e-test-utils/src/setup_import.rs @@ -15,7 +15,7 @@ use reth_provider::{ }; use reth_rpc_server_types::RpcModuleSelection; use reth_stages_types::StageId; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::{path::Path, sync::Arc}; use tempfile::TempDir; use tracing::{debug, info, span, Level}; @@ -24,8 +24,6 @@ use tracing::{debug, info, span, Level}; pub struct ChainImportResult { /// The nodes that were created pub nodes: Vec>, - /// The task manager - pub task_manager: TaskManager, /// The wallet for testing pub wallet: Wallet, /// Temporary directories that must be kept alive for the duration of the test @@ -68,8 +66,7 @@ pub async fn setup_engine_with_chain_import( + Copy + 'static, ) -> eyre::Result { - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; let network_config = NetworkArgs { discovery: DiscoveryArgs { disable_discovery: true, ..DiscoveryArgs::default() }, @@ -129,6 +126,7 @@ pub async fn setup_engine_with_chain_import( .with_default_tables() .build() .unwrap(), + reth_tasks::Runtime::test(), )?; // Initialize genesis if needed @@ -221,7 +219,7 @@ pub async fn setup_engine_with_chain_import( let node = EthereumNode::default(); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node_with_datadir(exec.clone(), datadir.clone()) + .testing_node_with_datadir(runtime.clone(), datadir.clone()) .with_types_and_provider::>() .with_components(node.components_builder()) .with_add_ons(node.add_ons()) @@ -243,7 +241,6 @@ pub async fn setup_engine_with_chain_import( Ok(ChainImportResult { nodes, - task_manager: tasks, wallet: crate::Wallet::default().with_chain_id(chain_spec.chain.id()), _temp_dirs: temp_dirs, }) @@ -333,6 +330,7 @@ mod tests { .with_default_tables() .build() .unwrap(), + reth_tasks::Runtime::test(), ) .expect("failed to create provider factory"); @@ -397,6 +395,7 @@ mod tests { .with_default_tables() .build() .unwrap(), + reth_tasks::Runtime::test(), ) .expect("failed to create provider factory"); @@ -497,6 +496,7 @@ mod tests { .with_default_tables() .build() .unwrap(), + reth_tasks::Runtime::test(), ) .expect("failed to create provider factory"); diff --git a/crates/e2e-test-utils/src/testsuite/setup.rs b/crates/e2e-test-utils/src/testsuite/setup.rs index e7a57e7075..4577a5ab0b 100644 --- a/crates/e2e-test-utils/src/testsuite/setup.rs +++ b/crates/e2e-test-utils/src/testsuite/setup.rs @@ -210,7 +210,7 @@ where let mut node_clients = Vec::new(); match result { - Ok((nodes, executor, _wallet)) => { + Ok((nodes, _wallet)) => { // create HTTP clients for each node's RPC and Engine API endpoints for node in &nodes { node_clients.push(node.to_node_client()?); @@ -218,12 +218,11 @@ where // spawn a separate task just to handle the shutdown tokio::spawn(async move { - // keep nodes and executor in scope to ensure they're not dropped + // keep nodes in scope to ensure they're not dropped let _nodes = nodes; - let _executor = executor; // Wait for shutdown signal let _ = shutdown_rx.recv().await; - // nodes and executor will be dropped here when the test completes + // nodes will be dropped here when the test completes }); } Err(e) => { diff --git a/crates/e2e-test-utils/tests/e2e-testsuite/main.rs b/crates/e2e-test-utils/tests/e2e-testsuite/main.rs index 4a2ac77ec6..a8314c7918 100644 --- a/crates/e2e-test-utils/tests/e2e-testsuite/main.rs +++ b/crates/e2e-test-utils/tests/e2e-testsuite/main.rs @@ -370,15 +370,14 @@ async fn test_setup_builder_with_custom_tree_config() -> Result<()> { .build(), ); - let (nodes, _tasks, _wallet) = - E2ETestSetupBuilder::::new(1, chain_spec, |_| { - EthPayloadBuilderAttributes::default() - }) - .with_tree_config_modifier(|config| { - config.with_persistence_threshold(0).with_memory_block_buffer_target(5) - }) - .build() - .await?; + let (nodes, _wallet) = E2ETestSetupBuilder::::new(1, chain_spec, |_| { + EthPayloadBuilderAttributes::default() + }) + .with_tree_config_modifier(|config| { + config.with_persistence_threshold(0).with_memory_block_buffer_target(5) + }) + .build() + .await?; assert_eq!(nodes.len(), 1); diff --git a/crates/e2e-test-utils/tests/rocksdb/main.rs b/crates/e2e-test-utils/tests/rocksdb/main.rs index c49295ef0a..3f6a02c49f 100644 --- a/crates/e2e-test-utils/tests/rocksdb/main.rs +++ b/crates/e2e-test-utils/tests/rocksdb/main.rs @@ -119,7 +119,7 @@ async fn test_rocksdb_node_startup() -> Result<()> { let chain_spec = test_chain_spec(); - let (nodes, _tasks, _wallet) = + let (nodes, _wallet) = E2ETestSetupBuilder::::new(1, chain_spec, test_attributes_generator) .with_storage_v2() .build() @@ -147,7 +147,7 @@ async fn test_rocksdb_block_mining() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _wallet) = + let (mut nodes, _wallet) = E2ETestSetupBuilder::::new(1, chain_spec, test_attributes_generator) .with_storage_v2() .build() @@ -201,7 +201,7 @@ async fn test_rocksdb_transaction_queries() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _) = E2ETestSetupBuilder::::new( + let (mut nodes, _) = E2ETestSetupBuilder::::new( 1, chain_spec.clone(), test_attributes_generator, @@ -268,7 +268,7 @@ async fn test_rocksdb_multi_tx_same_block() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _) = E2ETestSetupBuilder::::new( + let (mut nodes, _) = E2ETestSetupBuilder::::new( 1, chain_spec.clone(), test_attributes_generator, @@ -336,7 +336,7 @@ async fn test_rocksdb_txs_across_blocks() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _) = E2ETestSetupBuilder::::new( + let (mut nodes, _) = E2ETestSetupBuilder::::new( 1, chain_spec.clone(), test_attributes_generator, @@ -421,7 +421,7 @@ async fn test_rocksdb_pending_tx_not_in_storage() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _) = E2ETestSetupBuilder::::new( + let (mut nodes, _) = E2ETestSetupBuilder::::new( 1, chain_spec.clone(), test_attributes_generator, @@ -485,7 +485,7 @@ async fn test_rocksdb_reorg_unwind() -> Result<()> { let chain_spec = test_chain_spec(); let chain_id = chain_spec.chain().id(); - let (mut nodes, _tasks, _) = E2ETestSetupBuilder::::new( + let (mut nodes, _) = E2ETestSetupBuilder::::new( 1, chain_spec.clone(), test_attributes_generator, diff --git a/crates/engine/service/src/service.rs b/crates/engine/service/src/service.rs index 4427147251..ab1f9ed306 100644 --- a/crates/engine/service/src/service.rs +++ b/crates/engine/service/src/service.rs @@ -201,6 +201,7 @@ mod tests { TreeConfig::default(), Box::new(NoopInvalidBlockHook::default()), changeset_cache.clone(), + reth_tasks::Runtime::test(), ); let (sync_metrics_tx, _sync_metrics_rx) = unbounded_channel(); diff --git a/crates/engine/tree/Cargo.toml b/crates/engine/tree/Cargo.toml index 6e359bba82..4968f28041 100644 --- a/crates/engine/tree/Cargo.toml +++ b/crates/engine/tree/Cargo.toml @@ -141,6 +141,7 @@ test-utils = [ "reth-ethereum-primitives/test-utils", "reth-node-ethereum/test-utils", "reth-evm-ethereum/test-utils", + "reth-tasks/test-utils", ] [[test]] diff --git a/crates/engine/tree/benches/state_root_task.rs b/crates/engine/tree/benches/state_root_task.rs index f271e18811..9ad59f6a50 100644 --- a/crates/engine/tree/benches/state_root_task.rs +++ b/crates/engine/tree/benches/state_root_task.rs @@ -12,8 +12,7 @@ use rand::Rng; use reth_chainspec::ChainSpec; use reth_db_common::init::init_genesis; use reth_engine_tree::tree::{ - executor::WorkloadExecutor, precompile_cache::PrecompileCacheMap, PayloadProcessor, - StateProviderBuilder, TreeConfig, + precompile_cache::PrecompileCacheMap, PayloadProcessor, StateProviderBuilder, TreeConfig, }; use reth_ethereum_primitives::TransactionSigned; use reth_evm::OnStateHook; @@ -219,7 +218,7 @@ fn bench_state_root(c: &mut Criterion) { setup_provider(&factory, &state_updates).expect("failed to setup provider"); let payload_processor = PayloadProcessor::new( - WorkloadExecutor::default(), + reth_tasks::Runtime::test(), EthEvmConfig::new(factory.chain_spec()), &TreeConfig::default(), PrecompileCacheMap::default(), diff --git a/crates/engine/tree/src/backfill.rs b/crates/engine/tree/src/backfill.rs index fb0a8b0d1c..53a5ac4f31 100644 --- a/crates/engine/tree/src/backfill.rs +++ b/crates/engine/tree/src/backfill.rs @@ -138,7 +138,7 @@ impl PipelineSync { let (tx, rx) = oneshot::channel(); let pipeline = pipeline.take().expect("exists"); - self.pipeline_task_spawner.spawn_critical_blocking( + self.pipeline_task_spawner.spawn_critical_blocking_task( "pipeline task", Box::pin(async move { let result = pipeline.run_as_fut(Some(target)).await; diff --git a/crates/engine/tree/src/tree/payload_processor/executor.rs b/crates/engine/tree/src/tree/payload_processor/executor.rs deleted file mode 100644 index 410c344a59..0000000000 --- a/crates/engine/tree/src/tree/payload_processor/executor.rs +++ /dev/null @@ -1,47 +0,0 @@ -//! Executor for mixed I/O and CPU workloads. - -use reth_trie_parallel::root::get_tokio_runtime_handle; -use tokio::{runtime::Handle, task::JoinHandle}; - -/// An executor for mixed I/O and CPU workloads. -/// -/// This type uses tokio to spawn blocking tasks and will reuse an existing tokio -/// runtime if available or create its own. -#[derive(Debug, Clone)] -pub struct WorkloadExecutor { - inner: WorkloadExecutorInner, -} - -impl Default for WorkloadExecutor { - fn default() -> Self { - Self { inner: WorkloadExecutorInner::new() } - } -} - -impl WorkloadExecutor { - /// Returns the handle to the tokio runtime - pub(super) const fn handle(&self) -> &Handle { - &self.inner.handle - } - - /// Runs the provided function on an executor dedicated to blocking operations. - #[track_caller] - pub fn spawn_blocking(&self, func: F) -> JoinHandle - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - self.inner.handle.spawn_blocking(func) - } -} - -#[derive(Debug, Clone)] -struct WorkloadExecutorInner { - handle: Handle, -} - -impl WorkloadExecutorInner { - fn new() -> Self { - Self { handle: get_tokio_runtime_handle() } - } -} diff --git a/crates/engine/tree/src/tree/payload_processor/mod.rs b/crates/engine/tree/src/tree/payload_processor/mod.rs index 8ee4d4e670..5798e24576 100644 --- a/crates/engine/tree/src/tree/payload_processor/mod.rs +++ b/crates/engine/tree/src/tree/payload_processor/mod.rs @@ -15,7 +15,6 @@ use alloy_eips::{eip1898::BlockWithParent, eip4895::Withdrawal}; use alloy_evm::block::StateChangeSource; use alloy_primitives::B256; use crossbeam_channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; -use executor::WorkloadExecutor; use metrics::{Counter, Histogram}; use multiproof::{SparseTrieUpdate, *}; use parking_lot::RwLock; @@ -34,6 +33,7 @@ use reth_provider::{ StateProviderFactory, StateReader, }; use reth_revm::{db::BundleState, state::EvmState}; +use reth_tasks::Runtime; use reth_trie::{hashed_cursor::HashedCursorFactory, trie_cursor::TrieCursorFactory}; use reth_trie_parallel::{ proof_task::{ProofTaskCtx, ProofWorkerHandle}, @@ -55,7 +55,6 @@ use std::{ use tracing::{debug, debug_span, instrument, warn, Span}; pub mod bal; -pub mod executor; pub mod multiproof; mod preserved_sparse_trie; pub mod prewarm; @@ -109,7 +108,7 @@ where Evm: ConfigureEvm, { /// The executor used by to spawn tasks. - executor: WorkloadExecutor, + executor: Runtime, /// The most recent cache used for execution. execution_cache: PayloadExecutionCache, /// Metrics for trie operations @@ -146,13 +145,13 @@ where Evm: ConfigureEvm, { /// Returns a reference to the workload executor driving payload tasks. - pub const fn executor(&self) -> &WorkloadExecutor { + pub const fn executor(&self) -> &Runtime { &self.executor } /// Creates a new payload processor. pub fn new( - executor: WorkloadExecutor, + executor: Runtime, evm_config: Evm, config: &TreeConfig, precompile_cache_map: PrecompileCacheMap>, @@ -280,7 +279,7 @@ where let storage_worker_count = config.storage_worker_count(); let account_worker_count = config.account_worker_count(); let proof_handle = ProofWorkerHandle::new( - self.executor.handle().clone(), + &self.executor, task_ctx, storage_worker_count, account_worker_count, @@ -1001,9 +1000,7 @@ mod tests { use super::PayloadExecutionCache; use crate::tree::{ cached_state::{CachedStateMetrics, ExecutionCache, SavedCache}, - payload_processor::{ - evm_state_to_hashed_post_state, executor::WorkloadExecutor, PayloadProcessor, - }, + payload_processor::{evm_state_to_hashed_post_state, PayloadProcessor}, precompile_cache::PrecompileCacheMap, StateProviderBuilder, TreeConfig, }; @@ -1105,7 +1102,7 @@ mod tests { #[test] fn on_inserted_executed_block_populates_cache() { let payload_processor = PayloadProcessor::new( - WorkloadExecutor::default(), + reth_tasks::Runtime::test(), EthEvmConfig::new(Arc::new(ChainSpec::default())), &TreeConfig::default(), PrecompileCacheMap::default(), @@ -1134,7 +1131,7 @@ mod tests { #[test] fn on_inserted_executed_block_skips_on_parent_mismatch() { let payload_processor = PayloadProcessor::new( - WorkloadExecutor::default(), + reth_tasks::Runtime::test(), EthEvmConfig::new(Arc::new(ChainSpec::default())), &TreeConfig::default(), PrecompileCacheMap::default(), @@ -1269,7 +1266,7 @@ mod tests { } let mut payload_processor = PayloadProcessor::new( - WorkloadExecutor::default(), + reth_tasks::Runtime::test(), EthEvmConfig::new(factory.chain_spec()), &TreeConfig::default(), PrecompileCacheMap::default(), diff --git a/crates/engine/tree/src/tree/payload_processor/multiproof.rs b/crates/engine/tree/src/tree/payload_processor/multiproof.rs index 2811efdef9..a1ae2e7c26 100644 --- a/crates/engine/tree/src/tree/payload_processor/multiproof.rs +++ b/crates/engine/tree/src/tree/payload_processor/multiproof.rs @@ -1547,17 +1547,11 @@ mod tests { use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofWorkerHandle}; use revm_primitives::{B256, U256}; use std::sync::{Arc, OnceLock}; - use tokio::runtime::{Handle, Runtime}; - /// Get a handle to the test runtime, creating it if necessary - fn get_test_runtime_handle() -> Handle { - static TEST_RT: OnceLock = OnceLock::new(); - TEST_RT - .get_or_init(|| { - tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap() - }) - .handle() - .clone() + /// Get a test runtime, creating it if necessary + fn get_test_runtime() -> &'static reth_tasks::Runtime { + static TEST_RT: OnceLock = OnceLock::new(); + TEST_RT.get_or_init(reth_tasks::Runtime::test) } fn create_test_state_root_task(factory: F) -> MultiProofTask @@ -1573,11 +1567,11 @@ mod tests { + Send + 'static, { - let rt_handle = get_test_runtime_handle(); + let runtime = get_test_runtime(); let changeset_cache = ChangesetCache::new(); let overlay_factory = OverlayStateProviderFactory::new(factory, changeset_cache); let task_ctx = ProofTaskCtx::new(overlay_factory); - let proof_handle = ProofWorkerHandle::new(rt_handle, task_ctx, 1, 1, false); + let proof_handle = ProofWorkerHandle::new(runtime, task_ctx, 1, 1, false); let (to_sparse_trie, _receiver) = std::sync::mpsc::channel(); let (tx, rx) = crossbeam_channel::unbounded(); diff --git a/crates/engine/tree/src/tree/payload_processor/prewarm.rs b/crates/engine/tree/src/tree/payload_processor/prewarm.rs index 83fe8ccf0e..2ce8d83b3e 100644 --- a/crates/engine/tree/src/tree/payload_processor/prewarm.rs +++ b/crates/engine/tree/src/tree/payload_processor/prewarm.rs @@ -15,7 +15,6 @@ use crate::tree::{ cached_state::{CachedStateProvider, SavedCache}, payload_processor::{ bal::{self, total_slots, BALSlotIter}, - executor::WorkloadExecutor, multiproof::{MultiProofMessage, VersionedMultiProofTargets}, PayloadExecutionCache, }, @@ -37,6 +36,7 @@ use reth_provider::{ StateReader, }; use reth_revm::{database::StateProviderDatabase, state::EvmState}; +use reth_tasks::Runtime; use reth_trie::MultiProofTargets; use std::{ ops::Range, @@ -78,7 +78,7 @@ where Evm: ConfigureEvm, { /// The executor used to spawn execution tasks. - executor: WorkloadExecutor, + executor: Runtime, /// Shared execution cache. execution_cache: PayloadExecutionCache, /// Context provided to execution tasks @@ -101,7 +101,7 @@ where { /// Initializes the task with the given transactions pending execution pub fn new( - executor: WorkloadExecutor, + executor: Runtime, execution_cache: PayloadExecutionCache, ctx: PrewarmContext, to_multi_proof: Option>, @@ -667,7 +667,7 @@ where fn spawn_workers( self, workers_needed: usize, - task_executor: &WorkloadExecutor, + task_executor: &Runtime, to_multi_proof: Option>, done_tx: Sender<()>, ) -> CrossbeamSender> @@ -704,7 +704,7 @@ where fn spawn_bal_worker( &self, idx: usize, - executor: &WorkloadExecutor, + executor: &Runtime, bal: Arc, range: Range, done_tx: Sender<()>, diff --git a/crates/engine/tree/src/tree/payload_processor/sparse_trie.rs b/crates/engine/tree/src/tree/payload_processor/sparse_trie.rs index d5925463bb..b79550e347 100644 --- a/crates/engine/tree/src/tree/payload_processor/sparse_trie.rs +++ b/crates/engine/tree/src/tree/payload_processor/sparse_trie.rs @@ -1,6 +1,5 @@ //! Sparse Trie task related functionality. -use super::executor::WorkloadExecutor; use crate::tree::{ multiproof::{ dispatch_with_chunking, evm_state_to_hashed_post_state, MultiProofMessage, @@ -13,6 +12,7 @@ use alloy_rlp::{Decodable, Encodable}; use crossbeam_channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; use rayon::iter::ParallelIterator; use reth_primitives_traits::{Account, ParallelBridgeBuffered}; +use reth_tasks::Runtime; use reth_trie::{ proof_v2::Target, updates::TrieUpdates, DecodedMultiProofV2, HashedPostState, Nibbles, TrieAccount, EMPTY_ROOT_HASH, TRIE_ACCOUNT_RLP_MAX_SIZE, @@ -282,7 +282,7 @@ where { /// Creates a new sparse trie, pre-populating with an existing [`SparseStateTrie`]. pub(super) fn new_with_trie( - executor: &WorkloadExecutor, + executor: &Runtime, updates: CrossbeamReceiver, proof_worker_handle: ProofWorkerHandle, metrics: MultiProofTaskMetrics, diff --git a/crates/engine/tree/src/tree/payload_validator.rs b/crates/engine/tree/src/tree/payload_validator.rs index d67b07e5fe..096547a4a9 100644 --- a/crates/engine/tree/src/tree/payload_validator.rs +++ b/crates/engine/tree/src/tree/payload_validator.rs @@ -4,7 +4,7 @@ use crate::tree::{ cached_state::CachedStateProvider, error::{InsertBlockError, InsertBlockErrorKind, InsertPayloadError}, instrumented_state::InstrumentedStateProvider, - payload_processor::{executor::WorkloadExecutor, PayloadProcessor}, + payload_processor::PayloadProcessor, precompile_cache::{CachedPrecompile, CachedPrecompileMetrics, PrecompileCacheMap}, sparse_trie::StateRootComputeOutcome, EngineApiMetrics, EngineApiTreeState, ExecutionEnv, PayloadHandle, StateProviderBuilder, @@ -134,6 +134,8 @@ where validator: V, /// Changeset cache for in-memory trie changesets changeset_cache: ChangesetCache, + /// Task runtime for spawning parallel work. + runtime: reth_tasks::Runtime, } impl BasicEngineValidator @@ -166,10 +168,11 @@ where config: TreeConfig, invalid_block_hook: Box>, changeset_cache: ChangesetCache, + runtime: reth_tasks::Runtime, ) -> Self { let precompile_cache_map = PrecompileCacheMap::default(); let payload_processor = PayloadProcessor::new( - WorkloadExecutor::default(), + runtime.clone(), evm_config.clone(), &config, precompile_cache_map.clone(), @@ -186,6 +189,7 @@ where metrics: EngineApiMetrics::default(), validator, changeset_cache, + runtime, } } @@ -874,7 +878,8 @@ where let prefix_sets = hashed_state.construct_prefix_sets().freeze(); let overlay_factory = overlay_factory.with_extended_hashed_state_overlay(hashed_state.clone_into_sorted()); - ParallelStateRoot::new(overlay_factory, prefix_sets).incremental_root_with_updates() + ParallelStateRoot::new(overlay_factory, prefix_sets, self.runtime.clone()) + .incremental_root_with_updates() } /// Compute state root for the given hashed post state in serial. diff --git a/crates/engine/tree/src/tree/tests.rs b/crates/engine/tree/src/tree/tests.rs index ec4efd4264..7e376537e5 100644 --- a/crates/engine/tree/src/tree/tests.rs +++ b/crates/engine/tree/src/tree/tests.rs @@ -203,6 +203,7 @@ impl TestHarness { TreeConfig::default(), Box::new(NoopInvalidBlockHook::default()), changeset_cache.clone(), + reth_tasks::Runtime::test(), ); let tree = EngineApiTreeHandler::new( @@ -404,6 +405,7 @@ impl ValidatorTestHarness { TreeConfig::default(), Box::new(NoopInvalidBlockHook::default()), changeset_cache, + reth_tasks::Runtime::test(), ); Self { harness, validator, metrics: TestMetrics::default() } diff --git a/crates/ethereum/cli/src/interface.rs b/crates/ethereum/cli/src/interface.rs index 4cfff2c62a..817a35d805 100644 --- a/crates/ethereum/cli/src/interface.rs +++ b/crates/ethereum/cli/src/interface.rs @@ -92,7 +92,7 @@ impl< /// This accepts a closure that is used to launch the node via the /// [`NodeCommand`](node::NodeCommand). /// - /// This command will be run on the [default tokio runtime](reth_cli_runner::tokio_runtime). + /// This command will be run on the default tokio runtime. /// /// /// # Example @@ -143,7 +143,7 @@ impl< /// This accepts a closure that is used to launch the node via the /// [`NodeCommand`](node::NodeCommand). /// - /// This command will be run on the [default tokio runtime](reth_cli_runner::tokio_runtime). + /// This command will be run on the default tokio runtime. pub fn run_with_components( self, components: impl CliComponentsBuilder, diff --git a/crates/ethereum/node/Cargo.toml b/crates/ethereum/node/Cargo.toml index 306bbf54fb..9041b60e49 100644 --- a/crates/ethereum/node/Cargo.toml +++ b/crates/ethereum/node/Cargo.toml @@ -112,4 +112,5 @@ test-utils = [ "reth-primitives-traits/test-utils", "reth-evm-ethereum/test-utils", "reth-stages-types/test-utils", + "reth-tasks/test-utils", ] diff --git a/crates/ethereum/node/src/node.rs b/crates/ethereum/node/src/node.rs index 19d0792947..242dbce1c1 100644 --- a/crates/ethereum/node/src/node.rs +++ b/crates/ethereum/node/src/node.rs @@ -107,9 +107,11 @@ impl EthereumNode { /// use reth_chainspec::MAINNET; /// use reth_node_ethereum::EthereumNode; /// - /// let factory = EthereumNode::provider_factory_builder() - /// .open_read_only(MAINNET.clone(), "datadir") - /// .unwrap(); + /// fn demo(runtime: reth_tasks::Runtime) { + /// let factory = EthereumNode::provider_factory_builder() + /// .open_read_only(MAINNET.clone(), "datadir", runtime) + /// .unwrap(); + /// } /// ``` /// /// # Open a Providerfactory manually with all required components @@ -120,12 +122,15 @@ impl EthereumNode { /// use reth_node_ethereum::EthereumNode; /// use reth_provider::providers::{RocksDBProvider, StaticFileProvider}; /// - /// let factory = EthereumNode::provider_factory_builder() - /// .db(open_db_read_only("db", Default::default()).unwrap()) - /// .chainspec(ChainSpecBuilder::mainnet().build().into()) - /// .static_file(StaticFileProvider::read_only("db/static_files", false).unwrap()) - /// .rocksdb_provider(RocksDBProvider::builder("db/rocksdb").build().unwrap()) - /// .build_provider_factory(); + /// fn demo(runtime: reth_tasks::Runtime) { + /// let factory = EthereumNode::provider_factory_builder() + /// .db(open_db_read_only("db", Default::default()).unwrap()) + /// .chainspec(ChainSpecBuilder::mainnet().build().into()) + /// .static_file(StaticFileProvider::read_only("db/static_files", false).unwrap()) + /// .rocksdb_provider(RocksDBProvider::builder("db/rocksdb").build().unwrap()) + /// .runtime(runtime) + /// .build_provider_factory(); + /// } /// ``` pub fn provider_factory_builder() -> ProviderFactoryBuilder { ProviderFactoryBuilder::default() @@ -513,7 +518,7 @@ where // it doesn't impact the first block or the first gossiped blob transaction, so we // initialize this in the background let kzg_settings = validator.validator().kzg_settings().clone(); - ctx.task_executor().spawn_blocking(async move { + ctx.task_executor().spawn_blocking_task(async move { let _ = kzg_settings.get(); debug!(target: "reth::cli", "Initialized KZG settings"); }); diff --git a/crates/ethereum/node/tests/e2e/blobs.rs b/crates/ethereum/node/tests/e2e/blobs.rs index 5f5b17127a..16810b9c0a 100644 --- a/crates/ethereum/node/tests/e2e/blobs.rs +++ b/crates/ethereum/node/tests/e2e/blobs.rs @@ -10,7 +10,7 @@ use reth_ethereum_primitives::PooledTransactionVariant; use reth_node_builder::{NodeBuilder, NodeHandle}; use reth_node_core::{args::RpcServerArgs, node_config::NodeConfig}; use reth_node_ethereum::EthereumNode; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use reth_transaction_pool::TransactionPool; use std::{ sync::Arc, @@ -20,8 +20,7 @@ use std::{ #[tokio::test] async fn can_handle_blobs() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); let chain_spec = Arc::new( @@ -37,7 +36,7 @@ async fn can_handle_blobs() -> eyre::Result<()> { .with_unused_ports() .with_rpc(RpcServerArgs::default().with_unused_ports().with_http()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; @@ -92,8 +91,7 @@ async fn can_handle_blobs() -> eyre::Result<()> { #[tokio::test] async fn can_send_legacy_sidecar_post_activation() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); let chain_spec = Arc::new( @@ -107,7 +105,7 @@ async fn can_send_legacy_sidecar_post_activation() -> eyre::Result<()> { .with_force_blob_sidecar_upcasting(), ); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; @@ -146,8 +144,7 @@ async fn can_send_legacy_sidecar_post_activation() -> eyre::Result<()> { #[tokio::test] async fn blob_conversion_at_osaka() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let current_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); // Osaka activates in 2 slots @@ -170,7 +167,7 @@ async fn blob_conversion_at_osaka() -> eyre::Result<()> { .with_force_blob_sidecar_upcasting(), ); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; diff --git a/crates/ethereum/node/tests/e2e/custom_genesis.rs b/crates/ethereum/node/tests/e2e/custom_genesis.rs index 6d1689655f..63e366f00d 100644 --- a/crates/ethereum/node/tests/e2e/custom_genesis.rs +++ b/crates/ethereum/node/tests/e2e/custom_genesis.rs @@ -27,7 +27,7 @@ async fn can_run_eth_node_with_custom_genesis_number() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = + let (mut nodes, wallet) = setup::(1, chain_spec, false, eth_payload_attributes).await?; let mut node = nodes.pop().unwrap(); @@ -81,7 +81,7 @@ async fn custom_genesis_block_query_boundaries() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, _wallet) = + let (mut nodes, _wallet) = setup::(1, chain_spec, false, eth_payload_attributes).await?; let node = nodes.pop().unwrap(); diff --git a/crates/ethereum/node/tests/e2e/dev.rs b/crates/ethereum/node/tests/e2e/dev.rs index bf022a514e..279f7d0cb7 100644 --- a/crates/ethereum/node/tests/e2e/dev.rs +++ b/crates/ethereum/node/tests/e2e/dev.rs @@ -9,20 +9,19 @@ use reth_node_core::args::DevArgs; use reth_node_ethereum::{node::EthereumAddOns, EthereumNode}; use reth_provider::{providers::BlockchainProvider, CanonStateSubscriptions}; use reth_rpc_eth_api::{helpers::EthTransactions, EthApiServer}; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::sync::Arc; #[tokio::test] async fn can_run_dev_node() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let node_config = NodeConfig::test() .with_chain(custom_chain()) .with_dev(DevArgs { dev: true, ..Default::default() }); let NodeHandle { node, .. } = NodeBuilder::new(node_config.clone()) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .with_types_and_provider::>() .with_components(EthereumNode::components()) .with_add_ons(EthereumAddOns::default()) @@ -37,15 +36,14 @@ async fn can_run_dev_node() -> eyre::Result<()> { #[tokio::test] async fn can_run_dev_node_custom_attributes() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let node_config = NodeConfig::test() .with_chain(custom_chain()) .with_dev(DevArgs { dev: true, ..Default::default() }); let fee_recipient = Address::random(); let NodeHandle { node, .. } = NodeBuilder::new(node_config.clone()) - .testing_node(exec.clone()) + .testing_node(runtime.clone()) .with_types_and_provider::>() .with_components(EthereumNode::components()) .with_add_ons(EthereumAddOns::default()) diff --git a/crates/ethereum/node/tests/e2e/eth.rs b/crates/ethereum/node/tests/e2e/eth.rs index 5111a56a3a..183643b7d7 100644 --- a/crates/ethereum/node/tests/e2e/eth.rs +++ b/crates/ethereum/node/tests/e2e/eth.rs @@ -14,14 +14,14 @@ use reth_node_core::{args::RpcServerArgs, node_config::NodeConfig}; use reth_node_ethereum::EthereumNode; use reth_provider::BlockNumReader; use reth_rpc_api::TestingBuildBlockRequestV1; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::sync::Arc; #[tokio::test] async fn can_run_eth_node() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let (mut nodes, _tasks, wallet) = setup::( + let (mut nodes, wallet) = setup::( 1, Arc::new( ChainSpecBuilder::default() @@ -57,8 +57,7 @@ async fn can_run_eth_node() -> eyre::Result<()> { #[cfg(unix)] async fn can_run_eth_node_with_auth_engine_api_over_ipc() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let exec = TaskManager::current(); - let exec = exec.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); // Chain spec with test allocs let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); @@ -76,7 +75,7 @@ async fn can_run_eth_node_with_auth_engine_api_over_ipc() -> eyre::Result<()> { .with_rpc(RpcServerArgs::default().with_unused_ports().with_http().with_auth_ipc()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec) + .testing_node(runtime) .node(EthereumNode::default()) .launch() .await?; @@ -105,8 +104,7 @@ async fn can_run_eth_node_with_auth_engine_api_over_ipc() -> eyre::Result<()> { #[cfg(unix)] async fn test_failed_run_eth_node_with_no_auth_engine_api_over_ipc_opts() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let exec = TaskManager::current(); - let exec = exec.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); // Chain spec with test allocs let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); @@ -121,7 +119,7 @@ async fn test_failed_run_eth_node_with_no_auth_engine_api_over_ipc_opts() -> eyr // Node setup let node_config = NodeConfig::test().with_chain(chain_spec); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec) + .testing_node(runtime) .node(EthereumNode::default()) .launch() .await?; @@ -139,7 +137,7 @@ async fn test_failed_run_eth_node_with_no_auth_engine_api_over_ipc_opts() -> eyr async fn test_engine_graceful_shutdown() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let (mut nodes, _tasks, wallet) = setup::( + let (mut nodes, wallet) = setup::( 1, Arc::new( ChainSpecBuilder::default() @@ -190,8 +188,7 @@ async fn test_engine_graceful_shutdown() -> eyre::Result<()> { #[tokio::test] async fn test_testing_build_block_v1_osaka() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let exec = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); let chain_spec = Arc::new( @@ -208,7 +205,7 @@ async fn test_testing_build_block_v1_osaka() -> eyre::Result<()> { ); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec) + .testing_node(runtime) .node(EthereumNode::default()) .launch() .await?; @@ -278,7 +275,7 @@ async fn test_sparse_trie_reuse_across_blocks() -> eyre::Result<()> { .with_sparse_trie_prune_depth(2) .with_sparse_trie_max_storage_tries(100); - let (mut nodes, _tasks, _wallet) = setup_engine::( + let (mut nodes, _wallet) = setup_engine::( 1, Arc::new( ChainSpecBuilder::default() diff --git a/crates/ethereum/node/tests/e2e/invalid_payload.rs b/crates/ethereum/node/tests/e2e/invalid_payload.rs index 03269e53b7..7e83098215 100644 --- a/crates/ethereum/node/tests/e2e/invalid_payload.rs +++ b/crates/ethereum/node/tests/e2e/invalid_payload.rs @@ -37,7 +37,7 @@ async fn can_handle_invalid_payload_then_valid() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 2, chain_spec.clone(), false, @@ -154,7 +154,7 @@ async fn can_handle_multiple_invalid_payloads() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 2, chain_spec.clone(), false, @@ -255,7 +255,7 @@ async fn can_handle_invalid_payload_with_transactions() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 2, chain_spec.clone(), false, diff --git a/crates/ethereum/node/tests/e2e/p2p.rs b/crates/ethereum/node/tests/e2e/p2p.rs index 74266b1675..90193b2c4e 100644 --- a/crates/ethereum/node/tests/e2e/p2p.rs +++ b/crates/ethereum/node/tests/e2e/p2p.rs @@ -18,7 +18,7 @@ use std::{sync::Arc, time::Duration}; async fn can_sync() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let (mut nodes, _tasks, wallet) = setup::( + let (mut nodes, wallet) = setup::( 2, Arc::new( ChainSpecBuilder::default() @@ -74,7 +74,7 @@ async fn e2e_test_send_transactions() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, _) = setup_engine::( + let (mut nodes, _) = setup_engine::( 2, chain_spec.clone(), false, @@ -116,7 +116,7 @@ async fn test_long_reorg() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, _) = setup_engine::( + let (mut nodes, _) = setup_engine::( 2, chain_spec.clone(), false, @@ -172,7 +172,7 @@ async fn test_reorg_through_backfill() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, _) = setup_engine::( + let (mut nodes, _) = setup_engine::( 2, chain_spec.clone(), false, @@ -236,7 +236,7 @@ async fn test_tx_propagation() -> eyre::Result<()> { }; // Setup 10 nodes - let (mut nodes, _tasks, _) = setup_engine_with_connection::( + let (mut nodes, _) = setup_engine_with_connection::( 10, chain_spec.clone(), false, diff --git a/crates/ethereum/node/tests/e2e/pool.rs b/crates/ethereum/node/tests/e2e/pool.rs index 3777c4945d..e30aa41420 100644 --- a/crates/ethereum/node/tests/e2e/pool.rs +++ b/crates/ethereum/node/tests/e2e/pool.rs @@ -12,7 +12,7 @@ use reth_node_core::{args::RpcServerArgs, node_config::NodeConfig}; use reth_node_ethereum::EthereumNode; use reth_primitives_traits::Recovered; use reth_provider::CanonStateSubscriptions; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use reth_transaction_pool::{ blobstore::InMemoryBlobStore, test_utils::OkValidator, BlockInfo, CoinbaseTipOrdering, EthPooledTransaction, Pool, PoolTransaction, TransactionOrigin, TransactionPool, @@ -24,8 +24,7 @@ use std::{sync::Arc, time::Duration}; #[tokio::test] async fn maintain_txpool_stale_eviction() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let executor = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let txpool = Pool::new( OkValidator::default(), @@ -49,7 +48,7 @@ async fn maintain_txpool_stale_eviction() -> eyre::Result<()> { .with_unused_ports() .with_rpc(RpcServerArgs::default().with_unused_ports().with_http()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(executor.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; @@ -63,13 +62,13 @@ async fn maintain_txpool_stale_eviction() -> eyre::Result<()> { ..Default::default() }; - executor.spawn_critical( + runtime.spawn_critical_task( "txpool maintenance task", reth_transaction_pool::maintain::maintain_transaction_pool_future( node.inner.provider.clone(), txpool.clone(), node.inner.provider.clone().canonical_state_stream(), - executor.clone(), + runtime.clone(), config, ), ); @@ -98,8 +97,7 @@ async fn maintain_txpool_stale_eviction() -> eyre::Result<()> { #[tokio::test] async fn maintain_txpool_reorg() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let executor = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let txpool = Pool::new( OkValidator::default(), @@ -124,7 +122,7 @@ async fn maintain_txpool_reorg() -> eyre::Result<()> { .with_unused_ports() .with_rpc(RpcServerArgs::default().with_unused_ports().with_http()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(executor.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; @@ -135,13 +133,13 @@ async fn maintain_txpool_reorg() -> eyre::Result<()> { let w1 = wallets.first().unwrap(); let w2 = wallets.last().unwrap(); - executor.spawn_critical( + runtime.spawn_critical_task( "txpool maintenance task", reth_transaction_pool::maintain::maintain_transaction_pool_future( node.inner.provider.clone(), txpool.clone(), node.inner.provider.clone().canonical_state_stream(), - executor.clone(), + runtime.clone(), reth_transaction_pool::maintain::MaintainPoolConfig::default(), ), ); @@ -231,8 +229,7 @@ async fn maintain_txpool_reorg() -> eyre::Result<()> { #[tokio::test] async fn maintain_txpool_commit() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let tasks = TaskManager::current(); - let executor = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let txpool = Pool::new( OkValidator::default(), @@ -256,7 +253,7 @@ async fn maintain_txpool_commit() -> eyre::Result<()> { .with_unused_ports() .with_rpc(RpcServerArgs::default().with_unused_ports().with_http()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config.clone()) - .testing_node(executor.clone()) + .testing_node(runtime.clone()) .node(EthereumNode::default()) .launch() .await?; @@ -265,13 +262,13 @@ async fn maintain_txpool_commit() -> eyre::Result<()> { let wallet = Wallet::default(); - executor.spawn_critical( + runtime.spawn_critical_task( "txpool maintenance task", reth_transaction_pool::maintain::maintain_transaction_pool_future( node.inner.provider.clone(), txpool.clone(), node.inner.provider.clone().canonical_state_stream(), - executor.clone(), + runtime.clone(), reth_transaction_pool::maintain::MaintainPoolConfig::default(), ), ); diff --git a/crates/ethereum/node/tests/e2e/prestate.rs b/crates/ethereum/node/tests/e2e/prestate.rs index 6c66f09bf7..42a3e62a00 100644 --- a/crates/ethereum/node/tests/e2e/prestate.rs +++ b/crates/ethereum/node/tests/e2e/prestate.rs @@ -12,7 +12,7 @@ use reth_node_builder::{NodeBuilder, NodeHandle}; use reth_node_core::{args::RpcServerArgs, node_config::NodeConfig}; use reth_node_ethereum::EthereumNode; use reth_rpc_server_types::RpcModuleSelection; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use serde::Deserialize; use std::sync::Arc; @@ -29,8 +29,7 @@ async fn debug_trace_call_matches_geth_prestate_snapshot() -> Result<()> { let mut genesis: Genesis = MAINNET.genesis().clone(); genesis.coinbase = address!("0x95222290dd7278aa3ddd389cc1e1d165cc4bafe5"); - let exec = TaskManager::current(); - let exec = exec.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let expected_frame = expected_snapshot_frame()?; let prestate_mode = match &expected_frame { @@ -63,7 +62,7 @@ async fn debug_trace_call_matches_geth_prestate_snapshot() -> Result<()> { ); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec) + .testing_node(runtime) .node(EthereumNode::default()) .launch() .await?; diff --git a/crates/ethereum/node/tests/e2e/rpc.rs b/crates/ethereum/node/tests/e2e/rpc.rs index c149580ca6..1edbcaac15 100644 --- a/crates/ethereum/node/tests/e2e/rpc.rs +++ b/crates/ethereum/node/tests/e2e/rpc.rs @@ -21,7 +21,7 @@ use reth_node_core::{ use reth_node_ethereum::EthereumNode; use reth_payload_primitives::BuiltPayload; use reth_rpc_api::servers::AdminApiServer; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::{ sync::Arc, time::{SystemTime, UNIX_EPOCH}, @@ -57,7 +57,7 @@ async fn test_fee_history() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, chain_spec.clone(), false, @@ -142,7 +142,7 @@ async fn test_flashbots_validate_v3() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, chain_spec.clone(), false, @@ -224,7 +224,7 @@ async fn test_flashbots_validate_v4() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, chain_spec.clone(), false, @@ -314,7 +314,7 @@ async fn test_eth_config() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, chain_spec.clone(), false, @@ -344,8 +344,7 @@ async fn test_eth_config() -> eyre::Result<()> { async fn test_admin_external_ip() -> eyre::Result<()> { reth_tracing::init_test_tracing(); - let exec = TaskManager::current(); - let exec = exec.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); // Chain spec with test allocs let genesis: Genesis = serde_json::from_str(include_str!("../assets/genesis.json")).unwrap(); @@ -363,7 +362,7 @@ async fn test_admin_external_ip() -> eyre::Result<()> { .with_rpc(RpcServerArgs::default().with_unused_ports().with_http()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(exec) + .testing_node(runtime) .node(EthereumNode::default()) .launch() .await?; diff --git a/crates/ethereum/node/tests/e2e/selfdestruct.rs b/crates/ethereum/node/tests/e2e/selfdestruct.rs index 8ffd9169ef..626f9d3b5e 100644 --- a/crates/ethereum/node/tests/e2e/selfdestruct.rs +++ b/crates/ethereum/node/tests/e2e/selfdestruct.rs @@ -142,7 +142,7 @@ async fn test_selfdestruct_post_dencun() -> eyre::Result<()> { reth_tracing::init_test_tracing(); let tree_config = TreeConfig::default().without_prewarming(true).without_state_cache(false); - let (mut nodes, _tasks, wallet) = + let (mut nodes, wallet) = setup_engine::(1, cancun_spec(), false, tree_config, eth_payload_attributes) .await?; let mut node = nodes.pop().unwrap(); @@ -236,7 +236,7 @@ async fn test_selfdestruct_same_tx_post_dencun() -> eyre::Result<()> { reth_tracing::init_test_tracing(); let tree_config = TreeConfig::default().without_prewarming(true).without_state_cache(false); - let (mut nodes, _tasks, wallet) = + let (mut nodes, wallet) = setup_engine::(1, cancun_spec(), false, tree_config, eth_payload_attributes) .await?; let mut node = nodes.pop().unwrap(); @@ -311,7 +311,7 @@ async fn test_selfdestruct_pre_dencun() -> eyre::Result<()> { reth_tracing::init_test_tracing(); let tree_config = TreeConfig::default().without_prewarming(true).without_state_cache(false); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, shanghai_spec(), false, @@ -421,7 +421,7 @@ async fn test_selfdestruct_same_tx_preexisting_account_post_dencun() -> eyre::Re reth_tracing::init_test_tracing(); let tree_config = TreeConfig::default().without_prewarming(true).without_state_cache(false); - let (mut nodes, _tasks, wallet) = + let (mut nodes, wallet) = setup_engine::(1, cancun_spec(), false, tree_config, eth_payload_attributes) .await?; let mut node = nodes.pop().unwrap(); diff --git a/crates/ethereum/node/tests/e2e/simulate.rs b/crates/ethereum/node/tests/e2e/simulate.rs index ee8c615e62..9619a1cfa9 100644 --- a/crates/ethereum/node/tests/e2e/simulate.rs +++ b/crates/ethereum/node/tests/e2e/simulate.rs @@ -29,7 +29,7 @@ async fn test_simulate_v1_with_max_fee_per_blob_gas_only() -> eyre::Result<()> { .build(), ); - let (mut nodes, _tasks, wallet) = setup_engine::( + let (mut nodes, wallet) = setup_engine::( 1, chain_spec.clone(), false, diff --git a/crates/ethereum/node/tests/it/builder.rs b/crates/ethereum/node/tests/it/builder.rs index 48f1e0da2f..103d766cc6 100644 --- a/crates/ethereum/node/tests/it/builder.rs +++ b/crates/ethereum/node/tests/it/builder.rs @@ -11,7 +11,7 @@ use reth_node_builder::{EngineNodeLauncher, FullNodeComponents, NodeBuilder, Nod use reth_node_ethereum::node::{EthereumAddOns, EthereumNode}; use reth_provider::providers::BlockchainProvider; use reth_rpc_builder::Identity; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; #[test] fn test_basic_setup() { @@ -46,13 +46,13 @@ fn test_basic_setup() { #[tokio::test] async fn test_eth_launcher() { - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let config = NodeConfig::test(); let db = create_test_rw_db(); let _builder = NodeBuilder::new(config) .with_database(db) - .with_launch_context(tasks.executor()) + .with_launch_context(runtime.clone()) .with_types_and_provider::>>, >>() @@ -64,7 +64,7 @@ async fn test_eth_launcher() { }) .launch_with_fn(|builder| { let launcher = EngineNodeLauncher::new( - tasks.executor(), + runtime.clone(), builder.config().datadir(), Default::default(), ); @@ -81,13 +81,13 @@ fn test_eth_launcher_with_tokio_runtime() { let custom_rt = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime"); main_rt.block_on(async { - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let config = NodeConfig::test(); let db = create_test_rw_db(); let _builder = NodeBuilder::new(config) .with_database(db) - .with_launch_context(tasks.executor()) + .with_launch_context(runtime.clone()) .with_types_and_provider::>>, >>() @@ -101,7 +101,7 @@ fn test_eth_launcher_with_tokio_runtime() { }) .launch_with_fn(|builder| { let launcher = EngineNodeLauncher::new( - tasks.executor(), + runtime.clone(), builder.config().datadir(), Default::default(), ); diff --git a/crates/ethereum/node/tests/it/testing.rs b/crates/ethereum/node/tests/it/testing.rs index eb25e7d013..97b13b1bf2 100644 --- a/crates/ethereum/node/tests/it/testing.rs +++ b/crates/ethereum/node/tests/it/testing.rs @@ -13,14 +13,14 @@ use reth_node_core::{ use reth_node_ethereum::{node::EthereumAddOns, EthereumNode}; use reth_rpc_api::TestingBuildBlockRequestV1; use reth_rpc_server_types::{RethRpcModule, RpcModuleSelection}; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use std::str::FromStr; use tempfile::tempdir; use tokio::sync::oneshot; #[tokio::test(flavor = "multi_thread")] async fn testing_rpc_build_block_works() -> eyre::Result<()> { - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let mut rpc_args = reth_node_core::args::RpcServerArgs::default().with_http(); rpc_args.http_api = Some(RpcModuleSelection::from_iter([RethRpcModule::Testing])); let tempdir = tempdir().expect("temp datadir"); @@ -41,7 +41,7 @@ async fn testing_rpc_build_block_works() -> eyre::Result<()> { let builder = NodeBuilder::new(config) .with_database(db) - .with_launch_context(tasks.executor()) + .with_launch_context(runtime) .with_types::() .with_components(EthereumNode::components()) .with_add_ons(EthereumAddOns::default()) diff --git a/crates/ethereum/reth/Cargo.toml b/crates/ethereum/reth/Cargo.toml index 4938f8687e..c96b28a7b0 100644 --- a/crates/ethereum/reth/Cargo.toml +++ b/crates/ethereum/reth/Cargo.toml @@ -100,6 +100,7 @@ test-utils = [ "reth-node-builder?/test-utils", "reth-trie-db?/test-utils", "reth-codecs?/test-utils", + "reth-tasks?/test-utils", ] full = [ diff --git a/crates/exex/test-utils/src/lib.rs b/crates/exex/test-utils/src/lib.rs index 6198935615..a1a9f20ef4 100644 --- a/crates/exex/test-utils/src/lib.rs +++ b/crates/exex/test-utils/src/lib.rs @@ -55,7 +55,7 @@ use reth_provider::{ providers::{BlockchainProvider, RocksDBProvider, StaticFileProvider}, BlockReader, EthStorage, ProviderFactory, }; -use reth_tasks::TaskManager; +use reth_tasks::Runtime; use reth_transaction_pool::test_utils::{testing_pool, TestPool}; use tempfile::TempDir; use thiserror::Error; @@ -175,8 +175,8 @@ pub struct TestExExHandle { pub events_rx: UnboundedReceiver, /// Channel for sending notifications to the Execution Extension pub notifications_tx: Sender, - /// Node task manager - pub tasks: TaskManager, + /// Node task runtime + pub runtime: Runtime, /// WAL temp directory handle _wal_directory: TempDir, } @@ -252,6 +252,7 @@ pub async fn test_exex_context_with_chain_spec( chain_spec.clone(), StaticFileProvider::read_write(static_dir.keep()).expect("static file provider"), RocksDBProvider::builder(rocksdb_dir.keep()).with_default_tables().build().unwrap(), + reth_tasks::Runtime::test(), )?; let genesis_hash = init_genesis(&provider_factory)?; @@ -265,9 +266,9 @@ pub async fn test_exex_context_with_chain_spec( ) .await?; let network = network_manager.handle().clone(); - let tasks = TaskManager::current(); - let task_executor = tasks.executor(); - tasks.executor().spawn(network_manager); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); + let task_executor = runtime.clone(); + runtime.spawn_task(network_manager); let (_, payload_builder_handle) = NoopPayloadBuilderService::::new(); @@ -320,7 +321,7 @@ pub async fn test_exex_context_with_chain_spec( provider_factory, events_rx, notifications_tx, - tasks, + runtime, _wal_directory: wal_directory, }, )) diff --git a/crates/net/downloaders/Cargo.toml b/crates/net/downloaders/Cargo.toml index 6ef34b858b..709c30fe51 100644 --- a/crates/net/downloaders/Cargo.toml +++ b/crates/net/downloaders/Cargo.toml @@ -83,4 +83,5 @@ test-utils = [ "reth-primitives-traits/test-utils", "dep:reth-ethereum-primitives", "reth-ethereum-primitives?/test-utils", + "reth-tasks/test-utils", ] diff --git a/crates/net/downloaders/src/bodies/task.rs b/crates/net/downloaders/src/bodies/task.rs index 7d9ceeb94e..763fac1812 100644 --- a/crates/net/downloaders/src/bodies/task.rs +++ b/crates/net/downloaders/src/bodies/task.rs @@ -86,7 +86,7 @@ impl TaskDownloader { downloader, }; - spawner.spawn(downloader.boxed()); + spawner.spawn_task(downloader.boxed()); Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader } } diff --git a/crates/net/downloaders/src/headers/task.rs b/crates/net/downloaders/src/headers/task.rs index 779ad7ab10..83a2dc76b5 100644 --- a/crates/net/downloaders/src/headers/task.rs +++ b/crates/net/downloaders/src/headers/task.rs @@ -78,7 +78,7 @@ impl TaskDownloader { updates: UnboundedReceiverStream::new(updates_rx), downloader, }; - spawner.spawn(downloader.boxed()); + spawner.spawn_task(downloader.boxed()); Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader } } diff --git a/crates/net/network/Cargo.toml b/crates/net/network/Cargo.toml index 62252155e3..9eb7718e5a 100644 --- a/crates/net/network/Cargo.toml +++ b/crates/net/network/Cargo.toml @@ -139,6 +139,7 @@ test-utils = [ "reth-ethereum-primitives/test-utils", "dep:reth-evm-ethereum", "reth-evm-ethereum?/test-utils", + "reth-tasks/test-utils", ] [[bench]] diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 5af69d77a5..c0a75c647b 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -229,7 +229,7 @@ impl SessionManager { where F: Future + Send + 'static, { - self.executor.spawn(f.boxed()); + self.executor.spawn_task(f.boxed()); } /// Invoked on a received status update. diff --git a/crates/node/builder/Cargo.toml b/crates/node/builder/Cargo.toml index 82753d943f..3171713a93 100644 --- a/crates/node/builder/Cargo.toml +++ b/crates/node/builder/Cargo.toml @@ -50,7 +50,7 @@ reth-rpc-eth-types.workspace = true reth-rpc-layer.workspace = true reth-stages.workspace = true reth-static-file.workspace = true -reth-tasks.workspace = true +reth-tasks = { workspace = true, features = ["rayon"] } reth-tokio-util.workspace = true reth-tracing.workspace = true reth-transaction-pool.workspace = true @@ -120,6 +120,7 @@ test-utils = [ "reth-evm-ethereum/test-utils", "reth-node-ethereum/test-utils", "reth-primitives-traits/test-utils", + "reth-tasks/test-utils", ] op = [ "reth-db/op", diff --git a/crates/node/builder/src/builder/mod.rs b/crates/node/builder/src/builder/mod.rs index b2adea4f0f..09f08ad82b 100644 --- a/crates/node/builder/src/builder/mod.rs +++ b/crates/node/builder/src/builder/mod.rs @@ -903,8 +903,8 @@ impl BuilderContext { .request_handler(self.provider().clone()) .split_with_handle(); - self.executor.spawn_critical_blocking("p2p txpool", Box::pin(txpool)); - self.executor.spawn_critical_blocking("p2p eth request handler", Box::pin(eth)); + self.executor.spawn_critical_blocking_task("p2p txpool", Box::pin(txpool)); + self.executor.spawn_critical_blocking_task("p2p eth request handler", Box::pin(eth)); let default_peers_path = self.config().datadir().known_peers(); let known_peers_file = self.config().network.persistent_peers_file(default_peers_path); diff --git a/crates/node/builder/src/builder/states.rs b/crates/node/builder/src/builder/states.rs index 21bd897a40..c1693c503b 100644 --- a/crates/node/builder/src/builder/states.rs +++ b/crates/node/builder/src/builder/states.rs @@ -324,7 +324,7 @@ mod test { use reth_node_ethereum::EthereumNode; use reth_payload_builder::PayloadBuilderHandle; use reth_provider::noop::NoopProvider; - use reth_tasks::TaskManager; + use reth_tasks::Runtime; use reth_transaction_pool::noop::NoopTransactionPool; #[test] @@ -345,9 +345,7 @@ mod test { let task_executor = { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle); - manager.executor() + Runtime::with_existing_handle(runtime.handle().clone()).unwrap() }; let node = NodeAdapter { components, task_executor, provider: NoopProvider::default() }; diff --git a/crates/node/builder/src/components/payload.rs b/crates/node/builder/src/components/payload.rs index b587889e86..0425cbefd2 100644 --- a/crates/node/builder/src/components/payload.rs +++ b/crates/node/builder/src/components/payload.rs @@ -107,7 +107,8 @@ where let (payload_service, payload_service_handle) = PayloadBuilderService::new(payload_generator, ctx.provider().canonical_state_stream()); - ctx.task_executor().spawn_critical("payload builder service", Box::pin(payload_service)); + ctx.task_executor() + .spawn_critical_task("payload builder service", Box::pin(payload_service)); Ok(payload_service_handle) } @@ -133,7 +134,7 @@ where ) -> eyre::Result::Payload>> { let (tx, mut rx) = mpsc::unbounded_channel(); - ctx.task_executor().spawn_critical("payload builder", async move { + ctx.task_executor().spawn_critical_task("payload builder", async move { #[allow(clippy::collection_is_never_read)] let mut subscriptions = Vec::new(); diff --git a/crates/node/builder/src/components/pool.rs b/crates/node/builder/src/components/pool.rs index 0816869d52..243bda0738 100644 --- a/crates/node/builder/src/components/pool.rs +++ b/crates/node/builder/src/components/pool.rs @@ -256,7 +256,7 @@ where let chain_events = ctx.provider().canonical_state_stream(); let client = ctx.provider().clone(); - ctx.task_executor().spawn_critical( + ctx.task_executor().spawn_critical_task( "txpool maintenance task", reth_transaction_pool::maintain::maintain_transaction_pool_future( client, diff --git a/crates/node/builder/src/launch/common.rs b/crates/node/builder/src/launch/common.rs index e385f4a1df..dd04d267c8 100644 --- a/crates/node/builder/src/launch/common.rs +++ b/crates/node/builder/src/launch/common.rs @@ -215,9 +215,7 @@ impl LaunchContext { /// Configure global settings this includes: /// /// - Raising the file descriptor limit - /// - Configuring the global rayon thread pool with available parallelism. Honoring - /// engine.reserved-cpu-cores to reserve given number of cores for O while using at least 1 - /// core for the rayon thread pool + /// - Configuring the global rayon thread pool for implicit `par_iter` usage pub fn configure_globals(&self, reserved_cpu_cores: usize) { // Raise the fd limit of the process. // Does not do anything on windows. @@ -229,9 +227,7 @@ impl LaunchContext { Err(err) => warn!(%err, "Failed to raise file descriptor limit"), } - // Reserving the given number of CPU cores for the rest of OS. - // Users can reserve more cores by setting engine.reserved-cpu-cores - // Note: The global rayon thread pool will use at least one core. + // Configure the implicit global rayon pool for `par_iter` usage. let num_threads = available_parallelism() .map_or(0, |num| num.get().saturating_sub(reserved_cpu_cores).max(1)); if let Err(err) = ThreadPoolBuilder::new() @@ -503,6 +499,7 @@ where self.chain_spec(), static_file_provider, rocksdb_provider, + self.task_executor().clone(), )? .with_prune_modes(self.prune_modes()) .with_changeset_cache(changeset_cache); @@ -558,7 +555,7 @@ where let (tx, rx) = oneshot::channel(); // Pipeline should be run as blocking and panic if it fails. - self.task_executor().spawn_critical_blocking( + self.task_executor().spawn_critical_blocking_task( "pipeline task", Box::pin(async move { let (_, result) = pipeline.run_as_fut(Some(unwind_target)).await; @@ -678,7 +675,8 @@ where debug!(target: "reth::cli", "Spawning stages metrics listener task"); let sync_metrics_listener = reth_stages::MetricsListener::new(metrics_receiver); - self.task_executor().spawn_critical("stages metrics listener task", sync_metrics_listener); + self.task_executor() + .spawn_critical_task("stages metrics listener task", sync_metrics_listener); LaunchContextWith { inner: self.inner, @@ -1105,7 +1103,7 @@ where // If engine events are provided, spawn listener for new payload reporting let ethstats_for_events = ethstats.clone(); let task_executor = self.task_executor().clone(); - task_executor.spawn(Box::pin(async move { + task_executor.spawn_task(Box::pin(async move { while let Some(event) = engine_events.next().await { use reth_engine_primitives::ConsensusEngineEvent; match event { @@ -1131,7 +1129,7 @@ where })); // Spawn main ethstats service - task_executor.spawn(Box::pin(async move { ethstats.run().await })); + task_executor.spawn_task(Box::pin(async move { ethstats.run().await })); Ok(()) } diff --git a/crates/node/builder/src/launch/debug.rs b/crates/node/builder/src/launch/debug.rs index 896f56fb61..8e12567e48 100644 --- a/crates/node/builder/src/launch/debug.rs +++ b/crates/node/builder/src/launch/debug.rs @@ -213,7 +213,7 @@ where handle .node .task_executor - .spawn_critical("custom debug block provider consensus client", async move { + .spawn_critical_task("custom debug block provider consensus client", async move { rpc_consensus_client.run().await }); } else if let Some(url) = config.debug.rpc_consensus_url.clone() { @@ -234,7 +234,7 @@ where Arc::new(block_provider), ); - handle.node.task_executor.spawn_critical("rpc-ws consensus client", async move { + handle.node.task_executor.spawn_critical_task("rpc-ws consensus client", async move { rpc_consensus_client.run().await }); } else if let Some(maybe_custom_etherscan_url) = config.debug.etherscan.clone() { @@ -262,9 +262,12 @@ where handle.node.add_ons_handle.beacon_engine_handle.clone(), Arc::new(block_provider), ); - handle.node.task_executor.spawn_critical("etherscan consensus client", async move { - rpc_consensus_client.run().await - }); + handle + .node + .task_executor + .spawn_critical_task("etherscan consensus client", async move { + rpc_consensus_client.run().await + }); } if config.dev.dev { @@ -289,7 +292,7 @@ where }; let dev_mining_mode = handle.node.config.dev_mining_mode(pool); - handle.node.task_executor.spawn_critical("local engine", async move { + handle.node.task_executor.spawn_critical_task("local engine", async move { LocalMiner::new( blockchain_db, builder, diff --git a/crates/node/builder/src/launch/engine.rs b/crates/node/builder/src/launch/engine.rs index 015a28c8ba..ce13198a34 100644 --- a/crates/node/builder/src/launch/engine.rs +++ b/crates/node/builder/src/launch/engine.rs @@ -248,7 +248,7 @@ impl EngineNodeLauncher { static_file_producer_events.map(Into::into), ); - ctx.task_executor().spawn_critical( + ctx.task_executor().spawn_critical_task( "events task", Box::pin(node::handle_events( Some(Box::new(ctx.components().network().clone())), @@ -371,7 +371,7 @@ impl EngineNodeLauncher { let _ = exit.send(res); }; - ctx.task_executor().spawn_critical("consensus engine", Box::pin(consensus_engine)); + ctx.task_executor().spawn_critical_task("consensus engine", Box::pin(consensus_engine)); let engine_events_for_ethstats = engine_events.new_listener(); diff --git a/crates/node/builder/src/launch/exex.rs b/crates/node/builder/src/launch/exex.rs index 0621ee6c6e..ad0dfdcd41 100644 --- a/crates/node/builder/src/launch/exex.rs +++ b/crates/node/builder/src/launch/exex.rs @@ -111,7 +111,7 @@ impl ExExLauncher { let exex = exex.launch(context).instrument(span.clone()).await?; // spawn it as a crit task - executor.spawn_critical( + executor.spawn_critical_task( "exex", async move { info!(target: "reth::cli", "ExEx started"); @@ -140,14 +140,14 @@ impl ExExLauncher { ) .with_wal_blocks_warning(wal_blocks_warning); let exex_manager_handle = exex_manager.handle(); - components.task_executor().spawn_critical("exex manager", async move { + components.task_executor().spawn_critical_task("exex manager", async move { exex_manager.await.expect("exex manager crashed"); }); // send notifications from the blockchain tree to exex manager let mut canon_state_notifications = components.provider().subscribe_to_canonical_state(); let mut handle = exex_manager_handle.clone(); - components.task_executor().spawn_critical( + components.task_executor().spawn_critical_task( "exex manager blockchain tree notifications", async move { while let Ok(notification) = canon_state_notifications.recv().await { diff --git a/crates/node/builder/src/rpc.rs b/crates/node/builder/src/rpc.rs index 827b823b59..f7a9e0fc88 100644 --- a/crates/node/builder/src/rpc.rs +++ b/crates/node/builder/src/rpc.rs @@ -992,7 +992,7 @@ where let new_canonical_blocks = node.provider().canonical_state_stream(); let c = cache.clone(); - node.task_executor().spawn_critical( + node.task_executor().spawn_critical_task( "cache canonical blocks task", Box::pin(async move { cache_new_blocks_task(c, new_canonical_blocks).await; @@ -1352,6 +1352,7 @@ where tree_config, invalid_block_hook, changeset_cache, + ctx.node.task_executor().clone(), )) } } diff --git a/crates/node/metrics/src/server.rs b/crates/node/metrics/src/server.rs index 9ef68cf303..df64d4cd7c 100644 --- a/crates/node/metrics/src/server.rs +++ b/crates/node/metrics/src/server.rs @@ -419,7 +419,7 @@ fn handle_pprof_heap(_pprof_dump_dir: &PathBuf) -> Response> { mod tests { use super::*; use reqwest::Client; - use reth_tasks::TaskManager; + use reth_tasks::Runtime; use socket2::{Domain, Socket, Type}; use std::net::{SocketAddr, TcpListener}; @@ -433,7 +433,7 @@ mod tests { listener.local_addr().unwrap() } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_metrics_endpoint() { let chain_spec_info = ChainSpecInfo { name: "test".to_string() }; let version_info = VersionInfo { @@ -445,8 +445,7 @@ mod tests { build_profile: "test", }; - let tasks = TaskManager::current(); - let executor = tasks.executor(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let hooks = Hooks::builder().build(); @@ -455,7 +454,7 @@ mod tests { listen_addr, version_info, chain_spec_info, - executor, + runtime.clone(), hooks, std::env::temp_dir(), ); @@ -471,5 +470,8 @@ mod tests { let body = response.text().await.unwrap(); assert!(body.contains("reth_process_cpu_seconds_total")); assert!(body.contains("reth_process_start_time_seconds")); + + // Make sure the runtime is dropped after the test runs. + drop(runtime); } } diff --git a/crates/payload/basic/src/lib.rs b/crates/payload/basic/src/lib.rs index 287ab53a3a..83b644ea39 100644 --- a/crates/payload/basic/src/lib.rs +++ b/crates/payload/basic/src/lib.rs @@ -349,7 +349,7 @@ where self.metrics.inc_initiated_payload_builds(); let cached_reads = self.cached_reads.take().unwrap_or_default(); let builder = self.builder.clone(); - self.executor.spawn_blocking(Box::pin(async move { + self.executor.spawn_blocking_task(Box::pin(async move { // acquire the permit for executing the task let _permit = guard.acquire().await; let args = @@ -495,7 +495,7 @@ where let (tx, rx) = oneshot::channel(); let config = self.config.clone(); let builder = self.builder.clone(); - self.executor.spawn_blocking(Box::pin(async move { + self.executor.spawn_blocking_task(Box::pin(async move { let res = builder.build_empty_payload(config); let _ = tx.send(res); })); @@ -506,7 +506,7 @@ where debug!(target: "payload_builder", id=%self.config.payload_id(), "racing fallback payload"); // race the in progress job with this job let (tx, rx) = oneshot::channel(); - self.executor.spawn_blocking(Box::pin(async move { + self.executor.spawn_blocking_task(Box::pin(async move { let _ = tx.send(job()); })); empty_payload = Some(rx); diff --git a/crates/ress/provider/src/lib.rs b/crates/ress/provider/src/lib.rs index 3a0b83673a..6faa395698 100644 --- a/crates/ress/provider/src/lib.rs +++ b/crates/ress/provider/src/lib.rs @@ -224,7 +224,7 @@ where let _permit = self.witness_semaphore.acquire().await.map_err(ProviderError::other)?; let this = self.clone(); let (tx, rx) = oneshot::channel(); - self.task_spawner.spawn_blocking(Box::pin(async move { + self.task_spawner.spawn_blocking_task(Box::pin(async move { let result = this.generate_witness(block_hash); let _ = tx.send(result); })); diff --git a/crates/rpc/rpc-engine-api/src/engine_api.rs b/crates/rpc/rpc-engine-api/src/engine_api.rs index 1abf4be911..0aec613660 100644 --- a/crates/rpc/rpc-engine-api/src/engine_api.rs +++ b/crates/rpc/rpc-engine-api/src/engine_api.rs @@ -532,7 +532,7 @@ where let (tx, rx) = oneshot::channel(); let inner = self.inner.clone(); - self.inner.task_spawner.spawn_blocking(Box::pin(async move { + self.inner.task_spawner.spawn_blocking_task(Box::pin(async move { if count > MAX_PAYLOAD_BODIES_LIMIT { tx.send(Err(EngineApiError::PayloadRequestTooLarge { len: count })).ok(); return; @@ -659,7 +659,7 @@ where let (tx, rx) = oneshot::channel(); let inner = self.inner.clone(); - self.inner.task_spawner.spawn_blocking(Box::pin(async move { + self.inner.task_spawner.spawn_blocking_task(Box::pin(async move { let mut result = Vec::with_capacity(hashes.len()); for hash in hashes { let block_result = inner.provider.block(BlockHashOrNumber::Hash(hash)); diff --git a/crates/rpc/rpc-eth-api/src/helpers/blocking_task.rs b/crates/rpc/rpc-eth-api/src/helpers/blocking_task.rs index c174cd9bde..3908d407c7 100644 --- a/crates/rpc/rpc-eth-api/src/helpers/blocking_task.rs +++ b/crates/rpc/rpc-eth-api/src/helpers/blocking_task.rs @@ -163,7 +163,7 @@ pub trait SpawnBlocking: EthApiTypes + Clone + Send + Sync + 'static { { let (tx, rx) = oneshot::channel(); let this = self.clone(); - self.io_task_spawner().spawn_blocking(Box::pin(async move { + self.io_task_spawner().spawn_blocking_task(Box::pin(async move { let res = f(this); let _ = tx.send(res); })); @@ -186,7 +186,7 @@ pub trait SpawnBlocking: EthApiTypes + Clone + Send + Sync + 'static { { let (tx, rx) = oneshot::channel(); let this = self.clone(); - self.io_task_spawner().spawn_blocking(Box::pin(async move { + self.io_task_spawner().spawn_blocking_task(Box::pin(async move { let res = f(this).await; let _ = tx.send(res); })); @@ -197,8 +197,8 @@ pub trait SpawnBlocking: EthApiTypes + Clone + Send + Sync + 'static { /// Executes a blocking task on the tracing pool. /// /// Note: This is expected for futures that are predominantly CPU bound, as it uses `rayon` - /// under the hood, for blocking IO futures use [`spawn_blocking`](Self::spawn_blocking_io). See - /// . + /// under the hood, for blocking IO futures use + /// [`spawn_blocking_task`](Self::spawn_blocking_io). See . fn spawn_tracing(&self, f: F) -> impl Future> + Send where F: FnOnce(Self) -> Result + Send + 'static, diff --git a/crates/rpc/rpc-eth-types/src/cache/mod.rs b/crates/rpc/rpc-eth-types/src/cache/mod.rs index 7ae10da83a..d070d1e01e 100644 --- a/crates/rpc/rpc-eth-types/src/cache/mod.rs +++ b/crates/rpc/rpc-eth-types/src/cache/mod.rs @@ -145,7 +145,7 @@ impl EthStateCache { max_concurrent_db_requests, max_cached_tx_hashes, ); - executor.spawn_critical("eth state cache", Box::pin(service)); + executor.spawn_critical_task("eth state cache", Box::pin(service)); this } @@ -494,19 +494,21 @@ where let rate_limiter = this.rate_limiter.clone(); let mut action_sender = ActionSender::new(CacheKind::Block, block_hash, action_tx); - this.action_task_spawner.spawn_blocking(Box::pin(async move { - // Acquire permit - let _permit = rate_limiter.acquire().await; - // Only look in the database to prevent situations where we - // looking up the tree is blocking - let block_sender = provider - .sealed_block_with_senders( - BlockHashOrNumber::Hash(block_hash), - TransactionVariant::WithHash, - ) - .map(|maybe_block| maybe_block.map(Arc::new)); - action_sender.send_block(block_sender); - })); + this.action_task_spawner.spawn_blocking_task(Box::pin( + async move { + // Acquire permit + let _permit = rate_limiter.acquire().await; + // Only look in the database to prevent situations where we + // looking up the tree is blocking + let block_sender = provider + .sealed_block_with_senders( + BlockHashOrNumber::Hash(block_hash), + TransactionVariant::WithHash, + ) + .map(|maybe_block| maybe_block.map(Arc::new)); + action_sender.send_block(block_sender); + }, + )); } } CacheAction::GetReceipts { block_hash, response_tx } => { @@ -523,15 +525,17 @@ where let rate_limiter = this.rate_limiter.clone(); let mut action_sender = ActionSender::new(CacheKind::Receipt, block_hash, action_tx); - this.action_task_spawner.spawn_blocking(Box::pin(async move { - // Acquire permit - let _permit = rate_limiter.acquire().await; - let res = provider - .receipts_by_block(block_hash.into()) - .map(|maybe_receipts| maybe_receipts.map(Arc::new)); + this.action_task_spawner.spawn_blocking_task(Box::pin( + async move { + // Acquire permit + let _permit = rate_limiter.acquire().await; + let res = provider + .receipts_by_block(block_hash.into()) + .map(|maybe_receipts| maybe_receipts.map(Arc::new)); - action_sender.send_receipts(res); - })); + action_sender.send_receipts(res); + }, + )); } } CacheAction::GetHeader { block_hash, response_tx } => { @@ -555,16 +559,19 @@ where let rate_limiter = this.rate_limiter.clone(); let mut action_sender = ActionSender::new(CacheKind::Header, block_hash, action_tx); - this.action_task_spawner.spawn_blocking(Box::pin(async move { - // Acquire permit - let _permit = rate_limiter.acquire().await; - let header = provider.header(block_hash).and_then(|header| { - header.ok_or_else(|| { - ProviderError::HeaderNotFound(block_hash.into()) - }) - }); - action_sender.send_header(header); - })); + this.action_task_spawner.spawn_blocking_task(Box::pin( + async move { + // Acquire permit + let _permit = rate_limiter.acquire().await; + let header = + provider.header(block_hash).and_then(|header| { + header.ok_or_else(|| { + ProviderError::HeaderNotFound(block_hash.into()) + }) + }); + action_sender.send_header(header); + }, + )); } } CacheAction::ReceiptsResult { block_hash, res } => { diff --git a/crates/rpc/rpc/src/debug.rs b/crates/rpc/rpc/src/debug.rs index 74231b4b14..fe0dbba420 100644 --- a/crates/rpc/rpc/src/debug.rs +++ b/crates/rpc/rpc/src/debug.rs @@ -70,7 +70,7 @@ where }); // Spawn a task caching bad blocks - executor.spawn(Box::pin(async move { + executor.spawn_task(Box::pin(async move { while let Some(event) = stream.next().await { if let ConsensusEngineEvent::InvalidBlock(block) = event && let Ok(recovered) = diff --git a/crates/rpc/rpc/src/eth/builder.rs b/crates/rpc/rpc/src/eth/builder.rs index a968dacb7f..c283a76c91 100644 --- a/crates/rpc/rpc/src/eth/builder.rs +++ b/crates/rpc/rpc/src/eth/builder.rs @@ -525,7 +525,7 @@ where let new_canonical_blocks = provider.canonical_state_stream(); let fhc = fee_history_cache.clone(); let cache = eth_cache.clone(); - task_spawner.spawn_critical( + task_spawner.spawn_critical_task( "cache canonical blocks for fee history task", Box::pin(async move { fee_history_cache_new_blocks_task(fhc, new_canonical_blocks, provider, cache).await; diff --git a/crates/rpc/rpc/src/eth/core.rs b/crates/rpc/rpc/src/eth/core.rs index 3e05c38814..640d1aa3f5 100644 --- a/crates/rpc/rpc/src/eth/core.rs +++ b/crates/rpc/rpc/src/eth/core.rs @@ -385,7 +385,7 @@ where // Create tx pool insertion batcher let (processor, tx_batch_sender) = BatchTxProcessor::new(components.pool().clone(), max_batch_size); - task_spawner.spawn_critical("tx-batcher", Box::pin(processor)); + task_spawner.spawn_critical_task("tx-batcher", Box::pin(processor)); Self { components, diff --git a/crates/rpc/rpc/src/eth/filter.rs b/crates/rpc/rpc/src/eth/filter.rs index 8c4010d00a..dcfa605ca3 100644 --- a/crates/rpc/rpc/src/eth/filter.rs +++ b/crates/rpc/rpc/src/eth/filter.rs @@ -152,7 +152,7 @@ where let eth_filter = Self { inner: Arc::new(inner) }; let this = eth_filter.clone(); - eth_filter.inner.task_spawner.spawn_critical( + eth_filter.inner.task_spawner.spawn_critical_task( "eth-filters_stale-filters-clean", Box::pin(async move { this.watch_and_clear_stale_filters().await; @@ -634,7 +634,7 @@ where let (tx, rx) = oneshot::channel(); let this = self.clone(); - self.task_spawner.spawn_blocking(Box::pin(async move { + self.task_spawner.spawn_blocking_task(Box::pin(async move { let res = this.get_logs_in_block_range_inner(&filter, from_block, to_block, limits).await; let _ = tx.send(res); diff --git a/crates/rpc/rpc/src/eth/pubsub.rs b/crates/rpc/rpc/src/eth/pubsub.rs index 7e702cbc79..1f7374fa5c 100644 --- a/crates/rpc/rpc/src/eth/pubsub.rs +++ b/crates/rpc/rpc/src/eth/pubsub.rs @@ -214,7 +214,7 @@ where ) -> jsonrpsee::core::SubscriptionResult { let sink = pending.accept().await?; let pubsub = self.clone(); - self.inner.subscription_task_spawner.spawn(Box::pin(async move { + self.inner.subscription_task_spawner.spawn_task(Box::pin(async move { let _ = pubsub.handle_accepted(sink, kind, params).await; })); diff --git a/crates/rpc/rpc/src/reth.rs b/crates/rpc/rpc/src/reth.rs index 6f4e79d291..4963c48689 100644 --- a/crates/rpc/rpc/src/reth.rs +++ b/crates/rpc/rpc/src/reth.rs @@ -55,7 +55,7 @@ where let (tx, rx) = oneshot::channel(); let this = self.clone(); let f = c(this); - self.inner.task_spawner.spawn_blocking(Box::pin(async move { + self.inner.task_spawner.spawn_blocking_task(Box::pin(async move { let res = f.await; let _ = tx.send(res); })); @@ -116,7 +116,7 @@ where ) -> jsonrpsee::core::SubscriptionResult { let sink = pending.accept().await?; let stream = self.provider().canonical_state_stream(); - self.inner.task_spawner.spawn(Box::pin(pipe_from_stream(sink, stream))); + self.inner.task_spawner.spawn_task(Box::pin(pipe_from_stream(sink, stream))); Ok(()) } @@ -128,7 +128,7 @@ where ) -> jsonrpsee::core::SubscriptionResult { let sink = pending.accept().await?; let stream = self.provider().persisted_block_stream(); - self.inner.task_spawner.spawn(Box::pin(pipe_from_stream(sink, stream))); + self.inner.task_spawner.spawn_task(Box::pin(pipe_from_stream(sink, stream))); Ok(()) } @@ -141,7 +141,7 @@ where let sink = pending.accept().await?; let canon_stream = self.provider().canonical_state_stream(); let finalized_stream = self.provider().finalized_block_stream(); - self.inner.task_spawner.spawn(Box::pin(finalized_chain_notifications( + self.inner.task_spawner.spawn_task(Box::pin(finalized_chain_notifications( sink, canon_stream, finalized_stream, diff --git a/crates/rpc/rpc/src/validation.rs b/crates/rpc/rpc/src/validation.rs index a11124a84a..71c1a730ac 100644 --- a/crates/rpc/rpc/src/validation.rs +++ b/crates/rpc/rpc/src/validation.rs @@ -511,7 +511,7 @@ where let this = self.clone(); let (tx, rx) = oneshot::channel(); - self.task_spawner.spawn_blocking(Box::pin(async move { + self.task_spawner.spawn_blocking_task(Box::pin(async move { let result = Self::validate_builder_submission_v3(&this, request) .await .map_err(ErrorObject::from); @@ -529,7 +529,7 @@ where let this = self.clone(); let (tx, rx) = oneshot::channel(); - self.task_spawner.spawn_blocking(Box::pin(async move { + self.task_spawner.spawn_blocking_task(Box::pin(async move { let result = Self::validate_builder_submission_v4(&this, request) .await .map_err(ErrorObject::from); @@ -547,7 +547,7 @@ where let this = self.clone(); let (tx, rx) = oneshot::channel(); - self.task_spawner.spawn_blocking(Box::pin(async move { + self.task_spawner.spawn_blocking_task(Box::pin(async move { let result = Self::validate_builder_submission_v5(&this, request) .await .map_err(ErrorObject::from); diff --git a/crates/stages/api/Cargo.toml b/crates/stages/api/Cargo.toml index 6cdf45f790..59e0069c6f 100644 --- a/crates/stages/api/Cargo.toml +++ b/crates/stages/api/Cargo.toml @@ -47,6 +47,7 @@ reth-chainspec.workspace = true reth-db = { workspace = true, features = ["test-utils"] } reth-db-api.workspace = true reth-provider = { workspace = true, features = ["test-utils"] } +reth-tasks.workspace = true tokio = { workspace = true, features = ["sync", "rt-multi-thread"] } tokio-stream.workspace = true reth-testing-utils.workspace = true @@ -61,4 +62,5 @@ test-utils = [ "reth-primitives-traits/test-utils", "reth-provider/test-utils", "reth-stages-types/test-utils", + "reth-tasks/test-utils", ] diff --git a/crates/stages/api/src/stage.rs b/crates/stages/api/src/stage.rs index 32b81d3f70..2dd7cb4c37 100644 --- a/crates/stages/api/src/stage.rs +++ b/crates/stages/api/src/stage.rs @@ -348,6 +348,7 @@ mod tests { .build() .unwrap(), RocksDBProvider::builder(create_test_rocksdb_dir().0.keep()).build().unwrap(), + reth_tasks::Runtime::test(), ) .unwrap(); diff --git a/crates/stages/stages/Cargo.toml b/crates/stages/stages/Cargo.toml index 3c61fb2556..b6d80c9527 100644 --- a/crates/stages/stages/Cargo.toml +++ b/crates/stages/stages/Cargo.toml @@ -121,6 +121,7 @@ test-utils = [ "dep:reth-ethereum-primitives", "reth-ethereum-primitives?/test-utils", "reth-evm-ethereum/test-utils", + "reth-tasks/test-utils", ] rocksdb = ["reth-provider/rocksdb", "reth-db-common/rocksdb"] edge = ["rocksdb"] diff --git a/crates/stages/stages/src/test_utils/test_db.rs b/crates/stages/stages/src/test_utils/test_db.rs index 5f00a498c4..18e36f6ffa 100644 --- a/crates/stages/stages/src/test_utils/test_db.rs +++ b/crates/stages/stages/src/test_utils/test_db.rs @@ -54,6 +54,7 @@ impl Default for TestStageDB { MAINNET.clone(), StaticFileProvider::read_write(static_dir_path).unwrap(), RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(), + reth_tasks::Runtime::test(), ) .expect("failed to create test provider factory"), } @@ -73,6 +74,7 @@ impl TestStageDB { MAINNET.clone(), StaticFileProvider::read_write(static_dir_path).unwrap(), RocksDBProvider::builder(rocksdb_dir_path).with_default_tables().build().unwrap(), + reth_tasks::Runtime::test(), ) .expect("failed to create test provider factory"), } diff --git a/crates/storage/db-common/Cargo.toml b/crates/storage/db-common/Cargo.toml index 13ee65d641..eb65423c6e 100644 --- a/crates/storage/db-common/Cargo.toml +++ b/crates/storage/db-common/Cargo.toml @@ -44,6 +44,7 @@ tracing.workspace = true [dev-dependencies] reth-db = { workspace = true, features = ["mdbx"] } reth-provider = { workspace = true, features = ["test-utils"] } +reth-tasks.workspace = true [features] rocksdb = ["reth-db-api/rocksdb", "reth-provider/rocksdb"] diff --git a/crates/storage/db-common/src/init.rs b/crates/storage/db-common/src/init.rs index 4b0027de21..fa28eb35ed 100644 --- a/crates/storage/db-common/src/init.rs +++ b/crates/storage/db-common/src/init.rs @@ -900,6 +900,7 @@ mod tests { MAINNET.clone(), static_file_provider, rocksdb_provider, + reth_tasks::Runtime::test(), ) .unwrap(), ); diff --git a/crates/storage/provider/Cargo.toml b/crates/storage/provider/Cargo.toml index bc0a892660..f5e9663125 100644 --- a/crates/storage/provider/Cargo.toml +++ b/crates/storage/provider/Cargo.toml @@ -24,7 +24,7 @@ reth-db = { workspace = true, features = ["mdbx"] } reth-db-api.workspace = true reth-prune-types.workspace = true reth-stages-types.workspace = true -reth-tasks.workspace = true +reth-tasks = { workspace = true, features = ["rayon"] } reth-trie = { workspace = true, features = ["metrics"] } reth-trie-db = { workspace = true, features = ["metrics"] } reth-nippy-jar.workspace = true @@ -105,4 +105,5 @@ test-utils = [ "reth-stages-types/test-utils", "revm-state", "tokio", + "reth-tasks/test-utils", ] diff --git a/crates/storage/provider/src/lib.rs b/crates/storage/provider/src/lib.rs index 3b9b57e475..6f3f2295f8 100644 --- a/crates/storage/provider/src/lib.rs +++ b/crates/storage/provider/src/lib.rs @@ -28,9 +28,6 @@ pub use providers::{ pub mod changeset_walker; pub mod changesets_utils; -mod storage_threadpool; -use storage_threadpool::STORAGE_POOL; - #[cfg(any(test, feature = "test-utils"))] /// Common test helpers for mocking the Provider. pub mod test_utils; diff --git a/crates/storage/provider/src/providers/database/builder.rs b/crates/storage/provider/src/providers/database/builder.rs index 9403754fea..e170e8bcb2 100644 --- a/crates/storage/provider/src/providers/database/builder.rs +++ b/crates/storage/provider/src/providers/database/builder.rs @@ -54,9 +54,11 @@ impl ProviderFactoryBuilder { /// use reth_chainspec::MAINNET; /// use reth_provider::providers::{NodeTypesForProvider, ProviderFactoryBuilder}; /// - /// fn demo>() { + /// fn demo>( + /// runtime: reth_tasks::Runtime, + /// ) { /// let provider_factory = ProviderFactoryBuilder::::default() - /// .open_read_only(MAINNET.clone(), "datadir") + /// .open_read_only(MAINNET.clone(), "datadir", runtime) /// .unwrap(); /// } /// ``` @@ -69,9 +71,15 @@ impl ProviderFactoryBuilder { /// use reth_chainspec::MAINNET; /// use reth_provider::providers::{NodeTypesForProvider, ProviderFactoryBuilder, ReadOnlyConfig}; /// - /// fn demo>() { + /// fn demo>( + /// runtime: reth_tasks::Runtime, + /// ) { /// let provider_factory = ProviderFactoryBuilder::::default() - /// .open_read_only(MAINNET.clone(), ReadOnlyConfig::from_datadir("datadir").no_watch()) + /// .open_read_only( + /// MAINNET.clone(), + /// ReadOnlyConfig::from_datadir("datadir").no_watch(), + /// runtime, + /// ) /// .unwrap(); /// } /// ``` @@ -87,11 +95,14 @@ impl ProviderFactoryBuilder { /// use reth_chainspec::MAINNET; /// use reth_provider::providers::{NodeTypesForProvider, ProviderFactoryBuilder, ReadOnlyConfig}; /// - /// fn demo>() { + /// fn demo>( + /// runtime: reth_tasks::Runtime, + /// ) { /// let provider_factory = ProviderFactoryBuilder::::default() /// .open_read_only( /// MAINNET.clone(), /// ReadOnlyConfig::from_datadir("datadir").disable_long_read_transaction_safety(), + /// runtime, /// ) /// .unwrap(); /// } @@ -100,6 +111,7 @@ impl ProviderFactoryBuilder { self, chainspec: Arc, config: impl Into, + runtime: reth_tasks::Runtime, ) -> eyre::Result>> where N: NodeTypesForProvider, @@ -110,6 +122,7 @@ impl ProviderFactoryBuilder { .chainspec(chainspec) .static_file(StaticFileProvider::read_only(static_files_dir, watch_static_files)?) .rocksdb_provider(RocksDBProvider::builder(&rocksdb_dir).with_default_tables().build()?) + .runtime(runtime) .build_provider_factory() .map_err(Into::into) } @@ -361,6 +374,54 @@ impl TypesAnd4 { } impl TypesAnd4, StaticFileProvider, RocksDBProvider> +where + N: NodeTypesForProvider, + DB: Database + DatabaseMetrics + Clone + Unpin + 'static, +{ + /// Sets the task runtime for the provider factory. + #[allow(clippy::type_complexity)] + pub fn runtime( + self, + runtime: reth_tasks::Runtime, + ) -> TypesAnd5< + N, + DB, + Arc, + StaticFileProvider, + RocksDBProvider, + reth_tasks::Runtime, + > { + TypesAnd5::new(self.val_1, self.val_2, self.val_3, self.val_4, runtime) + } +} + +/// This is staging type that contains the configured types and _five_ values. +#[derive(Debug)] +pub struct TypesAnd5 { + _types: PhantomData, + val_1: Val1, + val_2: Val2, + val_3: Val3, + val_4: Val4, + val_5: Val5, +} + +impl TypesAnd5 { + /// Creates a new instance with the given types and five values. + pub fn new(val_1: Val1, val_2: Val2, val_3: Val3, val_4: Val4, val_5: Val5) -> Self { + Self { _types: Default::default(), val_1, val_2, val_3, val_4, val_5 } + } +} + +impl + TypesAnd5< + N, + DB, + Arc, + StaticFileProvider, + RocksDBProvider, + reth_tasks::Runtime, + > where N: NodeTypesForProvider, DB: Database + DatabaseMetrics + Clone + Unpin + 'static, @@ -369,7 +430,7 @@ where pub fn build_provider_factory( self, ) -> ProviderResult>> { - let Self { _types, val_1, val_2, val_3, val_4 } = self; - ProviderFactory::new(val_1, val_2, val_3, val_4) + let Self { _types, val_1, val_2, val_3, val_4, val_5 } = self; + ProviderFactory::new(val_1, val_2, val_3, val_4, val_5) } } diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index 94cf7e3579..5ff4ddc7aa 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -78,6 +78,8 @@ pub struct ProviderFactory { rocksdb_provider: RocksDBProvider, /// Changeset cache for trie unwinding changeset_cache: ChangesetCache, + /// Task runtime for spawning parallel I/O work. + runtime: reth_tasks::Runtime, } impl ProviderFactory> { @@ -100,6 +102,7 @@ impl ProviderFactory { chain_spec: Arc, static_file_provider: StaticFileProvider, rocksdb_provider: RocksDBProvider, + runtime: reth_tasks::Runtime, ) -> ProviderResult { // Load storage settings from database at init time. Creates a temporary provider // to read persisted settings, falling back to legacy defaults if none exist. @@ -115,6 +118,7 @@ impl ProviderFactory { Arc::new(RwLock::new(legacy_settings)), rocksdb_provider.clone(), ChangesetCache::new(), + runtime.clone(), ) .storage_settings()? .unwrap_or(legacy_settings); @@ -128,6 +132,7 @@ impl ProviderFactory { storage_settings: Arc::new(RwLock::new(storage_settings)), rocksdb_provider, changeset_cache: ChangesetCache::new(), + runtime, }) } @@ -142,8 +147,9 @@ impl ProviderFactory { chain_spec: Arc, static_file_provider: StaticFileProvider, rocksdb_provider: RocksDBProvider, + runtime: reth_tasks::Runtime, ) -> ProviderResult { - Self::new(db, chain_spec, static_file_provider, rocksdb_provider) + Self::new(db, chain_spec, static_file_provider, rocksdb_provider, runtime) .and_then(Self::assert_consistent) } } @@ -208,12 +214,14 @@ impl> ProviderFactory { args: DatabaseArguments, static_file_provider: StaticFileProvider, rocksdb_provider: RocksDBProvider, + runtime: reth_tasks::Runtime, ) -> RethResult { Self::new( init_db(path, args).map_err(RethError::msg)?, chain_spec, static_file_provider, rocksdb_provider, + runtime, ) .map_err(RethError::Provider) } @@ -237,6 +245,7 @@ impl ProviderFactory { self.storage_settings.clone(), self.rocksdb_provider.clone(), self.changeset_cache.clone(), + self.runtime.clone(), )) } @@ -255,6 +264,7 @@ impl ProviderFactory { self.storage_settings.clone(), self.rocksdb_provider.clone(), self.changeset_cache.clone(), + self.runtime.clone(), ))) } @@ -274,6 +284,7 @@ impl ProviderFactory { self.storage_settings.clone(), self.rocksdb_provider.clone(), self.changeset_cache.clone(), + self.runtime.clone(), )) } @@ -732,6 +743,7 @@ where storage_settings, rocksdb_provider, changeset_cache, + runtime, } = self; f.debug_struct("ProviderFactory") .field("db", &db) @@ -742,6 +754,7 @@ where .field("storage_settings", &*storage_settings.read()) .field("rocksdb_provider", &rocksdb_provider) .field("changeset_cache", &changeset_cache) + .field("runtime", &runtime) .finish() } } @@ -757,6 +770,7 @@ impl Clone for ProviderFactory { storage_settings: self.storage_settings.clone(), rocksdb_provider: self.rocksdb_provider.clone(), changeset_cache: self.changeset_cache.clone(), + runtime: self.runtime.clone(), } } } @@ -822,6 +836,7 @@ mod tests { DatabaseArguments::new(Default::default()), StaticFileProvider::read_write(static_dir_path).unwrap(), RocksDBProvider::builder(&rocksdb_path).build().unwrap(), + reth_tasks::Runtime::test(), ) .unwrap(); let provider = factory.provider().unwrap(); diff --git a/crates/storage/provider/src/providers/database/provider.rs b/crates/storage/provider/src/providers/database/provider.rs index 73d5b22bf9..5990f3cf62 100644 --- a/crates/storage/provider/src/providers/database/provider.rs +++ b/crates/storage/provider/src/providers/database/provider.rs @@ -18,7 +18,7 @@ use crate::{ PruneCheckpointReader, PruneCheckpointWriter, RawRocksDBBatch, RevertsInit, RocksBatchArg, RocksDBProviderFactory, StageCheckpointReader, StateProviderBox, StateWriter, StaticFileProviderFactory, StatsReader, StorageReader, StorageTrieWriter, TransactionVariant, - TransactionsProvider, TransactionsProviderExt, TrieWriter, STORAGE_POOL, + TransactionsProvider, TransactionsProviderExt, TrieWriter, }; use alloy_consensus::{ transaction::{SignerRecoverable, TransactionMeta, TxHashRef}, @@ -199,6 +199,8 @@ pub struct DatabaseProvider { rocksdb_provider: RocksDBProvider, /// Changeset cache for trie unwinding changeset_cache: ChangesetCache, + /// Task runtime for spawning parallel I/O work. + runtime: reth_tasks::Runtime, /// Pending `RocksDB` batches to be committed at provider commit time. #[cfg_attr(not(all(unix, feature = "rocksdb")), allow(dead_code))] pending_rocksdb_batches: PendingRocksDBBatches, @@ -221,6 +223,7 @@ impl Debug for DatabaseProvider { .field("storage_settings", &self.storage_settings) .field("rocksdb_provider", &self.rocksdb_provider) .field("changeset_cache", &self.changeset_cache) + .field("runtime", &self.runtime) .field("pending_rocksdb_batches", &"") .field("commit_order", &self.commit_order) .field("minimum_pruning_distance", &self.minimum_pruning_distance) @@ -354,6 +357,7 @@ impl DatabaseProvider { storage_settings: Arc>, rocksdb_provider: RocksDBProvider, changeset_cache: ChangesetCache, + runtime: reth_tasks::Runtime, commit_order: CommitOrder, ) -> Self { Self { @@ -365,6 +369,7 @@ impl DatabaseProvider { storage_settings, rocksdb_provider, changeset_cache, + runtime, pending_rocksdb_batches: Default::default(), commit_order, minimum_pruning_distance: MINIMUM_UNWIND_SAFE_DISTANCE, @@ -383,6 +388,7 @@ impl DatabaseProvider { storage_settings: Arc>, rocksdb_provider: RocksDBProvider, changeset_cache: ChangesetCache, + runtime: reth_tasks::Runtime, ) -> Self { Self::new_rw_inner( tx, @@ -393,6 +399,7 @@ impl DatabaseProvider { storage_settings, rocksdb_provider, changeset_cache, + runtime, CommitOrder::Normal, ) } @@ -408,6 +415,7 @@ impl DatabaseProvider { storage_settings: Arc>, rocksdb_provider: RocksDBProvider, changeset_cache: ChangesetCache, + runtime: reth_tasks::Runtime, ) -> Self { Self::new_rw_inner( tx, @@ -418,6 +426,7 @@ impl DatabaseProvider { storage_settings, rocksdb_provider, changeset_cache, + runtime, CommitOrder::Unwind, ) } @@ -556,13 +565,14 @@ impl DatabaseProvider DatabaseProvider DatabaseProvider { storage_settings: Arc>, rocksdb_provider: RocksDBProvider, changeset_cache: ChangesetCache, + runtime: reth_tasks::Runtime, ) -> Self { Self { tx, @@ -971,6 +982,7 @@ impl DatabaseProvider { storage_settings, rocksdb_provider, changeset_cache, + runtime, pending_rocksdb_batches: Default::default(), commit_order: CommitOrder::Normal, minimum_pruning_distance: MINIMUM_UNWIND_SAFE_DISTANCE, diff --git a/crates/storage/provider/src/providers/rocksdb/provider.rs b/crates/storage/provider/src/providers/rocksdb/provider.rs index a8ee7ad74b..77b37e410c 100644 --- a/crates/storage/provider/src/providers/rocksdb/provider.rs +++ b/crates/storage/provider/src/providers/rocksdb/provider.rs @@ -1,8 +1,5 @@ use super::metrics::{RocksDBMetrics, RocksDBOperation, ROCKSDB_TABLES}; -use crate::{ - providers::{compute_history_rank, needs_prev_shard_check, HistoryInfo}, - STORAGE_POOL, -}; +use crate::providers::{compute_history_rank, needs_prev_shard_check, HistoryInfo}; use alloy_consensus::transaction::TxHashRef; use alloy_primitives::{ map::{AddressMap, HashMap}, @@ -1204,6 +1201,7 @@ impl RocksDBProvider { blocks: &[ExecutedBlock], tx_nums: &[TxNumber], ctx: RocksDBWriteCtx, + runtime: &reth_tasks::Runtime, ) -> ProviderResult<()> { if !ctx.storage_settings.any_in_rocksdb() { return Ok(()); @@ -1218,7 +1216,7 @@ impl RocksDBProvider { let write_account_history = ctx.storage_settings.account_history_in_rocksdb; let write_storage_history = ctx.storage_settings.storages_history_in_rocksdb; - STORAGE_POOL.in_place_scope(|s| { + runtime.storage_pool().in_place_scope(|s| { if write_tx_hash { s.spawn(|_| { r_tx_hash = Some(self.write_tx_hash_numbers(blocks, tx_nums, &ctx)); diff --git a/crates/storage/provider/src/providers/static_file/manager.rs b/crates/storage/provider/src/providers/static_file/manager.rs index 277474e506..8c94a2f6ce 100644 --- a/crates/storage/provider/src/providers/static_file/manager.rs +++ b/crates/storage/provider/src/providers/static_file/manager.rs @@ -6,7 +6,7 @@ use crate::{ changeset_walker::{StaticFileAccountChangesetWalker, StaticFileStorageChangesetWalker}, to_range, BlockHashReader, BlockNumReader, BlockReader, BlockSource, EitherWriter, EitherWriterDestination, HeaderProvider, ReceiptProvider, StageCheckpointReader, StatsReader, - TransactionVariant, TransactionsProvider, TransactionsProviderExt, STORAGE_POOL, + TransactionVariant, TransactionsProvider, TransactionsProviderExt, }; use alloy_consensus::{transaction::TransactionMeta, Header}; use alloy_eips::{eip2718::Encodable2718, BlockHashOrNumber}; @@ -682,6 +682,7 @@ impl StaticFileProvider { blocks: &[ExecutedBlock], tx_nums: &[TxNumber], ctx: StaticFileWriteCtx, + runtime: &reth_tasks::Runtime, ) -> ProviderResult<()> { if blocks.is_empty() { return Ok(()); @@ -696,7 +697,7 @@ impl StaticFileProvider { let mut r_account_changesets = None; let mut r_storage_changesets = None; - STORAGE_POOL.in_place_scope(|s| { + runtime.storage_pool().in_place_scope(|s| { s.spawn(|_| { r_headers = Some(self.write_segment(StaticFileSegment::Headers, first_block_number, |w| { diff --git a/crates/storage/provider/src/storage_threadpool.rs b/crates/storage/provider/src/storage_threadpool.rs deleted file mode 100644 index e98c16564b..0000000000 --- a/crates/storage/provider/src/storage_threadpool.rs +++ /dev/null @@ -1,23 +0,0 @@ -//! Dedicated thread pool for storage I/O operations. -//! -//! This module provides a static rayon thread pool used for parallel writes to static files, -//! `RocksDB`, and other storage backends during block persistence. - -use rayon::{ThreadPool, ThreadPoolBuilder}; -use std::sync::LazyLock; - -/// Number of threads in the storage I/O thread pool. -const STORAGE_POOL_THREADS: usize = 16; - -/// Static thread pool for storage I/O operations. -/// -/// This pool is used exclusively by [`save_blocks`](crate::DatabaseProvider::save_blocks) to -/// parallelize writes to different storage backends (static files, `RocksDB`). Since this is the -/// only call site, all threads are always available when needed. -pub(crate) static STORAGE_POOL: LazyLock = LazyLock::new(|| { - ThreadPoolBuilder::new() - .num_threads(STORAGE_POOL_THREADS) - .thread_name(|idx| format!("save-blocks-{idx}")) - .build() - .expect("failed to create storage thread pool") -}); diff --git a/crates/storage/provider/src/test_utils/mod.rs b/crates/storage/provider/src/test_utils/mod.rs index 3774f18412..40dc359369 100644 --- a/crates/storage/provider/src/test_utils/mod.rs +++ b/crates/storage/provider/src/test_utils/mod.rs @@ -71,6 +71,7 @@ pub fn create_test_provider_factory_with_node_types( .with_default_tables() .build() .expect("failed to create test RocksDB provider"), + reth_tasks::Runtime::test(), ) .expect("failed to create test provider factory") } diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 7ddc6a32f9..8150651c41 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -13,7 +13,7 @@ workspace = true [dependencies] # async -tokio = { workspace = true, features = ["sync", "rt"] } +tokio = { workspace = true, features = ["sync", "rt", "rt-multi-thread"] } tracing-futures.workspace = true futures-util = { workspace = true, features = ["std"] } @@ -36,3 +36,4 @@ tokio = { workspace = true, features = ["sync", "rt", "rt-multi-thread", "time", [features] rayon = ["dep:rayon", "pin-project"] +test-utils = [] diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 4294f3d40c..091d66ae1b 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -12,22 +12,16 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] #![cfg_attr(docsrs, feature(doc_cfg))] -use crate::{ - metrics::{IncCounterOnDrop, TaskExecutorMetrics}, - shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal}, -}; +use crate::shutdown::{signal, GracefulShutdown, Shutdown, Signal}; use dyn_clone::DynClone; -use futures_util::{ - future::{select, BoxFuture}, - Future, FutureExt, TryFutureExt, -}; +use futures_util::future::BoxFuture; use std::{ any::Any, fmt::{Display, Formatter}, - pin::{pin, Pin}, + pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, OnceLock, + Arc, }, task::{ready, Context, Poll}, thread, @@ -37,17 +31,21 @@ use tokio::{ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, task::JoinHandle, }; -use tracing::{debug, error}; -use tracing_futures::Instrument; +use tracing::debug; pub mod metrics; +pub mod runtime; pub mod shutdown; #[cfg(feature = "rayon")] pub mod pool; -/// Global [`TaskExecutor`] instance that can be accessed from anywhere. -static GLOBAL_EXECUTOR: OnceLock = OnceLock::new(); +#[cfg(feature = "rayon")] +pub use runtime::RayonConfig; +pub use runtime::{Runtime, RuntimeBuildError, RuntimeBuilder, RuntimeConfig, TokioConfig}; + +/// A [`TaskExecutor`] is now an alias for [`Runtime`]. +pub type TaskExecutor = Runtime; /// Spawns an OS thread with the current tokio runtime context propagated. /// @@ -96,7 +94,7 @@ where /// A type that can spawn tasks. /// -/// The main purpose of this type is to abstract over [`TaskExecutor`] so it's more convenient to +/// The main purpose of this type is to abstract over [`Runtime`] so it's more convenient to /// provide default impls for testing. /// /// @@ -109,23 +107,22 @@ where /// use reth_tasks::{TaskSpawner, TokioTaskExecutor}; /// let executor = TokioTaskExecutor::default(); /// -/// let task = executor.spawn(Box::pin(async { +/// let task = executor.spawn_task(Box::pin(async { /// // -- snip -- /// })); /// task.await.unwrap(); /// # } /// ``` /// -/// Use the [`TaskExecutor`] that spawns task directly onto the tokio runtime via the [Handle]. +/// Use the [`Runtime`] that spawns task directly onto the tokio runtime via the [Handle]. /// /// ``` -/// # use reth_tasks::TaskManager; +/// # use reth_tasks::Runtime; /// fn t() { /// use reth_tasks::TaskSpawner; /// let rt = tokio::runtime::Runtime::new().unwrap(); -/// let manager = TaskManager::new(rt.handle().clone()); -/// let executor = manager.executor(); -/// let task = TaskSpawner::spawn(&executor, Box::pin(async { +/// let runtime = Runtime::with_existing_handle(rt.handle().clone()).unwrap(); +/// let task = TaskSpawner::spawn_task(&runtime, Box::pin(async { /// // -- snip -- /// })); /// rt.block_on(task).unwrap(); @@ -137,16 +134,20 @@ where pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone { /// Spawns the task onto the runtime. /// See also [`Handle::spawn`]. - fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>; + fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>; /// This spawns a critical task onto the runtime. - fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>; + fn spawn_critical_task( + &self, + name: &'static str, + fut: BoxFuture<'static, ()>, + ) -> JoinHandle<()>; /// Spawns a blocking task onto the runtime. - fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>; + fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>; /// This spawns a critical blocking task onto the runtime. - fn spawn_critical_blocking( + fn spawn_critical_blocking_task( &self, name: &'static str, fut: BoxFuture<'static, ()>, @@ -168,19 +169,23 @@ impl TokioTaskExecutor { } impl TaskSpawner for TokioTaskExecutor { - fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { + fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { tokio::task::spawn(fut) } - fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { + fn spawn_critical_task( + &self, + _name: &'static str, + fut: BoxFuture<'static, ()>, + ) -> JoinHandle<()> { tokio::task::spawn(fut) } - fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { + fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut)) } - fn spawn_critical_blocking( + fn spawn_critical_blocking_task( &self, _name: &'static str, fut: BoxFuture<'static, ()>, @@ -189,88 +194,45 @@ impl TaskSpawner for TokioTaskExecutor { } } -/// Many reth components require to spawn tasks for long-running jobs. For example `discovery` -/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks -/// that handle the traffic to and from a peer. -/// -/// To unify how tasks are created, the [`TaskManager`] provides access to the configured Tokio -/// runtime. A [`TaskManager`] stores the [`tokio::runtime::Handle`] it is associated with. In this -/// way it is possible to configure on which runtime a task is executed. +/// Monitors critical tasks for panics and manages graceful shutdown. /// /// The main purpose of this type is to be able to monitor if a critical task panicked, for -/// diagnostic purposes, since tokio task essentially fail silently. Therefore, this type is a -/// Stream that yields the name of panicked task, See [`TaskExecutor::spawn_critical`]. In order to -/// execute Tasks use the [`TaskExecutor`] type [`TaskManager::executor`]. +/// diagnostic purposes, since tokio tasks essentially fail silently. Therefore, this type is a +/// Future that resolves with the name of the panicked task. See [`Runtime::spawn_critical_task`]. +/// +/// Automatically spawned as a background task when building a [`Runtime`]. Use +/// [`Runtime::take_task_manager_handle`] to extract the join handle if you need to poll for +/// panic errors directly. #[derive(Debug)] #[must_use = "TaskManager must be polled to monitor critical tasks"] pub struct TaskManager { - /// Handle to the tokio runtime this task manager is associated with. - /// - /// See [`Handle`] docs. - handle: Handle, - /// Sender half for sending task events to this type - task_events_tx: UnboundedSender, - /// Receiver for task events + /// Receiver for task events. task_events_rx: UnboundedReceiver, /// The [Signal] to fire when all tasks should be shutdown. /// /// This is fired when dropped. signal: Option, - /// Receiver of the shutdown signal. - on_shutdown: Shutdown, - /// How many [`GracefulShutdown`] tasks are currently active + /// How many [`GracefulShutdown`] tasks are currently active. graceful_tasks: Arc, } // === impl TaskManager === impl TaskManager { - /// Returns a __new__ [`TaskManager`] over the currently running Runtime. - /// - /// This must be polled for the duration of the program. - /// - /// To obtain the current [`TaskExecutor`] see [`TaskExecutor::current`]. - /// - /// # Panics - /// - /// This will panic if called outside the context of a Tokio runtime. - pub fn current() -> Self { - let handle = Handle::current(); - Self::new(handle) - } - - /// Create a new instance connected to the given handle's tokio runtime. - /// - /// This also sets the global [`TaskExecutor`]. - pub fn new(handle: Handle) -> Self { + /// Create a new [`TaskManager`] without an associated [`Runtime`], returning + /// the shutdown/event primitives for [`RuntimeBuilder`] to wire up. + pub(crate) fn new_parts( + _handle: Handle, + ) -> (Self, Shutdown, UnboundedSender, Arc) { let (task_events_tx, task_events_rx) = unbounded_channel(); let (signal, on_shutdown) = signal(); + let graceful_tasks = Arc::new(AtomicUsize::new(0)); let manager = Self { - handle, - task_events_tx, task_events_rx, signal: Some(signal), - on_shutdown, - graceful_tasks: Arc::new(AtomicUsize::new(0)), + graceful_tasks: Arc::clone(&graceful_tasks), }; - - let _ = GLOBAL_EXECUTOR - .set(manager.executor()) - .inspect_err(|_| error!("Global executor already set")); - - manager - } - - /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is - /// connected to. - pub fn executor(&self) -> TaskExecutor { - TaskExecutor { - handle: self.handle.clone(), - on_shutdown: self.on_shutdown.clone(), - task_events_tx: self.task_events_tx.clone(), - metrics: Default::default(), - graceful_tasks: Arc::clone(&self.graceful_tasks), - } + (manager, on_shutdown, task_events_tx, graceful_tasks) } /// Fires the shutdown signal and awaits until all tasks are shutdown. @@ -287,15 +249,14 @@ impl TaskManager { fn do_graceful_shutdown(self, timeout: Option) -> bool { drop(self.signal); - let when = timeout.map(|t| std::time::Instant::now() + t); - while self.graceful_tasks.load(Ordering::Relaxed) > 0 { - if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) { + let deadline = timeout.map(|t| std::time::Instant::now() + t); + while self.graceful_tasks.load(Ordering::SeqCst) > 0 { + if deadline.is_some_and(|d| std::time::Instant::now() > d) { debug!("graceful shutdown timed out"); - return false + return false; } - std::hint::spin_loop(); + thread::yield_now(); } - debug!("gracefully shut down"); true } @@ -303,8 +264,8 @@ impl TaskManager { /// An endless future that resolves if a critical task panicked. /// -/// See [`TaskExecutor::spawn_critical`] -impl Future for TaskManager { +/// See [`Runtime::spawn_critical_task`] +impl std::future::Future for TaskManager { type Output = Result<(), PanickedTaskError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -339,7 +300,7 @@ impl Display for PanickedTaskError { } impl PanickedTaskError { - fn new(task_name: &'static str, error: Box) -> Self { + pub(crate) fn new(task_name: &'static str, error: Box) -> Self { let error = match error.downcast::() { Ok(value) => Some(*value), Err(error) => match error.downcast::<&str>() { @@ -354,356 +315,13 @@ impl PanickedTaskError { /// Represents the events that the `TaskManager`'s main future can receive. #[derive(Debug)] -enum TaskEvent { +pub(crate) enum TaskEvent { /// Indicates that a critical task has panicked. Panic(PanickedTaskError), /// A signal requesting a graceful shutdown of the `TaskManager`. GracefulShutdown, } -/// A type that can spawn new tokio tasks -#[derive(Debug, Clone)] -pub struct TaskExecutor { - /// Handle to the tokio runtime this task manager is associated with. - /// - /// See [`Handle`] docs. - handle: Handle, - /// Receiver of the shutdown signal. - on_shutdown: Shutdown, - /// Sender half for sending task events to this type - task_events_tx: UnboundedSender, - /// Task Executor Metrics - metrics: TaskExecutorMetrics, - /// How many [`GracefulShutdown`] tasks are currently active - graceful_tasks: Arc, -} - -// === impl TaskExecutor === - -impl TaskExecutor { - /// Attempts to get the current `TaskExecutor` if one has been initialized. - /// - /// Returns an error if no [`TaskExecutor`] has been initialized via [`TaskManager`]. - pub fn try_current() -> Result { - GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default) - } - - /// Returns the current `TaskExecutor`. - /// - /// # Panics - /// - /// Panics if no global executor has been initialized. Use [`try_current`](Self::try_current) - /// for a non-panicking version. - pub fn current() -> Self { - Self::try_current().unwrap() - } - - /// Returns the [Handle] to the tokio runtime. - pub const fn handle(&self) -> &Handle { - &self.handle - } - - /// Returns the receiver of the shutdown signal. - pub const fn on_shutdown_signal(&self) -> &Shutdown { - &self.on_shutdown - } - - /// Spawns a future on the tokio runtime depending on the [`TaskKind`] - fn spawn_on_rt(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - match task_kind { - TaskKind::Default => self.handle.spawn(fut), - TaskKind::Blocking => { - let handle = self.handle.clone(); - self.handle.spawn_blocking(move || handle.block_on(fut)) - } - } - } - - /// Spawns a regular task depending on the given [`TaskKind`] - fn spawn_task_as(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let on_shutdown = self.on_shutdown.clone(); - - // Choose the appropriate finished counter based on task kind - let finished_counter = match task_kind { - TaskKind::Default => self.metrics.finished_regular_tasks_total.clone(), - TaskKind::Blocking => self.metrics.finished_regular_blocking_tasks_total.clone(), - }; - - // Wrap the original future to increment the finished tasks counter upon completion - let task = { - async move { - // Create an instance of IncCounterOnDrop with the counter to increment - let _inc_counter_on_drop = IncCounterOnDrop::new(finished_counter); - let fut = pin!(fut); - let _ = select(on_shutdown, fut).await; - } - } - .in_current_span(); - - self.spawn_on_rt(task, task_kind) - } - - /// Spawns the task onto the runtime. - /// The given future resolves as soon as the [Shutdown] signal is received. - /// - /// See also [`Handle::spawn`]. - pub fn spawn(&self, fut: F) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - self.spawn_task_as(fut, TaskKind::Default) - } - - /// Spawns a blocking task onto the runtime. - /// The given future resolves as soon as the [Shutdown] signal is received. - /// - /// See also [`Handle::spawn_blocking`]. - pub fn spawn_blocking(&self, fut: F) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - self.spawn_task_as(fut, TaskKind::Blocking) - } - - /// Spawns the task onto the runtime. - /// The given future resolves as soon as the [Shutdown] signal is received. - /// - /// See also [`Handle::spawn`]. - pub fn spawn_with_signal(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let on_shutdown = self.on_shutdown.clone(); - let fut = f(on_shutdown); - - let task = fut.in_current_span(); - - self.handle.spawn(task) - } - - /// Spawns a critical task depending on the given [`TaskKind`] - fn spawn_critical_as( - &self, - name: &'static str, - fut: F, - task_kind: TaskKind, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let panicked_tasks_tx = self.task_events_tx.clone(); - let on_shutdown = self.on_shutdown.clone(); - - // wrap the task in catch unwind - let task = std::panic::AssertUnwindSafe(fut) - .catch_unwind() - .map_err(move |error| { - let task_error = PanickedTaskError::new(name, error); - error!("{task_error}"); - let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); - }) - .in_current_span(); - - // Clone only the specific counter that we need. - let finished_critical_tasks_total_metrics = - self.metrics.finished_critical_tasks_total.clone(); - let task = async move { - // Create an instance of IncCounterOnDrop with the counter to increment - let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics); - let task = pin!(task); - let _ = select(on_shutdown, task).await; - }; - - self.spawn_on_rt(task, task_kind) - } - - /// This spawns a critical blocking task onto the runtime. - /// The given future resolves as soon as the [Shutdown] signal is received. - /// - /// If this task panics, the [`TaskManager`] is notified. - pub fn spawn_critical_blocking(&self, name: &'static str, fut: F) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - self.spawn_critical_as(name, fut, TaskKind::Blocking) - } - - /// This spawns a critical task onto the runtime. - /// The given future resolves as soon as the [Shutdown] signal is received. - /// - /// If this task panics, the [`TaskManager`] is notified. - pub fn spawn_critical(&self, name: &'static str, fut: F) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - self.spawn_critical_as(name, fut, TaskKind::Default) - } - - /// This spawns a critical task onto the runtime. - /// - /// If this task panics, the [`TaskManager`] is notified. - pub fn spawn_critical_with_shutdown_signal( - &self, - name: &'static str, - f: impl FnOnce(Shutdown) -> F, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let panicked_tasks_tx = self.task_events_tx.clone(); - let on_shutdown = self.on_shutdown.clone(); - let fut = f(on_shutdown); - - // wrap the task in catch unwind - let task = std::panic::AssertUnwindSafe(fut) - .catch_unwind() - .map_err(move |error| { - let task_error = PanickedTaskError::new(name, error); - error!("{task_error}"); - let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); - }) - .map(drop) - .in_current_span(); - - self.handle.spawn(task) - } - - /// This spawns a critical task onto the runtime. - /// - /// If this task panics, the [`TaskManager`] is notified. - /// The [`TaskManager`] will wait until the given future has completed before shutting down. - /// - /// # Example - /// - /// ```no_run - /// # async fn t(executor: reth_tasks::TaskExecutor) { - /// - /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { - /// // await the shutdown signal - /// let guard = shutdown.await; - /// // do work before exiting the program - /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; - /// // allow graceful shutdown - /// drop(guard); - /// }); - /// # } - /// ``` - pub fn spawn_critical_with_graceful_shutdown_signal( - &self, - name: &'static str, - f: impl FnOnce(GracefulShutdown) -> F, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let panicked_tasks_tx = self.task_events_tx.clone(); - let on_shutdown = GracefulShutdown::new( - self.on_shutdown.clone(), - GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), - ); - let fut = f(on_shutdown); - - // wrap the task in catch unwind - let task = std::panic::AssertUnwindSafe(fut) - .catch_unwind() - .map_err(move |error| { - let task_error = PanickedTaskError::new(name, error); - error!("{task_error}"); - let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); - }) - .map(drop) - .in_current_span(); - - self.handle.spawn(task) - } - - /// This spawns a regular task onto the runtime. - /// - /// The [`TaskManager`] will wait until the given future has completed before shutting down. - /// - /// # Example - /// - /// ```no_run - /// # async fn t(executor: reth_tasks::TaskExecutor) { - /// - /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move { - /// // await the shutdown signal - /// let guard = shutdown.await; - /// // do work before exiting the program - /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; - /// // allow graceful shutdown - /// drop(guard); - /// }); - /// # } - /// ``` - pub fn spawn_with_graceful_shutdown_signal( - &self, - f: impl FnOnce(GracefulShutdown) -> F, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - let on_shutdown = GracefulShutdown::new( - self.on_shutdown.clone(), - GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), - ); - let fut = f(on_shutdown); - - self.handle.spawn(fut) - } - - /// Sends a request to the `TaskManager` to initiate a graceful shutdown. - /// - /// Caution: This will terminate the entire program. - /// - /// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that - /// can be handled via the returned [`GracefulShutdown`]. - pub fn initiate_graceful_shutdown( - &self, - ) -> Result> { - self.task_events_tx - .send(TaskEvent::GracefulShutdown) - .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?; - - Ok(GracefulShutdown::new( - self.on_shutdown.clone(), - GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), - )) - } -} - -impl TaskSpawner for TaskExecutor { - fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { - self.metrics.inc_regular_tasks(); - Self::spawn(self, fut) - } - - fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { - self.metrics.inc_critical_tasks(); - Self::spawn_critical(self, name, fut) - } - - fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { - self.metrics.inc_regular_blocking_tasks(); - Self::spawn_blocking(self, fut) - } - - fn spawn_critical_blocking( - &self, - name: &'static str, - fut: BoxFuture<'static, ()>, - ) -> JoinHandle<()> { - self.metrics.inc_critical_tasks(); - Self::spawn_critical_blocking(self, name, fut) - } -} - /// `TaskSpawner` with extended behaviour #[auto_impl::auto_impl(&, Arc)] pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone { @@ -717,7 +335,7 @@ pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone { f: impl FnOnce(GracefulShutdown) -> F, ) -> JoinHandle<()> where - F: Future + Send + 'static; + F: std::future::Future + Send + 'static; /// This spawns a regular task onto the runtime. /// @@ -727,50 +345,16 @@ pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone { f: impl FnOnce(GracefulShutdown) -> F, ) -> JoinHandle<()> where - F: Future + Send + 'static; + F: std::future::Future + Send + 'static; } -impl TaskSpawnerExt for TaskExecutor { - fn spawn_critical_with_graceful_shutdown_signal( - &self, - name: &'static str, - f: impl FnOnce(GracefulShutdown) -> F, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - Self::spawn_critical_with_graceful_shutdown_signal(self, name, f) - } - - fn spawn_with_graceful_shutdown_signal( - &self, - f: impl FnOnce(GracefulShutdown) -> F, - ) -> JoinHandle<()> - where - F: Future + Send + 'static, - { - Self::spawn_with_graceful_shutdown_signal(self, f) - } -} - -/// Determines how a task is spawned -enum TaskKind { - /// Spawn the task to the default executor [`Handle::spawn`] - Default, - /// Spawn the task to the blocking executor [`Handle::spawn_blocking`] - Blocking, -} - -/// Error returned by `try_current` when no task executor has been configured. -#[derive(Debug, Default, thiserror::Error)] -#[error("No current task executor available.")] -#[non_exhaustive] -pub struct NoCurrentTaskExecutorError; - #[cfg(test)] mod tests { use super::*; - use std::{sync::atomic::AtomicBool, time::Duration}; + use std::{ + sync::atomic::{AtomicBool, AtomicUsize, Ordering}, + time::Duration, + }; #[test] fn test_cloneable() { @@ -789,14 +373,13 @@ mod tests { #[test] fn test_critical() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); + let handle = rt.take_task_manager_handle().unwrap(); - executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") }); + rt.spawn_critical_task("this is a critical task", async { panic!("intentionally panic") }); runtime.block_on(async move { - let err_result = manager.await; + let err_result = handle.await.unwrap(); assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic"); let panicked_err = err_result.unwrap_err(); @@ -805,153 +388,127 @@ mod tests { }) } - // Tests that spawned tasks are terminated if the `TaskManager` drops #[test] fn test_manager_shutdown_critical() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle.clone()); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); let (signal, shutdown) = signal(); - executor.spawn_critical("this is a critical task", async move { + rt.spawn_critical_task("this is a critical task", async move { tokio::time::sleep(Duration::from_millis(200)).await; drop(signal); }); - drop(manager); + rt.graceful_shutdown(); - handle.block_on(shutdown); + runtime.block_on(shutdown); } - // Tests that spawned tasks are terminated if the `TaskManager` drops #[test] fn test_manager_shutdown() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle.clone()); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); let (signal, shutdown) = signal(); - executor.spawn(Box::pin(async move { + rt.spawn_task(Box::pin(async move { tokio::time::sleep(Duration::from_millis(200)).await; drop(signal); })); - drop(manager); + rt.graceful_shutdown(); - handle.block_on(shutdown); + runtime.block_on(shutdown); } #[test] fn test_manager_graceful_shutdown() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); let val = Arc::new(AtomicBool::new(false)); let c = val.clone(); - executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { let _guard = shutdown.await; tokio::time::sleep(Duration::from_millis(200)).await; c.store(true, Ordering::Relaxed); }); - manager.graceful_shutdown(); + rt.graceful_shutdown(); assert!(val.load(Ordering::Relaxed)); } #[test] fn test_manager_graceful_shutdown_many() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); let counter = Arc::new(AtomicUsize::new(0)); let num = 10; for _ in 0..num { let c = counter.clone(); - executor.spawn_critical_with_graceful_shutdown_signal( - "grace", - move |shutdown| async move { - let _guard = shutdown.await; - tokio::time::sleep(Duration::from_millis(200)).await; - c.fetch_add(1, Ordering::SeqCst); - }, - ); + rt.spawn_critical_with_graceful_shutdown_signal("grace", move |shutdown| async move { + let _guard = shutdown.await; + tokio::time::sleep(Duration::from_millis(200)).await; + c.fetch_add(1, Ordering::SeqCst); + }); } - manager.graceful_shutdown(); + rt.graceful_shutdown(); assert_eq!(counter.load(Ordering::Relaxed), num); } #[test] fn test_manager_graceful_shutdown_timeout() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let manager = TaskManager::new(handle); - let executor = manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); let timeout = Duration::from_millis(500); let val = Arc::new(AtomicBool::new(false)); let val2 = val.clone(); - executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + rt.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { let _guard = shutdown.await; tokio::time::sleep(timeout * 3).await; val2.store(true, Ordering::Relaxed); unreachable!("should not be reached"); }); - manager.graceful_shutdown_with_timeout(timeout); + rt.graceful_shutdown_with_timeout(timeout); assert!(!val.load(Ordering::Relaxed)); } #[test] - fn can_access_global() { + fn can_build_runtime() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let handle = runtime.handle().clone(); - let _manager = TaskManager::new(handle); - let _executor = TaskExecutor::try_current().unwrap(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); + let _handle = rt.handle(); } #[test] fn test_graceful_shutdown_triggered_by_executor() { let runtime = tokio::runtime::Runtime::new().unwrap(); - let task_manager = TaskManager::new(runtime.handle().clone()); - let executor = task_manager.executor(); + let rt = Runtime::with_existing_handle(runtime.handle().clone()).unwrap(); + let task_manager_handle = rt.take_task_manager_handle().unwrap(); let task_did_shutdown_flag = Arc::new(AtomicBool::new(false)); let flag_clone = task_did_shutdown_flag.clone(); - let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move { + let spawned_task_handle = rt.spawn_with_signal(|shutdown_signal| async move { shutdown_signal.await; flag_clone.store(true, Ordering::SeqCst); }); - let manager_future_handle = runtime.spawn(task_manager); - - let send_result = executor.initiate_graceful_shutdown(); - assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future"); - - let manager_final_result = runtime.block_on(manager_future_handle); + let send_result = rt.initiate_graceful_shutdown(); + assert!(send_result.is_ok()); + let manager_final_result = runtime.block_on(task_manager_handle); assert!(manager_final_result.is_ok(), "TaskManager task should not panic"); - assert_eq!( - manager_final_result.unwrap(), - Ok(()), - "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request" - ); + assert_eq!(manager_final_result.unwrap(), Ok(())); let task_join_result = runtime.block_on(spawned_task_handle); - assert!(task_join_result.is_ok(), "Spawned task should complete without panic"); + assert!(task_join_result.is_ok()); - assert!( - task_did_shutdown_flag.load(Ordering::Relaxed), - "Task should have received the shutdown signal and set the flag" - ); + assert!(task_did_shutdown_flag.load(Ordering::Relaxed)); } } diff --git a/crates/tasks/src/runtime.rs b/crates/tasks/src/runtime.rs new file mode 100644 index 0000000000..6c43f83ba2 --- /dev/null +++ b/crates/tasks/src/runtime.rs @@ -0,0 +1,928 @@ +//! Centralized management of async and parallel execution. +//! +//! This module provides [`Runtime`], a cheaply cloneable handle that manages: +//! - Tokio runtime (either owned or attached) +//! - Task spawning with shutdown awareness and panic monitoring +//! - Dedicated rayon thread pools for different workloads (with `rayon` feature) +//! - [`BlockingTaskGuard`] for rate-limiting expensive operations (with `rayon` feature) + +#[cfg(feature = "rayon")] +use crate::pool::{BlockingTaskGuard, BlockingTaskPool}; +use crate::{ + metrics::{IncCounterOnDrop, TaskExecutorMetrics}, + shutdown::{GracefulShutdown, GracefulShutdownGuard, Shutdown}, + PanickedTaskError, TaskEvent, TaskManager, +}; +use futures_util::{ + future::{select, BoxFuture}, + Future, FutureExt, TryFutureExt, +}; +#[cfg(feature = "rayon")] +use std::thread::available_parallelism; +use std::{ + pin::pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; +use tokio::{runtime::Handle, sync::mpsc::UnboundedSender, task::JoinHandle}; +use tracing::{debug, error}; +use tracing_futures::Instrument; + +use tokio::runtime::Runtime as TokioRuntime; + +/// Default thread keep-alive duration for the tokio runtime. +pub const DEFAULT_THREAD_KEEP_ALIVE: Duration = Duration::from_secs(15); + +/// Default reserved CPU cores for OS and other processes. +pub const DEFAULT_RESERVED_CPU_CORES: usize = 2; + +/// Default number of threads for the storage I/O pool. +pub const DEFAULT_STORAGE_POOL_THREADS: usize = 16; + +/// Default maximum number of concurrent blocking tasks (for RPC tracing guard). +pub const DEFAULT_MAX_BLOCKING_TASKS: usize = 512; + +/// Configuration for the tokio runtime. +#[derive(Debug, Clone)] +pub enum TokioConfig { + /// Build and own a new tokio runtime. + Owned { + /// Number of worker threads. If `None`, uses tokio's default (number of CPU cores). + worker_threads: Option, + /// How long to keep worker threads alive when idle. + thread_keep_alive: Duration, + /// Thread name prefix. + thread_name: &'static str, + }, + /// Attach to an existing tokio runtime handle. + ExistingHandle(Handle), +} + +impl Default for TokioConfig { + fn default() -> Self { + Self::Owned { + worker_threads: None, + thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE, + thread_name: "tokio-rt", + } + } +} + +impl TokioConfig { + /// Create a config that attaches to an existing runtime handle. + pub const fn existing_handle(handle: Handle) -> Self { + Self::ExistingHandle(handle) + } + + /// Create a config for an owned runtime with the specified number of worker threads. + pub const fn with_worker_threads(worker_threads: usize) -> Self { + Self::Owned { + worker_threads: Some(worker_threads), + thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE, + thread_name: "tokio-rt", + } + } +} + +/// Configuration for the rayon thread pools. +#[derive(Debug, Clone)] +#[cfg(feature = "rayon")] +pub struct RayonConfig { + /// Number of threads for the general CPU pool. + /// If `None`, derived from available parallelism minus reserved cores. + pub cpu_threads: Option, + /// Number of CPU cores to reserve for OS and other processes. + pub reserved_cpu_cores: usize, + /// Number of threads for the RPC blocking pool (trace calls, `eth_getProof`, etc.). + /// If `None`, uses the same as `cpu_threads`. + pub rpc_threads: Option, + /// Number of threads for the trie proof computation pool. + /// If `None`, uses the same as `cpu_threads`. + pub trie_threads: Option, + /// Number of threads for the storage I/O pool (static file, `RocksDB` writes in + /// `save_blocks`). If `None`, uses [`DEFAULT_STORAGE_POOL_THREADS`]. + pub storage_threads: Option, + /// Maximum number of concurrent blocking tasks for the RPC guard semaphore. + pub max_blocking_tasks: usize, +} + +#[cfg(feature = "rayon")] +impl Default for RayonConfig { + fn default() -> Self { + Self { + cpu_threads: None, + reserved_cpu_cores: DEFAULT_RESERVED_CPU_CORES, + rpc_threads: None, + trie_threads: None, + storage_threads: None, + max_blocking_tasks: DEFAULT_MAX_BLOCKING_TASKS, + } + } +} + +#[cfg(feature = "rayon")] +impl RayonConfig { + /// Set the number of reserved CPU cores. + pub const fn with_reserved_cpu_cores(mut self, reserved_cpu_cores: usize) -> Self { + self.reserved_cpu_cores = reserved_cpu_cores; + self + } + + /// Set the maximum number of concurrent blocking tasks. + pub const fn with_max_blocking_tasks(mut self, max_blocking_tasks: usize) -> Self { + self.max_blocking_tasks = max_blocking_tasks; + self + } + + /// Set the number of threads for the RPC blocking pool. + pub const fn with_rpc_threads(mut self, rpc_threads: usize) -> Self { + self.rpc_threads = Some(rpc_threads); + self + } + + /// Set the number of threads for the trie proof pool. + pub const fn with_trie_threads(mut self, trie_threads: usize) -> Self { + self.trie_threads = Some(trie_threads); + self + } + + /// Set the number of threads for the storage I/O pool. + pub const fn with_storage_threads(mut self, storage_threads: usize) -> Self { + self.storage_threads = Some(storage_threads); + self + } + + /// Compute the default number of threads based on available parallelism. + fn default_thread_count(&self) -> usize { + self.cpu_threads.unwrap_or_else(|| { + available_parallelism() + .map_or(1, |num| num.get().saturating_sub(self.reserved_cpu_cores).max(1)) + }) + } +} + +/// Configuration for building a [`Runtime`]. +#[derive(Debug, Clone, Default)] +pub struct RuntimeConfig { + /// Tokio runtime configuration. + pub tokio: TokioConfig, + /// Rayon thread pool configuration. + #[cfg(feature = "rayon")] + pub rayon: RayonConfig, +} + +impl RuntimeConfig { + /// Create a config that attaches to an existing tokio runtime handle. + #[cfg_attr(not(feature = "rayon"), allow(clippy::missing_const_for_fn))] + pub fn with_existing_handle(handle: Handle) -> Self { + Self { + tokio: TokioConfig::ExistingHandle(handle), + #[cfg(feature = "rayon")] + rayon: RayonConfig::default(), + } + } + + /// Set the tokio configuration. + pub fn with_tokio(mut self, tokio: TokioConfig) -> Self { + self.tokio = tokio; + self + } + + /// Set the rayon configuration. + #[cfg(feature = "rayon")] + pub const fn with_rayon(mut self, rayon: RayonConfig) -> Self { + self.rayon = rayon; + self + } +} + +/// Error returned when [`RuntimeBuilder::build`] fails. +#[derive(Debug, thiserror::Error)] +pub enum RuntimeBuildError { + /// Failed to build the tokio runtime. + #[error("Failed to build tokio runtime: {0}")] + TokioBuild(#[from] std::io::Error), + /// Failed to build a rayon thread pool. + #[cfg(feature = "rayon")] + #[error("Failed to build rayon thread pool: {0}")] + RayonBuild(#[from] rayon::ThreadPoolBuildError), +} + +// ── RuntimeInner ────────────────────────────────────────────────────── + +struct RuntimeInner { + /// Owned tokio runtime, if we built one. Kept alive via the `Arc`. + _tokio_runtime: Option, + /// Handle to the tokio runtime. + handle: Handle, + /// Receiver of the shutdown signal. + on_shutdown: Shutdown, + /// Sender half for sending task events to the [`TaskManager`]. + task_events_tx: UnboundedSender, + /// Task executor metrics. + metrics: TaskExecutorMetrics, + /// How many [`GracefulShutdown`] tasks are currently active. + graceful_tasks: Arc, + /// General-purpose rayon CPU pool. + #[cfg(feature = "rayon")] + cpu_pool: rayon::ThreadPool, + /// RPC blocking pool. + #[cfg(feature = "rayon")] + rpc_pool: BlockingTaskPool, + /// Trie proof computation pool. + #[cfg(feature = "rayon")] + trie_pool: rayon::ThreadPool, + /// Storage I/O pool. + #[cfg(feature = "rayon")] + storage_pool: rayon::ThreadPool, + /// Rate limiter for expensive RPC operations. + #[cfg(feature = "rayon")] + blocking_guard: BlockingTaskGuard, + /// Handle to the spawned [`TaskManager`] background task. + /// The task monitors critical tasks for panics and fires the shutdown signal. + /// Can be taken via [`Runtime::take_task_manager_handle`] to poll for panic errors. + task_manager_handle: Mutex>>>, +} + +// ── Runtime ─────────────────────────────────────────────────────────── + +/// A cheaply cloneable handle to the runtime resources. +/// +/// Wraps an `Arc` and provides access to: +/// - The tokio [`Handle`] +/// - Task spawning with shutdown awareness and panic monitoring +/// - Rayon thread pools (with `rayon` feature) +#[derive(Clone)] +pub struct Runtime(Arc); + +impl std::fmt::Debug for Runtime { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Runtime").field("handle", &self.0.handle).finish() + } +} + +#[cfg(any(test, feature = "test-utils"))] +impl Default for Runtime { + fn default() -> Self { + let config = match Handle::try_current() { + Ok(handle) => RuntimeConfig::with_existing_handle(handle), + Err(_) => RuntimeConfig::default(), + }; + RuntimeBuilder::new(config).build().expect("failed to build default Runtime") + } +} + +// ── Constructors ────────────────────────────────────────────────────── + +impl Runtime { + /// Creates a [`Runtime`] that attaches to an existing tokio runtime handle. + pub fn with_existing_handle(handle: Handle) -> Result { + RuntimeBuilder::new(RuntimeConfig::with_existing_handle(handle)).build() + } +} + +// ── Pool accessors ──────────────────────────────────────────────────── + +impl Runtime { + /// Takes the [`TaskManager`] handle out of this runtime, if one is stored. + /// + /// The handle resolves with `Err(PanickedTaskError)` if a critical task panicked, + /// or `Ok(())` if shutdown was requested. If not taken, the background task still + /// runs and logs panics at `debug!` level. + pub fn take_task_manager_handle(&self) -> Option>> { + self.0.task_manager_handle.lock().unwrap().take() + } + + /// Returns the tokio runtime [`Handle`]. + pub fn handle(&self) -> &Handle { + &self.0.handle + } + + /// Get the general-purpose rayon CPU thread pool. + #[cfg(feature = "rayon")] + pub fn cpu_pool(&self) -> &rayon::ThreadPool { + &self.0.cpu_pool + } + + /// Get the RPC blocking task pool. + #[cfg(feature = "rayon")] + pub fn rpc_pool(&self) -> &BlockingTaskPool { + &self.0.rpc_pool + } + + /// Get the trie proof computation pool. + #[cfg(feature = "rayon")] + pub fn trie_pool(&self) -> &rayon::ThreadPool { + &self.0.trie_pool + } + + /// Get the storage I/O pool. + #[cfg(feature = "rayon")] + pub fn storage_pool(&self) -> &rayon::ThreadPool { + &self.0.storage_pool + } + + /// Get a clone of the [`BlockingTaskGuard`]. + #[cfg(feature = "rayon")] + pub fn blocking_guard(&self) -> BlockingTaskGuard { + self.0.blocking_guard.clone() + } + + /// Run a closure on the CPU pool, blocking the current thread until completion. + #[cfg(feature = "rayon")] + pub fn install_cpu(&self, f: F) -> R + where + F: FnOnce() -> R + Send, + R: Send, + { + self.cpu_pool().install(f) + } + + /// Spawn a CPU-bound task on the RPC pool and return an async handle. + #[cfg(feature = "rayon")] + pub fn spawn_rpc(&self, f: F) -> crate::pool::BlockingTaskHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.rpc_pool().spawn(f) + } + + /// Run a closure on the trie pool, blocking the current thread until completion. + #[cfg(feature = "rayon")] + pub fn install_trie(&self, f: F) -> R + where + F: FnOnce() -> R + Send, + R: Send, + { + self.trie_pool().install(f) + } +} + +// ── Test helpers ────────────────────────────────────────────────────── + +impl Runtime { + /// Creates a lightweight [`Runtime`] for tests with minimal thread pools. + /// + /// If called from within a tokio runtime (e.g. `#[tokio::test]`), attaches to the existing + /// handle to avoid shutdown panics when the test runtime is dropped. + pub fn test() -> Self { + let config = match Handle::try_current() { + Ok(handle) => Self::test_config().with_tokio(TokioConfig::existing_handle(handle)), + Err(_) => Self::test_config(), + }; + RuntimeBuilder::new(config).build().expect("failed to build test Runtime") + } + + /// Creates a lightweight [`Runtime`] for tests, attaching to the given tokio handle. + pub fn test_with_handle(handle: Handle) -> Self { + let config = Self::test_config().with_tokio(TokioConfig::existing_handle(handle)); + RuntimeBuilder::new(config).build().expect("failed to build test Runtime") + } + + const fn test_config() -> RuntimeConfig { + RuntimeConfig { + tokio: TokioConfig::Owned { + worker_threads: Some(2), + thread_keep_alive: DEFAULT_THREAD_KEEP_ALIVE, + thread_name: "tokio-test", + }, + #[cfg(feature = "rayon")] + rayon: RayonConfig { + cpu_threads: Some(2), + reserved_cpu_cores: 0, + rpc_threads: Some(2), + trie_threads: Some(2), + storage_threads: Some(2), + max_blocking_tasks: 16, + }, + } + } +} + +// ── Spawn methods ───────────────────────────────────────────────────── + +/// Determines how a task is spawned. +enum TaskKind { + /// Spawn the task to the default executor [`Handle::spawn`]. + Default, + /// Spawn the task to the blocking executor [`Handle::spawn_blocking`]. + Blocking, +} + +impl Runtime { + /// Returns the receiver of the shutdown signal. + pub fn on_shutdown_signal(&self) -> &Shutdown { + &self.0.on_shutdown + } + + /// Spawns a future on the tokio runtime depending on the [`TaskKind`]. + fn spawn_on_rt(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + match task_kind { + TaskKind::Default => self.0.handle.spawn(fut), + TaskKind::Blocking => { + let handle = self.0.handle.clone(); + self.0.handle.spawn_blocking(move || handle.block_on(fut)) + } + } + } + + /// Spawns a regular task depending on the given [`TaskKind`]. + fn spawn_task_as(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let on_shutdown = self.0.on_shutdown.clone(); + + let finished_counter = match task_kind { + TaskKind::Default => self.0.metrics.finished_regular_tasks_total.clone(), + TaskKind::Blocking => self.0.metrics.finished_regular_blocking_tasks_total.clone(), + }; + + let task = { + async move { + let _inc_counter_on_drop = IncCounterOnDrop::new(finished_counter); + let fut = pin!(fut); + let _ = select(on_shutdown, fut).await; + } + } + .in_current_span(); + + self.spawn_on_rt(task, task_kind) + } + + /// Spawns the task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// See also [`Handle::spawn`]. + pub fn spawn_task(&self, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_task_as(fut, TaskKind::Default) + } + + /// Spawns a blocking task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// See also [`Handle::spawn_blocking`]. + pub fn spawn_blocking_task(&self, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_task_as(fut, TaskKind::Blocking) + } + + /// Spawns a blocking closure directly on the tokio runtime, bypassing shutdown + /// awareness. Useful for raw CPU-bound work. + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.0.handle.spawn_blocking(func) + } + + /// Spawns the task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// See also [`Handle::spawn`]. + pub fn spawn_with_signal(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let on_shutdown = self.0.on_shutdown.clone(); + let fut = f(on_shutdown); + let task = fut.in_current_span(); + self.0.handle.spawn(task) + } + + /// Spawns a critical task depending on the given [`TaskKind`]. + fn spawn_critical_as( + &self, + name: &'static str, + fut: F, + task_kind: TaskKind, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let panicked_tasks_tx = self.0.task_events_tx.clone(); + let on_shutdown = self.0.on_shutdown.clone(); + + // wrap the task in catch unwind + let task = std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .map_err(move |error| { + let task_error = PanickedTaskError::new(name, error); + error!("{task_error}"); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); + }) + .in_current_span(); + + let finished_critical_tasks_total_metrics = + self.0.metrics.finished_critical_tasks_total.clone(); + let task = async move { + let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics); + let task = pin!(task); + let _ = select(on_shutdown, task).await; + }; + + self.spawn_on_rt(task, task_kind) + } + + /// This spawns a critical task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical_task(&self, name: &'static str, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_critical_as(name, fut, TaskKind::Default) + } + + /// This spawns a critical blocking task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical_blocking_task(&self, name: &'static str, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_critical_as(name, fut, TaskKind::Blocking) + } + + /// This spawns a critical task onto the runtime. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical_with_shutdown_signal( + &self, + name: &'static str, + f: impl FnOnce(Shutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let panicked_tasks_tx = self.0.task_events_tx.clone(); + let on_shutdown = self.0.on_shutdown.clone(); + let fut = f(on_shutdown); + + // wrap the task in catch unwind + let task = std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .map_err(move |error| { + let task_error = PanickedTaskError::new(name, error); + error!("{task_error}"); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); + }) + .map(drop) + .in_current_span(); + + self.0.handle.spawn(task) + } + + /// This spawns a critical task onto the runtime. + /// + /// If this task panics, the [`TaskManager`] is notified. + /// The [`TaskManager`] will wait until the given future has completed before shutting down. + /// + /// # Example + /// + /// ```no_run + /// # async fn t(executor: reth_tasks::TaskExecutor) { + /// + /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move { + /// // await the shutdown signal + /// let guard = shutdown.await; + /// // do work before exiting the program + /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; + /// // allow graceful shutdown + /// drop(guard); + /// }); + /// # } + /// ``` + pub fn spawn_critical_with_graceful_shutdown_signal( + &self, + name: &'static str, + f: impl FnOnce(GracefulShutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let panicked_tasks_tx = self.0.task_events_tx.clone(); + let on_shutdown = GracefulShutdown::new( + self.0.on_shutdown.clone(), + GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)), + ); + let fut = f(on_shutdown); + + // wrap the task in catch unwind + let task = std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .map_err(move |error| { + let task_error = PanickedTaskError::new(name, error); + error!("{task_error}"); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); + }) + .map(drop) + .in_current_span(); + + self.0.handle.spawn(task) + } + + /// This spawns a regular task onto the runtime. + /// + /// The [`TaskManager`] will wait until the given future has completed before shutting down. + /// + /// # Example + /// + /// ```no_run + /// # async fn t(executor: reth_tasks::TaskExecutor) { + /// + /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move { + /// // await the shutdown signal + /// let guard = shutdown.await; + /// // do work before exiting the program + /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; + /// // allow graceful shutdown + /// drop(guard); + /// }); + /// # } + /// ``` + pub fn spawn_with_graceful_shutdown_signal( + &self, + f: impl FnOnce(GracefulShutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let on_shutdown = GracefulShutdown::new( + self.0.on_shutdown.clone(), + GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)), + ); + let fut = f(on_shutdown); + + self.0.handle.spawn(fut) + } + + /// Sends a request to the `TaskManager` to initiate a graceful shutdown. + /// + /// Caution: This will terminate the entire program. + pub fn initiate_graceful_shutdown( + &self, + ) -> Result> { + self.0 + .task_events_tx + .send(TaskEvent::GracefulShutdown) + .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?; + + Ok(GracefulShutdown::new( + self.0.on_shutdown.clone(), + GracefulShutdownGuard::new(Arc::clone(&self.0.graceful_tasks)), + )) + } + + /// Fires the shutdown signal and waits until all graceful tasks complete. + pub fn graceful_shutdown(&self) { + let _ = self.do_graceful_shutdown(None); + } + + /// Fires the shutdown signal and waits until all graceful tasks complete or the timeout + /// elapses. + /// + /// Returns `true` if all tasks completed before the timeout. + pub fn graceful_shutdown_with_timeout(&self, timeout: Duration) -> bool { + self.do_graceful_shutdown(Some(timeout)) + } + + fn do_graceful_shutdown(&self, timeout: Option) -> bool { + let _ = self.0.task_events_tx.send(TaskEvent::GracefulShutdown); + let deadline = timeout.map(|t| std::time::Instant::now() + t); + while self.0.graceful_tasks.load(Ordering::SeqCst) > 0 { + if deadline.is_some_and(|d| std::time::Instant::now() > d) { + debug!("graceful shutdown timed out"); + return false; + } + std::thread::yield_now(); + } + debug!("gracefully shut down"); + true + } +} + +// ── TaskSpawner impl ────────────────────────────────────────────────── + +impl crate::TaskSpawner for Runtime { + fn spawn_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { + self.0.metrics.inc_regular_tasks(); + Self::spawn_task(self, fut) + } + + fn spawn_critical_task( + &self, + name: &'static str, + fut: BoxFuture<'static, ()>, + ) -> JoinHandle<()> { + self.0.metrics.inc_critical_tasks(); + Self::spawn_critical_task(self, name, fut) + } + + fn spawn_blocking_task(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> { + self.0.metrics.inc_regular_blocking_tasks(); + Self::spawn_blocking_task(self, fut) + } + + fn spawn_critical_blocking_task( + &self, + name: &'static str, + fut: BoxFuture<'static, ()>, + ) -> JoinHandle<()> { + self.0.metrics.inc_critical_tasks(); + Self::spawn_critical_blocking_task(self, name, fut) + } +} + +// ── TaskSpawnerExt impl ────────────────────────────────────────────── + +impl crate::TaskSpawnerExt for Runtime { + fn spawn_critical_with_graceful_shutdown_signal( + &self, + name: &'static str, + f: impl FnOnce(GracefulShutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + Self::spawn_critical_with_graceful_shutdown_signal(self, name, f) + } + + fn spawn_with_graceful_shutdown_signal( + &self, + f: impl FnOnce(GracefulShutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + Self::spawn_with_graceful_shutdown_signal(self, f) + } +} + +// ── RuntimeBuilder ──────────────────────────────────────────────────── + +/// Builder for constructing a [`Runtime`]. +#[derive(Debug, Clone)] +pub struct RuntimeBuilder { + config: RuntimeConfig, +} + +impl RuntimeBuilder { + /// Create a new builder with the given configuration. + pub const fn new(config: RuntimeConfig) -> Self { + Self { config } + } + + /// Build the [`Runtime`]. + /// + /// The [`TaskManager`] is automatically spawned as a background task that monitors + /// critical tasks for panics. Use [`Runtime::take_task_manager_handle`] to extract + /// the join handle if you need to poll for panic errors. + pub fn build(self) -> Result { + let config = self.config; + + let (owned_runtime, handle) = match &config.tokio { + TokioConfig::Owned { worker_threads, thread_keep_alive, thread_name } => { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder + .enable_all() + .thread_keep_alive(*thread_keep_alive) + .thread_name(*thread_name); + + if let Some(threads) = worker_threads { + builder.worker_threads(*threads); + } + + let runtime = builder.build()?; + let h = runtime.handle().clone(); + (Some(runtime), h) + } + TokioConfig::ExistingHandle(h) => (None, h.clone()), + }; + + let (task_manager, on_shutdown, task_events_tx, graceful_tasks) = + TaskManager::new_parts(handle.clone()); + + #[cfg(feature = "rayon")] + let (cpu_pool, rpc_pool, trie_pool, storage_pool, blocking_guard) = { + let default_threads = config.rayon.default_thread_count(); + let rpc_threads = config.rayon.rpc_threads.unwrap_or(default_threads); + let trie_threads = config.rayon.trie_threads.unwrap_or(default_threads); + + let cpu_pool = rayon::ThreadPoolBuilder::new() + .num_threads(default_threads) + .thread_name(|i| format!("reth-cpu-{i}")) + .build()?; + + let rpc_raw = rayon::ThreadPoolBuilder::new() + .num_threads(rpc_threads) + .thread_name(|i| format!("reth-rpc-{i}")) + .build()?; + let rpc_pool = BlockingTaskPool::new(rpc_raw); + + let trie_pool = rayon::ThreadPoolBuilder::new() + .num_threads(trie_threads) + .thread_name(|i| format!("reth-trie-{i}")) + .build()?; + + let storage_threads = + config.rayon.storage_threads.unwrap_or(DEFAULT_STORAGE_POOL_THREADS); + let storage_pool = rayon::ThreadPoolBuilder::new() + .num_threads(storage_threads) + .thread_name(|i| format!("reth-storage-{i}")) + .build()?; + + let blocking_guard = BlockingTaskGuard::new(config.rayon.max_blocking_tasks); + + debug!( + default_threads, + rpc_threads, + trie_threads, + storage_threads, + max_blocking_tasks = config.rayon.max_blocking_tasks, + "Initialized rayon thread pools" + ); + + (cpu_pool, rpc_pool, trie_pool, storage_pool, blocking_guard) + }; + + let task_manager_handle = handle.spawn(async move { + let result = task_manager.await; + if let Err(ref err) = result { + debug!("{err}"); + } + result + }); + + let inner = RuntimeInner { + _tokio_runtime: owned_runtime, + handle, + on_shutdown, + task_events_tx, + metrics: Default::default(), + graceful_tasks, + #[cfg(feature = "rayon")] + cpu_pool, + #[cfg(feature = "rayon")] + rpc_pool, + #[cfg(feature = "rayon")] + trie_pool, + #[cfg(feature = "rayon")] + storage_pool, + #[cfg(feature = "rayon")] + blocking_guard, + task_manager_handle: Mutex::new(Some(task_manager_handle)), + }; + + Ok(Runtime(Arc::new(inner))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_runtime_config_default() { + let config = RuntimeConfig::default(); + assert!(matches!(config.tokio, TokioConfig::Owned { .. })); + } + + #[test] + fn test_runtime_config_existing_handle() { + let rt = TokioRuntime::new().unwrap(); + let config = RuntimeConfig::with_existing_handle(rt.handle().clone()); + assert!(matches!(config.tokio, TokioConfig::ExistingHandle(_))); + } + + #[cfg(feature = "rayon")] + #[test] + fn test_rayon_config_thread_count() { + let config = RayonConfig::default(); + let count = config.default_thread_count(); + assert!(count >= 1); + } + + #[test] + fn test_runtime_builder() { + let rt = TokioRuntime::new().unwrap(); + let config = RuntimeConfig::with_existing_handle(rt.handle().clone()); + let runtime = RuntimeBuilder::new(config).build().unwrap(); + let _ = runtime.handle(); + } +} diff --git a/crates/transaction-pool/Cargo.toml b/crates/transaction-pool/Cargo.toml index 2fe1b88a6b..d794c004eb 100644 --- a/crates/transaction-pool/Cargo.toml +++ b/crates/transaction-pool/Cargo.toml @@ -109,6 +109,7 @@ test-utils = [ "alloy-primitives/rand", "reth-evm/test-utils", "reth-evm-ethereum/test-utils", + "reth-tasks/test-utils", ] arbitrary = [ "proptest", diff --git a/crates/transaction-pool/src/lib.rs b/crates/transaction-pool/src/lib.rs index b47da0caf5..2a216955ad 100644 --- a/crates/transaction-pool/src/lib.rs +++ b/crates/transaction-pool/src/lib.rs @@ -237,7 +237,7 @@ //! use reth_storage_api::{BlockReaderIdExt, StateProviderFactory}; //! use reth_tasks::TokioTaskExecutor; //! use reth_tasks::TaskSpawner; -//! use reth_tasks::TaskManager; +//! use reth_tasks::Runtime; //! use reth_transaction_pool::{TransactionValidationTaskExecutor, Pool}; //! use reth_transaction_pool::blobstore::InMemoryBlobStore; //! use reth_transaction_pool::maintain::{maintain_transaction_pool_future}; @@ -252,16 +252,15 @@ //! { //! let blob_store = InMemoryBlobStore::default(); //! let rt = tokio::runtime::Runtime::new().unwrap(); -//! let manager = TaskManager::new(rt.handle().clone()); -//! let executor = manager.executor(); +//! let runtime = Runtime::with_existing_handle(rt.handle().clone()).unwrap(); //! let pool = Pool::eth_pool( -//! TransactionValidationTaskExecutor::eth(client.clone(), evm_config, blob_store.clone(), executor.clone()), +//! TransactionValidationTaskExecutor::eth(client.clone(), evm_config, blob_store.clone(), runtime.clone()), //! blob_store, //! Default::default(), //! ); //! //! // spawn a task that listens for new blocks and updates the pool's transactions, mined transactions etc.. -//! tokio::task::spawn(maintain_transaction_pool_future(client, pool, stream, executor.clone(), Default::default())); +//! tokio::task::spawn(maintain_transaction_pool_future(client, pool, stream, runtime.clone(), Default::default())); //! //! # } //! ``` diff --git a/crates/transaction-pool/src/maintain.rs b/crates/transaction-pool/src/maintain.rs index 2d6409014d..4a708e8755 100644 --- a/crates/transaction-pool/src/maintain.rs +++ b/crates/transaction-pool/src/maintain.rs @@ -229,7 +229,7 @@ pub async fn maintain_transaction_pool( .boxed() }; reload_accounts_fut = rx.fuse(); - task_spawner.spawn_blocking(fut); + task_spawner.spawn_blocking_task(fut); } // check if we have a new finalized block @@ -243,7 +243,7 @@ pub async fn maintain_transaction_pool( pool.delete_blobs(blobs); // and also do periodic cleanup let pool = pool.clone(); - task_spawner.spawn_blocking(Box::pin(async move { + task_spawner.spawn_blocking_task(Box::pin(async move { debug!(target: "txpool", finalized_block = %finalized, "cleaning up blob store"); pool.cleanup_blobs(); })); @@ -517,7 +517,7 @@ pub async fn maintain_transaction_pool( let pool = pool.clone(); let spawner = task_spawner.clone(); let client = client.clone(); - task_spawner.spawn(Box::pin(async move { + task_spawner.spawn_task(Box::pin(async move { // Start converting not eaerlier than 4 seconds into current slot to ensure // that our pool only contains valid transactions for the next block (as // it's not Osaka yet). @@ -565,7 +565,7 @@ pub async fn maintain_transaction_pool( let converter = BlobSidecarConverter::new(); let pool = pool.clone(); - spawner.spawn(Box::pin(async move { + spawner.spawn_task(Box::pin(async move { // Convert sidecar to EIP-7594 format let Some(sidecar) = converter.convert(sidecar).await else { return; @@ -867,7 +867,7 @@ mod tests { use reth_evm_ethereum::EthEvmConfig; use reth_fs_util as fs; use reth_provider::test_utils::{ExtendedAccount, MockEthProvider}; - use reth_tasks::TaskManager; + use reth_tasks::Runtime; #[test] fn changed_acc_entry() { @@ -906,10 +906,9 @@ mod tests { txpool.add_transaction(TransactionOrigin::Local, transaction.clone()).await.unwrap(); - let handle = tokio::runtime::Handle::current(); - let manager = TaskManager::new(handle); + let rt = Runtime::with_existing_handle(tokio::runtime::Handle::current()).unwrap(); let config = LocalTransactionBackupConfig::with_local_txs_backup(transactions_path.clone()); - manager.executor().spawn_critical_with_graceful_shutdown_signal("test task", |shutdown| { + rt.spawn_critical_with_graceful_shutdown_signal("test task", |shutdown| { backup_local_transactions_task(shutdown, txpool.clone(), config) }); @@ -918,8 +917,7 @@ mod tests { assert_eq!(*tx_to_cmp.hash(), *tx_on_finish.hash()); - // shutdown the executor - manager.graceful_shutdown(); + rt.graceful_shutdown(); let data = fs::read(transactions_path).unwrap(); diff --git a/crates/transaction-pool/src/validate/eth.rs b/crates/transaction-pool/src/validate/eth.rs index e3867630a0..fa63902ab9 100644 --- a/crates/transaction-pool/src/validate/eth.rs +++ b/crates/transaction-pool/src/validate/eth.rs @@ -1269,14 +1269,14 @@ impl EthTransactionValidatorBuilder { // Spawn validation tasks, they are blocking because they perform db lookups for _ in 0..additional_tasks { let task = task.clone(); - tasks.spawn_blocking(Box::pin(async move { + tasks.spawn_blocking_task(Box::pin(async move { task.run().await; })); } // we spawn them on critical tasks because validation, especially for EIP-4844 can be quite // heavy - tasks.spawn_critical_blocking( + tasks.spawn_critical_blocking_task( "transaction-validation-service", Box::pin(async move { task.run().await; diff --git a/crates/trie/parallel/Cargo.toml b/crates/trie/parallel/Cargo.toml index 22636091ec..236999ab9d 100644 --- a/crates/trie/parallel/Cargo.toml +++ b/crates/trie/parallel/Cargo.toml @@ -19,6 +19,7 @@ reth-provider.workspace = true reth-storage-errors.workspace = true reth-trie-common.workspace = true reth-trie-sparse = { workspace = true, features = ["std"] } +reth-tasks = { workspace = true, features = ["rayon"] } reth-trie.workspace = true # alloy @@ -33,7 +34,6 @@ thiserror.workspace = true derive_more.workspace = true rayon.workspace = true itertools.workspace = true -tokio = { workspace = true, features = ["rt-multi-thread"] } crossbeam-channel.workspace = true # `metrics` feature @@ -64,6 +64,7 @@ test-utils = [ "reth-trie-db/test-utils", "reth-trie-sparse/test-utils", "reth-trie/test-utils", + "reth-tasks/test-utils", ] [[bench]] diff --git a/crates/trie/parallel/benches/root.rs b/crates/trie/parallel/benches/root.rs index f07fce527a..b2359c9f39 100644 --- a/crates/trie/parallel/benches/root.rs +++ b/crates/trie/parallel/benches/root.rs @@ -68,7 +68,11 @@ pub fn calculate_state_root(c: &mut Criterion) { b.iter_with_setup( || { let trie_input = TrieInput::from_state(updated_state.clone()); - ParallelStateRoot::new(factory.clone(), trie_input.prefix_sets.freeze()) + ParallelStateRoot::new( + factory.clone(), + trie_input.prefix_sets.freeze(), + reth_tasks::Runtime::test(), + ) }, |calculator| calculator.incremental_root(), ); diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index d42534c271..5c34d0367e 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -257,7 +257,6 @@ mod tests { use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; use reth_trie::proof::Proof; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; - use tokio::runtime::Runtime; #[test] fn random_parallel_proof() { @@ -321,14 +320,12 @@ mod tests { let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref()); let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref()); - let rt = Runtime::new().unwrap(); - let changeset_cache = reth_trie_db::ChangesetCache::new(); let factory = reth_provider::providers::OverlayStateProviderFactory::new(factory, changeset_cache); let task_ctx = ProofTaskCtx::new(factory); - let proof_worker_handle = - ProofWorkerHandle::new(rt.handle().clone(), task_ctx, 1, 1, false); + let runtime = reth_tasks::Runtime::test(); + let proof_worker_handle = ProofWorkerHandle::new(&runtime, task_ctx, 1, 1, false); let parallel_result = ParallelProof::new(Default::default(), proof_worker_handle.clone()) .decoded_multiproof(targets.clone()) diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index 340e0ba671..6ea3910d65 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -46,6 +46,7 @@ use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, StateProofErro use reth_primitives_traits::dashmap::{self, DashMap}; use reth_provider::{DatabaseProviderROFactory, ProviderError, ProviderResult}; use reth_storage_errors::db::DatabaseError; +use reth_tasks::Runtime; use reth_trie::{ hashed_cursor::{HashedCursorFactory, HashedCursorMetricsCache, InstrumentedHashedCursor}, node_iter::{TrieElement, TrieNodeIter}, @@ -74,7 +75,6 @@ use std::{ }, time::{Duration, Instant}, }; -use tokio::runtime::Handle; use tracing::{debug, debug_span, error, trace}; #[cfg(feature = "metrics")] @@ -132,13 +132,13 @@ impl ProofWorkerHandle { /// Workers run until the last handle is dropped. /// /// # Parameters - /// - `executor`: Tokio runtime handle for spawning blocking tasks + /// - `runtime`: The centralized runtime used to spawn blocking worker tasks /// - `task_ctx`: Shared context with database view and prefix sets /// - `storage_worker_count`: Number of storage workers to spawn /// - `account_worker_count`: Number of account workers to spawn /// - `v2_proofs_enabled`: Whether to enable V2 storage proofs pub fn new( - executor: Handle, + runtime: &Runtime, task_ctx: ProofTaskCtx, storage_worker_count: usize, account_worker_count: usize, @@ -178,7 +178,8 @@ impl ProofWorkerHandle { v2_proofs_enabled, }; - // Clone for the first spawn_blocking (storage workers) + let executor = runtime.handle().clone(); + let task_ctx_for_storage = task_ctx.clone(); let executor_for_storage = executor.clone(); let cached_storage_roots_for_storage = cached_storage_roots.clone(); @@ -2021,7 +2022,6 @@ enum AccountWorkerJob { mod tests { use super::*; use reth_provider::test_utils::create_test_provider_factory; - use tokio::{runtime::Builder, task}; fn test_ctx(factory: Factory) -> ProofTaskCtx { ProofTaskCtx::new(factory) @@ -2030,25 +2030,21 @@ mod tests { /// Ensures `ProofWorkerHandle::new` spawns workers correctly. #[test] fn spawn_proof_workers_creates_handle() { - let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap(); - runtime.block_on(async { - let handle = tokio::runtime::Handle::current(); - let provider_factory = create_test_provider_factory(); - let changeset_cache = reth_trie_db::ChangesetCache::new(); - let factory = reth_provider::providers::OverlayStateProviderFactory::new( - provider_factory, - changeset_cache, - ); - let ctx = test_ctx(factory); + let provider_factory = create_test_provider_factory(); + let changeset_cache = reth_trie_db::ChangesetCache::new(); + let factory = reth_provider::providers::OverlayStateProviderFactory::new( + provider_factory, + changeset_cache, + ); + let ctx = test_ctx(factory); - let proof_handle = ProofWorkerHandle::new(handle.clone(), ctx, 5, 3, false); + let runtime = reth_tasks::Runtime::test(); + let proof_handle = ProofWorkerHandle::new(&runtime, ctx, 5, 3, false); - // Verify handle can be cloned - let _cloned_handle = proof_handle.clone(); + // Verify handle can be cloned + let _cloned_handle = proof_handle.clone(); - // Workers shut down automatically when handle is dropped - drop(proof_handle); - task::yield_now().await; - }); + // Workers shut down automatically when handle is dropped + drop(proof_handle); } } diff --git a/crates/trie/parallel/src/root.rs b/crates/trie/parallel/src/root.rs index 1736fa3ff0..edd453cca1 100644 --- a/crates/trie/parallel/src/root.rs +++ b/crates/trie/parallel/src/root.rs @@ -7,6 +7,7 @@ use itertools::Itertools; use reth_execution_errors::{SparseTrieError, StateProofError, StorageRootError}; use reth_provider::{DatabaseProviderROFactory, ProviderError}; use reth_storage_errors::db::DatabaseError; +use reth_tasks::Runtime; use reth_trie::{ hashed_cursor::HashedCursorFactory, node_iter::{TrieElement, TrieNodeIter}, @@ -16,13 +17,8 @@ use reth_trie::{ walker::TrieWalker, HashBuilder, Nibbles, StorageRoot, TRIE_ACCOUNT_RLP_MAX_SIZE, }; -use std::{ - collections::HashMap, - sync::{mpsc, OnceLock}, - time::Duration, -}; +use std::{collections::HashMap, sync::mpsc}; use thiserror::Error; -use tokio::runtime::{Builder, Handle, Runtime}; use tracing::*; /// Parallel incremental state root calculator. @@ -41,6 +37,8 @@ pub struct ParallelStateRoot { factory: Factory, // Prefix sets indicating which portions of the trie need to be recomputed. prefix_sets: TriePrefixSets, + /// The runtime handle for spawning blocking tasks. + runtime: Runtime, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, @@ -48,10 +46,11 @@ pub struct ParallelStateRoot { impl ParallelStateRoot { /// Create new parallel state root calculator. - pub fn new(factory: Factory, prefix_sets: TriePrefixSets) -> Self { + pub fn new(factory: Factory, prefix_sets: TriePrefixSets, runtime: Runtime) -> Self { Self { factory, prefix_sets, + runtime, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), } @@ -97,8 +96,7 @@ where debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots"); let mut storage_roots = HashMap::with_capacity(storage_root_targets.len()); - // Get runtime handle once outside the loop - let handle = get_tokio_runtime_handle(); + let handle = self.runtime.handle().clone(); for (hashed_address, prefix_set) in storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) @@ -271,29 +269,6 @@ impl From for ParallelStateRootError { } } -/// Gets or creates a tokio runtime handle for spawning blocking tasks. -/// This ensures we always have a runtime available for I/O operations. -pub fn get_tokio_runtime_handle() -> Handle { - Handle::try_current().unwrap_or_else(|_| { - // Create a new runtime if no runtime is available - static RT: OnceLock = OnceLock::new(); - - let rt = RT.get_or_init(|| { - Builder::new_multi_thread() - .enable_all() - // Keep the threads alive for at least the block time (12 seconds) plus buffer. - // This prevents the costly process of spawning new threads on every - // new block, and instead reuses the existing threads. - .thread_keep_alive(Duration::from_secs(15)) - .thread_name("trie-tokio-rt") - .build() - .expect("Failed to create tokio runtime") - }); - - rt.handle().clone() - }) -} - #[cfg(test)] mod tests { use super::*; @@ -353,8 +328,9 @@ mod tests { provider_rw.commit().unwrap(); } + let runtime = reth_tasks::Runtime::test(); assert_eq!( - ParallelStateRoot::new(overlay_factory.clone(), Default::default()) + ParallelStateRoot::new(overlay_factory.clone(), Default::default(), runtime.clone()) .incremental_root() .unwrap(), test_utils::state_root(state.clone()) @@ -390,7 +366,7 @@ mod tests { overlay_factory.with_hashed_state_overlay(Some(Arc::new(hashed_state.into_sorted()))); assert_eq!( - ParallelStateRoot::new(overlay_factory, prefix_sets.freeze()) + ParallelStateRoot::new(overlay_factory, prefix_sets.freeze(), runtime) .incremental_root() .unwrap(), test_utils::state_root(state) diff --git a/docs/vocs/docs/pages/sdk/examples/standalone-components.mdx b/docs/vocs/docs/pages/sdk/examples/standalone-components.mdx index 9fa0f75552..f8436dc987 100644 --- a/docs/vocs/docs/pages/sdk/examples/standalone-components.mdx +++ b/docs/vocs/docs/pages/sdk/examples/standalone-components.mdx @@ -34,9 +34,9 @@ The safest way to access the database is through Reth's provider factory: use reth_ethereum::node::EthereumNode; use reth_ethereum::chainspec::MAINNET; -// Open with automatic configuration +// Open with automatic configuration (requires a Runtime for task spawning) let factory = EthereumNode::provider_factory_builder() - .open_read_only(MAINNET.clone(), "path/to/datadir")?; + .open_read_only(MAINNET.clone(), "path/to/datadir", runtime)?; // Get a provider for queries let provider = factory.provider()?; @@ -72,7 +72,7 @@ To opt out of this, this safety mechanism can be disabled: ```rust let factory = EthereumNode::provider_factory_builder() - .open_read_only(MAINNET.clone(), ReadOnlyConfig::from_datadir("datadir").disable_long_read_transaction_safety())?; + .open_read_only(MAINNET.clone(), ReadOnlyConfig::from_datadir("datadir").disable_long_read_transaction_safety(), runtime)?; ``` ### Real-time Block Access Configuration diff --git a/examples/beacon-api-sidecar-fetcher/src/main.rs b/examples/beacon-api-sidecar-fetcher/src/main.rs index 4ec1727bc4..10c44fb53e 100644 --- a/examples/beacon-api-sidecar-fetcher/src/main.rs +++ b/examples/beacon-api-sidecar-fetcher/src/main.rs @@ -42,7 +42,7 @@ fn main() { let pool = node.pool.clone(); - node.task_executor.spawn(async move { + node.task_executor.spawn_task(async move { let mut sidecar_stream = MinedSidecarStream { events: notifications, pool, diff --git a/examples/beacon-api-sse/src/main.rs b/examples/beacon-api-sse/src/main.rs index fee20e09b1..014b961fe7 100644 --- a/examples/beacon-api-sse/src/main.rs +++ b/examples/beacon-api-sse/src/main.rs @@ -33,7 +33,7 @@ fn main() { .run(|builder, args| async move { let handle = builder.node(EthereumNode::default()).launch().await?; - handle.node.task_executor.spawn(Box::pin(args.run())); + handle.node.task_executor.spawn_task(Box::pin(args.run())); handle.wait_for_node_exit().await }) diff --git a/examples/custom-dev-node/src/main.rs b/examples/custom-dev-node/src/main.rs index c5441a2b38..c55d6cc844 100644 --- a/examples/custom-dev-node/src/main.rs +++ b/examples/custom-dev-node/src/main.rs @@ -17,12 +17,12 @@ use reth_ethereum::{ }, provider::CanonStateSubscriptions, rpc::api::eth::helpers::EthTransactions, - tasks::TaskManager, + tasks::Runtime, }; #[tokio::main] async fn main() -> eyre::Result<()> { - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; // create node config let node_config = NodeConfig::test() @@ -31,7 +31,7 @@ async fn main() -> eyre::Result<()> { .with_chain(custom_chain()); let NodeHandle { node, node_exit_future: _ } = NodeBuilder::new(node_config) - .testing_node(tasks.executor()) + .testing_node(runtime) .node(EthereumNode::default()) .launch_with_debug_capabilities() .await?; diff --git a/examples/custom-engine-types/src/main.rs b/examples/custom-engine-types/src/main.rs index 9a0442ed81..ffbaa19002 100644 --- a/examples/custom-engine-types/src/main.rs +++ b/examples/custom-engine-types/src/main.rs @@ -54,7 +54,7 @@ use reth_ethereum::{ primitives::{Block, SealedBlock}, provider::{EthStorage, StateProviderFactory}, rpc::types::engine::ExecutionPayload, - tasks::TaskManager, + tasks::Runtime, EthPrimitives, TransactionSigned, }; use reth_ethereum_payload_builder::{EthereumBuilderConfig, EthereumExecutionPayloadValidator}; @@ -391,7 +391,7 @@ where async fn main() -> eyre::Result<()> { let _guard = RethTracer::new().init()?; - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; // create genesis with canyon at block 2 let spec = ChainSpec::builder() @@ -407,7 +407,7 @@ async fn main() -> eyre::Result<()> { NodeConfig::test().with_rpc(RpcServerArgs::default().with_http()).with_chain(spec); let handle = NodeBuilder::new(node_config) - .testing_node(tasks.executor()) + .testing_node(runtime) .launch_node(MyCustomNode::default()) .await .unwrap(); diff --git a/examples/custom-evm/src/main.rs b/examples/custom-evm/src/main.rs index e32f0be6bd..07a3b009d3 100644 --- a/examples/custom-evm/src/main.rs +++ b/examples/custom-evm/src/main.rs @@ -35,7 +35,7 @@ use reth_ethereum::{ node::EthereumAddOns, EthereumNode, }, - tasks::TaskManager, + tasks::Runtime, EthPrimitives, }; use reth_tracing::{RethTracer, Tracer}; @@ -121,7 +121,7 @@ pub fn prague_custom() -> &'static Precompiles { async fn main() -> eyre::Result<()> { let _guard = RethTracer::new().init()?; - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; // create a custom chain spec let spec = ChainSpec::builder() @@ -138,7 +138,7 @@ async fn main() -> eyre::Result<()> { NodeConfig::test().with_rpc(RpcServerArgs::default().with_http()).with_chain(spec); let handle = NodeBuilder::new(node_config) - .testing_node(tasks.executor()) + .testing_node(runtime) // configure the node with regular ethereum types .with_types::() // use default ethereum components but with our executor diff --git a/examples/custom-inspector/src/main.rs b/examples/custom-inspector/src/main.rs index 0249a4a258..f75781f5d7 100644 --- a/examples/custom-inspector/src/main.rs +++ b/examples/custom-inspector/src/main.rs @@ -50,7 +50,7 @@ fn main() { println!("Spawning trace task!"); // Spawn an async block to listen for transactions. - node.task_executor.spawn(Box::pin(async move { + node.task_executor.spawn_task(Box::pin(async move { // Waiting for new transactions while let Some(event) = pending_transactions.next().await { let tx = event.transaction; diff --git a/examples/custom-node-components/src/main.rs b/examples/custom-node-components/src/main.rs index ebb1246457..02620432a8 100644 --- a/examples/custom-node-components/src/main.rs +++ b/examples/custom-node-components/src/main.rs @@ -96,7 +96,7 @@ where ); // spawn the maintenance task - ctx.task_executor().spawn_critical( + ctx.task_executor().spawn_critical_task( "txpool maintenance task", reth_ethereum::pool::maintain::maintain_transaction_pool_future( client, diff --git a/examples/custom-payload-builder/src/main.rs b/examples/custom-payload-builder/src/main.rs index 3de8297090..33eb92483f 100644 --- a/examples/custom-payload-builder/src/main.rs +++ b/examples/custom-payload-builder/src/main.rs @@ -84,7 +84,7 @@ where PayloadBuilderService::new(payload_generator, ctx.provider().canonical_state_stream()); ctx.task_executor() - .spawn_critical("custom payload builder service", Box::pin(payload_service)); + .spawn_critical_task("custom payload builder service", Box::pin(payload_service)); Ok(payload_builder) } diff --git a/examples/custom-rlpx-subprotocol/src/main.rs b/examples/custom-rlpx-subprotocol/src/main.rs index 91ec308abf..d64022df93 100644 --- a/examples/custom-rlpx-subprotocol/src/main.rs +++ b/examples/custom-rlpx-subprotocol/src/main.rs @@ -61,7 +61,7 @@ fn main() -> eyre::Result<()> { let subnetwork_peer_id = *subnetwork.peer_id(); let subnetwork_peer_addr = subnetwork.local_addr(); let subnetwork_handle = subnetwork.peers_handle(); - node.task_executor.spawn(subnetwork); + node.task_executor.spawn_task(subnetwork); // connect the launched node to the subnetwork node.network.peers_handle().add_peer(subnetwork_peer_id, subnetwork_peer_addr); diff --git a/examples/db-access/src/main.rs b/examples/db-access/src/main.rs index 1042ac55be..25a8d3eabf 100644 --- a/examples/db-access/src/main.rs +++ b/examples/db-access/src/main.rs @@ -25,8 +25,12 @@ fn main() -> eyre::Result<()> { // Instantiate a provider factory for Ethereum mainnet using the provided datadir path. let spec = ChainSpecBuilder::mainnet().build(); - let factory = EthereumNode::provider_factory_builder() - .open_read_only(spec.into(), ReadOnlyConfig::from_datadir(datadir))?; + let runtime = reth_ethereum::tasks::Runtime::test(); + let factory = EthereumNode::provider_factory_builder().open_read_only( + spec.into(), + ReadOnlyConfig::from_datadir(datadir), + runtime, + )?; // This call opens a RO transaction on the database. To write to the DB you'd need to call // the `provider_rw` function and look for the `Writer` variants of the traits. diff --git a/examples/full-contract-state/src/main.rs b/examples/full-contract-state/src/main.rs index bad7707415..8536e1327e 100644 --- a/examples/full-contract-state/src/main.rs +++ b/examples/full-contract-state/src/main.rs @@ -72,8 +72,12 @@ fn main() -> eyre::Result<()> { let datadir = std::env::var("RETH_DATADIR")?; let spec = ChainSpecBuilder::mainnet().build(); - let factory = EthereumNode::provider_factory_builder() - .open_read_only(spec.into(), ReadOnlyConfig::from_datadir(datadir))?; + let runtime = reth_ethereum::tasks::Runtime::test(); + let factory = EthereumNode::provider_factory_builder().open_read_only( + spec.into(), + ReadOnlyConfig::from_datadir(datadir), + runtime, + )?; let provider = factory.provider()?; let state_provider = factory.latest()?; diff --git a/examples/precompile-cache/src/main.rs b/examples/precompile-cache/src/main.rs index fe748db463..ee387c583a 100644 --- a/examples/precompile-cache/src/main.rs +++ b/examples/precompile-cache/src/main.rs @@ -33,7 +33,7 @@ use reth_ethereum::{ node::EthereumAddOns, EthEvmConfig, EthereumNode, }, - tasks::TaskManager, + tasks::Runtime, EthPrimitives, }; use reth_tracing::{RethTracer, Tracer}; @@ -187,7 +187,7 @@ where async fn main() -> eyre::Result<()> { let _guard = RethTracer::new().init()?; - let tasks = TaskManager::current(); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; // create a custom chain spec let spec = ChainSpec::builder() @@ -203,7 +203,7 @@ async fn main() -> eyre::Result<()> { NodeConfig::test().with_rpc(RpcServerArgs::default().with_http()).with_chain(spec); let handle = NodeBuilder::new(node_config) - .testing_node(tasks.executor()) + .testing_node(runtime) // configure the node with regular ethereum types .with_types::() // use default ethereum components but with our executor diff --git a/examples/rpc-db/src/main.rs b/examples/rpc-db/src/main.rs index 1d3e159c28..5b391c380a 100644 --- a/examples/rpc-db/src/main.rs +++ b/examples/rpc-db/src/main.rs @@ -31,7 +31,7 @@ use reth_ethereum::{ builder::{RethRpcModule, RpcModuleBuilder, RpcServerConfig, TransportRpcModuleConfig}, EthApiBuilder, }, - tasks::TokioTaskExecutor, + tasks::{Runtime, TokioTaskExecutor}, }; // Configuring the network parts, ideally also wouldn't need to think about this. use myrpc_ext::{MyRpcExt, MyRpcExtApiServer}; @@ -49,11 +49,13 @@ async fn main() -> eyre::Result<()> { DatabaseArguments::new(ClientVersion::default()), )?; let spec = Arc::new(ChainSpecBuilder::mainnet().build()); + let runtime = Runtime::with_existing_handle(tokio::runtime::Handle::current())?; let factory = ProviderFactory::>::new( db.clone(), spec.clone(), StaticFileProvider::read_only(db_path.join("static_files"), true)?, RocksDBProvider::builder(db_path.join("rocksdb")).build().unwrap(), + runtime, )?; // 2. Set up the blockchain provider using only the database provider and a noop for the tree to diff --git a/examples/txpool-tracing/src/main.rs b/examples/txpool-tracing/src/main.rs index f510a3f68b..04c8c146ef 100644 --- a/examples/txpool-tracing/src/main.rs +++ b/examples/txpool-tracing/src/main.rs @@ -38,7 +38,7 @@ fn main() { println!("Spawning trace task!"); // Spawn an async block to listen for transactions. - node.task_executor.spawn(Box::pin(async move { + node.task_executor.spawn_task(Box::pin(async move { // Waiting for new transactions while let Some(event) = pending_transactions.next().await { let tx = event.transaction;