integrate new message subsystem

This commit is contained in:
narodnik
2021-03-06 09:28:50 +01:00
parent 1431e17a2d
commit 4e5b844ba9
12 changed files with 154 additions and 196 deletions

View File

@@ -138,8 +138,8 @@ impl RpcInterface {
}
async fn start(executor: Arc<Executor<'_>>, options: ProgramOptions) -> Result<()> {
sapvi::net::message_subscriber::doteste().await;
return Ok(());
//sapvi::net::message_subscriber::doteste().await;
//return Ok(());
let p2p = net::P2p::new(options.network_settings);

View File

@@ -39,6 +39,7 @@ pub enum Error {
ChannelStopped,
ChannelTimeout,
ServiceStopped,
Utf8Error,
}
impl std::error::Error for Error {}
@@ -80,6 +81,7 @@ impl fmt::Display for Error {
Error::ChannelStopped => f.write_str("Channel stopped"),
Error::ChannelTimeout => f.write_str("Channel timed out"),
Error::ServiceStopped => f.write_str("Service stopped"),
Error::Utf8Error => f.write_str("Malformed UTF8"),
}
}
}
@@ -138,3 +140,9 @@ impl From<NetError> for Error {
}
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(err: std::string::FromUtf8Error) -> Error {
Error::Utf8Error
}
}

View File

@@ -104,7 +104,7 @@ impl Acceptor {
};
info!("Accepted client: {}", peer_addr);
let channel = Channel::new(stream, peer_addr, self.settings.clone());
let channel = Channel::new(stream, peer_addr, self.settings.clone()).await;
Ok(channel)
}
}

View File

