diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index ec0d1dd784..9905f395aa 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -577,7 +577,7 @@ where ); } else { // Gracefully disconnected - this.swarm.state_mut().peers_mut().on_disconnected(&peer_id); + this.swarm.state_mut().peers_mut().on_disconnected(peer_id); } this.event_listeners.send(NetworkEvent::SessionClosed { peer_id }); diff --git a/crates/net/network/src/peers/manager.rs b/crates/net/network/src/peers/manager.rs index 003553e102..5c1d532df9 100644 --- a/crates/net/network/src/peers/manager.rs +++ b/crates/net/network/src/peers/manager.rs @@ -17,13 +17,14 @@ use std::{ task::{Context, Poll}, time::Duration, }; + use thiserror::Error; use tokio::{ sync::{mpsc, oneshot}, time::{Instant, Interval}, }; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::trace; +use tracing::{debug, trace}; /// A communication channel to the [`PeersManager`] to apply manual changes to the peer set. #[derive(Clone, Debug)] @@ -227,10 +228,20 @@ impl PeersManager { } /// Gracefully disconnected - pub(crate) fn on_disconnected(&mut self, peer_id: &PeerId) { - if let Some(mut peer) = self.peers.get_mut(peer_id) { - self.connection_info.decr_state(peer.state); - peer.state = PeerConnectionState::Idle; + pub(crate) fn on_disconnected(&mut self, peer_id: PeerId) { + match self.peers.entry(peer_id) { + Entry::Occupied(mut entry) => { + self.connection_info.decr_state(entry.get().state); + + if entry.get().remove_after_disconnect { + // this peer should be removed from the set + entry.remove(); + } else { + entry.get_mut().state = PeerConnectionState::Idle; + return + } + } + Entry::Vacant(_) => return, } self.fill_outbound_slots(); @@ -321,12 +332,22 @@ impl PeersManager { /// Removes the tracked node from the set. pub(crate) fn remove_discovered_node(&mut self, peer_id: PeerId) { - if let Some(entry) = self.peers.remove(&peer_id) { + if let Some(mut peer) = self.peers.remove(&peer_id) { trace!(target : "net::peers", ?peer_id, "remove discovered node"); - if entry.state.is_connected() { - // TODO(mattsse): is this right to disconnect peers? - self.connection_info.decr_state(entry.state); - self.queued_actions.push_back(PeerAction::Disconnect { peer_id, reason: None }) + + if peer.state.is_connected() { + debug!(target : "net::peers", ?peer_id, "disconnecting on remove from discovery"); + // we terminate the active session here, but only remove the peer after the session + // was disconnected, this prevents the case where the session is scheduled for + // disconnect but the node is immediately rediscovered, See also + // [`Self::on_disconnected()`] + peer.remove_after_disconnect = true; + peer.state.disconnect(); + self.peers.insert(peer_id, peer); + self.queued_actions.push_back(PeerAction::Disconnect { + peer_id, + reason: Some(DisconnectReason::DisconnectRequested), + }) } } } @@ -512,6 +533,8 @@ pub struct Peer { state: PeerConnectionState, /// The [`ForkId`] that the peer announced via discovery. fork_id: Option, + /// Whether the entry should be removed after an existing session was terminated. + remove_after_disconnect: bool, } // === impl Peer === @@ -522,7 +545,13 @@ impl Peer { } fn with_state(addr: SocketAddr, state: PeerConnectionState) -> Self { - Self { addr, state, reputation: DEFAULT_REPUTATION, fork_id: None } + Self { + addr, + state, + reputation: DEFAULT_REPUTATION, + fork_id: None, + remove_after_disconnect: false, + } } /// Applies a reputation change to the peer and returns what action should be taken. @@ -869,7 +898,7 @@ mod test { assert_eq!(p.state, PeerConnectionState::DisconnectingOut); assert!(p.is_banned()); - peers.on_disconnected(&peer); + peers.on_disconnected(peer); let p = peers.peers.get(&peer).unwrap(); assert_eq!(p.state, PeerConnectionState::Idle); @@ -883,6 +912,44 @@ mod test { } } + #[tokio::test] + async fn test_remove_discovered_active() { + let peer = PeerId::random(); + let socket_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 2)), 8008); + let mut peers = PeersManager::default(); + peers.add_discovered_node(peer, socket_addr); + + match event!(peers) { + PeerAction::Connect { peer_id, remote_addr } => { + assert_eq!(peer_id, peer); + assert_eq!(remote_addr, socket_addr); + } + _ => unreachable!(), + } + + let p = peers.peers.get(&peer).unwrap(); + assert_eq!(p.state, PeerConnectionState::Out); + + peers.remove_discovered_node(peer); + + match event!(peers) { + PeerAction::Disconnect { peer_id, .. } => { + assert_eq!(peer_id, peer); + } + _ => unreachable!(), + } + + let p = peers.peers.get(&peer).unwrap(); + assert_eq!(p.state, PeerConnectionState::DisconnectingOut); + + peers.add_discovered_node(peer, socket_addr); + let p = peers.peers.get(&peer).unwrap(); + assert_eq!(p.state, PeerConnectionState::DisconnectingOut); + + peers.on_disconnected(peer); + assert!(peers.peers.get(&peer).is_none()); + } + #[tokio::test] async fn test_discovery_ban_list() { let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 1, 2));