diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index 27d831dc21..4c810b9ff9 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -28,7 +28,7 @@ use crate::{ peers::{PeersHandle, PeersManager}, session::SessionManager, state::NetworkState, - swarm::{Swarm, SwarmEvent}, + swarm::{NetworkConnectionState, Swarm, SwarmEvent}, transactions::NetworkTransactionEvent, FetchClient, NetworkBuilder, }; @@ -206,7 +206,7 @@ where Arc::clone(&num_active_peers), ); - let swarm = Swarm::new(incoming, sessions, state); + let swarm = Swarm::new(incoming, sessions, state, NetworkConnectionState::default()); let (to_manager_tx, from_handle_rx) = mpsc::unbounded_channel(); @@ -513,6 +513,17 @@ where NetworkHandleMessage::DisconnectPeer(peer_id, reason) => { self.swarm.sessions_mut().disconnect(peer_id, reason); } + NetworkHandleMessage::Shutdown(tx) => { + // Set connection status to `Shutdown`. Stops node to accept + // new incoming connections as well as sending connection requests to newly + // discovered nodes. + self.swarm.on_shutdown_requested(); + // Disconnect all active connections + self.swarm.sessions_mut().disconnect_all(Some(DisconnectReason::ClientQuitting)); + // drop pending connections + self.swarm.sessions_mut().disconnect_all_pending(); + let _ = tx.send(()); + } NetworkHandleMessage::ReputationChange(peer_id, kind) => { self.swarm.state_mut().peers_mut().apply_reputation_change(&peer_id, kind); } diff --git a/crates/net/network/src/network.rs b/crates/net/network/src/network.rs index 84bada0480..c901c36e5c 100644 --- a/crates/net/network/src/network.rs +++ b/crates/net/network/src/network.rs @@ -160,6 +160,16 @@ impl NetworkHandle { pub fn bandwidth_meter(&self) -> &BandwidthMeter { &self.inner.bandwidth_meter } + + /// Send message to gracefully shutdown node. + /// + /// This will disconnect all active and pending sessions and prevent + /// new connections to be established. + pub async fn shutdown(&self) -> Result<(), oneshot::error::RecvError> { + let (tx, rx) = oneshot::channel(); + self.send_message(NetworkHandleMessage::Shutdown(tx)); + rx.await + } } // === API Implementations === @@ -302,4 +312,6 @@ pub(crate) enum NetworkHandleMessage { GetPeerInfo(oneshot::Sender>), /// Get PeerInfo for a specific peer GetPeerInfoById(PeerId, oneshot::Sender>), + /// Gracefully shutdown network + Shutdown(oneshot::Sender<()>), } diff --git a/crates/net/network/src/peers/manager.rs b/crates/net/network/src/peers/manager.rs index 7e7d617d09..40f35ad6cf 100644 --- a/crates/net/network/src/peers/manager.rs +++ b/crates/net/network/src/peers/manager.rs @@ -163,7 +163,6 @@ impl PeersManager { if !self.connection_info.has_in_capacity() { return Err(InboundConnectionError::ExceedsLimit(self.connection_info.max_inbound)) } - // keep track of new connection self.connection_info.inc_in(); Ok(()) diff --git a/crates/net/network/src/session/handle.rs b/crates/net/network/src/session/handle.rs index dbadfa9fa1..30925507ad 100644 --- a/crates/net/network/src/session/handle.rs +++ b/crates/net/network/src/session/handle.rs @@ -24,11 +24,22 @@ use tokio::{ #[derive(Debug)] pub(crate) struct PendingSessionHandle { /// Can be used to tell the session to disconnect the connection/abort the handshake process. - pub(crate) _disconnect_tx: oneshot::Sender<()>, + pub(crate) disconnect_tx: Option>, /// The direction of the session pub(crate) direction: Direction, } +// === impl PendingSessionHandle === + +impl PendingSessionHandle { + /// Sends a disconnect command to the pending session. + pub(crate) fn disconnect(&mut self) { + if let Some(tx) = self.disconnect_tx.take() { + let _ = tx.send(()); + } + } +} + /// An established session with a remote peer. /// /// Within an active session that supports the `Ethereum Wire Protocol `, three high-level tasks can diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index eb179f09cf..cdcf95b5f7 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -222,8 +222,10 @@ impl SessionManager { self.fork_filter.clone(), )); - let handle = - PendingSessionHandle { _disconnect_tx: disconnect_tx, direction: Direction::Incoming }; + let handle = PendingSessionHandle { + disconnect_tx: Some(disconnect_tx), + direction: Direction::Incoming, + }; self.pending_sessions.insert(session_id, handle); self.counter.inc_pending_inbound(); Ok(session_id) @@ -248,7 +250,7 @@ impl SessionManager { )); let handle = PendingSessionHandle { - _disconnect_tx: disconnect_tx, + disconnect_tx: Some(disconnect_tx), direction: Direction::Outgoing(remote_peer_id), }; self.pending_sessions.insert(session_id, handle); @@ -265,6 +267,23 @@ impl SessionManager { } } + /// Initiates a shutdown of all sessions. + /// + /// It will trigger the disconnect on all the session tasks to gracefully terminate. The result + /// will be picked by the receiver. + pub(crate) fn disconnect_all(&self, reason: Option) { + for (_, session) in self.active_sessions.iter() { + session.disconnect(reason); + } + } + + /// Disconnects all pending sessions. + pub(crate) fn disconnect_all_pending(&mut self) { + for (_, session) in self.pending_sessions.iter_mut() { + session.disconnect(); + } + } + /// Sends a message to the peer's session pub(crate) fn send_message(&mut self, peer_id: &PeerId, msg: PeerMessage) { if let Some(session) = self.active_sessions.get_mut(peer_id) { diff --git a/crates/net/network/src/swarm.rs b/crates/net/network/src/swarm.rs index f8e31fcbf4..949abac9bd 100644 --- a/crates/net/network/src/swarm.rs +++ b/crates/net/network/src/swarm.rs @@ -69,6 +69,8 @@ pub(crate) struct Swarm { sessions: SessionManager, /// Tracks the entire state of the network and handles events received from the sessions. state: NetworkState, + /// Tracks the connection state of the node + net_connection_state: NetworkConnectionState, } // === impl Swarm === @@ -82,8 +84,9 @@ where incoming: ConnectionListener, sessions: SessionManager, state: NetworkState, + net_connection_state: NetworkConnectionState, ) -> Self { - Self { incoming, sessions, state } + Self { incoming, sessions, state, net_connection_state } } /// Access to the state. @@ -189,6 +192,10 @@ where return Some(SwarmEvent::TcpListenerClosed { remote_addr: address }) } ListenerEvent::Incoming { stream, remote_addr } => { + // Reject incoming connection if node is shutting down. + if self.is_shutting_down() { + return None + } // ensure we can handle an incoming connection from this address if let Err(err) = self.state_mut().peers_mut().on_incoming_pending_session(remote_addr.ip()) @@ -244,6 +251,10 @@ where StateAction::PeerAdded(peer_id) => return Some(SwarmEvent::PeerAdded(peer_id)), StateAction::PeerRemoved(peer_id) => return Some(SwarmEvent::PeerRemoved(peer_id)), StateAction::DiscoveredNode { peer_id, socket_addr, fork_id } => { + // Don't try to connect to peer if node is shutting down + if self.is_shutting_down() { + return None + } // Insert peer only if no fork id or a valid fork id if fork_id.map_or_else(|| true, |f| self.sessions.is_valid_fork_id(f)) { self.state_mut().peers_mut().add_peer(peer_id, socket_addr, fork_id); @@ -259,6 +270,17 @@ where } None } + + /// Set network connection state to `ShuttingDown` + pub(crate) fn on_shutdown_requested(&mut self) { + self.net_connection_state = NetworkConnectionState::ShuttingDown; + } + + /// Checks if the node's network connection state is 'ShuttingDown' + #[inline] + fn is_shutting_down(&self) -> bool { + matches!(self.net_connection_state, NetworkConnectionState::ShuttingDown) + } } impl Stream for Swarm @@ -394,3 +416,12 @@ pub(crate) enum SwarmEvent { /// Failed to establish a tcp stream to the given address/node OutgoingConnectionError { remote_addr: SocketAddr, peer_id: PeerId, error: io::Error }, } + +/// Represents the state of the connection of the node. If shutting down, +/// new connections won't be established. +#[derive(Default)] +pub(crate) enum NetworkConnectionState { + #[default] + Active, + ShuttingDown, +} diff --git a/crates/net/network/tests/it/connect.rs b/crates/net/network/tests/it/connect.rs index 6e2baf3748..243c807a25 100644 --- a/crates/net/network/tests/it/connect.rs +++ b/crates/net/network/tests/it/connect.rs @@ -509,3 +509,45 @@ async fn test_geth_disconnect() { .await .unwrap(); } + +#[tokio::test(flavor = "multi_thread")] +async fn test_shutdown() { + let net = Testnet::create(3).await; + + let mut handles = net.handles(); + let handle0 = handles.next().unwrap(); + let handle1 = handles.next().unwrap(); + let handle2 = handles.next().unwrap(); + + drop(handles); + let _handle = net.spawn(); + + handle0.add_peer(*handle1.peer_id(), handle1.local_addr()); + handle0.add_peer(*handle2.peer_id(), handle2.local_addr()); + handle1.add_peer(*handle2.peer_id(), handle2.local_addr()); + + let mut expected_connections = HashSet::from([*handle1.peer_id(), *handle2.peer_id()]); + + let mut listener0 = NetworkEventStream::new(handle0.event_listener()); + + // Before shutting down, we have two connected peers + let peer1 = listener0.next_session_established().await.unwrap(); + let peer2 = listener0.next_session_established().await.unwrap(); + assert!(expected_connections.contains(&peer1)); + assert!(expected_connections.contains(&peer2)); + assert_eq!(handle0.num_connected_peers(), 2); + + handle0.shutdown().await.unwrap(); + + // All sessions get disconnected + let (peer1, _reason) = listener0.next_session_closed().await.unwrap(); + let (peer2, _reason) = listener0.next_session_closed().await.unwrap(); + assert!(expected_connections.remove(&peer1)); + assert!(expected_connections.remove(&peer2)); + assert_eq!(handle0.num_connected_peers(), 0); + + // New connections are rejected + handle0.add_peer(*handle1.peer_id(), handle1.local_addr()); + let (_peer, reason) = listener0.next_session_closed().await.unwrap(); + assert_eq!(reason, Some(DisconnectReason::DisconnectRequested)); +}