Rewrite Control to be a layer on top of Connection

This commit is contained in:
Thomas Eizinger
2022-10-24 21:29:30 +11:00
parent ce57e251fb
commit 0ac90a0805
8 changed files with 376 additions and 338 deletions

View File

@@ -13,7 +13,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use futures::{channel::mpsc, future, io::AsyncReadExt, prelude::*};
use std::sync::Arc;
use tokio::{runtime::Runtime, task};
use yamux::{Config, Connection, Mode};
use yamux::{Config, Connection, Control, Mode};
criterion_group!(benches, concurrent);
criterion_main!(benches);
@@ -92,7 +92,10 @@ async fn oneway(
let server = async move {
let mut connection = Connection::new(server, config(), Mode::Server);
while let Some(mut stream) = connection.next_stream().await.unwrap() {
while let Some(Ok(mut stream)) = stream::poll_fn(|cx| connection.poll_next_inbound(cx))
.next()
.await
{
let tx = tx.clone();
task::spawn(async move {
@@ -113,8 +116,9 @@ async fn oneway(
task::spawn(server);
let conn = Connection::new(client, config(), Mode::Client);
let mut ctrl = conn.control().unwrap();
task::spawn(yamux::into_stream(conn).for_each(|r| {
let (mut ctrl, conn) = Control::new(conn);
task::spawn(conn.for_each(|r| {
r.unwrap();
future::ready(())
}));

View File

@@ -98,21 +98,15 @@ use crate::{
frame::{self, Frame},
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use futures::{
channel::{mpsc, oneshot},
future::{self, Either},
prelude::*,
sink::SinkExt,
stream::Fuse,
};
use cleanup::Cleanup;
use closing::Closing;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::VecDeque;
use std::task::Context;
use std::{fmt, sync::Arc, task::Poll};
use crate::connection::cleanup::Cleanup;
use crate::connection::closing::Closing;
pub use control::Control;
pub use control::{Control, ControlledConnection};
pub use stream::{Packet, State, Stream};
/// Arbitrary limit of our internal command channels.
@@ -169,39 +163,148 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
}
}
/// Get a controller for this connection.
pub fn control(&self) -> Result<Control> {
match &self.inner {
ConnectionState::Active(active) => Ok(active.control()),
ConnectionState::Closed
| ConnectionState::Closing { .. }
| ConnectionState::Cleanup(_) => Err(ConnectionError::Closed),
ConnectionState::Poisoned => unreachable!(),
/// Poll for a new outbound stream.
///
/// This function will fail if the current state does not allow opening new outbound streams.
pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
ConnectionState::Active(mut active) => match active.new_outbound() {
Ok(stream) => {
self.inner = ConnectionState::Active(active);
return Poll::Ready(Ok(stream));
}
Err(e) => {
self.inner = ConnectionState::Cleanup(active.cleanup(e));
continue;
}
},
ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx) {
Poll::Ready(Ok(())) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Err(ConnectionError::Closed));
}
Poll::Ready(Err(e)) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Err(e));
}
Poll::Pending => {
self.inner = ConnectionState::Closing(inner);
return Poll::Pending;
}
},
ConnectionState::Cleanup(mut inner) => match inner.poll_unpin(cx) {
Poll::Ready(e) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Err(e));
}
Poll::Pending => {
self.inner = ConnectionState::Cleanup(inner);
return Poll::Pending;
}
},
ConnectionState::Closed => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Err(ConnectionError::Closed));
}
ConnectionState::Poisoned => unreachable!(),
}
}
}
/// Get the next incoming stream, opened by the remote.
/// Poll for the next inbound stream.
///
/// This must be called repeatedly in order to make progress.
/// Once `Ok(None)` or `Err(_)` is returned the connection is
/// considered closed and no further invocation of this method
/// must be attempted.
///
/// # Cancellation
///
/// This function is cancellation-safe.
pub async fn next_stream(&mut self) -> Result<Option<Stream>> {
future::poll_fn(|cx| self.inner.poll_next(cx))
.await
.transpose()
/// If this function returns `None`, the underlying connection is closed.
pub fn poll_next_inbound(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
loop {
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
ConnectionState::Active(mut active) => match active.poll(cx) {
Poll::Ready(Ok(stream)) => {
self.inner = ConnectionState::Active(active);
return Poll::Ready(Some(Ok(stream)));
}
Poll::Ready(Err(e)) => {
self.inner = ConnectionState::Cleanup(active.cleanup(e));
continue;
}
Poll::Pending => {
self.inner = ConnectionState::Active(active);
return Poll::Pending;
}
},
ConnectionState::Closing(mut closing) => match closing.poll_unpin(cx) {
Poll::Ready(Ok(())) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(None);
}
Poll::Ready(Err(e)) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
self.inner = ConnectionState::Closing(closing);
return Poll::Pending;
}
},
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
Poll::Ready(ConnectionError::Closed) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(None);
}
Poll::Ready(other) => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Some(Err(other)));
}
Poll::Pending => {
self.inner = ConnectionState::Cleanup(cleanup);
return Poll::Pending;
}
},
ConnectionState::Closed => {
self.inner = ConnectionState::Closed;
return Poll::Ready(None);
}
ConnectionState::Poisoned => unreachable!(),
}
}
}
/// Have the underlying connection make progress.
///
/// If this returns `Poll::Ready(None)`, the connection is closed and does no longer
/// need to be polled.
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
self.inner.poll_next(cx)
/// Close the connection.
pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
ConnectionState::Active(active) => {
self.inner = ConnectionState::Closing(active.close());
}
ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx)? {
Poll::Ready(()) => {
self.inner = ConnectionState::Closed;
}
Poll::Pending => {
self.inner = ConnectionState::Closing(inner);
return Poll::Pending;
}
},
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
Poll::Ready(reason) => {
log::warn!("Failure while closing connection: {}", reason);
self.inner = ConnectionState::Closed;
return Poll::Ready(Ok(()));
}
Poll::Pending => {
self.inner = ConnectionState::Cleanup(cleanup);
return Poll::Pending;
}
},
ConnectionState::Closed => {
self.inner = ConnectionState::Closed;
return Poll::Ready(Ok(()));
}
ConnectionState::Poisoned => {
unreachable!()
}
}
}
}
}
@@ -209,7 +312,7 @@ impl<T> Drop for Connection<T> {
fn drop(&mut self) {
match &mut self.inner {
ConnectionState::Active(active) => active.drop_all_streams(),
ConnectionState::Closing { .. } => {}
ConnectionState::Closing(_) => {}
ConnectionState::Cleanup(_) => {}
ConnectionState::Closed => {}
ConnectionState::Poisoned => {}
@@ -221,10 +324,7 @@ enum ConnectionState<T> {
/// The connection is alive and healthy.
Active(Active<T>),
/// Our user requested to shutdown the connection, we are working on it.
Closing {
inner: Closing<T>,
reply: oneshot::Sender<()>,
},
Closing(Closing<T>),
/// An error occurred and we are cleaning up our resources.
Cleanup(Cleanup),
/// The connection is closed.
@@ -237,7 +337,7 @@ impl<T> fmt::Debug for ConnectionState<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionState::Active(_) => write!(f, "Active"),
ConnectionState::Closing { .. } => write!(f, "Closing"),
ConnectionState::Closing(_) => write!(f, "Closing"),
ConnectionState::Cleanup(_) => write!(f, "Cleanup"),
ConnectionState::Closed => write!(f, "Closed"),
ConnectionState::Poisoned => write!(f, "Poisoned"),
@@ -245,86 +345,6 @@ impl<T> fmt::Debug for ConnectionState<T> {
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> ConnectionState<T> {
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
loop {
match std::mem::replace(self, ConnectionState::Poisoned) {
ConnectionState::Active(mut active) => {
match active.control_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(ControlCommand::OpenStream(reply))) => {
active.on_open_stream(reply)?;
*self = ConnectionState::Active(active);
continue;
}
Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => {
*self = ConnectionState::Closing {
inner: active.close(),
reply,
};
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
}
_ => {}
}
match active.poll(cx) {
Poll::Ready(Ok(stream)) => {
*self = ConnectionState::Active(active);
return Poll::Ready(Some(Ok(stream)));
}
Poll::Ready(Err(e)) => *self = ConnectionState::Cleanup(active.cleanup(e)),
Poll::Pending => {
*self = ConnectionState::Active(active);
return Poll::Pending;
}
}
}
ConnectionState::Closing {
inner: mut closing,
reply,
} => match closing.poll_unpin(cx) {
Poll::Ready(Ok(())) => {
let _ = reply.send(());
*self = ConnectionState::Closed;
return Poll::Ready(None);
}
Poll::Ready(Err(e)) => {
let _ = reply.send(());
*self = ConnectionState::Closed;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
*self = ConnectionState::Closing {
inner: closing,
reply,
};
return Poll::Pending;
}
},
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
Poll::Ready(ConnectionError::Closed) => {
*self = ConnectionState::Closed;
return Poll::Ready(None);
}
Poll::Ready(other) => {
*self = ConnectionState::Closed;
return Poll::Ready(Some(Err(other)));
}
Poll::Pending => {
*self = ConnectionState::Cleanup(cleanup);
return Poll::Pending;
}
},
ConnectionState::Closed => return Poll::Ready(None),
ConnectionState::Poisoned => unreachable!(),
}
}
}
}
/// A Yamux connection object.
///
/// Wraps the underlying I/O resource and makes progress via its
@@ -337,8 +357,6 @@ struct Active<T> {
socket: Fuse<frame::Io<T>>,
next_id: u32,
streams: IntMap<StreamId, Stream>,
control_sender: mpsc::Sender<ControlCommand>,
control_receiver: mpsc::Receiver<ControlCommand>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
garbage: Vec<StreamId>,
@@ -346,15 +364,6 @@ struct Active<T> {
pending_frames: VecDeque<Frame<()>>,
}
/// `Control` to `Connection` commands.
#[derive(Debug)]
pub(crate) enum ControlCommand {
/// Open a new stream to the remote end.
OpenStream(oneshot::Sender<Result<Stream>>),
/// Close the whole connection.
CloseConnection(oneshot::Sender<()>),
}
/// `Stream` to `Connection` commands.
#[derive(Debug)]
pub(crate) enum StreamCommand {
@@ -410,7 +419,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let (control_sender, control_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Active {
id,
@@ -418,8 +426,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
control_sender,
control_receiver,
stream_sender,
stream_receiver,
next_id: match mode {
@@ -431,18 +437,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
}
fn control(&self) -> Control {
Control::new(self.control_sender.clone())
}
/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(
self.control_receiver,
self.stream_receiver,
self.pending_frames,
self.socket,
)
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
}
/// Cleanup all our resources.
@@ -451,7 +448,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
self.drop_all_streams();
Cleanup::new(self.control_receiver, self.stream_receiver, error)
Cleanup::new(self.stream_receiver, error)
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
@@ -503,11 +500,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
}
fn on_open_stream(&mut self, reply: oneshot::Sender<Result<Stream>>) -> Result<()> {
fn new_outbound(&mut self) -> Result<Stream> {
if self.streams.len() >= self.config.max_num_streams {
log::error!("{}: maximum number of streams reached", self.id);
let _ = reply.send(Err(ConnectionError::TooManyStreams));
return Ok(());
return Err(ConnectionError::TooManyStreams);
}
log::trace!("{}: creating new outbound stream", self.id);
@@ -533,21 +529,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
stream
};
if reply.send(Ok(stream.clone())).is_ok() {
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream);
} else {
log::debug!("{}: open stream {} has been cancelled", self.id, id);
if extra_credit > 0 {
let mut header = Header::data(id, 0);
header.rst();
let frame = Frame::new(header);
log::trace!("{}/{}: sending reset", self.id, id);
self.pending_frames.push_back(frame.into());
}
}
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
Ok(())
Ok(stream)
}
fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
@@ -959,11 +944,3 @@ impl<T> Active<T> {
}
}
}
/// Turn a Yamux [`Connection`] into a [`futures::Stream`].
pub fn into_stream<T>(mut c: Connection<T>) -> impl futures::stream::Stream<Item = Result<Stream>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
futures::stream::poll_fn(move |cx| c.poll_next(cx))
}

