refactor: remove utils-aio (#60)

This commit is contained in:
sinu.eth
2025-04-02 11:17:49 +07:00
committed by GitHub
parent 90c5ef84b8
commit c1db19e6cc
13 changed files with 0 additions and 968 deletions

View File

@@ -4,7 +4,6 @@ members = [
"spansy",
"uid-mux",
"utils",
"utils-aio",
"websocket-relay",
"futures-limit",
"futures-plex",
@@ -23,9 +22,7 @@ tlsn-utils-aio = { path = "utils-aio" }
uid-mux = { path = "uid-mux" }
rangeset = { path = "rangeset" }
async-std = "1"
async-trait = "0.1"
async-tungstenite = "0.16"
bincode = "1.3"
bytes = "1"
cfg-if = "1"
@@ -38,8 +35,6 @@ futures-sink = "0.3"
futures-util = "0.3"
pin-project-lite = "0.2"
pollster = "0.4"
prost = "0.9"
prost-build = "0.9"
rand = "0.8"
rayon = "1"
serde = "1"

View File

@@ -1,41 +0,0 @@
[package]
name = "tlsn-utils-aio"
version = "0.1.0"
edition = "2021"
[lib]
name = "utils_aio"
[features]
default = ["mux", "duplex", "tokio_compat", "wasm_compat", "codec"]
codec = ["serde", "tokio-serde", "tokio-serde/bincode"]
mux = []
duplex = []
tokio_compat = ["tokio/rt"]
wasm_compat = ["dep:wasm-bindgen-futures"]
[dependencies]
bytes.workspace = true
tokio = { workspace = true, features = ["sync", "io-util"] }
tokio-util = { workspace = true, features = ["codec", "compat"] }
tokio-serde = { version = "0.8", optional = true }
wasm-bindgen-futures = { version = "0.4", optional = true }
async-tungstenite.workspace = true
futures.workspace = true
futures-util.workspace = true
async-trait.workspace = true
thiserror.workspace = true
async-std.workspace = true
rayon.workspace = true
serde = { workspace = true, optional = true }
[dev-dependencies]
tokio = { workspace = true, features = [
"macros",
"rt",
"rt-multi-thread",
"io-util",
"time",
] }
serde = { workspace = true, features = ["derive"] }
tokio-serde = { version = "0.8", features = ["bincode"] }

View File

@@ -1,150 +0,0 @@
use tokio::sync::broadcast::{channel, error::RecvError, Sender};
/// An adaptive barrier
///
/// This allows to change the number of barriers dynamically.
/// Code is taken from https://users.rust-lang.org/t/a-poor-man-async-adaptive-barrier/68118
#[derive(Debug, Clone)]
pub struct AdaptiveBarrier {
inner: Sender<Empty>,
}
impl AdaptiveBarrier {
/// Wait in order to perform task synchronization
///
/// Waits for all other barriers who have been cloned from this one
/// to also call `wait`
pub async fn wait(self) {
let mut receiver = self.inner.subscribe();
drop(self.inner);
match receiver.recv().await {
Ok(_) => unreachable!(),
Err(RecvError::Lagged(_)) => unreachable!(),
Err(RecvError::Closed) => (),
}
}
pub fn new() -> Self {
Self {
// even though we will not receive any data, we still
// need to set the channel's capacity to the required minimum 1
inner: channel(1).0,
}
}
}
impl Default for AdaptiveBarrier {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy)]
enum Empty {}
#[cfg(test)]
mod tests {
use super::AdaptiveBarrier;
use std::{mem::replace, sync::Arc, time::Duration};
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
struct Waiter {
counter: Arc<Mutex<Vec<usize>>>,
barrier: AdaptiveBarrier,
}
impl Waiter {
fn new(counter: &Arc<Mutex<Vec<usize>>>) -> Self {
Self {
counter: Arc::clone(counter),
barrier: AdaptiveBarrier::new(),
}
}
// add a new value to the counter. if no values are present in the
// counter, the new value will be 1.
async fn count(self) {
let mut counter = self.counter.lock().await;
let last = counter.last().copied().unwrap_or_default();
counter.push(last + 1);
}
async fn count_wait(mut self) {
// use `replace()` because we can't call `self.barrier.wait().await;` here
let barrier = replace(&mut self.barrier, AdaptiveBarrier::new());
barrier.wait().await;
self.count().await;
}
}
// We expect that 0 is not the first number in the counter because we do not use
// the barrier in this test
#[tokio::test]
async fn test_adaptive_barrier_no_wait() {
let counter = Arc::new(Mutex::new(vec![]));
let waiter = Waiter::new(&counter);
let waiter_2 = waiter.clone();
let waiter_3 = waiter.clone();
let task = tokio::spawn(async move {
waiter.count().await;
});
let task_2 = tokio::spawn(async move {
waiter_2.count().await;
});
// the reason why we are not using here:
// _ = tokio::join!(task, task_2);
// is to make this test comparable to `test_adaptive_barrier_wait`
tokio::time::sleep(Duration::from_millis(1000)).await;
{
// Add 0 to counter. But this will not be the first number
// since `task` and `task_2` were already able to add to
// the counter.
counter.lock().await.push(0);
}
// both tasks must be finished now
assert!(task.is_finished() && task_2.is_finished());
let task_3 = tokio::spawn(async move {
waiter_3.count().await;
});
_ = tokio::join!(task, task_2, task_3);
assert_ne!(*counter.lock().await.first().unwrap(), 0);
}
// Now we use `count_wait` instead of `count` so 0 should be the first number
#[tokio::test]
async fn test_adaptive_barrier_wait() {
let counter = Arc::new(Mutex::new(vec![]));
let waiter = Waiter::new(&counter);
let waiter_2 = waiter.clone();
let waiter_3 = waiter.clone();
let task = tokio::spawn(async move {
waiter.count_wait().await;
});
let task_2 = tokio::spawn(async move {
waiter_2.count_wait().await;
});
tokio::time::sleep(Duration::from_millis(1000)).await;
{
counter.lock().await.push(0);
}
// both tasks must NOT be finished yet
assert!(!task.is_finished() && !task_2.is_finished());
// Now we wait for the last barrier, so all tasks can start counting
let task_3 = tokio::spawn(async move {
waiter_3.count_wait().await;
});
_ = tokio::join!(task, task_2, task_3);
assert_eq!(*counter.lock().await.first().unwrap(), 0);
}
}

View File

@@ -1,112 +0,0 @@
use async_trait::async_trait;
use futures_util::{AsyncRead, AsyncWrite};
use tokio_serde::formats::Bincode;
use tokio_util::{codec::LengthDelimitedCodec, compat::FuturesAsyncReadCompatExt};
use crate::{
duplex::Duplex,
mux::{MuxChannelSerde, MuxStream, MuxerError},
};
/// Wraps a [`MuxStream`] and provides a [`Channel`] with a bincode codec
#[derive(Debug, Clone)]
pub struct BincodeMux<M>(M);
impl<M> BincodeMux<M>
where
M: MuxStream,
{
/// Creates a new bincode mux
pub fn new(mux: M) -> Self {
Self(mux)
}
/// Claim the inner type that implements MuxStream
pub fn into_inner(self) -> M {
self.0
}
/// Attaches a bincode codec to the provided stream
pub fn attach_codec<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static, T>(
&self,
stream: S,
) -> impl Duplex<T>
where
T: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + Unpin + 'static,
{
let framed = LengthDelimitedCodec::builder().new_framed(stream.compat());
tokio_serde::Framed::new(framed, Bincode::default())
}
}
#[async_trait]
impl<M> MuxChannelSerde for BincodeMux<M>
where
M: MuxStream + Send + 'static,
{
async fn get_channel<
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
>(
&mut self,
id: &str,
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
let stream = self.0.get_stream(id).await?;
Ok(Box::new(self.attach_codec(stream)))
}
}
#[cfg(test)]
mod tests {
use crate::mux::mock::MockMuxChannelFactory;
use super::*;
use futures::{SinkExt, StreamExt};
#[derive(serde::Serialize, serde::Deserialize)]
struct Foo {
msg: String,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Bar {
msg: String,
}
#[tokio::test]
async fn test_mux_codec() {
let mux = MockMuxChannelFactory::new();
let mut framed_mux = BincodeMux::new(mux);
let mut channel_0 = framed_mux.get_channel("foo").await.unwrap();
let mut channel_1 = framed_mux.get_channel("foo").await.unwrap();
channel_0
.send(Foo {
msg: "hello".to_string(),
})
.await
.unwrap();
let msg: Foo = channel_1.next().await.unwrap().unwrap();
assert_eq!(msg.msg, "hello");
let mut channel_0 = framed_mux.get_channel("bar").await.unwrap();
let mut channel_1 = framed_mux.get_channel("bar").await.unwrap();
channel_0
.send(Bar {
msg: "world".to_string(),
})
.await
.unwrap();
let msg: Bar = channel_1.next().await.unwrap().unwrap();
assert_eq!(msg.msg, "world");
}
}

View File

@@ -1,94 +0,0 @@
use std::{
io::{Error, ErrorKind},
pin::Pin,
};
use futures::{channel::mpsc, AsyncRead, AsyncWrite, Sink, Stream};
use crate::{sink::IoSink, stream::IoStream};
pub trait DuplexByteStream: AsyncWrite + AsyncRead + Unpin {}
impl<T> DuplexByteStream for T where T: AsyncWrite + AsyncRead + Unpin {}
/// A channel that can be used to send and receive messages.
pub trait Duplex<T>: IoStream<T> + IoSink<T> + Send + Sync + Unpin {}
impl<T, U> Duplex<T> for U where U: IoStream<T> + IoSink<T> + Send + Sync + Unpin {}
#[derive(Debug)]
pub struct MemoryDuplex<T> {
sink: mpsc::Sender<T>,
stream: mpsc::Receiver<T>,
}
impl<T> MemoryDuplex<T>
where
T: Send + 'static,
{
pub fn new() -> (Self, Self) {
let (sender, receiver) = mpsc::channel(10);
let (sender_2, receiver_2) = mpsc::channel(10);
(
Self {
sink: sender,
stream: receiver_2,
},
Self {
sink: sender_2,
stream: receiver,
},
)
}
}
impl<T> Sink<T> for MemoryDuplex<T>
where
T: Send + 'static,
{
type Error = std::io::Error;
fn poll_ready(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_ready(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
}
fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::new(&mut self.sink)
.start_send(item)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_flush(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_close(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
}
}
impl<T> Stream for MemoryDuplex<T> {
type Item = Result<T, std::io::Error>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx).map(|x| x.map(Ok))
}
}

View File

@@ -1,95 +0,0 @@
use futures::task::Spawn;
/// Compatibility trait for implementing `futures::task::Spawn` for executors.
pub trait SpawnCompatExt {
/// Wrap the executor in a `Compat` wrapper.
fn compat(self) -> Compat<Self>
where
Self: Sized;
/// Wrap a reference to the executor in a `Compat` wrapper.
fn compat_ref(&self) -> Compat<&Self>;
/// Wrap a mutable reference to the executor in a `Compat` wrapper.
fn compat_mut(&mut self) -> Compat<&mut Self>;
}
impl<T> SpawnCompatExt for T {
fn compat(self) -> Compat<Self>
where
Self: Sized,
{
Compat::new(self)
}
fn compat_ref(&self) -> Compat<&Self> {
Compat::new(self)
}
fn compat_mut(&mut self) -> Compat<&mut Self> {
Compat::new(self)
}
}
pub struct Compat<T>(T);
impl<T> Compat<T> {
/// Create a new `Compat` wrapper around `inner`.
pub fn new(inner: T) -> Self {
Self(inner)
}
}
#[cfg(feature = "tokio_compat")]
impl Spawn for Compat<tokio::runtime::Runtime> {
fn spawn_obj(
&self,
future: futures::future::FutureObj<'static, ()>,
) -> Result<(), futures::task::SpawnError> {
drop(self.0.spawn(future));
Ok(())
}
}
#[cfg(feature = "tokio_compat")]
impl Spawn for Compat<tokio::runtime::Handle> {
fn spawn_obj(
&self,
future: futures::future::FutureObj<'static, ()>,
) -> Result<(), futures::task::SpawnError> {
drop(self.0.spawn(future));
Ok(())
}
}
#[cfg(feature = "wasm_compat")]
#[derive(Debug, Clone, Copy)]
pub struct WasmBindgenExecutor;
#[cfg(feature = "wasm_compat")]
impl Spawn for WasmBindgenExecutor {
fn spawn_obj(
&self,
future: futures::future::FutureObj<'static, ()>,
) -> Result<(), futures::task::SpawnError> {
wasm_bindgen_futures::spawn_local(future);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::task::SpawnExt;
async fn foo(exec: &impl Spawn) {
let task = exec.spawn_with_handle(async {}).unwrap();
task.await
}
#[tokio::test]
async fn test_tokio_compat() {
foo(&tokio::runtime::Handle::current().compat()).await;
}
}

View File

@@ -1,58 +0,0 @@
/// Extract expected variant of an enum and handle errors
///
/// This macro is intended to simplify extracting the expected message
/// when doing communication.
/// - The first argument is the expression, which is matched
/// - the second argument is the expected enum variant
#[macro_export]
macro_rules! expect_msg_or_err {
($stream:expr, $expected:path) => {
match $stream.next().await {
Some(Ok($expected(msg))) => Ok(msg),
Some(Ok(other)) => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unexpected message: {:?}", other),
)),
Some(Err(e)) => Err(e),
None => Err(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"stream closed unexpectedly",
)),
}
};
}
#[cfg(test)]
mod tests {
use futures_util::StreamExt;
#[derive(Debug)]
#[allow(dead_code)]
enum Msg {
Foo(u8),
Bar(u8),
}
#[tokio::test]
async fn test_expect_msg_macro() -> std::io::Result<()> {
let mut stream = Box::pin(futures::stream::once(async { Ok(Msg::Foo(0u8)) }));
let _ = expect_msg_or_err!(stream, Msg::Foo).unwrap();
let mut stream = Box::pin(futures::stream::once(async { Ok(Msg::Bar(0u8)) }));
let err = expect_msg_or_err!(stream, Msg::Foo).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
let mut stream = Box::pin(futures::stream::once(async {
Err::<Msg, _>(std::io::Error::from(std::io::ErrorKind::BrokenPipe))
}));
let err = expect_msg_or_err!(stream, Msg::Foo).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
Ok(())
}
}

