From cd4d6c52b0d9c501dcf340052623d29592df168c Mon Sep 17 00:00:00 2001 From: Emilia Hane Date: Fri, 8 Dec 2023 09:21:01 +0100 Subject: [PATCH] Cap mux simple (#5577) Signed-off-by: Emilia Hane --- Cargo.lock | 2 + crates/net/eth-wire/Cargo.toml | 3 + crates/net/eth-wire/src/capability.rs | 41 +- crates/net/eth-wire/src/disconnect.rs | 2 +- crates/net/eth-wire/src/errors/eth.rs | 9 +- crates/net/eth-wire/src/errors/mod.rs | 2 + crates/net/eth-wire/src/errors/muxdemux.rs | 47 ++ crates/net/eth-wire/src/ethstream.rs | 2 +- crates/net/eth-wire/src/lib.rs | 6 +- crates/net/eth-wire/src/muxdemux.rs | 592 +++++++++++++++++++++ crates/net/eth-wire/src/p2pstream.rs | 20 +- crates/net/network/src/session/active.rs | 2 +- 12 files changed, 703 insertions(+), 25 deletions(-) create mode 100644 crates/net/eth-wire/src/errors/muxdemux.rs create mode 100644 crates/net/eth-wire/src/muxdemux.rs diff --git a/Cargo.lock b/Cargo.lock index eb7fc739da..1a19bc7ab7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5898,6 +5898,7 @@ dependencies = [ "arbitrary", "async-trait", "bytes", + "derive_more", "ethers-core", "futures", "metrics", @@ -5909,6 +5910,7 @@ dependencies = [ "reth-discv4", "reth-ecies", "reth-metrics", + "reth-net-common", "reth-primitives", "reth-tracing", "secp256k1 0.27.0", diff --git a/crates/net/eth-wire/Cargo.toml b/crates/net/eth-wire/Cargo.toml index a5f691ef38..88af65762b 100644 --- a/crates/net/eth-wire/Cargo.toml +++ b/crates/net/eth-wire/Cargo.toml @@ -21,6 +21,7 @@ reth-metrics.workspace = true metrics.workspace = true bytes.workspace = true +derive_more = "0.99.17" thiserror.workspace = true serde = { workspace = true, optional = true } tokio = { workspace = true, features = ["full"] } @@ -38,10 +39,12 @@ proptest = { workspace = true, optional = true } proptest-derive = { workspace = true, optional = true } [dev-dependencies] +reth-net-common.workspace = true reth-primitives = { workspace = true, features = ["arbitrary"] } reth-tracing.workspace = true ethers-core = { workspace = true, default-features = false } + test-fuzz = "4" tokio-util = { workspace = true, features = ["io", "codec"] } rand.workspace = true diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs index c1e6d84ab6..8696109c68 100644 --- a/crates/net/eth-wire/src/capability.rs +++ b/crates/net/eth-wire/src/capability.rs @@ -8,6 +8,7 @@ use crate::{ EthMessage, EthMessageID, EthVersion, }; use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable}; +use derive_more::{Deref, DerefMut}; use reth_codecs::add_arbitrary_tests; use reth_primitives::bytes::{BufMut, Bytes}; #[cfg(feature = "serde")] @@ -249,14 +250,23 @@ pub enum SharedCapability { /// This represents the message ID offset for the first message of the eth capability in /// the message id space. offset: u8, + /// The number of messages of this capability. Needed to calculate range of message IDs in + /// demuxing. + messages: u8, }, } impl SharedCapability { - /// Creates a new [`SharedCapability`] based on the given name, offset, and version. + /// Creates a new [`SharedCapability`] based on the given name, offset, version (and messages + /// if the capability is custom). /// /// Returns an error if the offset is equal or less than [`MAX_RESERVED_MESSAGE_ID`]. - pub(crate) fn new(name: &str, version: u8, offset: u8) -> Result { + pub(crate) fn new( + name: &str, + version: u8, + offset: u8, + messages: u8, + ) -> Result { if offset <= MAX_RESERVED_MESSAGE_ID { return Err(SharedCapabilityError::ReservedMessageIdOffset(offset)) } @@ -266,6 +276,7 @@ impl SharedCapability { _ => Ok(Self::UnknownCapability { cap: Capability::new(name.to_string(), version as usize), offset, + messages, }), } } @@ -324,10 +335,10 @@ impl SharedCapability { } /// Returns the number of protocol messages supported by this capability. - pub fn num_messages(&self) -> Result { + pub fn num_messages(&self) -> u8 { match self { - SharedCapability::Eth { version: _version, .. } => Ok(EthMessageID::max() + 1), - _ => Err(SharedCapabilityError::UnknownCapability), + SharedCapability::Eth { version: _version, .. } => EthMessageID::max() + 1, + SharedCapability::UnknownCapability { messages, .. } => *messages, } } } @@ -335,7 +346,7 @@ impl SharedCapability { /// Non-empty,ordered list of recognized shared capabilities. /// /// Shared capabilities are ordered alphabetically by case sensitive name. -#[derive(Debug)] +#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)] pub struct SharedCapabilities(Vec); impl SharedCapabilities { @@ -500,9 +511,14 @@ pub fn shared_capability_offsets( for name in shared_capability_names { let proto_version = shared_capabilities.get(&name).expect("shared; qed"); - let shared_capability = SharedCapability::new(&name, proto_version.version as u8, offset)?; + let shared_capability = SharedCapability::new( + &name, + proto_version.version as u8, + offset, + proto_version.messages, + )?; - offset += proto_version.messages; + offset += shared_capability.num_messages(); shared_with_offsets.push(shared_capability); } @@ -519,9 +535,6 @@ pub enum SharedCapabilityError { /// Unsupported `eth` version. #[error(transparent)] UnsupportedVersion(#[from] ParseVersionError), - /// Cannot determine the number of messages for unknown capabilities. - #[error("cannot determine the number of messages for unknown capabilities")] - UnknownCapability, /// Thrown when the message id for a [SharedCapability] overlaps with the reserved p2p message /// id space [`MAX_RESERVED_MESSAGE_ID`]. #[error("message id offset `{0}` is reserved")] @@ -541,7 +554,7 @@ mod tests { #[test] fn from_eth_68() { - let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1).unwrap(); + let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap(); assert_eq!(capability.name(), "eth"); assert_eq!(capability.version(), 68); @@ -556,7 +569,7 @@ mod tests { #[test] fn from_eth_67() { - let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1).unwrap(); + let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap(); assert_eq!(capability.name(), "eth"); assert_eq!(capability.version(), 67); @@ -571,7 +584,7 @@ mod tests { #[test] fn from_eth_66() { - let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1).unwrap(); + let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap(); assert_eq!(capability.name(), "eth"); assert_eq!(capability.version(), 66); diff --git a/crates/net/eth-wire/src/disconnect.rs b/crates/net/eth-wire/src/disconnect.rs index aa3c6d220e..e03f99f07f 100644 --- a/crates/net/eth-wire/src/disconnect.rs +++ b/crates/net/eth-wire/src/disconnect.rs @@ -150,7 +150,7 @@ impl Decodable for DisconnectReason { /// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the /// underlying stream supports it. #[async_trait::async_trait] -pub trait CanDisconnect: Sink + Unpin + Sized { +pub trait CanDisconnect: Sink + Unpin { /// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect /// information if the stream implements a protocol that can carry the additional disconnect /// metadata. diff --git a/crates/net/eth-wire/src/errors/eth.rs b/crates/net/eth-wire/src/errors/eth.rs index 21645def4a..e05fa4e52b 100644 --- a/crates/net/eth-wire/src/errors/eth.rs +++ b/crates/net/eth-wire/src/errors/eth.rs @@ -1,6 +1,8 @@ //! Error handling for (`EthStream`)[crate::EthStream] use crate::{ - errors::P2PStreamError, version::ParseVersionError, DisconnectReason, EthMessageID, EthVersion, + errors::{MuxDemuxError, P2PStreamError}, + version::ParseVersionError, + DisconnectReason, EthMessageID, EthVersion, }; use reth_primitives::{Chain, GotExpected, GotExpectedBoxed, ValidationError, B256}; use std::io; @@ -13,6 +15,9 @@ pub enum EthStreamError { /// Error of the underlying P2P connection. P2PStreamError(#[from] P2PStreamError), #[error(transparent)] + /// Error of the underlying de-/muxed P2P connection. + MuxDemuxError(#[from] MuxDemuxError), + #[error(transparent)] /// Failed to parse peer's version. ParseVersionError(#[from] ParseVersionError), #[error(transparent)] @@ -43,6 +48,8 @@ impl EthStreamError { pub fn as_disconnected(&self) -> Option { if let EthStreamError::P2PStreamError(err) = self { err.as_disconnected() + } else if let EthStreamError::MuxDemuxError(MuxDemuxError::P2PStreamError(err)) = self { + err.as_disconnected() } else { None } diff --git a/crates/net/eth-wire/src/errors/mod.rs b/crates/net/eth-wire/src/errors/mod.rs index be3f8ced7f..c231e48608 100644 --- a/crates/net/eth-wire/src/errors/mod.rs +++ b/crates/net/eth-wire/src/errors/mod.rs @@ -1,7 +1,9 @@ //! Error types for stream variants mod eth; +mod muxdemux; mod p2p; pub use eth::*; +pub use muxdemux::*; pub use p2p::*; diff --git a/crates/net/eth-wire/src/errors/muxdemux.rs b/crates/net/eth-wire/src/errors/muxdemux.rs new file mode 100644 index 0000000000..74ca6e2fcf --- /dev/null +++ b/crates/net/eth-wire/src/errors/muxdemux.rs @@ -0,0 +1,47 @@ +use thiserror::Error; + +use crate::capability::{SharedCapabilityError, UnsupportedCapabilityError}; + +use super::P2PStreamError; + +/// Errors thrown by de-/muxing. +#[derive(Error, Debug)] +pub enum MuxDemuxError { + /// Error of the underlying P2P connection. + #[error(transparent)] + P2PStreamError(#[from] P2PStreamError), + /// Stream is in use by secondary stream impeding disconnect. + #[error("secondary streams are still running")] + StreamInUse, + /// Stream has already been set up for this capability stream type. + #[error("stream already init for stream type")] + StreamAlreadyExists, + /// Capability stream type is not shared with peer on underlying p2p connection. + #[error("stream type is not shared on this p2p connection")] + CapabilityNotShared, + /// Capability stream type has not been configured in [`crate::muxdemux::MuxDemuxer`]. + #[error("stream type is not configured")] + CapabilityNotConfigured, + /// Capability stream type has not been configured for + /// [`crate::capability::SharedCapabilities`] type. + #[error("stream type is not recognized")] + CapabilityNotRecognized, + /// Message ID is out of range. + #[error("message id out of range, {0}")] + MessageIdOutOfRange(u8), + /// Demux channel failed. + #[error("sending demuxed bytes to secondary stream failed")] + SendIngressBytesFailed, + /// Mux channel failed. + #[error("sending bytes from secondary stream to mux failed")] + SendEgressBytesFailed, + /// Attempt to disconnect the p2p stream via a stream clone. + #[error("secondary stream cannot disconnect p2p stream")] + CannotDisconnectP2PStream, + /// Shared capability error. + #[error(transparent)] + SharedCapabilityError(#[from] SharedCapabilityError), + /// Capability not supported on the p2p connection. + #[error(transparent)] + UnsupportedCapabilityError(#[from] UnsupportedCapabilityError), +} diff --git a/crates/net/eth-wire/src/ethstream.rs b/crates/net/eth-wire/src/ethstream.rs index 23f64040a4..f1162f7ee4 100644 --- a/crates/net/eth-wire/src/ethstream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -283,7 +283,7 @@ where // start_disconnect, which would ideally be a part of the CanDisconnect trait, or at // least similar. // - // Other parts of reth do not need traits like CanDisconnect because they work + // Other parts of reth do not yet need traits like CanDisconnect because atm they work // exclusively with EthStream>, where the inner P2PStream is accessible, // allowing for its start_disconnect method to be called. // diff --git a/crates/net/eth-wire/src/lib.rs b/crates/net/eth-wire/src/lib.rs index a1c7c70bdf..c090deb5f6 100644 --- a/crates/net/eth-wire/src/lib.rs +++ b/crates/net/eth-wire/src/lib.rs @@ -21,6 +21,7 @@ pub mod errors; mod ethstream; mod hello; pub mod multiplex; +pub mod muxdemux; mod p2pstream; mod pinger; pub mod protocol; @@ -37,11 +38,14 @@ pub use tokio_util::codec::{ }; pub use crate::{ + capability::Capability, disconnect::{CanDisconnect, DisconnectReason}, ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE}, hello::{HelloMessage, HelloMessageBuilder, HelloMessageWithProtocols}, + muxdemux::{MuxDemuxStream, StreamClone}, p2pstream::{ - P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream, + DisconnectP2P, P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream, MAX_RESERVED_MESSAGE_ID, }, + types::EthVersion, }; diff --git a/crates/net/eth-wire/src/muxdemux.rs b/crates/net/eth-wire/src/muxdemux.rs new file mode 100644 index 0000000000..621eb64d9b --- /dev/null +++ b/crates/net/eth-wire/src/muxdemux.rs @@ -0,0 +1,592 @@ +//! [`MuxDemuxer`] allows for multiple capability streams to share the same p2p connection. De-/ +//! muxing the connection offers two stream types [`MuxDemuxStream`] and [`StreamClone`]. +//! [`MuxDemuxStream`] is the main stream that wraps the p2p connection, only this stream can +//! advance transfer across the network. One [`MuxDemuxStream`] can have many [`StreamClone`]s, +//! these are weak clones of the stream and depend on advancing the [`MuxDemuxStream`] to make +//! progress. +//! +//! [`MuxDemuxer`] filters bytes according to message ID offset. The message ID offset is +//! negotiated upon start of the p2p connection. Bytes received by polling the [`MuxDemuxStream`] +//! or a [`StreamClone`] are specific to the capability stream wrapping it. When received the +//! message IDs are unmasked so that all message IDs start at 0x0. [`MuxDemuxStream`] and +//! [`StreamClone`] mask message IDs before sinking bytes to the [`MuxDemuxer`]. +//! +//! For example, `EthStream>>` is the main capability stream. +//! Subsequent capability streams clone the p2p connection via EthStream. +//! +//! When [`MuxDemuxStream`] is polled, [`MuxDemuxer`] receives bytes from the network. If these +//! bytes belong to the capability stream wrapping the [`MuxDemuxStream`] then they are passed up +//! directly. If these bytes however belong to another capability stream, then they are buffered +//! on a channel. When [`StreamClone`] is polled, bytes are read from this buffer. Similarly +//! [`StreamClone`] buffers egress bytes for [`MuxDemuxer`] that are read and sent to the network +//! when [`MuxDemuxStream`] is polled. + +use std::{ + collections::HashMap, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use derive_more::{Deref, DerefMut}; +use futures::{Sink, SinkExt, StreamExt}; +use reth_primitives::bytes::{Bytes, BytesMut}; +use tokio::sync::mpsc; +use tokio_stream::Stream; + +use crate::{ + capability::{Capability, SharedCapabilities, SharedCapability}, + errors::MuxDemuxError, + CanDisconnect, DisconnectP2P, DisconnectReason, +}; + +use MuxDemuxError::*; + +/// Stream MUX/DEMUX acts like a regular stream and sink for the owning stream, and handles bytes +/// belonging to other streams over their respective channels. +#[derive(Debug)] +pub struct MuxDemuxer { + // receive and send muxed p2p outputs + inner: S, + // owner of the stream. stores message id offset for this capability. + owner: SharedCapability, + // receive muxed p2p inputs from stream clones + mux: mpsc::UnboundedReceiver, + // send demuxed p2p outputs to app + demux: HashMap>, + // sender to mux stored to make new stream clones + mux_tx: mpsc::UnboundedSender, + // capabilities supported by underlying p2p stream (makes testing easier to store here too). + shared_capabilities: SharedCapabilities, +} + +/// The main stream on top of the p2p stream. Wraps [`MuxDemuxer`] and enforces it can't be dropped +/// before all secondary streams are dropped (stream clones). +#[derive(Debug, Deref, DerefMut)] +pub struct MuxDemuxStream(MuxDemuxer); + +impl MuxDemuxStream { + /// Creates a new [`MuxDemuxer`]. + pub fn try_new( + inner: S, + cap: Capability, + shared_capabilities: SharedCapabilities, + ) -> Result { + let owner = Self::shared_cap(&cap, &shared_capabilities)?.clone(); + + let demux = HashMap::new(); + let (mux_tx, mux) = mpsc::unbounded_channel(); + + Ok(Self(MuxDemuxer { inner, owner, mux, demux, mux_tx, shared_capabilities })) + } + + /// Clones the stream if the given capability stream type is shared on the underlying p2p + /// connection. + pub fn try_clone_stream(&mut self, cap: &Capability) -> Result { + let cap = self.shared_capabilities.ensure_matching_capability(cap)?.clone(); + let ingress = self.reg_new_ingress_buffer(&cap)?; + let mux_tx = self.mux_tx.clone(); + + Ok(StreamClone { stream: ingress, sink: mux_tx, cap }) + } + + /// Starts a graceful disconnect. + pub fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), MuxDemuxError> + where + S: DisconnectP2P, + { + if !self.can_drop() { + return Err(StreamInUse) + } + + self.inner.start_disconnect(reason).map_err(|e| e.into()) + } + + /// Returns `true` if the connection is about to disconnect. + pub fn is_disconnecting(&self) -> bool + where + S: DisconnectP2P, + { + self.inner.is_disconnecting() + } + + /// Shared capabilities of underlying p2p connection as negotiated by peers at connection + /// open. + pub fn shared_capabilities(&self) -> &SharedCapabilities { + &self.shared_capabilities + } + + fn shared_cap<'a>( + cap: &Capability, + shared_capabilities: &'a SharedCapabilities, + ) -> Result<&'a SharedCapability, MuxDemuxError> { + for shared_cap in shared_capabilities.iter_caps() { + match shared_cap { + SharedCapability::Eth { .. } if cap.is_eth() => return Ok(shared_cap), + SharedCapability::UnknownCapability { cap: unknown_cap, .. } + if cap == unknown_cap => + { + return Ok(shared_cap) + } + _ => continue, + } + } + + Err(CapabilityNotShared) + } + + fn reg_new_ingress_buffer( + &mut self, + cap: &SharedCapability, + ) -> Result, MuxDemuxError> { + if let Some(tx) = self.demux.get(cap) { + if !tx.is_closed() { + return Err(StreamAlreadyExists) + } + } + let (ingress_tx, ingress) = mpsc::unbounded_channel(); + self.demux.insert(cap.clone(), ingress_tx); + + Ok(ingress) + } + + fn unmask_msg_id(&self, id: &mut u8) -> Result<&SharedCapability, MuxDemuxError> { + for cap in self.shared_capabilities.iter_caps() { + let offset = cap.relative_message_id_offset(); + let next_offset = offset + cap.num_messages(); + if *id < next_offset { + *id -= offset; + return Ok(cap) + } + } + + Err(MessageIdOutOfRange(*id)) + } + + /// Masks message id with offset relative to the message id suffix reserved for capability + /// message ids. The p2p stream further masks the message id (todo: mask whole message id at + /// once to avoid copying message to mutate id byte or sink BytesMut). + fn mask_msg_id(&self, msg: Bytes) -> Bytes { + let mut masked_bytes = BytesMut::zeroed(msg.len()); + masked_bytes[0] = msg[0] + self.owner.relative_message_id_offset(); + masked_bytes[1..].copy_from_slice(&msg[1..]); + + masked_bytes.freeze() + } + + /// Checks if all clones of this shared stream have been dropped, if true then returns // + /// function to drop the stream. + fn can_drop(&mut self) -> bool { + for tx in self.demux.values() { + if !tx.is_closed() { + return false + } + } + + true + } +} + +impl Stream for MuxDemuxStream +where + S: Stream> + CanDisconnect + Unpin, + MuxDemuxError: From + From<>::Error>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut send_count = 0; + let mut mux_exhausted = false; + + loop { + // send buffered bytes from `StreamClone`s. try send at least as many messages as + // there are stream clones. + if self.inner.poll_ready_unpin(cx).is_ready() { + if let Poll::Ready(Some(item)) = self.mux.poll_recv(cx) { + self.inner.start_send_unpin(item)?; + if send_count < self.demux.len() { + send_count += 1; + continue + } + } else { + mux_exhausted = true; + } + } + + // advances the wire and either yields message for the owner or delegates message to a + // stream clone + let res = self.inner.poll_next_unpin(cx); + if res.is_pending() { + // no message is received. continue to send messages from stream clones as long as + // there are messages left to send. + if !mux_exhausted && self.inner.poll_ready_unpin(cx).is_ready() { + continue + } + // flush before returning pending + _ = self.inner.poll_flush_unpin(cx)?; + } + let mut bytes = match ready!(res) { + Some(Ok(bytes)) => bytes, + Some(Err(err)) => { + _ = self.inner.poll_flush_unpin(cx)?; + return Poll::Ready(Some(Err(err.into()))) + } + None => { + _ = self.inner.poll_flush_unpin(cx)?; + return Poll::Ready(None) + } + }; + + // normalize message id suffix for capability + let cap = self.unmask_msg_id(&mut bytes[0])?; + + // yield message for main stream + if *cap == self.owner { + _ = self.inner.poll_flush_unpin(cx)?; + return Poll::Ready(Some(Ok(bytes))) + } + + // delegate message for stream clone + let tx = self.demux.get(cap).ok_or(CapabilityNotConfigured)?; + tx.send(bytes).map_err(|_| SendIngressBytesFailed)?; + } + } +} + +impl Sink for MuxDemuxStream +where + S: Sink + CanDisconnect + Unpin, + MuxDemuxError: From, +{ + type Error = MuxDemuxError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready_unpin(cx).map_err(Into::into) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let item = self.mask_msg_id(item); + self.inner.start_send_unpin(item).map_err(|e| e.into()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush_unpin(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Ok(item) = self.mux.try_recv() { + self.inner.start_send_unpin(item)?; + } + _ = self.inner.poll_flush_unpin(cx)?; + + self.inner.poll_close_unpin(cx).map_err(Into::into) + } +} + +#[async_trait::async_trait] +impl CanDisconnect for MuxDemuxStream +where + S: Sink + CanDisconnect + Unpin + Send + Sync, + MuxDemuxError: From, +{ + async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), MuxDemuxError> { + if self.can_drop() { + return self.inner.disconnect(reason).await.map_err(Into::into) + } + Err(StreamInUse) + } +} + +/// More or less a weak clone of the stream wrapped in [`MuxDemuxer`] but the bytes belonging to +/// other capabilities have been filtered out. +#[derive(Debug)] +pub struct StreamClone { + // receive bytes from de-/muxer + stream: mpsc::UnboundedReceiver, + // send bytes to de-/muxer + sink: mpsc::UnboundedSender, + // message id offset for capability holding this clone + cap: SharedCapability, +} + +impl StreamClone { + fn mask_msg_id(&self, msg: Bytes) -> Bytes { + let mut masked_bytes = BytesMut::zeroed(msg.len()); + masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset(); + masked_bytes[1..].copy_from_slice(&msg[1..]); + + masked_bytes.freeze() + } +} + +impl Stream for StreamClone { + type Item = BytesMut; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.stream.poll_recv(cx) + } +} + +impl Sink for StreamClone { + type Error = MuxDemuxError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let item = self.mask_msg_id(item); + self.sink.send(item).map_err(|_| SendEgressBytesFailed) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[async_trait::async_trait] +impl CanDisconnect for StreamClone { + async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), MuxDemuxError> { + Err(CannotDisconnectP2PStream) + } +} + +#[cfg(test)] +mod test { + use std::{net::SocketAddr, pin::Pin}; + + use futures::{Future, SinkExt, StreamExt}; + use reth_ecies::util::pk2id; + use reth_primitives::{ + bytes::{BufMut, Bytes, BytesMut}, + ForkFilter, Hardfork, MAINNET, + }; + use secp256k1::{SecretKey, SECP256K1}; + use tokio::{ + net::{TcpListener, TcpStream}, + task::JoinHandle, + }; + use tokio_util::codec::{Decoder, Framed, LengthDelimitedCodec}; + + use crate::{ + capability::{Capability, SharedCapabilities}, + muxdemux::MuxDemuxStream, + protocol::Protocol, + EthVersion, HelloMessageWithProtocols, Status, StatusBuilder, StreamClone, + UnauthedEthStream, UnauthedP2PStream, + }; + + const ETH_68_CAP: Capability = Capability::eth(EthVersion::Eth68); + const ETH_68_PROTOCOL: Protocol = Protocol::new(ETH_68_CAP, 13); + const CUSTOM_CAP: Capability = Capability::new_static("snap", 1); + const CUSTOM_CAP_PROTOCOL: Protocol = Protocol::new(CUSTOM_CAP, 10); + // message IDs `0x00` and `0x01` are normalized for the custom protocol stream + const CUSTOM_REQUEST: [u8; 5] = [0x00, 0x00, 0x01, 0x0, 0xc0]; + const CUSTOM_RESPONSE: [u8; 5] = [0x01, 0x00, 0x01, 0x0, 0xc0]; + + fn shared_caps_eth68() -> SharedCapabilities { + let local_capabilities: Vec = vec![ETH_68_PROTOCOL]; + let peer_capabilities: Vec = vec![ETH_68_CAP]; + SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap() + } + + fn shared_caps_eth68_and_custom() -> SharedCapabilities { + let local_capabilities: Vec = vec![ETH_68_PROTOCOL, CUSTOM_CAP_PROTOCOL]; + let peer_capabilities: Vec = vec![ETH_68_CAP, CUSTOM_CAP]; + SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap() + } + + struct ConnectionBuilder { + local_addr: SocketAddr, + local_hello: HelloMessageWithProtocols, + status: Status, + fork_filter: ForkFilter, + } + + impl ConnectionBuilder { + fn new() -> Self { + let (_secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng()); + + let hello = HelloMessageWithProtocols::builder(pk2id(&pk)) + .protocol(ETH_68_PROTOCOL) + .protocol(CUSTOM_CAP_PROTOCOL) + .build(); + + let local_addr = "127.0.0.1:30303".parse().unwrap(); + + Self { + local_hello: hello, + local_addr, + status: StatusBuilder::default().build(), + fork_filter: MAINNET + .hardfork_fork_filter(Hardfork::Frontier) + .expect("The Frontier fork filter should exist on mainnet"), + } + } + + /// Connects a custom sub protocol stream and executes the given closure with that + /// established stream (main stream is eth). + fn with_connect_custom_protocol( + self, + f_local: F, + f_remote: G, + ) -> (JoinHandle, JoinHandle) + where + F: FnOnce(StreamClone) -> Pin + Send)>> + + Send + + Sync + + Send + + 'static, + G: FnOnce(StreamClone) -> Pin + Send)>> + + Send + + Sync + + Send + + 'static, + { + let local_addr = self.local_addr; + + let local_hello = self.local_hello.clone(); + let status = self.status; + let fork_filter = self.fork_filter.clone(); + + let local_handle = tokio::spawn(async move { + let local_listener = TcpListener::bind(local_addr).await.unwrap(); + let (incoming, _) = local_listener.accept().await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(incoming); + + let protocol_proxy = + connect_protocol(stream, local_hello, status, fork_filter).await; + + f_local(protocol_proxy).await + }); + + let remote_key = SecretKey::new(&mut rand::thread_rng()); + let remote_id = pk2id(&remote_key.public_key(SECP256K1)); + let mut remote_hello = self.local_hello.clone(); + remote_hello.id = remote_id; + let fork_filter = self.fork_filter.clone(); + + let remote_handle = tokio::spawn(async move { + let outgoing = TcpStream::connect(local_addr).await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(outgoing); + + let protocol_proxy = + connect_protocol(stream, remote_hello, status, fork_filter).await; + + f_remote(protocol_proxy).await + }); + + (local_handle, remote_handle) + } + } + + async fn connect_protocol( + stream: Framed, + hello: HelloMessageWithProtocols, + status: Status, + fork_filter: ForkFilter, + ) -> StreamClone { + let unauthed_stream = UnauthedP2PStream::new(stream); + let (p2p_stream, _) = unauthed_stream.handshake(hello).await.unwrap(); + + // ensure that the two share capabilities + assert_eq!(*p2p_stream.shared_capabilities(), shared_caps_eth68_and_custom(),); + + let shared_caps = p2p_stream.shared_capabilities().clone(); + let main_cap = shared_caps.eth().unwrap(); + let proxy_server = + MuxDemuxStream::try_new(p2p_stream, main_cap.capability().into_owned(), shared_caps) + .expect("should start mxdmx stream"); + + let (mut main_stream, _) = + UnauthedEthStream::new(proxy_server).handshake(status, fork_filter).await.unwrap(); + + let protocol_proxy = + main_stream.inner_mut().try_clone_stream(&CUSTOM_CAP).expect("should clone stream"); + + tokio::spawn(async move { + loop { + _ = main_stream.next().await.unwrap() + } + }); + + protocol_proxy + } + + #[test] + fn test_unmask_msg_id() { + let mut msg = BytesMut::with_capacity(1); + msg.put_u8(0x07); // eth msg id + + let mxdmx_stream = + MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth67), shared_caps_eth68()) + .unwrap(); + _ = mxdmx_stream.unmask_msg_id(&mut msg[0]).unwrap(); + + assert_eq!(msg.as_ref(), &[0x07]); + } + + #[test] + fn test_mask_msg_id() { + let mut msg = BytesMut::with_capacity(2); + msg.put_u8(0x10); // eth msg id + msg.put_u8(0x20); // some msg data + + let mxdmx_stream = + MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth66), shared_caps_eth68()) + .unwrap(); + let egress_bytes = mxdmx_stream.mask_msg_id(msg.freeze()); + + assert_eq!(egress_bytes.as_ref(), &[0x10, 0x20]); + } + + #[test] + fn test_unmask_msg_id_cap_not_in_shared_range() { + let mut msg = BytesMut::with_capacity(1); + msg.put_u8(0x11); + + let mxdmx_stream = + MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth68), shared_caps_eth68()) + .unwrap(); + + assert!(mxdmx_stream.unmask_msg_id(&mut msg[0]).is_err()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_mux_demux() { + let builder = ConnectionBuilder::new(); + + let request = Bytes::from(&CUSTOM_REQUEST[..]); + let response = Bytes::from(&CUSTOM_RESPONSE[..]); + let expected_request = request.clone(); + let expected_response = response.clone(); + + let (local_handle, remote_handle) = builder.with_connect_custom_protocol( + // send request from local addr + |mut protocol_proxy| { + Box::pin(async move { + protocol_proxy.send(request).await.unwrap(); + protocol_proxy.next().await.unwrap() + }) + }, + // respond from remote addr + |mut protocol_proxy| { + Box::pin(async move { + let request = protocol_proxy.next().await.unwrap(); + protocol_proxy.send(response).await.unwrap(); + request + }) + }, + ); + + let (local_res, remote_res) = tokio::join!(local_handle, remote_handle); + + // remote address receives request + assert_eq!(expected_request, remote_res.unwrap().freeze()); + // local address receives response + assert_eq!(expected_response, local_res.unwrap().freeze()); + } +} diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index 1d3f660643..3ac340add7 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -301,11 +301,6 @@ impl P2PStream { &self.shared_capabilities } - /// Returns `true` if the connection is about to disconnect. - pub fn is_disconnecting(&self) -> bool { - self.disconnecting - } - /// Returns `true` if the stream has outgoing capacity. fn has_outgoing_capacity(&self) -> bool { self.outgoing_messages.len() < self.outgoing_message_buffer_capacity @@ -326,7 +321,16 @@ impl P2PStream { ping.encode(&mut ping_bytes); self.outgoing_messages.push_back(ping_bytes.freeze()); } +} +pub trait DisconnectP2P { + /// Starts to gracefully disconnect. + fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>; + /// Returns `true` if the connection is about to disconnect. + fn is_disconnecting(&self) -> bool; +} + +impl DisconnectP2P for P2PStream { /// Starts to gracefully disconnect the connection by sending a Disconnect message and stop /// reading new messages. /// @@ -335,7 +339,7 @@ impl P2PStream { /// # Errors /// /// Returns an error only if the message fails to compress. - pub fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), snap::Error> { + fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> { // clear any buffered messages and queue in self.outgoing_messages.clear(); let disconnect = P2PMessage::Disconnect(reason); @@ -365,6 +369,10 @@ impl P2PStream { self.disconnecting = true; Ok(()) } + + fn is_disconnecting(&self) -> bool { + self.disconnecting + } } impl P2PStream diff --git a/crates/net/network/src/session/active.rs b/crates/net/network/src/session/active.rs index e737d6191f..25e53a194b 100644 --- a/crates/net/network/src/session/active.rs +++ b/crates/net/network/src/session/active.rs @@ -16,7 +16,7 @@ use reth_eth_wire::{ capability::Capabilities, errors::{EthHandshakeError, EthStreamError, P2PStreamError}, message::{EthBroadcastMessage, RequestPair}, - DisconnectReason, EthMessage, EthStream, P2PStream, + DisconnectP2P, DisconnectReason, EthMessage, EthStream, P2PStream, }; use reth_interfaces::p2p::error::RequestError; use reth_metrics::common::mpsc::MeteredPollSender;