diff --git a/crates/net/network/src/transactions.rs b/crates/net/network/src/transactions.rs index b67675a6f0..fe2e5e579d 100644 --- a/crates/net/network/src/transactions.rs +++ b/crates/net/network/src/transactions.rs @@ -34,7 +34,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, oneshot::error::RecvError}; use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream}; use tracing::{debug, trace}; @@ -101,7 +101,7 @@ pub struct TransactionsManager { /// From which we get all new incoming transaction related messages. network_events: UnboundedReceiverStream, /// All currently active requests for pooled transactions. - inflight_requests: Vec, + inflight_requests: FuturesUnordered, /// All currently pending transactions grouped by peers. /// /// This way we can track incoming transactions and prevent multiple pool imports for the same @@ -349,7 +349,7 @@ where }; if peer.request_tx.try_send(req).is_ok() { - self.inflight_requests.push(GetPooledTxRequest { peer_id, response: rx }) + self.inflight_requests.push(GetPooledTxRequestFut::new(peer_id, rx)) } else { // peer channel is saturated, drop the request self.metrics.egress_peer_channel_full.increment(1); @@ -574,28 +574,23 @@ where } // Advance all requests. - // We remove each request one by one and add them back. - for idx in (0..this.inflight_requests.len()).rev() { - let mut req = this.inflight_requests.swap_remove(idx); - match req.response.poll_unpin(cx) { - Poll::Pending => { - this.inflight_requests.push(req); + while let Poll::Ready(Some(GetPooledTxResponse { peer_id, result })) = + this.inflight_requests.poll_next_unpin(cx) + { + match result { + Ok(Ok(txs)) => { + this.import_transactions(peer_id, txs.0, TransactionSource::Response); } - Poll::Ready(Ok(Ok(txs))) => { - this.import_transactions(req.peer_id, txs.0, TransactionSource::Response); + Ok(Err(req_err)) => { + this.on_request_error(peer_id, req_err); } - Poll::Ready(Ok(Err(req_err))) => { - this.on_request_error(req.peer_id, req_err); - } - Poll::Ready(Err(_)) => { + Err(_) => { // request channel closed/dropped - this.on_request_error(req.peer_id, RequestError::ChannelClosed) + this.on_request_error(peer_id, RequestError::ChannelClosed) } } } - this.inflight_requests.shrink_to_fit(); - this.update_import_metrics(); // Advance all imports @@ -756,12 +751,49 @@ impl TransactionSource { } /// An inflight request for `PooledTransactions` from a peer -#[allow(missing_docs)] struct GetPooledTxRequest { peer_id: PeerId, response: oneshot::Receiver>, } +struct GetPooledTxResponse { + peer_id: PeerId, + result: Result, RecvError>, +} + +#[must_use = "futures do nothing unless polled"] +#[pin_project::pin_project] +struct GetPooledTxRequestFut { + #[pin] + inner: Option, +} + +impl GetPooledTxRequestFut { + fn new( + peer_id: PeerId, + response: oneshot::Receiver>, + ) -> Self { + Self { inner: Some(GetPooledTxRequest { peer_id, response }) } + } +} + +impl Future for GetPooledTxRequestFut { + type Output = GetPooledTxResponse; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut req = self.as_mut().project().inner.take().expect("polled after completion"); + match req.response.poll_unpin(cx) { + Poll::Ready(result) => { + Poll::Ready(GetPooledTxResponse { peer_id: req.peer_id, result }) + } + Poll::Pending => { + self.project().inner.set(Some(req)); + Poll::Pending + } + } + } +} + /// Tracks a single peer struct Peer { /// Keeps track of transactions that we know the peer has seen.