diff --git a/crates/net/eth-wire/src/multiplex.rs b/crates/net/eth-wire/src/multiplex.rs index d44f5ea7eb..9eb4f15f0b 100644 --- a/crates/net/eth-wire/src/multiplex.rs +++ b/crates/net/eth-wire/src/multiplex.rs @@ -13,15 +13,17 @@ use std::{ future::Future, io, pin::{pin, Pin}, + sync::Arc, task::{ready, Context, Poll}, }; use crate::{ capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError}, errors::{EthStreamError, P2PStreamError}, + handshake::EthRlpxHandshake, p2pstream::DisconnectP2P, - CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnauthedEthStream, - UnifiedStatus, + CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus, + HANDSHAKE_TIMEOUT, }; use bytes::{Bytes, BytesMut}; use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; @@ -135,7 +137,7 @@ impl RlpxProtocolMultiplexer { /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the /// primary stream and extra data. /// - /// See also [`UnauthedEthStream::handshake`] + /// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream) pub async fn into_satellite_stream_with_tuple_handshake( mut self, cap: &Capability, @@ -167,6 +169,7 @@ impl RlpxProtocolMultiplexer { // complete loop { tokio::select! { + biased; Some(Ok(msg)) = self.inner.conn.next() => { // Ensure the message belongs to the primary protocol let Some(offset) = msg.first().copied() @@ -188,6 +191,10 @@ impl RlpxProtocolMultiplexer { Some(msg) = from_primary.recv() => { self.inner.conn.send(msg).await.map_err(Into::into)?; } + // Poll all subprotocols for new messages + msg = ProtocolsPoller::new(&mut self.inner.protocols) => { + self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?; + } res = &mut f => { let (st, extra) = res?; return Ok((RlpxSatelliteStream { @@ -205,22 +212,28 @@ impl RlpxProtocolMultiplexer { } /// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given - /// primary protocol. + /// primary protocol and the handshake implementation. pub async fn into_eth_satellite_stream( self, status: UnifiedStatus, fork_filter: ForkFilter, + handshake: Arc, ) -> Result<(RlpxSatelliteStream>, UnifiedStatus), EthStreamError> where St: Stream> + Sink + Unpin, { let eth_cap = self.inner.conn.shared_capabilities().eth_version()?; - self.into_satellite_stream_with_tuple_handshake( - &Capability::eth(eth_cap), - move |proxy| async move { - UnauthedEthStream::new(proxy).handshake(status, fork_filter).await - }, - ) + self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| { + let handshake = handshake.clone(); + async move { + let mut unauth = UnauthProxy { inner: proxy }; + let their_status = handshake + .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT) + .await?; + let eth_stream = EthStream::new(eth_cap, unauth.into_inner()); + Ok((eth_stream, their_status)) + } + }) .await } } @@ -377,6 +390,57 @@ impl CanDisconnect for ProtocolProxy { } } +/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy` +/// using the same error type expectations (`P2PStreamError`). +#[derive(Debug)] +struct UnauthProxy { + inner: ProtocolProxy, +} + +impl UnauthProxy { + fn into_inner(self) -> ProtocolProxy { + self.inner + } +} + +impl Stream for UnauthProxy { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from))) + } +} + +impl Sink for UnauthProxy { + type Error = P2PStreamError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.inner.start_send_unpin(item).map_err(P2PStreamError::from) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from) + } +} + +impl CanDisconnect for UnauthProxy { + fn disconnect( + &mut self, + reason: DisconnectReason, + ) -> Pin>::Error>> + Send + '_>> { + let fut = self.inner.disconnect(reason); + Box::pin(async move { fut.await.map_err(P2PStreamError::from) }) + } +} + /// A connection channel to receive _`non_empty`_ messages for the negotiated protocol. /// /// This is a [Stream] that returns raw bytes of the received messages for this protocol. @@ -666,15 +730,56 @@ impl fmt::Debug for ProtocolStream { } } +/// Helper to poll multiple protocol streams in a `tokio::select`! branch +struct ProtocolsPoller<'a> { + protocols: &'a mut Vec, +} + +impl<'a> ProtocolsPoller<'a> { + const fn new(protocols: &'a mut Vec) -> Self { + Self { protocols } + } +} + +impl<'a> Future for ProtocolsPoller<'a> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Process protocols in reverse order, like the existing pattern + for idx in (0..self.protocols.len()).rev() { + let mut proto = self.protocols.swap_remove(idx); + match proto.poll_next_unpin(cx) { + Poll::Ready(Some(Err(err))) => { + self.protocols.push(proto); + return Poll::Ready(Err(P2PStreamError::from(err))) + } + Poll::Ready(Some(Ok(msg))) => { + // Got a message, put protocol back and return the message + self.protocols.push(proto); + return Poll::Ready(Ok(msg)); + } + _ => { + // push it back because we still want to complete the handshake first + self.protocols.push(proto); + } + } + } + + // All protocols processed, nothing ready + Poll::Pending + } +} + #[cfg(test)] mod tests { use super::*; use crate::{ + handshake::EthHandshake, test_utils::{ connect_passthrough, eth_handshake, eth_hello, proto::{test_hello, TestProtoMessage}, }, - UnauthedP2PStream, + UnauthedEthStream, UnauthedP2PStream, }; use reth_eth_wire_types::EthNetworkPrimitives; use tokio::{net::TcpListener, sync::oneshot}; @@ -736,7 +841,11 @@ mod tests { let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap(); let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn) - .into_eth_satellite_stream::(other_status, other_fork_filter) + .into_eth_satellite_stream::( + other_status, + other_fork_filter, + Arc::new(EthHandshake::default()), + ) .await .unwrap(); @@ -767,7 +876,11 @@ mod tests { let conn = connect_passthrough(local_addr, test_hello().0).await; let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn) - .into_eth_satellite_stream::(status, fork_filter) + .into_eth_satellite_stream::( + status, + fork_filter, + Arc::new(EthHandshake::default()), + ) .await .unwrap(); diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index e94376948c..c6bdb198b1 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -1150,18 +1150,20 @@ async fn authenticate_stream( .ok(); } - let (multiplex_stream, their_status) = - match multiplex_stream.into_eth_satellite_stream(status, fork_filter).await { - Ok((multiplex_stream, their_status)) => (multiplex_stream, their_status), - Err(err) => { - return PendingSessionEvent::Disconnected { - remote_addr, - session_id, - direction, - error: Some(PendingSessionHandshakeError::Eth(err)), - } + let (multiplex_stream, their_status) = match multiplex_stream + .into_eth_satellite_stream(status, fork_filter, handshake) + .await + { + Ok((multiplex_stream, their_status)) => (multiplex_stream, their_status), + Err(err) => { + return PendingSessionEvent::Disconnected { + remote_addr, + session_id, + direction, + error: Some(PendingSessionHandshakeError::Eth(err)), } - }; + } + }; (multiplex_stream.into(), their_status) };