Partition Workspaces

This commit is contained in:
sinu
2023-01-12 12:00:05 -08:00
parent fdf08da8ed
commit 6c65591b76
12 changed files with 463 additions and 11 deletions

View File

@@ -1,12 +1,16 @@
[package]
name = "tlsn-utils"
version = "0.1.0"
edition = "2021"
[workspace]
members = ["utils", "utils-aio"]
[lib]
name = "utils"
[dependencies]
[dev-dependencies]
rand = { workspace = true }
[workspace.dependencies]
rand = "0.8"
thiserror = "1"
async-trait = "0.1"
prost = "0.9"
futures = "0.3"
futures-util = "0.3"
tokio-util = "0.7"
tokio = "1.23"
async-tungstenite = "0.16"
prost-build = "0.9"
bytes = "1"
async-std = "1"

View File

@@ -0,0 +1,34 @@
[package]
name = "tlsn-utils-aio"
version = "0.1.0"
edition = "2021"
[lib]
name = "utils_aio"
[features]
default = ["mux", "duplex"]
codec = []
mux = []
duplex = []
[dependencies]
bytes.workspace = true
prost.workspace = true
tokio = { workspace = true, features = ["sync"] }
tokio-util = { workspace = true, features = ["codec", "compat"] }
async-tungstenite.workspace = true
futures.workspace = true
futures-util.workspace = true
async-trait.workspace = true
thiserror.workspace = true
async-std.workspace = true
[dev-dependencies]
tokio = { workspace = true, features = [
"macros",
"rt",
"rt-multi-thread",
"io-util",
"time",
] }

View File

@@ -0,0 +1,150 @@
use tokio::sync::broadcast::{channel, error::RecvError, Sender};
/// An adaptive barrier
///
/// This allows to change the number of barriers dynamically.
/// Code is taken from https://users.rust-lang.org/t/a-poor-man-async-adaptive-barrier/68118
#[derive(Debug, Clone)]
pub struct AdaptiveBarrier {
inner: Sender<Empty>,
}
impl AdaptiveBarrier {
/// Wait in order to perform task synchronization
///
/// Waits for all other barriers who have been cloned from this one
/// to also call `wait`
pub async fn wait(self) {
let mut receiver = self.inner.subscribe();
drop(self.inner);
match receiver.recv().await {
Ok(_) => unreachable!(),
Err(RecvError::Lagged(_)) => unreachable!(),
Err(RecvError::Closed) => (),
}
}
pub fn new() -> Self {
Self {
// even though we will not receive any data, we still
// need to set the channel's capacity to the required minimum 1
inner: channel(1).0,
}
}
}
impl Default for AdaptiveBarrier {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy)]
enum Empty {}
#[cfg(test)]
mod tests {
use super::AdaptiveBarrier;
use std::{mem::replace, sync::Arc, time::Duration};
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
struct Waiter {
counter: Arc<Mutex<Vec<usize>>>,
barrier: AdaptiveBarrier,
}
impl Waiter {
fn new(counter: &Arc<Mutex<Vec<usize>>>) -> Self {
Self {
counter: Arc::clone(counter),
barrier: AdaptiveBarrier::new(),
}
}
// add a new value to the counter. if no values are present in the
// counter, the new value will be 1.
async fn count(self) {
let mut counter = self.counter.lock().await;
let last = counter.last().copied().unwrap_or_default();
counter.push(last + 1);
}
async fn count_wait(mut self) {
// use `replace()` because we can't call `self.barrier.wait().await;` here
let barrier = replace(&mut self.barrier, AdaptiveBarrier::new());
barrier.wait().await;
self.count().await;
}
}
// We expect that 0 is not the first number in the counter because we do not use
// the barrier in this test
#[tokio::test]
async fn test_adaptive_barrier_no_wait() {
let counter = Arc::new(Mutex::new(vec![]));
let waiter = Waiter::new(&counter);
let waiter_2 = waiter.clone();
let waiter_3 = waiter.clone();
let task = tokio::spawn(async move {
waiter.count().await;
});
let task_2 = tokio::spawn(async move {
waiter_2.count().await;
});
// the reason why we are not using here:
// _ = tokio::join!(task, task_2);
// is to make this test comparable to `test_adaptive_barrier_wait`
tokio::time::sleep(Duration::from_millis(1000)).await;
{
// Add 0 to counter. But this will not be the first number
// since `task` and `task_2` were already able to add to
// the counter.
counter.lock().await.push(0);
}
// both tasks must be finished now
assert!(task.is_finished() && task_2.is_finished());
let task_3 = tokio::spawn(async move {
waiter_3.count().await;
});
_ = tokio::join!(task, task_2, task_3);
assert_ne!(*counter.lock().await.first().unwrap(), 0);
}
// Now we use `count_wait` instead of `count` so 0 should be the first number
#[tokio::test]
async fn test_adaptive_barrier_wait() {
let counter = Arc::new(Mutex::new(vec![]));
let waiter = Waiter::new(&counter);
let waiter_2 = waiter.clone();
let waiter_3 = waiter.clone();
let task = tokio::spawn(async move {
waiter.count_wait().await;
});
let task_2 = tokio::spawn(async move {
waiter_2.count_wait().await;
});
tokio::time::sleep(Duration::from_millis(1000)).await;
{
counter.lock().await.push(0);
}
// both tasks must NOT be finished yet
assert!(!task.is_finished() && !task_2.is_finished());
// Now we wait for the last barrier, so all tasks can start counting
let task_3 = tokio::spawn(async move {
waiter_3.count_wait().await;
});
_ = tokio::join!(task, task_2, task_3);
assert_eq!(*counter.lock().await.first().unwrap(), 0);
}
}

