diff --git a/crates/net/eth-wire-types/Cargo.toml b/crates/net/eth-wire-types/Cargo.toml index 7a43efae43..ea8ddd21d5 100644 --- a/crates/net/eth-wire-types/Cargo.toml +++ b/crates/net/eth-wire-types/Cargo.toml @@ -27,8 +27,8 @@ alloy-consensus.workspace = true bytes.workspace = true derive_more.workspace = true -thiserror.workspace = true serde = { workspace = true, optional = true } +thiserror.workspace = true # arbitrary utils arbitrary = { workspace = true, features = ["derive"], optional = true } diff --git a/crates/net/eth-wire-types/src/snap.rs b/crates/net/eth-wire-types/src/snap.rs index 8d6b446ed9..e20786b48b 100644 --- a/crates/net/eth-wire-types/src/snap.rs +++ b/crates/net/eth-wire-types/src/snap.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use alloy_primitives::{Bytes, B256}; -use alloy_rlp::{RlpDecodable, RlpEncodable}; +use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable}; use reth_codecs_derive::add_arbitrary_tests; /// Message IDs for the snap sync protocol @@ -224,4 +224,200 @@ impl SnapProtocolMessage { Self::TrieNodes(_) => SnapMessageId::TrieNodes, } } + + /// Encode the message to bytes + pub fn encode(&self) -> Bytes { + let mut buf = Vec::new(); + // Add message ID as first byte + buf.push(self.message_id() as u8); + + // Encode the message body based on its type + match self { + Self::GetAccountRange(msg) => msg.encode(&mut buf), + Self::AccountRange(msg) => msg.encode(&mut buf), + Self::GetStorageRanges(msg) => msg.encode(&mut buf), + Self::StorageRanges(msg) => msg.encode(&mut buf), + Self::GetByteCodes(msg) => msg.encode(&mut buf), + Self::ByteCodes(msg) => msg.encode(&mut buf), + Self::GetTrieNodes(msg) => msg.encode(&mut buf), + Self::TrieNodes(msg) => msg.encode(&mut buf), + } + + Bytes::from(buf) + } + + /// Decodes a SNAP protocol message from its message ID and RLP-encoded body. + pub fn decode(message_id: u8, buf: &mut &[u8]) -> Result { + // Decoding protocol message variants based on message ID + macro_rules! decode_snap_message_variant { + ($message_id:expr, $buf:expr, $id:expr, $variant:ident, $msg_type:ty) => { + if $message_id == $id as u8 { + return Ok(Self::$variant(<$msg_type>::decode($buf)?)); + } + }; + } + + // Try to decode each message type based on the message ID + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::GetAccountRange, + GetAccountRange, + GetAccountRangeMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::AccountRange, + AccountRange, + AccountRangeMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::GetStorageRanges, + GetStorageRanges, + GetStorageRangesMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::StorageRanges, + StorageRanges, + StorageRangesMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::GetByteCodes, + GetByteCodes, + GetByteCodesMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::ByteCodes, + ByteCodes, + ByteCodesMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::GetTrieNodes, + GetTrieNodes, + GetTrieNodesMessage + ); + decode_snap_message_variant!( + message_id, + buf, + SnapMessageId::TrieNodes, + TrieNodes, + TrieNodesMessage + ); + + Err(alloy_rlp::Error::Custom("Unknown message ID")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function to create a B256 from a u64 for testing + fn b256_from_u64(value: u64) -> B256 { + B256::left_padding_from(&value.to_be_bytes()) + } + + // Helper function to test roundtrip encoding/decoding + fn test_roundtrip(original: SnapProtocolMessage) { + let encoded = original.encode(); + + // Verify the first byte matches the expected message ID + assert_eq!(encoded[0], original.message_id() as u8); + + let mut buf = &encoded[1..]; + let decoded = SnapProtocolMessage::decode(encoded[0], &mut buf).unwrap(); + + // Verify the match + assert_eq!(decoded, original); + } + + #[test] + fn test_all_message_roundtrips() { + test_roundtrip(SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage { + request_id: 42, + root_hash: b256_from_u64(123), + starting_hash: b256_from_u64(456), + limit_hash: b256_from_u64(789), + response_bytes: 1024, + })); + + test_roundtrip(SnapProtocolMessage::AccountRange(AccountRangeMessage { + request_id: 42, + accounts: vec![AccountData { + hash: b256_from_u64(123), + body: Bytes::from(vec![1, 2, 3]), + }], + proof: vec![Bytes::from(vec![4, 5, 6])], + })); + + test_roundtrip(SnapProtocolMessage::GetStorageRanges(GetStorageRangesMessage { + request_id: 42, + root_hash: b256_from_u64(123), + account_hashes: vec![b256_from_u64(456)], + starting_hash: b256_from_u64(789), + limit_hash: b256_from_u64(101112), + response_bytes: 2048, + })); + + test_roundtrip(SnapProtocolMessage::StorageRanges(StorageRangesMessage { + request_id: 42, + slots: vec![vec![StorageData { + hash: b256_from_u64(123), + data: Bytes::from(vec![1, 2, 3]), + }]], + proof: vec![Bytes::from(vec![4, 5, 6])], + })); + + test_roundtrip(SnapProtocolMessage::GetByteCodes(GetByteCodesMessage { + request_id: 42, + hashes: vec![b256_from_u64(123)], + response_bytes: 1024, + })); + + test_roundtrip(SnapProtocolMessage::ByteCodes(ByteCodesMessage { + request_id: 42, + codes: vec![Bytes::from(vec![1, 2, 3])], + })); + + test_roundtrip(SnapProtocolMessage::GetTrieNodes(GetTrieNodesMessage { + request_id: 42, + root_hash: b256_from_u64(123), + paths: vec![TriePath { + account_path: Bytes::from(vec![1, 2, 3]), + slot_paths: vec![Bytes::from(vec![4, 5, 6])], + }], + response_bytes: 1024, + })); + + test_roundtrip(SnapProtocolMessage::TrieNodes(TrieNodesMessage { + request_id: 42, + nodes: vec![Bytes::from(vec![1, 2, 3])], + })); + } + + #[test] + fn test_unknown_message_id() { + // Create some random data + let data = Bytes::from(vec![1, 2, 3, 4]); + let mut buf = data.as_ref(); + + // Try to decode with an invalid message ID + let result = SnapProtocolMessage::decode(255, &mut buf); + + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e.to_string(), "Unknown message ID"); + } + } } diff --git a/crates/net/eth-wire/src/eth_snap_stream.rs b/crates/net/eth-wire/src/eth_snap_stream.rs new file mode 100644 index 0000000000..a1e6d88468 --- /dev/null +++ b/crates/net/eth-wire/src/eth_snap_stream.rs @@ -0,0 +1,421 @@ +//! Ethereum and snap combined protocol stream implementation. +//! +//! A stream type for handling both eth and snap protocol messages over a single `RLPx` connection. +//! Provides message encoding/decoding, ID multiplexing, and protocol message processing. + +use super::message::MAX_MESSAGE_SIZE; +use crate::{ + message::{EthBroadcastMessage, ProtocolBroadcastMessage}, + EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage, + RawCapabilityMessage, SnapMessageId, SnapProtocolMessage, +}; +use alloy_rlp::{Bytes, BytesMut, Encodable}; +use core::fmt::Debug; +use futures::{Sink, SinkExt}; +use pin_project::pin_project; +use std::{ + marker::PhantomData, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio_stream::Stream; + +/// Error type for the eth and snap stream +#[derive(thiserror::Error, Debug)] +pub enum EthSnapStreamError { + /// Invalid message for protocol version + #[error("invalid message for version {0:?}: {1}")] + InvalidMessage(EthVersion, String), + + /// Unknown message ID + #[error("unknown message id: {0}")] + UnknownMessageId(u8), + + /// Message too large + #[error("message too large: {0} > {1}")] + MessageTooLarge(usize, usize), + + /// RLP decoding error + #[error("rlp error: {0}")] + Rlp(#[from] alloy_rlp::Error), + + /// Status message received outside handshake + #[error("status message received outside handshake")] + StatusNotInHandshake, +} + +/// Combined message type that include either eth or snao protocol messages +#[derive(Debug)] +pub enum EthSnapMessage { + /// An Ethereum protocol message + Eth(EthMessage), + /// A snap protocol message + Snap(SnapProtocolMessage), +} + +/// A stream implementation that can handle both eth and snap protocol messages +/// over a single connection. +#[pin_project] +#[derive(Debug, Clone)] +pub struct EthSnapStream { + /// Protocol logic + eth_snap: EthSnapStreamInner, + /// Inner byte stream + #[pin] + inner: S, +} + +impl EthSnapStream +where + N: NetworkPrimitives, +{ + /// Create a new eth and snap protocol stream + pub fn new(stream: S, eth_version: EthVersion) -> Self { + Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream } + } + + /// Returns the eth version + #[inline] + pub const fn eth_version(&self) -> EthVersion { + self.eth_snap.eth_version() + } + + /// Returns the underlying stream + #[inline] + pub const fn inner(&self) -> &S { + &self.inner + } + + /// Returns mutable access to the underlying stream + #[inline] + pub fn inner_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consumes this type and returns the wrapped stream + #[inline] + pub fn into_inner(self) -> S { + self.inner + } +} + +impl EthSnapStream +where + S: Sink + Unpin, + EthSnapStreamError: From, + N: NetworkPrimitives, +{ + /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead. + pub fn start_send_broadcast( + &mut self, + item: EthBroadcastMessage, + ) -> Result<(), EthSnapStreamError> { + self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode( + ProtocolBroadcastMessage::from(item), + )))?; + + Ok(()) + } + + /// Sends a raw capability message directly over the stream + pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> { + let mut bytes = Vec::with_capacity(msg.payload.len() + 1); + msg.id.encode(&mut bytes); + bytes.extend_from_slice(&msg.payload); + + self.inner.start_send_unpin(bytes.into())?; + Ok(()) + } +} + +impl Stream for EthSnapStream +where + S: Stream> + Unpin, + EthSnapStreamError: From, + N: NetworkPrimitives, +{ + type Item = Result, EthSnapStreamError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let res = ready!(this.inner.poll_next(cx)); + + match res { + Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))), + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), + } + } +} + +impl Sink> for EthSnapStream +where + S: Sink + Unpin, + EthSnapStreamError: From, + N: NetworkPrimitives, +{ + type Error = EthSnapStreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx).map_err(Into::into) + } + + fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage) -> Result<(), Self::Error> { + let mut this = self.as_mut().project(); + + let bytes = match item { + EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?, + EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg), + }; + + this.inner.start_send_unpin(bytes)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx).map_err(Into::into) + } +} + +/// Stream handling combined eth and snap protocol logic +/// Snap version is not critical to specify yet, +/// Only one version, snap/1, does exist. +#[derive(Debug, Clone)] +struct EthSnapStreamInner { + /// Eth protocol version + eth_version: EthVersion, + /// Type marker + _pd: PhantomData, +} + +impl EthSnapStreamInner +where + N: NetworkPrimitives, +{ + /// Create a new eth and snap protocol stream + const fn new(eth_version: EthVersion) -> Self { + Self { eth_version, _pd: PhantomData } + } + + #[inline] + const fn eth_version(&self) -> EthVersion { + self.eth_version + } + + /// Decode a message from the stream + fn decode_message(&self, bytes: BytesMut) -> Result, EthSnapStreamError> { + if bytes.len() > MAX_MESSAGE_SIZE { + return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), MAX_MESSAGE_SIZE)); + } + + if bytes.is_empty() { + return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort)); + } + + let message_id = bytes[0]; + + // This check works because capabilities are sorted lexicographically + // if "eth" before "snap", giving eth messages lower IDs than snap messages, + // and eth message IDs are <= [`EthMessageID::max()`], + // snap message IDs are > [`EthMessageID::max()`]. + // See also . + if message_id <= EthMessageID::max() { + let mut buf = bytes.as_ref(); + match ProtocolMessage::decode_message(self.eth_version, &mut buf) { + Ok(protocol_msg) => { + if matches!(protocol_msg.message, EthMessage::Status(_)) { + return Err(EthSnapStreamError::StatusNotInHandshake); + } + Ok(EthSnapMessage::Eth(protocol_msg.message)) + } + Err(err) => { + Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string())) + } + } + } else if message_id > EthMessageID::max() && + message_id <= EthMessageID::max() + 1 + SnapMessageId::TrieNodes as u8 + { + // Checks for multiplexed snap message IDs : + // - message_id > EthMessageID::max() : ensures it's not an eth message + // - message_id <= EthMessageID::max() + 1 + snap_max : ensures it's within valid snap + // range + // Message IDs are assigned lexicographically during capability negotiation + // So real_snap_id = multiplexed_id - num_eth_messages + let adjusted_message_id = message_id - (EthMessageID::max() + 1); + let mut buf = &bytes[1..]; + + match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) { + Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)), + Err(err) => Err(EthSnapStreamError::Rlp(err)), + } + } else { + Err(EthSnapStreamError::UnknownMessageId(message_id)) + } + } + + /// Encode an eth message + fn encode_eth_message(&self, item: EthMessage) -> Result { + if matches!(item, EthMessage::Status(_)) { + return Err(EthSnapStreamError::StatusNotInHandshake); + } + + let protocol_msg = ProtocolMessage::from(item); + let mut buf = Vec::new(); + protocol_msg.encode(&mut buf); + Ok(Bytes::from(buf)) + } + + /// Encode a snap protocol message, adjusting the message ID to follow eth message IDs + /// for proper multiplexing. + fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes { + let encoded = message.encode(); + + let message_id = encoded[0]; + let adjusted_id = message_id + EthMessageID::max() + 1; + + let mut adjusted = Vec::with_capacity(encoded.len()); + adjusted.push(adjusted_id); + adjusted.extend_from_slice(&encoded[1..]); + + Bytes::from(adjusted) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{EthMessage, SnapProtocolMessage}; + use alloy_eips::BlockHashOrNumber; + use alloy_primitives::B256; + use alloy_rlp::Encodable; + use reth_eth_wire_types::{ + message::RequestPair, GetAccountRangeMessage, GetBlockHeaders, HeadersDirection, + }; + + // Helper to create eth message and its bytes + fn create_eth_message() -> (EthMessage, BytesMut) { + let eth_msg = EthMessage::::GetBlockHeaders(RequestPair { + request_id: 1, + message: GetBlockHeaders { + start_block: BlockHashOrNumber::Number(1), + limit: 10, + skip: 0, + direction: HeadersDirection::Rising, + }, + }); + + let protocol_msg = ProtocolMessage::from(eth_msg.clone()); + let mut buf = Vec::new(); + protocol_msg.encode(&mut buf); + + (eth_msg, BytesMut::from(&buf[..])) + } + + // Helper to create snap message and its bytes + fn create_snap_message() -> (SnapProtocolMessage, BytesMut) { + let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage { + request_id: 1, + root_hash: B256::default(), + starting_hash: B256::default(), + limit_hash: B256::default(), + response_bytes: 1000, + }); + + let inner = EthSnapStreamInner::::new(EthVersion::Eth67); + let encoded = inner.encode_snap_message(snap_msg.clone()); + + (snap_msg, BytesMut::from(&encoded[..])) + } + + #[test] + fn test_eth_message_roundtrip() { + let inner = EthSnapStreamInner::::new(EthVersion::Eth67); + let (eth_msg, eth_bytes) = create_eth_message(); + + // Verify encoding + let encoded_result = inner.encode_eth_message(eth_msg.clone()); + assert!(encoded_result.is_ok()); + + // Verify decoding + let decoded_result = inner.decode_message(eth_bytes.clone()); + assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_)))); + + // round trip + if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) { + assert_eq!(decoded_msg, eth_msg); + + let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap(); + let re_encoded_bytes = BytesMut::from(&re_encoded[..]); + let re_decoded = inner.decode_message(re_encoded_bytes); + + assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_)))); + if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded { + assert_eq!(final_msg, decoded_msg); + } + } + } + + #[test] + fn test_snap_protocol() { + let inner = EthSnapStreamInner::::new(EthVersion::Eth67); + let (snap_msg, snap_bytes) = create_snap_message(); + + // Verify encoding + let encoded_bytes = inner.encode_snap_message(snap_msg.clone()); + assert!(!encoded_bytes.is_empty()); + + // Verify decoding + let decoded_result = inner.decode_message(snap_bytes.clone()); + assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_)))); + + // round trip + if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) { + assert_eq!(decoded_msg, snap_msg); + + // re-encode message + let encoded = inner.encode_snap_message(decoded_msg.clone()); + + let re_encoded_bytes = BytesMut::from(&encoded[..]); + + // decode with properly adjusted ID + let re_decoded = inner.decode_message(re_encoded_bytes); + + assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_)))); + if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded { + assert_eq!(final_msg, decoded_msg); + } + } + } + + #[test] + fn test_message_id_boundaries() { + let inner = EthSnapStreamInner::::new(EthVersion::Eth67); + + // Create a bytes buffer with eth message ID at the max boundary with minimal content + let eth_max_id = EthMessageID::max(); + let mut eth_boundary_bytes = BytesMut::new(); + eth_boundary_bytes.extend_from_slice(&[eth_max_id]); + eth_boundary_bytes.extend_from_slice(&[0, 0]); + + // This should be decoded as eth message + let eth_boundary_result = inner.decode_message(eth_boundary_bytes); + assert!( + eth_boundary_result.is_err() || + matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_))) + ); + + // Create a bytes buffer with message ID just above eth max, it should be snap min + let snap_min_id = eth_max_id + 1; + let mut snap_boundary_bytes = BytesMut::new(); + snap_boundary_bytes.extend_from_slice(&[snap_min_id]); + snap_boundary_bytes.extend_from_slice(&[0, 0]); + + // Not a valid snap message yet, only snap id --> error + let snap_boundary_result = inner.decode_message(snap_boundary_bytes); + assert!(snap_boundary_result.is_err()); + } +} diff --git a/crates/net/eth-wire/src/lib.rs b/crates/net/eth-wire/src/lib.rs index e8603a90ff..a2cb35ae7f 100644 --- a/crates/net/eth-wire/src/lib.rs +++ b/crates/net/eth-wire/src/lib.rs @@ -16,6 +16,7 @@ pub mod capability; mod disconnect; pub mod errors; +pub mod eth_snap_stream; mod ethstream; mod hello; pub mod multiplex;