mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-09 12:48:03 -05:00
Partition Workspaces
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
[package]
|
||||
name = "tlsn-utils"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
[workspace]
|
||||
members = ["utils", "utils-aio"]
|
||||
|
||||
[lib]
|
||||
name = "utils"
|
||||
|
||||
[dependencies]
|
||||
|
||||
[dev-dependencies]
|
||||
rand = { workspace = true }
|
||||
[workspace.dependencies]
|
||||
rand = "0.8"
|
||||
thiserror = "1"
|
||||
async-trait = "0.1"
|
||||
prost = "0.9"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
tokio-util = "0.7"
|
||||
tokio = "1.23"
|
||||
async-tungstenite = "0.16"
|
||||
prost-build = "0.9"
|
||||
bytes = "1"
|
||||
async-std = "1"
|
||||
|
||||
34
utils/utils-aio/Cargo.toml
Normal file
34
utils/utils-aio/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "tlsn-utils-aio"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "utils_aio"
|
||||
|
||||
[features]
|
||||
default = ["mux", "duplex"]
|
||||
codec = []
|
||||
mux = []
|
||||
duplex = []
|
||||
|
||||
[dependencies]
|
||||
bytes.workspace = true
|
||||
prost.workspace = true
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
tokio-util = { workspace = true, features = ["codec", "compat"] }
|
||||
async-tungstenite.workspace = true
|
||||
futures.workspace = true
|
||||
futures-util.workspace = true
|
||||
async-trait.workspace = true
|
||||
thiserror.workspace = true
|
||||
async-std.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = [
|
||||
"macros",
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"io-util",
|
||||
"time",
|
||||
] }
|
||||
150
utils/utils-aio/src/adaptive_barrier.rs
Normal file
150
utils/utils-aio/src/adaptive_barrier.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use tokio::sync::broadcast::{channel, error::RecvError, Sender};
|
||||
|
||||
/// An adaptive barrier
|
||||
///
|
||||
/// This allows to change the number of barriers dynamically.
|
||||
/// Code is taken from https://users.rust-lang.org/t/a-poor-man-async-adaptive-barrier/68118
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdaptiveBarrier {
|
||||
inner: Sender<Empty>,
|
||||
}
|
||||
|
||||
impl AdaptiveBarrier {
|
||||
/// Wait in order to perform task synchronization
|
||||
///
|
||||
/// Waits for all other barriers who have been cloned from this one
|
||||
/// to also call `wait`
|
||||
pub async fn wait(self) {
|
||||
let mut receiver = self.inner.subscribe();
|
||||
drop(self.inner);
|
||||
match receiver.recv().await {
|
||||
Ok(_) => unreachable!(),
|
||||
Err(RecvError::Lagged(_)) => unreachable!(),
|
||||
Err(RecvError::Closed) => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
// even though we will not receive any data, we still
|
||||
// need to set the channel's capacity to the required minimum 1
|
||||
inner: channel(1).0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AdaptiveBarrier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum Empty {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::AdaptiveBarrier;
|
||||
use std::{mem::replace, sync::Arc, time::Duration};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Waiter {
|
||||
counter: Arc<Mutex<Vec<usize>>>,
|
||||
barrier: AdaptiveBarrier,
|
||||
}
|
||||
|
||||
impl Waiter {
|
||||
fn new(counter: &Arc<Mutex<Vec<usize>>>) -> Self {
|
||||
Self {
|
||||
counter: Arc::clone(counter),
|
||||
barrier: AdaptiveBarrier::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// add a new value to the counter. if no values are present in the
|
||||
// counter, the new value will be 1.
|
||||
async fn count(self) {
|
||||
let mut counter = self.counter.lock().await;
|
||||
let last = counter.last().copied().unwrap_or_default();
|
||||
counter.push(last + 1);
|
||||
}
|
||||
|
||||
async fn count_wait(mut self) {
|
||||
// use `replace()` because we can't call `self.barrier.wait().await;` here
|
||||
let barrier = replace(&mut self.barrier, AdaptiveBarrier::new());
|
||||
barrier.wait().await;
|
||||
self.count().await;
|
||||
}
|
||||
}
|
||||
|
||||
// We expect that 0 is not the first number in the counter because we do not use
|
||||
// the barrier in this test
|
||||
#[tokio::test]
|
||||
async fn test_adaptive_barrier_no_wait() {
|
||||
let counter = Arc::new(Mutex::new(vec![]));
|
||||
|
||||
let waiter = Waiter::new(&counter);
|
||||
let waiter_2 = waiter.clone();
|
||||
let waiter_3 = waiter.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
waiter.count().await;
|
||||
});
|
||||
let task_2 = tokio::spawn(async move {
|
||||
waiter_2.count().await;
|
||||
});
|
||||
|
||||
// the reason why we are not using here:
|
||||
// _ = tokio::join!(task, task_2);
|
||||
// is to make this test comparable to `test_adaptive_barrier_wait`
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
{
|
||||
// Add 0 to counter. But this will not be the first number
|
||||
// since `task` and `task_2` were already able to add to
|
||||
// the counter.
|
||||
counter.lock().await.push(0);
|
||||
}
|
||||
|
||||
// both tasks must be finished now
|
||||
assert!(task.is_finished() && task_2.is_finished());
|
||||
|
||||
let task_3 = tokio::spawn(async move {
|
||||
waiter_3.count().await;
|
||||
});
|
||||
_ = tokio::join!(task, task_2, task_3);
|
||||
assert_ne!(*counter.lock().await.first().unwrap(), 0);
|
||||
}
|
||||
|
||||
// Now we use `count_wait` instead of `count` so 0 should be the first number
|
||||
#[tokio::test]
|
||||
async fn test_adaptive_barrier_wait() {
|
||||
let counter = Arc::new(Mutex::new(vec![]));
|
||||
|
||||
let waiter = Waiter::new(&counter);
|
||||
let waiter_2 = waiter.clone();
|
||||
let waiter_3 = waiter.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
waiter.count_wait().await;
|
||||
});
|
||||
let task_2 = tokio::spawn(async move {
|
||||
waiter_2.count_wait().await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
{
|
||||
counter.lock().await.push(0);
|
||||
}
|
||||
|
||||
// both tasks must NOT be finished yet
|
||||
assert!(!task.is_finished() && !task_2.is_finished());
|
||||
|
||||
// Now we wait for the last barrier, so all tasks can start counting
|
||||
let task_3 = tokio::spawn(async move {
|
||||
waiter_3.count_wait().await;
|
||||
});
|
||||
_ = tokio::join!(task, task_2, task_3);
|
||||
assert_eq!(*counter.lock().await.first().unwrap(), 0);
|
||||
}
|
||||
}
|
||||
80
utils/utils-aio/src/codec.rs
Normal file
80
utils/utils-aio/src/codec.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use bytes::BytesMut;
|
||||
use prost::Message;
|
||||
use std::marker::PhantomData;
|
||||
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProstCodec<T, U>(PhantomData<T>, PhantomData<U>);
|
||||
|
||||
impl<T, U: Message> Default for ProstCodec<T, U> {
|
||||
fn default() -> Self {
|
||||
ProstCodec(PhantomData, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U: Message + From<T>> Encoder<T> for ProstCodec<T, U> {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
U::from(item)
|
||||
.encode(buf)
|
||||
.expect("Message only errors if not enough space");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<U>, U: Message + Default> Decoder for ProstCodec<T, U>
|
||||
where
|
||||
std::io::Error: From<<T as TryFrom<U>>::Error>,
|
||||
{
|
||||
type Item = T;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
let item: U = Message::decode(buf)?;
|
||||
|
||||
Ok(Some(T::try_from(item)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProstCodecDelimited<T, U> {
|
||||
_t: (PhantomData<T>, PhantomData<U>),
|
||||
inner: LengthDelimitedCodec,
|
||||
}
|
||||
|
||||
impl<T, U: Message> Default for ProstCodecDelimited<T, U> {
|
||||
fn default() -> Self {
|
||||
ProstCodecDelimited {
|
||||
_t: (PhantomData, PhantomData),
|
||||
inner: LengthDelimitedCodec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U: Message + From<T>> Encoder<T> for ProstCodecDelimited<T, U> {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
self.inner.encode(U::from(item).encode_to_vec().into(), buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<U>, U: Message + Default> Decoder for ProstCodecDelimited<T, U>
|
||||
where
|
||||
std::io::Error: From<<T as TryFrom<U>>::Error>,
|
||||
{
|
||||
type Item = T;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
let b = self.inner.decode(buf)?;
|
||||
if let Some(b) = b {
|
||||
let item: U = Message::decode(b)?;
|
||||
Ok(Some(T::try_from(item)?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
86
utils/utils-aio/src/duplex.rs
Normal file
86
utils/utils-aio/src/duplex.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use futures::{channel::mpsc, AsyncRead, AsyncWrite, Sink, Stream};
|
||||
use std::{
|
||||
io::{Error, ErrorKind},
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
pub trait DuplexByteStream: AsyncWrite + AsyncRead + Unpin {}
|
||||
|
||||
impl<T> DuplexByteStream for T where T: AsyncWrite + AsyncRead + Unpin {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DuplexChannel<T> {
|
||||
sink: mpsc::Sender<T>,
|
||||
stream: mpsc::Receiver<T>,
|
||||
}
|
||||
|
||||
impl<T> DuplexChannel<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
pub fn new() -> (Self, Self) {
|
||||
let (sender, receiver) = mpsc::channel(10);
|
||||
let (sender_2, receiver_2) = mpsc::channel(10);
|
||||
(
|
||||
Self {
|
||||
sink: sender,
|
||||
stream: receiver_2,
|
||||
},
|
||||
Self {
|
||||
sink: sender_2,
|
||||
stream: receiver,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sink<T> for DuplexChannel<T>
|
||||
where
|
||||
T: Send + 'static,
|
||||
{
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll_ready(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_ready(cx)
|
||||
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
|
||||
}
|
||||
|
||||
fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.sink)
|
||||
.start_send(item)
|
||||
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_flush(cx)
|
||||
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.sink)
|
||||
.poll_close(cx)
|
||||
.map_err(|_| Error::new(ErrorKind::ConnectionAborted, "channel died"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for DuplexChannel<T> {
|
||||
type Item = T;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.stream).poll_next(cx)
|
||||
}
|
||||
}
|
||||
22
utils/utils-aio/src/expect_msg.rs
Normal file
22
utils/utils-aio/src/expect_msg.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
/// Extract expected variant of an enum and handle errors
|
||||
///
|
||||
/// This macro is intended to simplify extracting the expected message
|
||||
/// when doing communication.
|
||||
/// - The first argument is the expression, which is matched
|
||||
/// - the second argument is the expected enum variant
|
||||
/// - the last argument is error which is retuned when the expected message is not present
|
||||
///
|
||||
/// The error needs to implement From for std::io::Error
|
||||
#[macro_export]
|
||||
macro_rules! expect_msg_or_err {
|
||||
($match: expr, $expected: path, $err: path) => {
|
||||
match $match {
|
||||
Some($expected(msg)) => Ok(msg),
|
||||
Some(other) => Err($err(other)),
|
||||
None => Err(From::from(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
"stream closed unexpectedly",
|
||||
))),
|
||||
}
|
||||
};
|
||||
}
|
||||
14
utils/utils-aio/src/factory.rs
Normal file
14
utils/utils-aio/src/factory.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// This trait is for factories which produce their items asynchronously
|
||||
#[async_trait]
|
||||
pub trait AsyncFactory<T> {
|
||||
type Config;
|
||||
type Error;
|
||||
|
||||
/// Creates new instance
|
||||
///
|
||||
/// * `id` - Unique ID of instance
|
||||
/// * `config` - Instance configuration
|
||||
async fn create(&mut self, id: String, config: Self::Config) -> Result<T, Self::Error>;
|
||||
}
|
||||
13
utils/utils-aio/src/lib.rs
Normal file
13
utils/utils-aio/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
pub mod adaptive_barrier;
|
||||
#[cfg(feature = "codec")]
|
||||
pub mod codec;
|
||||
#[cfg(feature = "duplex")]
|
||||
pub mod duplex;
|
||||
pub mod expect_msg;
|
||||
pub mod factory;
|
||||
#[cfg(feature = "mux")]
|
||||
pub mod mux;
|
||||
|
||||
pub trait Channel<T>: futures::Stream<Item = T> + futures::Sink<T> + Send + Unpin {}
|
||||
|
||||
impl<T, U> Channel<T> for U where U: futures::Stream<Item = T> + futures::Sink<T> + Send + Unpin {}
|
||||
37
utils/utils-aio/src/mux.rs
Normal file
37
utils/utils-aio/src/mux.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use super::Channel;
|
||||
use crate::duplex::DuplexByteStream;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MuxerError {
|
||||
#[error("Connection error occurred: {0}")]
|
||||
ConnectionError(String),
|
||||
#[error("IO error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("Duplicate stream id: {0:?}")]
|
||||
DuplicateStreamId(String),
|
||||
#[error("Encountered internal error: {0:?}")]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait MuxControl: Clone {
|
||||
/// Opens a new stream with the remote using the provided id
|
||||
async fn get_stream(
|
||||
&mut self,
|
||||
id: String,
|
||||
) -> Result<Box<dyn DuplexByteStream + Send>, MuxerError>;
|
||||
}
|
||||
|
||||
/// This trait is similar to [`MuxControl`] except it provides a stream
|
||||
/// with a codec attached which handles serialization.
|
||||
#[async_trait]
|
||||
pub trait MuxChannelControl<T> {
|
||||
/// Opens a new channel with the remote using the provided id
|
||||
///
|
||||
/// Attaches a codec to the underlying stream
|
||||
async fn get_channel(
|
||||
&mut self,
|
||||
id: String,
|
||||
) -> Result<Box<dyn Channel<T, Error = std::io::Error>>, MuxerError>;
|
||||
}
|
||||
12
utils/utils/Cargo.toml
Normal file
12
utils/utils/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "tlsn-utils"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "utils"
|
||||
|
||||
[dependencies]
|
||||
|
||||
[dev-dependencies]
|
||||
rand = { workspace = true }
|
||||
Reference in New Issue
Block a user