From 2da828478c0bebe619069a0175f594edb5745020 Mon Sep 17 00:00:00 2001 From: Andrew Kirillov <20803092+akirillo@users.noreply.github.com> Date: Fri, 6 Jan 2023 12:43:13 -0800 Subject: [PATCH] feat(net): Bandwidth monitoring (#707) * WIP for draft PR * added basic test * using BandwidthMeterInner type & added TcpStream test * formatted * formatted w/ +nightly * using & for and * formatted * added default impl for BandwidthMeter * using _bandwidth_meter bc unused * removed redundant clone * addressed nits, renamed file * addressed nits, renamed file --- Cargo.lock | 3 + crates/net/common/Cargo.toml | 7 +- crates/net/common/src/bandwidth_meter.rs | 272 +++++++++++++++++++++++ crates/net/common/src/lib.rs | 1 + crates/net/network/src/manager.rs | 11 + crates/net/network/src/network.rs | 10 + crates/net/network/src/session/active.rs | 10 +- crates/net/network/src/session/handle.rs | 3 +- crates/net/network/src/session/mod.rs | 18 +- 9 files changed, 326 insertions(+), 9 deletions(-) create mode 100644 crates/net/common/src/bandwidth_meter.rs diff --git a/Cargo.lock b/Cargo.lock index 367592de9f..0d176407ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3786,7 +3786,10 @@ dependencies = [ name = "reth-net-common" version = "0.1.0" dependencies = [ + "pin-project", + "reth-ecies", "reth-primitives", + "tokio", ] [[package]] diff --git a/crates/net/common/Cargo.toml b/crates/net/common/Cargo.toml index 702ecccb46..3a734890b4 100644 --- a/crates/net/common/Cargo.toml +++ b/crates/net/common/Cargo.toml @@ -10,4 +10,9 @@ Types shared accross network code [dependencies] # reth -reth-primitives = { path = "../../primitives" } \ No newline at end of file +reth-primitives = { path = "../../primitives" } +reth-ecies = { path = "../ecies" } + +# async +pin-project = "1.0" +tokio = { version = "1.21.2", features = ["full"] } \ No newline at end of file diff --git a/crates/net/common/src/bandwidth_meter.rs b/crates/net/common/src/bandwidth_meter.rs new file mode 100644 index 0000000000..82247e4118 --- /dev/null +++ b/crates/net/common/src/bandwidth_meter.rs @@ -0,0 +1,272 @@ +//! Support for monitoring bandwidth. Takes heavy inspiration from https://github.com/libp2p/rust-libp2p/blob/master/src/bandwidth.rs + +// Copyright 2019 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use reth_ecies::stream::HasRemoteAddr; +use std::{ + convert::TryFrom as _, + io, + net::SocketAddr, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{ready, Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpStream, +}; + +/// Meters bandwidth usage of streams +#[derive(Debug)] +struct BandwidthMeterInner { + /// Measures the number of inbound packets + inbound: AtomicU64, + /// Measures the number of outbound packets + outbound: AtomicU64, +} + +/// Public shareable struct used for getting bandwidth metering info +#[derive(Clone, Debug)] +pub struct BandwidthMeter { + inner: Arc, +} + +impl BandwidthMeter { + /// Returns the total number of bytes that have been downloaded on all the streams. + /// + /// > **Note**: This method is by design subject to race conditions. The returned value should + /// > only ever be used for statistics purposes. + pub fn total_inbound(&self) -> u64 { + self.inner.inbound.load(Ordering::Relaxed) + } + + /// Returns the total number of bytes that have been uploaded on all the streams. + /// + /// > **Note**: This method is by design subject to race conditions. The returned value should + /// > only ever be used for statistics purposes. + pub fn total_outbound(&self) -> u64 { + self.inner.outbound.load(Ordering::Relaxed) + } +} + +impl Default for BandwidthMeter { + fn default() -> Self { + Self { + inner: Arc::new(BandwidthMeterInner { + inbound: AtomicU64::new(0), + outbound: AtomicU64::new(0), + }), + } + } +} + +/// Wraps around a single stream that implements [`AsyncRead`] + [`AsyncWrite`] and meters the +/// bandwidth through it +#[derive(Debug)] +#[pin_project::pin_project] +pub struct MeteredStream { + /// The stream this instruments + #[pin] + inner: S, + /// The [`BandwidthMeter`] struct this uses to monitor bandwidth + meter: BandwidthMeter, +} + +impl MeteredStream { + /// Creates a new [`MeteredStream`] wrapping around the provided stream, + /// along with a new [`BandwidthMeter`] + pub fn new(inner: S) -> Self { + Self { inner, meter: BandwidthMeter::default() } + } + + /// Creates a new [`MeteredStream`] wrapping around the provided stream, + /// attaching the provided [`BandwidthMeter`] + pub fn new_with_meter(inner: S, meter: BandwidthMeter) -> Self { + Self { inner, meter } + } + + /// Provides a reference to the [`BandwidthMeter`] attached to this [`MeteredStream`] + pub fn get_bandwidth_meter(&self) -> &BandwidthMeter { + &self.meter + } +} + +impl AsyncRead for MeteredStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.project(); + let num_bytes = { + let init_num_bytes = buf.filled().len(); + ready!(this.inner.poll_read(cx, buf))?; + buf.filled().len() - init_num_bytes + }; + this.meter + .inner + .inbound + .fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for MeteredStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + let num_bytes = ready!(this.inner.poll_write(cx, buf))?; + this.meter + .inner + .outbound + .fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + Poll::Ready(Ok(num_bytes)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.inner.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.inner.poll_shutdown(cx) + } +} + +impl HasRemoteAddr for MeteredStream { + fn remote_addr(&self) -> Option { + self.inner.remote_addr() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream}, + net::{TcpListener, TcpStream}, + }; + + async fn duplex_stream_ping_pong( + client: &mut MeteredStream, + server: &mut MeteredStream, + ) { + let mut buf = [0u8; 4]; + + client.write_all(b"ping").await.unwrap(); + server.read(&mut buf).await.unwrap(); + + server.write_all(b"pong").await.unwrap(); + client.read(&mut buf).await.unwrap(); + } + + fn assert_bandwidth_counts( + bandwidth_meter: &BandwidthMeter, + expected_inbound: u64, + expected_outbound: u64, + ) { + let actual_inbound = bandwidth_meter.total_inbound(); + assert_eq!( + actual_inbound, expected_inbound, + "Expected {} inbound bytes, but got {}", + expected_inbound, actual_inbound, + ); + + let actual_outbound = bandwidth_meter.total_outbound(); + assert_eq!( + actual_outbound, expected_outbound, + "Expected {} inbound bytes, but got {}", + expected_outbound, actual_outbound, + ); + } + + #[tokio::test] + async fn test_count_read_write() { + // Taken in large part from https://docs.rs/tokio/latest/tokio/io/struct.DuplexStream.html#example + + let (client, server) = duplex(64); + let mut monitored_client = MeteredStream::new(client); + let mut monitored_server = MeteredStream::new(server); + + duplex_stream_ping_pong(&mut monitored_client, &mut monitored_server).await; + + assert_bandwidth_counts(monitored_client.get_bandwidth_meter(), 4, 4); + assert_bandwidth_counts(monitored_server.get_bandwidth_meter(), 4, 4); + } + + #[tokio::test] + async fn test_read_equals_write_tcp() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + let client_stream = TcpStream::connect(server_addr).await.unwrap(); + let mut metered_client_stream = MeteredStream::new(client_stream); + + let client_meter = metered_client_stream.meter.clone(); + + let handle = tokio::spawn(async move { + let (server_stream, _) = listener.accept().await.unwrap(); + let mut metered_server_stream = MeteredStream::new(server_stream); + + let mut buf = [0u8; 4]; + + metered_server_stream.read(&mut buf).await.unwrap(); + + assert_eq!(metered_server_stream.meter.total_inbound(), client_meter.total_outbound()); + }); + + metered_client_stream.write_all(b"ping").await.unwrap(); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn test_multiple_streams_one_meter() { + let (client_1, server_1) = duplex(64); + let (client_2, server_2) = duplex(64); + + let shared_client_bandwidth_meter = BandwidthMeter::default(); + let shared_server_bandwidth_meter = BandwidthMeter::default(); + + let mut monitored_client_1 = + MeteredStream::new_with_meter(client_1, shared_client_bandwidth_meter.clone()); + let mut monitored_server_1 = + MeteredStream::new_with_meter(server_1, shared_server_bandwidth_meter.clone()); + + let mut monitored_client_2 = + MeteredStream::new_with_meter(client_2, shared_client_bandwidth_meter.clone()); + let mut monitored_server_2 = + MeteredStream::new_with_meter(server_2, shared_server_bandwidth_meter.clone()); + + duplex_stream_ping_pong(&mut monitored_client_1, &mut monitored_server_1).await; + duplex_stream_ping_pong(&mut monitored_client_2, &mut monitored_server_2).await; + + assert_bandwidth_counts(&shared_client_bandwidth_meter, 8, 8); + assert_bandwidth_counts(&shared_server_bandwidth_meter, 8, 8); + } +} diff --git a/crates/net/common/src/lib.rs b/crates/net/common/src/lib.rs index 4623ee0cab..2fbbc4ea08 100644 --- a/crates/net/common/src/lib.rs +++ b/crates/net/common/src/lib.rs @@ -8,3 +8,4 @@ //! Shared types across reth-net pub mod ban_list; +pub mod bandwidth_meter; diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index 6d02dc029e..244e6d8d46 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -38,6 +38,7 @@ use reth_eth_wire::{ capability::{Capabilities, CapabilityMessage}, DisconnectReason, Status, }; +use reth_net_common::bandwidth_meter::BandwidthMeter; use reth_primitives::{PeerId, H256}; use reth_provider::BlockProvider; use std::{ @@ -127,6 +128,12 @@ impl NetworkManager { pub fn handle(&self) -> &NetworkHandle { &self.handle } + + /// Returns a shareable reference to the [`BandwidthMeter`] stored on the [`NetworkInner`] + /// inside of the [`NetworkHandle`] + pub fn bandwidth_meter(&self) -> &BandwidthMeter { + self.handle.bandwidth_meter() + } } impl NetworkManager @@ -174,6 +181,8 @@ where // need to retrieve the addr here since provided port could be `0` let local_peer_id = discovery.local_id(); + let bandwidth_meter: BandwidthMeter = BandwidthMeter::default(); + let sessions = SessionManager::new( secret_key, sessions_config, @@ -181,6 +190,7 @@ where status, hello_message, fork_filter, + bandwidth_meter.clone(), ); let state = NetworkState::new(client, discovery, peers_manager, genesis_hash); @@ -196,6 +206,7 @@ where local_peer_id, peers_handle, network_mode, + bandwidth_meter, ); Ok(Self { diff --git a/crates/net/network/src/network.rs b/crates/net/network/src/network.rs index 8c3173e3e7..30e91514aa 100644 --- a/crates/net/network/src/network.rs +++ b/crates/net/network/src/network.rs @@ -12,6 +12,7 @@ use reth_interfaces::{ p2p::headers::client::StatusUpdater, sync::{SyncState, SyncStateProvider, SyncStateUpdater}, }; +use reth_net_common::bandwidth_meter::BandwidthMeter; use reth_primitives::{PeerId, TransactionSigned, TxHash, H256, U256}; use std::{ net::SocketAddr, @@ -43,6 +44,7 @@ impl NetworkHandle { local_peer_id: PeerId, peers: PeersHandle, network_mode: NetworkMode, + bandwidth_meter: BandwidthMeter, ) -> Self { let inner = NetworkInner { num_active_peers, @@ -51,6 +53,7 @@ impl NetworkHandle { local_peer_id, peers, network_mode, + bandwidth_meter, is_syncing: Arc::new(Default::default()), }; Self { inner: Arc::new(inner) } @@ -200,6 +203,11 @@ impl NetworkHandle { msg: SharedTransactions(msg), }) } + + /// Provides a shareable reference to the [`BandwidthMeter`] stored on the [`NetworkInner`] + pub fn bandwidth_meter(&self) -> &BandwidthMeter { + &self.inner.bandwidth_meter + } } impl StatusUpdater for NetworkHandle { @@ -236,6 +244,8 @@ struct NetworkInner { peers: PeersHandle, /// The mode of the network network_mode: NetworkMode, + /// Used to measure inbound & outbound bandwidth across network streams (currently unused) + bandwidth_meter: BandwidthMeter, /// Represents if the network is currently syncing. is_syncing: Arc, } diff --git a/crates/net/network/src/session/active.rs b/crates/net/network/src/session/active.rs index d7b47372e8..31d731fcea 100644 --- a/crates/net/network/src/session/active.rs +++ b/crates/net/network/src/session/active.rs @@ -17,6 +17,7 @@ use reth_eth_wire::{ DisconnectReason, EthMessage, EthStream, P2PStream, }; use reth_interfaces::p2p::error::RequestError; +use reth_net_common::bandwidth_meter::MeteredStream; use reth_primitives::PeerId; use std::{ collections::VecDeque, @@ -47,7 +48,7 @@ pub(crate) struct ActiveSession { /// Keeps track of request ids. pub(crate) next_id: u64, /// The underlying connection. - pub(crate) conn: EthStream>>, + pub(crate) conn: EthStream>>>, /// Identifier of the node we're connected to. pub(crate) remote_peer_id: PeerId, /// The address we're connected to. @@ -513,6 +514,7 @@ mod tests { EthVersion, HelloMessage, NewPooledTransactionHashes, ProtocolVersion, Status, StatusBuilder, UnauthedEthStream, UnauthedP2PStream, }; + use reth_net_common::bandwidth_meter::BandwidthMeter; use reth_primitives::{ForkFilter, Hardfork}; use secp256k1::{SecretKey, SECP256K1}; use std::time::Duration; @@ -540,6 +542,7 @@ mod tests { status: Status, fork_filter: ForkFilter, next_id: usize, + bandwidth_meter: BandwidthMeter, } impl SessionBuilder { @@ -584,11 +587,13 @@ mod tests { let session_id = self.next_id(); let (_disconnect_tx, disconnect_rx) = oneshot::channel(); let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1); + let metered_stream = + MeteredStream::new_with_meter(stream, self.bandwidth_meter.clone()); tokio::task::spawn(start_pending_incoming_session( disconnect_rx, session_id, - stream, + metered_stream, pending_sessions_tx, remote_addr, self.secret_key, @@ -655,6 +660,7 @@ mod tests { local_peer_id, status: StatusBuilder::default().build(), fork_filter: Hardfork::Frontier.fork_filter(), + bandwidth_meter: BandwidthMeter::default(), } } } diff --git a/crates/net/network/src/session/handle.rs b/crates/net/network/src/session/handle.rs index 38b6987107..6f770ef0a8 100644 --- a/crates/net/network/src/session/handle.rs +++ b/crates/net/network/src/session/handle.rs @@ -9,6 +9,7 @@ use reth_eth_wire::{ errors::EthStreamError, DisconnectReason, EthStream, P2PStream, Status, }; +use reth_net_common::bandwidth_meter::MeteredStream; use reth_primitives::PeerId; use std::{io, net::SocketAddr, sync::Arc, time::Instant}; use tokio::{ @@ -93,7 +94,7 @@ pub(crate) enum PendingSessionEvent { peer_id: PeerId, capabilities: Arc, status: Status, - conn: EthStream>>, + conn: EthStream>>>, direction: Direction, client_id: String, }, diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 99670b70bc..c51ecbd895 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -19,6 +19,7 @@ use reth_eth_wire::{ errors::EthStreamError, DisconnectReason, HelloMessage, Status, UnauthedEthStream, UnauthedP2PStream, }; +use reth_net_common::bandwidth_meter::{BandwidthMeter, MeteredStream}; use reth_primitives::{ForkFilter, ForkId, ForkTransition, PeerId, H256, U256}; use reth_tasks::TaskExecutor; use secp256k1::SecretKey; @@ -88,6 +89,8 @@ pub(crate) struct SessionManager { active_session_tx: mpsc::Sender, /// Receiver half that listens for [`ActiveSessionEvent`] produced by pending sessions. active_session_rx: ReceiverStream, + /// Used to measure inbound & outbound bandwidth across all managed streams + bandwidth_meter: BandwidthMeter, } // === impl SessionManager === @@ -101,6 +104,7 @@ impl SessionManager { status: Status, hello_message: HelloMessage, fork_filter: ForkFilter, + bandwidth_meter: BandwidthMeter, ) -> Self { let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer); let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer); @@ -121,6 +125,7 @@ impl SessionManager { pending_session_rx: ReceiverStream::new(pending_sessions_rx), active_session_tx, active_session_rx: ReceiverStream::new(active_session_rx), + bandwidth_meter, } } @@ -186,10 +191,11 @@ impl SessionManager { let (disconnect_tx, disconnect_rx) = oneshot::channel(); let pending_events = self.pending_sessions_tx.clone(); + let metered_stream = MeteredStream::new_with_meter(stream, self.bandwidth_meter.clone()); self.spawn(start_pending_incoming_session( disconnect_rx, session_id, - stream, + metered_stream, pending_events, remote_addr, self.secret_key, @@ -220,6 +226,7 @@ impl SessionManager { self.hello_message.clone(), self.status, self.fork_filter.clone(), + self.bandwidth_meter.clone(), )); let handle = PendingSessionHandle { @@ -606,7 +613,7 @@ pub struct ExceedsSessionLimit(pub(crate) u32); pub(crate) async fn start_pending_incoming_session( disconnect_rx: oneshot::Receiver<()>, session_id: SessionId, - stream: TcpStream, + stream: MeteredStream, events: mpsc::Sender, remote_addr: SocketAddr, secret_key: SecretKey, @@ -642,9 +649,10 @@ async fn start_pending_outbound_session( hello: HelloMessage, status: Status, fork_filter: ForkFilter, + bandwidth_meter: BandwidthMeter, ) { let stream = match TcpStream::connect(remote_addr).await { - Ok(stream) => stream, + Ok(stream) => MeteredStream::new_with_meter(stream, bandwidth_meter), Err(error) => { let _ = events .send(PendingSessionEvent::OutgoingConnectionError { @@ -677,7 +685,7 @@ async fn start_pending_outbound_session( async fn authenticate( disconnect_rx: oneshot::Receiver<()>, events: mpsc::Sender, - stream: TcpStream, + stream: MeteredStream, session_id: SessionId, remote_addr: SocketAddr, secret_key: SecretKey, @@ -753,7 +761,7 @@ async fn authenticate( /// On Success return the authenticated stream as [`PendingSessionEvent`] #[allow(clippy::too_many_arguments)] async fn authenticate_stream( - stream: UnauthedP2PStream>, + stream: UnauthedP2PStream>>, session_id: SessionId, remote_addr: SocketAddr, direction: Direction,