View File

@@ -0,0 +1,80 @@
use bytes::BytesMut;
use prost::Message;
use std::marker::PhantomData;
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
#[derive(Debug, Clone)]
pub struct ProstCodec<T, U>(PhantomData<T>, PhantomData<U>);
impl<T, U: Message> Default for ProstCodec<T, U> {
fn default() -> Self {
ProstCodec(PhantomData, PhantomData)
}
}
impl<T, U: Message + From<T>> Encoder<T> for ProstCodec<T, U> {
type Error = std::io::Error;
fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
U::from(item)
.encode(buf)
.expect("Message only errors if not enough space");
Ok(())
}
}
impl<T: TryFrom<U>, U: Message + Default> Decoder for ProstCodec<T, U>
where
std::io::Error: From<<T as TryFrom<U>>::Error>,
{
type Item = T;
type Error = std::io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let item: U = Message::decode(buf)?;
Ok(Some(T::try_from(item)?))
}
}
#[derive(Debug, Clone)]
pub struct ProstCodecDelimited<T, U> {
_t: (PhantomData<T>, PhantomData<U>),
inner: LengthDelimitedCodec,
}
impl<T, U: Message> Default for ProstCodecDelimited<T, U> {
fn default() -> Self {
ProstCodecDelimited {
_t: (PhantomData, PhantomData),
inner: LengthDelimitedCodec::new(),
}
}
}
impl<T, U: Message + From<T>> Encoder<T> for ProstCodecDelimited<T, U> {
type Error = std::io::Error;
fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
self.inner.encode(U::from(item).encode_to_vec().into(), buf)
}
}
impl<T: TryFrom<U>, U: Message + Default> Decoder for ProstCodecDelimited<T, U>
where
std::io::Error: From<<T as TryFrom<U>>::Error>,
{
type Item = T;
type Error = std::io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let b = self.inner.decode(buf)?;
if let Some(b) = b {
let item: U = Message::decode(b)?;
Ok(Some(T::try_from(item)?))
} else {
Ok(None)
}
}
}

View File

