fix(eth-wire): send p2p handshake disconnects (#1047)

This commit is contained in:
Dan Cline
2023-01-26 09:01:53 -05:00
committed by GitHub
parent bd540c70ce
commit 1d5cce1092

View File

@@ -67,6 +67,25 @@ impl<S> UnauthedP2PStream<S> {
}
}
impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}
impl<S> UnauthedP2PStream<S>
where
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
@@ -128,20 +147,29 @@ where
"validating incoming p2p hello from peer"
);
// TODO: explicitly document that we only support v5.
if their_hello.protocol_version != ProtocolVersion::V5 {
// TODO: do we want to send a `Disconnect` message here?
if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
// send a disconnect message notifying the peer of the protocol version mismatch
self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
return Err(P2PStreamError::MismatchedProtocolVersion {
expected: ProtocolVersion::V5 as u8,
expected: hello.protocol_version as u8,
got: their_hello.protocol_version as u8,
})
}
// determine shared capabilities (currently returns only one capability)
let capability =
set_capability_offsets(hello.capabilities, their_hello.capabilities.clone())?;
let capability_res =
set_capability_offsets(hello.capabilities, their_hello.capabilities.clone());
let stream = P2PStream::new(self.inner, capability);
let shared_capability = match capability_res {
Err(err) => {
// we don't share any capabilities, send a disconnect message
self.send_disconnect(DisconnectReason::UselessPeer).await?;
Err(err)
}
Ok(cap) => Ok(cap),
}?;
let stream = P2PStream::new(self.inner, shared_capability);
Ok((stream, their_hello))
}
@@ -542,8 +570,6 @@ pub fn set_capability_offsets(
// disconnect if we don't share any capabilities
if shared_capabilities.is_empty() {
// TODO: send a disconnect message? if we want to do this, this will need to be a member
// method of `UnauthedP2PStream` so it can access the inner stream
return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
}
@@ -857,6 +883,62 @@ mod tests {
handle.await.unwrap();
}
#[tokio::test]
async fn test_handshake_disconnect() {
// create a p2p stream and server, then confirm that the two are authed
// create tcpstream
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let (server_hello, _) = eth_hello();
let unauthed_stream = UnauthedP2PStream::new(stream);
match unauthed_stream.handshake(server_hello.clone()).await {
Ok((_, hello)) => panic!(
"expected handshake to fail, instead got a successful Hello: {:?}",
hello
),
Err(P2PStreamError::MismatchedProtocolVersion { expected, got }) => {
assert_eq!(expected, server_hello.protocol_version as u8);
assert_ne!(expected, got);
}
Err(other_err) => {
panic!("expected mismatched protocol version error, got {:?}", other_err)
}
}
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = crate::PassthroughCodec::default().framed(outgoing);
let (mut client_hello, _) = eth_hello();
// modify the hello to include an incompatible p2p protocol version
client_hello.protocol_version = ProtocolVersion::V4;
let unauthed_stream = UnauthedP2PStream::new(sink);
match unauthed_stream.handshake(client_hello.clone()).await {
Ok((_, hello)) => {
panic!("expected handshake to fail, instead got a successful Hello: {:?}", hello)
}
Err(P2PStreamError::MismatchedProtocolVersion { expected, got }) => {
assert_eq!(expected, client_hello.protocol_version as u8);
assert_ne!(expected, got);
}
Err(other_err) => {
panic!("expected mismatched protocol version error, got {:?}", other_err)
}
}
// make sure the server receives the message and asserts before ending the test
handle.await.unwrap();
}
#[test]
fn test_peer_lower_capability_version() {
let local_capabilities: Vec<Capability> =