View File

@@ -1,4 +1,4 @@
use crate::connection::{ControlCommand, StreamCommand};
use crate::connection::StreamCommand;
use crate::ConnectionError;
use futures::channel::mpsc;
use futures::{ready, StreamExt};
@@ -10,20 +10,17 @@ use std::task::{Context, Poll};
#[must_use]
pub struct Cleanup {
state: State,
control_receiver: mpsc::Receiver<ControlCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
error: Option<ConnectionError>,
}
impl Cleanup {
pub(crate) fn new(
control_receiver: mpsc::Receiver<ControlCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
error: ConnectionError,
) -> Self {
Self {
state: State::ClosingControlReceiver,
control_receiver,
state: State::ClosingStreamReceiver,
stream_receiver,
error: Some(error),
}
@@ -38,21 +35,6 @@ impl Future for Cleanup {
loop {
match this.state {
State::ClosingControlReceiver => {
this.control_receiver.close();
this.state = State::DrainingControlReceiver;
}
State::DrainingControlReceiver => {
match ready!(this.control_receiver.poll_next_unpin(cx)) {
Some(ControlCommand::OpenStream(reply)) => {
let _ = reply.send(Err(ConnectionError::Closed));
}
Some(ControlCommand::CloseConnection(reply)) => {
let _ = reply.send(());
}
None => this.state = State::ClosingStreamReceiver,
}
}
State::ClosingStreamReceiver => {
this.stream_receiver.close();
this.state = State::DrainingStreamReceiver;
@@ -81,8 +63,6 @@ impl Future for Cleanup {
#[allow(clippy::enum_variant_names)]
enum State {
ClosingControlReceiver,
DrainingControlReceiver,
ClosingStreamReceiver,
DrainingStreamReceiver,
}

View File

@@ -1,7 +1,7 @@
use crate::connection::Result;
use crate::connection::{ControlCommand, StreamCommand};
use crate::connection::StreamCommand;
use crate::frame;
use crate::frame::Frame;
use crate::{frame, ConnectionError};
use futures::channel::mpsc;
use futures::stream::Fuse;
use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt};
@@ -14,7 +14,6 @@ use std::task::{Context, Poll};
#[must_use]
pub struct Closing<T> {
state: State,
control_receiver: mpsc::Receiver<ControlCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
pending_frames: VecDeque<Frame<()>>,
socket: Fuse<frame::Io<T>>,
@@ -25,14 +24,12 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
control_receiver: mpsc::Receiver<ControlCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
pending_frames: VecDeque<Frame<()>>,
socket: Fuse<frame::Io<T>>,
) -> Self {
Self {
state: State::ClosingControlReceiver,
control_receiver,
state: State::ClosingStreamReceiver,
stream_receiver,
pending_frames,
socket,
@@ -51,21 +48,6 @@ where
loop {
match this.state {
State::ClosingControlReceiver => {
this.control_receiver.close();
this.state = State::DrainingControlReceiver;
}
State::DrainingControlReceiver => {
match ready!(this.control_receiver.poll_next_unpin(cx)) {
Some(ControlCommand::OpenStream(reply)) => {
let _ = reply.send(Err(ConnectionError::Closed));
}
Some(ControlCommand::CloseConnection(reply)) => {
let _ = reply.send(());
}
None => this.state = State::ClosingStreamReceiver,
}
}
State::ClosingStreamReceiver => {
this.stream_receiver.close();
this.state = State::DrainingStreamReceiver;
@@ -106,9 +88,7 @@ where
}
}
pub enum State {
ClosingControlReceiver,
DrainingControlReceiver,
enum State {
ClosingStreamReceiver,
DrainingStreamReceiver,
SendingTermFrame,

View File

@@ -8,56 +8,40 @@
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
// at https://opensource.org/licenses/MIT.
use super::ControlCommand;
use crate::{error::ConnectionError, Stream};
use crate::connection::MAX_COMMAND_BACKLOG;
use crate::{error::ConnectionError, Connection, Stream};
use futures::{
channel::{mpsc, oneshot},
prelude::*,
ready,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
use std::pin::Pin;
use std::task::{Context, Poll};
type Result<T> = std::result::Result<T, ConnectionError>;
/// The Yamux `Connection` controller.
/// A Yamux [`Connection`] controller.
///
/// While a Yamux connection makes progress via its `next_stream` method,
/// this controller can be used to concurrently direct the connection,
/// e.g. to open a new stream to the remote or to close the connection.
/// This presents an alternative API for using a yamux [`Connection`].
///
/// The possible operations are implemented as async methods and redundantly
/// as poll-based variants which may be useful inside of other poll based
/// environments such as certain trait implementations.
#[derive(Debug)]
/// A [`Control`] communicates with a [`ControlledConnection`] via a channel. This allows
/// a [`Control`] to be cloned and shared between tasks and threads.
#[derive(Clone, Debug)]
pub struct Control {
/// Command channel to `Connection`.
/// Command channel to [`ControlledConnection`].
sender: mpsc::Sender<ControlCommand>,
/// Pending state of `poll_open_stream`.
pending_open: Option<oneshot::Receiver<Result<Stream>>>,
/// Pending state of `poll_close`.
pending_close: Option<oneshot::Receiver<()>>,
}
impl Clone for Control {
fn clone(&self) -> Self {
Control {
sender: self.sender.clone(),
pending_open: None,
pending_close: None,
}
}
}
impl Control {
pub(crate) fn new(sender: mpsc::Sender<ControlCommand>) -> Self {
Control {
sender,
pending_open: None,
pending_close: None,
}
pub fn new<T>(connection: Connection<T>) -> (Self, ControlledConnection<T>) {
let (sender, receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let control = Control { sender };
let connection = ControlledConnection {
state: State::Idle(connection),
commands: receiver,
};
(control, connection)
}
/// Open a new stream to the remote.
@@ -84,66 +68,181 @@ impl Control {
let _ = rx.await;
Ok(())
}
}
/// [`Poll`] based alternative to [`Control::open_stream`].
pub fn poll_open_stream(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Stream>> {
/// Wraps a [`Connection`] which can be controlled with a [`Control`].
pub struct ControlledConnection<T> {
state: State<T>,
commands: mpsc::Receiver<ControlCommand>,
}
impl<T> ControlledConnection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
loop {
match self.pending_open.take() {
None => {
ready!(self.sender.poll_ready(cx)?);
let (tx, rx) = oneshot::channel();
self.sender.start_send(ControlCommand::OpenStream(tx))?;
self.pending_open = Some(rx)
}
Some(mut rx) => match rx.poll_unpin(cx)? {
Poll::Ready(result) => return Poll::Ready(result),
Poll::Pending => {
self.pending_open = Some(rx);
return Poll::Pending;
match std::mem::replace(&mut self.state, State::Poisoned) {
State::Idle(mut connection) => {
match connection.poll_next_inbound(cx) {
Poll::Ready(maybe_stream) => {
self.state = State::Idle(connection);
return Poll::Ready(maybe_stream);
}
Poll::Pending => {}
}
},
}
}
}
/// Abort an ongoing open stream operation started by `poll_open_stream`.
pub fn abort_open_stream(&mut self) {
self.pending_open = None
}
/// [`Poll`] based alternative to [`Control::close`].
pub fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
loop {
match self.pending_close.take() {
None => {
if ready!(self.sender.poll_ready(cx)).is_err() {
// The receiver is closed which means the connection is already closed.
return Poll::Ready(Ok(()));
}
let (tx, rx) = oneshot::channel();
if let Err(e) = self.sender.start_send(ControlCommand::CloseConnection(tx)) {
if e.is_full() {
match self.commands.poll_next_unpin(cx) {
Poll::Ready(Some(ControlCommand::OpenStream(reply))) => {
self.state = State::OpeningNewStream { reply, connection };
continue;
}
debug_assert!(e.is_disconnected());
// The receiver is closed which means the connection is already closed.
return Poll::Ready(Ok(()));
Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => {
self.commands.close();
self.state = State::Closing {
reply: Some(reply),
inner: Closing::DrainingControlCommands { connection },
};
continue;
}
Poll::Ready(None) => {
// Last `Control` sender was dropped, close te connection.
self.state = State::Closing {
reply: None,
inner: Closing::ClosingConnection { connection },
};
continue;
}
Poll::Pending => {}
}
self.pending_close = Some(rx)
self.state = State::Idle(connection);
return Poll::Pending;
}
Some(mut rx) => match rx.poll_unpin(cx) {
Poll::Ready(Ok(())) => return Poll::Ready(Ok(())),
Poll::Ready(Err(oneshot::Canceled)) => {
// A dropped `oneshot::Sender` means the `Connection` is gone,
// which is `Ok`ay for us here.
return Poll::Ready(Ok(()));
State::OpeningNewStream {
reply,
mut connection,
} => match connection.poll_new_outbound(cx) {
Poll::Ready(stream) => {
let _ = reply.send(stream);
self.state = State::Idle(connection);
continue;
}
Poll::Pending => {
self.pending_close = Some(rx);
self.state = State::OpeningNewStream { reply, connection };
return Poll::Pending;
}
},
State::Closing {
reply,
inner: Closing::DrainingControlCommands { connection },
} => match self.commands.poll_next_unpin(cx) {
Poll::Ready(Some(ControlCommand::OpenStream(new_reply))) => {
let _ = new_reply.send(Err(ConnectionError::Closed));
self.state = State::Closing {
reply,
inner: Closing::DrainingControlCommands { connection },
};
continue;
}
Poll::Ready(Some(ControlCommand::CloseConnection(new_reply))) => {
let _ = new_reply.send(());
self.state = State::Closing {
reply,
inner: Closing::DrainingControlCommands { connection },
};
continue;
}
Poll::Ready(None) => {
self.state = State::Closing {
reply,
inner: Closing::ClosingConnection { connection },
};
continue;
}
Poll::Pending => {
self.state = State::Closing {
reply,
inner: Closing::DrainingControlCommands { connection },
};
return Poll::Pending;
}
},
State::Closing {
reply,
inner: Closing::ClosingConnection { mut connection },
} => match connection.poll_close(cx) {
Poll::Ready(Ok(())) | Poll::Ready(Err(ConnectionError::Closed)) => {
if let Some(reply) = reply {
let _ = reply.send(());
}
return Poll::Ready(None);
}
Poll::Ready(Err(other)) => {
if let Some(reply) = reply {
let _ = reply.send(());
}
return Poll::Ready(Some(Err(other)));
}
Poll::Pending => {
self.state = State::Closing {
reply,
inner: Closing::ClosingConnection { connection },
};
return Poll::Pending;
}
},
State::Poisoned => unreachable!(),
}
}
}
}
#[derive(Debug)]
enum ControlCommand {
/// Open a new stream to the remote end.
OpenStream(oneshot::Sender<Result<Stream>>),
/// Close the whole connection.
CloseConnection(oneshot::Sender<()>),
}
/// The state of a [`ControlledConnection`].
enum State<T> {
Idle(Connection<T>),
OpeningNewStream {
reply: oneshot::Sender<Result<Stream>>,
connection: Connection<T>,
},
Closing {
/// A channel to the [`Control`] in case the close was requested. `None` if we are closing because the last [`Control`] was dropped.
reply: Option<oneshot::Sender<()>>,
inner: Closing<T>,
},
Poisoned,
}
/// A sub-state of our larger state machine for a [`ControlledConnection`].
///
/// Closing connection involves two steps:
///
/// 1. Draining and answered all remaining [`ControlCommands`].
/// 1. Closing the underlying [`Connection`].
enum Closing<T> {
DrainingControlCommands { connection: Connection<T> },
ClosingConnection { connection: Connection<T> },
}
impl<T> futures::Stream for ControlledConnection<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Item = Result<Stream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_next(cx)
}
}

View File

@@ -30,7 +30,7 @@ mod frame;
pub(crate) mod connection;
pub use crate::connection::{into_stream, Connection, Control, Mode, Packet, Stream};
pub use crate::connection::{Connection, Control, ControlledConnection, Mode, Packet, Stream};
pub use crate::error::ConnectionError;
pub use crate::frame::{
header::{HeaderDecodeError, StreamId},

View File

@@ -19,7 +19,9 @@ use std::{
use tokio::net::{TcpListener, TcpStream};
use tokio::{net::TcpSocket, runtime::Runtime, task};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use yamux::{Config, Connection, ConnectionError, Mode, WindowUpdateMode};
use yamux::{
Config, Connection, ConnectionError, Control, ControlledConnection, Mode, WindowUpdateMode,
};
const PAYLOAD_SIZE: usize = 128 * 1024;
@@ -36,7 +38,7 @@ fn concurrent_streams() {
task::spawn(echo_server(server));
let mut ctrl = client.control().unwrap();
let (mut ctrl, client) = Control::new(client);
task::spawn(noop_server(client));
let result = (0..n_streams)
@@ -70,11 +72,11 @@ fn concurrent_streams() {
}
/// For each incoming stream of `c` echo back to the sender.
async fn echo_server<T>(c: Connection<T>) -> Result<(), ConnectionError>
async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
yamux::into_stream(c)
stream::poll_fn(|cx| c.poll_next_inbound(cx))
.try_for_each_concurrent(None, |mut stream| async move {
log::debug!("S: accepted new stream");
@@ -93,16 +95,15 @@ where
}
/// For each incoming stream, do nothing.
async fn noop_server<T>(c: Connection<T>)
async fn noop_server<T>(c: ControlledConnection<T>)
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
yamux::into_stream(c)
.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
c.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
}
/// Sends the given data on the provided stream, length-prefixed.

View File

@@ -47,7 +47,7 @@ fn prop_config_send_recv_single() {
let server = echo_server(server);
let client = async {
let control = client.control().unwrap();
let (control, client) = Control::new(client);
task::spawn(noop_server(client));
send_on_single_stream(control, msgs).await?;
@@ -78,7 +78,7 @@ fn prop_config_send_recv_multi() {
let server = echo_server(server);
let client = async {
let control = client.control().unwrap();
let (control, client) = Control::new(client);
task::spawn(noop_server(client));
send_on_separate_streams(control, msgs).await?;
@@ -107,7 +107,7 @@ fn prop_send_recv() {
let server = echo_server(server);
let client = async {
let control = client.control().unwrap();
let (control, client) = Control::new(client);
task::spawn(noop_server(client));
send_on_separate_streams(control, msgs).await?;
@@ -134,7 +134,7 @@ fn prop_max_streams() {
task::spawn(echo_server(server));
let mut control = client.control().unwrap();
let (mut control, client) = Control::new(client);
task::spawn(noop_server(client));
let mut v = Vec::new();
@@ -161,8 +161,9 @@ fn prop_send_recv_half_closed() {
// Server should be able to write on a stream shutdown by the client.
let server = async {
let mut first_stream =
server.next_stream().await?.ok_or(ConnectionError::Closed)?;
let mut server = stream::poll_fn(move |cx| server.poll_next_inbound(cx));
let mut first_stream = server.next().await.ok_or(ConnectionError::Closed)??;
task::spawn(noop_server(server));
@@ -176,7 +177,7 @@ fn prop_send_recv_half_closed() {
// Client should be able to read after shutting down the stream.
let client = async {
let mut control = client.control().unwrap();
let (mut control, client) = Control::new(client);
task::spawn(noop_server(client));
let mut stream = control.open_stream().await?;
@@ -242,7 +243,7 @@ fn write_deadlock() {
// Create and spawn a "client" that sends messages expected to be echoed
// by the server.
let client = Connection::new(client_endpoint, Config::default(), Mode::Client);
let mut ctrl = client.control().unwrap();
let (mut ctrl, client) = Control::new(client);
// Continuously advance the Yamux connection of the client in a background task.
pool.spawner()
@@ -346,11 +347,11 @@ async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
}
/// For each incoming stream of `c` echo back to the sender.
async fn echo_server<T>(c: Connection<T>) -> Result<(), ConnectionError>
async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
yamux::into_stream(c)
stream::poll_fn(|cx| c.poll_next_inbound(cx))
.try_for_each_concurrent(None, |mut stream| async move {
{
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
@@ -363,16 +364,12 @@ where
}
/// For each incoming stream, do nothing.
async fn noop_server<T>(c: Connection<T>)
where
T: AsyncRead + AsyncWrite + Unpin,
{
yamux::into_stream(c)
.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
async fn noop_server(c: impl Stream<Item = Result<yamux::Stream, yamux::ConnectionError>>) {
c.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
}
/// Send all messages, opening a new stream for each one.