mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-15 17:35:18 -05:00
159 lines
4.2 KiB
Rust
159 lines
4.2 KiB
Rust
//! Helper for shutdown signals
|
|
|
|
use futures_util::{
|
|
future::{FusedFuture, Shared},
|
|
FutureExt,
|
|
};
|
|
use std::{
|
|
future::Future,
|
|
pin::Pin,
|
|
sync::{atomic::AtomicUsize, Arc},
|
|
task::{ready, Context, Poll},
|
|
};
|
|
use tokio::sync::oneshot;
|
|
|
|
/// A Future that resolves when the shutdown event has been fired.
|
|
#[derive(Debug)]
|
|
pub struct GracefulShutdown {
|
|
shutdown: Shutdown,
|
|
guard: Option<GracefulShutdownGuard>,
|
|
}
|
|
|
|
impl GracefulShutdown {
|
|
pub(crate) const fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self {
|
|
Self { shutdown, guard: Some(guard) }
|
|
}
|
|
|
|
/// Returns a new shutdown future that is ignores the returned [`GracefulShutdownGuard`].
|
|
///
|
|
/// This just maps the return value of the future to `()`, it does not drop the guard.
|
|
pub fn ignore_guard(self) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
|
|
self.map(drop)
|
|
}
|
|
}
|
|
|
|
impl Future for GracefulShutdown {
|
|
type Output = GracefulShutdownGuard;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
ready!(self.shutdown.poll_unpin(cx));
|
|
Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion"))
|
|
}
|
|
}
|
|
|
|
impl Clone for GracefulShutdown {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
shutdown: self.shutdown.clone(),
|
|
guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A guard that fires once dropped to signal the [`TaskManager`](crate::TaskManager) that the
|
|
/// [`GracefulShutdown`] has completed.
|
|
#[derive(Debug)]
|
|
#[must_use = "if unused the task will not be gracefully shutdown"]
|
|
pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
|
|
|
|
impl GracefulShutdownGuard {
|
|
pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
|
|
counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
Self(counter)
|
|
}
|
|
}
|
|
|
|
impl Drop for GracefulShutdownGuard {
|
|
fn drop(&mut self) {
|
|
self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
|
|
}
|
|
}
|
|
|
|
/// A Future that resolves when the shutdown event has been fired.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Shutdown(Shared<oneshot::Receiver<()>>);
|
|
|
|
impl Future for Shutdown {
|
|
type Output = ();
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let pin = self.get_mut();
|
|
if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
|
|
Poll::Ready(())
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Shutdown signal that fires either manually or on drop by closing the channel
|
|
#[derive(Debug)]
|
|
pub struct Signal(oneshot::Sender<()>);
|
|
|
|
impl Signal {
|
|
/// Fire the signal manually.
|
|
pub fn fire(self) {
|
|
let _ = self.0.send(());
|
|
}
|
|
}
|
|
|
|
/// Create a channel pair that's used to propagate shutdown event
|
|
pub fn signal() -> (Signal, Shutdown) {
|
|
let (sender, receiver) = oneshot::channel();
|
|
(Signal(sender), Shutdown(receiver.shared()))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use futures_util::future::join_all;
|
|
use std::time::Duration;
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn test_shutdown() {
|
|
let (_signal, _shutdown) = signal();
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn test_drop_signal() {
|
|
let (signal, shutdown) = signal();
|
|
|
|
tokio::task::spawn(async move {
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
drop(signal)
|
|
});
|
|
|
|
shutdown.await;
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn test_multi_shutdowns() {
|
|
let (signal, shutdown) = signal();
|
|
|
|
let mut tasks = Vec::with_capacity(100);
|
|
for _ in 0..100 {
|
|
let shutdown = shutdown.clone();
|
|
let task = tokio::task::spawn(async move {
|
|
shutdown.await;
|
|
});
|
|
tasks.push(task);
|
|
}
|
|
|
|
drop(signal);
|
|
|
|
join_all(tasks).await;
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn test_drop_signal_from_thread() {
|
|
let (signal, shutdown) = signal();
|
|
|
|
let _thread = std::thread::spawn(|| {
|
|
std::thread::sleep(Duration::from_millis(500));
|
|
drop(signal)
|
|
});
|
|
|
|
shutdown.await;
|
|
}
|
|
}
|