feat(net): add shutdown network signal (#1011)

Co-authored-by: lambdaclass-user <github@lambdaclass.com>
This commit is contained in:
Mariano A. Nicolini
2023-02-04 22:28:13 +02:00
committed by GitHub
parent db5410b84b
commit dab1f4f497
7 changed files with 133 additions and 8 deletions

View File

@@ -28,7 +28,7 @@ use crate::{
peers::{PeersHandle, PeersManager},
session::SessionManager,
state::NetworkState,
swarm::{Swarm, SwarmEvent},
swarm::{NetworkConnectionState, Swarm, SwarmEvent},
transactions::NetworkTransactionEvent,
FetchClient, NetworkBuilder,
};
@@ -206,7 +206,7 @@ where
Arc::clone(&num_active_peers),
);
let swarm = Swarm::new(incoming, sessions, state);
let swarm = Swarm::new(incoming, sessions, state, NetworkConnectionState::default());
let (to_manager_tx, from_handle_rx) = mpsc::unbounded_channel();
@@ -513,6 +513,17 @@ where
NetworkHandleMessage::DisconnectPeer(peer_id, reason) => {
self.swarm.sessions_mut().disconnect(peer_id, reason);
}
NetworkHandleMessage::Shutdown(tx) => {
// Set connection status to `Shutdown`. Stops node to accept
// new incoming connections as well as sending connection requests to newly
// discovered nodes.
self.swarm.on_shutdown_requested();
// Disconnect all active connections
self.swarm.sessions_mut().disconnect_all(Some(DisconnectReason::ClientQuitting));
// drop pending connections
self.swarm.sessions_mut().disconnect_all_pending();
let _ = tx.send(());
}
NetworkHandleMessage::ReputationChange(peer_id, kind) => {
self.swarm.state_mut().peers_mut().apply_reputation_change(&peer_id, kind);
}

View File

@@ -160,6 +160,16 @@ impl NetworkHandle {
pub fn bandwidth_meter(&self) -> &BandwidthMeter {
&self.inner.bandwidth_meter
}
/// Send message to gracefully shutdown node.
///
/// This will disconnect all active and pending sessions and prevent
/// new connections to be established.
pub async fn shutdown(&self) -> Result<(), oneshot::error::RecvError> {
let (tx, rx) = oneshot::channel();
self.send_message(NetworkHandleMessage::Shutdown(tx));
rx.await
}
}
// === API Implementations ===
@@ -302,4 +312,6 @@ pub(crate) enum NetworkHandleMessage {
GetPeerInfo(oneshot::Sender<Vec<PeerInfo>>),
/// Get PeerInfo for a specific peer
GetPeerInfoById(PeerId, oneshot::Sender<Option<PeerInfo>>),
/// Gracefully shutdown network
Shutdown(oneshot::Sender<()>),
}

View File

