mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-19 03:04:27 -05:00
feat(eth-wire): introduce ProtocolMessage::decode_status for handshake (#21797)
Co-authored-by: Amp <amp@ampcode.com> Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
committed by
GitHub
parent
cb999b2a2d
commit
dbac7e1e4a
@@ -34,6 +34,9 @@ pub enum MessageError {
|
||||
/// Flags an unrecognized message ID for a given protocol version.
|
||||
#[error("message id {1:?} is invalid for version {0:?}")]
|
||||
Invalid(EthVersion, EthMessageID),
|
||||
/// Expected a Status message but received a different message type.
|
||||
#[error("expected status message but received {0:?}")]
|
||||
ExpectedStatusMessage(EthMessageID),
|
||||
/// Thrown when rlp decoding a message failed.
|
||||
#[error("RLP error: {0}")]
|
||||
RlpError(#[from] alloy_rlp::Error),
|
||||
@@ -57,6 +60,29 @@ pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
|
||||
}
|
||||
|
||||
impl<N: NetworkPrimitives> ProtocolMessage<N> {
|
||||
/// Decode only a Status message from RLP bytes.
|
||||
///
|
||||
/// This is used during the eth handshake where only a Status message is a valid response.
|
||||
/// Returns an error if the message is not a Status message.
|
||||
pub fn decode_status(
|
||||
version: EthVersion,
|
||||
buf: &mut &[u8],
|
||||
) -> Result<StatusMessage, MessageError> {
|
||||
let message_type = EthMessageID::decode(buf)?;
|
||||
|
||||
if message_type != EthMessageID::Status {
|
||||
return Err(MessageError::ExpectedStatusMessage(message_type))
|
||||
}
|
||||
|
||||
let status = if version < EthVersion::Eth69 {
|
||||
StatusMessage::Legacy(Status::decode(buf)?)
|
||||
} else {
|
||||
StatusMessage::Eth69(StatusEth69::decode(buf)?)
|
||||
};
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
/// Create a new `ProtocolMessage` from a message type and message rlp bytes.
|
||||
///
|
||||
/// This will enforce decoding according to the given [`EthVersion`] of the connection.
|
||||
@@ -881,4 +907,54 @@ mod tests {
|
||||
|
||||
assert_eq!(protocol_message, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_status_success() {
|
||||
use crate::{Status, StatusMessage};
|
||||
use alloy_hardforks::{ForkHash, ForkId};
|
||||
use alloy_primitives::{B256, U256};
|
||||
|
||||
let status = Status {
|
||||
version: EthVersion::Eth68,
|
||||
chain: alloy_chains::Chain::mainnet(),
|
||||
total_difficulty: U256::from(100u64),
|
||||
blockhash: B256::random(),
|
||||
genesis: B256::random(),
|
||||
forkid: ForkId { hash: ForkHash([0xb7, 0x15, 0x07, 0x7d]), next: 0 },
|
||||
};
|
||||
|
||||
let protocol_message = ProtocolMessage::<EthNetworkPrimitives>::from(EthMessage::Status(
|
||||
StatusMessage::Legacy(status),
|
||||
));
|
||||
let encoded = encode(protocol_message);
|
||||
|
||||
let decoded = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
|
||||
EthVersion::Eth68,
|
||||
&mut &encoded[..],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(decoded, StatusMessage::Legacy(s) if s == status));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_status_rejects_non_status() {
|
||||
let msg = EthMessage::<EthNetworkPrimitives>::GetBlockBodies(RequestPair {
|
||||
request_id: 1,
|
||||
message: crate::GetBlockBodies::default(),
|
||||
});
|
||||
let protocol_message =
|
||||
ProtocolMessage { message_type: EthMessageID::GetBlockBodies, message: msg };
|
||||
let encoded = encode(protocol_message);
|
||||
|
||||
let result = ProtocolMessage::<EthNetworkPrimitives>::decode_status(
|
||||
EthVersion::Eth68,
|
||||
&mut &encoded[..],
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(MessageError::ExpectedStatusMessage(EthMessageID::GetBlockBodies))
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,117 +119,100 @@ where
|
||||
}
|
||||
|
||||
let version = status.version();
|
||||
let msg = match ProtocolMessage::<EthNetworkPrimitives>::decode_message(
|
||||
let their_status_message = match ProtocolMessage::<EthNetworkPrimitives>::decode_status(
|
||||
version,
|
||||
&mut their_msg.as_ref(),
|
||||
) {
|
||||
Ok(m) => m,
|
||||
Ok(status) => status,
|
||||
Err(err) => {
|
||||
debug!("decode error in eth handshake: msg={their_msg:x}");
|
||||
unauth
|
||||
.disconnect(DisconnectReason::DisconnectRequested)
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthStreamError::InvalidMessage(err));
|
||||
}
|
||||
};
|
||||
|
||||
// Validate peer response
|
||||
match msg.message {
|
||||
EthMessage::Status(their_status_message) => {
|
||||
trace!("Validating incoming ETH status from peer");
|
||||
trace!("Validating incoming ETH status from peer");
|
||||
|
||||
if status.genesis() != their_status_message.genesis() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedGenesis(
|
||||
GotExpected {
|
||||
expected: status.genesis(),
|
||||
got: their_status_message.genesis(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
if status.genesis() != their_status_message.genesis() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedGenesis(
|
||||
GotExpected { expected: status.genesis(), got: their_status_message.genesis() }
|
||||
.into(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
if status.version() != their_status_message.version() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
|
||||
got: their_status_message.version(),
|
||||
expected: status.version(),
|
||||
})
|
||||
.into());
|
||||
}
|
||||
if status.version() != their_status_message.version() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
|
||||
got: their_status_message.version(),
|
||||
expected: status.version(),
|
||||
})
|
||||
.into());
|
||||
}
|
||||
|
||||
if *status.chain() != *their_status_message.chain() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedChain(GotExpected {
|
||||
got: *their_status_message.chain(),
|
||||
expected: *status.chain(),
|
||||
})
|
||||
.into());
|
||||
}
|
||||
if *status.chain() != *their_status_message.chain() {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::MismatchedChain(GotExpected {
|
||||
got: *their_status_message.chain(),
|
||||
expected: *status.chain(),
|
||||
})
|
||||
.into());
|
||||
}
|
||||
|
||||
// Ensure peer's total difficulty is reasonable
|
||||
if let StatusMessage::Legacy(s) = their_status_message &&
|
||||
s.total_difficulty.bit_len() > 160
|
||||
{
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
|
||||
got: s.total_difficulty.bit_len(),
|
||||
maximum: 160,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
// Fork validation
|
||||
if let Err(err) = fork_filter
|
||||
.validate(their_status_message.forkid())
|
||||
.map_err(EthHandshakeError::InvalidFork)
|
||||
{
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(err.into());
|
||||
}
|
||||
|
||||
if let StatusMessage::Eth69(s) = their_status_message {
|
||||
if s.earliest > s.latest {
|
||||
return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock {
|
||||
got: s.earliest,
|
||||
latest: s.latest,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if s.blockhash.is_zero() {
|
||||
return Err(EthHandshakeError::BlockhashZero.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UnifiedStatus::from_message(their_status_message))
|
||||
// Ensure peer's total difficulty is reasonable
|
||||
if let StatusMessage::Legacy(s) = &their_status_message &&
|
||||
s.total_difficulty.bit_len() > 160
|
||||
{
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
|
||||
got: s.total_difficulty.bit_len(),
|
||||
maximum: 160,
|
||||
}
|
||||
_ => {
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
Err(EthStreamError::EthHandshakeError(
|
||||
EthHandshakeError::NonStatusMessageInHandshake,
|
||||
))
|
||||
.into());
|
||||
}
|
||||
|
||||
// Fork validation
|
||||
if let Err(err) = fork_filter
|
||||
.validate(their_status_message.forkid())
|
||||
.map_err(EthHandshakeError::InvalidFork)
|
||||
{
|
||||
unauth
|
||||
.disconnect(DisconnectReason::ProtocolBreach)
|
||||
.await
|
||||
.map_err(EthStreamError::from)?;
|
||||
return Err(err.into());
|
||||
}
|
||||
|
||||
if let StatusMessage::Eth69(s) = &their_status_message {
|
||||
if s.earliest > s.latest {
|
||||
return Err(EthHandshakeError::EarliestBlockGreaterThanLatestBlock {
|
||||
got: s.earliest,
|
||||
latest: s.latest,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if s.blockhash.is_zero() {
|
||||
return Err(EthHandshakeError::BlockhashZero.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UnifiedStatus::from_message(their_status_message))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user