feat(eth-wire): introduce ProtocolMessage::decode_status for handshake (#21797)

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
Georgios Konstantopoulos
2026-02-05 18:20:04 -08:00
committed by GitHub
parent cb999b2a2d
commit dbac7e1e4a
2 changed files with 152 additions and 93 deletions

View File

@@ -34,6 +34,9 @@ pub enum MessageError {
/// Flags an unrecognized message ID for a given protocol version.
#[error("message id {1:?} is invalid for version {0:?}")]
Invalid(EthVersion, EthMessageID),
/// Expected a Status message but received a different message type.
#[error("expected status message but received {0:?}")]
ExpectedStatusMessage(EthMessageID),
/// Thrown when rlp decoding a message failed.
#[error("RLP error: {0}")]
RlpError(#[from] alloy_rlp::Error),
@@ -57,6 +60,29 @@ pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
}
impl<N: NetworkPrimitives> ProtocolMessage<N> {
/// Decode only a Status message from RLP bytes.
///
/// This is used during the eth handshake where only a Status message is a valid response.
/// Returns an error if the message is not a Status message.
pub fn decode_status(
version: EthVersion,
buf: &mut &[u8],
) -> Result<StatusMessage, MessageError> {
let message_type = EthMessageID::decode(buf)?;
if message_type != EthMessageID::Status {
return Err(MessageError::ExpectedStatusMessage(message_type))
}
let status = if version < EthVersion::Eth69 {
StatusMessage::Legacy(Status::decode(buf)?)
} else {
StatusMessage::Eth69(StatusEth69::decode(buf)?)
};
Ok(status)
}
/// Create a new `ProtocolMessage` from a message type and message rlp bytes.
///
/// This will enforce decoding according to the given [`EthVersion`] of the connection.
@@ -881,4 +907,54 @@ mod tests {
assert_eq!(protocol_message, decoded);
}
#[test]
fn decode_status_success() {
use crate::{Status, StatusMessage};
use alloy_hardforks::{ForkHash, ForkId};
use alloy_primitives::{B256, U256};
let status = Status {
version: EthVersion::Eth68,
chain: alloy_chains::Chain::mainnet(),
total_difficulty: U256::from(100u64),
blockhash: B256::random(),
genesis: B256::random(),
forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 },
};
let protocol_message = ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::Status(
StatusMessage::Legacy(status),
));
let encoded = encode(protocol_message);
let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
EthVersion::Eth68,
&mut &encoded[..],
)
.unwrap();
assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status));
}
#[test]
fn decode_status_rejects_non_status() {
let msg = EthMessage::<EthNetworkPrimitives>::GetBlockBodies(RequestPair {
request_id: 1,
message: crate::GetBlockBodies::default(),
});
let protocol_message =
ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg };
let encoded = encode(protocol_message);
let result = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
EthVersion::Eth68,
&mut &encoded[..],
);
assert!(matches!(
result,
Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies))
));
}
}

View File

@@ -119,117 +119,100 @@ where
}
let version = status.version();
let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
let their_status_message = match ProtocolMessage::<EthNetworkPrimitives>::decode_status(
version,
&mut their_msg.as_ref(),
) {
Ok(m) => m,
Ok(status) => status,
Err(err) => {
debug!("decode error in eth handshake: msg={their_msg:x}");
unauth
.disconnect(DisconnectReason::DisconnectRequested)
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthStreamError::InvalidMessage(err));
}
};
// Validate peer response
match msg.message {
EthMessage::Status(their_status_message) => {
trace!("Validating incoming ETH status from peer");
trace!("Validating incoming ETH status from peer");
if status.genesis() != their_status_message.genesis() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedGenesis(
GotExpected {
expected: status.genesis(),
got: their_status_message.genesis(),
}
.into(),
)
.into());
}
if status.genesis() != their_status_message.genesis() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedGenesis(
GotExpected { expected: status.genesis(), got: their_status_message.genesis() }
.into(),
)
.into());
}
if status.version() != their_status_message.version() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
got: their_status_message.version(),
expected: status.version(),
})
.into());
}
if status.version() != their_status_message.version() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
got: their_status_message.version(),
expected: status.version(),
})
.into());
}
if *status.chain() != *their_status_message.chain() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedChain(GotExpected {
got: *their_status_message.chain(),
expected: *status.chain(),
})
.into());
}
if *status.chain() != *their_status_message.chain() {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::MismatchedChain(GotExpected {
got: *their_status_message.chain(),
expected: *status.chain(),
})
.into());
}
// Ensure peer's total difficulty is reasonable
if let StatusMessage::Legacy(s) = their_status_message &&
s.total_difficulty.bit_len() > 160
{
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
got: s.total_difficulty.bit_len(),
maximum: 160,
}
.into());
}
// Fork validation
if let Err(err) = fork_filter
.validate(their_status_message.forkid())
.map_err(EthHandshakeError::InvalidFork)
{
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(err.into());
}
if let StatusMessage::Eth69(s) = their_status_message {
if s.earliest > s.latest {
return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock {
got: s.earliest,
latest: s.latest,
}
.into());
}
if s.blockhash.is_zero() {
return Err(EthHandshakeError::BlockhashZero.into());
}
}
Ok(UnifiedStatus::from_message(their_status_message))
// Ensure peer's total difficulty is reasonable
if let StatusMessage::Legacy(s) = &their_status_message &&
s.total_difficulty.bit_len() > 160
{
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
got: s.total_difficulty.bit_len(),
maximum: 160,
}
_ => {
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
))
.into());
}
// Fork validation
if let Err(err) = fork_filter
.validate(their_status_message.forkid())
.map_err(EthHandshakeError::InvalidFork)
{
unauth
.disconnect(DisconnectReason::ProtocolBreach)
.await
.map_err(EthStreamError::from)?;
return Err(err.into());
}
if let StatusMessage::Eth69(s) = &their_status_message {
if s.earliest > s.latest {
return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock {
got: s.earliest,
latest: s.latest,
}
.into());
}
if s.blockhash.is_zero() {
return Err(EthHandshakeError::BlockhashZero.into());
}
}
Ok(UnifiedStatus::from_message(their_status_message))
}
}