mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-01-11 00:08:13 -05:00
feat(net): implement support of subprotocols (#18080)
Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
@@ -13,15 +13,17 @@ use std::{
|
||||
future::Future,
|
||||
io,
|
||||
pin::{pin, Pin},
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
|
||||
errors::{EthStreamError, P2PStreamError},
|
||||
handshake::EthRlpxHandshake,
|
||||
p2pstream::DisconnectP2P,
|
||||
CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnauthedEthStream,
|
||||
UnifiedStatus,
|
||||
CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
|
||||
HANDSHAKE_TIMEOUT,
|
||||
};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
|
||||
@@ -135,7 +137,7 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
/// This accepts a closure that does a handshake with the remote peer and returns a tuple of the
|
||||
/// primary stream and extra data.
|
||||
///
|
||||
/// See also [`UnauthedEthStream::handshake`]
|
||||
/// See also [`UnauthedEthStream::handshake`](crate::UnauthedEthStream)
|
||||
pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
|
||||
mut self,
|
||||
cap: &Capability,
|
||||
@@ -167,6 +169,7 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
// complete
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
Some(Ok(msg)) = self.inner.conn.next() => {
|
||||
// Ensure the message belongs to the primary protocol
|
||||
let Some(offset) = msg.first().copied()
|
||||
@@ -188,6 +191,10 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
Some(msg) = from_primary.recv() => {
|
||||
self.inner.conn.send(msg).await.map_err(Into::into)?;
|
||||
}
|
||||
// Poll all subprotocols for new messages
|
||||
msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
|
||||
self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
|
||||
}
|
||||
res = &mut f => {
|
||||
let (st, extra) = res?;
|
||||
return Ok((RlpxSatelliteStream {
|
||||
@@ -205,22 +212,28 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
}
|
||||
|
||||
/// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
|
||||
/// primary protocol.
|
||||
/// primary protocol and the handshake implementation.
|
||||
pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
|
||||
self,
|
||||
status: UnifiedStatus,
|
||||
fork_filter: ForkFilter,
|
||||
handshake: Arc<dyn EthRlpxHandshake>,
|
||||
) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
|
||||
where
|
||||
St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
|
||||
self.into_satellite_stream_with_tuple_handshake(
|
||||
&Capability::eth(eth_cap),
|
||||
move |proxy| async move {
|
||||
UnauthedEthStream::new(proxy).handshake(status, fork_filter).await
|
||||
},
|
||||
)
|
||||
self.into_satellite_stream_with_tuple_handshake(&Capability::eth(eth_cap), move |proxy| {
|
||||
let handshake = handshake.clone();
|
||||
async move {
|
||||
let mut unauth = UnauthProxy { inner: proxy };
|
||||
let their_status = handshake
|
||||
.handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
|
||||
.await?;
|
||||
let eth_stream = EthStream::new(eth_cap, unauth.into_inner());
|
||||
Ok((eth_stream, their_status))
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -377,6 +390,57 @@ impl CanDisconnect<Bytes> for ProtocolProxy {
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapter so the injected `EthRlpxHandshake` can run over a multiplexed `ProtocolProxy`
|
||||
/// using the same error type expectations (`P2PStreamError`).
|
||||
#[derive(Debug)]
|
||||
struct UnauthProxy {
|
||||
inner: ProtocolProxy,
|
||||
}
|
||||
|
||||
impl UnauthProxy {
|
||||
fn into_inner(self) -> ProtocolProxy {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for UnauthProxy {
|
||||
type Item = Result<BytesMut, P2PStreamError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Bytes> for UnauthProxy {
|
||||
type Error = P2PStreamError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||
self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl CanDisconnect<Bytes> for UnauthProxy {
|
||||
fn disconnect(
|
||||
&mut self,
|
||||
reason: DisconnectReason,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
|
||||
let fut = self.inner.disconnect(reason);
|
||||
Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection channel to receive _`non_empty`_ messages for the negotiated protocol.
|
||||
///
|
||||
/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
|
||||
@@ -666,15 +730,56 @@ impl fmt::Debug for ProtocolStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to poll multiple protocol streams in a `tokio::select`! branch
|
||||
struct ProtocolsPoller<'a> {
|
||||
protocols: &'a mut Vec<ProtocolStream>,
|
||||
}
|
||||
|
||||
impl<'a> ProtocolsPoller<'a> {
|
||||
const fn new(protocols: &'a mut Vec<ProtocolStream>) -> Self {
|
||||
Self { protocols }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Future for ProtocolsPoller<'a> {
|
||||
type Output = Result<Bytes, P2PStreamError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
// Process protocols in reverse order, like the existing pattern
|
||||
for idx in (0..self.protocols.len()).rev() {
|
||||
let mut proto = self.protocols.swap_remove(idx);
|
||||
match proto.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
self.protocols.push(proto);
|
||||
return Poll::Ready(Err(P2PStreamError::from(err)))
|
||||
}
|
||||
Poll::Ready(Some(Ok(msg))) => {
|
||||
// Got a message, put protocol back and return the message
|
||||
self.protocols.push(proto);
|
||||
return Poll::Ready(Ok(msg));
|
||||
}
|
||||
_ => {
|
||||
// push it back because we still want to complete the handshake first
|
||||
self.protocols.push(proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All protocols processed, nothing ready
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
handshake::EthHandshake,
|
||||
test_utils::{
|
||||
connect_passthrough, eth_handshake, eth_hello,
|
||||
proto::{test_hello, TestProtoMessage},
|
||||
},
|
||||
UnauthedP2PStream,
|
||||
UnauthedEthStream, UnauthedP2PStream,
|
||||
};
|
||||
use reth_eth_wire_types::EthNetworkPrimitives;
|
||||
use tokio::{net::TcpListener, sync::oneshot};
|
||||
@@ -736,7 +841,11 @@ mod tests {
|
||||
let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
|
||||
|
||||
let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
|
||||
.into_eth_satellite_stream::<EthNetworkPrimitives>(other_status, other_fork_filter)
|
||||
.into_eth_satellite_stream::<EthNetworkPrimitives>(
|
||||
other_status,
|
||||
other_fork_filter,
|
||||
Arc::new(EthHandshake::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -767,7 +876,11 @@ mod tests {
|
||||
|
||||
let conn = connect_passthrough(local_addr, test_hello().0).await;
|
||||
let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
|
||||
.into_eth_satellite_stream::<EthNetworkPrimitives>(status, fork_filter)
|
||||
.into_eth_satellite_stream::<EthNetworkPrimitives>(
|
||||
status,
|
||||
fork_filter,
|
||||
Arc::new(EthHandshake::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1150,18 +1150,20 @@ async fn authenticate_stream<N: NetworkPrimitives>(
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (multiplex_stream, their_status) =
|
||||
match multiplex_stream.into_eth_satellite_stream(status, fork_filter).await {
|
||||
Ok((multiplex_stream, their_status)) => (multiplex_stream, their_status),
|
||||
Err(err) => {
|
||||
return PendingSessionEvent::Disconnected {
|
||||
remote_addr,
|
||||
session_id,
|
||||
direction,
|
||||
error: Some(PendingSessionHandshakeError::Eth(err)),
|
||||
}
|
||||
let (multiplex_stream, their_status) = match multiplex_stream
|
||||
.into_eth_satellite_stream(status, fork_filter, handshake)
|
||||
.await
|
||||
{
|
||||
Ok((multiplex_stream, their_status)) => (multiplex_stream, their_status),
|
||||
Err(err) => {
|
||||
return PendingSessionEvent::Disconnected {
|
||||
remote_addr,
|
||||
session_id,
|
||||
direction,
|
||||
error: Some(PendingSessionHandshakeError::Eth(err)),
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
(multiplex_stream.into(), their_status)
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user