feat: uid-mux (#28)

* feat: uid-mux

* fix deadlock

* wake up after signalling close

* appease rustfmt

* doc: Create -> Creates

* simplify Waker usage

* test more cases, handle errors better

* remove tracing-subscriber dep

* Apply suggestions from code review

Co-authored-by: dan <themighty1@users.noreply.github.com>

* feat(uid-mux): framed mux (#29)

* feat(uid-mux): framed mux

* add test utils

* Apply suggestions from code review

Co-authored-by: dan <themighty1@users.noreply.github.com>

---------

Co-authored-by: dan <themighty1@users.noreply.github.com>

---------

Co-authored-by: dan <themighty1@users.noreply.github.com>
This commit is contained in:
sinu.eth
2024-05-24 12:20:29 -08:00
committed by GitHub
parent 8d9f3c3a8d
commit 6cf35e0047
9 changed files with 1110 additions and 10 deletions

View File

@@ -1,11 +1,12 @@
[workspace]
members = ["utils", "utils-aio", "spansy", "serio"]
members = ["utils", "utils-aio", "spansy", "serio", "uid-mux"]
[workspace.dependencies]
tlsn-utils = { path = "utils" }
tlsn-utils-aio = { path = "utils-aio" }
spansy = { path = "spansy" }
serio = { path = "serio" }
uid-mux = { path = "uid-mux" }
rand = "0.8"
thiserror = "1"
@@ -14,6 +15,7 @@ prost = "0.9"
futures = "0.3"
futures-sink = "0.3"
futures-core = "0.3"
futures-io = "0.3"
futures-channel = "0.3"
futures-util = "0.3"
tokio-util = "0.7"
@@ -28,3 +30,5 @@ serde = "1"
cfg-if = "1"
bincode = "1.3"
pin-project-lite = "0.2"
tracing = "0.1"
tracing-subscriber = "0.3"

View File

@@ -7,7 +7,7 @@ edition = "2021"
default = ["compat", "channel", "codec", "bincode"]
compat = ["dep:futures-sink"]
channel = ["dep:futures-channel"]
codec = ["dep:tokio-util"]
codec = ["dep:tokio-util", "dep:futures-io"]
bincode = ["dep:bincode"]
[dependencies]
@@ -15,10 +15,14 @@ bytes.workspace = true
serde.workspace = true
pin-project-lite.workspace = true
futures-core.workspace = true
futures-io = { workspace = true, optional = true }
futures-channel = { workspace = true, optional = true }
futures-sink = { workspace = true, optional = true }
futures-util = { workspace = true, features = ["bilock", "unstable"] }
tokio-util = { workspace = true, features = ["codec"], optional = true }
tokio-util = { workspace = true, features = [
"codec",
"compat",
], optional = true }
bincode = { workspace = true, optional = true }
[dev-dependencies]

View File

@@ -8,8 +8,18 @@ use std::{
use bytes::{Bytes, BytesMut};
use futures_core::stream::TryStream;
use futures_io::{AsyncRead, AsyncWrite};
use crate::{Deserialize, Serialize, Sink, Stream};
use crate::{Deserialize, IoDuplex, Serialize, Sink, Stream};
/// A codec.
pub trait Codec<Io> {
/// The framed transport type.
type Framed: IoDuplex;
/// Creates a new framed transport with the given IO.
fn new_framed(&self, io: Io) -> Self::Framed;
}
/// A serializer.
pub trait Serializer {
@@ -53,6 +63,25 @@ mod bincode_impl {
Ok(deserialize(buf)?)
}
}
use tokio_util::{
codec::{Framed as TokioFramed, LengthDelimitedCodec},
compat::{Compat, FuturesAsyncReadCompatExt as _},
};
impl<Io> Codec<Io> for Bincode
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Framed = Framed<TokioFramed<Compat<Io>, LengthDelimitedCodec>, Self>;
fn new_framed(&self, io: Io) -> Self::Framed {
Framed::new(
LengthDelimitedCodec::builder().new_framed(io.compat()),
self.clone(),
)
}
}
}
#[cfg(feature = "bincode")]
@@ -133,7 +162,7 @@ where
mod tests {
use serde::{Deserialize, Serialize};
use tokio::io::duplex;
use tokio_util::codec::LengthDelimitedCodec;
use tokio_util::compat::TokioAsyncReadCompatExt;
use crate::{SinkExt, StreamExt};
@@ -149,11 +178,8 @@ mod tests {
fn test_framed() {
let (a, b) = duplex(1024);
let a = LengthDelimitedCodec::builder().new_framed(a);
let b = LengthDelimitedCodec::builder().new_framed(b);
let mut a = Framed::new(a, Bincode::default());
let mut b = Framed::new(b, Bincode::default());
let mut a = Bincode::default().new_framed(a.compat());
let mut b = Bincode::default().new_framed(b.compat());
let a = async {
a.send(Ping).await.unwrap();

34
uid-mux/Cargo.toml Normal file
View File

@@ -0,0 +1,34 @@
[package]
name = "uid-mux"
version = "0.1.0"
authors = ["TLSNotary Team"]
description = "Async multiplexing library."
keywords = ["multiplex", "channel", "futures", "async"]
categories = ["network-programming", "asynchronous"]
license = "MIT OR Apache-2.0"
edition = "2021"
[features]
default = ["tracing", "serio"]
tracing = ["dep:tracing"]
serio = ["dep:serio"]
test-utils = ["tokio/io-util", "dep:tokio-util"]
[dependencies]
async-trait = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
futures = { workspace = true }
tracing = { workspace = true, optional = true }
yamux = "0.13"
blake3 = "1.5"
hex = "0.4"
serio = { workspace = true, optional = true }
tokio-util = { version = "0.7", features = ["compat"], optional = true }
[dev-dependencies]
tokio-util = { version = "0.7", features = ["compat"] }
tokio = { workspace = true, features = [
"io-util",
"rt-multi-thread",
"macros",
] }

202
uid-mux/src/future.rs Normal file
View File

@@ -0,0 +1,202 @@
use std::{
pin::{pin, Pin},
task::{Context, Poll},
};
use futures::{ready, AsyncRead, AsyncWrite, Future};
use tokio::sync::oneshot;
use crate::{
log::{error, trace},
InternalId,
};
const BUF: usize = 32;
#[derive(Debug)]
struct Inner<Io> {
io: Io,
count: u8,
id: [u8; BUF],
}
impl<Io> Inner<Io> {
fn is_done(&self) -> bool {
self.count == 32
}
}
#[derive(Debug)]
enum State<Io> {
Pending(Inner<Io>),
Error,
}
impl<Io> State<Io> {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::Error)
}
}
/// A future that resolves when an id has been read.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct ReadId<Io>(State<Io>);
impl<Io> ReadId<Io> {
/// Creates a new `ReadId` future.
pub(crate) fn new(io: Io) -> Self {
Self(State::Pending(Inner {
io,
count: 0,
id: [0u8; BUF],
}))
}
}
impl<Io> Future for ReadId<Io>
where
Io: AsyncRead + Unpin,
{
type Output = Result<(InternalId, Io), std::io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let State::Pending(mut state) = self.0.take() else {
panic!("poll after completion");
};
while let Poll::Ready(read) =
pin!(&mut state.io).poll_read(cx, &mut state.id[state.count as usize..])?
{
state.count += read as u8;
if state.is_done() {
let id = InternalId(state.id);
trace!("read id: {}", id);
return Poll::Ready(Ok((id, state.io)));
} else if read == 0 {
error!("remote closed before sending id");
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()));
}
}
self.0 = State::Pending(state);
Poll::Pending
}
}
/// A future that resolves when an id has been written.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct WriteId<Io>(State<Io>);
impl<Io> WriteId<Io> {
/// Creates a new `WriteId` future.
pub(crate) fn new(io: Io, id: InternalId) -> Self {
Self(State::Pending(Inner {
io,
count: 0,
id: id.0,
}))
}
}
impl<Io> Future for WriteId<Io>
where
Io: AsyncWrite + Unpin,
{
type Output = Result<Io, std::io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let State::Pending(mut state) = self.0.take() else {
panic!("poll after completion");
};
// If we haven't finished sending the id, keep sending it.
if !state.is_done() {
while let Poll::Ready(sent) =
pin!(&mut state.io).poll_write(cx, &state.id[state.count as usize..])?
{
state.count += sent as u8;
if state.is_done() {
break;
}
}
}
// If we've finished sending, flush the write buffer. If flushing
// succeeds then we can return Ready, otherwise we need to keep
// trying.
if state.is_done() {
if pin!(&mut state.io).poll_flush(cx)?.is_ready() {
return Poll::Ready(Ok(state.io));
}
}
self.0 = State::Pending(state);
Poll::Pending
}
}
/// A future that resolves when a stream has been returned to the caller.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct ReturnStream<Io> {
fut: WriteId<Io>,
sender: Option<oneshot::Sender<Io>>,
}
impl<Io> ReturnStream<Io> {
/// Creates a new `ReturnStream` future.
pub(crate) fn new(id: InternalId, io: Io, sender: oneshot::Sender<Io>) -> Self {
Self {
fut: WriteId::new(io, id),
sender: Some(sender),
}
}
}
impl<Io> Future for ReturnStream<Io>
where
Io: AsyncWrite + Unpin,
{
type Output = Result<(), std::io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let io = ready!(pin!(&mut self.fut).poll(cx))?;
_ = self
.sender
.take()
.expect("future not polled after completion")
.send(io);
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
use tokio_util::compat::TokioAsyncReadCompatExt as _;
#[test]
fn test_id_future() {
let id_0 = InternalId([42u8; 32]);
// send 1 byte at a time
let (io_0, io_1) = duplex(1);
futures::executor::block_on(async {
let (_, (id_1, _)) = futures::try_join!(
WriteId::new(io_0.compat(), id_0),
ReadId::new(io_1.compat())
)
.unwrap();
assert_eq!(id_0, id_1);
});
}
}

