diff --git a/src/net/error.rs b/src/net/error.rs new file mode 100644 index 000000000..0ac237586 --- /dev/null +++ b/src/net/error.rs @@ -0,0 +1,29 @@ +use std::fmt; + +pub type NetResult = std::result::Result; + +#[derive(Debug, Copy, Clone)] +pub enum NetError { + OperationFailed, + ConnectFailed, + ConnectTimeout, + ChannelStopped, + ChannelTimeout, + ServiceStopped, +} + +impl std::error::Error for NetError {} + +impl fmt::Display for NetError { + fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result { + match *self { + NetError::OperationFailed => f.write_str("Operation failed"), + NetError::ConnectFailed => f.write_str("Connection failed"), + NetError::ConnectTimeout => f.write_str("Connection timed out"), + NetError::ChannelStopped => f.write_str("Channel stopped"), + NetError::ChannelTimeout => f.write_str("Channel timed out"), + NetError::ServiceStopped => f.write_str("Service stopped"), + } + } +} + diff --git a/src/net/message_subscriber.rs b/src/net/message_subscriber.rs index bbc4ac479..e1c5415c6 100644 --- a/src/net/message_subscriber.rs +++ b/src/net/message_subscriber.rs @@ -14,7 +14,7 @@ pub type MessageSubscriptionID = u64; macro_rules! receive_message { ($sub:expr, $message_type:path) => {{ - let wrapped_message = OwningRef::new($sub.receive().await?); + let wrapped_message = owning_ref::OwningRef::new($sub.receive().await?); wrapped_message.map(|msg| match msg { $message_type(msg_detail) => msg_detail, diff --git a/src/net/messages.rs b/src/net/messages.rs index c11c758e3..35fad2a25 100644 --- a/src/net/messages.rs +++ b/src/net/messages.rs @@ -29,7 +29,6 @@ pub enum PacketType { Pong = 1, GetAddrs = 2, Addrs = 3, - Sync = 4, Inv = 5, GetSlabs = 6, Slab = 7, @@ -38,11 +37,10 @@ pub enum PacketType { } pub enum Message { - Ping, - Pong, + Ping(PingMessage), + Pong(PongMessage), GetAddrs(GetAddrsMessage), Addrs(AddrsMessage), - Sync, Inv(InvMessage), GetSlabs(GetSlabsMessage), Slab(SlabMessage), @@ -50,6 +48,14 @@ pub enum Message { Verack(VerackMessage), } +pub struct PingMessage { + pub nonce: u32 +} + +pub struct PongMessage { + pub nonce: u32 +} + pub struct GetAddrsMessage {} pub struct GetSlabsMessage { @@ -74,6 +80,38 @@ pub struct VersionMessage {} pub struct VerackMessage {} +impl Encodable for PingMessage { + fn encode(&self, mut s: S) -> Result { + let mut len = 0; + len += self.nonce.encode(&mut s)?; + Ok(len) + } +} + +impl Decodable for PingMessage { + fn decode(mut d: D) -> Result { + Ok(Self { + nonce: Decodable::decode(&mut d)?, + }) + } +} + +impl Encodable for PongMessage { + fn encode(&self, mut s: S) -> Result { + let mut len = 0; + len += self.nonce.encode(&mut s)?; + Ok(len) + } +} + +impl Decodable for PongMessage { + fn decode(mut d: D) -> Result { + Ok(Self { + nonce: Decodable::decode(&mut d)?, + }) + } +} + impl Encodable for GetSlabsMessage { fn encode(&self, mut s: S) -> Result { let mut len = 0; @@ -180,11 +218,10 @@ impl Decodable for VerackMessage { impl Message { pub fn packet_type(&self) -> PacketType { match self { - Message::Ping => PacketType::Ping, - Message::Pong => PacketType::Pong, + Message::Ping(message) => PacketType::Ping, + Message::Pong(message) => PacketType::Pong, Message::GetAddrs(message) => PacketType::GetAddrs, Message::Addrs(message) => PacketType::Addrs, - Message::Sync => PacketType::Sync, Message::Inv(message) => PacketType::Inv, Message::GetSlabs(message) => PacketType::GetSlabs, Message::Slab(message) => PacketType::Slab, @@ -195,14 +232,22 @@ impl Message { pub fn pack(&self) -> Result { match self { - Message::Ping => Ok(Packet { - command: PacketType::Ping, - payload: Vec::new(), - }), - Message::Pong => Ok(Packet { - command: PacketType::Pong, - payload: Vec::new(), - }), + Message::Ping(message) => { + let mut payload = Vec::new(); + message.encode(&mut payload)?; + Ok(Packet { + command: PacketType::Ping, + payload: Vec::new(), + }) + } + Message::Pong(message) => { + let mut payload = Vec::new(); + message.encode(&mut payload)?; + Ok(Packet { + command: PacketType::Pong, + payload: Vec::new(), + }) + } Message::GetAddrs(message) => { let mut payload = Vec::new(); message.encode(&mut payload)?; @@ -219,13 +264,6 @@ impl Message { payload, }) } - Message::Sync => { - let payload = Vec::new(); - Ok(Packet { - command: PacketType::Sync, - payload, - }) - } Message::Inv(message) => { let payload = serialize(message); Ok(Packet { @@ -267,11 +305,10 @@ impl Message { pub fn unpack(packet: Packet) -> Result { let cursor = Cursor::new(packet.payload.clone()); match packet.command { - PacketType::Ping => Ok(Self::Ping), - PacketType::Pong => Ok(Self::Pong), + PacketType::Ping => Ok(Self::Ping(PingMessage::decode(cursor)?)), + PacketType::Pong => Ok(Self::Pong(PongMessage::decode(cursor)?)), PacketType::GetAddrs => Ok(Self::GetAddrs(GetAddrsMessage::decode(cursor)?)), PacketType::Addrs => Ok(Self::Addrs(AddrsMessage::decode(cursor)?)), - PacketType::Sync => Ok(Self::Sync), PacketType::Inv => Ok(Self::Inv(InvMessage::decode(cursor)?)), PacketType::GetSlabs => Ok(Self::GetSlabs(GetSlabsMessage::decode(cursor)?)), PacketType::Slab => Ok(Self::Slab(SlabMessage::decode(cursor)?)), @@ -282,11 +319,10 @@ impl Message { pub fn name(&self) -> &'static str { match self { - Message::Ping => "Ping", - Message::Pong => "Pong", + Message::Ping(_) => "Ping", + Message::Pong(_) => "Pong", Message::GetAddrs(_) => "GetAddrs", Message::Addrs(_) => "Addrs", - Message::Sync => "Sync", Message::Inv(_) => "Inv", Message::GetSlabs(_) => "GetSlabs", Message::Slab(_) => "Slab", diff --git a/src/net/protocols/mod.rs b/src/net/protocols/mod.rs index c53fe8dd7..5cdeb7a0f 100644 --- a/src/net/protocols/mod.rs +++ b/src/net/protocols/mod.rs @@ -1,7 +1,11 @@ +pub mod protocol_address; pub mod protocol_ping; pub mod protocol_seed; pub mod protocol_version; +pub mod protocol_jobs_manager; +pub use protocol_address::ProtocolAddress; pub use protocol_ping::ProtocolPing; pub use protocol_seed::ProtocolSeed; pub use protocol_version::ProtocolVersion; +pub use protocol_jobs_manager::{ProtocolJobsManager, ProtocolJobsManagerPtr}; diff --git a/src/net/protocols/protocol_address.rs b/src/net/protocols/protocol_address.rs new file mode 100644 index 000000000..37371be84 --- /dev/null +++ b/src/net/protocols/protocol_address.rs @@ -0,0 +1,28 @@ +use futures::FutureExt; +use log::*; +use rand::Rng; +use smol::{Executor, Task}; +use std::sync::Arc; + +use crate::net::error::{NetError, NetResult}; +use crate::net::messages; +use crate::net::utility::sleep; +use crate::net::{ChannelPtr, SettingsPtr}; +use crate::net::protocols::{ProtocolJobsManager, ProtocolJobsManagerPtr}; + +pub struct ProtocolAddress { + channel: ChannelPtr, + settings: SettingsPtr, + + jobsman: ProtocolJobsManagerPtr +} + +impl ProtocolAddress { + pub fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc { + Arc::new(Self { channel: channel.clone(), settings, jobsman: ProtocolJobsManager::new(channel) }) + } + + pub async fn start(self: Arc, executor: Arc>) { + } +} + diff --git a/src/net/protocols/protocol_jobs_manager.rs b/src/net/protocols/protocol_jobs_manager.rs new file mode 100644 index 000000000..62f52cbca --- /dev/null +++ b/src/net/protocols/protocol_jobs_manager.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; +use smol::Task; +use futures::Future; +use async_std::sync::Mutex; + +use crate::net::error::NetResult; +use crate::net::ChannelPtr; +use crate::system::ExecutorPtr; + +pub type ProtocolJobsManagerPtr = Arc; + +pub struct ProtocolJobsManager { + channel: ChannelPtr, + tasks: Mutex>>> +} + +impl ProtocolJobsManager { + pub fn new(channel: ChannelPtr) -> Arc { + Arc::new(Self { + channel, + tasks: Mutex::new(Vec::new()) + }) + } + + pub fn start(self: Arc, executor: ExecutorPtr<'_>) { + executor.spawn(self.handle_stop()).detach() + } + + pub async fn spawn<'a, F>(&self, future: F, executor: ExecutorPtr<'a>) + where + F: Future> + Send + 'a + { + self.tasks.lock().await.push(executor.spawn(future)) + } + + async fn handle_stop(self: Arc) { + let stop_sub = self.channel.clone().subscribe_stop().await; + + // Wait for the stop signal + // Not interested in the exact error + let _ = stop_sub.receive().await; + + self.close_all_tasks().await + } + + async fn close_all_tasks(self: Arc) { + let tasks = std::mem::take(&mut *self.tasks.lock().await); + for task in tasks { + let _ = task.cancel().await; + } + } +} + diff --git a/src/net/protocols/protocol_ping.rs b/src/net/protocols/protocol_ping.rs index 9192a7ae5..1bf670ba2 100644 --- a/src/net/protocols/protocol_ping.rs +++ b/src/net/protocols/protocol_ping.rs @@ -1,25 +1,31 @@ use futures::FutureExt; +use log::*; use rand::Rng; use smol::{Executor, Task}; use std::sync::Arc; use crate::net::error::{NetError, NetResult}; use crate::net::messages; -use crate::net::utility::{clone_net_error, sleep}; +use crate::net::utility::sleep; use crate::net::{ChannelPtr, SettingsPtr}; +use crate::net::protocols::{ProtocolJobsManager, ProtocolJobsManagerPtr}; pub struct ProtocolPing { channel: ChannelPtr, settings: SettingsPtr, + + jobsman: ProtocolJobsManagerPtr } impl ProtocolPing { pub fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc { - Arc::new(Self { channel, settings }) + Arc::new(Self { channel: channel.clone(), settings, jobsman: ProtocolJobsManager::new(channel) }) } - pub fn start(self: Arc, executor: Arc>) -> Task> { - executor.spawn(self.run_ping_pong()) + pub async fn start(self: Arc, executor: Arc>) { + self.jobsman.clone().start(executor.clone()); + self.jobsman.clone().spawn(self.clone().run_ping_pong(), executor.clone()).await; + self.jobsman.clone().spawn(self.reply_to_ping(), executor).await; } async fn run_ping_pong(self: Arc) -> NetResult<()> { @@ -34,18 +40,40 @@ impl ProtocolPing { sleep(self.settings.channel_heartbeat_seconds).await; // Create a random nonce - let _nonce = Self::random_nonce(); - // TODO: add the nonce after delete other crappy network code + let nonce = Self::random_nonce(); // Send ping message - let ping = messages::Message::Ping; + let ping = messages::Message::Ping(messages::PingMessage { + nonce + }); self.channel.clone().send(ping).await?; // Wait for pong, check nonce matches - let _pong_msg = pong_sub.receive().await?; - // TODO: fix pong enum - //let _pong_msg = receive_message!(pong_sub, messages::Message::Pong); - // TODO: add nonce check here + let pong_msg = receive_message!(pong_sub, messages::Message::Pong); + if pong_msg.nonce != nonce { + error!("Wrong nonce for ping reply. Disconnecting from channel."); + self.channel.stop().await; + return Err(NetError::ChannelStopped); + } + } + } + + async fn reply_to_ping(self: Arc) -> NetResult<()> { + let ping_sub = self + .channel + .clone() + .subscribe_msg(messages::PacketType::Ping) + .await; + + loop { + // Wait for ping, reply with pong that has a matching nonce + let ping = receive_message!(ping_sub, messages::Message::Ping); + + // Send ping message + let pong = messages::Message::Pong(messages::PongMessage { + nonce: ping.nonce + }); + self.channel.clone().send(pong).await?; } } diff --git a/src/net/sessions/inbound_session.rs b/src/net/sessions/inbound_session.rs index d286e692e..5c9ae6389 100644 --- a/src/net/sessions/inbound_session.rs +++ b/src/net/sessions/inbound_session.rs @@ -4,11 +4,10 @@ use std::net::SocketAddr; use std::sync::{Arc, Weak}; use crate::net::error::{NetError, NetResult}; -use crate::net::protocols::{ProtocolPing, ProtocolSeed}; +use crate::net::protocols::{ProtocolPing, ProtocolAddress, ProtocolSeed}; use crate::net::sessions::Session; use crate::net::{Acceptor, AcceptorPtr}; use crate::net::{ChannelPtr, Connector, HostsPtr, P2p, SettingsPtr}; -use crate::net::utility::clone_net_error; use crate::system::{StoppableTask, StoppableTaskPtr}; pub struct InboundSession { @@ -41,7 +40,7 @@ impl InboundSession { } self.accept_task.clone().start( - self.clone().channel_sub_loop(), + self.clone().channel_sub_loop(executor.clone()), // Ignore stop handler |_| { async {} }, NetError::ServiceStopped, @@ -60,23 +59,50 @@ impl InboundSession { executor: Arc>, ) -> NetResult<()> { info!("Starting inbound session on {}", accept_addr); - match self.acceptor.clone().start(accept_addr, executor) { - Ok(()) => { - } - Err(err) => { - error!("Error starting listener: {}", err); - return Err(err); - } + let result = self.acceptor.clone().start(accept_addr, executor); + if let Err(err) = result { + error!("Error starting listener: {}", err); } - Ok(()) + result } - async fn channel_sub_loop(self: Arc) -> NetResult<()> { + async fn channel_sub_loop(self: Arc, executor: Arc>) -> NetResult<()> { let channel_sub = self.acceptor.clone().subscribe().await; loop { - //let channel = (*channel_sub.receive().await)?; + let channel = (*channel_sub.receive().await).clone()?; + // Spawn a detached task to process the channel + // This will just perform the channel setup then exit. + executor.spawn(self.clone().setup_channel(channel, executor.clone())).detach(); } } + + async fn setup_channel(self: Arc, channel: ChannelPtr, executor: Arc>) -> NetResult<()> { + info!("Connected inbound [{}]", channel.address()); + + self.clone() + .register_channel(channel.clone(), executor.clone()) + .await?; + + let settings = self.p2p.upgrade().unwrap().settings(); + + self.attach_protocols(channel, settings, executor) + .await + } + + async fn attach_protocols( + self: Arc, + channel: ChannelPtr, + settings: SettingsPtr, + executor: Arc>, + ) -> NetResult<()> { + let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone()); + protocol_ping.start(executor.clone()).await; + + let protocol_addr = ProtocolAddress::new(channel, settings); + protocol_addr.start(executor).await; + + Ok(()) + } } impl Session for InboundSession { diff --git a/src/net/sessions/seed_session.rs b/src/net/sessions/seed_session.rs index f5371ad96..16fa4181a 100644 --- a/src/net/sessions/seed_session.rs +++ b/src/net/sessions/seed_session.rs @@ -79,19 +79,6 @@ impl SeedSession { } } - async fn register_channel( - self: Arc, - channel: ChannelPtr, - executor: Arc>, - ) -> NetResult<()> { - let handshake_task = self.perform_handshake_protocols(channel.clone(), executor.clone()); - - // start channel - channel.start(executor); - - handshake_task.await - } - async fn attach_protocols( self: Arc, channel: ChannelPtr, @@ -100,15 +87,11 @@ impl SeedSession { executor: Arc>, ) -> NetResult<()> { let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone()); - let ping_task = protocol_ping.start(executor.clone()); + protocol_ping.start(executor.clone()).await; let protocol_seed = ProtocolSeed::new(channel, hosts, settings.clone()); protocol_seed.start(executor.clone()).await?; - // Close the ping task now we finished. - // TODO: channel drop should trigger this automatically anyway via the stop signal - ping_task.cancel().await; - Ok(()) } } diff --git a/src/net/sessions/session.rs b/src/net/sessions/session.rs index 9f918c65d..3385ebcd1 100644 --- a/src/net/sessions/session.rs +++ b/src/net/sessions/session.rs @@ -17,7 +17,20 @@ async fn remove_sub_on_stop(p2p: P2pPtr, channel: ChannelPtr) { } #[async_trait] -pub trait Session { +pub trait Session: Sync { + async fn register_channel( + self: Arc, + channel: ChannelPtr, + executor: Arc>, + ) -> NetResult<()> { + let handshake_task = self.perform_handshake_protocols(channel.clone(), executor.clone()); + + // start channel + channel.start(executor); + + handshake_task.await + } + async fn perform_handshake_protocols( &self, channel: ChannelPtr, diff --git a/src/net/utility.rs b/src/net/utility.rs index 8d5b28687..6dc3fed7d 100644 --- a/src/net/utility.rs +++ b/src/net/utility.rs @@ -1,18 +1,7 @@ -use smol::{Async, Executor, Timer}; +use smol::Timer; use std::time::Duration; -use crate::error::Error; - pub async fn sleep(seconds: u32) { Timer::after(Duration::from_secs(seconds.into())).await; } -pub fn clone_net_error(error: &Error) -> Error { - match error { - Error::ConnectFailed => Error::ConnectFailed, - Error::ConnectTimeout => Error::ConnectTimeout, - Error::ChannelStopped => Error::ChannelStopped, - Error::ChannelTimeout => Error::ChannelTimeout, - _ => Error::OperationFailed, - } -} diff --git a/src/system/mod.rs b/src/system/mod.rs index d01ec21bc..73657bbe6 100644 --- a/src/system/mod.rs +++ b/src/system/mod.rs @@ -1,5 +1,8 @@ pub mod stoppable_task; pub mod subscriber; +pub mod types; pub use stoppable_task::{StoppableTask, StoppableTaskPtr}; pub use subscriber::{Subscriber, SubscriberPtr, Subscription}; +pub use types::ExecutorPtr; + diff --git a/src/system/types.rs b/src/system/types.rs new file mode 100644 index 000000000..e86f0cb87 --- /dev/null +++ b/src/system/types.rs @@ -0,0 +1,5 @@ +use smol::Executor; +use std::sync::Arc; + +pub type ExecutorPtr<'a> = Arc>; +