From 3af393cf78efba460d79c35bc4ef7a82e7ef9265 Mon Sep 17 00:00:00 2001 From: narodnik Date: Sun, 20 Dec 2020 17:58:33 +0100 Subject: [PATCH] migrate client_protocol to Arc model. --- src/bin/dfi.rs | 17 ++-- src/net/net.rs | 2 +- src/net/protocol/client_protocol.rs | 115 ++++++++++------------------ src/net/protocol/seed_protocol.rs | 4 +- src/net/protocol/server_protocol.rs | 6 +- src/utility.rs | 8 +- 6 files changed, 59 insertions(+), 93 deletions(-) diff --git a/src/bin/dfi.rs b/src/bin/dfi.rs index f7148d3b6..fbf325e7f 100644 --- a/src/bin/dfi.rs +++ b/src/bin/dfi.rs @@ -1,7 +1,7 @@ #[macro_use] extern crate clap; use async_channel::unbounded; -use async_dup::Arc; +use std::sync::Arc; use async_executor::Executor; use async_std::sync::Mutex; use easy_parallel::Parallel; @@ -60,13 +60,13 @@ async fn start(executor: Arc>, options: ProgramOptions) -> Result<( let accept_addr = options.accept_addr.clone(); - let mut client_slots: Vec = vec![]; + let mut client_slots = vec![]; for i in 0..options.connection_slots { debug!("Starting connection slot {}", i); - let mut client = ClientProtocol::new(connections.clone()); - client - .start(accept_addr.clone(), stored_addrs.clone(), executor.clone()) + let mut client = ClientProtocol::new(connections.clone(), accept_addr.clone(), stored_addrs.clone()); + client.clone() + .start(executor.clone()) .await; client_slots.push(client); } @@ -74,12 +74,11 @@ async fn start(executor: Arc>, options: ProgramOptions) -> Result<( for remote_addr in options.manual_connects { debug!("Starting connection (manual) to {}", remote_addr); - let mut client = ClientProtocol::new(connections.clone()); - client + let mut client = ClientProtocol::new(connections.clone(), accept_addr.clone(), + stored_addrs.clone()); + client.clone() .start_manual( remote_addr, - accept_addr.clone(), - stored_addrs.clone(), executor.clone(), ) .await; diff --git a/src/net/net.rs b/src/net/net.rs index a5b8c1e75..43a5c51dc 100644 --- a/src/net/net.rs +++ b/src/net/net.rs @@ -1,4 +1,4 @@ -use async_dup::Arc; +use std::sync::Arc; use futures::prelude::*; use log::*; use num_enum::{IntoPrimitive, TryFromPrimitive}; diff --git a/src/net/protocol/client_protocol.rs b/src/net/protocol/client_protocol.rs index 435604ec1..f19b3651c 100644 --- a/src/net/protocol/client_protocol.rs +++ b/src/net/protocol/client_protocol.rs @@ -1,4 +1,5 @@ -use async_dup::Arc; +use async_std::sync::Mutex; +use std::sync::Arc; use log::*; use rand::seq::SliceRandom; use smol::{Async, Executor}; @@ -14,18 +15,25 @@ pub struct ClientProtocol { send_sx: async_channel::Sender, send_rx: async_channel::Receiver, connections: ConnectionsMap, - main_process: Option>, + main_process: Mutex>>, + + accept_addr: Option, + stored_addrs: AddrsStorage, } impl ClientProtocol { - pub fn new(connections: ConnectionsMap) -> Self { + pub fn new(connections: ConnectionsMap, accept_addr: Option, + stored_addrs: AddrsStorage, + ) -> Arc { let (send_sx, send_rx) = async_channel::unbounded::(); - Self { + Arc::new(Self { send_sx, send_rx, connections, - main_process: None, - } + main_process: Mutex::new(None), + accept_addr, + stored_addrs + }) } pub fn get_send_pipe(&self) -> async_channel::Sender { @@ -33,6 +41,7 @@ impl ClientProtocol { } async fn fetch_random_addr( + self: Arc, accept_addr: &Option, stored_addrs: &AddrsStorage, connections: &ConnectionsMap, @@ -58,19 +67,15 @@ impl ClientProtocol { } pub async fn start( - &mut self, - accept_addr: Option, - stored_addrs: AddrsStorage, + self: Arc, executor: Arc>, ) { - let connections = self.connections.clone(); - let (send_sx, send_rx) = (self.send_sx.clone(), self.send_rx.clone()); - let executor2 = executor.clone(); + let self2 = self.clone(); - self.main_process = Some(executor.spawn(async move { + *self2.main_process.lock().await = Some(executor.spawn(async move { loop { - let addr = match stored_addrs.lock().await.choose(&mut rand_core::OsRng) { + let addr = match self.stored_addrs.lock().await.choose(&mut rand_core::OsRng) { Some(addr) => addr.clone(), None => { debug!("No addresses in store. Sleeping for 2 secs before retrying..."); @@ -78,10 +83,10 @@ impl ClientProtocol { continue; } }; - if connections.lock().await.contains_key(&addr) { + if self.connections.lock().await.contains_key(&addr) { continue; } - if let Some(accept_addr) = accept_addr { + if let Some(accept_addr) = self.accept_addr { if addr == accept_addr { continue; } @@ -89,12 +94,8 @@ impl ClientProtocol { debug!("Attempting connect to {}", addr); - Self::try_connect_process( + self.try_connect_process( addr, - connections.clone(), - accept_addr.clone(), - stored_addrs.clone(), - (send_sx.clone(), send_rx.clone()), executor2.clone(), ) .await; @@ -106,28 +107,20 @@ impl ClientProtocol { } pub async fn start_manual( - &mut self, + self: Arc, remote_addr: SocketAddr, - accept_addr: Option, - stored_addrs: AddrsStorage, executor: Arc>, ) { - let connections = self.connections.clone(); - let (send_sx, send_rx) = (self.send_sx.clone(), self.send_rx.clone()); - let executor2 = executor.clone(); + let self2 = self.clone(); - self.main_process = Some(executor.spawn(async move { + *self2.main_process.lock().await = Some(executor.spawn(async move { loop { for _ in 0..4 { debug!("Attempting connect to {}", remote_addr); - Self::try_connect_process( + self.try_connect_process( remote_addr, - connections.clone(), - accept_addr.clone(), - stored_addrs.clone(), - (send_sx.clone(), send_rx.clone()), executor2.clone(), ) .await; @@ -138,25 +131,15 @@ impl ClientProtocol { } pub async fn try_connect_process( + &self, address: SocketAddr, - connections: ConnectionsMap, - accept_addr: Option, - stored_addrs: AddrsStorage, - (send_sx, send_rx): ( - async_channel::Sender, - async_channel::Receiver, - ), executor: Arc>, ) { match Async::::connect(address.clone()).await { Ok(stream) => { - let _ = Self::handle_connect( + let _ = self.handle_connect( stream, - stored_addrs.clone(), - connections, address, - (send_sx.clone(), send_rx.clone()), - accept_addr, executor, ) .await; @@ -168,32 +151,22 @@ impl ClientProtocol { } async fn handle_connect( + &self, stream: Async, - stored_addrs: AddrsStorage, - connections: ConnectionsMap, address: SocketAddr, - (send_sx, send_rx): ( - async_channel::Sender, - async_channel::Receiver, - ), - accept_addr: Option, executor: Arc>, ) -> Result<()> { debug!("Connected to {}", address); let stream = async_dup::Arc::new(stream); - connections + self.connections .lock() .await - .insert(address.clone(), send_sx.clone()); + .insert(address.clone(), self.send_sx.clone()); // Run event loop - match Self::event_loop_process( + match self.event_loop_process( stream, - stored_addrs, - (send_sx, send_rx), - accept_addr, - connections.clone(), executor, ) .await @@ -205,7 +178,7 @@ impl ClientProtocol { warn!("Server disconnected: {}", err); } } - connections.lock().await.remove(&address); + self.connections.lock().await.remove(&address); Ok(()) } @@ -225,30 +198,24 @@ impl ClientProtocol { } pub async fn event_loop_process( + &self, mut stream: net::AsyncTcpStream, - stored_addrs: AddrsStorage, - (send_sx, send_rx): ( - async_channel::Sender, - async_channel::Receiver, - ), - accept_addr: Option, - connections: ConnectionsMap, executor: Arc>, ) -> Result<()> { let inactivity_timer = net::InactivityTimer::new(executor.clone()); let clock = Arc::new(AtomicU64::new(0)); - let send_sx2 = send_sx.clone(); + let send_sx2 = self.send_sx.clone(); let clock2 = clock.clone(); let ping_task = executor.spawn(protocol_base::repeat_ping(send_sx2, clock2)); let mut send_addr_task = None; - if let Some(accept_addr) = accept_addr { - send_addr_task = Some(executor.spawn(Self::send_addr(send_sx.clone(), accept_addr))); + if let Some(accept_addr) = self.accept_addr { + send_addr_task = Some(executor.spawn(Self::send_addr(self.send_sx.clone(), accept_addr.clone()))); } loop { - let event = net::select_event(&mut stream, &send_rx, &inactivity_timer).await?; + let event = net::select_event(&mut stream, &self.send_rx, &inactivity_timer).await?; match event { net::Event::Send(message) => { @@ -258,10 +225,10 @@ impl ClientProtocol { inactivity_timer.reset().await?; protocol_base::protocol( message, - &stored_addrs, - &send_sx, + &self.stored_addrs, + &self.send_sx, Some(&clock), - connections.clone(), + self.connections.clone(), ) .await?; } diff --git a/src/net/protocol/seed_protocol.rs b/src/net/protocol/seed_protocol.rs index 075f90ecb..cc7752c22 100644 --- a/src/net/protocol/seed_protocol.rs +++ b/src/net/protocol/seed_protocol.rs @@ -1,4 +1,4 @@ -use async_dup::Arc; +use std::sync::Arc; use log::*; use smol::{Async, Executor}; use std::net::{SocketAddr, TcpStream}; @@ -92,7 +92,7 @@ impl SeedProtocol { .send(net::Message::GetAddrs(net::GetAddrsMessage {})) .await?; - let stream = Arc::new(stream); + let stream = async_dup::Arc::new(stream); // Run event loop match Self::event_loop_process(stream, stored_addrs.clone(), (send_sx, send_rx), executor) diff --git a/src/net/protocol/server_protocol.rs b/src/net/protocol/server_protocol.rs index e5f076a30..b29bc8879 100644 --- a/src/net/protocol/server_protocol.rs +++ b/src/net/protocol/server_protocol.rs @@ -1,4 +1,4 @@ -use async_dup::Arc; +use std::sync::Arc; use log::*; use smol::{Async, Executor}; use std::net::{SocketAddr, TcpListener}; @@ -33,7 +33,7 @@ impl ServerProtocol { &mut self, address: SocketAddr, stored_addrs: AddrsStorage, - executor: async_dup::Arc>, + executor: std::sync::Arc>, ) -> Result<()> { let listener = Async::::bind(address)?; info!("Listening on {}", listener.get_ref().local_addr()?); @@ -41,7 +41,7 @@ impl ServerProtocol { loop { let (stream, peer_addr) = listener.accept().await?; info!("Accepted client: {}", peer_addr); - let stream = Arc::new(stream); + let stream = async_dup::Arc::new(stream); let (send_sx, send_rx) = (self.send_sx.clone(), self.send_rx.clone()); diff --git a/src/utility.rs b/src/utility.rs index ce5a576a2..07a3d2979 100644 --- a/src/utility.rs +++ b/src/utility.rs @@ -1,4 +1,4 @@ -use async_dup::Arc; +use std::sync::Arc; use std::collections::HashMap; use std::fs::OpenOptions; use std::io::prelude::*; @@ -12,13 +12,13 @@ use smol::{Executor, Task}; //use crate::{net, serial, Channel, ClientProtocol, Result, SlabsManagerSafe}; use crate::{net::net, serial, Result}; -pub type ConnectionsMap = async_dup::Arc< +pub type ConnectionsMap = std::sync::Arc< async_std::sync::Mutex>>, >; -pub type AddrsStorage = async_dup::Arc>>; +pub type AddrsStorage = std::sync::Arc>>; -pub type Clock = async_dup::Arc; +pub type Clock = std::sync::Arc; pub fn get_current_time() -> u64 { let start = SystemTime::now();