mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-08 22:48:09 -05:00
refactor: remove utils-aio (#60)
This commit is contained in:
@@ -4,7 +4,6 @@ members = [
|
||||
"spansy",
|
||||
"uid-mux",
|
||||
"utils",
|
||||
"utils-aio",
|
||||
"websocket-relay",
|
||||
"futures-limit",
|
||||
"futures-plex",
|
||||
@@ -23,9 +22,7 @@ tlsn-utils-aio = { path = "utils-aio" }
|
||||
uid-mux = { path = "uid-mux" }
|
||||
rangeset = { path = "rangeset" }
|
||||
|
||||
async-std = "1"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.16"
|
||||
bincode = "1.3"
|
||||
bytes = "1"
|
||||
cfg-if = "1"
|
||||
@@ -38,8 +35,6 @@ futures-sink = "0.3"
|
||||
futures-util = "0.3"
|
||||
pin-project-lite = "0.2"
|
||||
pollster = "0.4"
|
||||
prost = "0.9"
|
||||
prost-build = "0.9"
|
||||
rand = "0.8"
|
||||
rayon = "1"
|
||||
serde = "1"
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-utils-aio"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "utils_aio"
|
||||
|
||||
[features]
|
||||
default = ["mux", "duplex", "tokio_compat", "wasm_compat", "codec"]
|
||||
codec = ["serde", "tokio-serde", "tokio-serde/bincode"]
|
||||
mux = []
|
||||
duplex = []
|
||||
tokio_compat = ["tokio/rt"]
|
||||
wasm_compat = ["dep:wasm-bindgen-futures"]
|
||||
|
||||
[dependencies]
|
||||
bytes.workspace = true
|
||||
tokio = { workspace = true, features = ["sync", "io-util"] }
|
||||
tokio-util = { workspace = true, features = ["codec", "compat"] }
|
||||
tokio-serde = { version = "0.8", optional = true }
|
||||
wasm-bindgen-futures = { version = "0.4", optional = true }
|
||||
async-tungstenite.workspace = true
|
||||
futures.workspace = true
|
||||
futures-util.workspace = true
|
||||
async-trait.workspace = true
|
||||
thiserror.workspace = true
|
||||
async-std.workspace = true
|
||||
rayon.workspace = true
|
||||
serde = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = [
|
||||
"macros",
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"io-util",
|
||||
"time",
|
||||
] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
tokio-serde = { version = "0.8", features = ["bincode"] }
|
||||
@@ -1,150 +0,0 @@
|
||||
use tokio::sync::broadcast::{channel, error::RecvError, Sender};
|
||||
|
||||
/// An adaptive barrier
|
||||
///
|
||||
/// This allows to change the number of barriers dynamically.
|
||||
/// Code is taken from https://users.rust-lang.org/t/a-poor-man-async-adaptive-barrier/68118
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveBarrier {
|
||||
inner: Sender<Empty>,
|
||||
}
|
||||
|
||||
impl AdaptiveBarrier {
|
||||
/// Wait in order to perform task synchronization
|
||||
///
|
||||
/// Waits for all other barriers who have been cloned from this one
|
||||
/// to also call `wait`
|
||||
pub async fn wait(self) {
|
||||
let mut receiver = self.inner.subscribe();
|
||||
drop(self.inner);
|
||||
match receiver.recv().await {
|
||||
Ok(_) => unreachable!(),
|
||||
Err(RecvError::Lagged(_)) => unreachable!(),
|
||||
Err(RecvError::Closed) => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
// even though we will not receive any data, we still
|
||||
// need to set the channel's capacity to the required minimum 1
|
||||
inner: channel(1).0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AdaptiveBarrier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum Empty {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::AdaptiveBarrier;
|
||||
use std::{mem::replace, sync::Arc, time::Duration};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Waiter {
|
||||
counter: Arc<Mutex<Vec<usize>>>,
|
||||
barrier: AdaptiveBarrier,
|
||||
}
|
||||
|
||||
impl Waiter {
|
||||
fn new(counter: &Arc<Mutex<Vec<usize>>>) -> Self {
|
||||
Self {
|
||||
counter: Arc::clone(counter),
|
||||
barrier: AdaptiveBarrier::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// add a new value to the counter. if no values are present in the
|
||||
// counter, the new value will be 1.
|
||||
async fn count(self) {
|
||||
let mut counter = self.counter.lock().await;
|
||||
let last = counter.last().copied().unwrap_or_default();
|
||||
counter.push(last + 1);
|
||||
}
|
||||
|
||||
async fn count_wait(mut self) {
|
||||
// use `replace()` because we can't call `self.barrier.wait().await;` here
|
||||
let barrier = replace(&mut self.barrier, AdaptiveBarrier::new());
|
||||
barrier.wait().await;
|
||||
self.count().await;
|
||||
}
|
||||
}
|
||||
|
||||
// We expect that 0 is not the first number in the counter because we do not use
|
||||
// the barrier in this test
|
||||
#[tokio::test]
|
||||
async fn test_adaptive_barrier_no_wait() {
|
||||
let counter = Arc::new(Mutex::new(vec![]));
|
||||
|
||||
let waiter = Waiter::new(&counter);
|
||||
let waiter_2 = waiter.clone();
|
||||
let waiter_3 = waiter.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
waiter.count().await;
|
||||
});
|
||||
let task_2 = tokio::spawn(async move {
|
||||
waiter_2.count().await;
|
||||
});
|
||||
|
||||
// the reason why we are not using here:
|
||||
// _ = tokio::join!(task, task_2);
|
||||
// is to make this test comparable to `test_adaptive_barrier_wait`
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
{
|
||||
// Add 0 to counter. But this will not be the first number
|
||||
// since `task` and `task_2` were already able to add to
|
||||
// the counter.
|
||||
counter.lock().await.push(0);
|
||||
}
|
||||
|
||||
// both tasks must be finished now
|
||||
assert!(task.is_finished() && task_2.is_finished());
|
||||
|
||||
let task_3 = tokio::spawn(async move {
|
||||
waiter_3.count().await;
|
||||
});
|
||||
_ = tokio::join!(task, task_2, task_3);
|
||||
assert_ne!(*counter.lock().await.first().unwrap(), 0);
|
||||
}
|
||||
|
||||
// Now we use `count_wait` instead of `count` so 0 should be the first number
|
||||
#[tokio::test]
|
||||
async fn test_adaptive_barrier_wait() {
|
||||
let counter = Arc::new(Mutex::new(vec![]));
|
||||
|
||||
let waiter = Waiter::new(&counter);
|
||||
let waiter_2 = waiter.clone();
|
||||
let waiter_3 = waiter.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
waiter.count_wait().await;
|
||||
});
|
||||
let task_2 = tokio::spawn(async move {
|
||||
waiter_2.count_wait().await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
{
|
||||
counter.lock().await.push(0);
|
||||
}
|
||||
|
||||
// both tasks must NOT be finished yet
|
||||
assert!(!task.is_finished() && !task_2.is_finished());
|
||||
|
||||
// Now we wait for the last barrier, so all tasks can start counting
|
||||
let task_3 = tokio::spawn(async move {
|
||||
waiter_3.count_wait().await;
|
||||
});
|
||||
_ = tokio::join!(task, task_2, task_3);
|
||||
assert_eq!(*counter.lock().await.first().unwrap(), 0);
|
||||
}
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{AsyncRead, AsyncWrite};
|
||||
use tokio_serde::formats::Bincode;
|
||||
use tokio_util::{codec::LengthDelimitedCodec, compat::FuturesAsyncReadCompatExt};
|
||||
|
||||
use crate::{
|
||||
duplex::Duplex,
|
||||
mux::{MuxChannelSerde, MuxStream, MuxerError},
|
||||
};
|
||||
|
||||
/// Wraps a [`MuxStream`] and provides a [`Channel`] with a bincode codec
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BincodeMux<M>(M);
|
||||
|
||||
impl<M> BincodeMux<M>
|
||||
where
|
||||
M: MuxStream,
|
||||
{
|
||||
/// Creates a new bincode mux
|
||||
pub fn new(mux: M) -> Self {
|
||||
Self(mux)
|
||||
}
|
||||
|
||||
/// Claim the inner type that implements MuxStream
|
||||
pub fn into_inner(self) -> M {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Attaches a bincode codec to the provided stream
|
||||
pub fn attach_codec<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static, T>(
|
||||
&self,
|
||||
stream: S,
|
||||
) -> impl Duplex<T>
|
||||
where
|
||||
T: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + Unpin + 'static,
|
||||
{
|
||||
let framed = LengthDelimitedCodec::builder().new_framed(stream.compat());
|
||||
|
||||
tokio_serde::Framed::new(framed, Bincode::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<M> MuxChannelSerde for BincodeMux<M>
|
||||
where
|
||||
M: MuxStream + Send + 'static,
|
||||
{
|
||||
async fn get_channel<
|
||||
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
|
||||
>(
|
||||
&mut self,
|
||||
id: &str,
|
||||
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
|
||||
let stream = self.0.get_stream(id).await?;
|
||||
|
||||
Ok(Box::new(self.attach_codec(stream)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::mux::mock::MockMuxChannelFactory;
|
||||
|
||||
use super::*;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Foo {
|
||||
msg: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Bar {
|
||||
msg: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mux_codec() {
|
||||
let mux = MockMuxChannelFactory::new();
|
||||
|
||||
let mut framed_mux = BincodeMux::new(mux);
|
||||
|
||||
let mut channel_0 = framed_mux.get_channel("foo").await.unwrap();
|
||||
let mut channel_1 = framed_mux.get_channel("foo").await.unwrap();
|
||||
|
||||
channel_0
|
||||
.send(Foo {
|
||||
msg: "hello".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg: Foo = channel_1.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(msg.msg, "hello");
|
||||
|
||||
let mut channel_0 = framed_mux.get_channel("bar").await.unwrap();
|
||||
let mut channel_1 = framed_mux.get_channel("bar").await.unwrap();
|
||||
|
||||
channel_0
|
||||
.send(Bar {
|
||||
msg: "world".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg: Bar = channel_1.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(msg.msg, "world");
|
||||
}
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
use std::{
|
||||
io::{Error, ErrorKind},
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
use futures::{channel::mpsc, AsyncRead, AsyncWrite, Sink, Stream};
|
||||
|
||||
use crate::{sink::IoSink, stream::IoStream};
|
||||
|
||||
pub trait DuplexByteStream: AsyncWrite + AsyncRead + Unpin {}
|
||||
|
||||
impl<T> DuplexByteStream for T where T: AsyncWrite + AsyncRead + Unpin {}
|
||||
|
||||
/// A channel that can be used to send and receive messages.
|
||||
pub trait Duplex<T>: IoStream<T> + IoSink<T> + Send + Sync + Unpin {}
|
||||
|
||||
impl<T, U> Duplex<T> for U where U: IoStream<T> + IoSink<T> + Send + Sync + Unpin {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryDuplex<T> {
|
||||
sink: mpsc::Sender<T>,
|
||||
stream: mpsc::Receiver<T>,
|
||||
}
|
||||
|
||||
impl<T> MemoryDuplex<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
pub fn new() -> (Self, Self) {
|
||||
let (sender, receiver) = mpsc::channel(10);
|
||||
let (sender_2, receiver_2) = mpsc::channel(10);
|
||||
(
|
||||
Self {
|
||||
sink: sender,
|
||||
stream: receiver_2,
|
||||
},
|
||||
Self {
|
||||
sink: sender_2,
|
||||
stream: receiver,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sink<T> for MemoryDuplex<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_ready(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
|
||||
}
|
||||
|
||||
fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.sink)
|
||||
.start_send(item)
|
||||
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_flush(cx)
|
||||
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_close(cx)
|
||||
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for MemoryDuplex<T> {
|
||||
type Item = Result<T, std::io::Error>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.stream).poll_next(cx).map(|x| x.map(Ok))
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
use futures::task::Spawn;
|
||||
|
||||
/// Compatibility trait for implementing `futures::task::Spawn` for executors.
|
||||
pub trait SpawnCompatExt {
|
||||
/// Wrap the executor in a `Compat` wrapper.
|
||||
fn compat(self) -> Compat<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Wrap a reference to the executor in a `Compat` wrapper.
|
||||
fn compat_ref(&self) -> Compat<&Self>;
|
||||
|
||||
/// Wrap a mutable reference to the executor in a `Compat` wrapper.
|
||||
fn compat_mut(&mut self) -> Compat<&mut Self>;
|
||||
}
|
||||
|
||||
impl<T> SpawnCompatExt for T {
|
||||
fn compat(self) -> Compat<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Compat::new(self)
|
||||
}
|
||||
|
||||
fn compat_ref(&self) -> Compat<&Self> {
|
||||
Compat::new(self)
|
||||
}
|
||||
|
||||
fn compat_mut(&mut self) -> Compat<&mut Self> {
|
||||
Compat::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Compat<T>(T);
|
||||
|
||||
impl<T> Compat<T> {
|
||||
/// Create a new `Compat` wrapper around `inner`.
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self(inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio_compat")]
|
||||
impl Spawn for Compat<tokio::runtime::Runtime> {
|
||||
fn spawn_obj(
|
||||
&self,
|
||||
future: futures::future::FutureObj<'static, ()>,
|
||||
) -> Result<(), futures::task::SpawnError> {
|
||||
drop(self.0.spawn(future));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio_compat")]
|
||||
impl Spawn for Compat<tokio::runtime::Handle> {
|
||||
fn spawn_obj(
|
||||
&self,
|
||||
future: futures::future::FutureObj<'static, ()>,
|
||||
) -> Result<(), futures::task::SpawnError> {
|
||||
drop(self.0.spawn(future));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "wasm_compat")]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WasmBindgenExecutor;
|
||||
|
||||
#[cfg(feature = "wasm_compat")]
|
||||
impl Spawn for WasmBindgenExecutor {
|
||||
fn spawn_obj(
|
||||
&self,
|
||||
future: futures::future::FutureObj<'static, ()>,
|
||||
) -> Result<(), futures::task::SpawnError> {
|
||||
wasm_bindgen_futures::spawn_local(future);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use futures::task::SpawnExt;
|
||||
|
||||
async fn foo(exec: &impl Spawn) {
|
||||
let task = exec.spawn_with_handle(async {}).unwrap();
|
||||
task.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tokio_compat() {
|
||||
foo(&tokio::runtime::Handle::current().compat()).await;
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
/// Extract expected variant of an enum and handle errors
|
||||
///
|
||||
/// This macro is intended to simplify extracting the expected message
|
||||
/// when doing communication.
|
||||
/// - The first argument is the expression, which is matched
|
||||
/// - the second argument is the expected enum variant
|
||||
#[macro_export]
|
||||
macro_rules! expect_msg_or_err {
|
||||
($stream:expr, $expected:path) => {
|
||||
match $stream.next().await {
|
||||
Some(Ok($expected(msg))) => Ok(msg),
|
||||
Some(Ok(other)) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("unexpected message: {:?}", other),
|
||||
)),
|
||||
Some(Err(e)) => Err(e),
|
||||
None => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
"stream closed unexpectedly",
|
||||
)),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
enum Msg {
|
||||
Foo(u8),
|
||||
Bar(u8),
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expect_msg_macro() -> std::io::Result<()> {
|
||||
let mut stream = Box::pin(futures::stream::once(async { Ok(Msg::Foo(0u8)) }));
|
||||
|
||||
let _ = expect_msg_or_err!(stream, Msg::Foo).unwrap();
|
||||
|
||||
let mut stream = Box::pin(futures::stream::once(async { Ok(Msg::Bar(0u8)) }));
|
||||
|
||||
let err = expect_msg_or_err!(stream, Msg::Foo).unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
|
||||
let mut stream = Box::pin(futures::stream::once(async {
|
||||
Err::<Msg, _>(std::io::Error::from(std::io::ErrorKind::BrokenPipe))
|
||||
}));
|
||||
|
||||
let err = expect_msg_or_err!(stream, Msg::Foo).unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// This trait is for factories which produce their items asynchronously
|
||||
#[async_trait]
|
||||
pub trait AsyncFactory<T> {
|
||||
type Config;
|
||||
type Error;
|
||||
|
||||
/// Creates new instance
|
||||
///
|
||||
/// * `id` - Unique ID of instance
|
||||
/// * `config` - Instance configuration
|
||||
async fn create(&mut self, id: String, config: Self::Config) -> Result<T, Self::Error>;
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
pub mod adaptive_barrier;
|
||||
#[cfg(feature = "codec")]
|
||||
pub mod codec;
|
||||
#[cfg(feature = "duplex")]
|
||||
pub mod duplex;
|
||||
pub mod executor;
|
||||
pub mod expect_msg;
|
||||
pub mod factory;
|
||||
#[cfg(feature = "mux")]
|
||||
pub mod mux;
|
||||
pub mod non_blocking_backend;
|
||||
pub mod sink;
|
||||
pub mod stream;
|
||||
@@ -1,174 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::duplex::Duplex;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MuxerError {
|
||||
#[error(transparent)]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("internal error: {0:?}")]
|
||||
InternalError(String),
|
||||
#[error("duplicate stream id: {0:?}")]
|
||||
DuplicateStreamId(String),
|
||||
}
|
||||
|
||||
/// A trait for opening a new duplex byte stream with a remote peer.
|
||||
#[async_trait]
|
||||
pub trait MuxStream: Clone {
|
||||
type Stream: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static;
|
||||
|
||||
/// Opens a new stream with the remote using the provided id
|
||||
async fn get_stream(&mut self, id: &str) -> Result<Self::Stream, MuxerError>;
|
||||
}
|
||||
|
||||
/// A trait for opening a new duplex channel with a remote peer.
|
||||
#[async_trait]
|
||||
pub trait MuxChannelSerde: Sized {
|
||||
/// Opens a new channel with the remote using the provided id
|
||||
async fn get_channel<
|
||||
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
|
||||
>(
|
||||
&mut self,
|
||||
id: &str,
|
||||
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError>;
|
||||
}
|
||||
|
||||
/// A trait for opening a new duplex channel with a remote peer.
|
||||
///
|
||||
/// This trait is similar to [`MuxChannelSized`] except it is object safe.
|
||||
#[async_trait]
|
||||
pub trait MuxChannel<T> {
|
||||
/// Opens a new channel with the remote using the provided id
|
||||
///
|
||||
/// Attaches a codec to the underlying stream
|
||||
async fn get_channel(&mut self, id: &str) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> MuxChannel<T> for U
|
||||
where
|
||||
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Unpin + 'static,
|
||||
U: MuxChannelSerde + Send,
|
||||
{
|
||||
async fn get_channel(&mut self, id: &str) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
|
||||
self.get_channel::<T>(id).await
|
||||
}
|
||||
}
|
||||
|
||||
pub mod mock {
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
use std::{
|
||||
any::Any,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::duplex::MemoryDuplex;
|
||||
|
||||
#[derive(Default)]
|
||||
struct FactoryState {
|
||||
exists: HashSet<String>,
|
||||
buffer: HashMap<String, Box<dyn Any + Send + 'static>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct MockMuxChannelFactory {
|
||||
state: Arc<Mutex<FactoryState>>,
|
||||
}
|
||||
|
||||
impl MockMuxChannelFactory {
|
||||
/// Creates a new mock mux channel factory
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(FactoryState {
|
||||
exists: HashSet::new(),
|
||||
buffer: HashMap::new(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MuxStream for MockMuxChannelFactory {
|
||||
type Stream = tokio_util::compat::Compat<tokio::io::DuplexStream>;
|
||||
|
||||
async fn get_stream(&mut self, id: &str) -> Result<Self::Stream, MuxerError> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
if let Some(stream) = state.buffer.remove(id) {
|
||||
if let Ok(stream) = stream.downcast::<tokio::io::DuplexStream>() {
|
||||
Ok((*stream).compat())
|
||||
} else {
|
||||
Err(MuxerError::InternalError(
|
||||
"failed to downcast stream".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
if !state.exists.insert(id.to_string()) {
|
||||
return Err(MuxerError::DuplicateStreamId(id.to_string()));
|
||||
}
|
||||
|
||||
let (stream_0, stream_1) = tokio::io::duplex(1 << 23);
|
||||
state.buffer.insert(id.to_string(), Box::new(stream_1));
|
||||
Ok(stream_0.compat())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MuxChannelSerde for MockMuxChannelFactory {
|
||||
async fn get_channel<T: Send + 'static>(
|
||||
&mut self,
|
||||
id: &str,
|
||||
) -> Result<Box<dyn Duplex<T> + 'static>, MuxerError> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
if let Some(channel) = state.buffer.remove(id) {
|
||||
if let Ok(channel) = channel.downcast::<MemoryDuplex<T>>() {
|
||||
Ok(channel)
|
||||
} else {
|
||||
Err(MuxerError::InternalError(
|
||||
"failed to downcast channel".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
if !state.exists.insert(id.to_string()) {
|
||||
return Err(MuxerError::DuplicateStreamId(id.to_string()));
|
||||
}
|
||||
|
||||
let (channel_0, channel_1) = MemoryDuplex::new();
|
||||
state.buffer.insert(id.to_string(), Box::new(channel_1));
|
||||
|
||||
Ok(Box::new(channel_0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use super::{MockMuxChannelFactory, MuxChannelSerde};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_mux_channel_factory() {
|
||||
let mut factory = MockMuxChannelFactory::new();
|
||||
let mut channel_0 = factory.get_channel("test").await.unwrap();
|
||||
let mut channel_1 = factory.get_channel("test").await.unwrap();
|
||||
|
||||
channel_0.send(0).await.unwrap();
|
||||
let received = channel_1.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(received, 0);
|
||||
|
||||
channel_1.send(0).await.unwrap();
|
||||
let received = channel_0.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(received, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use futures::channel::oneshot;
|
||||
|
||||
pub type Backend = RayonBackend;
|
||||
|
||||
/// Allows to spawn a closure on a thread outside of the async runtime
|
||||
///
|
||||
/// This allows to perform CPU-intensive tasks without blocking the runtime.
|
||||
#[async_trait]
|
||||
pub trait NonBlockingBackend {
|
||||
/// Spawn the closure in a separate thread and await the result
|
||||
async fn spawn<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(closure: F) -> T;
|
||||
}
|
||||
|
||||
/// A CPU backend that uses Rayon
|
||||
pub struct RayonBackend;
|
||||
|
||||
#[async_trait]
|
||||
impl NonBlockingBackend for RayonBackend {
|
||||
async fn spawn<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(closure: F) -> T {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
rayon::spawn(move || {
|
||||
_ = sender.send(closure());
|
||||
});
|
||||
|
||||
receiver.await.expect("channel should not be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
/// A macro for asynchronously evaluating an expression on a non-blocking backend.
|
||||
///
|
||||
/// The expression must be `Send + 'static`, including it's returned type.
|
||||
///
|
||||
/// All variables referenced in the expression are moved to the backend scope.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use utils_aio::blocking;
|
||||
/// # futures::executor::block_on(async {
|
||||
/// let a = 1u8;
|
||||
/// let b = 2u8;
|
||||
///
|
||||
/// let sum = blocking!(a + b);
|
||||
///
|
||||
/// assert_eq!(sum, 3);
|
||||
/// # });
|
||||
/// ```
|
||||
///
|
||||
/// # Example: Returned arguments
|
||||
///
|
||||
/// When variables used in the expression they are moved into the backend scope. If you still need
|
||||
/// to use them after evaluating the expression you can have them returned using the following syntax:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use utils_aio::blocking;
|
||||
/// struct NotCopy(u32);
|
||||
///
|
||||
/// # futures::executor::block_on(async {
|
||||
/// let a = NotCopy(1);
|
||||
/// let b = NotCopy(2);
|
||||
///
|
||||
/// let (a, b, sum) = blocking! {
|
||||
/// (a, b) => a.0 + b.0
|
||||
/// };
|
||||
///
|
||||
/// assert_eq!(a.0, 1);
|
||||
/// assert_eq!(b.0, 2);
|
||||
/// assert_eq!(sum, 3);
|
||||
/// # });
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! blocking {
|
||||
(($($arg:ident),+) => $expr:expr) => {
|
||||
{
|
||||
use $crate::non_blocking_backend::NonBlockingBackend;
|
||||
$crate::non_blocking_backend::Backend::spawn(move || {
|
||||
let result = $expr;
|
||||
($($arg),+, result)
|
||||
}).await
|
||||
}
|
||||
};
|
||||
(() => $expr:expr) => {
|
||||
{
|
||||
use $crate::non_blocking_backend::NonBlockingBackend;
|
||||
$crate::non_blocking_backend::Backend::spawn(move || $expr).await
|
||||
}
|
||||
};
|
||||
($expr:expr) => {
|
||||
{
|
||||
use $crate::non_blocking_backend::NonBlockingBackend;
|
||||
$crate::non_blocking_backend::Backend::spawn(move || $expr).await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Backend, NonBlockingBackend};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn() {
|
||||
let sum = Backend::spawn(compute_sum).await;
|
||||
assert_eq!(sum, 4950);
|
||||
}
|
||||
|
||||
fn compute_sum() -> u32 {
|
||||
(0..100).sum()
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
use futures::Sink;
|
||||
|
||||
/// A sink with `std::io::Error` as the error type.
|
||||
pub trait IoSink<T>: Sink<T, Error = std::io::Error> {}
|
||||
|
||||
impl<T, U> IoSink<T> for U where U: Sink<T, Error = std::io::Error> {}
|
||||
@@ -1,96 +0,0 @@
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{future::FusedFuture, stream::FusedStream, Future, Stream, TryStream, TryStreamExt};
|
||||
|
||||
/// A stream which yields `std::io::Result<T>` items.
|
||||
pub trait IoStream<T>: Stream<Item = Result<T, std::io::Error>> {}
|
||||
|
||||
impl<T, U> IoStream<U> for T where T: Stream<Item = Result<U, std::io::Error>> {}
|
||||
|
||||
/// Future for the [`expect_next`](Duplex::expect_next) method.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub struct ExpectNext<'a, St: ?Sized> {
|
||||
stream: &'a mut St,
|
||||
}
|
||||
|
||||
impl<St: ?Sized + Unpin> Unpin for ExpectNext<'_, St> {}
|
||||
|
||||
impl<'a, St: ?Sized + TryStream + Unpin> ExpectNext<'a, St> {
|
||||
pub(super) fn new(stream: &'a mut St) -> Self {
|
||||
Self { stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ?Sized + TryStream + Unpin + FusedStream> FusedFuture for ExpectNext<'_, St>
|
||||
where
|
||||
<St as TryStream>::Error: Into<io::Error>,
|
||||
{
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.stream.is_terminated()
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ?Sized + TryStream + Unpin> Future for ExpectNext<'_, St>
|
||||
where
|
||||
<St as TryStream>::Error: Into<io::Error>,
|
||||
{
|
||||
type Output = io::Result<St::Ok>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.stream
|
||||
.try_poll_next_unpin(cx)
|
||||
.map_err(|e| e.into())?
|
||||
.map(|item| {
|
||||
if let Some(item) = item {
|
||||
Ok(item)
|
||||
} else {
|
||||
Err(io::ErrorKind::UnexpectedEof.into())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait for [`TryStream`](futures::stream::TryStream).
|
||||
pub trait ExpectStreamExt: TryStream + Unpin {
|
||||
/// Creates a future that attempts to resolve the next item in the stream.
|
||||
/// If an error is encountered before the next item, the error is returned
|
||||
/// instead.
|
||||
///
|
||||
/// Additionally, if the stream ends before the next item, an error is
|
||||
/// returned.
|
||||
///
|
||||
/// This is similar to the [`TryStreamExt::try_next`](futures::stream::TryStreamExt::try_next)
|
||||
/// combinator, but returns an error if the stream ends before the next item instead of an
|
||||
/// `Option`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # futures::executor::block_on(async {
|
||||
/// use futures::{SinkExt, StreamExt};
|
||||
/// use utils_aio::stream::ExpectStreamExt;
|
||||
/// use utils_aio::duplex::MemoryDuplex;
|
||||
///
|
||||
/// let (mut a, mut b) = MemoryDuplex::new();
|
||||
///
|
||||
/// a.send(()).await.unwrap();
|
||||
/// a.close().await.unwrap();
|
||||
///
|
||||
/// assert!(b.expect_next().await.is_ok());
|
||||
/// assert!(b.expect_next().await.is_err());
|
||||
/// # })
|
||||
/// ```
|
||||
fn expect_next(&mut self) -> ExpectNext<'_, Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
ExpectNext::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ExpectStreamExt for T where T: TryStream + Unpin {}
|
||||
Reference in New Issue
Block a user