mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-01-29 00:58:11 -05:00
feat(net): add shutdown network signal (#1011)
Co-authored-by: lambdaclass-user <github@lambdaclass.com>
This commit is contained in:
committed by
GitHub
parent
db5410b84b
commit
dab1f4f497
@@ -28,7 +28,7 @@ use crate::{
|
||||
peers::{PeersHandle, PeersManager},
|
||||
session::SessionManager,
|
||||
state::NetworkState,
|
||||
swarm::{Swarm, SwarmEvent},
|
||||
swarm::{NetworkConnectionState, Swarm, SwarmEvent},
|
||||
transactions::NetworkTransactionEvent,
|
||||
FetchClient, NetworkBuilder,
|
||||
};
|
||||
@@ -206,7 +206,7 @@ where
|
||||
Arc::clone(&num_active_peers),
|
||||
);
|
||||
|
||||
let swarm = Swarm::new(incoming, sessions, state);
|
||||
let swarm = Swarm::new(incoming, sessions, state, NetworkConnectionState::default());
|
||||
|
||||
let (to_manager_tx, from_handle_rx) = mpsc::unbounded_channel();
|
||||
|
||||
@@ -513,6 +513,17 @@ where
|
||||
NetworkHandleMessage::DisconnectPeer(peer_id, reason) => {
|
||||
self.swarm.sessions_mut().disconnect(peer_id, reason);
|
||||
}
|
||||
NetworkHandleMessage::Shutdown(tx) => {
|
||||
// Set connection status to `Shutdown`. Stops node to accept
|
||||
// new incoming connections as well as sending connection requests to newly
|
||||
// discovered nodes.
|
||||
self.swarm.on_shutdown_requested();
|
||||
// Disconnect all active connections
|
||||
self.swarm.sessions_mut().disconnect_all(Some(DisconnectReason::ClientQuitting));
|
||||
// drop pending connections
|
||||
self.swarm.sessions_mut().disconnect_all_pending();
|
||||
let _ = tx.send(());
|
||||
}
|
||||
NetworkHandleMessage::ReputationChange(peer_id, kind) => {
|
||||
self.swarm.state_mut().peers_mut().apply_reputation_change(&peer_id, kind);
|
||||
}
|
||||
|
||||
@@ -160,6 +160,16 @@ impl NetworkHandle {
|
||||
pub fn bandwidth_meter(&self) -> &BandwidthMeter {
|
||||
&self.inner.bandwidth_meter
|
||||
}
|
||||
|
||||
/// Send message to gracefully shutdown node.
|
||||
///
|
||||
/// This will disconnect all active and pending sessions and prevent
|
||||
/// new connections to be established.
|
||||
pub async fn shutdown(&self) -> Result<(), oneshot::error::RecvError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.send_message(NetworkHandleMessage::Shutdown(tx));
|
||||
rx.await
|
||||
}
|
||||
}
|
||||
|
||||
// === API Implementations ===
|
||||
@@ -302,4 +312,6 @@ pub(crate) enum NetworkHandleMessage {
|
||||
GetPeerInfo(oneshot::Sender<Vec<PeerInfo>>),
|
||||
/// Get PeerInfo for a specific peer
|
||||
GetPeerInfoById(PeerId, oneshot::Sender<Option<PeerInfo>>),
|
||||
/// Gracefully shutdown network
|
||||
Shutdown(oneshot::Sender<()>),
|
||||
}
|
||||
|
||||
@@ -163,7 +163,6 @@ impl PeersManager {
|
||||
if !self.connection_info.has_in_capacity() {
|
||||
return Err(InboundConnectionError::ExceedsLimit(self.connection_info.max_inbound))
|
||||
}
|
||||
|
||||
// keep track of new connection
|
||||
self.connection_info.inc_in();
|
||||
Ok(())
|
||||
|
||||
@@ -24,11 +24,22 @@ use tokio::{
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PendingSessionHandle {
|
||||
/// Can be used to tell the session to disconnect the connection/abort the handshake process.
|
||||
pub(crate) _disconnect_tx: oneshot::Sender<()>,
|
||||
pub(crate) disconnect_tx: Option<oneshot::Sender<()>>,
|
||||
/// The direction of the session
|
||||
pub(crate) direction: Direction,
|
||||
}
|
||||
|
||||
// === impl PendingSessionHandle ===
|
||||
|
||||
impl PendingSessionHandle {
|
||||
/// Sends a disconnect command to the pending session.
|
||||
pub(crate) fn disconnect(&mut self) {
|
||||
if let Some(tx) = self.disconnect_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An established session with a remote peer.
|
||||
///
|
||||
/// Within an active session that supports the `Ethereum Wire Protocol `, three high-level tasks can
|
||||
|
||||
@@ -222,8 +222,10 @@ impl SessionManager {
|
||||
self.fork_filter.clone(),
|
||||
));
|
||||
|
||||
let handle =
|
||||
PendingSessionHandle { _disconnect_tx: disconnect_tx, direction: Direction::Incoming };
|
||||
let handle = PendingSessionHandle {
|
||||
disconnect_tx: Some(disconnect_tx),
|
||||
direction: Direction::Incoming,
|
||||
};
|
||||
self.pending_sessions.insert(session_id, handle);
|
||||
self.counter.inc_pending_inbound();
|
||||
Ok(session_id)
|
||||
@@ -248,7 +250,7 @@ impl SessionManager {
|
||||
));
|
||||
|
||||
let handle = PendingSessionHandle {
|
||||
_disconnect_tx: disconnect_tx,
|
||||
disconnect_tx: Some(disconnect_tx),
|
||||
direction: Direction::Outgoing(remote_peer_id),
|
||||
};
|
||||
self.pending_sessions.insert(session_id, handle);
|
||||
@@ -265,6 +267,23 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initiates a shutdown of all sessions.
|
||||
///
|
||||
/// It will trigger the disconnect on all the session tasks to gracefully terminate. The result
|
||||
/// will be picked by the receiver.
|
||||
pub(crate) fn disconnect_all(&self, reason: Option<DisconnectReason>) {
|
||||
for (_, session) in self.active_sessions.iter() {
|
||||
session.disconnect(reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// Disconnects all pending sessions.
|
||||
pub(crate) fn disconnect_all_pending(&mut self) {
|
||||
for (_, session) in self.pending_sessions.iter_mut() {
|
||||
session.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a message to the peer's session
|
||||
pub(crate) fn send_message(&mut self, peer_id: &PeerId, msg: PeerMessage) {
|
||||
if let Some(session) = self.active_sessions.get_mut(peer_id) {
|
||||
|
||||
@@ -69,6 +69,8 @@ pub(crate) struct Swarm<C> {
|
||||
sessions: SessionManager,
|
||||
/// Tracks the entire state of the network and handles events received from the sessions.
|
||||
state: NetworkState<C>,
|
||||
/// Tracks the connection state of the node
|
||||
net_connection_state: NetworkConnectionState,
|
||||
}
|
||||
|
||||
// === impl Swarm ===
|
||||
@@ -82,8 +84,9 @@ where
|
||||
incoming: ConnectionListener,
|
||||
sessions: SessionManager,
|
||||
state: NetworkState<C>,
|
||||
net_connection_state: NetworkConnectionState,
|
||||
) -> Self {
|
||||
Self { incoming, sessions, state }
|
||||
Self { incoming, sessions, state, net_connection_state }
|
||||
}
|
||||
|
||||
/// Access to the state.
|
||||
@@ -189,6 +192,10 @@ where
|
||||
return Some(SwarmEvent::TcpListenerClosed { remote_addr: address })
|
||||
}
|
||||
ListenerEvent::Incoming { stream, remote_addr } => {
|
||||
// Reject incoming connection if node is shutting down.
|
||||
if self.is_shutting_down() {
|
||||
return None
|
||||
}
|
||||
// ensure we can handle an incoming connection from this address
|
||||
if let Err(err) =
|
||||
self.state_mut().peers_mut().on_incoming_pending_session(remote_addr.ip())
|
||||
@@ -244,6 +251,10 @@ where
|
||||
StateAction::PeerAdded(peer_id) => return Some(SwarmEvent::PeerAdded(peer_id)),
|
||||
StateAction::PeerRemoved(peer_id) => return Some(SwarmEvent::PeerRemoved(peer_id)),
|
||||
StateAction::DiscoveredNode { peer_id, socket_addr, fork_id } => {
|
||||
// Don't try to connect to peer if node is shutting down
|
||||
if self.is_shutting_down() {
|
||||
return None
|
||||
}
|
||||
// Insert peer only if no fork id or a valid fork id
|
||||
if fork_id.map_or_else(|| true, |f| self.sessions.is_valid_fork_id(f)) {
|
||||
self.state_mut().peers_mut().add_peer(peer_id, socket_addr, fork_id);
|
||||
@@ -259,6 +270,17 @@ where
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Set network connection state to `ShuttingDown`
|
||||
pub(crate) fn on_shutdown_requested(&mut self) {
|
||||
self.net_connection_state = NetworkConnectionState::ShuttingDown;
|
||||
}
|
||||
|
||||
/// Checks if the node's network connection state is 'ShuttingDown'
|
||||
#[inline]
|
||||
fn is_shutting_down(&self) -> bool {
|
||||
matches!(self.net_connection_state, NetworkConnectionState::ShuttingDown)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for Swarm<C>
|
||||
@@ -394,3 +416,12 @@ pub(crate) enum SwarmEvent {
|
||||
/// Failed to establish a tcp stream to the given address/node
|
||||
OutgoingConnectionError { remote_addr: SocketAddr, peer_id: PeerId, error: io::Error },
|
||||
}
|
||||
|
||||
/// Represents the state of the connection of the node. If shutting down,
|
||||
/// new connections won't be established.
|
||||
#[derive(Default)]
|
||||
pub(crate) enum NetworkConnectionState {
|
||||
#[default]
|
||||
Active,
|
||||
ShuttingDown,
|
||||
}
|
||||
|
||||
@@ -509,3 +509,45 @@ async fn test_geth_disconnect() {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_shutdown() {
|
||||
let net = Testnet::create(3).await;
|
||||
|
||||
let mut handles = net.handles();
|
||||
let handle0 = handles.next().unwrap();
|
||||
let handle1 = handles.next().unwrap();
|
||||
let handle2 = handles.next().unwrap();
|
||||
|
||||
drop(handles);
|
||||
let _handle = net.spawn();
|
||||
|
||||
handle0.add_peer(*handle1.peer_id(), handle1.local_addr());
|
||||
handle0.add_peer(*handle2.peer_id(), handle2.local_addr());
|
||||
handle1.add_peer(*handle2.peer_id(), handle2.local_addr());
|
||||
|
||||
let mut expected_connections = HashSet::from([*handle1.peer_id(), *handle2.peer_id()]);
|
||||
|
||||
let mut listener0 = NetworkEventStream::new(handle0.event_listener());
|
||||
|
||||
// Before shutting down, we have two connected peers
|
||||
let peer1 = listener0.next_session_established().await.unwrap();
|
||||
let peer2 = listener0.next_session_established().await.unwrap();
|
||||
assert!(expected_connections.contains(&peer1));
|
||||
assert!(expected_connections.contains(&peer2));
|
||||
assert_eq!(handle0.num_connected_peers(), 2);
|
||||
|
||||
handle0.shutdown().await.unwrap();
|
||||
|
||||
// All sessions get disconnected
|
||||
let (peer1, _reason) = listener0.next_session_closed().await.unwrap();
|
||||
let (peer2, _reason) = listener0.next_session_closed().await.unwrap();
|
||||
assert!(expected_connections.remove(&peer1));
|
||||
assert!(expected_connections.remove(&peer2));
|
||||
assert_eq!(handle0.num_connected_peers(), 0);
|
||||
|
||||
// New connections are rejected
|
||||
handle0.add_peer(*handle1.peer_id(), handle1.local_addr());
|
||||
let (_peer, reason) = listener0.next_session_closed().await.unwrap();
|
||||
assert_eq!(reason, Some(DisconnectReason::DisconnectRequested));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user