StoppableTask: use CondVar instead of channels, add logs and make impl more robust

This commit is contained in:
x
2023-08-31 09:44:13 +02:00
parent 7ed79b1365
commit 2c94dfdfa9

View File

@@ -16,47 +16,29 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use std::sync::Arc;
use log::trace;
use rand::{rngs::OsRng, Rng};
use smol::{
channel,
future::{self, Future},
Executor,
};
use std::sync::Arc;
use super::CondVar;
pub type StoppableTaskPtr = Arc<StoppableTask>;
pub struct StoppableTask {
// NOTE: we could send the error code from stop() instead of having it specified in start()
// but then that would introduce lifetimes to the entire struct.
stop_send: channel::Sender<()>,
stop_recv: channel::Receiver<()>,
stop_barrier: CondVar,
/// Used to signal to the main running process that it should stop.
signal: CondVar,
/// When we call `stop()`, we wait until the process is finished. This is used to prevent
/// `stop()` from exiting until the task has closed.
barrier: CondVar,
// Used so we can keep StoppableTask in HashMap/HashSet
task_id: usize,
/// Used so we can keep StoppableTask in HashMap/HashSet
task_id: u32,
}
impl std::hash::Hash for StoppableTask {
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.task_id.hash(state);
}
}
impl std::cmp::PartialEq for StoppableTask {
fn eq(&self, other: &Self) -> bool {
self.task_id == other.task_id
}
}
impl std::cmp::Eq for StoppableTask {}
/// A task that can be prematurely stopped at any time.
///
/// ```rust
@@ -72,15 +54,15 @@ impl std::cmp::Eq for StoppableTask {}
/// Then at any time we can call `task.stop()` to close the task.
impl StoppableTask {
pub fn new() -> Arc<Self> {
let (stop_send, stop_recv) = channel::bounded(1);
Arc::new(Self { stop_send, stop_recv, stop_barrier: CondVar::new(), task_id: OsRng.gen() })
Arc::new(Self { signal: CondVar::new(), barrier: CondVar::new(), task_id: OsRng.gen() })
}
/// Stops the task. Will return when the process has fully closed.
/// Stops the task. On completion, guarantees the process has stopped.
pub async fn stop(&self) {
// Ignore any errors from this send
let _ = self.stop_send.send(()).await;
self.stop_barrier.wait().await;
trace!(target: "system::StoppableTask", "Stopping task {}", self.task_id);
self.signal.notify();
self.barrier.wait().await;
trace!(target: "system::StoppableTask", "Stopped task {}", self.task_id);
}
/// Starts the task.
@@ -100,17 +82,57 @@ impl StoppableTask {
StopFn: FnOnce(std::result::Result<(), Error>) -> StopFut + Send + 'a,
Error: std::error::Error + Send + 'a,
{
// NOTE: we could send the error code from stop() instead of having it specified in start()
trace!(target: "system::StoppableTask", "Starting task {}", self.task_id);
// Allow stopping and starting task again.
// NOTE: maybe we should disallow this with a panic?
self.signal.reset();
self.barrier.reset();
executor
.spawn(async move {
// Task which waits for a stop signal
let stop_fut = async {
let _ = self.stop_recv.recv().await;
self.signal.wait().await;
trace!(
target: "system::StoppableTask",
"Stop signal received for task {}",
self.task_id
);
Err(stop_value)
};
// Wait on our main task or stop task - whichever finishes first
let result = future::or(main, stop_fut).await;
trace!(
target: "system::StoppableTask",
"Closing task {} with result: {:?}",
self.task_id,
result
);
stop_handler(result).await;
self.stop_barrier.notify();
// Allow `stop()` to finish
self.barrier.notify();
})
.detach();
}
}
impl std::hash::Hash for StoppableTask {
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.task_id.hash(state);
}
}
impl std::cmp::PartialEq for StoppableTask {
fn eq(&self, other: &Self) -> bool {
self.task_id == other.task_id
}
}
impl std::cmp::Eq for StoppableTask {}