Start to better define the protocol subsystem through ProtocolJobsManager

This commit is contained in:
narodnik
2021-01-25 17:02:45 +01:00
parent 295bfed30f
commit bd2353620e
13 changed files with 281 additions and 84 deletions

29
src/net/error.rs Normal file
View File

@@ -0,0 +1,29 @@
use std::fmt;
pub type NetResult<T> = std::result::Result<T, NetError>;
#[derive(Debug, Copy, Clone)]
pub enum NetError {
OperationFailed,
ConnectFailed,
ConnectTimeout,
ChannelStopped,
ChannelTimeout,
ServiceStopped,
}
impl std::error::Error for NetError {}
impl fmt::Display for NetError {
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
match *self {
NetError::OperationFailed => f.write_str("Operation failed"),
NetError::ConnectFailed => f.write_str("Connection failed"),
NetError::ConnectTimeout => f.write_str("Connection timed out"),
NetError::ChannelStopped => f.write_str("Channel stopped"),
NetError::ChannelTimeout => f.write_str("Channel timed out"),
NetError::ServiceStopped => f.write_str("Service stopped"),
}
}
}

View File

@@ -14,7 +14,7 @@ pub type MessageSubscriptionID = u64;
macro_rules! receive_message {
($sub:expr, $message_type:path) => {{
let wrapped_message = OwningRef::new($sub.receive().await?);
let wrapped_message = owning_ref::OwningRef::new($sub.receive().await?);
wrapped_message.map(|msg| match msg {
$message_type(msg_detail) => msg_detail,

View File

@@ -29,7 +29,6 @@ pub enum PacketType {
Pong = 1,
GetAddrs = 2,
Addrs = 3,
Sync = 4,
Inv = 5,
GetSlabs = 6,
Slab = 7,
@@ -38,11 +37,10 @@ pub enum PacketType {
}
pub enum Message {
Ping,
Pong,
Ping(PingMessage),
Pong(PongMessage),
GetAddrs(GetAddrsMessage),
Addrs(AddrsMessage),
Sync,
Inv(InvMessage),
GetSlabs(GetSlabsMessage),
Slab(SlabMessage),
@@ -50,6 +48,14 @@ pub enum Message {
Verack(VerackMessage),
}
pub struct PingMessage {
pub nonce: u32
}
pub struct PongMessage {
pub nonce: u32
}
pub struct GetAddrsMessage {}
pub struct GetSlabsMessage {
@@ -74,6 +80,38 @@ pub struct VersionMessage {}
pub struct VerackMessage {}
impl Encodable for PingMessage {
fn encode<S: io::Write>(&self, mut s: S) -> Result<usize> {
let mut len = 0;
len += self.nonce.encode(&mut s)?;
Ok(len)
}
}
impl Decodable for PingMessage {
fn decode<D: io::Read>(mut d: D) -> Result<Self> {
Ok(Self {
nonce: Decodable::decode(&mut d)?,
})
}
}
impl Encodable for PongMessage {
fn encode<S: io::Write>(&self, mut s: S) -> Result<usize> {
let mut len = 0;
len += self.nonce.encode(&mut s)?;
Ok(len)
}
}
impl Decodable for PongMessage {
fn decode<D: io::Read>(mut d: D) -> Result<Self> {
Ok(Self {
nonce: Decodable::decode(&mut d)?,
})
}
}
impl Encodable for GetSlabsMessage {
fn encode<S: io::Write>(&self, mut s: S) -> Result<usize> {
let mut len = 0;
@@ -180,11 +218,10 @@ impl Decodable for VerackMessage {
impl Message {
pub fn packet_type(&self) -> PacketType {
match self {
Message::Ping => PacketType::Ping,
Message::Pong => PacketType::Pong,
Message::Ping(message) => PacketType::Ping,
Message::Pong(message) => PacketType::Pong,
Message::GetAddrs(message) => PacketType::GetAddrs,
Message::Addrs(message) => PacketType::Addrs,
Message::Sync => PacketType::Sync,
Message::Inv(message) => PacketType::Inv,
Message::GetSlabs(message) => PacketType::GetSlabs,
Message::Slab(message) => PacketType::Slab,
@@ -195,14 +232,22 @@ impl Message {
pub fn pack(&self) -> Result<Packet> {
match self {
Message::Ping => Ok(Packet {
command: PacketType::Ping,
payload: Vec::new(),
}),
Message::Pong => Ok(Packet {
command: PacketType::Pong,
payload: Vec::new(),
}),
Message::Ping(message) => {
let mut payload = Vec::new();
message.encode(&mut payload)?;
Ok(Packet {
command: PacketType::Ping,
payload: Vec::new(),
})
}
Message::Pong(message) => {
let mut payload = Vec::new();
message.encode(&mut payload)?;
Ok(Packet {
command: PacketType::Pong,
payload: Vec::new(),
})
}
Message::GetAddrs(message) => {
let mut payload = Vec::new();
message.encode(&mut payload)?;
@@ -219,13 +264,6 @@ impl Message {
payload,
})
}
Message::Sync => {
let payload = Vec::new();
Ok(Packet {
command: PacketType::Sync,
payload,
})
}
Message::Inv(message) => {
let payload = serialize(message);
Ok(Packet {
@@ -267,11 +305,10 @@ impl Message {
pub fn unpack(packet: Packet) -> Result<Self> {
let cursor = Cursor::new(packet.payload.clone());
match packet.command {
PacketType::Ping => Ok(Self::Ping),
PacketType::Pong => Ok(Self::Pong),
PacketType::Ping => Ok(Self::Ping(PingMessage::decode(cursor)?)),
PacketType::Pong => Ok(Self::Pong(PongMessage::decode(cursor)?)),
PacketType::GetAddrs => Ok(Self::GetAddrs(GetAddrsMessage::decode(cursor)?)),
PacketType::Addrs => Ok(Self::Addrs(AddrsMessage::decode(cursor)?)),
PacketType::Sync => Ok(Self::Sync),
PacketType::Inv => Ok(Self::Inv(InvMessage::decode(cursor)?)),
PacketType::GetSlabs => Ok(Self::GetSlabs(GetSlabsMessage::decode(cursor)?)),
PacketType::Slab => Ok(Self::Slab(SlabMessage::decode(cursor)?)),
@@ -282,11 +319,10 @@ impl Message {
pub fn name(&self) -> &'static str {
match self {
Message::Ping => "Ping",
Message::Pong => "Pong",
Message::Ping(_) => "Ping",
Message::Pong(_) => "Pong",
Message::GetAddrs(_) => "GetAddrs",
Message::Addrs(_) => "Addrs",
Message::Sync => "Sync",
Message::Inv(_) => "Inv",
Message::GetSlabs(_) => "GetSlabs",
Message::Slab(_) => "Slab",

View File

@@ -1,7 +1,11 @@
pub mod protocol_address;
pub mod protocol_ping;
pub mod protocol_seed;
pub mod protocol_version;
pub mod protocol_jobs_manager;
pub use protocol_address::ProtocolAddress;
pub use protocol_ping::ProtocolPing;
pub use protocol_seed::ProtocolSeed;
pub use protocol_version::ProtocolVersion;
pub use protocol_jobs_manager::{ProtocolJobsManager, ProtocolJobsManagerPtr};

View File

@@ -0,0 +1,28 @@
use futures::FutureExt;
use log::*;
use rand::Rng;
use smol::{Executor, Task};
use std::sync::Arc;
use crate::net::error::{NetError, NetResult};
use crate::net::messages;
use crate::net::utility::sleep;
use crate::net::{ChannelPtr, SettingsPtr};
use crate::net::protocols::{ProtocolJobsManager, ProtocolJobsManagerPtr};
pub struct ProtocolAddress {
channel: ChannelPtr,
settings: SettingsPtr,
jobsman: ProtocolJobsManagerPtr
}
impl ProtocolAddress {
pub fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc<Self> {
Arc::new(Self { channel: channel.clone(), settings, jobsman: ProtocolJobsManager::new(channel) })
}
pub async fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) {
}
}

View File

@@ -0,0 +1,53 @@
use std::sync::Arc;
use smol::Task;
use futures::Future;
use async_std::sync::Mutex;
use crate::net::error::NetResult;
use crate::net::ChannelPtr;
use crate::system::ExecutorPtr;
pub type ProtocolJobsManagerPtr = Arc<ProtocolJobsManager>;
pub struct ProtocolJobsManager {
channel: ChannelPtr,
tasks: Mutex<Vec<Task<NetResult<()>>>>
}
impl ProtocolJobsManager {
pub fn new(channel: ChannelPtr) -> Arc<Self> {
Arc::new(Self {
channel,
tasks: Mutex::new(Vec::new())
})
}
pub fn start(self: Arc<Self>, executor: ExecutorPtr<'_>) {
executor.spawn(self.handle_stop()).detach()
}
pub async fn spawn<'a, F>(&self, future: F, executor: ExecutorPtr<'a>)
where
F: Future<Output=NetResult<()>> + Send + 'a
{
self.tasks.lock().await.push(executor.spawn(future))
}
async fn handle_stop(self: Arc<Self>) {
let stop_sub = self.channel.clone().subscribe_stop().await;
// Wait for the stop signal
// Not interested in the exact error
let _ = stop_sub.receive().await;
self.close_all_tasks().await
}
async fn close_all_tasks(self: Arc<Self>) {
let tasks = std::mem::take(&mut *self.tasks.lock().await);
for task in tasks {
let _ = task.cancel().await;
}
}
}

View File

@@ -1,25 +1,31 @@
use futures::FutureExt;
use log::*;
use rand::Rng;
use smol::{Executor, Task};
use std::sync::Arc;
use crate::net::error::{NetError, NetResult};
use crate::net::messages;
use crate::net::utility::{clone_net_error, sleep};
use crate::net::utility::sleep;
use crate::net::{ChannelPtr, SettingsPtr};
use crate::net::protocols::{ProtocolJobsManager, ProtocolJobsManagerPtr};
pub struct ProtocolPing {
channel: ChannelPtr,
settings: SettingsPtr,
jobsman: ProtocolJobsManagerPtr
}
impl ProtocolPing {
pub fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc<Self> {
Arc::new(Self { channel, settings })
Arc::new(Self { channel: channel.clone(), settings, jobsman: ProtocolJobsManager::new(channel) })
}
pub fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) -> Task<NetResult<()>> {
executor.spawn(self.run_ping_pong())
pub async fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) {
self.jobsman.clone().start(executor.clone());
self.jobsman.clone().spawn(self.clone().run_ping_pong(), executor.clone()).await;
self.jobsman.clone().spawn(self.reply_to_ping(), executor).await;
}
async fn run_ping_pong(self: Arc<Self>) -> NetResult<()> {
@@ -34,18 +40,40 @@ impl ProtocolPing {
sleep(self.settings.channel_heartbeat_seconds).await;
// Create a random nonce
let _nonce = Self::random_nonce();
// TODO: add the nonce after delete other crappy network code
let nonce = Self::random_nonce();
// Send ping message
let ping = messages::Message::Ping;
let ping = messages::Message::Ping(messages::PingMessage {
nonce
});
self.channel.clone().send(ping).await?;
// Wait for pong, check nonce matches
let _pong_msg = pong_sub.receive().await?;
// TODO: fix pong enum
//let _pong_msg = receive_message!(pong_sub, messages::Message::Pong);
// TODO: add nonce check here
let pong_msg = receive_message!(pong_sub, messages::Message::Pong);
if pong_msg.nonce != nonce {
error!("Wrong nonce for ping reply. Disconnecting from channel.");
self.channel.stop().await;
return Err(NetError::ChannelStopped);
}
}
}
async fn reply_to_ping(self: Arc<Self>) -> NetResult<()> {
let ping_sub = self
.channel
.clone()
.subscribe_msg(messages::PacketType::Ping)
.await;
loop {
// Wait for ping, reply with pong that has a matching nonce
let ping = receive_message!(ping_sub, messages::Message::Ping);
// Send ping message
let pong = messages::Message::Pong(messages::PongMessage {
nonce: ping.nonce
});
self.channel.clone().send(pong).await?;
}
}

View File

@@ -4,11 +4,10 @@ use std::net::SocketAddr;
use std::sync::{Arc, Weak};
use crate::net::error::{NetError, NetResult};
use crate::net::protocols::{ProtocolPing, ProtocolSeed};
use crate::net::protocols::{ProtocolPing, ProtocolAddress, ProtocolSeed};
use crate::net::sessions::Session;
use crate::net::{Acceptor, AcceptorPtr};
use crate::net::{ChannelPtr, Connector, HostsPtr, P2p, SettingsPtr};
use crate::net::utility::clone_net_error;
use crate::system::{StoppableTask, StoppableTaskPtr};
pub struct InboundSession {
@@ -41,7 +40,7 @@ impl InboundSession {
}
self.accept_task.clone().start(
self.clone().channel_sub_loop(),
self.clone().channel_sub_loop(executor.clone()),
// Ignore stop handler
|_| { async {} },
NetError::ServiceStopped,
@@ -60,23 +59,50 @@ impl InboundSession {
executor: Arc<Executor<'_>>,
) -> NetResult<()> {
info!("Starting inbound session on {}", accept_addr);
match self.acceptor.clone().start(accept_addr, executor) {
Ok(()) => {
}
Err(err) => {
error!("Error starting listener: {}", err);
return Err(err);
}
let result = self.acceptor.clone().start(accept_addr, executor);
if let Err(err) = result {
error!("Error starting listener: {}", err);
}
Ok(())
result
}
async fn channel_sub_loop(self: Arc<Self>) -> NetResult<()> {
async fn channel_sub_loop(self: Arc<Self>, executor: Arc<Executor<'_>>) -> NetResult<()> {
let channel_sub = self.acceptor.clone().subscribe().await;
loop {
//let channel = (*channel_sub.receive().await)?;
let channel = (*channel_sub.receive().await).clone()?;
// Spawn a detached task to process the channel
// This will just perform the channel setup then exit.
executor.spawn(self.clone().setup_channel(channel, executor.clone())).detach();
}
}
async fn setup_channel(self: Arc<Self>, channel: ChannelPtr, executor: Arc<Executor<'_>>) -> NetResult<()> {
info!("Connected inbound [{}]", channel.address());
self.clone()
.register_channel(channel.clone(), executor.clone())
.await?;
let settings = self.p2p.upgrade().unwrap().settings();
self.attach_protocols(channel, settings, executor)
.await
}
async fn attach_protocols(
self: Arc<Self>,
channel: ChannelPtr,
settings: SettingsPtr,
executor: Arc<Executor<'_>>,
) -> NetResult<()> {
let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone());
protocol_ping.start(executor.clone()).await;
let protocol_addr = ProtocolAddress::new(channel, settings);
protocol_addr.start(executor).await;
Ok(())
}
}
impl Session for InboundSession {

View File

@@ -79,19 +79,6 @@ impl SeedSession {
}
}
async fn register_channel(
self: Arc<Self>,
channel: ChannelPtr,
executor: Arc<Executor<'_>>,
) -> NetResult<()> {
let handshake_task = self.perform_handshake_protocols(channel.clone(), executor.clone());
// start channel
channel.start(executor);
handshake_task.await
}
async fn attach_protocols(
self: Arc<Self>,
channel: ChannelPtr,
@@ -100,15 +87,11 @@ impl SeedSession {
executor: Arc<Executor<'_>>,
) -> NetResult<()> {
let protocol_ping = ProtocolPing::new(channel.clone(), settings.clone());
let ping_task = protocol_ping.start(executor.clone());
protocol_ping.start(executor.clone()).await;
let protocol_seed = ProtocolSeed::new(channel, hosts, settings.clone());
protocol_seed.start(executor.clone()).await?;
// Close the ping task now we finished.
// TODO: channel drop should trigger this automatically anyway via the stop signal
ping_task.cancel().await;
Ok(())
}
}

View File

@@ -17,7 +17,20 @@ async fn remove_sub_on_stop(p2p: P2pPtr, channel: ChannelPtr) {
}
#[async_trait]
pub trait Session {
pub trait Session: Sync {
async fn register_channel(
self: Arc<Self>,
channel: ChannelPtr,
executor: Arc<Executor<'_>>,
) -> NetResult<()> {
let handshake_task = self.perform_handshake_protocols(channel.clone(), executor.clone());
// start channel
channel.start(executor);
handshake_task.await
}
async fn perform_handshake_protocols(
&self,
channel: ChannelPtr,

View File

@@ -1,18 +1,7 @@
use smol::{Async, Executor, Timer};
use smol::Timer;
use std::time::Duration;
use crate::error::Error;
pub async fn sleep(seconds: u32) {
Timer::after(Duration::from_secs(seconds.into())).await;
}
pub fn clone_net_error(error: &Error) -> Error {
match error {
Error::ConnectFailed => Error::ConnectFailed,
Error::ConnectTimeout => Error::ConnectTimeout,
Error::ChannelStopped => Error::ChannelStopped,
Error::ChannelTimeout => Error::ChannelTimeout,
_ => Error::OperationFailed,
}
}

View File

@@ -1,5 +1,8 @@
pub mod stoppable_task;
pub mod subscriber;
pub mod types;
pub use stoppable_task::{StoppableTask, StoppableTaskPtr};
pub use subscriber::{Subscriber, SubscriberPtr, Subscription};
pub use types::ExecutorPtr;

5
src/system/types.rs Normal file
View File

@@ -0,0 +1,5 @@
use smol::Executor;
use std::sync::Arc;
pub type ExecutorPtr<'a> = Arc<Executor<'a>>;