@@ -163,7 +163,6 @@ impl PeersManager {
if !self.connection_info.has_in_capacity() {
return Err(InboundConnectionError::ExceedsLimit(self.connection_info.max_inbound))
}
// keep track of new connection
self.connection_info.inc_in();
Ok(())

View File

@@ -24,11 +24,22 @@ use tokio::{
#[derive(Debug)]
pub(crate) struct PendingSessionHandle {
/// Can be used to tell the session to disconnect the connection/abort the handshake process.
pub(crate) _disconnect_tx: oneshot::Sender<()>,
pub(crate) disconnect_tx: Option<oneshot::Sender<()>>,
/// The direction of the session
pub(crate) direction: Direction,
}
// === impl PendingSessionHandle ===
impl PendingSessionHandle {
/// Sends a disconnect command to the pending session.
pub(crate) fn disconnect(&mut self) {
if let Some(tx) = self.disconnect_tx.take() {
let _ = tx.send(());
}
}
}
/// An established session with a remote peer.
///
/// Within an active session that supports the `Ethereum Wire Protocol `, three high-level tasks can

View File

@@ -222,8 +222,10 @@ impl SessionManager {
self.fork_filter.clone(),
));
let handle =
PendingSessionHandle { _disconnect_tx: disconnect_tx, direction: Direction::Incoming };
let handle = PendingSessionHandle {
disconnect_tx: Some(disconnect_tx),
direction: Direction::Incoming,
};
self.pending_sessions.insert(session_id, handle);
self.counter.inc_pending_inbound();
Ok(session_id)
@@ -248,7 +250,7 @@ impl SessionManager {
));
let handle = PendingSessionHandle {
_disconnect_tx: disconnect_tx,
disconnect_tx: Some(disconnect_tx),
direction: Direction::Outgoing(remote_peer_id),
};
self.pending_sessions.insert(session_id, handle);
@@ -265,6 +267,23 @@ impl SessionManager {
}
}
/// Initiates a shutdown of all sessions.
///
/// It will trigger the disconnect on all the session tasks to gracefully terminate. The result
/// will be picked by the receiver.
pub(crate) fn disconnect_all(&self, reason: Option<DisconnectReason>) {
for (_, session) in self.active_sessions.iter() {
session.disconnect(reason);
}
}
/// Disconnects all pending sessions.
pub(crate) fn disconnect_all_pending(&mut self) {
for (_, session) in self.pending_sessions.iter_mut() {
session.disconnect();
}
}
/// Sends a message to the peer's session
pub(crate) fn send_message(&mut self, peer_id: &PeerId, msg: PeerMessage) {
if let Some(session) = self.active_sessions.get_mut(peer_id) {

View File

@@ -69,6 +69,8 @@ pub(crate) struct Swarm<C> {
sessions: SessionManager,
/// Tracks the entire state of the network and handles events received from the sessions.
state: NetworkState<C>,
/// Tracks the connection state of the node
net_connection_state: NetworkConnectionState,
}
// === impl Swarm ===
@@ -82,8 +84,9 @@ where
incoming: ConnectionListener,
sessions: SessionManager,
state: NetworkState<C>,
net_connection_state: NetworkConnectionState,
) -> Self {
Self { incoming, sessions, state }
Self { incoming, sessions, state, net_connection_state }
}
/// Access to the state.
@@ -189,6 +192,10 @@ where
return Some(SwarmEvent::TcpListenerClosed { remote_addr: address })
}
ListenerEvent::Incoming { stream, remote_addr } => {
// Reject incoming connection if node is shutting down.
if self.is_shutting_down() {
return None
}
// ensure we can handle an incoming connection from this address
if let Err(err) =
self.state_mut().peers_mut().on_incoming_pending_session(remote_addr.ip())
@@ -244,6 +251,10 @@ where
StateAction::PeerAdded(peer_id) => return Some(SwarmEvent::PeerAdded(peer_id)),
StateAction::PeerRemoved(peer_id) => return Some(SwarmEvent::PeerRemoved(peer_id)),
StateAction::DiscoveredNode { peer_id, socket_addr, fork_id } => {
// Don't try to connect to peer if node is shutting down
if self.is_shutting_down() {
return None
}
// Insert peer only if no fork id or a valid fork id
if fork_id.map_or_else(|| true, |f| self.sessions.is_valid_fork_id(f)) {
self.state_mut().peers_mut().add_peer(peer_id, socket_addr, fork_id);
@@ -259,6 +270,17 @@ where
}
None
}
/// Set network connection state to `ShuttingDown`
pub(crate) fn on_shutdown_requested(&mut self) {
self.net_connection_state = NetworkConnectionState::ShuttingDown;
}
/// Checks if the node's network connection state is 'ShuttingDown'
#[inline]
fn is_shutting_down(&self) -> bool {
matches!(self.net_connection_state, NetworkConnectionState::ShuttingDown)
}
}
impl<C> Stream for Swarm<C>
@@ -394,3 +416,12 @@ pub(crate) enum SwarmEvent {
/// Failed to establish a tcp stream to the given address/node
OutgoingConnectionError { remote_addr: SocketAddr, peer_id: PeerId, error: io::Error },
}
/// Represents the state of the connection of the node. If shutting down,
/// new connections won't be established.
#[derive(Default)]
pub(crate) enum NetworkConnectionState {
#[default]
Active,
ShuttingDown,
}

View File

@@ -509,3 +509,45 @@ async fn test_geth_disconnect() {
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_shutdown() {
let net = Testnet::create(3).await;
let mut handles = net.handles();
let handle0 = handles.next().unwrap();
let handle1 = handles.next().unwrap();
let handle2 = handles.next().unwrap();
drop(handles);
let _handle = net.spawn();
handle0.add_peer(*handle1.peer_id(), handle1.local_addr());
handle0.add_peer(*handle2.peer_id(), handle2.local_addr());
handle1.add_peer(*handle2.peer_id(), handle2.local_addr());
let mut expected_connections = HashSet::from([*handle1.peer_id(), *handle2.peer_id()]);
let mut listener0 = NetworkEventStream::new(handle0.event_listener());
// Before shutting down, we have two connected peers
let peer1 = listener0.next_session_established().await.unwrap();
let peer2 = listener0.next_session_established().await.unwrap();
assert!(expected_connections.contains(&peer1));
assert!(expected_connections.contains(&peer2));
assert_eq!(handle0.num_connected_peers(), 2);
handle0.shutdown().await.unwrap();
// All sessions get disconnected
let (peer1, _reason) = listener0.next_session_closed().await.unwrap();
let (peer2, _reason) = listener0.next_session_closed().await.unwrap();
assert!(expected_connections.remove(&peer1));
assert!(expected_connections.remove(&peer2));
assert_eq!(handle0.num_connected_peers(), 0);
// New connections are rejected
handle0.add_peer(*handle1.peer_id(), handle1.local_addr());
let (_peer, reason) = listener0.next_session_closed().await.unwrap();
assert_eq!(reason, Some(DisconnectReason::DisconnectRequested));
}