feat(net): Bandwidth monitoring (#707)

* WIP for draft PR

* added basic test

* using BandwidthMeterInner type & added TcpStream test

* formatted

* formatted w/ +nightly

* using  &  for  and

* formatted

* added default impl for BandwidthMeter

* using _bandwidth_meter bc unused

* removed redundant clone

* addressed nits, renamed file

* addressed nits, renamed file
This commit is contained in:
Andrew Kirillov
2023-01-06 12:43:13 -08:00
committed by GitHub
parent 1d2e0526a8
commit 2da828478c
9 changed files with 326 additions and 9 deletions

3
Cargo.lock generated
View File

@@ -3786,7 +3786,10 @@ dependencies = [
name = "reth-net-common"
version = "0.1.0"
dependencies = [
"pin-project",
"reth-ecies",
"reth-primitives",
"tokio",
]
[[package]]

View File

@@ -10,4 +10,9 @@ Types shared accross network code
[dependencies]
# reth
reth-primitives = { path = "../../primitives" }
reth-primitives = { path = "../../primitives" }
reth-ecies = { path = "../ecies" }
# async
pin-project = "1.0"
tokio = { version = "1.21.2", features = ["full"] }

View File

@@ -0,0 +1,272 @@
//! Support for monitoring bandwidth. Takes heavy inspiration from https://github.com/libp2p/rust-libp2p/blob/master/src/bandwidth.rs
// Copyright 2019 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use reth_ecies::stream::HasRemoteAddr;
use std::{
convert::TryFrom as _,
io,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{ready, Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};
/// Meters bandwidth usage of streams
#[derive(Debug)]
struct BandwidthMeterInner {
/// Measures the number of inbound packets
inbound: AtomicU64,
/// Measures the number of outbound packets
outbound: AtomicU64,
}
/// Public shareable struct used for getting bandwidth metering info
#[derive(Clone, Debug)]
pub struct BandwidthMeter {
inner: Arc<BandwidthMeterInner>,
}
impl BandwidthMeter {
/// Returns the total number of bytes that have been downloaded on all the streams.
///
/// > **Note**: This method is by design subject to race conditions. The returned value should
/// > only ever be used for statistics purposes.
pub fn total_inbound(&self) -> u64 {
self.inner.inbound.load(Ordering::Relaxed)
}
/// Returns the total number of bytes that have been uploaded on all the streams.
///
/// > **Note**: This method is by design subject to race conditions. The returned value should
/// > only ever be used for statistics purposes.
pub fn total_outbound(&self) -> u64 {
self.inner.outbound.load(Ordering::Relaxed)
}
}
impl Default for BandwidthMeter {
fn default() -> Self {
Self {
inner: Arc::new(BandwidthMeterInner {
inbound: AtomicU64::new(0),
outbound: AtomicU64::new(0),
}),
}
}
}
/// Wraps around a single stream that implements [`AsyncRead`] + [`AsyncWrite`] and meters the
/// bandwidth through it
#[derive(Debug)]
#[pin_project::pin_project]
pub struct MeteredStream<S> {
/// The stream this instruments
#[pin]
inner: S,
/// The [`BandwidthMeter`] struct this uses to monitor bandwidth
meter: BandwidthMeter,
}
impl<S> MeteredStream<S> {
/// Creates a new [`MeteredStream`] wrapping around the provided stream,
/// along with a new [`BandwidthMeter`]
pub fn new(inner: S) -> Self {
Self { inner, meter: BandwidthMeter::default() }
}
/// Creates a new [`MeteredStream`] wrapping around the provided stream,
/// attaching the provided [`BandwidthMeter`]
pub fn new_with_meter(inner: S, meter: BandwidthMeter) -> Self {
Self { inner, meter }
}
/// Provides a reference to the [`BandwidthMeter`] attached to this [`MeteredStream`]
pub fn get_bandwidth_meter(&self) -> &BandwidthMeter {
&self.meter
}
}
impl<Stream: AsyncRead> AsyncRead for MeteredStream<Stream> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let num_bytes = {
let init_num_bytes = buf.filled().len();
ready!(this.inner.poll_read(cx, buf))?;
buf.filled().len() - init_num_bytes
};
this.meter
.inner
.inbound
.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
impl<Stream: AsyncWrite> AsyncWrite for MeteredStream<Stream> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
this.meter
.inner
.outbound
.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed);
Poll::Ready(Ok(num_bytes))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
this.inner.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
this.inner.poll_shutdown(cx)
}
}
impl HasRemoteAddr for MeteredStream<TcpStream> {
fn remote_addr(&self) -> Option<SocketAddr> {
self.inner.remote_addr()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream},
net::{TcpListener, TcpStream},
};
async fn duplex_stream_ping_pong(
client: &mut MeteredStream<DuplexStream>,
server: &mut MeteredStream<DuplexStream>,
) {
let mut buf = [0u8; 4];
client.write_all(b"ping").await.unwrap();
server.read(&mut buf).await.unwrap();
server.write_all(b"pong").await.unwrap();
client.read(&mut buf).await.unwrap();
}
fn assert_bandwidth_counts(
bandwidth_meter: &BandwidthMeter,
expected_inbound: u64,
expected_outbound: u64,
) {
let actual_inbound = bandwidth_meter.total_inbound();
assert_eq!(
actual_inbound, expected_inbound,
"Expected {} inbound bytes, but got {}",
expected_inbound, actual_inbound,
);
let actual_outbound = bandwidth_meter.total_outbound();
assert_eq!(
actual_outbound, expected_outbound,
"Expected {} inbound bytes, but got {}",
expected_outbound, actual_outbound,
);
}
#[tokio::test]
async fn test_count_read_write() {
// Taken in large part from https://docs.rs/tokio/latest/tokio/io/struct.DuplexStream.html#example
let (client, server) = duplex(64);
let mut monitored_client = MeteredStream::new(client);
let mut monitored_server = MeteredStream::new(server);
duplex_stream_ping_pong(&mut monitored_client, &mut monitored_server).await;
assert_bandwidth_counts(monitored_client.get_bandwidth_meter(), 4, 4);
assert_bandwidth_counts(monitored_server.get_bandwidth_meter(), 4, 4);
}
#[tokio::test]
async fn test_read_equals_write_tcp() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(server_addr).await.unwrap();
let mut metered_client_stream = MeteredStream::new(client_stream);
let client_meter = metered_client_stream.meter.clone();
let handle = tokio::spawn(async move {
let (server_stream, _) = listener.accept().await.unwrap();
let mut metered_server_stream = MeteredStream::new(server_stream);
let mut buf = [0u8; 4];
metered_server_stream.read(&mut buf).await.unwrap();
assert_eq!(metered_server_stream.meter.total_inbound(), client_meter.total_outbound());
});
metered_client_stream.write_all(b"ping").await.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn test_multiple_streams_one_meter() {
let (client_1, server_1) = duplex(64);
let (client_2, server_2) = duplex(64);
let shared_client_bandwidth_meter = BandwidthMeter::default();
let shared_server_bandwidth_meter = BandwidthMeter::default();
let mut monitored_client_1 =
MeteredStream::new_with_meter(client_1, shared_client_bandwidth_meter.clone());
let mut monitored_server_1 =
MeteredStream::new_with_meter(server_1, shared_server_bandwidth_meter.clone());
let mut monitored_client_2 =
MeteredStream::new_with_meter(client_2, shared_client_bandwidth_meter.clone());
let mut monitored_server_2 =
MeteredStream::new_with_meter(server_2, shared_server_bandwidth_meter.clone());
duplex_stream_ping_pong(&mut monitored_client_1, &mut monitored_server_1).await;
duplex_stream_ping_pong(&mut monitored_client_2, &mut monitored_server_2).await;
assert_bandwidth_counts(&shared_client_bandwidth_meter, 8, 8);
assert_bandwidth_counts(&shared_server_bandwidth_meter, 8, 8);
}
}

