mirror of
https://github.com/tlsnotary/rust-yamux.git
synced 2026-01-09 12:58:03 -05:00
fix: avoid race condition between pending frames and closing stream (#156)
Currently, we have a `garbage_collect` function that checks whether any of our streams have been dropped. This can cause a race condition where the channel between a `Stream` and the `Connection` still has pending frames for a stream but dropping a stream causes us to already send a `FIN` flag for the stream. We fix this by maintaining a single channel for each stream. When a stream gets dropped, the `Receiver` becomes disconnected. We use this information to queue the correct frame (`FIN` vs `RST`) into the buffer. At this point, all previous frames have already been processed and the race condition is thus not present. Additionally, this also allows us to implement `Stream::poll_flush` by forwarding to the underlying `Sender`. Note that at present day, this only checks whether there is _space_ in the channel, not whether the items have been emitted by the `Receiver`. We have a PR upstream that might fix this: https://github.com/rust-lang/futures-rs/pull/2746 Fixes: #117.
This commit is contained in:
@@ -16,4 +16,3 @@ log = "0.4.17"
|
||||
[dev-dependencies]
|
||||
env_logger = "0.10"
|
||||
constrained-connection = "0.1"
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ nohash-hasher = "0.2"
|
||||
parking_lot = "0.12"
|
||||
rand = "0.8.3"
|
||||
static_assertions = "1"
|
||||
pin-project = "1.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
@@ -26,6 +27,7 @@ quickcheck = "1.0"
|
||||
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
constrained-connection = "0.1"
|
||||
futures_ringbuf = "0.3.1"
|
||||
|
||||
[[bench]]
|
||||
name = "concurrent"
|
||||
|
||||
@@ -96,16 +96,18 @@ use crate::{
|
||||
error::ConnectionError,
|
||||
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
|
||||
frame::{self, Frame},
|
||||
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
|
||||
Config, WindowUpdateMode, DEFAULT_CREDIT,
|
||||
};
|
||||
use cleanup::Cleanup;
|
||||
use closing::Closing;
|
||||
use futures::stream::SelectAll;
|
||||
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::task::Context;
|
||||
use std::task::{Context, Waker};
|
||||
use std::{fmt, sync::Arc, task::Poll};
|
||||
|
||||
use crate::tagged_stream::TaggedStream;
|
||||
pub use stream::{Packet, State, Stream};
|
||||
|
||||
/// How the connection is used.
|
||||
@@ -347,10 +349,11 @@ struct Active<T> {
|
||||
config: Arc<Config>,
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
next_id: u32,
|
||||
|
||||
streams: IntMap<StreamId, Stream>,
|
||||
stream_sender: mpsc::Sender<StreamCommand>,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
dropped_streams: Vec<StreamId>,
|
||||
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
|
||||
no_streams_waker: Option<Waker>,
|
||||
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
}
|
||||
|
||||
@@ -360,7 +363,7 @@ pub(crate) enum StreamCommand {
|
||||
/// A new frame should be sent to the remote.
|
||||
SendFrame(Frame<Either<Data, WindowUpdate>>),
|
||||
/// Close a stream.
|
||||
CloseStream { id: StreamId, ack: bool },
|
||||
CloseStream { ack: bool },
|
||||
}
|
||||
|
||||
/// Possible actions as a result of incoming frame handling.
|
||||
@@ -408,7 +411,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
|
||||
let id = Id::random();
|
||||
log::debug!("new connection: {} ({:?})", id, mode);
|
||||
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
|
||||
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
|
||||
Active {
|
||||
id,
|
||||
@@ -416,20 +418,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
config: Arc::new(cfg),
|
||||
socket,
|
||||
streams: IntMap::default(),
|
||||
stream_sender,
|
||||
stream_receiver,
|
||||
stream_receivers: SelectAll::default(),
|
||||
no_streams_waker: None,
|
||||
next_id: match mode {
|
||||
Mode::Client => 1,
|
||||
Mode::Server => 2,
|
||||
},
|
||||
dropped_streams: Vec::new(),
|
||||
pending_frames: VecDeque::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gracefully close the connection to the remote.
|
||||
fn close(self) -> Closing<T> {
|
||||
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
|
||||
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
|
||||
}
|
||||
|
||||
/// Cleanup all our resources.
|
||||
@@ -438,13 +439,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
|
||||
self.drop_all_streams();
|
||||
|
||||
Cleanup::new(self.stream_receiver, error)
|
||||
Cleanup::new(self.stream_receivers, error)
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
|
||||
loop {
|
||||
self.garbage_collect();
|
||||
|
||||
if self.socket.poll_ready_unpin(cx).is_ready() {
|
||||
if let Some(frame) = self.pending_frames.pop_front() {
|
||||
self.socket.start_send_unpin(frame)?;
|
||||
@@ -457,17 +456,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match self.stream_receiver.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
|
||||
self.on_send_frame(frame);
|
||||
match self.stream_receivers.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
|
||||
self.on_send_frame(frame.into());
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
|
||||
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
|
||||
self.on_close_stream(id, ack);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some((id, None))) => {
|
||||
self.on_drop_stream(id);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
debug_assert!(false, "Only closed during shutdown")
|
||||
self.no_streams_waker = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
@@ -508,16 +511,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
self.pending_frames.push_back(frame.into());
|
||||
}
|
||||
|
||||
let stream = {
|
||||
let config = self.config.clone();
|
||||
let sender = self.stream_sender.clone();
|
||||
let window = self.config.receive_window;
|
||||
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
|
||||
if extra_credit == 0 {
|
||||
stream.set_flag(stream::Flag::Syn)
|
||||
}
|
||||
stream
|
||||
};
|
||||
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);
|
||||
|
||||
if extra_credit == 0 {
|
||||
stream.set_flag(stream::Flag::Syn)
|
||||
}
|
||||
|
||||
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
|
||||
self.streams.insert(id, stream.clone());
|
||||
@@ -541,6 +539,69 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
.push_back(Frame::close_stream(id, ack).into());
|
||||
}
|
||||
|
||||
fn on_drop_stream(&mut self, id: StreamId) {
|
||||
let stream = self.streams.remove(&id).expect("stream not found");
|
||||
|
||||
log::trace!("{}: removing dropped {}", self.id, stream);
|
||||
let stream_id = stream.id();
|
||||
let frame = {
|
||||
let mut shared = stream.shared();
|
||||
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
|
||||
// The stream was dropped without calling `poll_close`.
|
||||
// We reset the stream to inform the remote of the closure.
|
||||
State::Open => {
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.rst();
|
||||
Some(Frame::new(header))
|
||||
}
|
||||
// The stream was dropped without calling `poll_close`.
|
||||
// We have already received a FIN from remote and send one
|
||||
// back which closes the stream for good.
|
||||
State::RecvClosed => {
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.fin();
|
||||
Some(Frame::new(header))
|
||||
}
|
||||
// The stream was properly closed. We already sent our FIN frame.
|
||||
// The remote may be out of credit though and blocked on
|
||||
// writing more data. We may need to reset the stream.
|
||||
State::SendClosed => {
|
||||
if self.config.window_update_mode == WindowUpdateMode::OnRead
|
||||
&& shared.window == 0
|
||||
{
|
||||
// The remote may be waiting for a window update
|
||||
// which we will never send, so reset the stream now.
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.rst();
|
||||
Some(Frame::new(header))
|
||||
} else {
|
||||
// The remote has either still credit or will be given more
|
||||
// (due to an enqueued window update or because the update
|
||||
// mode is `OnReceive`) or we already have inbound frames in
|
||||
// the socket buffer which will be processed later. In any
|
||||
// case we will reply with an RST in `Connection::on_data`
|
||||
// because the stream will no longer be known.
|
||||
None
|
||||
}
|
||||
}
|
||||
// The stream was properly closed. We already have sent our FIN frame. The
|
||||
// remote end has already done so in the past.
|
||||
State::Closed => None,
|
||||
};
|
||||
if let Some(w) = shared.reader.take() {
|
||||
w.wake()
|
||||
}
|
||||
if let Some(w) = shared.writer.take() {
|
||||
w.wake()
|
||||
}
|
||||
frame
|
||||
};
|
||||
if let Some(f) = frame {
|
||||
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
|
||||
self.pending_frames.push_back(f.into());
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the result of reading from the socket.
|
||||
///
|
||||
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
|
||||
@@ -628,12 +689,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
log::error!("{}: maximum number of streams reached", self.id);
|
||||
return Action::Terminate(Frame::internal_error());
|
||||
}
|
||||
let mut stream = {
|
||||
let config = self.config.clone();
|
||||
let credit = DEFAULT_CREDIT;
|
||||
let sender = self.stream_sender.clone();
|
||||
Stream::new(stream_id, self.id, config, credit, credit, sender)
|
||||
};
|
||||
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
|
||||
let mut window_update = None;
|
||||
{
|
||||
let mut shared = stream.shared();
|
||||
@@ -748,15 +804,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
log::error!("{}: maximum number of streams reached", self.id);
|
||||
return Action::Terminate(Frame::protocol_error());
|
||||
}
|
||||
let stream = {
|
||||
let credit = frame.header().credit() + DEFAULT_CREDIT;
|
||||
let config = self.config.clone();
|
||||
let sender = self.stream_sender.clone();
|
||||
let mut stream =
|
||||
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
|
||||
stream.set_flag(stream::Flag::Ack);
|
||||
stream
|
||||
};
|
||||
|
||||
let credit = frame.header().credit() + DEFAULT_CREDIT;
|
||||
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
|
||||
stream.set_flag(stream::Flag::Ack);
|
||||
|
||||
if is_finish {
|
||||
stream
|
||||
.shared()
|
||||
@@ -821,6 +873,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Action::None
|
||||
}
|
||||
|
||||
fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
|
||||
let config = self.config.clone();
|
||||
|
||||
let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
|
||||
self.stream_receivers.push(TaggedStream::new(id, receiver));
|
||||
if let Some(waker) = self.no_streams_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
Stream::new(id, self.id, config, window, credit, sender)
|
||||
}
|
||||
|
||||
fn next_stream_id(&mut self) -> Result<StreamId> {
|
||||
let proposed = StreamId::new(self.next_id);
|
||||
self.next_id = self
|
||||
@@ -844,79 +908,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
|
||||
Mode::Server => id.is_client(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove stale streams and create necessary messages to be sent to the remote.
|
||||
fn garbage_collect(&mut self) {
|
||||
let conn_id = self.id;
|
||||
let win_update_mode = self.config.window_update_mode;
|
||||
for stream in self.streams.values_mut() {
|
||||
if stream.strong_count() > 1 {
|
||||
continue;
|
||||
}
|
||||
log::trace!("{}: removing dropped {}", conn_id, stream);
|
||||
let stream_id = stream.id();
|
||||
let frame = {
|
||||
let mut shared = stream.shared();
|
||||
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
|
||||
// The stream was dropped without calling `poll_close`.
|
||||
// We reset the stream to inform the remote of the closure.
|
||||
State::Open => {
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.rst();
|
||||
Some(Frame::new(header))
|
||||
}
|
||||
// The stream was dropped without calling `poll_close`.
|
||||
// We have already received a FIN from remote and send one
|
||||
// back which closes the stream for good.
|
||||
State::RecvClosed => {
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.fin();
|
||||
Some(Frame::new(header))
|
||||
}
|
||||
// The stream was properly closed. We either already have
|
||||
// or will at some later point send our FIN frame.
|
||||
// The remote may be out of credit though and blocked on
|
||||
// writing more data. We may need to reset the stream.
|
||||
State::SendClosed => {
|
||||
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
|
||||
// The remote may be waiting for a window update
|
||||
// which we will never send, so reset the stream now.
|
||||
let mut header = Header::data(stream_id, 0);
|
||||
header.rst();
|
||||
Some(Frame::new(header))
|
||||
} else {
|
||||
// The remote has either still credit or will be given more
|
||||
// (due to an enqueued window update or because the update
|
||||
// mode is `OnReceive`) or we already have inbound frames in
|
||||
// the socket buffer which will be processed later. In any
|
||||
// case we will reply with an RST in `Connection::on_data`
|
||||
// because the stream will no longer be known.
|
||||
None
|
||||
}
|
||||
}
|
||||
// The stream was properly closed. We either already have
|
||||
// or will at some later point send our FIN frame. The
|
||||
// remote end has already done so in the past.
|
||||
State::Closed => None,
|
||||
};
|
||||
if let Some(w) = shared.reader.take() {
|
||||
w.wake()
|
||||
}
|
||||
if let Some(w) = shared.writer.take() {
|
||||
w.wake()
|
||||
}
|
||||
frame
|
||||
};
|
||||
if let Some(f) = frame {
|
||||
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
|
||||
self.pending_frames.push_back(f.into());
|
||||
}
|
||||
self.dropped_streams.push(stream_id)
|
||||
}
|
||||
for id in self.dropped_streams.drain(..) {
|
||||
self.streams.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Active<T> {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::connection::StreamCommand;
|
||||
use crate::ConnectionError;
|
||||
use crate::tagged_stream::TaggedStream;
|
||||
use crate::{ConnectionError, StreamId};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{ready, StreamExt};
|
||||
use futures::stream::SelectAll;
|
||||
use futures::StreamExt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
@@ -10,18 +12,18 @@ use std::task::{Context, Poll};
|
||||
#[must_use]
|
||||
pub struct Cleanup {
|
||||
state: State,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
|
||||
error: Option<ConnectionError>,
|
||||
}
|
||||
|
||||
impl Cleanup {
|
||||
pub(crate) fn new(
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
|
||||
error: ConnectionError,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: State::ClosingStreamReceiver,
|
||||
stream_receiver,
|
||||
stream_receivers,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
@@ -36,26 +38,23 @@ impl Future for Cleanup {
|
||||
loop {
|
||||
match this.state {
|
||||
State::ClosingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
for stream in this.stream_receivers.iter_mut() {
|
||||
stream.inner_mut().close();
|
||||
}
|
||||
this.state = State::DrainingStreamReceiver;
|
||||
}
|
||||
|
||||
State::DrainingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
|
||||
match ready!(this.stream_receiver.poll_next_unpin(cx)) {
|
||||
Some(cmd) => {
|
||||
drop(cmd);
|
||||
}
|
||||
None => {
|
||||
return Poll::Ready(
|
||||
this.error
|
||||
.take()
|
||||
.expect("to not be called after completion"),
|
||||
);
|
||||
}
|
||||
State::DrainingStreamReceiver => match this.stream_receivers.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(cmd)) => {
|
||||
drop(cmd);
|
||||
}
|
||||
}
|
||||
Poll::Ready(None) | Poll::Pending => {
|
||||
return Poll::Ready(
|
||||
this.error
|
||||
.take()
|
||||
.expect("to not be called after completion"),
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::connection::StreamCommand;
|
||||
use crate::frame;
|
||||
use crate::frame::Frame;
|
||||
use crate::tagged_stream::TaggedStream;
|
||||
use crate::Result;
|
||||
use crate::{frame, StreamId};
|
||||
use futures::channel::mpsc;
|
||||
use futures::stream::Fuse;
|
||||
use futures::stream::{Fuse, SelectAll};
|
||||
use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt};
|
||||
use std::collections::VecDeque;
|
||||
use std::future::Future;
|
||||
@@ -14,7 +15,7 @@ use std::task::{Context, Poll};
|
||||
#[must_use]
|
||||
pub struct Closing<T> {
|
||||
state: State,
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
}
|
||||
@@ -24,13 +25,13 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream_receiver: mpsc::Receiver<StreamCommand>,
|
||||
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
|
||||
pending_frames: VecDeque<Frame<()>>,
|
||||
socket: Fuse<frame::Io<T>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: State::ClosingStreamReceiver,
|
||||
stream_receiver,
|
||||
stream_receivers,
|
||||
pending_frames,
|
||||
socket,
|
||||
}
|
||||
@@ -49,27 +50,30 @@ where
|
||||
loop {
|
||||
match this.state {
|
||||
State::ClosingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
for stream in this.stream_receivers.iter_mut() {
|
||||
stream.inner_mut().close();
|
||||
}
|
||||
this.state = State::DrainingStreamReceiver;
|
||||
}
|
||||
|
||||
State::DrainingStreamReceiver => {
|
||||
this.stream_receiver.close();
|
||||
|
||||
match ready!(this.stream_receiver.poll_next_unpin(cx)) {
|
||||
Some(StreamCommand::SendFrame(frame)) => {
|
||||
match this.stream_receivers.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
|
||||
this.pending_frames.push_back(frame.into())
|
||||
}
|
||||
Some(StreamCommand::CloseStream { id, ack }) => this
|
||||
.pending_frames
|
||||
.push_back(Frame::close_stream(id, ack).into()),
|
||||
None => this.state = State::SendingTermFrame,
|
||||
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
|
||||
this.pending_frames
|
||||
.push_back(Frame::close_stream(id, ack).into());
|
||||
}
|
||||
Poll::Ready(Some((_, None))) => {}
|
||||
Poll::Pending | Poll::Ready(None) => {
|
||||
// No more frames from streams, append `Term` frame and flush them all.
|
||||
this.pending_frames.push_back(Frame::term().into());
|
||||
this.state = State::FlushingPendingFrames;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
State::SendingTermFrame => {
|
||||
this.pending_frames.push_back(Frame::term().into());
|
||||
this.state = State::FlushingPendingFrames;
|
||||
}
|
||||
State::FlushingPendingFrames => {
|
||||
ready!(this.socket.poll_ready_unpin(cx))?;
|
||||
|
||||
@@ -91,7 +95,6 @@ where
|
||||
enum State {
|
||||
ClosingStreamReceiver,
|
||||
DrainingStreamReceiver,
|
||||
SendingTermFrame,
|
||||
FlushingPendingFrames,
|
||||
ClosingSocket,
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ use futures::{
|
||||
channel::mpsc,
|
||||
future::Either,
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
ready,
|
||||
ready, SinkExt,
|
||||
};
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use std::convert::TryInto;
|
||||
@@ -136,10 +136,6 @@ impl Stream {
|
||||
self.flag = flag
|
||||
}
|
||||
|
||||
pub(crate) fn strong_count(&self) -> usize {
|
||||
Arc::strong_count(&self.shared)
|
||||
}
|
||||
|
||||
pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
|
||||
self.shared.lock()
|
||||
}
|
||||
@@ -358,8 +354,10 @@ impl AsyncWrite for Stream {
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.sender
|
||||
.poll_flush_unpin(cx)
|
||||
.map_err(|_| self.write_zero_err())
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
@@ -377,7 +375,7 @@ impl AsyncWrite for Stream {
|
||||
false
|
||||
};
|
||||
log::trace!("{}/{}: close", self.conn, self.id);
|
||||
let cmd = StreamCommand::CloseStream { id: self.id, ack };
|
||||
let cmd = StreamCommand::CloseStream { ack };
|
||||
self.sender
|
||||
.start_send(cmd)
|
||||
.map_err(|_| self.write_zero_err())?;
|
||||
|
||||
@@ -30,6 +30,7 @@ mod error;
|
||||
mod frame;
|
||||
|
||||
pub(crate) mod connection;
|
||||
mod tagged_stream;
|
||||
|
||||
pub use crate::connection::{Connection, Mode, Packet, Stream};
|
||||
pub use crate::control::{Control, ControlledConnection};
|
||||
|
||||
52
yamux/src/tagged_stream.rs
Normal file
52
yamux/src/tagged_stream.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// A stream that yields its tag with every item.
|
||||
#[pin_project::pin_project]
|
||||
pub struct TaggedStream<K, S> {
|
||||
key: K,
|
||||
#[pin]
|
||||
inner: S,
|
||||
|
||||
reported_none: bool,
|
||||
}
|
||||
|
||||
impl<K, S> TaggedStream<K, S> {
|
||||
pub fn new(key: K, inner: S) -> Self {
|
||||
Self {
|
||||
key,
|
||||
inner,
|
||||
reported_none: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner_mut(&mut self) -> &mut S {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, S> Stream for TaggedStream<K, S>
|
||||
where
|
||||
K: Copy,
|
||||
S: Stream,
|
||||
{
|
||||
type Item = (K, Option<S::Item>);
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
if *this.reported_none {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
match futures::ready!(this.inner.poll_next(cx)) {
|
||||
Some(item) => Poll::Ready(Some((*this.key, Some(item)))),
|
||||
None => {
|
||||
*this.reported_none = true;
|
||||
|
||||
Poll::Ready(Some((*this.key, None)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user