diff --git a/src/system/stoppable_task.rs b/src/system/stoppable_task.rs index 95da7ce94..437ea8a2b 100644 --- a/src/system/stoppable_task.rs +++ b/src/system/stoppable_task.rs @@ -16,47 +16,29 @@ * along with this program. If not, see . */ -use std::sync::Arc; - +use log::trace; use rand::{rngs::OsRng, Rng}; use smol::{ - channel, future::{self, Future}, Executor, }; +use std::sync::Arc; use super::CondVar; pub type StoppableTaskPtr = Arc; pub struct StoppableTask { - // NOTE: we could send the error code from stop() instead of having it specified in start() - // but then that would introduce lifetimes to the entire struct. - stop_send: channel::Sender<()>, - stop_recv: channel::Receiver<()>, - stop_barrier: CondVar, + /// Used to signal to the main running process that it should stop. + signal: CondVar, + /// When we call `stop()`, we wait until the process is finished. This is used to prevent + /// `stop()` from exiting until the task has closed. + barrier: CondVar, - // Used so we can keep StoppableTask in HashMap/HashSet - task_id: usize, + /// Used so we can keep StoppableTask in HashMap/HashSet + task_id: u32, } -impl std::hash::Hash for StoppableTask { - fn hash(&self, state: &mut H) - where - H: std::hash::Hasher, - { - self.task_id.hash(state); - } -} - -impl std::cmp::PartialEq for StoppableTask { - fn eq(&self, other: &Self) -> bool { - self.task_id == other.task_id - } -} - -impl std::cmp::Eq for StoppableTask {} - /// A task that can be prematurely stopped at any time. /// /// ```rust @@ -72,15 +54,15 @@ impl std::cmp::Eq for StoppableTask {} /// Then at any time we can call `task.stop()` to close the task. impl StoppableTask { pub fn new() -> Arc { - let (stop_send, stop_recv) = channel::bounded(1); - Arc::new(Self { stop_send, stop_recv, stop_barrier: CondVar::new(), task_id: OsRng.gen() }) + Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() }) } - /// Stops the task. Will return when the process has fully closed. + /// Stops the task. On completion, guarantees the process has stopped. pub async fn stop(&self) { - // Ignore any errors from this send - let _ = self.stop_send.send(()).await; - self.stop_barrier.wait().await; + trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id); + self.signal.notify(); + self.barrier.wait().await; + trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id); } /// Starts the task. @@ -100,17 +82,57 @@ impl StoppableTask { StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a, Error: std::error::Error + Send + 'a, { + // NOTE: we could send the error code from stop() instead of having it specified in start() + trace!(target: "system::StoppableTask", "Starting task {}", self.task_id); + // Allow stopping and starting task again. + // NOTE: maybe we should disallow this with a panic? + self.signal.reset(); + self.barrier.reset(); + executor .spawn(async move { + // Task which waits for a stop signal let stop_fut = async { - let _ = self.stop_recv.recv().await; + self.signal.wait().await; + trace!( + target: "system::StoppableTask", + "Stop signal received for task {}", + self.task_id + ); Err(stop_value) }; + // Wait on our main task or stop task - whichever finishes first let result = future::or(main, stop_fut).await; + + trace!( + target: "system::StoppableTask", + "Closing task {} with result: {:?}", + self.task_id, + result + ); + stop_handler(result).await; - self.stop_barrier.notify(); + // Allow `stop()` to finish + self.barrier.notify(); }) .detach(); } } + +impl std::hash::Hash for StoppableTask { + fn hash(&self, state: &mut H) + where + H: std::hash::Hasher, + { + self.task_id.hash(state); + } +} + +impl std::cmp::PartialEq for StoppableTask { + fn eq(&self, other: &Self) -> bool { + self.task_id == other.task_id + } +} + +impl std::cmp::Eq for StoppableTask {}