From dbac7e1e4a7789243b3347bd238984018ec5518e Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Thu, 5 Feb 2026 18:20:04 -0800 Subject: [PATCH] feat(eth-wire): introduce ProtocolMessage::decode_status for handshake (#21797) Co-authored-by: Amp Co-authored-by: Matthias Seitz --- crates/net/eth-wire-types/src/message.rs | 76 ++++++++++ crates/net/eth-wire/src/handshake.rs | 169 ++++++++++------------- 2 files changed, 152 insertions(+), 93 deletions(-) diff --git a/crates/net/eth-wire-types/src/message.rs b/crates/net/eth-wire-types/src/message.rs index 5d29d960bf..ee36986db8 100644 --- a/crates/net/eth-wire-types/src/message.rs +++ b/crates/net/eth-wire-types/src/message.rs @@ -34,6 +34,9 @@ pub enum MessageError { /// Flags an unrecognized message ID for a given protocol version. #[error("message id {1:?} is invalid for version {0:?}")] Invalid(EthVersion, EthMessageID), + /// Expected a Status message but received a different message type. + #[error("expected status message but received {0:?}")] + ExpectedStatusMessage(EthMessageID), /// Thrown when rlp decoding a message failed. #[error("RLP error: {0}")] RlpError(#[from] alloy_rlp::Error), @@ -57,6 +60,29 @@ pub struct ProtocolMessage { } impl ProtocolMessage { + /// Decode only a Status message from RLP bytes. + /// + /// This is used during the eth handshake where only a Status message is a valid response. + /// Returns an error if the message is not a Status message. + pub fn decode_status( + version: EthVersion, + buf: &mut &[u8], + ) -> Result { + let message_type = EthMessageID::decode(buf)?; + + if message_type != EthMessageID::Status { + return Err(MessageError::ExpectedStatusMessage(message_type)) + } + + let status = if version < EthVersion::Eth69 { + StatusMessage::Legacy(Status::decode(buf)?) + } else { + StatusMessage::Eth69(StatusEth69::decode(buf)?) + }; + + Ok(status) + } + /// Create a new `ProtocolMessage` from a message type and message rlp bytes. /// /// This will enforce decoding according to the given [`EthVersion`] of the connection. @@ -881,4 +907,54 @@ mod tests { assert_eq!(protocol_message, decoded); } + + #[test] + fn decode_status_success() { + use crate::{Status, StatusMessage}; + use alloy_hardforks::{ForkHash, ForkId}; + use alloy_primitives::{B256, U256}; + + let status = Status { + version: EthVersion::Eth68, + chain: alloy_chains::Chain::mainnet(), + total_difficulty: U256::from(100u64), + blockhash: B256::random(), + genesis: B256::random(), + forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 }, + }; + + let protocol_message = ProtocolMessage::::from(EthMessage::Status( + StatusMessage::Legacy(status), + )); + let encoded = encode(protocol_message); + + let decoded = ProtocolMessage::::decode_status( + EthVersion::Eth68, + &mut &encoded[..], + ) + .unwrap(); + + assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status)); + } + + #[test] + fn decode_status_rejects_non_status() { + let msg = EthMessage::::GetBlockBodies(RequestPair { + request_id: 1, + message: crate::GetBlockBodies::default(), + }); + let protocol_message = + ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg }; + let encoded = encode(protocol_message); + + let result = ProtocolMessage::::decode_status( + EthVersion::Eth68, + &mut &encoded[..], + ); + + assert!(matches!( + result, + Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies)) + )); + } } diff --git a/crates/net/eth-wire/src/handshake.rs b/crates/net/eth-wire/src/handshake.rs index f604f1fca1..cb6d55aae2 100644 --- a/crates/net/eth-wire/src/handshake.rs +++ b/crates/net/eth-wire/src/handshake.rs @@ -119,117 +119,100 @@ where } let version = status.version(); - let msg = match ProtocolMessage::::decode_message( + let their_status_message = match ProtocolMessage::::decode_status( version, &mut their_msg.as_ref(), ) { - Ok(m) => m, + Ok(status) => status, Err(err) => { debug!("decode error in eth handshake: msg={their_msg:x}"); unauth - .disconnect(DisconnectReason::DisconnectRequested) + .disconnect(DisconnectReason::ProtocolBreach) .await .map_err(EthStreamError::from)?; return Err(EthStreamError::InvalidMessage(err)); } }; - // Validate peer response - match msg.message { - EthMessage::Status(their_status_message) => { - trace!("Validating incoming ETH status from peer"); + trace!("Validating incoming ETH status from peer"); - if status.genesis() != their_status_message.genesis() { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - return Err(EthHandshakeError::MismatchedGenesis( - GotExpected { - expected: status.genesis(), - got: their_status_message.genesis(), - } - .into(), - ) - .into()); - } + if status.genesis() != their_status_message.genesis() { + unauth + .disconnect(DisconnectReason::ProtocolBreach) + .await + .map_err(EthStreamError::from)?; + return Err(EthHandshakeError::MismatchedGenesis( + GotExpected { expected: status.genesis(), got: their_status_message.genesis() } + .into(), + ) + .into()); + } - if status.version() != their_status_message.version() { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected { - got: their_status_message.version(), - expected: status.version(), - }) - .into()); - } + if status.version() != their_status_message.version() { + unauth + .disconnect(DisconnectReason::ProtocolBreach) + .await + .map_err(EthStreamError::from)?; + return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected { + got: their_status_message.version(), + expected: status.version(), + }) + .into()); + } - if *status.chain() != *their_status_message.chain() { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - return Err(EthHandshakeError::MismatchedChain(GotExpected { - got: *their_status_message.chain(), - expected: *status.chain(), - }) - .into()); - } + if *status.chain() != *their_status_message.chain() { + unauth + .disconnect(DisconnectReason::ProtocolBreach) + .await + .map_err(EthStreamError::from)?; + return Err(EthHandshakeError::MismatchedChain(GotExpected { + got: *their_status_message.chain(), + expected: *status.chain(), + }) + .into()); + } - // Ensure peer's total difficulty is reasonable - if let StatusMessage::Legacy(s) = their_status_message && - s.total_difficulty.bit_len() > 160 - { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge { - got: s.total_difficulty.bit_len(), - maximum: 160, - } - .into()); - } - - // Fork validation - if let Err(err) = fork_filter - .validate(their_status_message.forkid()) - .map_err(EthHandshakeError::InvalidFork) - { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - return Err(err.into()); - } - - if let StatusMessage::Eth69(s) = their_status_message { - if s.earliest > s.latest { - return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock { - got: s.earliest, - latest: s.latest, - } - .into()); - } - - if s.blockhash.is_zero() { - return Err(EthHandshakeError::BlockhashZero.into()); - } - } - - Ok(UnifiedStatus::from_message(their_status_message)) + // Ensure peer's total difficulty is reasonable + if let StatusMessage::Legacy(s) = &their_status_message && + s.total_difficulty.bit_len() > 160 + { + unauth + .disconnect(DisconnectReason::ProtocolBreach) + .await + .map_err(EthStreamError::from)?; + return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge { + got: s.total_difficulty.bit_len(), + maximum: 160, } - _ => { - unauth - .disconnect(DisconnectReason::ProtocolBreach) - .await - .map_err(EthStreamError::from)?; - Err(EthStreamError::EthHandshakeError( - EthHandshakeError::NonStatusMessageInHandshake, - )) + .into()); + } + + // Fork validation + if let Err(err) = fork_filter + .validate(their_status_message.forkid()) + .map_err(EthHandshakeError::InvalidFork) + { + unauth + .disconnect(DisconnectReason::ProtocolBreach) + .await + .map_err(EthStreamError::from)?; + return Err(err.into()); + } + + if let StatusMessage::Eth69(s) = &their_status_message { + if s.earliest > s.latest { + return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock { + got: s.earliest, + latest: s.latest, + } + .into()); + } + + if s.blockhash.is_zero() { + return Err(EthHandshakeError::BlockhashZero.into()); } } + + Ok(UnifiedStatus::from_message(their_status_message)) } }