diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs index 9eea58b6a6..37a33f4dd4 100644 --- a/crates/net/eth-wire/src/capability.rs +++ b/crates/net/eth-wire/src/capability.rs @@ -404,6 +404,13 @@ pub struct UnsupportedCapabilityError { capability: Capability, } +impl UnsupportedCapabilityError { + /// Creates a new error with the given capability + pub const fn new(capability: Capability) -> Self { + Self { capability } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/net/network/src/error.rs b/crates/net/network/src/error.rs index 8156392b22..f88d8bd815 100644 --- a/crates/net/network/src/error.rs +++ b/crates/net/network/src/error.rs @@ -216,7 +216,7 @@ impl SessionError for PendingSessionHandshakeError { ECIESErrorImpl::Secp256k1(_) | ECIESErrorImpl::InvalidHandshake { .. } ), - Self::Timeout => false, + Self::Timeout | Self::UnsupportedExtraCapability => false, } } @@ -235,6 +235,7 @@ impl SessionError for PendingSessionHandshakeError { ECIESErrorImpl::InvalidHandshake { .. } ), Self::Timeout => false, + Self::UnsupportedExtraCapability => true, } } @@ -243,6 +244,7 @@ impl SessionError for PendingSessionHandshakeError { Self::Eth(eth) => eth.should_backoff(), Self::Ecies(_) => Some(BackoffKind::Low), Self::Timeout => Some(BackoffKind::Medium), + Self::UnsupportedExtraCapability => Some(BackoffKind::High), } } } diff --git a/crates/net/network/src/protocol.rs b/crates/net/network/src/protocol.rs index aa0749c2c7..6479ac77e5 100644 --- a/crates/net/network/src/protocol.rs +++ b/crates/net/network/src/protocol.rs @@ -147,7 +147,7 @@ impl RlpxSubProtocols { /// A set of additional RLPx-based sub-protocol connection handlers. #[derive(Default)] -pub(crate) struct RlpxSubProtocolHandlers(Vec>); +pub(crate) struct RlpxSubProtocolHandlers(pub(crate) Vec>); impl RlpxSubProtocolHandlers { /// Returns all handlers. @@ -200,6 +200,13 @@ impl DynProtocolHandler for T { pub(crate) trait DynConnectionHandler: Send + Sync + 'static { fn protocol(&self) -> Protocol; + fn on_unsupported_by_peer( + self: Box, + supported: &SharedCapabilities, + direction: Direction, + peer_id: PeerId, + ) -> OnNotSupported; + fn into_connection( self: Box, direction: Direction, @@ -213,6 +220,15 @@ impl DynConnectionHandler for T { T::protocol(self) } + fn on_unsupported_by_peer( + self: Box, + supported: &SharedCapabilities, + direction: Direction, + peer_id: PeerId, + ) -> OnNotSupported { + T::on_unsupported_by_peer(*self, supported, direction, peer_id) + } + fn into_connection( self: Box, direction: Direction, diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 4b007fffc1..74be82e73a 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -26,7 +26,7 @@ use std::{ use crate::{ message::PeerMessage, metrics::SessionManagerMetrics, - protocol::{IntoRlpxSubProtocol, RlpxSubProtocolHandlers, RlpxSubProtocols}, + protocol::{IntoRlpxSubProtocol, OnNotSupported, RlpxSubProtocolHandlers, RlpxSubProtocols}, session::active::ActiveSession, }; use counter::SessionCounter; @@ -771,6 +771,9 @@ pub enum PendingSessionHandshakeError { /// Thrown when the authentication timed out #[error("authentication timed out")] Timeout, + /// Thrown when the remote lacks the required capability + #[error("Mandatory extra capability unsupported")] + UnsupportedExtraCapability, } impl PendingSessionHandshakeError { @@ -1013,6 +1016,32 @@ async fn authenticate_stream( } }; + // if we have extra handlers, check if it must be supported by the remote + if !extra_handlers.is_empty() { + // ensure that no extra handlers that aren't supported are not mandatory + while let Some(pos) = extra_handlers.iter().position(|handler| { + p2p_stream + .shared_capabilities() + .ensure_matching_capability(&handler.protocol().cap) + .is_err() + }) { + let handler = extra_handlers.remove(pos); + if handler.on_unsupported_by_peer( + p2p_stream.shared_capabilities(), + direction, + their_hello.id, + ) == OnNotSupported::Disconnect + { + return PendingSessionEvent::Disconnected { + remote_addr, + session_id, + direction, + error: Some(PendingSessionHandshakeError::UnsupportedExtraCapability), + }; + } + } + } + // Ensure we negotiated mandatory eth protocol let eth_version = match p2p_stream.shared_capabilities().eth_version() { Ok(version) => version, @@ -1027,6 +1056,7 @@ async fn authenticate_stream( }; let (conn, their_status) = if p2p_stream.shared_capabilities().len() == 1 { + // if the shared caps are 1, we know both support the eth version // if the hello handshake was successful we can try status handshake // // Before trying status handshake, set up the version to negotiated shared version @@ -1058,6 +1088,7 @@ async fn authenticate_stream( for handler in extra_handlers.into_iter() { let cap = handler.protocol().cap; let remote_peer_id = their_hello.id; + multiplex_stream .install_protocol(&cap, move |conn| { handler.into_connection(direction, remote_peer_id, conn) diff --git a/crates/net/network/tests/it/multiplex.rs b/crates/net/network/tests/it/multiplex.rs index 5747d20271..c0384c24a9 100644 --- a/crates/net/network/tests/it/multiplex.rs +++ b/crates/net/network/tests/it/multiplex.rs @@ -14,10 +14,12 @@ use reth_eth_wire::{ }; use reth_network::{ protocol::{ConnectionHandler, OnNotSupported, ProtocolHandler}, - test_utils::Testnet, + test_utils::{NetworkEventStream, Testnet}, + NetworkConfigBuilder, NetworkEventListenerProvider, NetworkManager, }; -use reth_network_api::{Direction, PeerId}; -use reth_provider::test_utils::MockEthProvider; +use reth_network_api::{Direction, NetworkInfo, PeerId, Peers}; +use reth_provider::{noop::NoopProvider, test_utils::MockEthProvider}; +use secp256k1::SecretKey; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -200,7 +202,7 @@ impl ConnectionHandler for PingPongConnectionHandler { _direction: Direction, _peer_id: PeerId, ) -> OnNotSupported { - OnNotSupported::KeepAlive + OnNotSupported::Disconnect } fn into_connection( @@ -275,6 +277,47 @@ impl Stream for PingPongProtoConnection { } } +#[tokio::test(flavor = "multi_thread")] +async fn test_connect_to_non_multiplex_peer() { + reth_tracing::init_test_tracing(); + + let net = Testnet::create(1).await; + + let secret_key = SecretKey::new(&mut rand::thread_rng()); + + let config = NetworkConfigBuilder::eth(secret_key) + .listener_port(0) + .disable_discovery() + .build(NoopProvider::default()); + + let mut network = NetworkManager::new(config).await.unwrap(); + + let (tx, _) = mpsc::unbounded_channel(); + network.add_rlpx_sub_protocol(PingPongProtoHandler { state: ProtocolState { events: tx } }); + + let handle = network.handle().clone(); + tokio::task::spawn(network); + + // create networkeventstream to get the next session event easily. + let events = handle.event_listener(); + let mut event_stream = NetworkEventStream::new(events); + + let mut handles = net.handles(); + let handle0 = handles.next().unwrap(); + drop(handles); + + let _handle = net.spawn(); + + handle.add_peer(*handle0.peer_id(), handle0.local_addr()); + + let added_peer_id = event_stream.peer_added().await.unwrap(); + assert_eq!(added_peer_id, *handle0.peer_id()); + + // peer with mismatched capability version should fail to connect and be removed. + let removed_peer_id = event_stream.peer_removed().await.unwrap(); + assert_eq!(removed_peer_id, *handle0.peer_id()); +} + #[tokio::test(flavor = "multi_thread")] async fn test_proto_multiplex() { reth_tracing::init_test_tracing();