109
uid-mux/src/lib.rs Normal file
View File

@@ -0,0 +1,109 @@
//! Multiplexing with unique channel ids.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub(crate) mod future;
#[cfg(feature = "serio")]
mod serio;
#[cfg(any(test, feature = "test-utils"))]
pub mod test_utils;
pub mod yamux;
#[cfg(feature = "serio")]
pub use serio::{FramedMux, FramedUidMux};
use core::fmt;
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite};
/// Internal stream identifier.
///
/// User provided ids are hashed to a fixed length.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct InternalId([u8; 32]);
impl InternalId {
/// Creates a new `InternalId` from a byte slice.
pub(crate) fn new(bytes: &[u8]) -> Self {
Self(blake3::hash(bytes).into())
}
}
impl fmt::Display for InternalId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in &self.0[..4] {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}
impl AsRef<[u8]> for InternalId {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
/// A multiplexer that opens streams with unique ids.
#[async_trait]
pub trait UidMux<Id> {
/// Stream type.
type Stream: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static;
/// Error type.
type Error;
/// Open a new stream with the given id.
async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error>;
}
pub(crate) mod log {
macro_rules! error {
($( $tokens:tt )*) => {
{
#[cfg(feature = "tracing")]
tracing::error!($( $tokens )*);
}
};
}
macro_rules! warn_ {
($( $tokens:tt )*) => {
{
#[cfg(feature = "tracing")]
tracing::warn!($( $tokens )*);
}
};
}
macro_rules! trace {
($( $tokens:tt )*) => {
{
#[cfg(feature = "tracing")]
tracing::trace!($( $tokens )*);
}
};
}
macro_rules! debug {
($( $tokens:tt )*) => {
{
#[cfg(feature = "tracing")]
tracing::debug!($( $tokens )*);
}
};
}
macro_rules! info {
($( $tokens:tt )*) => {
{
#[cfg(feature = "tracing")]
tracing::info!($( $tokens )*);
}
};
}
pub(crate) use {debug, error, info, trace, warn_ as warn};
}

131
uid-mux/src/serio.rs Normal file
View File

@@ -0,0 +1,131 @@
use ::serio::{codec::Codec, IoDuplex};
use async_trait::async_trait;
use crate::UidMux;
/// A multiplexer that opens framed streams with unique ids.
#[async_trait]
pub trait FramedUidMux<Id> {
/// Stream type.
type Framed: IoDuplex;
/// Error type.
type Error;
/// Opens a new framed stream with the given id.
async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error>;
}
/// A framed multiplexer.
#[derive(Debug)]
pub struct FramedMux<M, C> {
mux: M,
codec: C,
}
impl<M, C> FramedMux<M, C> {
/// Creates a new `FramedMux`.
pub fn new(mux: M, codec: C) -> Self {
Self { mux, codec }
}
/// Returns a reference to the mux.
pub fn mux(&self) -> &M {
&self.mux
}
/// Returns a mutable reference to the mux.
pub fn mux_mut(&mut self) -> &mut M {
&mut self.mux
}
/// Returns a reference to the codec.
pub fn codec(&self) -> &C {
&self.codec
}
/// Returns a mutable reference to the codec.
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
/// Splits the `FramedMux` into its parts.
pub fn into_parts(self) -> (M, C) {
(self.mux, self.codec)
}
}
#[async_trait]
impl<Id, M, C> FramedUidMux<Id> for FramedMux<M, C>
where
Id: Sync,
M: UidMux<Id> + Sync,
C: Codec<<M as UidMux<Id>>::Stream> + Sync,
{
/// Stream type.
type Framed = <C as Codec<<M as UidMux<Id>>::Stream>>::Framed;
/// Error type.
type Error = <M as UidMux<Id>>::Error;
/// Opens a new framed stream with the given id.
async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error> {
let stream = self.mux.open(id).await?;
Ok(self.codec.new_framed(stream))
}
}
impl<M: Clone, C: Clone> Clone for FramedMux<M, C> {
fn clone(&self) -> Self {
Self {
mux: self.mux.clone(),
codec: self.codec.clone(),
}
}
}
#[cfg(test)]
mod tests {
use std::future::IntoFuture;
use super::*;
use crate::yamux::{Config, Mode, Yamux};
use ::serio::codec::Bincode;
use serio::{stream::IoStreamExt, SinkExt};
use tokio::io::duplex;
use tokio_util::compat::TokioAsyncReadCompatExt;
#[tokio::test]
async fn test_framed_mux() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let client_ctrl = FramedMux::new(client.control(), Bincode);
let server_ctrl = FramedMux::new(server.control(), Bincode);
let conn_task = tokio::spawn(async {
futures::try_join!(client.into_future(), server.into_future()).unwrap();
});
futures::join!(
async {
let mut stream = client_ctrl.open_framed(b"test").await.unwrap();
stream.send(42u128).await.unwrap();
client_ctrl.mux().close();
},
async {
let mut stream = server_ctrl.open_framed(b"test").await.unwrap();
let num: u128 = stream.expect_next().await.unwrap();
server_ctrl.mux().close();
assert_eq!(num, 42u128);
}
);
conn_task.await.unwrap();
}
}

