diff --git a/crates/cli/runner/src/lib.rs b/crates/cli/runner/src/lib.rs index 48caf171ee..3060391d97 100644 --- a/crates/cli/runner/src/lib.rs +++ b/crates/cli/runner/src/lib.rs @@ -174,8 +174,10 @@ where { let fut = pin!(fut); tokio::select! { - err = tasks => { - return Err(err.into()) + task_manager_result = tasks => { + if let Err(panicked_error) = task_manager_result { + return Err(panicked_error.into()); + } }, res = fut => res?, } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 3213f03824..a4776798e2 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -162,10 +162,10 @@ pub struct TaskManager { /// /// See [`Handle`] docs. handle: Handle, - /// Sender half for sending panic signals to this type - panicked_tasks_tx: UnboundedSender, - /// Listens for panicked tasks - panicked_tasks_rx: UnboundedReceiver, + /// Sender half for sending task events to this type + task_events_tx: UnboundedSender, + /// Receiver for task events + task_events_rx: UnboundedReceiver, /// The [Signal] to fire when all tasks should be shutdown. /// /// This is fired when dropped. @@ -197,12 +197,12 @@ impl TaskManager { /// /// This also sets the global [`TaskExecutor`]. pub fn new(handle: Handle) -> Self { - let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel(); + let (task_events_tx, task_events_rx) = unbounded_channel(); let (signal, on_shutdown) = signal(); let manager = Self { handle, - panicked_tasks_tx, - panicked_tasks_rx, + task_events_tx, + task_events_rx, signal: Some(signal), on_shutdown, graceful_tasks: Arc::new(AtomicUsize::new(0)), @@ -221,7 +221,7 @@ impl TaskManager { TaskExecutor { handle: self.handle.clone(), on_shutdown: self.on_shutdown.clone(), - panicked_tasks_tx: self.panicked_tasks_tx.clone(), + task_events_tx: self.task_events_tx.clone(), metrics: Default::default(), graceful_tasks: Arc::clone(&self.graceful_tasks), } @@ -259,16 +259,23 @@ impl TaskManager { /// /// See [`TaskExecutor::spawn_critical`] impl Future for TaskManager { - type Output = PanickedTaskError; + type Output = Result<(), PanickedTaskError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx)); - Poll::Ready(err.expect("stream can not end")) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) { + Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)), + Some(TaskEvent::GracefulShutdown) | None => { + if let Some(signal) = self.get_mut().signal.take() { + signal.fire(); + } + Poll::Ready(Ok(())) + } + } } } /// Error with the name of the task that panicked and an error downcasted to string, if possible. -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq, Eq)] pub struct PanickedTaskError { task_name: &'static str, error: Option, @@ -299,6 +306,15 @@ impl PanickedTaskError { } } +/// Represents the events that the `TaskManager`'s main future can receive. +#[derive(Debug)] +enum TaskEvent { + /// Indicates that a critical task has panicked. + Panic(PanickedTaskError), + /// A signal requesting a graceful shutdown of the `TaskManager`. + GracefulShutdown, +} + /// A type that can spawn new tokio tasks #[derive(Debug, Clone)] pub struct TaskExecutor { @@ -308,8 +324,8 @@ pub struct TaskExecutor { handle: Handle, /// Receiver of the shutdown signal. on_shutdown: Shutdown, - /// Sender half for sending panic signals to this type - panicked_tasks_tx: UnboundedSender, + /// Sender half for sending task events to this type + task_events_tx: UnboundedSender, /// Task Executor Metrics metrics: TaskExecutorMetrics, /// How many [`GracefulShutdown`] tasks are currently active @@ -433,7 +449,7 @@ impl TaskExecutor { where F: Future + Send + 'static, { - let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let panicked_tasks_tx = self.task_events_tx.clone(); let on_shutdown = self.on_shutdown.clone(); // wrap the task in catch unwind @@ -442,7 +458,7 @@ impl TaskExecutor { .map_err(move |error| { let task_error = PanickedTaskError::new(name, error); error!("{task_error}"); - let _ = panicked_tasks_tx.send(task_error); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); }) .in_current_span(); @@ -492,7 +508,7 @@ impl TaskExecutor { where F: Future + Send + 'static, { - let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let panicked_tasks_tx = self.task_events_tx.clone(); let on_shutdown = self.on_shutdown.clone(); let fut = f(on_shutdown); @@ -502,7 +518,7 @@ impl TaskExecutor { .map_err(move |error| { let task_error = PanickedTaskError::new(name, error); error!("{task_error}"); - let _ = panicked_tasks_tx.send(task_error); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); }) .map(drop) .in_current_span(); @@ -538,7 +554,7 @@ impl TaskExecutor { where F: Future + Send + 'static, { - let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let panicked_tasks_tx = self.task_events_tx.clone(); let on_shutdown = GracefulShutdown::new( self.on_shutdown.clone(), GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), @@ -551,7 +567,7 @@ impl TaskExecutor { .map_err(move |error| { let task_error = PanickedTaskError::new(name, error); error!("{task_error}"); - let _ = panicked_tasks_tx.send(task_error); + let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error)); }) .map(drop) .in_current_span(); @@ -593,6 +609,25 @@ impl TaskExecutor { self.handle.spawn(fut) } + + /// Sends a request to the `TaskManager` to initiate a graceful shutdown. + /// + /// Caution: This will terminate the entire program. + /// + /// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that + /// can be handled via the returned [`GracefulShutdown`]. + pub fn initiate_graceful_shutdown( + &self, + ) -> Result> { + self.task_events_tx + .send(TaskEvent::GracefulShutdown) + .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?; + + Ok(GracefulShutdown::new( + self.on_shutdown.clone(), + GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)), + )) + } } impl TaskSpawner for TaskExecutor { @@ -711,9 +746,12 @@ mod tests { executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") }); runtime.block_on(async move { - let err = manager.await; - assert_eq!(err.task_name, "this is a critical task"); - assert_eq!(err.error, Some("intentionally panic".to_string())); + let err_result = manager.await; + assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic"); + let panicked_err = err_result.unwrap_err(); + + assert_eq!(panicked_err.task_name, "this is a critical task"); + assert_eq!(panicked_err.error, Some("intentionally panic".to_string())); }) } @@ -829,4 +867,41 @@ mod tests { let _manager = TaskManager::new(handle); let _executor = TaskExecutor::try_current().unwrap(); } + + #[test] + fn test_graceful_shutdown_triggered_by_executor() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let task_manager = TaskManager::new(runtime.handle().clone()); + let executor = task_manager.executor(); + + let task_did_shutdown_flag = Arc::new(AtomicBool::new(false)); + let flag_clone = task_did_shutdown_flag.clone(); + + let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move { + shutdown_signal.await; + flag_clone.store(true, Ordering::SeqCst); + }); + + let manager_future_handle = runtime.spawn(task_manager); + + let send_result = executor.initiate_graceful_shutdown(); + assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future"); + + let manager_final_result = runtime.block_on(manager_future_handle); + + assert!(manager_final_result.is_ok(), "TaskManager task should not panic"); + assert_eq!( + manager_final_result.unwrap(), + Ok(()), + "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request" + ); + + let task_join_result = runtime.block_on(spawned_task_handle); + assert!(task_join_result.is_ok(), "Spawned task should complete without panic"); + + assert!( + task_did_shutdown_flag.load(Ordering::Relaxed), + "Task should have received the shutdown signal and set the flag" + ); + } }