net: Implement Unix socket transport (both dialer and listener)

This commit is contained in:
parazyd
2023-08-17 18:13:09 +02:00
parent 827ec53f63
commit b57d0c1c9b
4 changed files with 184 additions and 9 deletions

View File

@@ -133,6 +133,7 @@ prettytable-rs = "0.10.0"
# -----BEGIN LIBRARY FEATURES-----
[features]
p2p-transport-unix = []
p2p-transport-tcp = []
p2p-transport-tor = ["arti-client", "tor-hscrypto"]
p2p-transport-nym = []

View File

@@ -39,6 +39,10 @@ pub(crate) mod tor;
/// Nym transport
pub(crate) mod nym;
#[cfg(feature = "p2p-transport-unix")]
/// Unix socket transport
pub(crate) mod unix;
/// Dialer variants
#[derive(Debug, Clone)]
pub enum DialerVariant {
@@ -65,6 +69,10 @@ pub enum DialerVariant {
#[cfg(feature = "p2p-transport-nym")]
/// Nym with TLS
NymTls(nym::NymDialer),
#[cfg(feature = "p2p-transport-unix")]
/// Unix socket
Unix(unix::UnixDialer),
}
/// Listener variants
@@ -77,6 +85,10 @@ pub enum ListenerVariant {
#[cfg(feature = "p2p-transport-tcp")]
/// TCP with TLS
TcpTls(tcp::TcpListener),
#[cfg(feature = "p2p-transport-unix")]
/// Unix socket
Unix(unix::UnixListener),
}
/// A dialer that is able to transparently operate over arbitrary transports.
@@ -87,18 +99,34 @@ pub struct Dialer {
variant: DialerVariant,
}
impl Dialer {
/// Instantiate a new [`Dialer`] with the given [`Url`].
/// Must contain a scheme, host string, and a port.
pub async fn new(endpoint: Url) -> Result<Self> {
if endpoint.host_str().is_none() || endpoint.port().is_none() {
macro_rules! enforce_hostport {
($endpoint:ident) => {
if $endpoint.host_str().is_none() || $endpoint.port().is_none() {
return Err(Error::InvalidDialerScheme)
}
};
}
macro_rules! enforce_abspath {
($endpoint:ident) => {
if $endpoint.host_str().is_some() || $endpoint.port().is_some() {
return Err(Error::InvalidDialerScheme)
}
if $endpoint.to_file_path().is_err() {
return Err(Error::InvalidDialerScheme)
}
};
}
impl Dialer {
/// Instantiate a new [`Dialer`] with the given [`Url`].
pub async fn new(endpoint: Url) -> Result<Self> {
match endpoint.scheme().to_lowercase().as_str() {
#[cfg(feature = "p2p-transport-tcp")]
"tcp" => {
// Build a TCP dialer
enforce_hostport!(endpoint);
let variant = tcp::TcpDialer::new(None).await?;
let variant = DialerVariant::Tcp(variant);
Ok(Self { endpoint, variant })
@@ -107,6 +135,7 @@ impl Dialer {
#[cfg(feature = "p2p-transport-tcp")]
"tcp+tls" => {
// Build a TCP dialer wrapped with TLS
enforce_hostport!(endpoint);
let variant = tcp::TcpDialer::new(None).await?;
let variant = DialerVariant::TcpTls(variant);
Ok(Self { endpoint, variant })
@@ -115,6 +144,7 @@ impl Dialer {
#[cfg(feature = "p2p-transport-tor")]
"tor" => {
// Build a Tor dialer
enforce_hostport!(endpoint);
let variant = tor::TorDialer::new().await?;
let variant = DialerVariant::Tor(variant);
Ok(Self { endpoint, variant })
@@ -123,6 +153,7 @@ impl Dialer {
#[cfg(feature = "p2p-transport-tor")]
"tor+tls" => {
// Build a Tor dialer wrapped with TLS
enforce_hostport!(endpoint);
let variant = tor::TorDialer::new().await?;
let variant = DialerVariant::TorTls(variant);
Ok(Self { endpoint, variant })
@@ -131,6 +162,7 @@ impl Dialer {
#[cfg(feature = "p2p-transport-nym")]
"nym" => {
// Build a Nym dialer
enforce_hostport!(endpoint);
let variant = nym::NymDialer::new().await?;
let variant = DialerVariant::Nym(variant);
Ok(Self { endpoint, variant })
@@ -139,11 +171,21 @@ impl Dialer {
#[cfg(feature = "p2p-transport-nym")]
"nym+tls" => {
// Build a Nym dialer wrapped with TLS
enforce_hostport!(endpoint);
let variant = nym::NymDialer::new().await?;
let variant = DialerVariant::NymTls(variant);
Ok(Self { endpoint, variant })
}
#[cfg(feature = "p2p-transport-unix")]
"unix" => {
enforce_abspath!(endpoint);
// Build a Unix socket dialer
let variant = unix::UnixDialer::new().await?;
let variant = DialerVariant::Unix(variant);
Ok(Self { endpoint, variant })
}
x => Err(Error::UnsupportedTransport(x.to_string())),
}
}
@@ -195,6 +237,13 @@ impl Dialer {
DialerVariant::NymTls(_dialer) => {
todo!();
}
#[cfg(feature = "p2p-transport-unix")]
DialerVariant::Unix(dialer) => {
let path = self.endpoint.to_file_path()?;
let stream = dialer.do_dial(path).await?;
Ok(Box::new(stream))
}
}
}
@@ -216,14 +265,11 @@ impl Listener {
/// Instantiate a new [`Listener`] with the given [`Url`].
/// Must contain a scheme, host string, and a port.
pub async fn new(endpoint: Url) -> Result<Self> {
if endpoint.host_str().is_none() || endpoint.port().is_none() {
return Err(Error::InvalidListenerScheme)
}
match endpoint.scheme().to_lowercase().as_str() {
#[cfg(feature = "p2p-transport-tcp")]
"tcp" => {
// Build a TCP listener
enforce_hostport!(endpoint);
let variant = tcp::TcpListener::new(1024).await?;
let variant = ListenerVariant::Tcp(variant);
Ok(Self { endpoint, variant })
@@ -232,11 +278,20 @@ impl Listener {
#[cfg(feature = "p2p-transport-tcp")]
"tcp+tls" => {
// Build a TCP listener wrapped with TLS
enforce_hostport!(endpoint);
let variant = tcp::TcpListener::new(1024).await?;
let variant = ListenerVariant::TcpTls(variant);
Ok(Self { endpoint, variant })
}
#[cfg(feature = "p2p-transport-unix")]
"unix" => {
enforce_abspath!(endpoint);
let variant = unix::UnixListener::new().await?;
let variant = ListenerVariant::Unix(variant);
Ok(Self { endpoint, variant })
}
x => Err(Error::UnsupportedTransport(x.to_string())),
}
}
@@ -260,6 +315,13 @@ impl Listener {
let l = tlsupgrade.upgrade_listener_tcp_tls(l).await?;
Ok(Box::new(l))
}
#[cfg(feature = "p2p-transport-unix")]
ListenerVariant::Unix(listener) => {
let path = self.endpoint.to_file_path()?;
let l = listener.do_listen(&path.into()).await?;
Ok(Box::new(l))
}
}
}
@@ -283,6 +345,9 @@ impl PtStream for arti_client::DataStream {}
#[cfg(feature = "p2p-transport-tor")]
impl PtStream for async_rustls::TlsStream<arti_client::DataStream> {}
#[cfg(feature = "p2p-transport-unix")]
impl PtStream for async_std::os::unix::net::UnixStream {}
/// Wrapper trait for async listeners
#[async_trait]
pub trait PtListener: Send + Sync + Unpin {

83
src/net/transport/unix.rs Normal file
View File

@@ -0,0 +1,83 @@
/* This file is part of DarkFi (https://dark.fi)
*
* Copyright (C) 2020-2023 Dyne.org foundation
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use async_std::{
fs,
os::unix::net::{UnixListener as AsyncStdUnixListener, UnixStream},
path::{Path, PathBuf},
};
use async_trait::async_trait;
use log::debug;
use url::Url;
use super::{PtListener, PtStream};
use crate::Result;
/// Unix Dialer implementation
#[derive(Debug, Clone)]
pub struct UnixDialer;
impl UnixDialer {
/// Instantiate a new [`UnixDialer`] object
pub(crate) async fn new() -> Result<Self> {
Ok(Self {})
}
/// Internal dial function
pub(crate) async fn do_dial(
&self,
path: impl AsRef<Path> + core::fmt::Debug,
) -> Result<UnixStream> {
debug!(target: "net::unix::do_dial", "Dialing {:?} Unix socket...", path);
let stream = UnixStream::connect(path).await?;
Ok(stream)
}
}
/// Unix Listener implementation
#[derive(Debug, Clone)]
pub struct UnixListener;
impl UnixListener {
/// Instantiate a new [`UnixListener`] object
pub(crate) async fn new() -> Result<Self> {
Ok(Self {})
}
/// Internal listen function
pub(crate) async fn do_listen(&self, path: &PathBuf) -> Result<AsyncStdUnixListener> {
// This rm is a bit aggressive, but c'est la vie.
let _ = fs::remove_file(path).await;
let listener = AsyncStdUnixListener::bind(path).await?;
Ok(listener)
}
}
#[async_trait]
impl PtListener for AsyncStdUnixListener {
async fn next(&self) -> Result<(Box<dyn PtStream>, Url)> {
let (stream, _peer_addr) = match self.accept().await {
Ok((s, a)) => (s, a),
Err(e) => return Err(e.into()),
};
let addr = self.local_addr().unwrap();
let url = Url::parse(&format!("unix://{}", addr.as_pathname().unwrap().to_str().unwrap()))?;
Ok((Box::new(stream), url))
}
}

View File

@@ -66,3 +66,29 @@ async fn tcp_tls_transport() {
assert_eq!(buf, payload);
}
#[async_std::test]
async fn unix_transport() {
let tmpdir = std::env::temp_dir();
let url = Url::parse(&format!(
"unix://{}/darkfi_unix_plain.sock",
tmpdir.as_os_str().to_str().unwrap()
))
.unwrap();
let listener = Listener::new(url.clone()).await.unwrap().listen().await.unwrap();
task::spawn(async move {
let (stream, _) = listener.next().await.unwrap();
let (mut reader, mut writer) = smol::io::split(stream);
io::copy(&mut reader, &mut writer).await.unwrap();
});
let payload = b"ohai unix";
let dialer = Dialer::new(url).await.unwrap();
let mut client = dialer.dial(None).await.unwrap();
client.write_all(payload).await.unwrap();
let mut buf = vec![0u8; 9];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, payload);
}