View File

@@ -1,14 +0,0 @@
use async_trait::async_trait;
/// This trait is for factories which produce their items asynchronously
#[async_trait]
pub trait AsyncFactory<T> {
type Config;
type Error;
/// Creates new instance
///
/// * `id` - Unique ID of instance
/// * `config` - Instance configuration
async fn create(&mut self, id: String, config: Self::Config) -> Result<T, Self::Error>;
}

View File

@@ -1,13 +0,0 @@
pub mod adaptive_barrier;
#[cfg(feature = "codec")]
pub mod codec;
#[cfg(feature = "duplex")]
pub mod duplex;
pub mod executor;
pub mod expect_msg;
pub mod factory;
#[cfg(feature = "mux")]
pub mod mux;
pub mod non_blocking_backend;
pub mod sink;
pub mod stream;

View File

@@ -1,174 +0,0 @@
use async_trait::async_trait;
use futures_util::{AsyncRead, AsyncWrite};
use crate::duplex::Duplex;
#[derive(Debug, thiserror::Error)]
pub enum MuxerError {
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("internal error: {0:?}")]
InternalError(String),
#[error("duplicate stream id: {0:?}")]
DuplicateStreamId(String),
}
/// A trait for opening a new duplex byte stream with a remote peer.
#[async_trait]
pub trait MuxStream: Clone {
type Stream: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static;
/// Opens a new stream with the remote using the provided id
async fn get_stream(&mut self, id: &str) -> Result<Self::Stream, MuxerError>;
}
/// A trait for opening a new duplex channel with a remote peer.
#[async_trait]
pub trait MuxChannelSerde: Sized {
/// Opens a new channel with the remote using the provided id
async fn get_channel<
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
>(
&mut self,
id: &str,
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError>;
}
/// A trait for opening a new duplex channel with a remote peer.
///
/// This trait is similar to [`MuxChannelSized`] except it is object safe.
#[async_trait]
pub trait MuxChannel<T> {
/// Opens a new channel with the remote using the provided id
///
/// Attaches a codec to the underlying stream
async fn get_channel(&mut self, id: &str) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError>;
}
#[async_trait]
impl<T, U> MuxChannel<T> for U
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
U: MuxChannelSerde + Send,
{
async fn get_channel(&mut self, id: &str) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
self.get_channel::<T>(id).await
}
}
pub mod mock {
use tokio_util::compat::TokioAsyncReadCompatExt;
use super::*;
use std::{
any::Any,
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use crate::duplex::MemoryDuplex;
#[derive(Default)]
struct FactoryState {
exists: HashSet<String>,
buffer: HashMap<String, Box<dyn Any + Send + 'static>>,
}
#[derive(Default, Clone)]
pub struct MockMuxChannelFactory {
state: Arc<Mutex<FactoryState>>,
}
impl MockMuxChannelFactory {
/// Creates a new mock mux channel factory
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(FactoryState {
exists: HashSet::new(),
buffer: HashMap::new(),
})),
}
}
}
#[async_trait]
impl MuxStream for MockMuxChannelFactory {
type Stream = tokio_util::compat::Compat<tokio::io::DuplexStream>;
async fn get_stream(&mut self, id: &str) -> Result<Self::Stream, MuxerError> {
let mut state = self.state.lock().unwrap();
if let Some(stream) = state.buffer.remove(id) {
if let Ok(stream) = stream.downcast::<tokio::io::DuplexStream>() {
Ok((*stream).compat())
} else {
Err(MuxerError::InternalError(
"failed to downcast stream".to_string(),
))
}
} else {
if !state.exists.insert(id.to_string()) {
return Err(MuxerError::DuplicateStreamId(id.to_string()));
}
let (stream_0, stream_1) = tokio::io::duplex(1 << 23);
state.buffer.insert(id.to_string(), Box::new(stream_1));
Ok(stream_0.compat())
}
}
}
#[async_trait]
impl MuxChannelSerde for MockMuxChannelFactory {
async fn get_channel<T: Send + 'static>(
&mut self,
id: &str,
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
let mut state = self.state.lock().unwrap();
if let Some(channel) = state.buffer.remove(id) {
if let Ok(channel) = channel.downcast::<MemoryDuplex<T>>() {
Ok(channel)
} else {
Err(MuxerError::InternalError(
"failed to downcast channel".to_string(),
))
}
} else {
if !state.exists.insert(id.to_string()) {
return Err(MuxerError::DuplicateStreamId(id.to_string()));
}
let (channel_0, channel_1) = MemoryDuplex::new();
state.buffer.insert(id.to_string(), Box::new(channel_1));
Ok(Box::new(channel_0))
}
}
}
#[cfg(test)]
mod test {
use futures::{SinkExt, StreamExt};
use super::{MockMuxChannelFactory, MuxChannelSerde};
#[tokio::test]
async fn test_mock_mux_channel_factory() {
let mut factory = MockMuxChannelFactory::new();
let mut channel_0 = factory.get_channel("test").await.unwrap();
let mut channel_1 = factory.get_channel("test").await.unwrap();
channel_0.send(0).await.unwrap();
let received = channel_1.next().await.unwrap().unwrap();
assert_eq!(received, 0);
channel_1.send(0).await.unwrap();
let received = channel_0.next().await.unwrap().unwrap();
assert_eq!(received, 0);
}
}
}

