From c168ef4433c3fef02c63127fc0f963cd9177eae4 Mon Sep 17 00:00:00 2001 From: Dan Cline <6798349+Rjected@users.noreply.github.com> Date: Wed, 22 Feb 2023 06:18:12 -0500 Subject: [PATCH] feat: implement eth handshake disconnects (#1494) --- Cargo.lock | 1 + crates/net/eth-wire/Cargo.toml | 4 +- crates/net/eth-wire/src/disconnect.rs | 44 +++++++++++++++ crates/net/eth-wire/src/ethstream.rs | 68 ++++++++++++++++++------ crates/net/eth-wire/src/lib.rs | 2 +- crates/net/eth-wire/src/p2pstream.rs | 55 +++++++++++-------- crates/net/network/src/session/active.rs | 4 +- 7 files changed, 137 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c383b29a6..563084db04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4557,6 +4557,7 @@ name = "reth-eth-wire" version = "0.1.0" dependencies = [ "arbitrary", + "async-trait", "bytes", "ethers-core", "futures", diff --git a/crates/net/eth-wire/Cargo.toml b/crates/net/eth-wire/Cargo.toml index 264c3a449c..fb76a352b6 100644 --- a/crates/net/eth-wire/Cargo.toml +++ b/crates/net/eth-wire/Cargo.toml @@ -15,12 +15,14 @@ serde = { version = "1", optional = true } # reth reth-codecs = { path = "../../storage/codecs" } reth-primitives = { path = "../../primitives" } +reth-ecies = { path = "../ecies" } reth-rlp = { path = "../../rlp", features = ["alloc", "derive", "std", "ethereum-types", "smol_str"] } # used for Chain and builders ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false } tokio = { version = "1.21.2", features = ["full"] } +tokio-util = { version = "0.7.4", features = ["io", "codec"] } futures = "0.3.24" tokio-stream = "0.1.11" pin-project = "1.0" @@ -28,6 +30,7 @@ tracing = "0.1.37" snap = "1.0.5" smol_str = "0.1" metrics = "0.20.1" +async-trait = "0.1" # arbitrary utils arbitrary = { version = "1.1.7", features = ["derive"], optional = true } @@ -36,7 +39,6 @@ proptest-derive = { version = "0.3", optional = true } [dev-dependencies] reth-primitives = { path = "../../primitives", features = ["arbitrary"] } -reth-ecies = { path = "../ecies" } reth-tracing = { path = "../../tracing" } ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false } diff --git a/crates/net/eth-wire/src/disconnect.rs b/crates/net/eth-wire/src/disconnect.rs index 6ddf4367f9..b72d7bf9a2 100644 --- a/crates/net/eth-wire/src/disconnect.rs +++ b/crates/net/eth-wire/src/disconnect.rs @@ -1,10 +1,15 @@ //! Disconnect +use bytes::Bytes; +use futures::{Sink, SinkExt}; use reth_codecs::derive_arbitrary; +use reth_ecies::stream::ECIESStream; use reth_primitives::bytes::{Buf, BufMut}; use reth_rlp::{Decodable, DecodeError, Encodable, Header}; use std::fmt::Display; use thiserror::Error; +use tokio::io::AsyncWrite; +use tokio_util::codec::{Encoder, Framed}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -143,6 +148,45 @@ impl Decodable for DisconnectReason { } } +/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using +/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the +/// underlying stream supports it. +#[async_trait::async_trait] +pub trait CanDisconnect: Sink + Unpin + Sized { + /// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect + /// information if the stream implements a protocol that can carry the additional disconnect + /// metadata. + async fn disconnect( + &mut self, + reason: DisconnectReason, + ) -> Result<(), >::Error>; +} + +// basic impls for things like Framed +#[async_trait::async_trait] +impl CanDisconnect for Framed +where + T: AsyncWrite + Unpin + Send, + U: Encoder + Send, +{ + async fn disconnect( + &mut self, + _reason: DisconnectReason, + ) -> Result<(), >::Error> { + self.close().await + } +} + +#[async_trait::async_trait] +impl CanDisconnect for ECIESStream +where + S: AsyncWrite + Unpin + Send, +{ + async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> { + self.close().await + } +} + #[cfg(test)] mod tests { use crate::{p2pstream::P2PMessage, DisconnectReason}; diff --git a/crates/net/eth-wire/src/ethstream.rs b/crates/net/eth-wire/src/ethstream.rs index c61113c756..605b8e6a67 100644 --- a/crates/net/eth-wire/src/ethstream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -2,7 +2,7 @@ use crate::{ errors::{EthHandshakeError, EthStreamError}, message::{EthBroadcastMessage, ProtocolBroadcastMessage}, types::{EthMessage, ProtocolMessage, Status}, - EthVersion, + CanDisconnect, DisconnectReason, EthVersion, }; use futures::{ready, Sink, SinkExt, StreamExt}; use pin_project::pin_project; @@ -43,8 +43,8 @@ impl UnauthedEthStream { impl UnauthedEthStream where - S: Stream> + Sink + Unpin, - EthStreamError: From, + S: Stream> + CanDisconnect + Unpin, + EthStreamError: From + From<>::Error>, { /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status` /// handshake is completed successfully. This also returns the `Status` message sent by the @@ -67,13 +67,18 @@ where self.inner.send(our_status_bytes).await?; tracing::trace!("waiting for eth status from peer"); - let their_msg = self - .inner - .next() - .await - .ok_or(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))??; + let their_msg_res = self.inner.next().await; + + let their_msg = match their_msg_res { + Some(msg) => msg, + None => { + self.inner.disconnect(DisconnectReason::DisconnectRequested).await?; + return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse)) + } + }?; if their_msg.len() > MAX_MESSAGE_SIZE { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; return Err(EthStreamError::MessageTooBig(their_msg.len())) } @@ -82,6 +87,7 @@ where Ok(m) => m, Err(err) => { tracing::debug!("decode error in eth handshake: msg={their_msg:x}"); + self.inner.disconnect(DisconnectReason::DisconnectRequested).await?; return Err(err) } }; @@ -95,6 +101,7 @@ where "validating incoming eth status from peer" ); if status.genesis != resp.genesis { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; return Err(EthHandshakeError::MismatchedGenesis { expected: status.genesis, got: resp.genesis, @@ -103,6 +110,7 @@ where } if status.version != resp.version { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; return Err(EthHandshakeError::MismatchedProtocolVersion { expected: status.version, got: resp.version, @@ -111,6 +119,7 @@ where } if status.chain != resp.chain { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; return Err(EthHandshakeError::MismatchedChain { expected: status.chain, got: resp.chain, @@ -121,6 +130,7 @@ where // TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times // larger, it will still fit within 100 bits if status.total_difficulty.bit_len() > 100 { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge { maximum: 100, got: status.total_difficulty.bit_len(), @@ -128,7 +138,12 @@ where .into()) } - fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)?; + if let Err(err) = + fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork) + { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; + return Err(err.into()) + } // now we can create the `EthStream` because the peer has successfully completed // the handshake @@ -136,9 +151,12 @@ where Ok((stream, resp)) } - _ => Err(EthStreamError::EthHandshakeError( - EthHandshakeError::NonStatusMessageInHandshake, - )), + _ => { + self.inner.disconnect(DisconnectReason::ProtocolBreach).await?; + Err(EthStreamError::EthHandshakeError( + EthHandshakeError::NonStatusMessageInHandshake, + )) + } } } } @@ -239,10 +257,10 @@ where } } -impl Sink for EthStream +impl Sink for EthStream where - S: Sink + Unpin, - EthStreamError: From, + S: CanDisconnect + Unpin, + EthStreamError: From<>::Error>, { type Error = EthStreamError; @@ -252,6 +270,15 @@ where fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> { if matches!(item, EthMessage::Status(_)) { + // TODO: to disconnect here we would need to do something similar to P2PStream's + // start_disconnect, which would ideally be a part of the CanDisconnect trait, or at + // least similar. + // + // Other parts of reth do not need traits like CanDisconnect because they work + // exclusively with EthStream>, where the inner P2PStream is accessible, + // allowing for its start_disconnect method to be called. + // + // self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach); return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake)) } @@ -273,6 +300,17 @@ where } } +#[async_trait::async_trait] +impl CanDisconnect for EthStream +where + S: CanDisconnect + Send, + EthStreamError: From<>::Error>, +{ + async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> { + self.inner.disconnect(reason).await.map_err(Into::into) + } +} + #[cfg(test)] mod tests { use super::UnauthedEthStream; diff --git a/crates/net/eth-wire/src/lib.rs b/crates/net/eth-wire/src/lib.rs index 309f38a226..44715489b6 100644 --- a/crates/net/eth-wire/src/lib.rs +++ b/crates/net/eth-wire/src/lib.rs @@ -24,7 +24,7 @@ pub use tokio_util::codec::{ }; pub use crate::{ - disconnect::DisconnectReason, + disconnect::{CanDisconnect, DisconnectReason}, ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE}, hello::HelloMessage, p2pstream::{P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream}, diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index 6bbf31e231..25be3ae0d2 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -1,6 +1,7 @@ #![allow(dead_code, unreachable_pub, missing_docs, unused_variables)] use crate::{ capability::{Capability, SharedCapability}, + disconnect::CanDisconnect, errors::{P2PHandshakeError, P2PStreamError}, pinger::{Pinger, PingerEvent}, DisconnectReason, HelloMessage, @@ -72,25 +73,6 @@ impl UnauthedP2PStream { } } -impl UnauthedP2PStream -where - S: Sink + Unpin, -{ - /// Send a disconnect message during the handshake. This is sent without snappy compression. - pub async fn send_disconnect( - &mut self, - reason: DisconnectReason, - ) -> Result<(), P2PStreamError> { - let mut buf = BytesMut::new(); - P2PMessage::Disconnect(reason).encode(&mut buf); - tracing::trace!( - %reason, - "Sending disconnect message during the handshake", - ); - self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io) - } -} - impl UnauthedP2PStream where S: Stream> + Sink + Unpin, @@ -180,6 +162,35 @@ where } } +impl UnauthedP2PStream +where + S: Sink + Unpin, +{ + /// Send a disconnect message during the handshake. This is sent without snappy compression. + pub async fn send_disconnect( + &mut self, + reason: DisconnectReason, + ) -> Result<(), P2PStreamError> { + let mut buf = BytesMut::new(); + P2PMessage::Disconnect(reason).encode(&mut buf); + tracing::trace!( + %reason, + "Sending disconnect message during the handshake", + ); + self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io) + } +} + +#[async_trait::async_trait] +impl CanDisconnect for P2PStream +where + S: Sink + Unpin + Send + Sync, +{ + async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> { + self.disconnect(reason).await + } +} + /// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p` /// protocol messages. #[pin_project] @@ -284,13 +295,13 @@ impl P2PStream { impl P2PStream where - S: Sink + Unpin, + S: Sink + Unpin + Send, { /// Disconnects the connection by sending a disconnect message. /// /// This future resolves once the disconnect message has been sent and the stream has been /// closed. - pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> { + pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> { self.start_disconnect(reason)?; self.close().await } @@ -821,7 +832,7 @@ mod tests { let (server_hello, _) = eth_hello(); - let (p2p_stream, _) = + let (mut p2p_stream, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap(); p2p_stream.disconnect(expected_disconnect).await.unwrap(); diff --git a/crates/net/network/src/session/active.rs b/crates/net/network/src/session/active.rs index d1e99e6cd9..b1a715fafe 100644 --- a/crates/net/network/src/session/active.rs +++ b/crates/net/network/src/session/active.rs @@ -753,9 +753,9 @@ mod tests { &self, local_addr: SocketAddr, f: F, - ) -> Pin + Send + Sync>> + ) -> Pin + Send>> where - F: FnOnce(EthStream>>) -> O + Send + Sync + 'static, + F: FnOnce(EthStream>>) -> O + Send + 'static, O: Future + Send + Sync, { let status = self.status;