mirror of
https://github.com/tlsnotary/tlsn-utils.git
synced 2026-01-08 20:28:06 -05:00
feat: uid-mux (#28)
* feat: uid-mux * fix deadlock * wake up after signalling close * appease rustfmt * doc: Create -> Creates * simplify Waker usage * test more cases, handle errors better * remove tracing-subscriber dep * Apply suggestions from code review Co-authored-by: dan <themighty1@users.noreply.github.com> * feat(uid-mux): framed mux (#29) * feat(uid-mux): framed mux * add test utils * Apply suggestions from code review Co-authored-by: dan <themighty1@users.noreply.github.com> --------- Co-authored-by: dan <themighty1@users.noreply.github.com> --------- Co-authored-by: dan <themighty1@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
[workspace]
|
||||
members = ["utils", "utils-aio", "spansy", "serio"]
|
||||
members = ["utils", "utils-aio", "spansy", "serio", "uid-mux"]
|
||||
|
||||
[workspace.dependencies]
|
||||
tlsn-utils = { path = "utils" }
|
||||
tlsn-utils-aio = { path = "utils-aio" }
|
||||
spansy = { path = "spansy" }
|
||||
serio = { path = "serio" }
|
||||
uid-mux = { path = "uid-mux" }
|
||||
|
||||
rand = "0.8"
|
||||
thiserror = "1"
|
||||
@@ -14,6 +15,7 @@ prost = "0.9"
|
||||
futures = "0.3"
|
||||
futures-sink = "0.3"
|
||||
futures-core = "0.3"
|
||||
futures-io = "0.3"
|
||||
futures-channel = "0.3"
|
||||
futures-util = "0.3"
|
||||
tokio-util = "0.7"
|
||||
@@ -28,3 +30,5 @@ serde = "1"
|
||||
cfg-if = "1"
|
||||
bincode = "1.3"
|
||||
pin-project-lite = "0.2"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
@@ -7,7 +7,7 @@ edition = "2021"
|
||||
default = ["compat", "channel", "codec", "bincode"]
|
||||
compat = ["dep:futures-sink"]
|
||||
channel = ["dep:futures-channel"]
|
||||
codec = ["dep:tokio-util"]
|
||||
codec = ["dep:tokio-util", "dep:futures-io"]
|
||||
bincode = ["dep:bincode"]
|
||||
|
||||
[dependencies]
|
||||
@@ -15,10 +15,14 @@ bytes.workspace = true
|
||||
serde.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
futures-core.workspace = true
|
||||
futures-io = { workspace = true, optional = true }
|
||||
futures-channel = { workspace = true, optional = true }
|
||||
futures-sink = { workspace = true, optional = true }
|
||||
futures-util = { workspace = true, features = ["bilock", "unstable"] }
|
||||
tokio-util = { workspace = true, features = ["codec"], optional = true }
|
||||
tokio-util = { workspace = true, features = [
|
||||
"codec",
|
||||
"compat",
|
||||
], optional = true }
|
||||
bincode = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -8,8 +8,18 @@ use std::{
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures_core::stream::TryStream;
|
||||
use futures_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::{Deserialize, Serialize, Sink, Stream};
|
||||
use crate::{Deserialize, IoDuplex, Serialize, Sink, Stream};
|
||||
|
||||
/// A codec.
|
||||
pub trait Codec<Io> {
|
||||
/// The framed transport type.
|
||||
type Framed: IoDuplex;
|
||||
|
||||
/// Creates a new framed transport with the given IO.
|
||||
fn new_framed(&self, io: Io) -> Self::Framed;
|
||||
}
|
||||
|
||||
/// A serializer.
|
||||
pub trait Serializer {
|
||||
@@ -53,6 +63,25 @@ mod bincode_impl {
|
||||
Ok(deserialize(buf)?)
|
||||
}
|
||||
}
|
||||
|
||||
use tokio_util::{
|
||||
codec::{Framed as TokioFramed, LengthDelimitedCodec},
|
||||
compat::{Compat, FuturesAsyncReadCompatExt as _},
|
||||
};
|
||||
|
||||
impl<Io> Codec<Io> for Bincode
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Framed = Framed<TokioFramed<Compat<Io>, LengthDelimitedCodec>, Self>;
|
||||
|
||||
fn new_framed(&self, io: Io) -> Self::Framed {
|
||||
Framed::new(
|
||||
LengthDelimitedCodec::builder().new_framed(io.compat()),
|
||||
self.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bincode")]
|
||||
@@ -133,7 +162,7 @@ where
|
||||
mod tests {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
use crate::{SinkExt, StreamExt};
|
||||
|
||||
@@ -149,11 +178,8 @@ mod tests {
|
||||
fn test_framed() {
|
||||
let (a, b) = duplex(1024);
|
||||
|
||||
let a = LengthDelimitedCodec::builder().new_framed(a);
|
||||
let b = LengthDelimitedCodec::builder().new_framed(b);
|
||||
|
||||
let mut a = Framed::new(a, Bincode::default());
|
||||
let mut b = Framed::new(b, Bincode::default());
|
||||
let mut a = Bincode::default().new_framed(a.compat());
|
||||
let mut b = Bincode::default().new_framed(b.compat());
|
||||
|
||||
let a = async {
|
||||
a.send(Ping).await.unwrap();
|
||||
|
||||
34
uid-mux/Cargo.toml
Normal file
34
uid-mux/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[package]
|
||||
name = "uid-mux"
|
||||
version = "0.1.0"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "Async multiplexing library."
|
||||
keywords = ["multiplex", "channel", "futures", "async"]
|
||||
categories = ["network-programming", "asynchronous"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["tracing", "serio"]
|
||||
tracing = ["dep:tracing"]
|
||||
serio = ["dep:serio"]
|
||||
test-utils = ["tokio/io-util", "dep:tokio-util"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
futures = { workspace = true }
|
||||
tracing = { workspace = true, optional = true }
|
||||
yamux = "0.13"
|
||||
blake3 = "1.5"
|
||||
hex = "0.4"
|
||||
serio = { workspace = true, optional = true }
|
||||
tokio-util = { version = "0.7", features = ["compat"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-util",
|
||||
"rt-multi-thread",
|
||||
"macros",
|
||||
] }
|
||||
202
uid-mux/src/future.rs
Normal file
202
uid-mux/src/future.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use std::{
|
||||
pin::{pin, Pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{ready, AsyncRead, AsyncWrite, Future};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
log::{error, trace},
|
||||
InternalId,
|
||||
};
|
||||
|
||||
const BUF: usize = 32;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Inner<Io> {
|
||||
io: Io,
|
||||
count: u8,
|
||||
id: [u8; BUF],
|
||||
}
|
||||
|
||||
impl<Io> Inner<Io> {
|
||||
fn is_done(&self) -> bool {
|
||||
self.count == 32
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State<Io> {
|
||||
Pending(Inner<Io>),
|
||||
Error,
|
||||
}
|
||||
|
||||
impl<Io> State<Io> {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, Self::Error)
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that resolves when an id has been read.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub(crate) struct ReadId<Io>(State<Io>);
|
||||
|
||||
impl<Io> ReadId<Io> {
|
||||
/// Creates a new `ReadId` future.
|
||||
pub(crate) fn new(io: Io) -> Self {
|
||||
Self(State::Pending(Inner {
|
||||
io,
|
||||
count: 0,
|
||||
id: [0u8; BUF],
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Future for ReadId<Io>
|
||||
where
|
||||
Io: AsyncRead + Unpin,
|
||||
{
|
||||
type Output = Result<(InternalId, Io), std::io::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let State::Pending(mut state) = self.0.take() else {
|
||||
panic!("poll after completion");
|
||||
};
|
||||
|
||||
while let Poll::Ready(read) =
|
||||
pin!(&mut state.io).poll_read(cx, &mut state.id[state.count as usize..])?
|
||||
{
|
||||
state.count += read as u8;
|
||||
if state.is_done() {
|
||||
let id = InternalId(state.id);
|
||||
trace!("read id: {}", id);
|
||||
return Poll::Ready(Ok((id, state.io)));
|
||||
} else if read == 0 {
|
||||
error!("remote closed before sending id");
|
||||
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()));
|
||||
}
|
||||
}
|
||||
|
||||
self.0 = State::Pending(state);
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that resolves when an id has been written.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub(crate) struct WriteId<Io>(State<Io>);
|
||||
|
||||
impl<Io> WriteId<Io> {
|
||||
/// Creates a new `WriteId` future.
|
||||
pub(crate) fn new(io: Io, id: InternalId) -> Self {
|
||||
Self(State::Pending(Inner {
|
||||
io,
|
||||
count: 0,
|
||||
id: id.0,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Future for WriteId<Io>
|
||||
where
|
||||
Io: AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = Result<Io, std::io::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let State::Pending(mut state) = self.0.take() else {
|
||||
panic!("poll after completion");
|
||||
};
|
||||
|
||||
// If we haven't finished sending the id, keep sending it.
|
||||
if !state.is_done() {
|
||||
while let Poll::Ready(sent) =
|
||||
pin!(&mut state.io).poll_write(cx, &state.id[state.count as usize..])?
|
||||
{
|
||||
state.count += sent as u8;
|
||||
if state.is_done() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we've finished sending, flush the write buffer. If flushing
|
||||
// succeeds then we can return Ready, otherwise we need to keep
|
||||
// trying.
|
||||
if state.is_done() {
|
||||
if pin!(&mut state.io).poll_flush(cx)?.is_ready() {
|
||||
return Poll::Ready(Ok(state.io));
|
||||
}
|
||||
}
|
||||
|
||||
self.0 = State::Pending(state);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that resolves when a stream has been returned to the caller.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub(crate) struct ReturnStream<Io> {
|
||||
fut: WriteId<Io>,
|
||||
sender: Option<oneshot::Sender<Io>>,
|
||||
}
|
||||
|
||||
impl<Io> ReturnStream<Io> {
|
||||
/// Creates a new `ReturnStream` future.
|
||||
pub(crate) fn new(id: InternalId, io: Io, sender: oneshot::Sender<Io>) -> Self {
|
||||
Self {
|
||||
fut: WriteId::new(io, id),
|
||||
sender: Some(sender),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Future for ReturnStream<Io>
|
||||
where
|
||||
Io: AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = Result<(), std::io::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let io = ready!(pin!(&mut self.fut).poll(cx))?;
|
||||
|
||||
_ = self
|
||||
.sender
|
||||
.take()
|
||||
.expect("future not polled after completion")
|
||||
.send(io);
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt as _;
|
||||
|
||||
#[test]
|
||||
fn test_id_future() {
|
||||
let id_0 = InternalId([42u8; 32]);
|
||||
|
||||
// send 1 byte at a time
|
||||
let (io_0, io_1) = duplex(1);
|
||||
|
||||
futures::executor::block_on(async {
|
||||
let (_, (id_1, _)) = futures::try_join!(
|
||||
WriteId::new(io_0.compat(), id_0),
|
||||
ReadId::new(io_1.compat())
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(id_0, id_1);
|
||||
});
|
||||
}
|
||||
}
|
||||
109
uid-mux/src/lib.rs
Normal file
109
uid-mux/src/lib.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
//! Multiplexing with unique channel ids.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub(crate) mod future;
|
||||
#[cfg(feature = "serio")]
|
||||
mod serio;
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
pub mod test_utils;
|
||||
pub mod yamux;
|
||||
|
||||
#[cfg(feature = "serio")]
|
||||
pub use serio::{FramedMux, FramedUidMux};
|
||||
|
||||
use core::fmt;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// Internal stream identifier.
|
||||
///
|
||||
/// User provided ids are hashed to a fixed length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct InternalId([u8; 32]);
|
||||
|
||||
impl InternalId {
|
||||
/// Creates a new `InternalId` from a byte slice.
|
||||
pub(crate) fn new(bytes: &[u8]) -> Self {
|
||||
Self(blake3::hash(bytes).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for InternalId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
for byte in &self.0[..4] {
|
||||
write!(f, "{:02x}", byte)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for InternalId {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// A multiplexer that opens streams with unique ids.
|
||||
#[async_trait]
|
||||
pub trait UidMux<Id> {
|
||||
/// Stream type.
|
||||
type Stream: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static;
|
||||
/// Error type.
|
||||
type Error;
|
||||
|
||||
/// Open a new stream with the given id.
|
||||
async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error>;
|
||||
}
|
||||
|
||||
pub(crate) mod log {
|
||||
macro_rules! error {
|
||||
($( $tokens:tt )*) => {
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::error!($( $tokens )*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! warn_ {
|
||||
($( $tokens:tt )*) => {
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!($( $tokens )*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! trace {
|
||||
($( $tokens:tt )*) => {
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!($( $tokens )*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! debug {
|
||||
($( $tokens:tt )*) => {
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::debug!($( $tokens )*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! info {
|
||||
($( $tokens:tt )*) => {
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::info!($( $tokens )*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use {debug, error, info, trace, warn_ as warn};
|
||||
}
|
||||
131
uid-mux/src/serio.rs
Normal file
131
uid-mux/src/serio.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use ::serio::{codec::Codec, IoDuplex};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::UidMux;
|
||||
|
||||
/// A multiplexer that opens framed streams with unique ids.
|
||||
#[async_trait]
|
||||
pub trait FramedUidMux<Id> {
|
||||
/// Stream type.
|
||||
type Framed: IoDuplex;
|
||||
/// Error type.
|
||||
type Error;
|
||||
|
||||
/// Opens a new framed stream with the given id.
|
||||
async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error>;
|
||||
}
|
||||
|
||||
/// A framed multiplexer.
|
||||
#[derive(Debug)]
|
||||
pub struct FramedMux<M, C> {
|
||||
mux: M,
|
||||
codec: C,
|
||||
}
|
||||
|
||||
impl<M, C> FramedMux<M, C> {
|
||||
/// Creates a new `FramedMux`.
|
||||
pub fn new(mux: M, codec: C) -> Self {
|
||||
Self { mux, codec }
|
||||
}
|
||||
|
||||
/// Returns a reference to the mux.
|
||||
pub fn mux(&self) -> &M {
|
||||
&self.mux
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the mux.
|
||||
pub fn mux_mut(&mut self) -> &mut M {
|
||||
&mut self.mux
|
||||
}
|
||||
|
||||
/// Returns a reference to the codec.
|
||||
pub fn codec(&self) -> &C {
|
||||
&self.codec
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the codec.
|
||||
pub fn codec_mut(&mut self) -> &mut C {
|
||||
&mut self.codec
|
||||
}
|
||||
|
||||
/// Splits the `FramedMux` into its parts.
|
||||
pub fn into_parts(self) -> (M, C) {
|
||||
(self.mux, self.codec)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Id, M, C> FramedUidMux<Id> for FramedMux<M, C>
|
||||
where
|
||||
Id: Sync,
|
||||
M: UidMux<Id> + Sync,
|
||||
C: Codec<<M as UidMux<Id>>::Stream> + Sync,
|
||||
{
|
||||
/// Stream type.
|
||||
type Framed = <C as Codec<<M as UidMux<Id>>::Stream>>::Framed;
|
||||
/// Error type.
|
||||
type Error = <M as UidMux<Id>>::Error;
|
||||
|
||||
/// Opens a new framed stream with the given id.
|
||||
async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error> {
|
||||
let stream = self.mux.open(id).await?;
|
||||
Ok(self.codec.new_framed(stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Clone, C: Clone> Clone for FramedMux<M, C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mux: self.mux.clone(),
|
||||
codec: self.codec.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::future::IntoFuture;
|
||||
|
||||
use super::*;
|
||||
use crate::yamux::{Config, Mode, Yamux};
|
||||
|
||||
use ::serio::codec::Bincode;
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_mux() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let client_ctrl = FramedMux::new(client.control(), Bincode);
|
||||
let server_ctrl = FramedMux::new(server.control(), Bincode);
|
||||
|
||||
let conn_task = tokio::spawn(async {
|
||||
futures::try_join!(client.into_future(), server.into_future()).unwrap();
|
||||
});
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
let mut stream = client_ctrl.open_framed(b"test").await.unwrap();
|
||||
|
||||
stream.send(42u128).await.unwrap();
|
||||
|
||||
client_ctrl.mux().close();
|
||||
},
|
||||
async {
|
||||
let mut stream = server_ctrl.open_framed(b"test").await.unwrap();
|
||||
|
||||
let num: u128 = stream.expect_next().await.unwrap();
|
||||
|
||||
server_ctrl.mux().close();
|
||||
|
||||
assert_eq!(num, 42u128);
|
||||
}
|
||||
);
|
||||
|
||||
conn_task.await.unwrap();
|
||||
}
|
||||
}
|
||||
47
uid-mux/src/test_utils.rs
Normal file
47
uid-mux/src/test_utils.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Test utilities.
|
||||
|
||||
use tokio::io::{duplex, DuplexStream};
|
||||
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
|
||||
use yamux::{Config, Mode};
|
||||
|
||||
use crate::{
|
||||
yamux::{Yamux, YamuxCtrl},
|
||||
FramedMux,
|
||||
};
|
||||
|
||||
/// Creates a test pair of yamux instances.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buffer` - The buffer size.
|
||||
pub fn test_yamux_pair(
|
||||
buffer: usize,
|
||||
) -> (Yamux<Compat<DuplexStream>>, Yamux<Compat<DuplexStream>>) {
|
||||
let (a, b) = duplex(buffer);
|
||||
|
||||
let a = Yamux::new(a.compat(), Config::default(), Mode::Client);
|
||||
let b = Yamux::new(b.compat(), Config::default(), Mode::Server);
|
||||
|
||||
(a, b)
|
||||
}
|
||||
|
||||
/// Creates a test pair of framed yamux instances.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buffer` - The buffer size.
|
||||
/// * `codec` - The codec.
|
||||
pub fn test_yamux_pair_framed<C: Clone>(
|
||||
buffer: usize,
|
||||
codec: C,
|
||||
) -> (
|
||||
(FramedMux<YamuxCtrl, C>, Yamux<Compat<DuplexStream>>),
|
||||
(FramedMux<YamuxCtrl, C>, Yamux<Compat<DuplexStream>>),
|
||||
) {
|
||||
let (a, b) = test_yamux_pair(buffer);
|
||||
|
||||
let ctrl_a = FramedMux::new(a.control(), codec.clone());
|
||||
let ctrl_b = FramedMux::new(b.control(), codec);
|
||||
|
||||
((ctrl_a, a), (ctrl_b, b))
|
||||
}
|
||||
543
uid-mux/src/yamux.rs
Normal file
543
uid-mux/src/yamux.rs
Normal file
@@ -0,0 +1,543 @@
|
||||
//! Yamux multiplexer.
|
||||
//!
|
||||
//! This module provides a [`yamux`](https://crates.io/crates/yamux) wrapper which implements [`UidMux`](crate::UidMux).
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
future::IntoFuture,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesUnordered, AsyncRead, AsyncWrite, Future, FutureExt, StreamExt};
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use yamux::Connection;
|
||||
|
||||
use crate::{
|
||||
future::{ReadId, ReturnStream},
|
||||
log::{debug, error, info, trace, warn},
|
||||
InternalId, UidMux,
|
||||
};
|
||||
|
||||
pub use yamux::{Config, ConnectionError, Mode, Stream};
|
||||
|
||||
type Result<T, E = ConnectionError> = std::result::Result<T, E>;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Role {
|
||||
Client,
|
||||
Server,
|
||||
}
|
||||
|
||||
impl fmt::Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Role::Client => write!(f, "Client"),
|
||||
Role::Server => write!(f, "Server"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A yamux multiplexer.
|
||||
#[derive(Debug)]
|
||||
pub struct Yamux<Io> {
|
||||
role: Role,
|
||||
conn: Connection<Io>,
|
||||
queue: Arc<Mutex<Queue>>,
|
||||
close_notify: Arc<Notify>,
|
||||
shutdown_notify: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Queue {
|
||||
waiting: HashMap<InternalId, oneshot::Sender<Stream>>,
|
||||
ready: HashMap<InternalId, Stream>,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl Default for Queue {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
waiting: Default::default(),
|
||||
ready: Default::default(),
|
||||
waker: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Yamux<Io> {
|
||||
/// Returns a new control handle.
|
||||
pub fn control(&self) -> YamuxCtrl {
|
||||
YamuxCtrl {
|
||||
role: self.role,
|
||||
queue: self.queue.clone(),
|
||||
close_notify: self.close_notify.clone(),
|
||||
shutdown_notify: self.shutdown_notify.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Yamux<Io>
|
||||
where
|
||||
Io: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
/// Creates a new yamux multiplexer.
|
||||
pub fn new(io: Io, config: Config, mode: Mode) -> Self {
|
||||
let role = match mode {
|
||||
Mode::Client => Role::Client,
|
||||
Mode::Server => Role::Server,
|
||||
};
|
||||
|
||||
Self {
|
||||
role,
|
||||
conn: Connection::new(io, config, mode),
|
||||
queue: Default::default(),
|
||||
close_notify: Default::default(),
|
||||
shutdown_notify: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> IntoFuture for Yamux<Io>
|
||||
where
|
||||
Io: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
type Output = Result<()>;
|
||||
type IntoFuture = YamuxFuture<Io>;
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
YamuxFuture {
|
||||
role: self.role,
|
||||
conn: self.conn,
|
||||
incoming: Default::default(),
|
||||
outgoing: Default::default(),
|
||||
queue: self.queue,
|
||||
closed: false,
|
||||
remote_closed: false,
|
||||
close_notify: self.close_notify,
|
||||
shutdown_notify: self.shutdown_notify,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A yamux connection future.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless you `.await` or poll them"]
|
||||
pub struct YamuxFuture<Io> {
|
||||
role: Role,
|
||||
conn: Connection<Io>,
|
||||
/// Pending incoming streams, waiting for ids to be received.
|
||||
incoming: FuturesUnordered<ReadId<Stream>>,
|
||||
/// Pending outgoing streams, waiting to send ids and return streams
|
||||
/// to callers.
|
||||
outgoing: FuturesUnordered<ReturnStream<Stream>>,
|
||||
queue: Arc<Mutex<Queue>>,
|
||||
/// Whether this side has closed the connection.
|
||||
closed: bool,
|
||||
/// Whether the remote has closed the connection.
|
||||
remote_closed: bool,
|
||||
close_notify: Arc<Notify>,
|
||||
shutdown_notify: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<Io> YamuxFuture<Io>
|
||||
where
|
||||
Io: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
|
||||
fn client_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
if let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
|
||||
if stream.is_some() {
|
||||
error!("client mux received incoming stream");
|
||||
return Err(
|
||||
std::io::Error::other("client mode cannot accept incoming streams").into(),
|
||||
);
|
||||
}
|
||||
|
||||
info!("remote closed connection");
|
||||
self.remote_closed = true;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
|
||||
fn client_handle_outbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
// Putting this in a block so the lock is released as soon as possible.
|
||||
{
|
||||
let mut queue = self.queue.lock().unwrap();
|
||||
while !queue.waiting.is_empty() {
|
||||
if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
|
||||
let id = *queue.waiting.keys().next().unwrap();
|
||||
let sender = queue.waiting.remove(&id).unwrap();
|
||||
|
||||
debug!("opened new stream: {}", id);
|
||||
|
||||
self.outgoing.push(ReturnStream::new(id, stream, sender));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Set the waker so `YamuxCtrl` can wake up the connection.
|
||||
queue.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
while let Poll::Ready(Some(result)) = self.outgoing.poll_next_unpin(cx) {
|
||||
if let Err(err) = result {
|
||||
warn!("connection closed while opening stream: {}", err);
|
||||
self.remote_closed = true;
|
||||
} else {
|
||||
trace!("finished opening stream");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
|
||||
fn server_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
while let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
|
||||
let Some(stream) = stream else {
|
||||
if !self.remote_closed {
|
||||
info!("remote closed connection");
|
||||
self.remote_closed = true;
|
||||
}
|
||||
|
||||
break;
|
||||
};
|
||||
|
||||
debug!("received incoming stream");
|
||||
// The size of this is bounded by yamux max streams config.
|
||||
self.incoming.push(ReadId::new(stream));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
|
||||
fn server_process_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
let mut queue = self.queue.lock().unwrap();
|
||||
while let Poll::Ready(Some(result)) = self.incoming.poll_next_unpin(cx) {
|
||||
match result {
|
||||
Ok((id, stream)) => {
|
||||
debug!("received stream: {}", id);
|
||||
if let Some(sender) = queue.waiting.remove(&id) {
|
||||
_ = sender
|
||||
.send(stream)
|
||||
.inspect_err(|_| error!("caller dropped receiver"));
|
||||
trace!("returned stream to caller: {}", id);
|
||||
} else {
|
||||
trace!("queuing stream: {}", id);
|
||||
queue.ready.insert(id, stream);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("connection closed while receiving stream: {}", err);
|
||||
self.remote_closed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the waker so `YamuxCtrl` can wake up the connection.
|
||||
queue.waker = Some(cx.waker().clone());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
|
||||
fn handle_shutdown(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
// Attempt to close the connection if the shutdown notify has been set.
|
||||
if !self.closed && self.shutdown_notify.load(Ordering::Relaxed) {
|
||||
if let Poll::Ready(()) = self.conn.poll_close(cx)? {
|
||||
self.closed = true;
|
||||
info!("mux connection closed");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_complete(&self) -> bool {
|
||||
self.remote_closed || self.closed
|
||||
}
|
||||
|
||||
fn poll_client(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
self.client_handle_inbound(cx)?;
|
||||
|
||||
if !self.remote_closed {
|
||||
self.client_handle_outbound(cx)?;
|
||||
|
||||
// We need to poll the inbound again to make sure the connection
|
||||
// flushes the write buffer.
|
||||
self.client_handle_inbound(cx)?;
|
||||
}
|
||||
|
||||
self.handle_shutdown(cx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_server(&mut self, cx: &mut Context<'_>) -> Result<()> {
|
||||
self.server_handle_inbound(cx)?;
|
||||
self.server_process_inbound(cx)?;
|
||||
self.handle_shutdown(cx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Future for YamuxFuture<Io>
|
||||
where
|
||||
Io: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
type Output = Result<()>;
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(
|
||||
fields(role = %self.role),
|
||||
skip_all
|
||||
)
|
||||
)]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.role {
|
||||
Role::Client => self.poll_client(cx)?,
|
||||
Role::Server => self.poll_server(cx)?,
|
||||
};
|
||||
|
||||
if self.is_complete() {
|
||||
self.close_notify.notify_waiters();
|
||||
info!("connection complete");
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A yamux control handle.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct YamuxCtrl {
|
||||
role: Role,
|
||||
queue: Arc<Mutex<Queue>>,
|
||||
close_notify: Arc<Notify>,
|
||||
shutdown_notify: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl YamuxCtrl {
|
||||
/// Closes the yamux connection.
|
||||
pub fn close(&self) {
|
||||
self.shutdown_notify.store(true, Ordering::Relaxed);
|
||||
|
||||
// Wake up the connection.
|
||||
self.queue
|
||||
.lock()
|
||||
.unwrap()
|
||||
.waker
|
||||
.as_ref()
|
||||
.map(|waker| waker.wake_by_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Id> UidMux<Id> for YamuxCtrl
|
||||
where
|
||||
Id: fmt::Debug + AsRef<[u8]> + Sync,
|
||||
{
|
||||
type Stream = Stream;
|
||||
type Error = std::io::Error;
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
tracing::instrument(
|
||||
fields(role = %self.role, id = hex::encode(id)),
|
||||
skip_all,
|
||||
err
|
||||
)
|
||||
)]
|
||||
async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error> {
|
||||
let internal_id = InternalId::new(id.as_ref());
|
||||
|
||||
debug!("opening stream: {}", internal_id);
|
||||
|
||||
let receiver = {
|
||||
let mut queue = self.queue.lock().unwrap();
|
||||
if let Some(stream) = queue.ready.remove(&internal_id) {
|
||||
trace!("stream already opened");
|
||||
return Ok(stream);
|
||||
}
|
||||
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
|
||||
// Insert the oneshot into the queue.
|
||||
queue.waiting.insert(internal_id, sender);
|
||||
// Wake up the connection.
|
||||
queue.waker.as_ref().map(|waker| waker.wake_by_ref());
|
||||
|
||||
trace!("waiting for stream");
|
||||
|
||||
receiver
|
||||
};
|
||||
|
||||
futures::select! {
|
||||
stream = receiver.fuse() =>
|
||||
stream
|
||||
.inspect(|_| debug!("caller received stream"))
|
||||
.inspect_err(|_| error!("connection cancelled stream"))
|
||||
.map_err(|_| {
|
||||
std::io::Error::other(format!("connection cancelled stream"))
|
||||
}),
|
||||
_ = self.close_notify.notified().fuse() => {
|
||||
error!("connection closed before stream opened");
|
||||
Err(std::io::ErrorKind::ConnectionAborted.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_yamux() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let client_ctrl = client.control();
|
||||
let server_ctrl = server.control();
|
||||
|
||||
let conn_task = tokio::spawn(async {
|
||||
futures::try_join!(client.into_future(), server.into_future()).unwrap();
|
||||
});
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
let mut stream = client_ctrl.open(b"0").await.unwrap();
|
||||
let mut stream2 = client_ctrl.open(b"00").await.unwrap();
|
||||
|
||||
stream.write_all(b"ping").await.unwrap();
|
||||
stream2.write_all(b"ping2").await.unwrap();
|
||||
},
|
||||
async {
|
||||
let mut stream = server_ctrl.open(b"0").await.unwrap();
|
||||
let mut stream2 = server_ctrl.open(b"00").await.unwrap();
|
||||
|
||||
let mut buf = [0; 4];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"ping");
|
||||
|
||||
let mut buf = [0; 5];
|
||||
stream2.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"ping2");
|
||||
}
|
||||
);
|
||||
|
||||
client_ctrl.close();
|
||||
server_ctrl.close();
|
||||
|
||||
conn_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_yamux_client_close() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let client_ctrl = client.control();
|
||||
|
||||
let mut fut = futures::future::try_join(client.into_future(), server.into_future());
|
||||
|
||||
_ = futures::poll!(&mut fut);
|
||||
|
||||
client_ctrl.close();
|
||||
|
||||
// Both connections close cleanly.
|
||||
fut.await.unwrap();
|
||||
}
|
||||
|
||||
// Test the case where the client closes the connection while the server is expecting a new stream.
|
||||
#[tokio::test]
|
||||
async fn test_yamux_client_close_early() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let client_ctrl = client.control();
|
||||
let server_ctrl = server.control();
|
||||
|
||||
let mut fut_conn = futures::future::try_join(client.into_future(), server.into_future());
|
||||
_ = futures::poll!(&mut fut_conn);
|
||||
|
||||
let mut fut_open = server_ctrl.open(b"0");
|
||||
_ = futures::poll!(&mut fut_open);
|
||||
|
||||
client_ctrl.close();
|
||||
|
||||
// Both connections close cleanly.
|
||||
fut_conn.await.unwrap();
|
||||
// But caller gets an error.
|
||||
assert!(fut_open.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_yamux_server_close() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let server_ctrl = server.control();
|
||||
|
||||
let mut fut = futures::future::try_join(client.into_future(), server.into_future());
|
||||
|
||||
_ = futures::poll!(&mut fut);
|
||||
|
||||
server_ctrl.close();
|
||||
|
||||
// Both connections close cleanly.
|
||||
fut.await.unwrap();
|
||||
}
|
||||
|
||||
// Test the case where the server closes the connection while the client is opening a new stream.
|
||||
#[tokio::test]
|
||||
async fn test_yamux_server_close_early() {
|
||||
let (client_io, server_io) = duplex(1024);
|
||||
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
|
||||
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
|
||||
|
||||
let client_ctrl = client.control();
|
||||
let server_ctrl = server.control();
|
||||
|
||||
let mut fut_client = client.into_future();
|
||||
let mut fut_server = server.into_future();
|
||||
|
||||
let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
|
||||
_ = futures::poll!(&mut fut_conn);
|
||||
drop(fut_conn);
|
||||
|
||||
let mut fut_open = client_ctrl.open(b"0");
|
||||
_ = futures::poll!(&mut fut_open);
|
||||
|
||||
// We need to prevent the client from beating us to the punch here.
|
||||
fut_client.queue.lock().unwrap().waiting.clear();
|
||||
|
||||
server_ctrl.close();
|
||||
|
||||
// Both connections close cleanly.
|
||||
futures::try_join!(fut_client, fut_server).unwrap();
|
||||
// But caller gets an error.
|
||||
assert!(fut_open.await.is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user