View File

@@ -1,110 +0,0 @@
use async_trait::async_trait;
use futures::channel::oneshot;
pub type Backend = RayonBackend;
/// Allows to spawn a closure on a thread outside of the async runtime
///
/// This allows to perform CPU-intensive tasks without blocking the runtime.
#[async_trait]
pub trait NonBlockingBackend {
/// Spawn the closure in a separate thread and await the result
async fn spawn<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(closure: F) -> T;
}
/// A CPU backend that uses Rayon
pub struct RayonBackend;
#[async_trait]
impl NonBlockingBackend for RayonBackend {
async fn spawn<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(closure: F) -> T {
let (sender, receiver) = oneshot::channel();
rayon::spawn(move || {
_ = sender.send(closure());
});
receiver.await.expect("channel should not be canceled")
}
}
/// A macro for asynchronously evaluating an expression on a non-blocking backend.
///
/// The expression must be `Send + 'static`, including it's returned type.
///
/// All variables referenced in the expression are moved to the backend scope.
///
/// # Example
///
/// ```rust
/// # use utils_aio::blocking;
/// # futures::executor::block_on(async {
/// let a = 1u8;
/// let b = 2u8;
///
/// let sum = blocking!(a + b);
///
/// assert_eq!(sum, 3);
/// # });
/// ```
///
/// # Example: Returned arguments
///
/// When variables used in the expression they are moved into the backend scope. If you still need
/// to use them after evaluating the expression you can have them returned using the following syntax:
///
/// ```rust
/// # use utils_aio::blocking;
/// struct NotCopy(u32);
///
/// # futures::executor::block_on(async {
/// let a = NotCopy(1);
/// let b = NotCopy(2);
///
/// let (a, b, sum) = blocking! {
/// (a, b) => a.0 + b.0
/// };
///
/// assert_eq!(a.0, 1);
/// assert_eq!(b.0, 2);
/// assert_eq!(sum, 3);
/// # });
/// ```
#[macro_export]
macro_rules! blocking {
(($($arg:ident),+) => $expr:expr) => {
{
use $crate::non_blocking_backend::NonBlockingBackend;
$crate::non_blocking_backend::Backend::spawn(move || {
let result = $expr;
($($arg),+, result)
}).await
}
};
(() => $expr:expr) => {
{
use $crate::non_blocking_backend::NonBlockingBackend;
$crate::non_blocking_backend::Backend::spawn(move || $expr).await
}
};
($expr:expr) => {
{
use $crate::non_blocking_backend::NonBlockingBackend;
$crate::non_blocking_backend::Backend::spawn(move || $expr).await
}
};
}
#[cfg(test)]
mod tests {
use super::{Backend, NonBlockingBackend};
#[tokio::test]
async fn test_spawn() {
let sum = Backend::spawn(compute_sum).await;
assert_eq!(sum, 4950);
}
fn compute_sum() -> u32 {
(0..100).sum()
}
}

