mirror of
https://github.com/tlsnotary/rust-yamux.git
synced 2026-01-09 12:58:03 -05:00
Rewrite Control to be a layer on top of Connection
This commit is contained in:
@@ -13,7 +13,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use futures::{channel::mpsc, future, io::AsyncReadExt, prelude::*};
|
||||
use std::sync::Arc;
|
||||
use tokio::{runtime::Runtime, task};
|
||||
use yamux::{Config, Connection, Mode};
|
||||
use yamux::{Config, Connection, Control, Mode};
|
||||
|
||||
criterion_group!(benches, concurrent);
|
||||
criterion_main!(benches);
|
||||
@@ -92,7 +92,10 @@ async fn oneway(
|
||||
let server = async move {
|
||||
let mut connection = Connection::new(server, config(), Mode::Server);
|
||||
|
||||
while let Some(mut stream) = connection.next_stream().await.unwrap() {
|
||||
while let Some(Ok(mut stream)) = stream::poll_fn(|cx| connection.poll_next_inbound(cx))
|
||||
.next()
|
||||
.await
|
||||
{
|
||||
let tx = tx.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
@@ -113,8 +116,9 @@ async fn oneway(
|
||||
task::spawn(server);
|
||||
|
||||
let conn = Connection::new(client, config(), Mode::Client);
|
||||
let mut ctrl = conn.control().unwrap();
|
||||
task::spawn(yamux::into_stream(conn).for_each(|r| {
|
||||
let (mut ctrl, conn) = Control::new(conn);
|
||||
|
||||
task::spawn(conn.for_each(|r| {
|
||||
r.unwrap();
|
||||
future::ready(())
|
||||
}));
|
||||
|
||||
@@ -98,21 +98,15 @@ use crate::{
|
||||
frame::{self, Frame},
|
||||
Config, WindowUpdateMode, DEFAULT_CREDIT,
|
||||
};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future::{self, Either},
|
||||
prelude::*,
|
||||
sink::SinkExt,
|
||||
stream::Fuse,
|
||||
};
|
||||
use cleanup::Cleanup;
|
||||
use closing::Closing;
|
||||
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::task::Context;
|
||||
use std::{fmt, sync::Arc, task::Poll};
|
||||
|
||||
use crate::connection::cleanup::Cleanup;
|
||||
use crate::connection::closing::Closing;
|
||||
pub use control::Control;
|
||||
pub use control::{Control, ControlledConnection};
|
||||
pub use stream::{Packet, State, Stream};
|
||||
|
||||
/// Arbitrary limit of our internal command channels.
|
||||
@@ -169,39 +163,148 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a controller for this connection.
|
||||
pub fn control(&self) -> Result<Control> {
|
||||
match &self.inner {
|
||||
ConnectionState::Active(active) => Ok(active.control()),
|
||||
ConnectionState::Closed
|
||||
| ConnectionState::Closing { .. }
|
||||
| ConnectionState::Cleanup(_) => Err(ConnectionError::Closed),
|
||||
ConnectionState::Poisoned => unreachable!(),
|
||||
/// Poll for a new outbound stream.
|
||||
///
|
||||
/// This function will fail if the current state does not allow opening new outbound streams.
|
||||
pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
|
||||
loop {
|
||||
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
|
||||
ConnectionState::Active(mut active) => match active.new_outbound() {
|
||||
Ok(stream) => {
|
||||
self.inner = ConnectionState::Active(active);
|
||||
return Poll::Ready(Ok(stream));
|
||||
}
|
||||
Err(e) => {
|
||||
self.inner = ConnectionState::Cleanup(active.cleanup(e));
|
||||
continue;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Err(ConnectionError::Closed));
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Closing(inner);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Cleanup(mut inner) => match inner.poll_unpin(cx) {
|
||||
Poll::Ready(e) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Cleanup(inner);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closed => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Err(ConnectionError::Closed));
|
||||
}
|
||||
ConnectionState::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the next incoming stream, opened by the remote.
|
||||
/// Poll for the next inbound stream.
|
||||
///
|
||||
/// This must be called repeatedly in order to make progress.
|
||||
/// Once `Ok(None)` or `Err(_)` is returned the connection is
|
||||
/// considered closed and no further invocation of this method
|
||||
/// must be attempted.
|
||||
///
|
||||
/// # Cancellation
|
||||
///
|
||||
/// This function is cancellation-safe.
|
||||
pub async fn next_stream(&mut self) -> Result<Option<Stream>> {
|
||||
future::poll_fn(|cx| self.inner.poll_next(cx))
|
||||
.await
|
||||
.transpose()
|
||||
/// If this function returns `None`, the underlying connection is closed.
|
||||
pub fn poll_next_inbound(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
|
||||
loop {
|
||||
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
|
||||
ConnectionState::Active(mut active) => match active.poll(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
self.inner = ConnectionState::Active(active);
|
||||
return Poll::Ready(Some(Ok(stream)));
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.inner = ConnectionState::Cleanup(active.cleanup(e));
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Active(active);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closing(mut closing) => match closing.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Some(Err(e)));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Closing(closing);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
|
||||
Poll::Ready(ConnectionError::Closed) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Ready(other) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Some(Err(other)));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Cleanup(cleanup);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closed => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
ConnectionState::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Have the underlying connection make progress.
|
||||
///
|
||||
/// If this returns `Poll::Ready(None)`, the connection is closed and does no longer
|
||||
/// need to be polled.
|
||||
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
|
||||
self.inner.poll_next(cx)
|
||||
/// Close the connection.
|
||||
pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
loop {
|
||||
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
|
||||
ConnectionState::Active(active) => {
|
||||
self.inner = ConnectionState::Closing(active.close());
|
||||
}
|
||||
ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx)? {
|
||||
Poll::Ready(()) => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Closing(inner);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
|
||||
Poll::Ready(reason) => {
|
||||
log::warn!("Failure while closing connection: {}", reason);
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Cleanup(cleanup);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closed => {
|
||||
self.inner = ConnectionState::Closed;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
ConnectionState::Poisoned => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,7 +312,7 @@ impl<T> Drop for Connection<T> {
|
||||
fn drop(&mut self) {
|
||||
match &mut self.inner {
|
||||
ConnectionState::Active(active) => active.drop_all_streams(),
|
||||
ConnectionState::Closing { .. } => {}
|
||||
ConnectionState::Closing(_) => {}
|
||||
ConnectionState::Cleanup(_) => {}
|
||||
ConnectionState::Closed => {}
|
||||
ConnectionState::Poisoned => {}
|
||||
@@ -221,10 +324,7 @@ enum ConnectionState<T> {
|
||||
/// The connection is alive and healthy.
|
||||
Active(Active<T>),
|
||||
/// Our user requested to shutdown the connection, we are working on it.
|
||||
Closing {
|
||||
inner: Closing<T>,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
Closing(Closing<T>),
|
||||
/// An error occurred and we are cleaning up our resources.
|
||||
Cleanup(Cleanup),
|
||||
/// The connection is closed.
|
||||
@@ -237,7 +337,7 @@ impl<T> fmt::Debug for ConnectionState<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ConnectionState::Active(_) => write!(f, "Active"),
|
||||
ConnectionState::Closing { .. } => write!(f, "Closing"),
|
||||
ConnectionState::Closing(_) => write!(f, "Closing"),
|
||||
ConnectionState::Cleanup(_) => write!(f, "Cleanup"),
|
||||
ConnectionState::Closed => write!(f, "Closed"),
|
||||
ConnectionState::Poisoned => write!(f, "Poisoned"),
|
||||
@@ -245,86 +345,6 @@ impl<T> fmt::Debug for ConnectionState<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> ConnectionState<T> {
|
||||
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
|
||||
loop {
|
||||
match std::mem::replace(self, ConnectionState::Poisoned) {
|
||||
ConnectionState::Active(mut active) => {
|
||||
match active.control_receiver.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ControlCommand::OpenStream(reply))) => {
|
||||
active.on_open_stream(reply)?;
|
||||
|
||||
*self = ConnectionState::Active(active);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => {
|
||||
*self = ConnectionState::Closing {
|
||||
inner: active.close(),
|
||||
reply,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
debug_assert!(false, "Only closed during shutdown")
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
match active.poll(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
*self = ConnectionState::Active(active);
|
||||
return Poll::Ready(Some(Ok(stream)));
|
||||
}
|
||||
Poll::Ready(Err(e)) => *self = ConnectionState::Cleanup(active.cleanup(e)),
|
||||
Poll::Pending => {
|
||||
*self = ConnectionState::Active(active);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionState::Closing {
|
||||
inner: mut closing,
|
||||
reply,
|
||||
} => match closing.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let _ = reply.send(());
|
||||
*self = ConnectionState::Closed;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
let _ = reply.send(());
|
||||
*self = ConnectionState::Closed;
|
||||
return Poll::Ready(Some(Err(e)));
|
||||
}
|
||||
Poll::Pending => {
|
||||
*self = ConnectionState::Closing {
|
||||
inner: closing,
|
||||
reply,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
|
||||
Poll::Ready(ConnectionError::Closed) => {
|
||||
*self = ConnectionState::Closed;
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Ready(other) => {
|
||||
*self = ConnectionState::Closed;
|
||||
return Poll::Ready(Some(Err(other)));
|
||||
}
|
||||
Poll::Pending => {
|
||||
*self = ConnectionState::Cleanup(cleanup);
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
ConnectionState::Closed => return Poll::Ready(None),
|
||||
ConnectionState::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Yamux connection object.
|
||||
///
|
||||
/// Wraps the underlying I/O resource and makes progress via its
|
||||
@@ -337,8 +357,6 @@ struct Active<T> {
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
next_id: u32,
|
||||
streams: IntMap<StreamId, Stream>,
|
||||
control_sender: mpsc::Sender<ControlCommand>,
|
||||
control_receiver: mpsc::Receiver<ControlCommand>,
|
||||
stream_sender: mpsc::Sender<StreamCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
garbage: Vec<StreamId>,
|
||||
@@ -346,15 +364,6 @@ struct Active<T> {
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
}
|
||||
|
||||
/// `Control` to `Connection` commands.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ControlCommand {
|
||||
/// Open a new stream to the remote end.
|
||||
OpenStream(oneshot::Sender<Result<Stream>>),
|
||||
/// Close the whole connection.
|
||||
CloseConnection(oneshot::Sender<()>),
|
||||
}
|
||||
|
||||
/// `Stream` to `Connection` commands.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum StreamCommand {
|
||||
@@ -410,7 +419,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
let id = Id::random();
|
||||
log::debug!("new connection: {} ({:?})", id, mode);
|
||||
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
|
||||
let (control_sender, control_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
|
||||
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
|
||||
Active {
|
||||
id,
|
||||
@@ -418,8 +426,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
config: Arc::new(cfg),
|
||||
socket,
|
||||
streams: IntMap::default(),
|
||||
control_sender,
|
||||
control_receiver,
|
||||
stream_sender,
|
||||
stream_receiver,
|
||||
next_id: match mode {
|
||||
@@ -431,18 +437,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn control(&self) -> Control {
|
||||
Control::new(self.control_sender.clone())
|
||||
}
|
||||
|
||||
/// Gracefully close the connection to the remote.
|
||||
fn close(self) -> Closing<T> {
|
||||
Closing::new(
|
||||
self.control_receiver,
|
||||
self.stream_receiver,
|
||||
self.pending_frames,
|
||||
self.socket,
|
||||
)
|
||||
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
|
||||
}
|
||||
|
||||
/// Cleanup all our resources.
|
||||
@@ -451,7 +448,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
|
||||
self.drop_all_streams();
|
||||
|
||||
Cleanup::new(self.control_receiver, self.stream_receiver, error)
|
||||
Cleanup::new(self.stream_receiver, error)
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
|
||||
@@ -503,11 +500,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_open_stream(&mut self, reply: oneshot::Sender<Result<Stream>>) -> Result<()> {
|
||||
fn new_outbound(&mut self) -> Result<Stream> {
|
||||
if self.streams.len() >= self.config.max_num_streams {
|
||||
log::error!("{}: maximum number of streams reached", self.id);
|
||||
let _ = reply.send(Err(ConnectionError::TooManyStreams));
|
||||
return Ok(());
|
||||
return Err(ConnectionError::TooManyStreams);
|
||||
}
|
||||
|
||||
log::trace!("{}: creating new outbound stream", self.id);
|
||||
@@ -533,21 +529,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
stream
|
||||
};
|
||||
|
||||
if reply.send(Ok(stream.clone())).is_ok() {
|
||||
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
|
||||
self.streams.insert(id, stream);
|
||||
} else {
|
||||
log::debug!("{}: open stream {} has been cancelled", self.id, id);
|
||||
if extra_credit > 0 {
|
||||
let mut header = Header::data(id, 0);
|
||||
header.rst();
|
||||
let frame = Frame::new(header);
|
||||
log::trace!("{}/{}: sending reset", self.id, id);
|
||||
self.pending_frames.push_back(frame.into());
|
||||
}
|
||||
}
|
||||
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
|
||||
self.streams.insert(id, stream.clone());
|
||||
|
||||
Ok(())
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
|
||||
@@ -959,11 +944,3 @@ impl<T> Active<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn a Yamux [`Connection`] into a [`futures::Stream`].
|
||||
pub fn into_stream<T>(mut c: Connection<T>) -> impl futures::stream::Stream<Item = Result<Stream>>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
futures::stream::poll_fn(move |cx| c.poll_next(cx))
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::connection::{ControlCommand, StreamCommand};
|
||||
use crate::connection::StreamCommand;
|
||||
use crate::ConnectionError;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{ready, StreamExt};
|
||||
@@ -10,20 +10,17 @@ use std::task::{Context, Poll};
|
||||
#[must_use]
|
||||
pub struct Cleanup {
|
||||
state: State,
|
||||
control_receiver: mpsc::Receiver<ControlCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
error: Option<ConnectionError>,
|
||||
}
|
||||
|
||||
impl Cleanup {
|
||||
pub(crate) fn new(
|
||||
control_receiver: mpsc::Receiver<ControlCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
error: ConnectionError,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: State::ClosingControlReceiver,
|
||||
control_receiver,
|
||||
state: State::ClosingStreamReceiver,
|
||||
stream_receiver,
|
||||
error: Some(error),
|
||||
}
|
||||
@@ -38,21 +35,6 @@ impl Future for Cleanup {
|
||||
|
||||
loop {
|
||||
match this.state {
|
||||
State::ClosingControlReceiver => {
|
||||
this.control_receiver.close();
|
||||
this.state = State::DrainingControlReceiver;
|
||||
}
|
||||
State::DrainingControlReceiver => {
|
||||
match ready!(this.control_receiver.poll_next_unpin(cx)) {
|
||||
Some(ControlCommand::OpenStream(reply)) => {
|
||||
let _ = reply.send(Err(ConnectionError::Closed));
|
||||
}
|
||||
Some(ControlCommand::CloseConnection(reply)) => {
|
||||
let _ = reply.send(());
|
||||
}
|
||||
None => this.state = State::ClosingStreamReceiver,
|
||||
}
|
||||
}
|
||||
State::ClosingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
this.state = State::DrainingStreamReceiver;
|
||||
@@ -81,8 +63,6 @@ impl Future for Cleanup {
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum State {
|
||||
ClosingControlReceiver,
|
||||
DrainingControlReceiver,
|
||||
ClosingStreamReceiver,
|
||||
DrainingStreamReceiver,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::connection::Result;
|
||||
use crate::connection::{ControlCommand, StreamCommand};
|
||||
use crate::connection::StreamCommand;
|
||||
use crate::frame;
|
||||
use crate::frame::Frame;
|
||||
use crate::{frame, ConnectionError};
|
||||
use futures::channel::mpsc;
|
||||
use futures::stream::Fuse;
|
||||
use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt};
|
||||
@@ -14,7 +14,6 @@ use std::task::{Context, Poll};
|
||||
#[must_use]
|
||||
pub struct Closing<T> {
|
||||
state: State,
|
||||
control_receiver: mpsc::Receiver<ControlCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
@@ -25,14 +24,12 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
control_receiver: mpsc::Receiver<ControlCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: State::ClosingControlReceiver,
|
||||
control_receiver,
|
||||
state: State::ClosingStreamReceiver,
|
||||
stream_receiver,
|
||||
pending_frames,
|
||||
socket,
|
||||
@@ -51,21 +48,6 @@ where
|
||||
|
||||
loop {
|
||||
match this.state {
|
||||
State::ClosingControlReceiver => {
|
||||
this.control_receiver.close();
|
||||
this.state = State::DrainingControlReceiver;
|
||||
}
|
||||
State::DrainingControlReceiver => {
|
||||
match ready!(this.control_receiver.poll_next_unpin(cx)) {
|
||||
Some(ControlCommand::OpenStream(reply)) => {
|
||||
let _ = reply.send(Err(ConnectionError::Closed));
|
||||
}
|
||||
Some(ControlCommand::CloseConnection(reply)) => {
|
||||
let _ = reply.send(());
|
||||
}
|
||||
None => this.state = State::ClosingStreamReceiver,
|
||||
}
|
||||
}
|
||||
State::ClosingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
this.state = State::DrainingStreamReceiver;
|
||||
@@ -106,9 +88,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub enum State {
|
||||
ClosingControlReceiver,
|
||||
DrainingControlReceiver,
|
||||
enum State {
|
||||
ClosingStreamReceiver,
|
||||
DrainingStreamReceiver,
|
||||
SendingTermFrame,
|
||||
|
||||
@@ -8,56 +8,40 @@
|
||||
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
|
||||
// at https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::ControlCommand;
|
||||
use crate::{error::ConnectionError, Stream};
|
||||
use crate::connection::MAX_COMMAND_BACKLOG;
|
||||
use crate::{error::ConnectionError, Connection, Stream};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
ready,
|
||||
};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
type Result<T> = std::result::Result<T, ConnectionError>;
|
||||
|
||||
/// The Yamux `Connection` controller.
|
||||
/// A Yamux [`Connection`] controller.
|
||||
///
|
||||
/// While a Yamux connection makes progress via its `next_stream` method,
|
||||
/// this controller can be used to concurrently direct the connection,
|
||||
/// e.g. to open a new stream to the remote or to close the connection.
|
||||
/// This presents an alternative API for using a yamux [`Connection`].
|
||||
///
|
||||
/// The possible operations are implemented as async methods and redundantly
|
||||
/// as poll-based variants which may be useful inside of other poll based
|
||||
/// environments such as certain trait implementations.
|
||||
#[derive(Debug)]
|
||||
/// A [`Control`] communicates with a [`ControlledConnection`] via a channel. This allows
|
||||
/// a [`Control`] to be cloned and shared between tasks and threads.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Control {
|
||||
/// Command channel to `Connection`.
|
||||
/// Command channel to [`ControlledConnection`].
|
||||
sender: mpsc::Sender<ControlCommand>,
|
||||
/// Pending state of `poll_open_stream`.
|
||||
pending_open: Option<oneshot::Receiver<Result<Stream>>>,
|
||||
/// Pending state of `poll_close`.
|
||||
pending_close: Option<oneshot::Receiver<()>>,
|
||||
}
|
||||
|
||||
impl Clone for Control {
|
||||
fn clone(&self) -> Self {
|
||||
Control {
|
||||
sender: self.sender.clone(),
|
||||
pending_open: None,
|
||||
pending_close: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Control {
|
||||
pub(crate) fn new(sender: mpsc::Sender<ControlCommand>) -> Self {
|
||||
Control {
|
||||
sender,
|
||||
pending_open: None,
|
||||
pending_close: None,
|
||||
}
|
||||
pub fn new<T>(connection: Connection<T>) -> (Self, ControlledConnection<T>) {
|
||||
let (sender, receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
|
||||
|
||||
let control = Control { sender };
|
||||
let connection = ControlledConnection {
|
||||
state: State::Idle(connection),
|
||||
commands: receiver,
|
||||
};
|
||||
|
||||
(control, connection)
|
||||
}
|
||||
|
||||
/// Open a new stream to the remote.
|
||||
@@ -84,66 +68,181 @@ impl Control {
|
||||
let _ = rx.await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Poll`] based alternative to [`Control::open_stream`].
|
||||
pub fn poll_open_stream(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Stream>> {
|
||||
/// Wraps a [`Connection`] which can be controlled with a [`Control`].
|
||||
pub struct ControlledConnection<T> {
|
||||
state: State<T>,
|
||||
commands: mpsc::Receiver<ControlCommand>,
|
||||
}
|
||||
|
||||
impl<T> ControlledConnection<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
|
||||
loop {
|
||||
match self.pending_open.take() {
|
||||
None => {
|
||||
ready!(self.sender.poll_ready(cx)?);
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.sender.start_send(ControlCommand::OpenStream(tx))?;
|
||||
self.pending_open = Some(rx)
|
||||
}
|
||||
Some(mut rx) => match rx.poll_unpin(cx)? {
|
||||
Poll::Ready(result) => return Poll::Ready(result),
|
||||
Poll::Pending => {
|
||||
self.pending_open = Some(rx);
|
||||
return Poll::Pending;
|
||||
match std::mem::replace(&mut self.state, State::Poisoned) {
|
||||
State::Idle(mut connection) => {
|
||||
match connection.poll_next_inbound(cx) {
|
||||
Poll::Ready(maybe_stream) => {
|
||||
self.state = State::Idle(connection);
|
||||
return Poll::Ready(maybe_stream);
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Abort an ongoing open stream operation started by `poll_open_stream`.
|
||||
pub fn abort_open_stream(&mut self) {
|
||||
self.pending_open = None
|
||||
}
|
||||
|
||||
/// [`Poll`] based alternative to [`Control::close`].
|
||||
pub fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
|
||||
loop {
|
||||
match self.pending_close.take() {
|
||||
None => {
|
||||
if ready!(self.sender.poll_ready(cx)).is_err() {
|
||||
// The receiver is closed which means the connection is already closed.
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = self.sender.start_send(ControlCommand::CloseConnection(tx)) {
|
||||
if e.is_full() {
|
||||
match self.commands.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ControlCommand::OpenStream(reply))) => {
|
||||
self.state = State::OpeningNewStream { reply, connection };
|
||||
continue;
|
||||
}
|
||||
debug_assert!(e.is_disconnected());
|
||||
// The receiver is closed which means the connection is already closed.
|
||||
return Poll::Ready(Ok(()));
|
||||
Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => {
|
||||
self.commands.close();
|
||||
|
||||
self.state = State::Closing {
|
||||
reply: Some(reply),
|
||||
inner: Closing::DrainingControlCommands { connection },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Last `Control` sender was dropped, close te connection.
|
||||
self.state = State::Closing {
|
||||
reply: None,
|
||||
inner: Closing::ClosingConnection { connection },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
self.pending_close = Some(rx)
|
||||
|
||||
self.state = State::Idle(connection);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Some(mut rx) => match rx.poll_unpin(cx) {
|
||||
Poll::Ready(Ok(())) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(oneshot::Canceled)) => {
|
||||
// A dropped `oneshot::Sender` means the `Connection` is gone,
|
||||
// which is `Ok`ay for us here.
|
||||
return Poll::Ready(Ok(()));
|
||||
State::OpeningNewStream {
|
||||
reply,
|
||||
mut connection,
|
||||
} => match connection.poll_new_outbound(cx) {
|
||||
Poll::Ready(stream) => {
|
||||
let _ = reply.send(stream);
|
||||
|
||||
self.state = State::Idle(connection);
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.pending_close = Some(rx);
|
||||
self.state = State::OpeningNewStream { reply, connection };
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
State::Closing {
|
||||
reply,
|
||||
inner: Closing::DrainingControlCommands { connection },
|
||||
} => match self.commands.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ControlCommand::OpenStream(new_reply))) => {
|
||||
let _ = new_reply.send(Err(ConnectionError::Closed));
|
||||
|
||||
self.state = State::Closing {
|
||||
reply,
|
||||
inner: Closing::DrainingControlCommands { connection },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(ControlCommand::CloseConnection(new_reply))) => {
|
||||
let _ = new_reply.send(());
|
||||
|
||||
self.state = State::Closing {
|
||||
reply,
|
||||
inner: Closing::DrainingControlCommands { connection },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
self.state = State::Closing {
|
||||
reply,
|
||||
inner: Closing::ClosingConnection { connection },
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Closing {
|
||||
reply,
|
||||
inner: Closing::DrainingControlCommands { connection },
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
State::Closing {
|
||||
reply,
|
||||
inner: Closing::ClosingConnection { mut connection },
|
||||
} => match connection.poll_close(cx) {
|
||||
Poll::Ready(Ok(())) | Poll::Ready(Err(ConnectionError::Closed)) => {
|
||||
if let Some(reply) = reply {
|
||||
let _ = reply.send(());
|
||||
}
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Ready(Err(other)) => {
|
||||
if let Some(reply) = reply {
|
||||
let _ = reply.send(());
|
||||
}
|
||||
return Poll::Ready(Some(Err(other)));
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.state = State::Closing {
|
||||
reply,
|
||||
inner: Closing::ClosingConnection { connection },
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
State::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ControlCommand {
|
||||
/// Open a new stream to the remote end.
|
||||
OpenStream(oneshot::Sender<Result<Stream>>),
|
||||
/// Close the whole connection.
|
||||
CloseConnection(oneshot::Sender<()>),
|
||||
}
|
||||
|
||||
/// The state of a [`ControlledConnection`].
|
||||
enum State<T> {
|
||||
Idle(Connection<T>),
|
||||
OpeningNewStream {
|
||||
reply: oneshot::Sender<Result<Stream>>,
|
||||
connection: Connection<T>,
|
||||
},
|
||||
Closing {
|
||||
/// A channel to the [`Control`] in case the close was requested. `None` if we are closing because the last [`Control`] was dropped.
|
||||
reply: Option<oneshot::Sender<()>>,
|
||||
inner: Closing<T>,
|
||||
},
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
/// A sub-state of our larger state machine for a [`ControlledConnection`].
|
||||
///
|
||||
/// Closing connection involves two steps:
|
||||
///
|
||||
/// 1. Draining and answered all remaining [`ControlCommands`].
|
||||
/// 1. Closing the underlying [`Connection`].
|
||||
enum Closing<T> {
|
||||
DrainingControlCommands { connection: Connection<T> },
|
||||
ClosingConnection { connection: Connection<T> },
|
||||
}
|
||||
|
||||
impl<T> futures::Stream for ControlledConnection<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Item = Result<Stream>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.get_mut().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ mod frame;
|
||||
|
||||
pub(crate) mod connection;
|
||||
|
||||
pub use crate::connection::{into_stream, Connection, Control, Mode, Packet, Stream};
|
||||
pub use crate::connection::{Connection, Control, ControlledConnection, Mode, Packet, Stream};
|
||||
pub use crate::error::ConnectionError;
|
||||
pub use crate::frame::{
|
||||
header::{HeaderDecodeError, StreamId},
|
||||
|
||||
@@ -19,7 +19,9 @@ use std::{
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::{net::TcpSocket, runtime::Runtime, task};
|
||||
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
|
||||
use yamux::{Config, Connection, ConnectionError, Mode, WindowUpdateMode};
|
||||
use yamux::{
|
||||
Config, Connection, ConnectionError, Control, ControlledConnection, Mode, WindowUpdateMode,
|
||||
};
|
||||
|
||||
const PAYLOAD_SIZE: usize = 128 * 1024;
|
||||
|
||||
@@ -36,7 +38,7 @@ fn concurrent_streams() {
|
||||
|
||||
task::spawn(echo_server(server));
|
||||
|
||||
let mut ctrl = client.control().unwrap();
|
||||
let (mut ctrl, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
|
||||
let result = (0..n_streams)
|
||||
@@ -70,11 +72,11 @@ fn concurrent_streams() {
|
||||
}
|
||||
|
||||
/// For each incoming stream of `c` echo back to the sender.
|
||||
async fn echo_server<T>(c: Connection<T>) -> Result<(), ConnectionError>
|
||||
async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
yamux::into_stream(c)
|
||||
stream::poll_fn(|cx| c.poll_next_inbound(cx))
|
||||
.try_for_each_concurrent(None, |mut stream| async move {
|
||||
log::debug!("S: accepted new stream");
|
||||
|
||||
@@ -93,16 +95,15 @@ where
|
||||
}
|
||||
|
||||
/// For each incoming stream, do nothing.
|
||||
async fn noop_server<T>(c: Connection<T>)
|
||||
async fn noop_server<T>(c: ControlledConnection<T>)
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
yamux::into_stream(c)
|
||||
.for_each(|maybe_stream| {
|
||||
drop(maybe_stream);
|
||||
future::ready(())
|
||||
})
|
||||
.await;
|
||||
c.for_each(|maybe_stream| {
|
||||
drop(maybe_stream);
|
||||
future::ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Sends the given data on the provided stream, length-prefixed.
|
||||
|
||||
@@ -47,7 +47,7 @@ fn prop_config_send_recv_single() {
|
||||
|
||||
let server = echo_server(server);
|
||||
let client = async {
|
||||
let control = client.control().unwrap();
|
||||
let (control, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
send_on_single_stream(control, msgs).await?;
|
||||
|
||||
@@ -78,7 +78,7 @@ fn prop_config_send_recv_multi() {
|
||||
|
||||
let server = echo_server(server);
|
||||
let client = async {
|
||||
let control = client.control().unwrap();
|
||||
let (control, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
send_on_separate_streams(control, msgs).await?;
|
||||
|
||||
@@ -107,7 +107,7 @@ fn prop_send_recv() {
|
||||
|
||||
let server = echo_server(server);
|
||||
let client = async {
|
||||
let control = client.control().unwrap();
|
||||
let (control, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
send_on_separate_streams(control, msgs).await?;
|
||||
|
||||
@@ -134,7 +134,7 @@ fn prop_max_streams() {
|
||||
|
||||
task::spawn(echo_server(server));
|
||||
|
||||
let mut control = client.control().unwrap();
|
||||
let (mut control, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
|
||||
let mut v = Vec::new();
|
||||
@@ -161,8 +161,9 @@ fn prop_send_recv_half_closed() {
|
||||
|
||||
// Server should be able to write on a stream shutdown by the client.
|
||||
let server = async {
|
||||
let mut first_stream =
|
||||
server.next_stream().await?.ok_or(ConnectionError::Closed)?;
|
||||
let mut server = stream::poll_fn(move |cx| server.poll_next_inbound(cx));
|
||||
|
||||
let mut first_stream = server.next().await.ok_or(ConnectionError::Closed)??;
|
||||
|
||||
task::spawn(noop_server(server));
|
||||
|
||||
@@ -176,7 +177,7 @@ fn prop_send_recv_half_closed() {
|
||||
|
||||
// Client should be able to read after shutting down the stream.
|
||||
let client = async {
|
||||
let mut control = client.control().unwrap();
|
||||
let (mut control, client) = Control::new(client);
|
||||
task::spawn(noop_server(client));
|
||||
|
||||
let mut stream = control.open_stream().await?;
|
||||
@@ -242,7 +243,7 @@ fn write_deadlock() {
|
||||
// Create and spawn a "client" that sends messages expected to be echoed
|
||||
// by the server.
|
||||
let client = Connection::new(client_endpoint, Config::default(), Mode::Client);
|
||||
let mut ctrl = client.control().unwrap();
|
||||
let (mut ctrl, client) = Control::new(client);
|
||||
|
||||
// Continuously advance the Yamux connection of the client in a background task.
|
||||
pool.spawner()
|
||||
@@ -346,11 +347,11 @@ async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
|
||||
}
|
||||
|
||||
/// For each incoming stream of `c` echo back to the sender.
|
||||
async fn echo_server<T>(c: Connection<T>) -> Result<(), ConnectionError>
|
||||
async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
yamux::into_stream(c)
|
||||
stream::poll_fn(|cx| c.poll_next_inbound(cx))
|
||||
.try_for_each_concurrent(None, |mut stream| async move {
|
||||
{
|
||||
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
|
||||
@@ -363,16 +364,12 @@ where
|
||||
}
|
||||
|
||||
/// For each incoming stream, do nothing.
|
||||
async fn noop_server<T>(c: Connection<T>)
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
yamux::into_stream(c)
|
||||
.for_each(|maybe_stream| {
|
||||
drop(maybe_stream);
|
||||
future::ready(())
|
||||
})
|
||||
.await;
|
||||
async fn noop_server(c: impl Stream<Item = Result<yamux::Stream, yamux::ConnectionError>>) {
|
||||
c.for_each(|maybe_stream| {
|
||||
drop(maybe_stream);
|
||||
future::ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Send all messages, opening a new stream for each one.
|
||||
|
||||
Reference in New Issue
Block a user