View File

@@ -8,3 +8,4 @@
//! Shared types across reth-net
pub mod ban_list;
pub mod bandwidth_meter;

View File

@@ -38,6 +38,7 @@ use reth_eth_wire::{
capability::{Capabilities, CapabilityMessage},
DisconnectReason, Status,
};
use reth_net_common::bandwidth_meter::BandwidthMeter;
use reth_primitives::{PeerId, H256};
use reth_provider::BlockProvider;
use std::{
@@ -127,6 +128,12 @@ impl<C> NetworkManager<C> {
pub fn handle(&self) -> &NetworkHandle {
&self.handle
}
/// Returns a shareable reference to the [`BandwidthMeter`] stored on the [`NetworkInner`]
/// inside of the [`NetworkHandle`]
pub fn bandwidth_meter(&self) -> &BandwidthMeter {
self.handle.bandwidth_meter()
}
}
impl<C> NetworkManager<C>
@@ -174,6 +181,8 @@ where
// need to retrieve the addr here since provided port could be `0`
let local_peer_id = discovery.local_id();
let bandwidth_meter: BandwidthMeter = BandwidthMeter::default();
let sessions = SessionManager::new(
secret_key,
sessions_config,
@@ -181,6 +190,7 @@ where
status,
hello_message,
fork_filter,
bandwidth_meter.clone(),
);
let state = NetworkState::new(client, discovery, peers_manager, genesis_hash);
@@ -196,6 +206,7 @@ where
local_peer_id,
peers_handle,
network_mode,
bandwidth_meter,
);
Ok(Self {

View File

@@ -12,6 +12,7 @@ use reth_interfaces::{
p2p::headers::client::StatusUpdater,
sync::{SyncState, SyncStateProvider, SyncStateUpdater},
};
use reth_net_common::bandwidth_meter::BandwidthMeter;
use reth_primitives::{PeerId, TransactionSigned, TxHash, H256, U256};
use std::{
net::SocketAddr,
@@ -43,6 +44,7 @@ impl NetworkHandle {
local_peer_id: PeerId,
peers: PeersHandle,
network_mode: NetworkMode,
bandwidth_meter: BandwidthMeter,
) -> Self {
let inner = NetworkInner {
num_active_peers,
@@ -51,6 +53,7 @@ impl NetworkHandle {
local_peer_id,
peers,
network_mode,
bandwidth_meter,
is_syncing: Arc::new(Default::default()),
};
Self { inner: Arc::new(inner) }
@@ -200,6 +203,11 @@ impl NetworkHandle {
msg: SharedTransactions(msg),
})
}
/// Provides a shareable reference to the [`BandwidthMeter`] stored on the [`NetworkInner`]
pub fn bandwidth_meter(&self) -> &BandwidthMeter {
&self.inner.bandwidth_meter
}
}
impl StatusUpdater for NetworkHandle {
@@ -236,6 +244,8 @@ struct NetworkInner {
peers: PeersHandle,
/// The mode of the network
network_mode: NetworkMode,
/// Used to measure inbound & outbound bandwidth across network streams (currently unused)
bandwidth_meter: BandwidthMeter,
/// Represents if the network is currently syncing.
is_syncing: Arc<AtomicBool>,
}

View File

@@ -17,6 +17,7 @@ use reth_eth_wire::{
DisconnectReason, EthMessage, EthStream, P2PStream,
};
use reth_interfaces::p2p::error::RequestError;
use reth_net_common::bandwidth_meter::MeteredStream;
use reth_primitives::PeerId;
use std::{
collections::VecDeque,
@@ -47,7 +48,7 @@ pub(crate) struct ActiveSession {
/// Keeps track of request ids.
pub(crate) next_id: u64,
/// The underlying connection.
pub(crate) conn: EthStream<P2PStream<ECIESStream<TcpStream>>>,
pub(crate) conn: EthStream<P2PStream<ECIESStream<MeteredStream<TcpStream>>>>,
/// Identifier of the node we're connected to.
pub(crate) remote_peer_id: PeerId,
/// The address we're connected to.
@@ -513,6 +514,7 @@ mod tests {
EthVersion, HelloMessage, NewPooledTransactionHashes, ProtocolVersion, Status,
StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
};
use reth_net_common::bandwidth_meter::BandwidthMeter;
use reth_primitives::{ForkFilter, Hardfork};
use secp256k1::{SecretKey, SECP256K1};
use std::time::Duration;
@@ -540,6 +542,7 @@ mod tests {
status: Status,
fork_filter: ForkFilter,
next_id: usize,
bandwidth_meter: BandwidthMeter,
}
impl SessionBuilder {
@@ -584,11 +587,13 @@ mod tests {
let session_id = self.next_id();
let (_disconnect_tx, disconnect_rx) = oneshot::channel();
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
let metered_stream =
MeteredStream::new_with_meter(stream, self.bandwidth_meter.clone());
tokio::task::spawn(start_pending_incoming_session(
disconnect_rx,
session_id,
stream,
metered_stream,
pending_sessions_tx,
remote_addr,
self.secret_key,
@@ -655,6 +660,7 @@ mod tests {
local_peer_id,
status: StatusBuilder::default().build(),
fork_filter: Hardfork::Frontier.fork_filter(),
bandwidth_meter: BandwidthMeter::default(),
}
}
}

View File

@@ -9,6 +9,7 @@ use reth_eth_wire::{
errors::EthStreamError,
DisconnectReason, EthStream, P2PStream, Status,
};
use reth_net_common::bandwidth_meter::MeteredStream;
use reth_primitives::PeerId;
use std::{io, net::SocketAddr, sync::Arc, time::Instant};
use tokio::{
@@ -93,7 +94,7 @@ pub(crate) enum PendingSessionEvent {
peer_id: PeerId,
capabilities: Arc<Capabilities>,
status: Status,
conn: EthStream<P2PStream<ECIESStream<TcpStream>>>,
conn: EthStream<P2PStream<ECIESStream<MeteredStream<TcpStream>>>>,
direction: Direction,
client_id: String,
},

View File

@@ -19,6 +19,7 @@ use reth_eth_wire::{
errors::EthStreamError,
DisconnectReason, HelloMessage, Status, UnauthedEthStream, UnauthedP2PStream,
};
use reth_net_common::bandwidth_meter::{BandwidthMeter, MeteredStream};
use reth_primitives::{ForkFilter, ForkId, ForkTransition, PeerId, H256, U256};
use reth_tasks::TaskExecutor;
use secp256k1::SecretKey;
@@ -88,6 +89,8 @@ pub(crate) struct SessionManager {
active_session_tx: mpsc::Sender<ActiveSessionMessage>,
/// Receiver half that listens for [`ActiveSessionEvent`] produced by pending sessions.
active_session_rx: ReceiverStream<ActiveSessionMessage>,
/// Used to measure inbound & outbound bandwidth across all managed streams
bandwidth_meter: BandwidthMeter,
}
// === impl SessionManager ===
@@ -101,6 +104,7 @@ impl SessionManager {
status: Status,
hello_message: HelloMessage,
fork_filter: ForkFilter,
bandwidth_meter: BandwidthMeter,
) -> Self {
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer);
let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer);
@@ -121,6 +125,7 @@ impl SessionManager {
pending_session_rx: ReceiverStream::new(pending_sessions_rx),
active_session_tx,
active_session_rx: ReceiverStream::new(active_session_rx),
bandwidth_meter,
}
}
@@ -186,10 +191,11 @@ impl SessionManager {
let (disconnect_tx, disconnect_rx) = oneshot::channel();
let pending_events = self.pending_sessions_tx.clone();
let metered_stream = MeteredStream::new_with_meter(stream, self.bandwidth_meter.clone());
self.spawn(start_pending_incoming_session(
disconnect_rx,
session_id,
stream,
metered_stream,
pending_events,
remote_addr,
self.secret_key,
@@ -220,6 +226,7 @@ impl SessionManager {
self.hello_message.clone(),
self.status,
self.fork_filter.clone(),
self.bandwidth_meter.clone(),
));
let handle = PendingSessionHandle {
@@ -606,7 +613,7 @@ pub struct ExceedsSessionLimit(pub(crate) u32);
pub(crate) async fn start_pending_incoming_session(
disconnect_rx: oneshot::Receiver<()>,
session_id: SessionId,
stream: TcpStream,
stream: MeteredStream<TcpStream>,
events: mpsc::Sender<PendingSessionEvent>,
remote_addr: SocketAddr,
secret_key: SecretKey,
@@ -642,9 +649,10 @@ async fn start_pending_outbound_session(
hello: HelloMessage,
status: Status,
fork_filter: ForkFilter,
bandwidth_meter: BandwidthMeter,
) {
let stream = match TcpStream::connect(remote_addr).await {
Ok(stream) => stream,
Ok(stream) => MeteredStream::new_with_meter(stream, bandwidth_meter),
Err(error) => {
let _ = events
.send(PendingSessionEvent::OutgoingConnectionError {
@@ -677,7 +685,7 @@ async fn start_pending_outbound_session(
async fn authenticate(
disconnect_rx: oneshot::Receiver<()>,
events: mpsc::Sender<PendingSessionEvent>,
stream: TcpStream,
stream: MeteredStream<TcpStream>,
session_id: SessionId,
remote_addr: SocketAddr,
secret_key: SecretKey,
@@ -753,7 +761,7 @@ async fn authenticate(
/// On Success return the authenticated stream as [`PendingSessionEvent`]
#[allow(clippy::too_many_arguments)]
async fn authenticate_stream(
stream: UnauthedP2PStream<ECIESStream<TcpStream>>,
stream: UnauthedP2PStream<ECIESStream<MeteredStream<TcpStream>>>,
session_id: SessionId,
remote_addr: SocketAddr,
direction: Direction,