View File

@@ -1,6 +0,0 @@
use futures::Sink;
/// A sink with `std::io::Error` as the error type.
pub trait IoSink<T>: Sink<T, Error = std::io::Error> {}
impl<T, U> IoSink<T> for U where U: Sink<T, Error = std::io::Error> {}

View File

@@ -1,96 +0,0 @@
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use futures::{future::FusedFuture, stream::FusedStream, Future, Stream, TryStream, TryStreamExt};
/// A stream which yields `std::io::Result<T>` items.
pub trait IoStream<T>: Stream<Item = Result<T, std::io::Error>> {}
impl<T, U> IoStream<U> for T where T: Stream<Item = Result<U, std::io::Error>> {}
/// Future for the [`expect_next`](Duplex::expect_next) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ExpectNext<'a, St: ?Sized> {
stream: &'a mut St,
}
impl<St: ?Sized + Unpin> Unpin for ExpectNext<'_, St> {}
impl<'a, St: ?Sized + TryStream + Unpin> ExpectNext<'a, St> {
pub(super) fn new(stream: &'a mut St) -> Self {
Self { stream }
}
}
impl<St: ?Sized + TryStream + Unpin + FusedStream> FusedFuture for ExpectNext<'_, St>
where
<St as TryStream>::Error: Into<io::Error>,
{
fn is_terminated(&self) -> bool {
self.stream.is_terminated()
}
}
impl<St: ?Sized + TryStream + Unpin> Future for ExpectNext<'_, St>
where
<St as TryStream>::Error: Into<io::Error>,
{
type Output = io::Result<St::Ok>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.stream
.try_poll_next_unpin(cx)
.map_err(|e| e.into())?
.map(|item| {
if let Some(item) = item {
Ok(item)
} else {
Err(io::ErrorKind::UnexpectedEof.into())
}
})
}
}
/// Extension trait for [`TryStream`](futures::stream::TryStream).
pub trait ExpectStreamExt: TryStream + Unpin {
/// Creates a future that attempts to resolve the next item in the stream.
/// If an error is encountered before the next item, the error is returned
/// instead.
///
/// Additionally, if the stream ends before the next item, an error is
/// returned.
///
/// This is similar to the [`TryStreamExt::try_next`](futures::stream::TryStreamExt::try_next)
/// combinator, but returns an error if the stream ends before the next item instead of an
/// `Option`.
///
/// # Examples
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::{SinkExt, StreamExt};
/// use utils_aio::stream::ExpectStreamExt;
/// use utils_aio::duplex::MemoryDuplex;
///
/// let (mut a, mut b) = MemoryDuplex::new();
///
/// a.send(()).await.unwrap();
/// a.close().await.unwrap();
///
/// assert!(b.expect_next().await.is_ok());
/// assert!(b.expect_next().await.is_err());
/// # })
/// ```
fn expect_next(&mut self) -> ExpectNext<'_, Self>
where
Self: Sized,
{
ExpectNext::new(self)
}
}
impl<T> ExpectStreamExt for T where T: TryStream + Unpin {}