@@ -0,0 +1,86 @@
use futures::{channel::mpsc, AsyncRead, AsyncWrite, Sink, Stream};
use std::{
io::{Error, ErrorKind},
pin::Pin,
};
pub trait DuplexByteStream: AsyncWrite + AsyncRead + Unpin {}
impl<T> DuplexByteStream for T where T: AsyncWrite + AsyncRead + Unpin {}
#[derive(Debug)]
pub struct DuplexChannel<T> {
sink: mpsc::Sender<T>,
stream: mpsc::Receiver<T>,
}
impl<T> DuplexChannel<T>
where
T: Send + 'static,
{
pub fn new() -> (Self, Self) {
let (sender, receiver) = mpsc::channel(10);
let (sender_2, receiver_2) = mpsc::channel(10);
(
Self {
sink: sender,
stream: receiver_2,
},
Self {
sink: sender_2,
stream: receiver,
},
)
}
}
impl<T> Sink<T> for DuplexChannel<T>
where
T: Send + 'static,
{
type Error = std::io::Error;
fn poll_ready(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_ready(cx)
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
}
fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::new(&mut self.sink)
.start_send(item)
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_flush(cx)
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink)
.poll_close(cx)
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
}
}
impl<T> Stream for DuplexChannel<T> {
type Item = T;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}

View File

@@ -0,0 +1,22 @@
/// Extract expected variant of an enum and handle errors
///
/// This macro is intended to simplify extracting the expected message
/// when doing communication.
/// - The first argument is the expression, which is matched
/// - the second argument is the expected enum variant
/// - the last argument is error which is retuned when the expected message is not present
///
/// The error needs to implement From for std::io::Error
#[macro_export]
macro_rules! expect_msg_or_err {
($match: expr, $expected: path, $err: path) => {
match $match {
Some($expected(msg)) => Ok(msg),
Some(other) => Err($err(other)),
None => Err(From::from(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"stream closed unexpectedly",
))),
}
};
}

View File

@@ -0,0 +1,14 @@
use async_trait::async_trait;
/// This trait is for factories which produce their items asynchronously
#[async_trait]
pub trait AsyncFactory<T> {
type Config;
type Error;
/// Creates new instance
///
/// * `id` - Unique ID of instance
/// * `config` - Instance configuration
async fn create(&mut self, id: String, config: Self::Config) -> Result<T, Self::Error>;
}

View File

@@ -0,0 +1,13 @@
pub mod adaptive_barrier;
#[cfg(feature = "codec")]
pub mod codec;
#[cfg(feature = "duplex")]
pub mod duplex;
pub mod expect_msg;
pub mod factory;
#[cfg(feature = "mux")]
pub mod mux;
pub trait Channel<T>: futures::Stream<Item = T> + futures::Sink<T> + Send + Unpin {}
impl<T, U> Channel<T> for U where U: futures::Stream<Item = T> + futures::Sink<T> + Send + Unpin {}

View File

@@ -0,0 +1,37 @@
use super::Channel;
use crate::duplex::DuplexByteStream;
use async_trait::async_trait;
#[derive(Debug, thiserror::Error)]
pub enum MuxerError {
#[error("Connection error occurred: {0}")]
ConnectionError(String),
#[error("IO error")]
IOError(#[from] std::io::Error),
#[error("Duplicate stream id: {0:?}")]
DuplicateStreamId(String),
#[error("Encountered internal error: {0:?}")]
InternalError(String),
}
#[async_trait]
pub trait MuxControl: Clone {
/// Opens a new stream with the remote using the provided id
async fn get_stream(
&mut self,
id: String,
) -> Result<Box<dyn DuplexByteStream + Send>, MuxerError>;
}
/// This trait is similar to [`MuxControl`] except it provides a stream
/// with a codec attached which handles serialization.
#[async_trait]
pub trait MuxChannelControl<T> {
/// Opens a new channel with the remote using the provided id
///
/// Attaches a codec to the underlying stream
async fn get_channel(
&mut self,
id: String,
) -> Result<Box<dyn Channel<T, Error = std::io::Error>>, MuxerError>;
}

12
utils/utils/Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
name = "tlsn-utils"
version = "0.1.0"
edition = "2021"
[lib]
name = "utils"
[dependencies]
[dev-dependencies]
rand = { workspace = true }