mirror of
https://github.com/tlsnotary/rust-yamux.git
synced 2026-01-09 12:58:03 -05:00
feat: don't allow to open more than 256 unacknowledged streams (#153)
Co-authored-by: Max Inden <mail@max-inden.de>
This commit is contained in:
256
test-harness/tests/ack_backlog.rs
Normal file
256
test-harness/tests/ack_backlog.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::future::FutureExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::{future, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt};
|
||||
use std::future::Future;
|
||||
use std::mem;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use test_harness::bind;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use yamux::{Config, Connection, ConnectionError, Mode, Stream};
|
||||
|
||||
#[tokio::test]
|
||||
async fn honours_ack_backlog_of_256() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let (listener, address) = bind(None).await.expect("bind");
|
||||
|
||||
let server = async {
|
||||
let socket = listener.accept().await.expect("accept").0.compat();
|
||||
let connection = Connection::new(socket, Config::default(), Mode::Server);
|
||||
|
||||
Server::new(connection, rx).await
|
||||
};
|
||||
|
||||
let client = async {
|
||||
let socket = TcpStream::connect(address).await.expect("connect").compat();
|
||||
let connection = Connection::new(socket, Config::default(), Mode::Client);
|
||||
|
||||
Client::new(connection, tx).await
|
||||
};
|
||||
|
||||
let (server_processed, client_processed) = future::try_join(server, client).await.unwrap();
|
||||
|
||||
assert_eq!(server_processed, 257);
|
||||
assert_eq!(client_processed, 257);
|
||||
}
|
||||
|
||||
enum Server<T> {
|
||||
Idle {
|
||||
connection: Connection<T>,
|
||||
trigger: oneshot::Receiver<()>,
|
||||
},
|
||||
Accepting {
|
||||
connection: Connection<T>,
|
||||
worker_streams: FuturesUnordered<BoxFuture<'static, yamux::Result<()>>>,
|
||||
streams_processed: usize,
|
||||
},
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
impl<T> Server<T> {
|
||||
fn new(connection: Connection<T>, trigger: oneshot::Receiver<()>) -> Self {
|
||||
Server::Idle {
|
||||
connection,
|
||||
trigger,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Server<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = yamux::Result<usize>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match mem::replace(this, Server::Poisoned) {
|
||||
Server::Idle {
|
||||
mut trigger,
|
||||
connection,
|
||||
} => match trigger.poll_unpin(cx) {
|
||||
Poll::Ready(_) => {
|
||||
*this = Server::Accepting {
|
||||
connection,
|
||||
worker_streams: Default::default(),
|
||||
streams_processed: 0,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
*this = Server::Idle {
|
||||
trigger,
|
||||
connection,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
Server::Accepting {
|
||||
mut connection,
|
||||
mut streams_processed,
|
||||
mut worker_streams,
|
||||
} => {
|
||||
match connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(stream)) => {
|
||||
worker_streams.push(pong_ping(stream).boxed());
|
||||
*this = Server::Accepting {
|
||||
connection,
|
||||
streams_processed,
|
||||
worker_streams,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(Ok(streams_processed));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match worker_streams.poll_next_unpin(cx)? {
|
||||
Poll::Ready(Some(())) => {
|
||||
streams_processed += 1;
|
||||
*this = Server::Accepting {
|
||||
connection,
|
||||
streams_processed,
|
||||
worker_streams,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) | Poll::Pending => {}
|
||||
}
|
||||
|
||||
*this = Server::Accepting {
|
||||
connection,
|
||||
streams_processed,
|
||||
worker_streams,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
Server::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Client<T> {
|
||||
connection: Connection<T>,
|
||||
worker_streams: FuturesUnordered<BoxFuture<'static, yamux::Result<()>>>,
|
||||
trigger: Option<oneshot::Sender<()>>,
|
||||
streams_processed: usize,
|
||||
}
|
||||
|
||||
impl<T> Client<T> {
|
||||
fn new(connection: Connection<T>, trigger: oneshot::Sender<()>) -> Self {
|
||||
Self {
|
||||
connection,
|
||||
trigger: Some(trigger),
|
||||
worker_streams: FuturesUnordered::default(),
|
||||
streams_processed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Client<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = yamux::Result<usize>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
// First, try to open 256 streams
|
||||
if this.worker_streams.len() < 256 && this.streams_processed == 0 {
|
||||
match this.connection.poll_new_outbound(cx)? {
|
||||
Poll::Ready(stream) => {
|
||||
this.worker_streams.push(ping_pong(stream).boxed());
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
panic!("Should be able to open 256 streams without yielding")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if this.worker_streams.len() == 256 && this.streams_processed == 0 {
|
||||
let poll_result = this.connection.poll_new_outbound(cx);
|
||||
|
||||
match (poll_result, this.trigger.take()) {
|
||||
(Poll::Pending, Some(trigger)) => {
|
||||
// This is what we want, our task gets parked because we have hit the limit.
|
||||
// Tell the server to start processing streams and wait until we get woken.
|
||||
|
||||
trigger.send(()).unwrap();
|
||||
return Poll::Pending;
|
||||
}
|
||||
(Poll::Ready(stream), None) => {
|
||||
// We got woken because the server has started to acknowledge streams.
|
||||
this.worker_streams.push(ping_pong(stream.unwrap()).boxed());
|
||||
continue;
|
||||
}
|
||||
(Poll::Ready(_), Some(_)) => {
|
||||
panic!("should not be able to open stream if server hasn't acknowledged existing streams")
|
||||
}
|
||||
(Poll::Pending, None) => {}
|
||||
}
|
||||
}
|
||||
|
||||
match this.worker_streams.poll_next_unpin(cx)? {
|
||||
Poll::Ready(Some(())) => {
|
||||
this.streams_processed += 1;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) if this.streams_processed > 0 => {
|
||||
return Poll::Ready(Ok(this.streams_processed));
|
||||
}
|
||||
Poll::Ready(None) | Poll::Pending => {}
|
||||
}
|
||||
|
||||
// Allow the connection to make progress
|
||||
match this.connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(_)) => {
|
||||
panic!("server never opens stream")
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(Ok(this.streams_processed));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ping_pong(mut stream: Stream) -> Result<(), ConnectionError> {
|
||||
let mut buffer = [0u8; 4];
|
||||
stream.write_all(b"ping").await?;
|
||||
stream.read_exact(&mut buffer).await?;
|
||||
|
||||
assert_eq!(&buffer, b"pong");
|
||||
|
||||
stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pong_ping(mut stream: Stream) -> Result<(), ConnectionError> {
|
||||
let mut buffer = [0u8; 4];
|
||||
stream.write_all(b"pong").await?;
|
||||
stream.read_exact(&mut buffer).await?;
|
||||
|
||||
assert_eq!(&buffer, b"ping");
|
||||
|
||||
stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
243
test-harness/tests/ack_timing.rs
Normal file
243
test-harness/tests/ack_timing.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use futures::future::BoxFuture;
|
||||
use futures::future::FutureExt;
|
||||
use futures::{future, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use std::future::Future;
|
||||
use std::mem;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use test_harness::bind;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use yamux::{Config, Connection, ConnectionError, Mode, Stream};
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_is_acknowledged_on_first_use() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (listener, address) = bind(None).await.expect("bind");
|
||||
|
||||
let server = async {
|
||||
let socket = listener.accept().await.expect("accept").0.compat();
|
||||
let connection = Connection::new(socket, Config::default(), Mode::Server);
|
||||
|
||||
Server::new(connection).await
|
||||
};
|
||||
|
||||
let client = async {
|
||||
let socket = TcpStream::connect(address).await.expect("connect").compat();
|
||||
let connection = Connection::new(socket, Config::default(), Mode::Client);
|
||||
|
||||
Client::new(connection).await
|
||||
};
|
||||
|
||||
let ((), ()) = future::try_join(server, client).await.unwrap();
|
||||
}
|
||||
|
||||
enum Server<T> {
|
||||
Accepting {
|
||||
connection: Connection<T>,
|
||||
},
|
||||
Working {
|
||||
connection: Connection<T>,
|
||||
stream: BoxFuture<'static, yamux::Result<()>>,
|
||||
},
|
||||
Idle {
|
||||
connection: Connection<T>,
|
||||
},
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
impl<T> Server<T> {
|
||||
fn new(connection: Connection<T>) -> Self {
|
||||
Server::Accepting { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Server<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = yamux::Result<()>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match mem::replace(this, Self::Poisoned) {
|
||||
Self::Accepting { mut connection } => match connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(stream)) => {
|
||||
*this = Self::Working {
|
||||
connection,
|
||||
stream: pong_ping(stream).boxed(),
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
panic!("connection closed before receiving a new stream")
|
||||
}
|
||||
Poll::Pending => {
|
||||
*this = Self::Accepting { connection };
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
Self::Working {
|
||||
mut connection,
|
||||
mut stream,
|
||||
} => {
|
||||
match stream.poll_unpin(cx)? {
|
||||
Poll::Ready(()) => {
|
||||
*this = Self::Idle { connection };
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(_)) => {
|
||||
panic!("not expecting new stream");
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
panic!("connection closed before stream completed")
|
||||
}
|
||||
Poll::Pending => {
|
||||
*this = Self::Working { connection, stream };
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::Idle { mut connection } => match connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(_)) => {
|
||||
panic!("not expecting new stream");
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||
Poll::Pending => {
|
||||
*this = Self::Idle { connection };
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
Self::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Client<T> {
|
||||
Opening {
|
||||
connection: Connection<T>,
|
||||
},
|
||||
Working {
|
||||
connection: Connection<T>,
|
||||
stream: BoxFuture<'static, yamux::Result<()>>,
|
||||
},
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
impl<T> Client<T> {
|
||||
fn new(connection: Connection<T>) -> Self {
|
||||
Self::Opening { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Client<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = yamux::Result<()>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match mem::replace(this, Self::Poisoned) {
|
||||
Self::Opening { mut connection } => match connection.poll_new_outbound(cx)? {
|
||||
Poll::Ready(stream) => {
|
||||
*this = Self::Working {
|
||||
connection,
|
||||
stream: ping_pong(stream).boxed(),
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
*this = Self::Opening { connection };
|
||||
return Poll::Pending;
|
||||
}
|
||||
},
|
||||
Self::Working {
|
||||
mut connection,
|
||||
mut stream,
|
||||
} => {
|
||||
match stream.poll_unpin(cx)? {
|
||||
Poll::Ready(()) => {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match connection.poll_next_inbound(cx)? {
|
||||
Poll::Ready(Some(_)) => {
|
||||
panic!("not expecting new stream");
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
panic!("connection closed before stream completed")
|
||||
}
|
||||
Poll::Pending => {
|
||||
*this = Self::Working { connection, stream };
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::Poisoned => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handler for the **outbound** stream on the client.
|
||||
///
|
||||
/// Initially, the stream is not acknowledged. The server will only acknowledge the stream with the first frame.
|
||||
async fn ping_pong(mut stream: Stream) -> Result<(), ConnectionError> {
|
||||
assert!(
|
||||
stream.is_pending_ack(),
|
||||
"newly returned stream should not be acknowledged"
|
||||
);
|
||||
|
||||
let mut buffer = [0u8; 4];
|
||||
stream.write_all(b"ping").await?;
|
||||
stream.read_exact(&mut buffer).await?;
|
||||
|
||||
assert!(
|
||||
!stream.is_pending_ack(),
|
||||
"stream should be acknowledged once we received the first data"
|
||||
);
|
||||
assert_eq!(&buffer, b"pong");
|
||||
|
||||
stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handler for the **inbound** stream on the server.
|
||||
///
|
||||
/// Initially, the stream is not acknowledged. We only include the ACK flag in the first frame.
|
||||
async fn pong_ping(mut stream: Stream) -> Result<(), ConnectionError> {
|
||||
assert!(
|
||||
stream.is_pending_ack(),
|
||||
"before sending anything we should not have acknowledged the stream to the remote"
|
||||
);
|
||||
|
||||
let mut buffer = [0u8; 4];
|
||||
stream.write_all(b"pong").await?;
|
||||
|
||||
assert!(
|
||||
!stream.is_pending_ack(),
|
||||
"we should have sent an ACK flag with the first payload"
|
||||
);
|
||||
|
||||
stream.read_exact(&mut buffer).await?;
|
||||
|
||||
assert_eq!(&buffer, b"ping");
|
||||
|
||||
stream.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -91,13 +91,14 @@ mod cleanup;
|
||||
mod closing;
|
||||
mod stream;
|
||||
|
||||
use crate::Result;
|
||||
use crate::tagged_stream::TaggedStream;
|
||||
use crate::{
|
||||
error::ConnectionError,
|
||||
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
|
||||
frame::{self, Frame},
|
||||
Config, WindowUpdateMode, DEFAULT_CREDIT,
|
||||
};
|
||||
use crate::{Result, MAX_ACK_BACKLOG};
|
||||
use cleanup::Cleanup;
|
||||
use closing::Closing;
|
||||
use futures::stream::SelectAll;
|
||||
@@ -107,7 +108,6 @@ use std::collections::VecDeque;
|
||||
use std::task::{Context, Waker};
|
||||
use std::{fmt, sync::Arc, task::Poll};
|
||||
|
||||
use crate::tagged_stream::TaggedStream;
|
||||
pub use stream::{Packet, State, Stream};
|
||||
|
||||
/// How the connection is used.
|
||||
@@ -162,12 +162,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
|
||||
pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
|
||||
loop {
|
||||
match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
|
||||
ConnectionState::Active(mut active) => match active.new_outbound() {
|
||||
Ok(stream) => {
|
||||
ConnectionState::Active(mut active) => match active.poll_new_outbound(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
self.inner = ConnectionState::Active(active);
|
||||
return Poll::Ready(Ok(stream));
|
||||
}
|
||||
Err(e) => {
|
||||
Poll::Pending => {
|
||||
self.inner = ConnectionState::Active(active);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.inner = ConnectionState::Cleanup(active.cleanup(e));
|
||||
continue;
|
||||
}
|
||||
@@ -355,6 +359,7 @@ struct Active<T> {
|
||||
no_streams_waker: Option<Waker>,
|
||||
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
new_outbound_stream_waker: Option<Waker>,
|
||||
}
|
||||
|
||||
/// `Stream` to `Connection` commands.
|
||||
@@ -425,6 +430,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Mode::Server => 2,
|
||||
},
|
||||
pending_frames: VecDeque::default(),
|
||||
new_outbound_stream_waker: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,10 +499,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn new_outbound(&mut self) -> Result<Stream> {
|
||||
fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
|
||||
if self.streams.len() >= self.config.max_num_streams {
|
||||
log::error!("{}: maximum number of streams reached", self.id);
|
||||
return Err(ConnectionError::TooManyStreams);
|
||||
return Poll::Ready(Err(ConnectionError::TooManyStreams));
|
||||
}
|
||||
|
||||
if self.ack_backlog() >= MAX_ACK_BACKLOG {
|
||||
log::debug!("{MAX_ACK_BACKLOG} streams waiting for ACK, registering task for wake-up until remote acknowledges at least one stream");
|
||||
self.new_outbound_stream_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
log::trace!("{}: creating new outbound stream", self.id);
|
||||
@@ -511,7 +523,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
self.pending_frames.push_back(frame.into());
|
||||
}
|
||||
|
||||
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);
|
||||
let mut stream = self.make_new_outbound_stream(id, self.config.receive_window);
|
||||
|
||||
if extra_credit == 0 {
|
||||
stream.set_flag(stream::Flag::Syn)
|
||||
@@ -520,7 +532,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
|
||||
self.streams.insert(id, stream.clone());
|
||||
|
||||
Ok(stream)
|
||||
Poll::Ready(Ok(stream))
|
||||
}
|
||||
|
||||
fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
|
||||
@@ -549,7 +561,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
|
||||
// The stream was dropped without calling `poll_close`.
|
||||
// We reset the stream to inform the remote of the closure.
|
||||
State::Open => {
|
||||
State::Open { .. } => {
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.rst();
|
||||
Some(Frame::new(header))
|
||||
@@ -610,6 +622,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
/// if one was opened by the remote.
|
||||
fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
|
||||
log::trace!("{}: received: {}", self.id, frame.header());
|
||||
|
||||
if frame.header().flags().contains(header::ACK) {
|
||||
let id = frame.header().stream_id();
|
||||
if let Some(stream) = self.streams.get(&id) {
|
||||
stream
|
||||
.shared()
|
||||
.update_state(self.id, id, State::Open { acknowledged: true });
|
||||
}
|
||||
if let Some(waker) = self.new_outbound_stream_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
let action = match frame.header().tag() {
|
||||
Tag::Data => self.on_data(frame.into_data()),
|
||||
Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()),
|
||||
@@ -689,7 +714,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
log::error!("{}: maximum number of streams reached", self.id);
|
||||
return Action::Terminate(Frame::internal_error());
|
||||
}
|
||||
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
|
||||
let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
|
||||
let mut window_update = None;
|
||||
{
|
||||
let mut shared = stream.shared();
|
||||
@@ -806,7 +831,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
}
|
||||
|
||||
let credit = frame.header().credit() + DEFAULT_CREDIT;
|
||||
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
|
||||
let mut stream = self.make_new_inbound_stream(stream_id, credit);
|
||||
stream.set_flag(stream::Flag::Ack);
|
||||
|
||||
if is_finish {
|
||||
@@ -873,7 +898,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Action::None
|
||||
}
|
||||
|
||||
fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
|
||||
fn make_new_inbound_stream(&mut self, id: StreamId, credit: u32) -> Stream {
|
||||
let config = self.config.clone();
|
||||
|
||||
let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
|
||||
@@ -882,7 +907,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Stream::new(id, self.id, config, window, credit, sender)
|
||||
Stream::new_inbound(id, self.id, config, credit, sender)
|
||||
}
|
||||
|
||||
fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream {
|
||||
let config = self.config.clone();
|
||||
|
||||
let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
|
||||
self.stream_receivers.push(TaggedStream::new(id, receiver));
|
||||
if let Some(waker) = self.no_streams_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Stream::new_outbound(id, self.id, config, window, sender)
|
||||
}
|
||||
|
||||
fn next_stream_id(&mut self) -> Result<StreamId> {
|
||||
@@ -898,6 +935,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Ok(proposed)
|
||||
}
|
||||
|
||||
/// The ACK backlog is defined as the number of outbound streams that have not yet been acknowledged.
|
||||
fn ack_backlog(&mut self) -> usize {
|
||||
self.streams
|
||||
.values()
|
||||
.filter(|s| s.is_outbound(self.mode))
|
||||
.filter(|s| s.is_pending_ack())
|
||||
.count()
|
||||
}
|
||||
|
||||
// Check if the given stream ID is valid w.r.t. the provided tag and our connection mode.
|
||||
fn is_valid_remote_id(&self, id: StreamId, tag: Tag) -> bool {
|
||||
if tag == Tag::Ping || tag == Tag::GoAway {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
|
||||
// at https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::frame::header::ACK;
|
||||
use crate::{
|
||||
chunks::Chunks,
|
||||
connection::{self, StreamCommand},
|
||||
@@ -15,7 +16,7 @@ use crate::{
|
||||
header::{Data, Header, StreamId, WindowUpdate},
|
||||
Frame,
|
||||
},
|
||||
Config, WindowUpdateMode,
|
||||
Config, Mode, WindowUpdateMode, DEFAULT_CREDIT,
|
||||
};
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
@@ -36,7 +37,18 @@ use std::{
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum State {
|
||||
/// Open bidirectionally.
|
||||
Open,
|
||||
Open {
|
||||
/// Whether the stream is acknowledged.
|
||||
///
|
||||
/// For outbound streams, this tracks whether the remote has acknowledged our stream.
|
||||
/// For inbound streams, this tracks whether we have acknowledged the stream to the remote.
|
||||
///
|
||||
/// This starts out with `false` and is set to `true` when we receive or send an `ACK` flag for this stream.
|
||||
/// We may also directly transition:
|
||||
/// - from `Open` to `RecvClosed` if the remote immediately sends `FIN`.
|
||||
/// - from `Open` to `Closed` if the remote immediately sends `RST`.
|
||||
acknowledged: bool,
|
||||
},
|
||||
/// Open for incoming messages.
|
||||
SendClosed,
|
||||
/// Open for outgoing messages.
|
||||
@@ -100,21 +112,37 @@ impl fmt::Display for Stream {
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
pub(crate) fn new(
|
||||
pub(crate) fn new_inbound(
|
||||
id: StreamId,
|
||||
conn: connection::Id,
|
||||
config: Arc<Config>,
|
||||
window: u32,
|
||||
credit: u32,
|
||||
sender: mpsc::Sender<StreamCommand>,
|
||||
) -> Self {
|
||||
Stream {
|
||||
Self {
|
||||
id,
|
||||
conn,
|
||||
config: config.clone(),
|
||||
sender,
|
||||
flag: Flag::None,
|
||||
shared: Arc::new(Mutex::new(Shared::new(window, credit, config))),
|
||||
shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_outbound(
|
||||
id: StreamId,
|
||||
conn: connection::Id,
|
||||
config: Arc<Config>,
|
||||
window: u32,
|
||||
sender: mpsc::Sender<StreamCommand>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
conn,
|
||||
config: config.clone(),
|
||||
sender,
|
||||
flag: Flag::None,
|
||||
shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +159,30 @@ impl Stream {
|
||||
matches!(self.shared().state(), State::Closed)
|
||||
}
|
||||
|
||||
/// Whether we are still waiting for the remote to acknowledge this stream.
|
||||
pub fn is_pending_ack(&self) -> bool {
|
||||
matches!(
|
||||
self.shared().state(),
|
||||
State::Open {
|
||||
acknowledged: false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Whether this is an outbound stream.
|
||||
///
|
||||
/// Clients use odd IDs and servers use even IDs.
|
||||
/// A stream is outbound if:
|
||||
///
|
||||
/// - Its ID is odd and we are the client.
|
||||
/// - Its ID is even and we are the server.
|
||||
pub(crate) fn is_outbound(&self, our_mode: Mode) -> bool {
|
||||
match our_mode {
|
||||
Mode::Client => self.id.is_client(),
|
||||
Mode::Server => self.id.is_server(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the flag that should be set on the next outbound frame header.
|
||||
pub(crate) fn set_flag(&mut self, flag: Flag) {
|
||||
self.flag = flag
|
||||
@@ -347,6 +399,16 @@ impl AsyncWrite for Stream {
|
||||
let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
|
||||
self.add_flag(frame.header_mut());
|
||||
log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
|
||||
|
||||
// technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending
|
||||
// We are tracking this information:
|
||||
// a) to be consistent with outbound streams
|
||||
// b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
|
||||
if frame.header().flags().contains(ACK) {
|
||||
self.shared()
|
||||
.update_state(self.conn, self.id, State::Open { acknowledged: true });
|
||||
}
|
||||
|
||||
let cmd = StreamCommand::SendFrame(frame);
|
||||
self.sender
|
||||
.start_send(cmd)
|
||||
@@ -399,7 +461,9 @@ pub(crate) struct Shared {
|
||||
impl Shared {
|
||||
fn new(window: u32, credit: u32, config: Arc<Config>) -> Self {
|
||||
Shared {
|
||||
state: State::Open,
|
||||
state: State::Open {
|
||||
acknowledged: false,
|
||||
},
|
||||
window,
|
||||
credit,
|
||||
buffer: Chunks::new(),
|
||||
@@ -426,19 +490,19 @@ impl Shared {
|
||||
|
||||
match (current, next) {
|
||||
(Closed, _) => {}
|
||||
(Open, _) => self.state = next,
|
||||
(Open { .. }, _) => self.state = next,
|
||||
(RecvClosed, Closed) => self.state = Closed,
|
||||
(RecvClosed, Open) => {}
|
||||
(RecvClosed, Open { .. }) => {}
|
||||
(RecvClosed, RecvClosed) => {}
|
||||
(RecvClosed, SendClosed) => self.state = Closed,
|
||||
(SendClosed, Closed) => self.state = Closed,
|
||||
(SendClosed, Open) => {}
|
||||
(SendClosed, Open { .. }) => {}
|
||||
(SendClosed, RecvClosed) => self.state = Closed,
|
||||
(SendClosed, SendClosed) => {}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"{}/{}: update state: ({:?} {:?} {:?})",
|
||||
"{}/{}: update state: (from {:?} to {:?} -> {:?})",
|
||||
cid,
|
||||
sid,
|
||||
current,
|
||||
|
||||
@@ -42,6 +42,11 @@ pub const DEFAULT_CREDIT: u32 = 256 * 1024; // as per yamux specification
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ConnectionError>;
|
||||
|
||||
/// The maximum number of streams we will open without an acknowledgement from the other peer.
|
||||
///
|
||||
/// This enables a very basic form of backpressure on the creation of streams.
|
||||
const MAX_ACK_BACKLOG: usize = 256;
|
||||
|
||||
/// Default maximum number of bytes a Yamux data frame might carry as its
|
||||
/// payload when being send. Larger Payloads will be split.
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user