diff --git a/crates/cli/commands/src/db/mod.rs b/crates/cli/commands/src/db/mod.rs index dee7c71c1e..d27afab79c 100644 --- a/crates/cli/commands/src/db/mod.rs +++ b/crates/cli/commands/src/db/mod.rs @@ -2,6 +2,7 @@ use crate::common::{AccessRights, CliNodeTypes, Environment, EnvironmentArgs}; use clap::{Parser, Subcommand}; use reth_chainspec::{EthChainSpec, EthereumHardforks}; use reth_cli::chainspec::ChainSpecParser; +use reth_cli_runner::CliContext; use reth_db::version::{get_db_version, DatabaseVersionError, DB_VERSION}; use reth_db_common::DbTool; use std::{ @@ -79,7 +80,10 @@ macro_rules! db_exec { impl> Command { /// Execute `db` command - pub async fn execute>(self) -> eyre::Result<()> { + pub async fn execute>( + self, + ctx: CliContext, + ) -> eyre::Result<()> { let data_dir = self.env.datadir.clone().resolve_datadir(self.env.chain.chain()); let db_path = data_dir.db(); let static_files_path = data_dir.static_files(); @@ -158,7 +162,7 @@ impl> Command let access_rights = if command.dry_run { AccessRights::RO } else { AccessRights::RW }; db_exec!(self.env, tool, N, access_rights, { - command.execute(&tool)?; + command.execute(&tool, ctx.task_executor.clone())?; }); } Subcommands::StaticFileHeader(command) => { diff --git a/crates/cli/commands/src/db/repair_trie.rs b/crates/cli/commands/src/db/repair_trie.rs index 351a83716d..3ccda64afb 100644 --- a/crates/cli/commands/src/db/repair_trie.rs +++ b/crates/cli/commands/src/db/repair_trie.rs @@ -18,6 +18,7 @@ use reth_node_metrics::{ }; use reth_provider::{providers::ProviderNodeTypes, ChainSpecProvider, StageCheckpointReader}; use reth_stages::StageId; +use reth_tasks::TaskExecutor; use reth_trie::{ verify::{Output, Verifier}, Nibbles, @@ -48,52 +49,37 @@ pub struct Command { impl Command { /// Execute `db repair-trie` command - pub fn execute(self, tool: &DbTool) -> eyre::Result<()> { + pub fn execute( + self, + tool: &DbTool, + task_executor: TaskExecutor, + ) -> eyre::Result<()> { // Set up metrics server if requested let _metrics_handle = if let Some(listen_addr) = self.metrics { - // Spawn an OS thread with a single-threaded tokio runtime for the metrics server let chain_name = tool.provider_factory.chain_spec().chain().to_string(); + let executor = task_executor.clone(); - let handle = std::thread::Builder::new().name("metrics-server".to_string()).spawn( - move || { - // Create a single-threaded tokio runtime - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to create tokio runtime for metrics server"); + let handle = task_executor.spawn_critical("metrics server", async move { + let config = MetricServerConfig::new( + listen_addr, + VersionInfo { + version: version_metadata().cargo_pkg_version.as_ref(), + build_timestamp: version_metadata().vergen_build_timestamp.as_ref(), + cargo_features: version_metadata().vergen_cargo_features.as_ref(), + git_sha: version_metadata().vergen_git_sha.as_ref(), + target_triple: version_metadata().vergen_cargo_target_triple.as_ref(), + build_profile: version_metadata().build_profile_name.as_ref(), + }, + ChainSpecInfo { name: chain_name }, + executor, + Hooks::builder().build(), + ); - let handle = runtime.handle().clone(); - runtime.block_on(async move { - let task_manager = reth_tasks::TaskManager::new(handle.clone()); - let task_executor = task_manager.executor(); - - let config = MetricServerConfig::new( - listen_addr, - VersionInfo { - version: version_metadata().cargo_pkg_version.as_ref(), - build_timestamp: version_metadata().vergen_build_timestamp.as_ref(), - cargo_features: version_metadata().vergen_cargo_features.as_ref(), - git_sha: version_metadata().vergen_git_sha.as_ref(), - target_triple: version_metadata() - .vergen_cargo_target_triple - .as_ref(), - build_profile: version_metadata().build_profile_name.as_ref(), - }, - ChainSpecInfo { name: chain_name }, - task_executor, - Hooks::builder().build(), - ); - - // Spawn the metrics server - if let Err(e) = MetricServer::new(config).serve().await { - tracing::error!("Metrics server error: {}", e); - } - - // Block forever to keep the runtime alive - std::future::pending::<()>().await - }); - }, - )?; + // Spawn the metrics server + if let Err(e) = MetricServer::new(config).serve().await { + tracing::error!("Metrics server error: {}", e); + } + }); Some(handle) } else { diff --git a/crates/cli/runner/src/lib.rs b/crates/cli/runner/src/lib.rs index 79dc6b2114..63a9de7722 100644 --- a/crates/cli/runner/src/lib.rs +++ b/crates/cli/runner/src/lib.rs @@ -97,6 +97,57 @@ impl CliRunner { command_res } + /// Executes a command in a blocking context with access to `CliContext`. + /// + /// See [`Runtime::spawn_blocking`](tokio::runtime::Runtime::spawn_blocking). + pub fn run_blocking_command_until_exit( + self, + command: impl FnOnce(CliContext) -> F + Send + 'static, + ) -> Result<(), E> + where + F: Future> + Send + 'static, + E: Send + Sync + From + From + 'static, + { + let AsyncCliRunner { context, mut task_manager, tokio_runtime } = + AsyncCliRunner::new(self.tokio_runtime); + + // Spawn the command on the blocking thread pool + let handle = tokio_runtime.handle().clone(); + let command_handle = + tokio_runtime.handle().spawn_blocking(move || handle.block_on(command(context))); + + // Wait for the command to complete or ctrl-c + let command_res = tokio_runtime.block_on(run_to_completion_or_panic( + &mut task_manager, + run_until_ctrl_c( + async move { command_handle.await.expect("Failed to join blocking task") }, + ), + )); + + if command_res.is_err() { + error!(target: "reth::cli", "shutting down due to error"); + } else { + debug!(target: "reth::cli", "shutting down gracefully"); + task_manager.graceful_shutdown_with_timeout(Duration::from_secs(5)); + } + + // Shutdown the runtime on a separate thread + let (tx, rx) = mpsc::channel(); + std::thread::Builder::new() + .name("tokio-runtime-shutdown".to_string()) + .spawn(move || { + drop(tokio_runtime); + let _ = tx.send(()); + }) + .unwrap(); + + let _ = rx.recv_timeout(Duration::from_secs(5)).inspect_err(|err| { + debug!(target: "reth::cli", %err, "tokio runtime shutdown timed out"); + }); + + command_res + } + /// Executes a regular future until completion or until external signal received. pub fn run_until_ctrl_c(self, fut: F) -> Result<(), E> where diff --git a/crates/ethereum/cli/src/app.rs b/crates/ethereum/cli/src/app.rs index 00f50a479d..b9561c9c44 100644 --- a/crates/ethereum/cli/src/app.rs +++ b/crates/ethereum/cli/src/app.rs @@ -154,7 +154,9 @@ where Commands::ImportEra(command) => runner.run_blocking_until_ctrl_c(command.execute::()), Commands::ExportEra(command) => runner.run_blocking_until_ctrl_c(command.execute::()), Commands::DumpGenesis(command) => runner.run_blocking_until_ctrl_c(command.execute()), - Commands::Db(command) => runner.run_blocking_until_ctrl_c(command.execute::()), + Commands::Db(command) => { + runner.run_blocking_command_until_exit(|ctx| command.execute::(ctx)) + } Commands::Download(command) => runner.run_blocking_until_ctrl_c(command.execute::()), Commands::Stage(command) => { runner.run_command_until_exit(|ctx| command.execute::(ctx, components)) diff --git a/crates/optimism/cli/src/app.rs b/crates/optimism/cli/src/app.rs index 34048b3d94..a23873daad 100644 --- a/crates/optimism/cli/src/app.rs +++ b/crates/optimism/cli/src/app.rs @@ -98,7 +98,9 @@ where runner.run_blocking_until_ctrl_c(command.execute::()) } Commands::DumpGenesis(command) => runner.run_blocking_until_ctrl_c(command.execute()), - Commands::Db(command) => runner.run_blocking_until_ctrl_c(command.execute::()), + Commands::Db(command) => { + runner.run_blocking_command_until_exit(|ctx| command.execute::(ctx)) + } Commands::Stage(command) => { runner.run_command_until_exit(|ctx| command.execute::(ctx, components)) }