47
uid-mux/src/test_utils.rs Normal file
View File

@@ -0,0 +1,47 @@
//! Test utilities.
use tokio::io::{duplex, DuplexStream};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use yamux::{Config, Mode};
use crate::{
yamux::{Yamux, YamuxCtrl},
FramedMux,
};
/// Creates a test pair of yamux instances.
///
/// # Arguments
///
/// * `buffer` - The buffer size.
pub fn test_yamux_pair(
buffer: usize,
) -> (Yamux<Compat<DuplexStream>>, Yamux<Compat<DuplexStream>>) {
let (a, b) = duplex(buffer);
let a = Yamux::new(a.compat(), Config::default(), Mode::Client);
let b = Yamux::new(b.compat(), Config::default(), Mode::Server);
(a, b)
}
/// Creates a test pair of framed yamux instances.
///
/// # Arguments
///
/// * `buffer` - The buffer size.
/// * `codec` - The codec.
pub fn test_yamux_pair_framed<C: Clone>(
buffer: usize,
codec: C,
) -> (
(FramedMux<YamuxCtrl, C>, Yamux<Compat<DuplexStream>>),
(FramedMux<YamuxCtrl, C>, Yamux<Compat<DuplexStream>>),
) {
let (a, b) = test_yamux_pair(buffer);
let ctrl_a = FramedMux::new(a.control(), codec.clone());
let ctrl_b = FramedMux::new(b.control(), codec);
((ctrl_a, a), (ctrl_b, b))
}

