From c057c610d439d7a2298545c3c7042eac25d39cab Mon Sep 17 00:00:00 2001 From: ghassmo Date: Wed, 4 May 2022 01:41:43 +0300 Subject: [PATCH] net3: new redesign for acceptor and connector and use more general form for transport protocols --- src/net3/acceptor.rs | 143 ++++++++++++-------------------------- src/net3/connector.rs | 86 +++++++++++------------ src/net3/mod.rs | 4 +- src/net3/transport.rs | 40 ++++++++++- src/net3/transport/tcp.rs | 48 +++++++++++-- 5 files changed, 170 insertions(+), 151 deletions(-) diff --git a/src/net3/acceptor.rs b/src/net3/acceptor.rs index 121787cdf..24fbe5078 100644 --- a/src/net3/acceptor.rs +++ b/src/net3/acceptor.rs @@ -1,8 +1,6 @@ -use async_std::{stream::StreamExt, sync::Arc}; -use std::net::SocketAddr; +use async_std::sync::Arc; -use futures_rustls::TlsStream; -use log::{error, info}; +use log::error; use smol::Executor; use url::Url; @@ -11,13 +9,7 @@ use crate::{ Error, Result, }; -use super::{Channel, ChannelPtr, TcpTransport, Transport}; - -/// A helper function to convert peer addr to Url and add scheme -fn peer_addr_to_url(addr: SocketAddr, scheme: &str) -> Result { - let url = Url::parse(&format!("{}://{}", scheme, addr))?; - Ok(url) -} +use super::{Channel, ChannelPtr, TcpTransport, Transport, TransportListener, TransportName}; /// Atomic pointer to Acceptor class. pub type AcceptorPtr = Arc; @@ -38,39 +30,12 @@ impl Acceptor { /// thread, erroring if a connection problem occurs. pub async fn start( self: Arc, - accept_addr: Url, + accept_url: Url, executor: Arc>, ) -> Result<()> { - self.accept(accept_addr, executor); - Ok(()) - } - - /// Stop accepting inbound socket connections. - pub async fn stop(&self) { - // Send stop signal - self.task.stop().await; - } - - /// Start receiving network messages. - pub async fn subscribe(self: Arc) -> Subscription> { - self.channel_subscriber.clone().subscribe().await - } - - /// Run the accept loop in a new thread and error if a connection problem - /// occurs. - fn accept(self: Arc, accept_addr: Url, executor: Arc>) { - self.task.clone().start( - self.clone().run_accept_loop(accept_addr), - |result| self.handle_stop(result), - Error::ServiceStopped, - executor, - ); - } - - /// Run the accept loop. - async fn run_accept_loop(self: Arc, accept_url: Url) -> Result<()> { - match accept_url.scheme() { - "tcp" => { + let transport_name = TransportName::try_from(accept_url.clone())?; + match transport_name { + TransportName::Tcp(upgrade) => { let transport = TcpTransport::new(None, 1024); let listener = transport.listen_on(accept_url); @@ -87,71 +52,51 @@ impl Acceptor { } let listener = listener?; - let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { - let result: Result<()> = { - let stream = stream?; - let peer_addr = peer_addr_to_url(stream.peer_addr()?, "tcp")?; - info!("Accepted client: {}", peer_addr); - let channel = Channel::new(Box::new(stream), peer_addr).await; - self.channel_subscriber.notify(Ok(channel)).await; - Ok(()) - }; - - if let Err(err) = result { - error!("Error listening for connections: {}", err); - return Err(Error::ServiceStopped) + match upgrade { + None => { + self.accept(Box::new(listener), executor); } + Some(u) if u == "tls" => { + let tls_listener = transport.upgrade_listener(listener)?.await?; + self.accept(Box::new(tls_listener), executor); + } + // TODO hanle unsupported upgrade + Some(_) => todo!(), } } - "tcp+tls" => { - let transport = TcpTransport::new(None, 1024); + TransportName::Tor(_upgrade) => todo!(), + } + Ok(()) + } - let listener = transport.listen_on(accept_url); + /// Stop accepting inbound socket connections. + pub async fn stop(&self) { + // Send stop signal + self.task.stop().await; + } - if let Err(err) = listener { - error!("Setup failed: {}", err); - return Err(Error::OperationFailed) - } + /// Start receiving network messages. + pub async fn subscribe(self: Arc) -> Subscription> { + self.channel_subscriber.clone().subscribe().await + } - let listener = listener?.await; + /// Run the accept loop in a new thread and error if a connection problem + /// occurs. + fn accept(self: Arc, listener: Box, executor: Arc>) { + self.task.clone().start( + self.clone().run_accept_loop(listener), + |result| self.handle_stop(result), + Error::ServiceStopped, + executor, + ); + } - if let Err(err) = listener { - error!("Bind listener failed: {}", err); - return Err(Error::OperationFailed) - } - - let (acceptor, listener) = transport.upgrade_listener(listener?)?.await?; - - let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { - let result: Result<()> = { - let stream = stream?; - let peer_addr = peer_addr_to_url(stream.peer_addr()?, "tls")?; - info!("Accepted client: {}", peer_addr); - let stream = acceptor.accept(stream).await; - - if let Err(err) = stream { - error!("Error wraping the connection with tls: {}", err); - return Err(Error::ServiceStopped) - } - - let stream = stream?; - let channel = - Channel::new(Box::new(TlsStream::Server(stream)), peer_addr).await; - self.channel_subscriber.notify(Ok(channel)).await; - Ok(()) - }; - - if let Err(err) = result { - error!("Error listening for connections: {}", err); - return Err(Error::ServiceStopped) - } - } - } - "tor" => todo!(), - _ => unimplemented!(), + /// Run the accept loop. + async fn run_accept_loop(self: Arc, listener: Box) -> Result<()> { + while let Ok((stream, peer_addr)) = listener.next().await { + let channel = Channel::new(stream, peer_addr).await; + self.channel_subscriber.notify(Ok(channel)).await; } Ok(()) } diff --git a/src/net3/connector.rs b/src/net3/connector.rs index 566322f8d..40340741d 100644 --- a/src/net3/connector.rs +++ b/src/net3/connector.rs @@ -1,4 +1,4 @@ -use async_std::future::timeout; +use async_std::{future::timeout, sync::Arc}; use std::time::Duration; use log::error; @@ -6,7 +6,7 @@ use url::Url; use crate::{Error, Result}; -use super::{Channel, ChannelPtr, SettingsPtr, TcpTransport, Transport}; +use super::{Channel, ChannelPtr, SettingsPtr, TcpTransport, Transport, TransportName}; /// Create outbound socket connections. pub struct Connector { @@ -21,52 +21,50 @@ impl Connector { /// Establish an outbound connection. pub async fn connect(&self, connect_url: Url) -> Result { + let transport_name = TransportName::try_from(connect_url.clone())?; let result = timeout(Duration::from_secs(self.settings.connect_timeout_seconds.into()), async { - match connect_url.scheme() { - "tcp" => { - let transport = TcpTransport::new(None, 1024); - let stream = transport.dial(connect_url.clone()); - - if let Err(err) = stream { - error!("Setup failed: {}", err); - return Err(Error::ConnectFailed) - } - - let stream = stream?.await; - - if let Err(err) = stream { - error!("Connection failed: {}", err); - return Err(Error::ConnectFailed) - } - - Ok(Channel::new(Box::new(stream?), connect_url).await) - } - "tcp+tls" => { - let transport = TcpTransport::new(None, 1024); - let stream = transport.dial(connect_url.clone()); - - if let Err(err) = stream { - error!("Setup failed: {}", err); - return Err(Error::ConnectFailed) - } - - let stream = stream?.await; - - if let Err(err) = stream { - error!("Connection failed: {}", err); - return Err(Error::ConnectFailed) - } - - let stream = transport.upgrade_dialer(stream?)?.await; - - Ok(Channel::new(Box::new(stream?), connect_url).await) - } - "tor" => todo!(), - _ => unimplemented!(), - } + self.connect_channel(connect_url, transport_name).await }) .await?; result } + + async fn connect_channel( + &self, + connect_url: Url, + transport_name: TransportName, + ) -> Result> { + match transport_name { + TransportName::Tcp(upgrade) => { + let transport = TcpTransport::new(None, 1024); + let stream = transport.dial(connect_url.clone()); + + if let Err(err) = stream { + error!("Setup failed: {}", err); + return Err(Error::ConnectFailed) + } + + let stream = stream?.await; + + if let Err(err) = stream { + error!("Connection failed: {}", err); + return Err(Error::ConnectFailed) + } + + let channel = match upgrade { + None => Channel::new(Box::new(stream?), connect_url.clone()).await, + Some(u) if u == "tls" => { + let stream = transport.upgrade_dialer(stream?)?.await; + Channel::new(Box::new(stream?), connect_url).await + } + // TODO hanle unsupported upgrade + Some(_) => todo!(), + }; + + Ok(channel) + } + TransportName::Tor(_upgrade) => todo!(), + } + } } diff --git a/src/net3/mod.rs b/src/net3/mod.rs index 065a57b3e..95cc49304 100644 --- a/src/net3/mod.rs +++ b/src/net3/mod.rs @@ -98,4 +98,6 @@ pub use p2p::{P2p, P2pPtr}; pub use protocol::{ProtocolBase, ProtocolBasePtr, ProtocolJobsManager, ProtocolJobsManagerPtr}; pub use session::{SESSION_ALL, SESSION_INBOUND, SESSION_MANUAL, SESSION_OUTBOUND, SESSION_SEED}; pub use settings::{Settings, SettingsPtr}; -pub use transport::{TcpTransport, TorTransport, Transport, TransportStream}; +pub use transport::{ + TcpTransport, TorTransport, Transport, TransportListener, TransportName, TransportStream, +}; diff --git a/src/net3/transport.rs b/src/net3/transport.rs index 450a1bba9..2eeb402ea 100644 --- a/src/net3/transport.rs +++ b/src/net3/transport.rs @@ -1,5 +1,7 @@ -use async_trait::async_trait; +use std::net::SocketAddr; +use async_trait::async_trait; +// TODO remove * use futures::prelude::*; use futures_rustls::{TlsAcceptor, TlsStream}; use url::Url; @@ -15,13 +17,45 @@ pub use tcp::TcpTransport; mod tor; pub use tor::TorTransport; -/// This used as wrapper for stream return by dial function inside Transport trait +/// A helper function to convert SocketAddr to Url and add scheme +pub(crate) fn socket_addr_to_url(addr: SocketAddr, scheme: &str) -> Result { + let url = Url::parse(&format!("{}://{}", scheme, addr))?; + Ok(url) +} + +/// Used as wrapper for stream used by Transport trait pub trait TransportStream: AsyncWrite + AsyncRead + Unpin + Send + Sync {} +/// Used as wrapper for listener used by Transport trait +#[async_trait] +pub trait TransportListener: Send + Sync + Unpin { + async fn next(&self) -> Result<(Box, Url)>; +} + +#[derive(Clone)] +pub enum TransportName { + Tcp(Option), + Tor(Option), +} + +impl TryFrom for TransportName { + type Error = crate::Error; + + fn try_from(url: Url) -> Result { + let transport_name = match url.scheme() { + "tcp" => Self::Tcp(None), + "tcp+tls" | "tls" => Self::Tcp(Some("tls".into())), + "tor" => Self::Tor(None), + "tor+tls" => Self::Tcp(Some("tls".into())), + n => return Err(crate::Error::UnsupportedTransport(n.into())), + }; + Ok(transport_name) + } +} + /// The `Transport` trait serves as a base for implementing transport protocols. /// Base transports can optionally be upgraded with TLS in order to support encryption. /// The implementation of our TLS authentication can be found in the [`upgrade_tls`] module. -#[async_trait] pub trait Transport { type Acceptor; type Connector; diff --git a/src/net3/transport/tcp.rs b/src/net3/transport/tcp.rs index 8f9e99dc7..29f4a4599 100644 --- a/src/net3/transport/tcp.rs +++ b/src/net3/transport/tcp.rs @@ -1,18 +1,58 @@ use async_std::net::{TcpListener, TcpStream}; use std::{io, net::SocketAddr, pin::Pin}; +use async_trait::async_trait; use futures::prelude::*; use futures_rustls::{TlsAcceptor, TlsStream}; -use log::debug; +use log::{debug, error}; use socket2::{Domain, Socket, Type}; use url::Url; -use super::{TlsUpgrade, Transport, TransportStream}; +use super::{socket_addr_to_url, TlsUpgrade, Transport, TransportListener, TransportStream}; use crate::{Error, Result}; impl TransportStream for TcpStream {} impl TransportStream for TlsStream {} +#[async_trait] +impl TransportListener for TcpListener { + async fn next(&self) -> Result<(Box, Url)> { + let (stream, peer_addr) = match self.accept().await { + Ok((s, a)) => (s, a), + Err(err) => { + error!("Error listening for connections: {}", err); + return Err(Error::ServiceStopped) + } + }; + let url = socket_addr_to_url(peer_addr, "tcp")?; + Ok((Box::new(stream), url)) + } +} + +#[async_trait] +impl TransportListener for (TlsAcceptor, TcpListener) { + async fn next(&self) -> Result<(Box, Url)> { + let (stream, peer_addr) = match self.1.accept().await { + Ok((s, a)) => (s, a), + Err(err) => { + error!("Error listening for connections: {}", err); + return Err(Error::ServiceStopped) + } + }; + + let stream = self.0.accept(stream).await; + + if let Err(err) = stream { + error!("Error wraping the connection with tls: {}", err); + return Err(Error::ServiceStopped) + } + + let url = socket_addr_to_url(peer_addr, "tcp+tls")?; + + Ok((Box::new(TlsStream::Server(stream?)), url)) + } +} + #[derive(Copy, Clone)] pub struct TcpTransport { /// TTL to set for opened sockets, or `None` for default @@ -33,7 +73,7 @@ impl Transport for TcpTransport { fn listen_on(self, url: Url) -> Result { match url.scheme() { - "tcp" | "tcp+tls" => {} + "tcp" | "tcp+tls" | "tls" => {} x => return Err(Error::UnsupportedTransport(x.to_string())), } @@ -49,7 +89,7 @@ impl Transport for TcpTransport { fn dial(self, url: Url) -> Result { match url.scheme() { - "tcp" | "tcp+tls" => {} + "tcp" | "tcp+tls" | "tls" => {} x => return Err(Error::UnsupportedTransport(x.to_string())), }