@@ -12,7 +12,7 @@ use std::sync::Arc;
use crate::error;
use crate::net::error::{NetError, NetResult};
use crate::net::message_subscriber::{
MessageSubscriber, MessageSubscriberPtr, MessageSubscription,
MessageSubsystem, MessageSubscription, Message2
};
use crate::net::messages;
use crate::net::settings::SettingsPtr;
@@ -24,7 +24,7 @@ pub struct Channel {
reader: Mutex<ReadHalf<Async<TcpStream>>>,
writer: Mutex<WriteHalf<Async<TcpStream>>>,
address: SocketAddr,
message_subscriber: MessageSubscriberPtr,
message_subsystem: MessageSubsystem,
stop_subscriber: SubscriberPtr<NetError>,
receive_task: StoppableTaskPtr,
stopped: AtomicBool,
@@ -32,15 +32,19 @@ pub struct Channel {
}
impl Channel {
pub fn new(stream: Async<TcpStream>, address: SocketAddr, settings: SettingsPtr) -> Arc<Self> {
pub async fn new(stream: Async<TcpStream>, address: SocketAddr, settings: SettingsPtr) -> Arc<Self> {
let (reader, writer) = stream.split();
let reader = Mutex::new(reader);
let writer = Mutex::new(writer);
let message_subsystem = MessageSubsystem::new();
Self::setup_dispatchers(&message_subsystem).await;
Arc::new(Self {
reader,
writer,
address,
message_subscriber: MessageSubscriber::new(),
message_subsystem,
stop_subscriber: Subscriber::new(),
receive_task: StoppableTask::new(),
stopped: AtomicBool::new(false),
@@ -52,7 +56,7 @@ impl Channel {
debug!(target: "net", "Channel::start() [START, address={}]", self.address());
let self2 = self.clone();
self.receive_task.clone().start(
self.clone().receive_loop(),
self.clone().main_receive_loop(),
// Ignore stop handler
|result| self2.handle_stop(result),
NetError::ServiceStopped,
@@ -114,19 +118,18 @@ impl Channel {
result
}
pub async fn subscribe_msg(
self: Arc<Self>,
packet_type: messages::PacketType,
) -> MessageSubscription {
pub async fn subscribe_msg<M: Message2>(
self: Arc<Self>
) -> NetResult<MessageSubscription<M>> {
debug!(target: "net",
"Channel::subscribe_msg() [START, pkt_type={:?}, address={}]",
packet_type,
"Channel::subscribe_msg() [START, command={:?}, address={}]",
M::name(),
self.address()
);
let sub = self.message_subscriber.clone().subscribe(packet_type).await;
let sub = self.message_subsystem.subscribe::<M>().await;
debug!(target: "net",
"Channel::subscribe_msg() [END, pkt_type={:?}, address={}]",
packet_type,
"Channel::subscribe_msg() [END, command={:?}, address={}]",
M::name(),
self.address()
);
sub
@@ -143,17 +146,26 @@ impl Channel {
}
}
async fn receive_loop(self: Arc<Self>) -> NetResult<()> {
async fn setup_dispatchers(message_subsystem: &MessageSubsystem) {
message_subsystem.add_dispatch::<messages::VersionMessage>().await;
message_subsystem.add_dispatch::<messages::VerackMessage>().await;
message_subsystem.add_dispatch::<messages::PingMessage>().await;
message_subsystem.add_dispatch::<messages::PongMessage>().await;
message_subsystem.add_dispatch::<messages::GetAddrsMessage>().await;
message_subsystem.add_dispatch::<messages::AddrsMessage>().await;
}
async fn main_receive_loop(self: Arc<Self>) -> NetResult<()> {
debug!(target: "net",
"Channel::receive_loop() [START, address={}]",
self.address()
);
let reader = &mut *self.reader.lock().await;
loop {
let message_result = messages::receive_message(reader).await;
let message = match message_result {
Ok(message) => Arc::new(message),
let packet = match messages::read_packet(reader).await {
Ok(packet) => packet,
Err(err) => {
if Self::is_eof_error(&err) {
info!("Channel {} disconnected", self.address());
@@ -170,7 +182,7 @@ impl Channel {
};
// Send result to our subscribers
self.message_subscriber.notify(Ok(message)).await;
self.message_subsystem.notify(&packet.command2, packet.payload).await;
}
}
@@ -180,8 +192,7 @@ impl Channel {
Ok(()) => panic!("Channel task should never complete without error status"),
Err(err) => {
// Send this error to all channel subscribers
let result = Err(err);
self.message_subscriber.notify(result).await;
self.message_subsystem.trigger_error(err).await;
}
}
debug!(target: "net", "Channel::handle_stop() [END, address={}]", self.address());

View File

@@ -19,7 +19,7 @@ impl Connector {
futures::select! {
stream_result = Async::<TcpStream>::connect(hostaddr).fuse() => {
match stream_result {
Ok(stream) => Ok(Channel::new(stream, hostaddr, self.settings.clone())),
Ok(stream) => Ok(Channel::new(stream, hostaddr, self.settings.clone()).await),
Err(_) => Err(NetError::ConnectFailed)
}
}

View File

@@ -14,148 +14,22 @@ use crate::net::messages::{Message, PacketType};
use crate::serial::Decodable;
use crate::serial::Encodable;
pub type MessageSubscriberPtr = Arc<MessageSubscriber>;
pub type MessageResult = NetResult<Arc<Message>>;
pub type MessageSubscriptionID = u64;
macro_rules! receive_message {
($sub:expr, $message_type:path) => {{
let wrapped_message = owning_ref::OwningRef::new($sub.receive().await?);
wrapped_message.map(|msg| match msg {
$message_type(msg_detail) => msg_detail,
_ => {
panic!("Filter for receive sub invalid!");
}
})
}};
}
pub struct MessageSubscription {
id: MessageSubscriptionID,
filter: PacketType,
recv_queue: async_channel::Receiver<MessageResult>,
parent: Arc<MessageSubscriber>,
}
impl MessageSubscription {
fn is_relevant_message(&self, message_result: &MessageResult) -> bool {
match message_result {
Ok(message) => {
let packet_type = message.packet_type();
// Apply the filter
packet_type == self.filter
}
Err(_) => {
// Propagate all errors
true
}
}
}
pub async fn receive(&self) -> MessageResult {
loop {
let message_result = self.recv_queue.recv().await;
match message_result {
Ok(message_result) => {
if self.clone().is_relevant_message(&message_result) {
return message_result;
}
}
Err(err) => {
panic!("MessageSubscription::receive() recv_queue failed! {}", err);
}
}
}
}
// Must be called manually since async Drop is not possible in Rust
pub async fn unsubscribe(&self) {
self.parent.clone().unsubscribe(self.id).await
}
}
pub struct MessageSubscriber {
subs: Mutex<HashMap<MessageSubscriptionID, async_channel::Sender<MessageResult>>>,
}
impl MessageSubscriber {
pub fn new() -> Arc<Self> {
Arc::new(Self {
subs: Mutex::new(HashMap::new()),
})
}
pub fn random_id() -> MessageSubscriptionID {
let mut rng = rand::thread_rng();
rng.gen()
}
pub async fn subscribe(self: Arc<Self>, packet_type: PacketType) -> MessageSubscription {
let (sender, recvr) = async_channel::unbounded();
let sub_id = Self::random_id();
self.subs.lock().await.insert(sub_id, sender);
MessageSubscription {
id: sub_id,
filter: packet_type,
recv_queue: recvr,
parent: self.clone(),
}
}
async fn unsubscribe(self: Arc<Self>, sub_id: MessageSubscriptionID) {
self.subs.lock().await.remove(&sub_id);
}
pub async fn notify(&self, message_result: NetResult<Arc<Message>>) {
let mut garbage_ids = Vec::new();
for (sub_id, sub) in &*self.subs.lock().await {
match sub.send(message_result.clone()).await {
Ok(()) => {}
Err(_err) => {
// Automatically clean out closed channels
garbage_ids.push(*sub_id);
//panic!("Error returned sending message in notify() call! {}", err);
}
}
}
self.collect_garbage(garbage_ids).await;
}
async fn collect_garbage(&self, ids: Vec<MessageSubscriptionID>) {
let mut subs = self.subs.lock().await;
for id in &ids {
subs.remove(id);
}
}
}
//
//
type MessageResult<M> = NetResult<Arc<M>>;
pub trait Message2: 'static + Decodable + Send + Sync {
fn name() -> &'static str;
fn deserialize();
fn serialize();
}
pub struct MessageSubscription2<M: Message2> {
pub struct MessageSubscription<M: Message2> {
id: MessageSubscriptionID,
recv_queue: async_channel::Receiver<NetResult<Arc<M>>>,
recv_queue: async_channel::Receiver<MessageResult<M>>,
parent: Arc<MessageDispatcher<M>>,
}
impl<M: Message2> MessageSubscription2<M> {
pub async fn receive(&self) -> NetResult<Arc<M>> {
impl<M: Message2> MessageSubscription<M> {
pub async fn receive(&self) -> MessageResult<M> {
match self.recv_queue.recv().await {
Ok(message) => message,
Err(err) => {
@@ -171,7 +45,7 @@ impl<M: Message2> MessageSubscription2<M> {
}
#[async_trait]
trait MessageDispatcherInterface: Sync {
trait MessageDispatcherInterface: Send + Sync {
async fn trigger(&self, payload: Vec<u8>);
async fn trigger_error(&self, err: NetError);
@@ -180,7 +54,7 @@ trait MessageDispatcherInterface: Sync {
}
struct MessageDispatcher<M: Message2> {
subs: Mutex<HashMap<MessageSubscriptionID, async_channel::Sender<NetResult<Arc<M>>>>>,
subs: Mutex<HashMap<MessageSubscriptionID, async_channel::Sender<MessageResult<M>>>>,
}
impl<M: Message2> MessageDispatcher<M> {
@@ -195,12 +69,12 @@ impl<M: Message2> MessageDispatcher<M> {
rng.gen()
}
pub async fn subscribe(self: Arc<Self>) -> MessageSubscription2<M> {
pub async fn subscribe(self: Arc<Self>) -> MessageSubscription<M> {
let (sender, recvr) = async_channel::unbounded();
let sub_id = Self::random_id();
self.subs.lock().await.insert(sub_id, sender);
MessageSubscription2 {
MessageSubscription {
id: sub_id,
recv_queue: recvr,
parent: self,
@@ -211,7 +85,7 @@ impl<M: Message2> MessageDispatcher<M> {
self.subs.lock().await.remove(&sub_id);
}
async fn trigger_all(&self, message: NetResult<Arc<M>>) {
async fn trigger_all(&self, message: MessageResult<M>) {
let mut garbage_ids = Vec::new();
for (sub_id, sub) in &*self.subs.lock().await {
@@ -262,6 +136,44 @@ impl<M: Message2> MessageDispatcherInterface for MessageDispatcher<M> {
}
}
use crate::net::messages::{PingMessage, PongMessage, GetAddrsMessage, AddrsMessage, VersionMessage, VerackMessage};
impl Message2 for PingMessage {
fn name() -> &'static str {
"ping"
}
}
impl Message2 for PongMessage {
fn name() -> &'static str {
"pong"
}
}
impl Message2 for GetAddrsMessage {
fn name() -> &'static str {
"getaddr"
}
}
impl Message2 for AddrsMessage {
fn name() -> &'static str {
"addr"
}
}
impl Message2 for VersionMessage {
fn name() -> &'static str {
"version"
}
}
impl Message2 for VerackMessage {
fn name() -> &'static str {
"verack"
}
}
struct MyVersionMessage {
x: u32,
}
@@ -270,9 +182,6 @@ impl Message2 for MyVersionMessage {
fn name() -> &'static str {
"verver"
}
fn deserialize() {}
fn serialize() {}
}
impl Encodable for MyVersionMessage {
@@ -291,7 +200,7 @@ impl Decodable for MyVersionMessage {
}
}
struct MessageSubsystem {
pub struct MessageSubsystem {
dispatchers: Mutex<HashMap<&'static str, Arc<dyn MessageDispatcherInterface>>>,
}
@@ -309,12 +218,12 @@ impl MessageSubsystem {
.insert(M::name(), Arc::new(MessageDispatcher::<M>::new()));
}
pub async fn subscribe<M: Message2>(&self) -> NetResult<MessageSubscription2<M>> {
pub async fn subscribe<M: Message2>(&self) -> NetResult<MessageSubscription<M>> {
let dispatcher = self
.dispatchers
.lock()
.await
.get(MyVersionMessage::name())
.get(M::name())
.cloned();
let sub = match dispatcher {

View File

@@ -237,6 +237,7 @@ impl Message {
message.encode(&mut payload)?;
Ok(Packet {
command: PacketType::Ping,
command2: String::from(self.name()),
payload,
})
}
@@ -245,6 +246,7 @@ impl Message {
message.encode(&mut payload)?;
Ok(Packet {
command: PacketType::Pong,
command2: String::from(self.name()),
payload,
})
}
@@ -253,6 +255,7 @@ impl Message {
message.encode(&mut payload)?;
Ok(Packet {
command: PacketType::GetAddrs,
command2: String::from(self.name()),
payload,
})
}
@@ -261,6 +264,7 @@ impl Message {
message.encode(Cursor::new(&mut payload))?;
Ok(Packet {
command: PacketType::Addrs,
command2: String::from(self.name()),
payload,
})
}
@@ -268,6 +272,7 @@ impl Message {
let payload = serialize(message);
Ok(Packet {
command: PacketType::Inv,
command2: String::from(self.name()),
payload,
})
}
@@ -275,6 +280,7 @@ impl Message {
let payload = serialize(message);
Ok(Packet {
command: PacketType::GetSlabs,
command2: String::from(self.name()),
payload,
})
}
@@ -282,6 +288,7 @@ impl Message {
let payload = serialize(message);
Ok(Packet {
command: PacketType::Slab,
command2: String::from(self.name()),
payload,
})
}
@@ -289,6 +296,7 @@ impl Message {
let payload = serialize(message);
Ok(Packet {
command: PacketType::Version,
command2: String::from(self.name()),
payload,
})
}
@@ -296,6 +304,7 @@ impl Message {
let payload = serialize(message);
Ok(Packet {
command: PacketType::Verack,
command2: String::from(self.name()),
payload,
})
}
@@ -336,6 +345,7 @@ impl Message {
// These are converted to messages and passed to event loop
pub struct Packet {
pub command: PacketType,
pub command2: String,
pub payload: Vec<u8>,
}
@@ -351,9 +361,16 @@ pub async fn read_packet<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Packet>
}
// The type of the message
let command = AsyncReadExt::read_u8(stream).await?;
//let command = AsyncReadExt::read_u8(stream).await?;
//debug!(target: "net", "read command: {}", command);
//let command = PacketType::try_from(command).map_err(|_| Error::MalformedPacket)?;
let command_len = VarInt::decode_async(stream).await?.0 as usize;
let mut command = vec![0u8; command_len];
if command_len > 0 {
stream.read_exact(&mut command).await?;
}
let command = String::from_utf8(command)?;
debug!(target: "net", "read command: {}", command);
let command = PacketType::try_from(command).map_err(|_| Error::MalformedPacket)?;
let payload_len = VarInt::decode_async(stream).await?.0 as usize;
@@ -364,7 +381,7 @@ pub async fn read_packet<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Packet>
}
debug!(target: "net", "read payload {} bytes", payload_len);
Ok(Packet { command, payload })
Ok(Packet { command: PacketType::Verack, command2: command, payload })
}
pub async fn send_packet<W: AsyncWrite + Unpin>(stream: &mut W, packet: Packet) -> Result<()> {
@@ -372,8 +389,15 @@ pub async fn send_packet<W: AsyncWrite + Unpin>(stream: &mut W, packet: Packet)
stream.write_all(&MAGIC_BYTES).await?;
debug!(target: "net", "sent magic...");
AsyncWriteExt::write_u8(stream, packet.command as u8).await?;
debug!(target: "net", "sent command: {}", packet.command as u8);
//AsyncWriteExt::write_u8(stream, packet.command as u8).await?;
//debug!(target: "net", "sent command: {}", packet.command as u8);
VarInt(packet.command2.len() as u64)
.encode_async(stream)
.await?;
assert!(!packet.command2.is_empty());
stream.write_all(&packet.command2.as_bytes()).await?;
debug!(target: "net", "sent command: {}", packet.command2);
assert_eq!(std::mem::size_of::<usize>(), std::mem::size_of::<u64>());
VarInt(packet.payload.len() as u64)

View File

@@ -21,6 +21,5 @@ pub use acceptor::{Acceptor, AcceptorPtr};
pub use channel::{Channel, ChannelPtr};
pub use connector::Connector;
pub use hosts::{Hosts, HostsPtr};
pub use message_subscriber::{MessageSubscriber, MessageSubscription};
pub use p2p::P2p;
pub use settings::{Settings, SettingsPtr};

View File

@@ -11,8 +11,8 @@ use crate::net::{ChannelPtr, HostsPtr, SettingsPtr};
pub struct ProtocolAddress {
channel: ChannelPtr,
addrs_sub: MessageSubscription,
get_addrs_sub: MessageSubscription,
addrs_sub: MessageSubscription<messages::AddrsMessage>,
get_addrs_sub: MessageSubscription<messages::GetAddrsMessage>,
hosts: HostsPtr,
settings: SettingsPtr,
@@ -24,13 +24,15 @@ impl ProtocolAddress {
pub async fn new(channel: ChannelPtr, hosts: HostsPtr, settings: SettingsPtr) -> Arc<Self> {
let addrs_sub = channel
.clone()
.subscribe_msg(messages::PacketType::Addrs)
.await;
.subscribe_msg::<messages::AddrsMessage>()
.await
.expect("Missing addrs dispatcher!");
let get_addrs_sub = channel
.clone()
.subscribe_msg(messages::PacketType::GetAddrs)
.await;
.subscribe_msg::<messages::GetAddrsMessage>()
.await
.expect("Missing getaddrs dispatcher!");
Arc::new(Self {
channel: channel.clone(),
@@ -63,7 +65,7 @@ impl ProtocolAddress {
async fn handle_receive_addrs(self: Arc<Self>) -> NetResult<()> {
debug!(target: "net", "ProtocolAddress::handle_receive_addrs() [START]");
loop {
let addrs_msg = receive_message!(self.addrs_sub, messages::Message::Addrs);
let addrs_msg = self.addrs_sub.receive().await?;
debug!(target: "net", "ProtocolAddress::handle_receive_addrs() storing address in hosts");
self.hosts.store(addrs_msg.addrs.clone()).await;
@@ -73,7 +75,7 @@ impl ProtocolAddress {
async fn handle_receive_get_addrs(self: Arc<Self>) -> NetResult<()> {
debug!(target: "net", "ProtocolAddress::handle_receive_get_addrs() [START]");
loop {
let _get_addrs = receive_message!(self.get_addrs_sub, messages::Message::GetAddrs);
let _get_addrs = self.get_addrs_sub.receive().await?;
debug!(target: "net", "ProtocolAddress::handle_receive_get_addrs() received GetAddrs message");

View File

@@ -45,8 +45,9 @@ impl ProtocolPing {
let pong_sub = self
.channel
.clone()
.subscribe_msg(messages::PacketType::Pong)
.await;
.subscribe_msg::<messages::PongMessage>()
.await
.expect("Missing pong dispatcher!");
loop {
// Wait channel_heartbeat amount of time
@@ -63,7 +64,7 @@ impl ProtocolPing {
let start = Instant::now();
// Wait for pong, check nonce matches
let pong_msg = receive_message!(pong_sub, messages::Message::Pong);
let pong_msg = pong_sub.receive().await?;
if pong_msg.nonce != nonce {
error!("Wrong nonce for ping reply. Disconnecting from channel.");
self.channel.stop().await;
@@ -79,12 +80,13 @@ impl ProtocolPing {
let ping_sub = self
.channel
.clone()
.subscribe_msg(messages::PacketType::Ping)
.await;
.subscribe_msg::<messages::PingMessage>()
.await
.expect("Missing ping dispatcher!");
loop {
// Wait for ping, reply with pong that has a matching nonce
let ping = receive_message!(ping_sub, messages::Message::Ping);
let ping = ping_sub.receive().await?;
debug!(target: "net", "ProtocolPing::reply_to_ping() received Ping message");
// Send ping message

View File

@@ -26,8 +26,9 @@ impl ProtocolSeed {
let addr_sub = self
.channel
.clone()
.subscribe_msg(messages::PacketType::Addrs)
.await;
.subscribe_msg::<messages::AddrsMessage>()
.await
.expect("Missing addrs dispatcher!");
// Send own address to the seed server
self.send_own_address().await?;
@@ -37,7 +38,7 @@ impl ProtocolSeed {
self.channel.clone().send(get_addr).await?;
// Receive addresses
let addrs_msg = receive_message!(addr_sub, messages::Message::Addrs);
let addrs_msg = addr_sub.receive().await?;
self.hosts.store(addrs_msg.addrs.clone()).await;
debug!(target: "net", "ProtocolSeed::start() [END]");

View File

@@ -11,8 +11,8 @@ use crate::net::{ChannelPtr, SettingsPtr};
pub struct ProtocolVersion {
channel: ChannelPtr,
version_sub: MessageSubscription,
verack_sub: MessageSubscription,
version_sub: MessageSubscription<messages::VersionMessage>,
verack_sub: MessageSubscription<messages::VerackMessage>,
settings: SettingsPtr,
}
@@ -20,13 +20,15 @@ impl ProtocolVersion {
pub async fn new(channel: ChannelPtr, settings: SettingsPtr) -> Arc<Self> {
let version_sub = channel
.clone()
.subscribe_msg(messages::PacketType::Version)
.await;
.subscribe_msg::<messages::VersionMessage>()
.await
.expect("Missing version dispatcher!");
let verack_sub = channel
.clone()
.subscribe_msg(messages::PacketType::Verack)
.await;
.subscribe_msg::<messages::VerackMessage>()
.await
.expect("Missing verack dispatcher!");
Arc::new(Self {
channel,