fix: avoid race condition between pending frames and closing stream (#156)

Currently, we have a `garbage_collect` function that checks whether any of our
streams have been dropped. This can cause a race condition where the channel
between a `Stream` and the `Connection` still has pending frames for a stream
but dropping a stream causes us to already send a `FIN` flag for the stream.

We fix this by maintaining a single channel for each stream. When a stream gets
dropped, the `Receiver` becomes disconnected. We use this information to queue
the correct frame (`FIN` vs `RST`) into the buffer. At this point, all previous
frames have already been processed and the race condition is thus not present.

Additionally, this also allows us to implement `Stream::poll_flush` by
forwarding to the underlying `Sender`. Note that at present day, this only
checks whether there is _space_ in the channel, not whether the items have been
emitted by the `Receiver`.

We have a PR upstream that might fix this:
https://github.com/rust-lang/futures-rs/pull/2746

Fixes: #117.
This commit is contained in:
Thomas Eizinger
2023-05-24 04:34:11 +02:00
committed by GitHub
parent 88ed4dfc7a
commit 52c725b365
8 changed files with 212 additions and 167 deletions

View File

@@ -16,4 +16,3 @@ log = "0.4.17"
[dev-dependencies]
env_logger = "0.10"
constrained-connection = "0.1"

View File

@@ -16,6 +16,7 @@ nohash-hasher = "0.2"
parking_lot = "0.12"
rand = "0.8.3"
static_assertions = "1"
pin-project = "1.1.0"
[dev-dependencies]
anyhow = "1"
@@ -26,6 +27,7 @@ quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.3.1"
[[bench]]
name = "concurrent"

View File

@@ -96,16 +96,18 @@ use crate::{
error::ConnectionError,
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
frame::{self, Frame},
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use cleanup::Cleanup;
use closing::Closing;
use futures::stream::SelectAll;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::VecDeque;
use std::task::Context;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};
use crate::tagged_stream::TaggedStream;
pub use stream::{Packet, State, Stream};
/// How the connection is used.
@@ -347,10 +349,11 @@ struct Active<T> {
config: Arc<Config>,
socket: Fuse<frame::Io<T>>,
next_id: u32,
streams: IntMap<StreamId, Stream>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
dropped_streams: Vec<StreamId>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,
pending_frames: VecDeque<Frame<()>>,
}
@@ -360,7 +363,7 @@ pub(crate) enum StreamCommand {
/// A new frame should be sent to the remote.
SendFrame(Frame<Either<Data, WindowUpdate>>),
/// Close a stream.
CloseStream { id: StreamId, ack: bool },
CloseStream { ack: bool },
}
/// Possible actions as a result of incoming frame handling.
@@ -408,7 +411,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Active {
id,
@@ -416,20 +418,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
stream_sender,
stream_receiver,
stream_receivers: SelectAll::default(),
no_streams_waker: None,
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2,
},
dropped_streams: Vec::new(),
pending_frames: VecDeque::default(),
}
}
/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
}
/// Cleanup all our resources.
@@ -438,13 +439,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
self.drop_all_streams();
Cleanup::new(self.stream_receiver, error)
Cleanup::new(self.stream_receivers, error)
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
self.garbage_collect();
if self.socket.poll_ready_unpin(cx).is_ready() {
if let Some(frame) = self.pending_frames.pop_front() {
self.socket.start_send_unpin(frame)?;
@@ -457,17 +456,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Pending => {}
}
match self.stream_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
self.on_send_frame(frame);
match self.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
self.on_send_frame(frame.into());
continue;
}
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(Some((id, None))) => {
self.on_drop_stream(id);
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
self.no_streams_waker = Some(cx.waker().clone());
}
Poll::Pending => {}
}
@@ -508,16 +511,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
self.pending_frames.push_back(frame.into());
}
let stream = {
let config = self.config.clone();
let sender = self.stream_sender.clone();
let window = self.config.receive_window;
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
stream
};
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
@@ -541,6 +539,69 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.push_back(Frame::close_stream(id, ack).into());
}
fn on_drop_stream(&mut self, id: StreamId) {
let stream = self.streams.remove(&id).expect("stream not found");
log::trace!("{}: removing dropped {}", self.id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We already sent our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if self.config.window_update_mode == WindowUpdateMode::OnRead
&& shared.window == 0
{
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We already have sent our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
}
/// Process the result of reading from the socket.
///
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
@@ -628,12 +689,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let mut stream = {
let config = self.config.clone();
let credit = DEFAULT_CREDIT;
let sender = self.stream_sender.clone();
Stream::new(stream_id, self.id, config, credit, credit, sender)
};
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
let mut window_update = None;
{
let mut shared = stream.shared();
@@ -748,15 +804,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::protocol_error());
}
let stream = {
let credit = frame.header().credit() + DEFAULT_CREDIT;
let config = self.config.clone();
let sender = self.stream_sender.clone();
let mut stream =
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
stream.set_flag(stream::Flag::Ack);
stream
};
let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
stream.set_flag(stream::Flag::Ack);
if is_finish {
stream
.shared()
@@ -821,6 +873,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::None
}
fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
let config = self.config.clone();
let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
self.stream_receivers.push(TaggedStream::new(id, receiver));
if let Some(waker) = self.no_streams_waker.take() {
waker.wake();
}
Stream::new(id, self.id, config, window, credit, sender)
}
fn next_stream_id(&mut self) -> Result<StreamId> {
let proposed = StreamId::new(self.next_id);
self.next_id = self
@@ -844,79 +908,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Server => id.is_client(),
}
}
/// Remove stale streams and create necessary messages to be sent to the remote.
fn garbage_collect(&mut self) {
let conn_id = self.id;
let win_update_mode = self.config.window_update_mode;
for stream in self.streams.values_mut() {
if stream.strong_count() > 1 {
continue;
}
log::trace!("{}: removing dropped {}", conn_id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
self.dropped_streams.push(stream_id)
}
for id in self.dropped_streams.drain(..) {
self.streams.remove(&id);
}
}
}
impl<T> Active<T> {

View File

@@ -1,7 +1,9 @@
use crate::connection::StreamCommand;
use crate::ConnectionError;
use crate::tagged_stream::TaggedStream;
use crate::{ConnectionError, StreamId};
use futures::channel::mpsc;
use futures::{ready, StreamExt};
use futures::stream::SelectAll;
use futures::StreamExt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -10,18 +12,18 @@ use std::task::{Context, Poll};
#[must_use]
pub struct Cleanup {
state: State,
stream_receiver: mpsc::Receiver<StreamCommand>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
error: Option<ConnectionError>,
}
impl Cleanup {
pub(crate) fn new(
stream_receiver: mpsc::Receiver<StreamCommand>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
error: ConnectionError,
) -> Self {
Self {
state: State::ClosingStreamReceiver,
stream_receiver,
stream_receivers,
error: Some(error),
}
}
@@ -36,26 +38,23 @@ impl Future for Cleanup {
loop {
match this.state {
State::ClosingStreamReceiver => {
this.stream_receiver.close();
for stream in this.stream_receivers.iter_mut() {
stream.inner_mut().close();
}
this.state = State::DrainingStreamReceiver;
}
State::DrainingStreamReceiver => {
this.stream_receiver.close();
match ready!(this.stream_receiver.poll_next_unpin(cx)) {
Some(cmd) => {
drop(cmd);
}
None => {
return Poll::Ready(
this.error
.take()
.expect("to not be called after completion"),
);
}
State::DrainingStreamReceiver => match this.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some(cmd)) => {
drop(cmd);
}
}
Poll::Ready(None) | Poll::Pending => {
return Poll::Ready(
this.error
.take()
.expect("to not be called after completion"),
)
}
},
}
}
}

View File

@@ -1,9 +1,10 @@
use crate::connection::StreamCommand;
use crate::frame;
use crate::frame::Frame;
use crate::tagged_stream::TaggedStream;
use crate::Result;
use crate::{frame, StreamId};
use futures::channel::mpsc;
use futures::stream::Fuse;
use futures::stream::{Fuse, SelectAll};
use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt};
use std::collections::VecDeque;
use std::future::Future;
@@ -14,7 +15,7 @@ use std::task::{Context, Poll};
#[must_use]
pub struct Closing<T> {
state: State,
stream_receiver: mpsc::Receiver<StreamCommand>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
pending_frames: VecDeque<Frame<()>>,
socket: Fuse<frame::Io<T>>,
}
@@ -24,13 +25,13 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
stream_receiver: mpsc::Receiver<StreamCommand>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
pending_frames: VecDeque<Frame<()>>,
socket: Fuse<frame::Io<T>>,
) -> Self {
Self {
state: State::ClosingStreamReceiver,
stream_receiver,
stream_receivers,
pending_frames,
socket,
}
@@ -49,27 +50,30 @@ where
loop {
match this.state {
State::ClosingStreamReceiver => {
this.stream_receiver.close();
for stream in this.stream_receivers.iter_mut() {
stream.inner_mut().close();
}
this.state = State::DrainingStreamReceiver;
}
State::DrainingStreamReceiver => {
this.stream_receiver.close();
match ready!(this.stream_receiver.poll_next_unpin(cx)) {
Some(StreamCommand::SendFrame(frame)) => {
match this.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
this.pending_frames.push_back(frame.into())
}
Some(StreamCommand::CloseStream { id, ack }) => this
.pending_frames
.push_back(Frame::close_stream(id, ack).into()),
None => this.state = State::SendingTermFrame,
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
this.pending_frames
.push_back(Frame::close_stream(id, ack).into());
}
Poll::Ready(Some((_, None))) => {}
Poll::Pending | Poll::Ready(None) => {
// No more frames from streams, append `Term` frame and flush them all.
this.pending_frames.push_back(Frame::term().into());
this.state = State::FlushingPendingFrames;
continue;
}
}
}
State::SendingTermFrame => {
this.pending_frames.push_back(Frame::term().into());
this.state = State::FlushingPendingFrames;
}
State::FlushingPendingFrames => {
ready!(this.socket.poll_ready_unpin(cx))?;
@@ -91,7 +95,6 @@ where
enum State {
ClosingStreamReceiver,
DrainingStreamReceiver,
SendingTermFrame,
FlushingPendingFrames,
ClosingSocket,
}

View File

@@ -21,7 +21,7 @@ use futures::{
channel::mpsc,
future::Either,
io::{AsyncRead, AsyncWrite},
ready,
ready, SinkExt,
};
use parking_lot::{Mutex, MutexGuard};
use std::convert::TryInto;
@@ -136,10 +136,6 @@ impl Stream {
self.flag = flag
}
pub(crate) fn strong_count(&self) -> usize {
Arc::strong_count(&self.shared)
}
pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
self.shared.lock()
}
@@ -358,8 +354,10 @@ impl AsyncWrite for Stream {
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.sender
.poll_flush_unpin(cx)
.map_err(|_| self.write_zero_err())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
@@ -377,7 +375,7 @@ impl AsyncWrite for Stream {
false
};
log::trace!("{}/{}: close", self.conn, self.id);
let cmd = StreamCommand::CloseStream { id: self.id, ack };
let cmd = StreamCommand::CloseStream { ack };
self.sender
.start_send(cmd)
.map_err(|_| self.write_zero_err())?;

View File

@@ -30,6 +30,7 @@ mod error;
mod frame;
pub(crate) mod connection;
mod tagged_stream;
pub use crate::connection::{Connection, Mode, Packet, Stream};
pub use crate::control::{Control, ControlledConnection};

View File

@@ -0,0 +1,52 @@
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
/// A stream that yields its tag with every item.
#[pin_project::pin_project]
pub struct TaggedStream<K, S> {
key: K,
#[pin]
inner: S,
reported_none: bool,
}
impl<K, S> TaggedStream<K, S> {
pub fn new(key: K, inner: S) -> Self {
Self {
key,
inner,
reported_none: false,
}
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<K, S> Stream for TaggedStream<K, S>
where
K: Copy,
S: Stream,
{
type Item = (K, Option<S::Item>);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.reported_none {
return Poll::Ready(None);
}
match futures::ready!(this.inner.poll_next(cx)) {
Some(item) => Poll::Ready(Some((*this.key, Some(item)))),
None => {
*this.reported_none = true;
Poll::Ready(Some((*this.key, None)))
}
}
}
}