diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs index 3a72504e93..9bd4afa82e 100644 --- a/crates/net/eth-wire/src/capability.rs +++ b/crates/net/eth-wire/src/capability.rs @@ -376,17 +376,24 @@ impl SharedCapabilities { /// Returns the matching shared capability for the given capability offset. /// - /// `offset` is the multiplexed message id offset of the capability relative to - /// [`MAX_RESERVED_MESSAGE_ID`]. + /// `offset` is the multiplexed message id offset of the capability relative to the reserved + /// message id space. In other words, counting starts at [`MAX_RESERVED_MESSAGE_ID`] + 1, which + /// corresponds to the first non-reserved message id. + /// + /// For example: `offset == 0` corresponds to the first shared message across the shared + /// capabilities and will return the first shared capability that supports messages. #[inline] pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> { - self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID)) + self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1)) } /// Returns the matching shared capability for the given capability offset. /// /// `offset` is the multiplexed message id offset of the capability that includes the reserved /// message id space. + /// + /// This will always return None if `offset` is less than or equal to + /// [`MAX_RESERVED_MESSAGE_ID`] because the reserved message id space is not shared. #[inline] pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> { let mut iter = self.0.iter(); @@ -637,12 +644,14 @@ mod tests { let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap(); - assert!(shared.find_by_relative_offset(0).is_none()); - let shared_eth = shared.find_by_relative_offset(1).unwrap(); + let shared_eth = shared.find_by_relative_offset(0).unwrap(); assert_eq!(shared_eth.name(), "eth"); let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap(); assert_eq!(shared_eth.name(), "eth"); + + // reserved message id space + assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none()); } #[test] @@ -654,15 +663,14 @@ mod tests { let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap(); - assert!(shared.find_by_relative_offset(0).is_none()); - let shared_eth = shared.find_by_relative_offset(1).unwrap(); + let shared_eth = shared.find_by_relative_offset(0).unwrap(); assert_eq!(shared_eth.name(), proto.cap.name); let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap(); assert_eq!(shared_eth.name(), proto.cap.name); - // the 5th shared message is the last message of the aaa capability - let shared_eth = shared.find_by_relative_offset(5).unwrap(); + // the 5th shared message (0,1,2,3,4) is the last message of the aaa capability + let shared_eth = shared.find_by_relative_offset(4).unwrap(); assert_eq!(shared_eth.name(), proto.cap.name); let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap(); assert_eq!(shared_eth.name(), proto.cap.name); diff --git a/crates/net/eth-wire/src/multiplex.rs b/crates/net/eth-wire/src/multiplex.rs index d0dcf467e5..e3bb92ecad 100644 --- a/crates/net/eth-wire/src/multiplex.rs +++ b/crates/net/eth-wire/src/multiplex.rs @@ -65,15 +65,16 @@ impl RlpxProtocolMultiplexer { mut self, cap: &Capability, handshake: F, - ) -> Result, Self> + ) -> Result, Err> where F: FnOnce(ProtocolProxy) -> Fut, Fut: Future>, St: Stream> + Sink + Unpin, + P2PStreamError: Into, { let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned() else { - return Err(self) + return Err(P2PStreamError::CapabilityNotShared.into()) }; let (to_primary, from_wire) = mpsc::unbounded_channel(); @@ -87,20 +88,36 @@ impl RlpxProtocolMultiplexer { let f = handshake(proxy); pin_mut!(f); - // handle messages until the handshake is complete + // this polls the connection and the primary stream concurrently until the handshake is + // complete loop { - // TODO error handling tokio::select! { Some(Ok(msg)) = self.conn.next() => { - // TODO handle multiplex - let _ = to_primary.send(msg); + // Ensure the message belongs to the primary protocol + let offset = msg[0]; + if let Some(cap) = self.conn.shared_capabilities().find_by_relative_offset(offset) { + if cap == &shared_cap { + // delegate to primary + let _ = to_primary.send(msg); + } else { + // delegate to satellite + for proto in &self.protocols { + if proto.cap == *cap { + // TODO: need some form of backpressure here so buffering can't be abused + proto.send_raw(msg); + break + } + } + } + } else { + return Err(P2PStreamError::UnknownReservedMessageId(offset).into()) + } } Some(msg) = from_primary.recv() => { - // TODO error handling - self.conn.send(msg).await.unwrap(); + self.conn.send(msg).await.map_err(Into::into)?; } res = &mut f => { - let Ok(primary) = res else { return Err(self) }; + let primary = res?; return Ok(RlpxSatelliteStream { conn: self.conn, to_primary, @@ -117,24 +134,47 @@ impl RlpxProtocolMultiplexer { } /// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth") +/// +/// Only emits and sends _non-empty_ messages #[derive(Debug)] pub struct ProtocolProxy { cap: SharedCapability, + /// Receives _non-empty_ messages from the wire from_wire: UnboundedReceiverStream, + /// Sends _non-empty_ messages from the wire to_wire: UnboundedSender, } impl ProtocolProxy { + /// Sends a _non-empty_ message on the wire. + fn try_send(&self, msg: Bytes) -> Result<(), io::Error> { + if msg.is_empty() { + // message must not be empty + return Err(io::ErrorKind::InvalidInput.into()) + } + self.to_wire.send(self.mask_msg_id(msg)).map_err(|_| io::ErrorKind::BrokenPipe.into()) + } + + /// Masks the message ID of a message to be sent on the wire. + /// + /// # Panics + /// + /// If the message is empty. + #[inline] fn mask_msg_id(&self, msg: Bytes) -> Bytes { - // TODO handle empty messages let mut masked_bytes = BytesMut::zeroed(msg.len()); masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset(); masked_bytes[1..].copy_from_slice(&msg[1..]); masked_bytes.freeze() } + /// Unmasks the message ID of a message received from the wire. + /// + /// # Panics + /// + /// If the message is empty. + #[inline] fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { - // TODO handle empty messages msg[0] -= self.cap.relative_message_id_offset(); msg } @@ -157,8 +197,7 @@ impl Sink for ProtocolProxy { } fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - let msg = self.mask_msg_id(item); - self.to_wire.send(msg).map_err(|_| io::ErrorKind::BrokenPipe.into()) + self.get_mut().try_send(item) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { @@ -181,7 +220,7 @@ impl CanDisconnect for ProtocolProxy { } } -/// A connection channel to receive messages for the negotiated protocol. +/// A connection channel to receive _non_empty_ messages for the negotiated protocol. /// /// This is a [Stream] that returns raw bytes of the received messages for this protocol. #[derive(Debug)] @@ -287,34 +326,28 @@ where Poll::Ready(Some(Ok(msg))) => { delegated = true; let offset = msg[0]; - // find the protocol that matches the offset - // TODO optimize this by keeping a better index - let mut lowest_satellite = None; - // find the protocol with the lowest offset that is greater than the message - // offset - for (i, proto) in this.satellites.iter().enumerate() { - let proto_offset = proto.cap.relative_message_id_offset(); - if proto_offset >= offset { - if let Some((_, lowest_offset)) = lowest_satellite { - if proto_offset < lowest_offset { - lowest_satellite = Some((i, proto_offset)); + // delegate the multiplexed message to the correct protocol + if let Some(cap) = + this.conn.shared_capabilities().find_by_relative_offset(offset) + { + if cap == &this.primary_capability { + // delegate to primary + let _ = this.to_primary.send(msg); + } else { + // delegate to satellite + for proto in &this.satellites { + if proto.cap == *cap { + proto.send_raw(msg); + break } - } else { - lowest_satellite = Some((i, proto_offset)); } } + } else { + return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId( + offset, + ) + .into()))) } - - if let Some((idx, lowest_offset)) = lowest_satellite { - if lowest_offset < this.primary_capability.relative_message_id_offset() - { - // delegate to satellite - this.satellites[idx].send_raw(msg); - continue - } - } - // delegate to primary - let _ = this.to_primary.send(msg); } Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), Poll::Ready(None) => { @@ -373,18 +406,29 @@ struct ProtocolStream { } impl ProtocolStream { + /// Masks the message ID of a message to be sent on the wire. + /// + /// # Panics + /// + /// If the message is empty. + #[inline] fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes { - // TODO handle empty messages msg[0] += self.cap.relative_message_id_offset(); msg.freeze() } + /// Unmasks the message ID of a message received from the wire. + /// + /// # Panics + /// + /// If the message is empty. + #[inline] fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { - // TODO handle empty messages msg[0] -= self.cap.relative_message_id_offset(); msg } + /// Sends the message to the satellite stream. fn send_raw(&self, msg: BytesMut) { let _ = self.to_satellite.send(self.unmask_id(msg)); } @@ -396,7 +440,7 @@ impl Stream for ProtocolStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let msg = ready!(this.satellite_st.as_mut().poll_next(cx)); - Poll::Ready(msg.map(|msg| this.mask_msg_id(msg))) + Poll::Ready(msg.filter(|msg| !msg.is_empty()).map(|msg| this.mask_msg_id(msg))) } } @@ -408,15 +452,13 @@ impl fmt::Debug for ProtocolStream { #[cfg(test)] mod tests { - use tokio::net::TcpListener; - use tokio_util::codec::Decoder; - + use super::*; use crate::{ test_utils::{connect_passthrough, eth_handshake, eth_hello}, UnauthedEthStream, UnauthedP2PStream, }; - - use super::*; + use tokio::net::TcpListener; + use tokio_util::codec::Decoder; #[tokio::test] async fn eth_satellite() { diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index ed6001fb93..8d407c4872 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -228,9 +228,10 @@ where /// /// See also /// -/// This stream emits Bytes that start with the normalized message id, so that the first byte of -/// each message starts from 0. If this stream only supports a single capability, for example `eth` -/// then the first byte of each message will match [EthMessageID](crate::types::EthMessageID). +/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first +/// byte of each message starts from 0. If this stream only supports a single capability, for +/// example `eth` then the first byte of each message will match +/// [EthMessageID](crate::types::EthMessageID). #[pin_project] #[derive(Debug)] pub struct P2PStream { @@ -405,6 +406,11 @@ where None => return Poll::Ready(None), }; + if bytes.is_empty() { + // empty messages are not allowed + return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage))) + } + // first check that the compressed message length does not exceed the max // payload size let decompressed_len = snap::raw::decompress_len(&bytes[1..])?; @@ -430,7 +436,7 @@ where err })?; - let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?; + let id = bytes[0]; match id { _ if id == P2PMessageID::Ping as u8 => { trace!("Received Ping, Sending Pong");