feat(tasks): enable graceful shutdown request via TaskExecutor (#16386)

Signed-off-by: 7suyash7 <suyashnyn1@gmail.com>
Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
Suyash Nayan
2025-05-22 15:57:26 +05:30
committed by GitHub
parent 6cf363ba88
commit 9a1e4ffd7e
2 changed files with 103 additions and 26 deletions

View File

@@ -174,8 +174,10 @@ where
{
let fut = pin!(fut);
tokio::select! {
err = tasks => {
return Err(err.into())
task_manager_result = tasks => {
if let Err(panicked_error) = task_manager_result {
return Err(panicked_error.into());
}
},
res = fut => res?,
}

View File

@@ -162,10 +162,10 @@ pub struct TaskManager {
///
/// See [`Handle`] docs.
handle: Handle,
/// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
/// Listens for panicked tasks
panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
/// Sender half for sending task events to this type
task_events_tx: UnboundedSender<TaskEvent>,
/// Receiver for task events
task_events_rx: UnboundedReceiver<TaskEvent>,
/// The [Signal] to fire when all tasks should be shutdown.
///
/// This is fired when dropped.
@@ -197,12 +197,12 @@ impl TaskManager {
///
/// This also sets the global [`TaskExecutor`].
pub fn new(handle: Handle) -> Self {
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
let (task_events_tx, task_events_rx) = unbounded_channel();
let (signal, on_shutdown) = signal();
let manager = Self {
handle,
panicked_tasks_tx,
panicked_tasks_rx,
task_events_tx,
task_events_rx,
signal: Some(signal),
on_shutdown,
graceful_tasks: Arc::new(AtomicUsize::new(0)),
@@ -221,7 +221,7 @@ impl TaskManager {
TaskExecutor {
handle: self.handle.clone(),
on_shutdown: self.on_shutdown.clone(),
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
task_events_tx: self.task_events_tx.clone(),
metrics: Default::default(),
graceful_tasks: Arc::clone(&self.graceful_tasks),
}
@@ -259,16 +259,23 @@ impl TaskManager {
///
/// See [`TaskExecutor::spawn_critical`]
impl Future for TaskManager {
type Output = PanickedTaskError;
type Output = Result<(), PanickedTaskError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
Poll::Ready(err.expect("stream can not end"))
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
Some(TaskEvent::GracefulShutdown) | None => {
if let Some(signal) = self.get_mut().signal.take() {
signal.fire();
}
Poll::Ready(Ok(()))
}
}
}
}
/// Error with the name of the task that panicked and an error downcasted to string, if possible.
#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub struct PanickedTaskError {
task_name: &'static str,
error: Option<String>,
@@ -299,6 +306,15 @@ impl PanickedTaskError {
}
}
/// Represents the events that the `TaskManager`'s main future can receive.
#[derive(Debug)]
enum TaskEvent {
/// Indicates that a critical task has panicked.
Panic(PanickedTaskError),
/// A signal requesting a graceful shutdown of the `TaskManager`.
GracefulShutdown,
}
/// A type that can spawn new tokio tasks
#[derive(Debug, Clone)]
pub struct TaskExecutor {
@@ -308,8 +324,8 @@ pub struct TaskExecutor {
handle: Handle,
/// Receiver of the shutdown signal.
on_shutdown: Shutdown,
/// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
/// Sender half for sending task events to this type
task_events_tx: UnboundedSender<TaskEvent>,
/// Task Executor Metrics
metrics: TaskExecutorMetrics,
/// How many [`GracefulShutdown`] tasks are currently active
@@ -433,7 +449,7 @@ impl TaskExecutor {
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let panicked_tasks_tx = self.task_events_tx.clone();
let on_shutdown = self.on_shutdown.clone();
// wrap the task in catch unwind
@@ -442,7 +458,7 @@ impl TaskExecutor {
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
})
.in_current_span();
@@ -492,7 +508,7 @@ impl TaskExecutor {
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let panicked_tasks_tx = self.task_events_tx.clone();
let on_shutdown = self.on_shutdown.clone();
let fut = f(on_shutdown);
@@ -502,7 +518,7 @@ impl TaskExecutor {
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
})
.map(drop)
.in_current_span();
@@ -538,7 +554,7 @@ impl TaskExecutor {
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let panicked_tasks_tx = self.task_events_tx.clone();
let on_shutdown = GracefulShutdown::new(
self.on_shutdown.clone(),
GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
@@ -551,7 +567,7 @@ impl TaskExecutor {
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
})
.map(drop)
.in_current_span();
@@ -593,6 +609,25 @@ impl TaskExecutor {
self.handle.spawn(fut)
}
/// Sends a request to the `TaskManager` to initiate a graceful shutdown.
///
/// Caution: This will terminate the entire program.
///
/// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that
/// can be handled via the returned [`GracefulShutdown`].
pub fn initiate_graceful_shutdown(
&self,
) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
self.task_events_tx
.send(TaskEvent::GracefulShutdown)
.map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
Ok(GracefulShutdown::new(
self.on_shutdown.clone(),
GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
))
}
}
impl TaskSpawner for TaskExecutor {
@@ -711,9 +746,12 @@ mod tests {
executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
runtime.block_on(async move {
let err = manager.await;
assert_eq!(err.task_name, "this is a critical task");
assert_eq!(err.error, Some("intentionally panic".to_string()));
let err_result = manager.await;
assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
let panicked_err = err_result.unwrap_err();
assert_eq!(panicked_err.task_name, "this is a critical task");
assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
})
}
@@ -829,4 +867,41 @@ mod tests {
let _manager = TaskManager::new(handle);
let _executor = TaskExecutor::try_current().unwrap();
}
#[test]
fn test_graceful_shutdown_triggered_by_executor() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let task_manager = TaskManager::new(runtime.handle().clone());
let executor = task_manager.executor();
let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
let flag_clone = task_did_shutdown_flag.clone();
let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
shutdown_signal.await;
flag_clone.store(true, Ordering::SeqCst);
});
let manager_future_handle = runtime.spawn(task_manager);
let send_result = executor.initiate_graceful_shutdown();
assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
let manager_final_result = runtime.block_on(manager_future_handle);
assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
assert_eq!(
manager_final_result.unwrap(),
Ok(()),
"TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
);
let task_join_result = runtime.block_on(spawned_task_handle);
assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
assert!(
task_did_shutdown_flag.load(Ordering::Relaxed),
"Task should have received the shutdown signal and set the flag"
);
}
}