mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-08 14:05:16 -05:00
feat(tasks): enable graceful shutdown request via TaskExecutor (#16386)
Signed-off-by: 7suyash7 <suyashnyn1@gmail.com> Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
@@ -174,8 +174,10 @@ where
|
||||
{
|
||||
let fut = pin!(fut);
|
||||
tokio::select! {
|
||||
err = tasks => {
|
||||
return Err(err.into())
|
||||
task_manager_result = tasks => {
|
||||
if let Err(panicked_error) = task_manager_result {
|
||||
return Err(panicked_error.into());
|
||||
}
|
||||
},
|
||||
res = fut => res?,
|
||||
}
|
||||
|
||||
@@ -162,10 +162,10 @@ pub struct TaskManager {
|
||||
///
|
||||
/// See [`Handle`] docs.
|
||||
handle: Handle,
|
||||
/// Sender half for sending panic signals to this type
|
||||
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
|
||||
/// Listens for panicked tasks
|
||||
panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
|
||||
/// Sender half for sending task events to this type
|
||||
task_events_tx: UnboundedSender<TaskEvent>,
|
||||
/// Receiver for task events
|
||||
task_events_rx: UnboundedReceiver<TaskEvent>,
|
||||
/// The [Signal] to fire when all tasks should be shutdown.
|
||||
///
|
||||
/// This is fired when dropped.
|
||||
@@ -197,12 +197,12 @@ impl TaskManager {
|
||||
///
|
||||
/// This also sets the global [`TaskExecutor`].
|
||||
pub fn new(handle: Handle) -> Self {
|
||||
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
|
||||
let (task_events_tx, task_events_rx) = unbounded_channel();
|
||||
let (signal, on_shutdown) = signal();
|
||||
let manager = Self {
|
||||
handle,
|
||||
panicked_tasks_tx,
|
||||
panicked_tasks_rx,
|
||||
task_events_tx,
|
||||
task_events_rx,
|
||||
signal: Some(signal),
|
||||
on_shutdown,
|
||||
graceful_tasks: Arc::new(AtomicUsize::new(0)),
|
||||
@@ -221,7 +221,7 @@ impl TaskManager {
|
||||
TaskExecutor {
|
||||
handle: self.handle.clone(),
|
||||
on_shutdown: self.on_shutdown.clone(),
|
||||
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
|
||||
task_events_tx: self.task_events_tx.clone(),
|
||||
metrics: Default::default(),
|
||||
graceful_tasks: Arc::clone(&self.graceful_tasks),
|
||||
}
|
||||
@@ -259,16 +259,23 @@ impl TaskManager {
|
||||
///
|
||||
/// See [`TaskExecutor::spawn_critical`]
|
||||
impl Future for TaskManager {
|
||||
type Output = PanickedTaskError;
|
||||
type Output = Result<(), PanickedTaskError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
|
||||
Poll::Ready(err.expect("stream can not end"))
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
|
||||
Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
|
||||
Some(TaskEvent::GracefulShutdown) | None => {
|
||||
if let Some(signal) = self.get_mut().signal.take() {
|
||||
signal.fire();
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error with the name of the task that panicked and an error downcasted to string, if possible.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
|
||||
pub struct PanickedTaskError {
|
||||
task_name: &'static str,
|
||||
error: Option<String>,
|
||||
@@ -299,6 +306,15 @@ impl PanickedTaskError {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the events that the `TaskManager`'s main future can receive.
|
||||
#[derive(Debug)]
|
||||
enum TaskEvent {
|
||||
/// Indicates that a critical task has panicked.
|
||||
Panic(PanickedTaskError),
|
||||
/// A signal requesting a graceful shutdown of the `TaskManager`.
|
||||
GracefulShutdown,
|
||||
}
|
||||
|
||||
/// A type that can spawn new tokio tasks
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskExecutor {
|
||||
@@ -308,8 +324,8 @@ pub struct TaskExecutor {
|
||||
handle: Handle,
|
||||
/// Receiver of the shutdown signal.
|
||||
on_shutdown: Shutdown,
|
||||
/// Sender half for sending panic signals to this type
|
||||
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
|
||||
/// Sender half for sending task events to this type
|
||||
task_events_tx: UnboundedSender<TaskEvent>,
|
||||
/// Task Executor Metrics
|
||||
metrics: TaskExecutorMetrics,
|
||||
/// How many [`GracefulShutdown`] tasks are currently active
|
||||
@@ -433,7 +449,7 @@ impl TaskExecutor {
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
||||
let panicked_tasks_tx = self.task_events_tx.clone();
|
||||
let on_shutdown = self.on_shutdown.clone();
|
||||
|
||||
// wrap the task in catch unwind
|
||||
@@ -442,7 +458,7 @@ impl TaskExecutor {
|
||||
.map_err(move |error| {
|
||||
let task_error = PanickedTaskError::new(name, error);
|
||||
error!("{task_error}");
|
||||
let _ = panicked_tasks_tx.send(task_error);
|
||||
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
|
||||
})
|
||||
.in_current_span();
|
||||
|
||||
@@ -492,7 +508,7 @@ impl TaskExecutor {
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
||||
let panicked_tasks_tx = self.task_events_tx.clone();
|
||||
let on_shutdown = self.on_shutdown.clone();
|
||||
let fut = f(on_shutdown);
|
||||
|
||||
@@ -502,7 +518,7 @@ impl TaskExecutor {
|
||||
.map_err(move |error| {
|
||||
let task_error = PanickedTaskError::new(name, error);
|
||||
error!("{task_error}");
|
||||
let _ = panicked_tasks_tx.send(task_error);
|
||||
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
|
||||
})
|
||||
.map(drop)
|
||||
.in_current_span();
|
||||
@@ -538,7 +554,7 @@ impl TaskExecutor {
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
||||
let panicked_tasks_tx = self.task_events_tx.clone();
|
||||
let on_shutdown = GracefulShutdown::new(
|
||||
self.on_shutdown.clone(),
|
||||
GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
|
||||
@@ -551,7 +567,7 @@ impl TaskExecutor {
|
||||
.map_err(move |error| {
|
||||
let task_error = PanickedTaskError::new(name, error);
|
||||
error!("{task_error}");
|
||||
let _ = panicked_tasks_tx.send(task_error);
|
||||
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
|
||||
})
|
||||
.map(drop)
|
||||
.in_current_span();
|
||||
@@ -593,6 +609,25 @@ impl TaskExecutor {
|
||||
|
||||
self.handle.spawn(fut)
|
||||
}
|
||||
|
||||
/// Sends a request to the `TaskManager` to initiate a graceful shutdown.
|
||||
///
|
||||
/// Caution: This will terminate the entire program.
|
||||
///
|
||||
/// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that
|
||||
/// can be handled via the returned [`GracefulShutdown`].
|
||||
pub fn initiate_graceful_shutdown(
|
||||
&self,
|
||||
) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
|
||||
self.task_events_tx
|
||||
.send(TaskEvent::GracefulShutdown)
|
||||
.map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
|
||||
|
||||
Ok(GracefulShutdown::new(
|
||||
self.on_shutdown.clone(),
|
||||
GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskSpawner for TaskExecutor {
|
||||
@@ -711,9 +746,12 @@ mod tests {
|
||||
executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
|
||||
|
||||
runtime.block_on(async move {
|
||||
let err = manager.await;
|
||||
assert_eq!(err.task_name, "this is a critical task");
|
||||
assert_eq!(err.error, Some("intentionally panic".to_string()));
|
||||
let err_result = manager.await;
|
||||
assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
|
||||
let panicked_err = err_result.unwrap_err();
|
||||
|
||||
assert_eq!(panicked_err.task_name, "this is a critical task");
|
||||
assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
|
||||
})
|
||||
}
|
||||
|
||||
@@ -829,4 +867,41 @@ mod tests {
|
||||
let _manager = TaskManager::new(handle);
|
||||
let _executor = TaskExecutor::try_current().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graceful_shutdown_triggered_by_executor() {
|
||||
let runtime = tokio::runtime::Runtime::new().unwrap();
|
||||
let task_manager = TaskManager::new(runtime.handle().clone());
|
||||
let executor = task_manager.executor();
|
||||
|
||||
let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
|
||||
let flag_clone = task_did_shutdown_flag.clone();
|
||||
|
||||
let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
|
||||
shutdown_signal.await;
|
||||
flag_clone.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
let manager_future_handle = runtime.spawn(task_manager);
|
||||
|
||||
let send_result = executor.initiate_graceful_shutdown();
|
||||
assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
|
||||
|
||||
let manager_final_result = runtime.block_on(manager_future_handle);
|
||||
|
||||
assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
|
||||
assert_eq!(
|
||||
manager_final_result.unwrap(),
|
||||
Ok(()),
|
||||
"TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
|
||||
);
|
||||
|
||||
let task_join_result = runtime.block_on(spawned_task_handle);
|
||||
assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
|
||||
|
||||
assert!(
|
||||
task_did_shutdown_flag.load(Ordering::Relaxed),
|
||||
"Task should have received the shutdown signal and set the flag"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user