diff --git a/src/net/acceptor.rs b/src/net/acceptor.rs index 51429ad44..32628a8ed 100644 --- a/src/net/acceptor.rs +++ b/src/net/acceptor.rs @@ -4,7 +4,7 @@ use std::net::{SocketAddr, TcpListener}; use std::sync::Arc; use crate::net::error::{NetError, NetResult}; -use crate::net::{Channel, ChannelPtr, SettingsPtr}; +use crate::net::{Channel, ChannelPtr}; use crate::system::{StoppableTask, StoppableTaskPtr, Subscriber, SubscriberPtr, Subscription}; pub type AcceptorPtr = Arc; @@ -12,15 +12,13 @@ pub type AcceptorPtr = Arc; pub struct Acceptor { channel_subscriber: SubscriberPtr>, task: StoppableTaskPtr, - settings: SettingsPtr, } impl Acceptor { - pub fn new(settings: SettingsPtr) -> Arc { + pub fn new() -> Arc { Arc::new(Self { channel_subscriber: Subscriber::new(), task: StoppableTask::new(), - settings, }) } @@ -48,14 +46,14 @@ impl Acceptor { fn setup(accept_addr: SocketAddr) -> NetResult> { let listener = match Async::::bind(accept_addr) { - Ok(l) => l, + Ok(listener) => listener, Err(err) => { error!("Bind listener failed: {}", err); return Err(NetError::OperationFailed); } }; let local_addr = match listener.get_ref().local_addr() { - Ok(a) => a, + Ok(addr) => addr, Err(err) => { error!("Failed to get local address: {}", err); return Err(NetError::OperationFailed); @@ -104,7 +102,7 @@ impl Acceptor { }; info!("Accepted client: {}", peer_addr); - let channel = Channel::new(stream, peer_addr, self.settings.clone()).await; + let channel = Channel::new(stream, peer_addr).await; Ok(channel) } } diff --git a/src/net/channel.rs b/src/net/channel.rs index b49a75678..e7d2e72d9 100644 --- a/src/net/channel.rs +++ b/src/net/channel.rs @@ -13,7 +13,6 @@ use crate::error; use crate::net::error::{NetError, NetResult}; use crate::net::message_subscriber::{MessageSubscription, MessageSubsystem}; use crate::net::messages; -use crate::net::settings::SettingsPtr; use crate::system::{StoppableTask, StoppableTaskPtr, Subscriber, SubscriberPtr, Subscription}; pub type ChannelPtr = Arc; @@ -26,14 +25,12 @@ pub struct Channel { stop_subscriber: SubscriberPtr, receive_task: StoppableTaskPtr, stopped: AtomicBool, - settings: SettingsPtr, } impl Channel { pub async fn new( stream: Async, address: SocketAddr, - settings: SettingsPtr, ) -> Arc { let (reader, writer) = stream.split(); let reader = Mutex::new(reader); @@ -50,7 +47,6 @@ impl Channel { stop_subscriber: Subscriber::new(), receive_task: StoppableTask::new(), stopped: AtomicBool::new(false), - settings, }) } diff --git a/src/net/connector.rs b/src/net/connector.rs index d9481007e..8d2bac826 100644 --- a/src/net/connector.rs +++ b/src/net/connector.rs @@ -19,7 +19,7 @@ impl Connector { futures::select! { stream_result = Async::::connect(hostaddr).fuse() => { match stream_result { - Ok(stream) => Ok(Channel::new(stream, hostaddr, self.settings.clone()).await), + Ok(stream) => Ok(Channel::new(stream, hostaddr).await), Err(_) => Err(NetError::ConnectFailed) } } diff --git a/src/net/hosts.rs b/src/net/hosts.rs index f7d9cb746..7b7d12009 100644 --- a/src/net/hosts.rs +++ b/src/net/hosts.rs @@ -3,20 +3,16 @@ use rand::seq::SliceRandom; use std::net::SocketAddr; use std::sync::Arc; -use crate::net::SettingsPtr; - pub type HostsPtr = Arc; pub struct Hosts { addrs: Mutex>, - settings: SettingsPtr, } impl Hosts { - pub fn new(settings: SettingsPtr) -> Arc { + pub fn new() -> Arc { Arc::new(Self { addrs: Mutex::new(Vec::new()), - settings, }) } diff --git a/src/net/message_subscriber.rs b/src/net/message_subscriber.rs index 3f0792310..a20cbd984 100644 --- a/src/net/message_subscriber.rs +++ b/src/net/message_subscriber.rs @@ -207,7 +207,10 @@ impl MessageSubsystem { } } -pub async fn doteste() { +// This is a test function for the message subsystem code above +// Normall we would use the #[test] macro but cannot since it is async code +// Instead we call it using smol::block_on() in the unit test code after this func +async fn _do_message_subscriber_test() { struct MyVersionMessage { x: u32, } @@ -256,6 +259,7 @@ pub async fn doteste() { // receive // 1. do a get easy let msg2 = sub.receive().await.unwrap(); + assert_eq!(msg2.x, 110); println!("{}", msg2.x); subsystem.trigger_error(NetError::ChannelStopped).await; @@ -265,3 +269,14 @@ pub async fn doteste() { sub.unsubscribe().await; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_subscriber() { + smol::block_on(_do_message_subscriber_test()); + } +} + diff --git a/src/net/messages.rs b/src/net/messages.rs index b9e8f0ffb..dbbaea13d 100644 --- a/src/net/messages.rs +++ b/src/net/messages.rs @@ -4,7 +4,6 @@ use std::io; use std::net::SocketAddr; use crate::error::{Error, Result}; -pub use crate::net::AsyncTcpStream; use crate::serial::{Decodable, Encodable, VarInt}; const MAGIC_BYTES: [u8; 4] = [0xd9, 0xef, 0xb6, 0x7d]; diff --git a/src/net/mod.rs b/src/net/mod.rs index 801e373b7..481982999 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,11 +1,7 @@ -use smol::Async; -use std::net::TcpStream; - pub mod acceptor; pub mod channel; pub mod connector; pub mod error; -#[macro_use] pub mod message_subscriber; pub mod hosts; pub mod messages; @@ -15,8 +11,6 @@ pub mod sessions; pub mod settings; pub mod utility; -pub type AsyncTcpStream = async_dup::Arc>; - pub use acceptor::{Acceptor, AcceptorPtr}; pub use channel::{Channel, ChannelPtr}; pub use connector::Connector; diff --git a/src/net/p2p.rs b/src/net/p2p.rs index c5cf6f376..3fa5e6881 100644 --- a/src/net/p2p.rs +++ b/src/net/p2p.rs @@ -31,7 +31,7 @@ impl P2p { pending: Mutex::new(HashSet::new()), channels: Mutex::new(HashMap::new()), stop_subscriber: Subscriber::new(), - hosts: Hosts::new(settings.clone()), + hosts: Hosts::new(), settings, }) } diff --git a/src/net/protocols/protocol_address.rs b/src/net/protocols/protocol_address.rs index d54e146d4..6b4e734c9 100644 --- a/src/net/protocols/protocol_address.rs +++ b/src/net/protocols/protocol_address.rs @@ -6,7 +6,7 @@ use crate::net::error::NetResult; use crate::net::message_subscriber::MessageSubscription; use crate::net::messages; use crate::net::protocols::{ProtocolJobsManager, ProtocolJobsManagerPtr}; -use crate::net::{ChannelPtr, HostsPtr, SettingsPtr}; +use crate::net::{ChannelPtr, HostsPtr}; pub struct ProtocolAddress { channel: ChannelPtr, @@ -15,13 +15,12 @@ pub struct ProtocolAddress { get_addrs_sub: MessageSubscription, hosts: HostsPtr, - settings: SettingsPtr, jobsman: ProtocolJobsManagerPtr, } impl ProtocolAddress { - pub async fn new(channel: ChannelPtr, hosts: HostsPtr, settings: SettingsPtr) -> Arc { + pub async fn new(channel: ChannelPtr, hosts: HostsPtr) -> Arc { let addrs_sub = channel .clone() .subscribe_msg::() @@ -39,7 +38,6 @@ impl ProtocolAddress { addrs_sub, get_addrs_sub, hosts, - settings, jobsman: ProtocolJobsManager::new("ProtocolAddress", channel), }) } diff --git a/src/net/protocols/protocol_jobs_manager.rs b/src/net/protocols/protocol_jobs_manager.rs index d84c7d01b..5909cb572 100644 --- a/src/net/protocols/protocol_jobs_manager.rs +++ b/src/net/protocols/protocol_jobs_manager.rs @@ -29,6 +29,7 @@ impl ProtocolJobsManager { executor.spawn(self.handle_stop()).detach() } + /// Spawns a new task adding it to the internal queue pub async fn spawn<'a, F>(&self, future: F, executor: ExecutorPtr<'a>) where F: Future> + Send + 'a, @@ -36,6 +37,7 @@ impl ProtocolJobsManager { self.tasks.lock().await.push(executor.spawn(future)) } + /// This is run in start(). When the channel closes, we also stop all the tasks async fn handle_stop(self: Arc) { let stop_sub = self.channel.clone().subscribe_stop().await; @@ -52,8 +54,10 @@ impl ProtocolJobsManager { self.name, self.channel.address() ); + // Take all the tasks from our internal queue... let tasks = std::mem::take(&mut *self.tasks.lock().await); for task in tasks { + // ... and cancel them let _ = task.cancel().await; } } diff --git a/src/net/sessions/inbound_session.rs b/src/net/sessions/inbound_session.rs index 55446872b..9eb8c635d 100644 --- a/src/net/sessions/inbound_session.rs +++ b/src/net/sessions/inbound_session.rs @@ -18,12 +18,7 @@ pub struct InboundSession { impl InboundSession { pub fn new(p2p: Weak) -> Arc { - let settings = { - let p2p = p2p.upgrade().unwrap(); - p2p.settings() - }; - - let acceptor = Acceptor::new(settings); + let acceptor = Acceptor::new(); Arc::new(Self { p2p, @@ -73,6 +68,7 @@ impl InboundSession { result } + /// Wait for all new channels created by the acceptor and call setup_channel() on them. async fn channel_sub_loop(self: Arc, executor: Arc>) -> NetResult<()> { let channel_sub = self.acceptor.clone().subscribe().await; loop { @@ -108,7 +104,7 @@ impl InboundSession { let hosts = self.p2p().hosts().clone(); let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone()); - let protocol_addr = ProtocolAddress::new(channel, hosts, settings).await; + let protocol_addr = ProtocolAddress::new(channel, hosts).await; protocol_ping.start(executor.clone()).await; protocol_addr.start(executor).await; diff --git a/src/net/sessions/outbound_session.rs b/src/net/sessions/outbound_session.rs index 5a2295519..4c586e232 100644 --- a/src/net/sessions/outbound_session.rs +++ b/src/net/sessions/outbound_session.rs @@ -95,6 +95,10 @@ impl OutboundSession { } } + /// Load a valid address that we can connect to. + /// Valid means we aren't connecting (pending state) or connected (open channel) + /// in another slot, and it isn't our own inbound address. + /// Retry otherwise. async fn load_address(&self, slot_number: u32) -> NetResult { let p2p = self.p2p(); let hosts = p2p.hosts(); @@ -146,7 +150,7 @@ impl OutboundSession { let hosts = self.p2p().hosts().clone(); let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone()); - let protocol_addr = ProtocolAddress::new(channel, hosts, settings).await; + let protocol_addr = ProtocolAddress::new(channel, hosts).await; protocol_ping.start(executor.clone()).await; protocol_addr.start(executor).await; diff --git a/src/net/sessions/seed_session.rs b/src/net/sessions/seed_session.rs index ecd307b62..2c04b7bae 100644 --- a/src/net/sessions/seed_session.rs +++ b/src/net/sessions/seed_session.rs @@ -36,6 +36,9 @@ impl SeedSession { tasks.push(executor.spawn(self.clone().start_seed(i, seed.clone(), executor.clone()))); } + // This line loops through all the tasks and waits for them to finish. + // But if the seed_query_timeout_seconds times out before they are finished, + // then it will simply quit and the tasks will get dropped. futures::select! { _ = async move { for (i, task) in tasks.into_iter().enumerate() {