543
uid-mux/src/yamux.rs Normal file
View File

@@ -0,0 +1,543 @@
//! Yamux multiplexer.
//!
//! This module provides a [`yamux`](https://crates.io/crates/yamux) wrapper which implements [`UidMux`](crate::UidMux).
use std::{
collections::HashMap,
fmt,
future::IntoFuture,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
task::{Context, Poll, Waker},
};
use async_trait::async_trait;
use futures::{stream::FuturesUnordered, AsyncRead, AsyncWrite, Future, FutureExt, StreamExt};
use tokio::sync::{oneshot, Notify};
use yamux::Connection;
use crate::{
future::{ReadId, ReturnStream},
log::{debug, error, info, trace, warn},
InternalId, UidMux,
};
pub use yamux::{Config, ConnectionError, Mode, Stream};
type Result<T, E = ConnectionError> = std::result::Result<T, E>;
#[derive(Debug, Clone, Copy)]
enum Role {
Client,
Server,
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Role::Client => write!(f, "Client"),
Role::Server => write!(f, "Server"),
}
}
}
/// A yamux multiplexer.
#[derive(Debug)]
pub struct Yamux<Io> {
role: Role,
conn: Connection<Io>,
queue: Arc<Mutex<Queue>>,
close_notify: Arc<Notify>,
shutdown_notify: Arc<AtomicBool>,
}
#[derive(Debug)]
struct Queue {
waiting: HashMap<InternalId, oneshot::Sender<Stream>>,
ready: HashMap<InternalId, Stream>,
waker: Option<Waker>,
}
impl Default for Queue {
fn default() -> Self {
Self {
waiting: Default::default(),
ready: Default::default(),
waker: None,
}
}
}
impl<Io> Yamux<Io> {
/// Returns a new control handle.
pub fn control(&self) -> YamuxCtrl {
YamuxCtrl {
role: self.role,
queue: self.queue.clone(),
close_notify: self.close_notify.clone(),
shutdown_notify: self.shutdown_notify.clone(),
}
}
}
impl<Io> Yamux<Io>
where
Io: AsyncWrite + AsyncRead + Unpin,
{
/// Creates a new yamux multiplexer.
pub fn new(io: Io, config: Config, mode: Mode) -> Self {
let role = match mode {
Mode::Client => Role::Client,
Mode::Server => Role::Server,
};
Self {
role,
conn: Connection::new(io, config, mode),
queue: Default::default(),
close_notify: Default::default(),
shutdown_notify: Default::default(),
}
}
}
impl<Io> IntoFuture for Yamux<Io>
where
Io: AsyncWrite + AsyncRead + Unpin,
{
type Output = Result<()>;
type IntoFuture = YamuxFuture<Io>;
fn into_future(self) -> Self::IntoFuture {
YamuxFuture {
role: self.role,
conn: self.conn,
incoming: Default::default(),
outgoing: Default::default(),
queue: self.queue,
closed: false,
remote_closed: false,
close_notify: self.close_notify,
shutdown_notify: self.shutdown_notify,
}
}
}
/// A yamux connection future.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct YamuxFuture<Io> {
role: Role,
conn: Connection<Io>,
/// Pending incoming streams, waiting for ids to be received.
incoming: FuturesUnordered<ReadId<Stream>>,
/// Pending outgoing streams, waiting to send ids and return streams
/// to callers.
outgoing: FuturesUnordered<ReturnStream<Stream>>,
queue: Arc<Mutex<Queue>>,
/// Whether this side has closed the connection.
closed: bool,
/// Whether the remote has closed the connection.
remote_closed: bool,
close_notify: Arc<Notify>,
shutdown_notify: Arc<AtomicBool>,
}
impl<Io> YamuxFuture<Io>
where
Io: AsyncWrite + AsyncRead + Unpin,
{
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
fn client_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
if let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
if stream.is_some() {
error!("client mux received incoming stream");
return Err(
std::io::Error::other("client mode cannot accept incoming streams").into(),
);
}
info!("remote closed connection");
self.remote_closed = true;
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
fn client_handle_outbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
// Putting this in a block so the lock is released as soon as possible.
{
let mut queue = self.queue.lock().unwrap();
while !queue.waiting.is_empty() {
if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
let id = *queue.waiting.keys().next().unwrap();
let sender = queue.waiting.remove(&id).unwrap();
debug!("opened new stream: {}", id);
self.outgoing.push(ReturnStream::new(id, stream, sender));
} else {
break;
}
}
// Set the waker so `YamuxCtrl` can wake up the connection.
queue.waker = Some(cx.waker().clone());
}
while let Poll::Ready(Some(result)) = self.outgoing.poll_next_unpin(cx) {
if let Err(err) = result {
warn!("connection closed while opening stream: {}", err);
self.remote_closed = true;
} else {
trace!("finished opening stream");
}
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
fn server_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
while let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
let Some(stream) = stream else {
if !self.remote_closed {
info!("remote closed connection");
self.remote_closed = true;
}
break;
};
debug!("received incoming stream");
// The size of this is bounded by yamux max streams config.
self.incoming.push(ReadId::new(stream));
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
fn server_process_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
let mut queue = self.queue.lock().unwrap();
while let Poll::Ready(Some(result)) = self.incoming.poll_next_unpin(cx) {
match result {
Ok((id, stream)) => {
debug!("received stream: {}", id);
if let Some(sender) = queue.waiting.remove(&id) {
_ = sender
.send(stream)
.inspect_err(|_| error!("caller dropped receiver"));
trace!("returned stream to caller: {}", id);
} else {
trace!("queuing stream: {}", id);
queue.ready.insert(id, stream);
}
}
Err(err) => {
warn!("connection closed while receiving stream: {}", err);
self.remote_closed = true;
}
}
}
// Set the waker so `YamuxCtrl` can wake up the connection.
queue.waker = Some(cx.waker().clone());
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
fn handle_shutdown(&mut self, cx: &mut Context<'_>) -> Result<()> {
// Attempt to close the connection if the shutdown notify has been set.
if !self.closed && self.shutdown_notify.load(Ordering::Relaxed) {
if let Poll::Ready(()) = self.conn.poll_close(cx)? {
self.closed = true;
info!("mux connection closed");
}
}
Ok(())
}
fn is_complete(&self) -> bool {
self.remote_closed || self.closed
}
fn poll_client(&mut self, cx: &mut Context<'_>) -> Result<()> {
self.client_handle_inbound(cx)?;
if !self.remote_closed {
self.client_handle_outbound(cx)?;
// We need to poll the inbound again to make sure the connection
// flushes the write buffer.
self.client_handle_inbound(cx)?;
}
self.handle_shutdown(cx)?;
Ok(())
}
fn poll_server(&mut self, cx: &mut Context<'_>) -> Result<()> {
self.server_handle_inbound(cx)?;
self.server_process_inbound(cx)?;
self.handle_shutdown(cx)?;
Ok(())
}
}
impl<Io> Future for YamuxFuture<Io>
where
Io: AsyncWrite + AsyncRead + Unpin,
{
type Output = Result<()>;
#[cfg_attr(
feature = "tracing",
tracing::instrument(
fields(role = %self.role),
skip_all
)
)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.role {
Role::Client => self.poll_client(cx)?,
Role::Server => self.poll_server(cx)?,
};
if self.is_complete() {
self.close_notify.notify_waiters();
info!("connection complete");
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
/// A yamux control handle.
#[derive(Debug, Clone)]
pub struct YamuxCtrl {
role: Role,
queue: Arc<Mutex<Queue>>,
close_notify: Arc<Notify>,
shutdown_notify: Arc<AtomicBool>,
}
impl YamuxCtrl {
/// Closes the yamux connection.
pub fn close(&self) {
self.shutdown_notify.store(true, Ordering::Relaxed);
// Wake up the connection.
self.queue
.lock()
.unwrap()
.waker
.as_ref()
.map(|waker| waker.wake_by_ref());
}
}
#[async_trait]
impl<Id> UidMux<Id> for YamuxCtrl
where
Id: fmt::Debug + AsRef<[u8]> + Sync,
{
type Stream = Stream;
type Error = std::io::Error;
#[cfg_attr(
feature = "tracing",
tracing::instrument(
fields(role = %self.role, id = hex::encode(id)),
skip_all,
err
)
)]
async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error> {
let internal_id = InternalId::new(id.as_ref());
debug!("opening stream: {}", internal_id);
let receiver = {
let mut queue = self.queue.lock().unwrap();
if let Some(stream) = queue.ready.remove(&internal_id) {
trace!("stream already opened");
return Ok(stream);
}
let (sender, receiver) = oneshot::channel();
// Insert the oneshot into the queue.
queue.waiting.insert(internal_id, sender);
// Wake up the connection.
queue.waker.as_ref().map(|waker| waker.wake_by_ref());
trace!("waiting for stream");
receiver
};
futures::select! {
stream = receiver.fuse() =>
stream
.inspect(|_| debug!("caller received stream"))
.inspect_err(|_| error!("connection cancelled stream"))
.map_err(|_| {
std::io::Error::other(format!("connection cancelled stream"))
}),
_ = self.close_notify.notified().fuse() => {
error!("connection closed before stream opened");
Err(std::io::ErrorKind::ConnectionAborted.into())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{AsyncReadExt, AsyncWriteExt};
use tokio::io::duplex;
use tokio_util::compat::TokioAsyncReadCompatExt;
#[tokio::test]
async fn test_yamux() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let client_ctrl = client.control();
let server_ctrl = server.control();
let conn_task = tokio::spawn(async {
futures::try_join!(client.into_future(), server.into_future()).unwrap();
});
futures::join!(
async {
let mut stream = client_ctrl.open(b"0").await.unwrap();
let mut stream2 = client_ctrl.open(b"00").await.unwrap();
stream.write_all(b"ping").await.unwrap();
stream2.write_all(b"ping2").await.unwrap();
},
async {
let mut stream = server_ctrl.open(b"0").await.unwrap();
let mut stream2 = server_ctrl.open(b"00").await.unwrap();
let mut buf = [0; 4];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
let mut buf = [0; 5];
stream2.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping2");
}
);
client_ctrl.close();
server_ctrl.close();
conn_task.await.unwrap();
}
#[tokio::test]
async fn test_yamux_client_close() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let client_ctrl = client.control();
let mut fut = futures::future::try_join(client.into_future(), server.into_future());
_ = futures::poll!(&mut fut);
client_ctrl.close();
// Both connections close cleanly.
fut.await.unwrap();
}
// Test the case where the client closes the connection while the server is expecting a new stream.
#[tokio::test]
async fn test_yamux_client_close_early() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let client_ctrl = client.control();
let server_ctrl = server.control();
let mut fut_conn = futures::future::try_join(client.into_future(), server.into_future());
_ = futures::poll!(&mut fut_conn);
let mut fut_open = server_ctrl.open(b"0");
_ = futures::poll!(&mut fut_open);
client_ctrl.close();
// Both connections close cleanly.
fut_conn.await.unwrap();
// But caller gets an error.
assert!(fut_open.await.is_err());
}
#[tokio::test]
async fn test_yamux_server_close() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let server_ctrl = server.control();
let mut fut = futures::future::try_join(client.into_future(), server.into_future());
_ = futures::poll!(&mut fut);
server_ctrl.close();
// Both connections close cleanly.
fut.await.unwrap();
}
// Test the case where the server closes the connection while the client is opening a new stream.
#[tokio::test]
async fn test_yamux_server_close_early() {
let (client_io, server_io) = duplex(1024);
let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
let client_ctrl = client.control();
let server_ctrl = server.control();
let mut fut_client = client.into_future();
let mut fut_server = server.into_future();
let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
_ = futures::poll!(&mut fut_conn);
drop(fut_conn);
let mut fut_open = client_ctrl.open(b"0");
_ = futures::poll!(&mut fut_open);
// We need to prevent the client from beating us to the punch here.
fut_client.queue.lock().unwrap().waiting.clear();
server_ctrl.close();
// Both connections close cleanly.
futures::try_join!(fut_client, fut_server).unwrap();
// But caller gets an error.
assert!(fut_open.await.is_err());
}
}