From 4e5b844ba98e6656b39d85d529e335c90ba148c9 Mon Sep 17 00:00:00 2001 From: narodnik Date: Sat, 6 Mar 2021 09:28:50 +0100 Subject: [PATCH] integrate new message subsystem --- src/bin/dfi.rs | 4 +- src/error.rs | 8 ++ src/net/acceptor.rs | 2 +- src/net/channel.rs | 53 ++++--- src/net/connector.rs | 2 +- src/net/message_subscriber.rs | 193 +++++++------------------- src/net/messages.rs | 34 ++++- src/net/mod.rs | 1 - src/net/protocols/protocol_address.rs | 18 +-- src/net/protocols/protocol_ping.rs | 14 +- src/net/protocols/protocol_seed.rs | 7 +- src/net/protocols/protocol_version.rs | 14 +- 12 files changed, 154 insertions(+), 196 deletions(-) diff --git a/src/bin/dfi.rs b/src/bin/dfi.rs index 4915e17cc..c245a5332 100644 --- a/src/bin/dfi.rs +++ b/src/bin/dfi.rs @@ -138,8 +138,8 @@ impl RpcInterface { } async fn start(executor: Arc>, options: ProgramOptions) -> Result<()> { - sapvi::net::message_subscriber::doteste().await; - return Ok(()); + //sapvi::net::message_subscriber::doteste().await; + //return Ok(()); let p2p = net::P2p::new(options.network_settings); diff --git a/src/error.rs b/src/error.rs index 260aacd0f..fe9faf488 100644 --- a/src/error.rs +++ b/src/error.rs @@ -39,6 +39,7 @@ pub enum Error { ChannelStopped, ChannelTimeout, ServiceStopped, + Utf8Error, } impl std::error::Error for Error {} @@ -80,6 +81,7 @@ impl fmt::Display for Error { Error::ChannelStopped => f.write_str("Channel stopped"), Error::ChannelTimeout => f.write_str("Channel timed out"), Error::ServiceStopped => f.write_str("Service stopped"), + Error::Utf8Error => f.write_str("Malformed UTF8"), } } } @@ -138,3 +140,9 @@ impl From for Error { } } } + +impl From for Error { + fn from(err: std::string::FromUtf8Error) -> Error { + Error::Utf8Error + } +} diff --git a/src/net/acceptor.rs b/src/net/acceptor.rs index a81f38d3b..51429ad44 100644 --- a/src/net/acceptor.rs +++ b/src/net/acceptor.rs @@ -104,7 +104,7 @@ impl Acceptor { }; info!("Accepted client: {}", peer_addr); - let channel = Channel::new(stream, peer_addr, self.settings.clone()); + let channel = Channel::new(stream, peer_addr, self.settings.clone()).await; Ok(channel) } } diff --git a/src/net/channel.rs b/src/net/channel.rs index e97b713cf..110ff1210 100644 --- a/src/net/channel.rs +++ b/src/net/channel.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use crate::error; use crate::net::error::{NetError, NetResult}; use crate::net::message_subscriber::{ - MessageSubscriber, MessageSubscriberPtr, MessageSubscription, + MessageSubsystem, MessageSubscription, Message2 }; use crate::net::messages; use crate::net::settings::SettingsPtr; @@ -24,7 +24,7 @@ pub struct Channel { reader: Mutex>>, writer: Mutex>>, address: SocketAddr, - message_subscriber: MessageSubscriberPtr, + message_subsystem: MessageSubsystem, stop_subscriber: SubscriberPtr, receive_task: StoppableTaskPtr, stopped: AtomicBool, @@ -32,15 +32,19 @@ pub struct Channel { } impl Channel { - pub fn new(stream: Async, address: SocketAddr, settings: SettingsPtr) -> Arc { + pub async fn new(stream: Async, address: SocketAddr, settings: SettingsPtr) -> Arc { let (reader, writer) = stream.split(); let reader = Mutex::new(reader); let writer = Mutex::new(writer); + + let message_subsystem = MessageSubsystem::new(); + Self::setup_dispatchers(&message_subsystem).await; + Arc::new(Self { reader, writer, address, - message_subscriber: MessageSubscriber::new(), + message_subsystem, stop_subscriber: Subscriber::new(), receive_task: StoppableTask::new(), stopped: AtomicBool::new(false), @@ -52,7 +56,7 @@ impl Channel { debug!(target: "net", "Channel::start() [START, address={}]", self.address()); let self2 = self.clone(); self.receive_task.clone().start( - self.clone().receive_loop(), + self.clone().main_receive_loop(), // Ignore stop handler |result| self2.handle_stop(result), NetError::ServiceStopped, @@ -114,19 +118,18 @@ impl Channel { result } - pub async fn subscribe_msg( - self: Arc, - packet_type: messages::PacketType, - ) -> MessageSubscription { + pub async fn subscribe_msg( + self: Arc + ) -> NetResult> { debug!(target: "net", - "Channel::subscribe_msg() [START, pkt_type={:?}, address={}]", - packet_type, + "Channel::subscribe_msg() [START, command={:?}, address={}]", + M::name(), self.address() ); - let sub = self.message_subscriber.clone().subscribe(packet_type).await; + let sub = self.message_subsystem.subscribe::().await; debug!(target: "net", - "Channel::subscribe_msg() [END, pkt_type={:?}, address={}]", - packet_type, + "Channel::subscribe_msg() [END, command={:?}, address={}]", + M::name(), self.address() ); sub @@ -143,17 +146,26 @@ impl Channel { } } - async fn receive_loop(self: Arc) -> NetResult<()> { + async fn setup_dispatchers(message_subsystem: &MessageSubsystem) { + message_subsystem.add_dispatch::().await; + message_subsystem.add_dispatch::().await; + message_subsystem.add_dispatch::().await; + message_subsystem.add_dispatch::().await; + message_subsystem.add_dispatch::().await; + message_subsystem.add_dispatch::().await; + } + + async fn main_receive_loop(self: Arc) -> NetResult<()> { debug!(target: "net", "Channel::receive_loop() [START, address={}]", self.address() ); + let reader = &mut *self.reader.lock().await; loop { - let message_result = messages::receive_message(reader).await; - let message = match message_result { - Ok(message) => Arc::new(message), + let packet = match messages::read_packet(reader).await { + Ok(packet) => packet, Err(err) => { if Self::is_eof_error(&err) { info!("Channel {} disconnected", self.address()); @@ -170,7 +182,7 @@ impl Channel { }; // Send result to our subscribers - self.message_subscriber.notify(Ok(message)).await; + self.message_subsystem.notify(&packet.command2, packet.payload).await; } } @@ -180,8 +192,7 @@ impl Channel { Ok(()) => panic!("Channel task should never complete without error status"), Err(err) => { // Send this error to all channel subscribers - let result = Err(err); - self.message_subscriber.notify(result).await; + self.message_subsystem.trigger_error(err).await; } } debug!(target: "net", "Channel::handle_stop() [END, address={}]", self.address()); diff --git a/src/net/connector.rs b/src/net/connector.rs index 1da495212..d9481007e 100644 --- a/src/net/connector.rs +++ b/src/net/connector.rs @@ -19,7 +19,7 @@ impl Connector { futures::select! { stream_result = Async::::connect(hostaddr).fuse() => { match stream_result { - Ok(stream) => Ok(Channel::new(stream, hostaddr, self.settings.clone())), + Ok(stream) => Ok(Channel::new(stream, hostaddr, self.settings.clone()).await), Err(_) => Err(NetError::ConnectFailed) } } diff --git a/src/net/message_subscriber.rs b/src/net/message_subscriber.rs index 3b3940cbd..756f26414 100644 --- a/src/net/message_subscriber.rs +++ b/src/net/message_subscriber.rs @@ -14,148 +14,22 @@ use crate::net::messages::{Message, PacketType}; use crate::serial::Decodable; use crate::serial::Encodable; -pub type MessageSubscriberPtr = Arc; -pub type MessageResult = NetResult>; pub type MessageSubscriptionID = u64; - -macro_rules! receive_message { - ($sub:expr, $message_type:path) => {{ - let wrapped_message = owning_ref::OwningRef::new($sub.receive().await?); - - wrapped_message.map(|msg| match msg { - $message_type(msg_detail) => msg_detail, - _ => { - panic!("Filter for receive sub invalid!"); - } - }) - }}; -} - -pub struct MessageSubscription { - id: MessageSubscriptionID, - filter: PacketType, - recv_queue: async_channel::Receiver, - parent: Arc, -} - -impl MessageSubscription { - fn is_relevant_message(&self, message_result: &MessageResult) -> bool { - match message_result { - Ok(message) => { - let packet_type = message.packet_type(); - - // Apply the filter - packet_type == self.filter - } - Err(_) => { - // Propagate all errors - true - } - } - } - - pub async fn receive(&self) -> MessageResult { - loop { - let message_result = self.recv_queue.recv().await; - - match message_result { - Ok(message_result) => { - if self.clone().is_relevant_message(&message_result) { - return message_result; - } - } - Err(err) => { - panic!("MessageSubscription::receive() recv_queue failed! {}", err); - } - } - } - } - - // Must be called manually since async Drop is not possible in Rust - pub async fn unsubscribe(&self) { - self.parent.clone().unsubscribe(self.id).await - } -} - -pub struct MessageSubscriber { - subs: Mutex>>, -} - -impl MessageSubscriber { - pub fn new() -> Arc { - Arc::new(Self { - subs: Mutex::new(HashMap::new()), - }) - } - - pub fn random_id() -> MessageSubscriptionID { - let mut rng = rand::thread_rng(); - rng.gen() - } - - pub async fn subscribe(self: Arc, packet_type: PacketType) -> MessageSubscription { - let (sender, recvr) = async_channel::unbounded(); - - let sub_id = Self::random_id(); - - self.subs.lock().await.insert(sub_id, sender); - - MessageSubscription { - id: sub_id, - filter: packet_type, - recv_queue: recvr, - parent: self.clone(), - } - } - - async fn unsubscribe(self: Arc, sub_id: MessageSubscriptionID) { - self.subs.lock().await.remove(&sub_id); - } - - pub async fn notify(&self, message_result: NetResult>) { - let mut garbage_ids = Vec::new(); - - for (sub_id, sub) in &*self.subs.lock().await { - match sub.send(message_result.clone()).await { - Ok(()) => {} - Err(_err) => { - // Automatically clean out closed channels - garbage_ids.push(*sub_id); - //panic!("Error returned sending message in notify() call! {}", err); - } - } - } - - self.collect_garbage(garbage_ids).await; - } - - async fn collect_garbage(&self, ids: Vec) { - let mut subs = self.subs.lock().await; - for id in &ids { - subs.remove(id); - } - } -} - -// -// +type MessageResult = NetResult>; pub trait Message2: 'static + Decodable + Send + Sync { fn name() -> &'static str; - - fn deserialize(); - fn serialize(); } -pub struct MessageSubscription2 { +pub struct MessageSubscription { id: MessageSubscriptionID, - recv_queue: async_channel::Receiver>>, + recv_queue: async_channel::Receiver>, parent: Arc>, } -impl MessageSubscription2 { - pub async fn receive(&self) -> NetResult> { +impl MessageSubscription { + pub async fn receive(&self) -> MessageResult { match self.recv_queue.recv().await { Ok(message) => message, Err(err) => { @@ -171,7 +45,7 @@ impl MessageSubscription2 { } #[async_trait] -trait MessageDispatcherInterface: Sync { +trait MessageDispatcherInterface: Send + Sync { async fn trigger(&self, payload: Vec); async fn trigger_error(&self, err: NetError); @@ -180,7 +54,7 @@ trait MessageDispatcherInterface: Sync { } struct MessageDispatcher { - subs: Mutex>>>>, + subs: Mutex>>>, } impl MessageDispatcher { @@ -195,12 +69,12 @@ impl MessageDispatcher { rng.gen() } - pub async fn subscribe(self: Arc) -> MessageSubscription2 { + pub async fn subscribe(self: Arc) -> MessageSubscription { let (sender, recvr) = async_channel::unbounded(); let sub_id = Self::random_id(); self.subs.lock().await.insert(sub_id, sender); - MessageSubscription2 { + MessageSubscription { id: sub_id, recv_queue: recvr, parent: self, @@ -211,7 +85,7 @@ impl MessageDispatcher { self.subs.lock().await.remove(&sub_id); } - async fn trigger_all(&self, message: NetResult>) { + async fn trigger_all(&self, message: MessageResult) { let mut garbage_ids = Vec::new(); for (sub_id, sub) in &*self.subs.lock().await { @@ -262,6 +136,44 @@ impl MessageDispatcherInterface for MessageDispatcher { } } +use crate::net::messages::{PingMessage, PongMessage, GetAddrsMessage, AddrsMessage, VersionMessage, VerackMessage}; + +impl Message2 for PingMessage { + fn name() -> &'static str { + "ping" + } +} + +impl Message2 for PongMessage { + fn name() -> &'static str { + "pong" + } +} + +impl Message2 for GetAddrsMessage { + fn name() -> &'static str { + "getaddr" + } +} + +impl Message2 for AddrsMessage { + fn name() -> &'static str { + "addr" + } +} + +impl Message2 for VersionMessage { + fn name() -> &'static str { + "version" + } +} + +impl Message2 for VerackMessage { + fn name() -> &'static str { + "verack" + } +} + struct MyVersionMessage { x: u32, } @@ -270,9 +182,6 @@ impl Message2 for MyVersionMessage { fn name() -> &'static str { "verver" } - - fn deserialize() {} - fn serialize() {} } impl Encodable for MyVersionMessage { @@ -291,7 +200,7 @@ impl Decodable for MyVersionMessage { } } -struct MessageSubsystem { +pub struct MessageSubsystem { dispatchers: Mutex>>, } @@ -309,12 +218,12 @@ impl MessageSubsystem { .insert(M::name(), Arc::new(MessageDispatcher::::new())); } - pub async fn subscribe(&self) -> NetResult> { + pub async fn subscribe(&self) -> NetResult> { let dispatcher = self .dispatchers .lock() .await - .get(MyVersionMessage::name()) + .get(M::name()) .cloned(); let sub = match dispatcher { diff --git a/src/net/messages.rs b/src/net/messages.rs index 357b3b635..43c1804b5 100644 --- a/src/net/messages.rs +++ b/src/net/messages.rs @@ -237,6 +237,7 @@ impl Message { message.encode(&mut payload)?; Ok(Packet { command: PacketType::Ping, + command2: String::from(self.name()), payload, }) } @@ -245,6 +246,7 @@ impl Message { message.encode(&mut payload)?; Ok(Packet { command: PacketType::Pong, + command2: String::from(self.name()), payload, }) } @@ -253,6 +255,7 @@ impl Message { message.encode(&mut payload)?; Ok(Packet { command: PacketType::GetAddrs, + command2: String::from(self.name()), payload, }) } @@ -261,6 +264,7 @@ impl Message { message.encode(Cursor::new(&mut payload))?; Ok(Packet { command: PacketType::Addrs, + command2: String::from(self.name()), payload, }) } @@ -268,6 +272,7 @@ impl Message { let payload = serialize(message); Ok(Packet { command: PacketType::Inv, + command2: String::from(self.name()), payload, }) } @@ -275,6 +280,7 @@ impl Message { let payload = serialize(message); Ok(Packet { command: PacketType::GetSlabs, + command2: String::from(self.name()), payload, }) } @@ -282,6 +288,7 @@ impl Message { let payload = serialize(message); Ok(Packet { command: PacketType::Slab, + command2: String::from(self.name()), payload, }) } @@ -289,6 +296,7 @@ impl Message { let payload = serialize(message); Ok(Packet { command: PacketType::Version, + command2: String::from(self.name()), payload, }) } @@ -296,6 +304,7 @@ impl Message { let payload = serialize(message); Ok(Packet { command: PacketType::Verack, + command2: String::from(self.name()), payload, }) } @@ -336,6 +345,7 @@ impl Message { // These are converted to messages and passed to event loop pub struct Packet { pub command: PacketType, + pub command2: String, pub payload: Vec, } @@ -351,9 +361,16 @@ pub async fn read_packet(stream: &mut R) -> Result } // The type of the message - let command = AsyncReadExt::read_u8(stream).await?; + //let command = AsyncReadExt::read_u8(stream).await?; + //debug!(target: "net", "read command: {}", command); + //let command = PacketType::try_from(command).map_err(|_| Error::MalformedPacket)?; + let command_len = VarInt::decode_async(stream).await?.0 as usize; + let mut command = vec![0u8; command_len]; + if command_len > 0 { + stream.read_exact(&mut command).await?; + } + let command = String::from_utf8(command)?; debug!(target: "net", "read command: {}", command); - let command = PacketType::try_from(command).map_err(|_| Error::MalformedPacket)?; let payload_len = VarInt::decode_async(stream).await?.0 as usize; @@ -364,7 +381,7 @@ pub async fn read_packet(stream: &mut R) -> Result } debug!(target: "net", "read payload {} bytes", payload_len); - Ok(Packet { command, payload }) + Ok(Packet { command: PacketType::Verack, command2: command, payload }) } pub async fn send_packet(stream: &mut W, packet: Packet) -> Result<()> { @@ -372,8 +389,15 @@ pub async fn send_packet(stream: &mut W, packet: Packet) stream.write_all(&MAGIC_BYTES).await?; debug!(target: "net", "sent magic..."); - AsyncWriteExt::write_u8(stream, packet.command as u8).await?; - debug!(target: "net", "sent command: {}", packet.command as u8); + //AsyncWriteExt::write_u8(stream, packet.command as u8).await?; + //debug!(target: "net", "sent command: {}", packet.command as u8); + + VarInt(packet.command2.len() as u64) + .encode_async(stream) + .await?; + assert!(!packet.command2.is_empty()); + stream.write_all(&packet.command2.as_bytes()).await?; + debug!(target: "net", "sent command: {}", packet.command2); assert_eq!(std::mem::size_of::(), std::mem::size_of::()); VarInt(packet.payload.len() as u64) diff --git a/src/net/mod.rs b/src/net/mod.rs index 1dba05156..801e373b7 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -21,6 +21,5 @@ pub use acceptor::{Acceptor, AcceptorPtr}; pub use channel::{Channel, ChannelPtr}; pub use connector::Connector; pub use hosts::{Hosts, HostsPtr}; -pub use message_subscriber::{MessageSubscriber, MessageSubscription}; pub use p2p::P2p; pub use settings::{Settings, SettingsPtr}; diff --git a/src/net/protocols/protocol_address.rs b/src/net/protocols/protocol_address.rs index 8c2614030..5b777481e 100644 --- a/src/net/protocols/protocol_address.rs +++ b/src/net/protocols/protocol_address.rs @@ -11,8 +11,8 @@ use crate::net::{ChannelPtr, HostsPtr, SettingsPtr}; pub struct ProtocolAddress { channel: ChannelPtr, - addrs_sub: MessageSubscription, - get_addrs_sub: MessageSubscription, + addrs_sub: MessageSubscription, + get_addrs_sub: MessageSubscription, hosts: HostsPtr, settings: SettingsPtr, @@ -24,13 +24,15 @@ impl ProtocolAddress { pub async fn new(channel: ChannelPtr, hosts: HostsPtr, settings: SettingsPtr) -> Arc { let addrs_sub = channel .clone() - .subscribe_msg(messages::PacketType::Addrs) - .await; + .subscribe_msg::() + .await + .expect("Missing addrs dispatcher!"); let get_addrs_sub = channel .clone() - .subscribe_msg(messages::PacketType::GetAddrs) - .await; + .subscribe_msg::() + .await + .expect("Missing getaddrs dispatcher!"); Arc::new(Self { channel: channel.clone(), @@ -63,7 +65,7 @@ impl ProtocolAddress { async fn handle_receive_addrs(self: Arc) -> NetResult<()> { debug!(target: "net", "ProtocolAddress::handle_receive_addrs() [START]"); loop { - let addrs_msg = receive_message!(self.addrs_sub, messages::Message::Addrs); + let addrs_msg = self.addrs_sub.receive().await?; debug!(target: "net", "ProtocolAddress::handle_receive_addrs() storing address in hosts"); self.hosts.store(addrs_msg.addrs.clone()).await; @@ -73,7 +75,7 @@ impl ProtocolAddress { async fn handle_receive_get_addrs(self: Arc) -> NetResult<()> { debug!(target: "net", "ProtocolAddress::handle_receive_get_addrs() [START]"); loop { - let _get_addrs = receive_message!(self.get_addrs_sub, messages::Message::GetAddrs); + let _get_addrs = self.get_addrs_sub.receive().await?; debug!(target: "net", "ProtocolAddress::handle_receive_get_addrs() received GetAddrs message"); diff --git a/src/net/protocols/protocol_ping.rs b/src/net/protocols/protocol_ping.rs index 628f7c8b5..b36b210ab 100644 --- a/src/net/protocols/protocol_ping.rs +++ b/src/net/protocols/protocol_ping.rs @@ -45,8 +45,9 @@ impl ProtocolPing { let pong_sub = self .channel .clone() - .subscribe_msg(messages::PacketType::Pong) - .await; + .subscribe_msg::() + .await + .expect("Missing pong dispatcher!"); loop { // Wait channel_heartbeat amount of time @@ -63,7 +64,7 @@ impl ProtocolPing { let start = Instant::now(); // Wait for pong, check nonce matches - let pong_msg = receive_message!(pong_sub, messages::Message::Pong); + let pong_msg = pong_sub.receive().await?; if pong_msg.nonce != nonce { error!("Wrong nonce for ping reply. Disconnecting from channel."); self.channel.stop().await; @@ -79,12 +80,13 @@ impl ProtocolPing { let ping_sub = self .channel .clone() - .subscribe_msg(messages::PacketType::Ping) - .await; + .subscribe_msg::() + .await + .expect("Missing ping dispatcher!"); loop { // Wait for ping, reply with pong that has a matching nonce - let ping = receive_message!(ping_sub, messages::Message::Ping); + let ping = ping_sub.receive().await?; debug!(target: "net", "ProtocolPing::reply_to_ping() received Ping message"); // Send ping message diff --git a/src/net/protocols/protocol_seed.rs b/src/net/protocols/protocol_seed.rs index 05e3cbd35..9824a9f93 100644 --- a/src/net/protocols/protocol_seed.rs +++ b/src/net/protocols/protocol_seed.rs @@ -26,8 +26,9 @@ impl ProtocolSeed { let addr_sub = self .channel .clone() - .subscribe_msg(messages::PacketType::Addrs) - .await; + .subscribe_msg::() + .await + .expect("Missing addrs dispatcher!"); // Send own address to the seed server self.send_own_address().await?; @@ -37,7 +38,7 @@ impl ProtocolSeed { self.channel.clone().send(get_addr).await?; // Receive addresses - let addrs_msg = receive_message!(addr_sub, messages::Message::Addrs); + let addrs_msg = addr_sub.receive().await?; self.hosts.store(addrs_msg.addrs.clone()).await; debug!(target: "net", "ProtocolSeed::start() [END]"); diff --git a/src/net/protocols/protocol_version.rs b/src/net/protocols/protocol_version.rs index 6e096cfd5..928869334 100644 --- a/src/net/protocols/protocol_version.rs +++ b/src/net/protocols/protocol_version.rs @@ -11,8 +11,8 @@ use crate::net::{ChannelPtr, SettingsPtr}; pub struct ProtocolVersion { channel: ChannelPtr, - version_sub: MessageSubscription, - verack_sub: MessageSubscription, + version_sub: MessageSubscription, + verack_sub: MessageSubscription, settings: SettingsPtr, } @@ -20,13 +20,15 @@ impl ProtocolVersion { pub async fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc { let version_sub = channel .clone() - .subscribe_msg(messages::PacketType::Version) - .await; + .subscribe_msg::() + .await + .expect("Missing version dispatcher!"); let verack_sub = channel .clone() - .subscribe_msg(messages::PacketType::Verack) - .await; + .subscribe_msg::() + .await + .expect("Missing verack dispatcher!"); Arc::new(Self { channel,