diff --git a/Cargo.lock b/Cargo.lock index a040a58b12..5066715dfb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,11 +123,12 @@ checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" [[package]] name = "async-lock" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e97a171d191782fba31bb902b14ad94e24a68145032b7eedf871ab0bc0d077b6" +checksum = "c8101efe8695a6c17e02911402145357e718ac92d3ff88ae8419e84b1707b685" dependencies = [ "event-listener", + "futures-lite", ] [[package]] @@ -390,15 +391,16 @@ dependencies = [ [[package]] name = "cargo_metadata" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3abb7553d5b9b8421c6de7cb02606ff15e0c6eea7d8eadd75ef013fd636bec36" +checksum = "406c859255d568f4f742b3146d51851f3bfd49f734a2c289d9107c4395ee0062" dependencies = [ "camino", "cargo-platform", "semver 1.0.14", "serde", "serde_json", + "thiserror", ] [[package]] @@ -497,9 +499,9 @@ dependencies = [ [[package]] name = "clap" -version = "3.2.22" +version = "3.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" +checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" dependencies = [ "bitflags", "clap_lex 0.2.4", @@ -638,7 +640,7 @@ dependencies = [ "atty", "cast", "ciborium", - "clap 3.2.22", + "clap 3.2.23", "criterion-plot", "itertools 0.10.5", "lazy_static", @@ -790,12 +792,12 @@ dependencies = [ [[package]] name = "darling" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4529658bdda7fd6769b8614be250cdcfc3aeb0ee72fe66f9e41e5e5eb73eac02" +checksum = "b0dd3cd20dc6b5a876612a6e5accfe7f3dd883db6d07acfbf14c128f61550dfa" dependencies = [ - "darling_core 0.14.1", - "darling_macro 0.14.1", + "darling_core 0.14.2", + "darling_macro 0.14.2", ] [[package]] @@ -814,9 +816,9 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "649c91bc01e8b1eac09fb91e8dbc7d517684ca6be8ebc75bb9cafc894f9fdb6f" +checksum = "a784d2ccaf7c98501746bf0be29b2022ba41fd62a2e622af997a03e9f972859f" dependencies = [ "fnv", "ident_case", @@ -839,11 +841,11 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddfc69c5bfcbd2fc09a0f38451d2daf0e372e367986a83906d1b0dbc88134fb5" +checksum = "7618812407e9402654622dd402b0a89dff9ba93badd6540781526117b92aab7e" dependencies = [ - "darling_core 0.14.1", + "darling_core 0.14.2", "quote", "syn", ] @@ -959,13 +961,13 @@ dependencies = [ [[package]] name = "discv5" version = "0.1.0" -source = "git+https://github.com/sigp/discv5#7d8c1ce072de384a472beebaf03d36fb463b9f7a" +source = "git+https://github.com/sigp/discv5#517eb3f0c5e5b347d8fe6c2973e1698f89e83524" dependencies = [ "aes 0.7.5", "aes-gcm", "arrayvec", "delay_map", - "enr 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", + "enr 0.6.2", "fnv", "futures", "hashlink", @@ -1111,8 +1113,8 @@ dependencies = [ [[package]] name = "enr" -version = "0.6.2" -source = "git+https://github.com/sigp/enr#fba51d4473f1b6fcc66161cd593352b70995e702" +version = "0.7.0" +source = "git+https://github.com/sigp/enr#f27b94eafad20dc04d47c97a0d75d32f2c5e72e9" dependencies = [ "base64", "bs58", @@ -1204,18 +1206,18 @@ dependencies = [ [[package]] name = "ethers-core" -version = "0.17.0" -source = "git+https://github.com/gakonst/ethers-rs#a9dd53da810d8eff82aa77e0f9297b4a453028e6" +version = "1.0.0" +source = "git+https://github.com/gakonst/ethers-rs#def99318bb0d65257ea68c93fcc269cdf90d0284" dependencies = [ "arrayvec", "bytes", "chrono", "elliptic-curve", "ethabi", - "fastrlp", "generic-array", "hex", "k256", + "open-fastrlp", "rand 0.8.5", "rlp", "rlp-derive", @@ -1268,20 +1270,6 @@ dependencies = [ "arrayvec", "auto_impl", "bytes", - "ethereum-types", - "fastrlp-derive", -] - -[[package]] -name = "fastrlp-derive" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9e9158c1d8f0a7a716c9191562eaabba70268ba64972ef4871ce8d66fd08872" -dependencies = [ - "bytes", - "proc-macro2", - "quote", - "syn", ] [[package]] @@ -1375,6 +1363,21 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" +[[package]] +name = "futures-lite" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694489acd39452c77daa48516b894c153f192c3578d5a839b62c58099fcbf48" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-macro" version = "0.3.25" @@ -1545,9 +1548,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be" +checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" dependencies = [ "bytes", "fnv", @@ -2135,9 +2138,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.135" +version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c" +checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" [[package]] name = "libloading" @@ -2258,14 +2261,14 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "mio" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" +checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.36.1", + "windows-sys 0.42.0", ] [[package]] @@ -2452,6 +2455,31 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "open-fastrlp" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "131de184f045153e72c537ef4f1d57babddf2a897ca19e67bdff697aebba7f3d" +dependencies = [ + "arrayvec", + "auto_impl", + "bytes", + "ethereum-types", + "open-fastrlp-derive", +] + +[[package]] +name = "open-fastrlp-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "003b2be5c6c53c1cfeb0a238b8a1c3915cd410feb684457a36c10038f764bb1c" +dependencies = [ + "bytes", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "openssl-probe" version = "0.1.5" @@ -2507,6 +2535,12 @@ dependencies = [ "syn", ] +[[package]] +name = "parking" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72" + [[package]] name = "parking_lot" version = "0.11.2" @@ -3115,11 +3149,13 @@ dependencies = [ "hex-literal", "maplit", "pin-project", + "pin-utils", "rand 0.8.5", "reth-ecies", "reth-primitives", "reth-rlp", "secp256k1", + "snap", "thiserror", "tokio", "tokio-stream", @@ -3218,7 +3254,7 @@ dependencies = [ name = "reth-p2p" version = "0.1.0" dependencies = [ - "enr 0.6.2 (git+https://github.com/sigp/enr)", + "enr 0.7.0", "secp256k1", "serde", "serde_derive", @@ -3636,9 +3672,9 @@ dependencies = [ [[package]] name = "secp256k1" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7649a0b3ffb32636e60c7ce0d70511eda9c52c658cd0634e194d5a19943aeff" +checksum = "ff55dc09d460954e9ef2fa8a7ced735a964be9981fd50e870b2b3b0705e14964" dependencies = [ "rand 0.8.5", "secp256k1-sys", @@ -3894,6 +3930,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7475118a28b7e3a2e157ce0131ba8c5526ea96e90ee601d9f6bb2e286a35ab44" +[[package]] +name = "snap" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" + [[package]] name = "socket2" version = "0.4.7" @@ -4082,9 +4124,9 @@ dependencies = [ [[package]] name = "test-fuzz" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "125df852011c4f8f31df5620f4aea38ecddb5dfb4d9bc569b30485b15ffc3d4e" +checksum = "cd4a3a7f00909d5a1d1f83b86b65d91e4c94f80b0c2d0ae37e2ef44da7b7a0a0" dependencies = [ "serde", "test-fuzz-internal", @@ -4094,9 +4136,9 @@ dependencies = [ [[package]] name = "test-fuzz-internal" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58071dc2471840e9f374eeb0f6e405a31bccb3cc5d59bb4598f02cafc274b5c4" +checksum = "a9186daca5c58cb307d09731e0ba06b13fd6c036c90672b9bfc31cecf76cf689" dependencies = [ "cargo_metadata", "proc-macro2", @@ -4107,11 +4149,11 @@ dependencies = [ [[package]] name = "test-fuzz-macro" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856bbca0314c328004691b9c0639fb198ca764d1ce0e20d4dd8b78f2697c2a6f" +checksum = "57d187b450bfb5b7939f82f9747dc1ebb15a7a9c4a93cd304a41aece7149608b" dependencies = [ - "darling 0.14.1", + "darling 0.14.2", "if_chain", "lazy_static", "proc-macro2", @@ -4125,9 +4167,9 @@ dependencies = [ [[package]] name = "test-fuzz-runtime" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "303774eb17994c2ddb59c460369f4c3a55496f013380278d78eeebd2deb896ac" +checksum = "1a0d69068569b9b7311095823fe0e49eedfd05ad4277eb64fc334cf1a5bc5116" dependencies = [ "bincode", "hex", @@ -4139,9 +4181,9 @@ dependencies = [ [[package]] name = "textwrap" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16" +checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" @@ -4635,6 +4677,12 @@ dependencies = [ "libc", ] +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + [[package]] name = "walkdir" version = "2.3.2" diff --git a/crates/net/ecies/src/util.rs b/crates/net/ecies/src/util.rs index ea4891d7f3..2bdfbc93e7 100644 --- a/crates/net/ecies/src/util.rs +++ b/crates/net/ecies/src/util.rs @@ -31,6 +31,8 @@ pub fn pk2id(pk: &PublicKey) -> PeerId { /// Converts a [PeerId] to a [secp256k1::PublicKey] by prepending the [PeerId] bytes with the /// SECP256K1_TAG_PUBKEY_UNCOMPRESSED tag. pub(crate) fn id2pk(id: PeerId) -> Result { + // NOTE: H512 is used as a PeerId not because it represents a hash, but because 512 bits is + // enough to represent an uncompressed public key. let mut s = [0_u8; 65]; // SECP256K1_TAG_PUBKEY_UNCOMPRESSED = 0x04 // see: https://github.com/bitcoin-core/secp256k1/blob/master/include/secp256k1.h#L211 diff --git a/crates/net/eth-wire/Cargo.toml b/crates/net/eth-wire/Cargo.toml index 1a358ec8cb..3882176a38 100644 --- a/crates/net/eth-wire/Cargo.toml +++ b/crates/net/eth-wire/Cargo.toml @@ -26,7 +26,9 @@ tokio-stream = "0.1.11" secp256k1 = { version = "0.24.0", features = ["global-context", "rand-std", "recovery"] } tokio-util = { version = "0.7.4", features = ["io"] } pin-project = "1.0" +pin-utils = "0.1.0" tracing = "0.1.37" +snap = "1.0.5" [dev-dependencies] hex-literal = "0.3" diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs new file mode 100644 index 0000000000..f67a384f99 --- /dev/null +++ b/crates/net/eth-wire/src/capability.rs @@ -0,0 +1,87 @@ +use crate::{version::ParseVersionError, EthVersion}; + +/// This represents a shared capability, its version, and its offset. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SharedCapability { + /// The `eth` capability. + Eth { version: EthVersion, offset: u8 }, + + /// An unknown capability. + UnknownCapability { name: String, version: u8, offset: u8 }, +} + +impl SharedCapability { + /// Creates a new [`SharedCapability`] based on the given name, offset, and version. + pub(crate) fn new(name: &str, version: u8, offset: u8) -> Result { + match name { + "eth" => Ok(Self::Eth { version: EthVersion::try_from(version)?, offset }), + _ => Ok(Self::UnknownCapability { name: name.to_string(), version, offset }), + } + } + + /// Returns the name of the capability. + pub(crate) fn name(&self) -> &str { + match self { + SharedCapability::Eth { .. } => "eth", + SharedCapability::UnknownCapability { name, .. } => name, + } + } + + /// Returns the version of the capability. + pub(crate) fn version(&self) -> u8 { + match self { + SharedCapability::Eth { version, .. } => *version as u8, + SharedCapability::UnknownCapability { version, .. } => *version, + } + } + + /// Returns the message ID offset of the current capability. + pub(crate) fn offset(&self) -> u8 { + match self { + SharedCapability::Eth { offset, .. } => *offset, + SharedCapability::UnknownCapability { offset, .. } => *offset, + } + } + + /// Returns the number of protocol messages supported by this capability. + pub(crate) fn num_messages(&self) -> Result { + match self { + SharedCapability::Eth { version, .. } => Ok(version.total_messages()), + _ => Err(SharedCapabilityError::UnknownCapability), + } + } +} + +/// An error that may occur while creating a [`SharedCapability`]. +#[derive(Debug, thiserror::Error)] +pub enum SharedCapabilityError { + /// Unsupported `eth` version. + #[error(transparent)] + UnsupportedVersion(#[from] ParseVersionError), + /// Cannot determine the number of messages for unknown capabilities. + #[error("cannot determine the number of messages for unknown capabilities")] + UnknownCapability, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_eth_67() { + let capability = SharedCapability::new("eth", 67, 0).unwrap(); + + assert_eq!(capability.name(), "eth"); + assert_eq!(capability.version(), 67); + assert_eq!(capability, SharedCapability::Eth { version: EthVersion::Eth67, offset: 0 }); + } + + #[test] + fn from_eth_66() { + let capability = SharedCapability::new("eth", 66, 0).unwrap(); + + assert_eq!(capability.name(), "eth"); + assert_eq!(capability.version(), 66); + assert_eq!(capability, SharedCapability::Eth { version: EthVersion::Eth66, offset: 0 }); + } +} diff --git a/crates/net/eth-wire/src/error.rs b/crates/net/eth-wire/src/error.rs index 0ee68fc196..e3f360e1dd 100644 --- a/crates/net/eth-wire/src/error.rs +++ b/crates/net/eth-wire/src/error.rs @@ -3,7 +3,7 @@ use std::io; use reth_primitives::{Chain, H256}; -use crate::types::forkid::ValidationError; +use crate::{capability::SharedCapabilityError, types::forkid::ValidationError}; /// Errors when sending/receiving messages #[derive(thiserror::Error, Debug)] @@ -14,6 +14,8 @@ pub enum EthStreamError { #[error(transparent)] Rlp(#[from] reth_rlp::DecodeError), #[error(transparent)] + P2PStreamError(#[from] P2PStreamError), + #[error(transparent)] HandshakeError(#[from] HandshakeError), #[error("message size ({0}) exceeds max length (10MB)")] MessageTooBig(usize), @@ -37,3 +39,66 @@ pub enum HandshakeError { #[error("mismatched chain in Status message. expected: {expected:?}, got: {got:?}")] MismatchedChain { expected: Chain, got: Chain }, } + +/// Errors when sending/receiving p2p messages. These should result in kicking the peer. +#[derive(thiserror::Error, Debug)] +pub enum P2PStreamError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Rlp(#[from] reth_rlp::DecodeError), + #[error(transparent)] + Snap(#[from] snap::Error), + #[error(transparent)] + HandshakeError(#[from] P2PHandshakeError), + #[error("message size ({message_size}) exceeds max length ({max_size})")] + MessageTooBig { message_size: usize, max_size: usize }, + #[error("unknown reserved p2p message id: {0}")] + UnknownReservedMessageId(u8), + #[error("empty protocol message received")] + EmptyProtocolMessage, + #[error(transparent)] + PingerError(#[from] PingerError), + #[error("ping timed out with {0} retries")] + PingTimeout(u8), + #[error(transparent)] + ParseVersionError(#[from] SharedCapabilityError), + #[error("mismatched protocol version in Hello message. expected: {expected:?}, got: {got:?}")] + MismatchedProtocolVersion { expected: u8, got: u8 }, + #[error("started ping task before the handshake completed")] + PingBeforeHandshake, + // TODO: remove / reconsider + #[error("disconnected")] + Disconnected, +} + +/// Errors when conducting a p2p handshake +#[derive(thiserror::Error, Debug)] +pub enum P2PHandshakeError { + #[error("hello message can only be recv/sent in handshake")] + HelloNotInHandshake, + #[error("received non-hello message when trying to handshake")] + NonHelloMessageInHandshake, + #[error("no capabilities shared with peer")] + NoSharedCapabilities, + #[error("no response received when sending out handshake")] + NoResponse, + #[error("handshake timed out")] + Timeout, +} + +/// An error that can occur when interacting with a [`Pinger`]. +#[derive(Debug, thiserror::Error)] +pub enum PingerError { + /// A ping was sent while the pinger was in the `TimedOut` state. + #[error("ping sent while timed out")] + PingWhileTimedOut, + + /// A pong was received while the pinger was in the `Ready` state. + #[error("pong received while ready")] + PongWhileReady, + + /// A pong was received while the pinger was in the `TimedOut` state. + #[error("pong received while timed out")] + PongWhileTimedOut, +} diff --git a/crates/net/eth-wire/src/stream.rs b/crates/net/eth-wire/src/ethstream.rs similarity index 71% rename from crates/net/eth-wire/src/stream.rs rename to crates/net/eth-wire/src/ethstream.rs index 95407078a6..44ae90157c 100644 --- a/crates/net/eth-wire/src/stream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -2,12 +2,11 @@ use crate::{ error::{EthStreamError, HandshakeError}, types::{forkid::ForkFilter, EthMessage, ProtocolMessage, Status}, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::{ready, Sink, SinkExt, StreamExt}; use pin_project::pin_project; use reth_rlp::{Decodable, Encodable}; use std::{ - io, pin::Pin, task::{Context, Poll}, }; @@ -35,11 +34,10 @@ impl EthStream { } } -impl EthStream +impl EthStream where - S: Stream> - + Sink - + Unpin, + S: Stream> + Sink + Unpin, + EthStreamError: From, { /// Given an instantiated transport layer, it proceeds to return an [`EthStream`] /// after performing a [`Status`] message handshake as specified in @@ -105,9 +103,10 @@ where } } -impl Stream for EthStream +impl Stream for EthStream where - S: Stream> + Unpin, + S: Stream> + Unpin, + EthStreamError: From, { type Item = Result; @@ -139,9 +138,10 @@ where } } -impl Sink for EthStream +impl Sink for EthStream where - S: Sink + Unpin, + S: Sink + Unpin, + EthStreamError: From, { type Error = EthStreamError; @@ -175,6 +175,7 @@ where #[cfg(test)] mod tests { use crate::{ + p2pstream::{CapabilityMessage, HelloMessage, ProtocolVersion, UnauthedP2PStream}, types::{broadcast::BlockHashNumber, forkid::ForkFilter, EthMessage, Status}, EthStream, PassthroughCodec, }; @@ -298,4 +299,91 @@ mod tests { // make sure the server receives the message and asserts before ending the test handle.await.unwrap(); } + + #[tokio::test] + async fn ethstream_over_p2p() { + // create a p2p stream and server, then confirm that the two are authed + // create tcpstream + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let server_key = SecretKey::new(&mut rand::thread_rng()); + let test_msg = EthMessage::NewBlockHashes( + vec![ + BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 }, + BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 }, + ] + .into(), + ); + + let genesis = H256::random(); + let fork_filter = ForkFilter::new(0, genesis, vec![]); + + let status = Status { + version: EthVersion::Eth67 as u8, + chain: Chain::Mainnet.into(), + total_difficulty: U256::from(0), + blockhash: H256::random(), + genesis, + // Pass the current fork id. + forkid: fork_filter.current(), + }; + + let status_copy = status; + let fork_filter_clone = fork_filter.clone(); + let test_msg_clone = test_msg.clone(); + let handle = tokio::spawn(async move { + // roughly based off of the design of tokio::net::TcpListener + let (incoming, _) = listener.accept().await.unwrap(); + let stream = ECIESStream::incoming(incoming, server_key).await.unwrap(); + + let server_hello = HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "bitcoind/1.0.0".to_string(), + capabilities: vec![CapabilityMessage::new( + "eth".to_string(), + EthVersion::Eth67 as usize, + )], + port: 30303, + id: pk2id(&server_key.public_key(SECP256K1)), + }; + + let unauthed_stream = UnauthedP2PStream::new(stream); + let p2p_stream = unauthed_stream.handshake(server_hello).await.unwrap(); + let mut eth_stream = EthStream::new(p2p_stream); + eth_stream.handshake(status_copy, fork_filter_clone).await.unwrap(); + + // use the stream to get the next message + let message = eth_stream.next().await.unwrap().unwrap(); + assert_eq!(message, test_msg_clone); + }); + + // create the server pubkey + let server_id = pk2id(&server_key.public_key(SECP256K1)); + + let client_key = SecretKey::new(&mut rand::thread_rng()); + + let outgoing = TcpStream::connect(local_addr).await.unwrap(); + let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap(); + + let client_hello = HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "bitcoind/1.0.0".to_string(), + capabilities: vec![CapabilityMessage::new( + "eth".to_string(), + EthVersion::Eth67 as usize, + )], + port: 30303, + id: pk2id(&client_key.public_key(SECP256K1)), + }; + + let unauthed_stream = UnauthedP2PStream::new(sink); + let p2p_stream = unauthed_stream.handshake(client_hello).await.unwrap(); + let mut client_stream = EthStream::new(p2p_stream); + client_stream.handshake(status, fork_filter).await.unwrap(); + + client_stream.send(test_msg).await.unwrap(); + + // make sure the server receives the message and asserts before ending the test + handle.await.unwrap(); + } } diff --git a/crates/net/eth-wire/src/lib.rs b/crates/net/eth-wire/src/lib.rs index 9615eb3aee..c308407b4d 100644 --- a/crates/net/eth-wire/src/lib.rs +++ b/crates/net/eth-wire/src/lib.rs @@ -9,9 +9,12 @@ pub use tokio_util::codec::{ LengthDelimitedCodec as PassthroughCodec, LengthDelimitedCodecError as PassthroughCodecError, }; +mod capability; pub mod error; -mod stream; +mod ethstream; +mod p2pstream; +mod pinger; pub mod types; pub use types::*; -pub use stream::EthStream; +pub use ethstream::EthStream; diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs new file mode 100644 index 0000000000..82f7fed52c --- /dev/null +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -0,0 +1,977 @@ +#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)] +use bytes::{Buf, Bytes, BytesMut}; +use futures::{ready, FutureExt, Sink, SinkExt, StreamExt}; +use pin_project::pin_project; +use reth_primitives::H512 as PeerId; +use reth_rlp::{Decodable, DecodeError, Encodable, RlpDecodable, RlpEncodable}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio_stream::Stream; + +use crate::{ + capability::SharedCapability, + error::{P2PHandshakeError, P2PStreamError}, + pinger::{IntervalTimeoutPinger, PingerEvent}, +}; + +/// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload. +/// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706). +const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024; + +/// [`MAX_RESERVED_MESSAGE_ID`] is the maximum message ID reserved for the `p2p` subprotocol. If +/// there are any incoming messages with an ID greater than this, they are subprotocol messages. +const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f; + +/// [`MAX_P2P_MESSAGE_ID`] is the maximum message ID in use for the `p2p` subprotocol. +const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8; + +/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` +/// handshake has timed out. +const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); + +/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has +/// timed out. +const PING_TIMEOUT: Duration = Duration::from_secs(15); + +/// [`PING_INTERVAL`] determines the amount of time to wait between sending `p2p` ping messages +/// when the peer is responsive. +const PING_INTERVAL: Duration = Duration::from_secs(60); + +/// [`GRACE_PERIOD`] determines the amount of time to wait for a peer to disconnect after sending a +/// [`P2PMessage::Disconnect`] message. +const GRACE_PERIOD: Duration = Duration::from_secs(2); + +/// [`MAX_FAILED_PINGS`] determines the maximum number of failed ping attempts before disconnecting +/// from a peer. +const MAX_FAILED_PINGS: u8 = 3; + +/// An un-authenticated `P2PStream`. This is consumed and returns a [`P2PStream`] after the `Hello` +/// handshake is completed. +#[pin_project] +pub struct UnauthedP2PStream { + #[pin] + inner: S, +} + +impl UnauthedP2PStream { + /// Create a new `UnauthedP2PStream` from a `Stream` of bytes. + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl UnauthedP2PStream +where + S: Stream> + Sink + Unpin, +{ + /// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is + /// completed. + pub async fn handshake(mut self, hello: HelloMessage) -> Result, P2PStreamError> { + tracing::trace!("sending p2p hello ..."); + + // send our hello message with the Sink + let mut raw_hello_bytes = BytesMut::new(); + P2PMessage::Hello(hello.clone()).encode(&mut raw_hello_bytes); + self.inner.send(raw_hello_bytes.into()).await?; + + tracing::trace!("waiting for p2p hello from peer ..."); + + let hello_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next()) + .await + .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))? + .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??; + + // let's check the compressed length first, we will need to check again once confirming + // that it contains snappy-compressed data (this will be the case for all non-p2p messages). + if hello_bytes.len() > MAX_PAYLOAD_SIZE { + return Err(P2PStreamError::MessageTooBig { + message_size: hello_bytes.len(), + max_size: MAX_PAYLOAD_SIZE, + }) + } + + // get the message id + let id = *hello_bytes.first().ok_or_else(|| P2PStreamError::EmptyProtocolMessage)?; + + // the first message sent MUST be the hello message + if id != P2PMessageID::Hello as u8 { + return Err(P2PStreamError::HandshakeError( + P2PHandshakeError::NonHelloMessageInHandshake, + )) + } + + let their_hello = match P2PMessage::decode(&mut &hello_bytes[..])? { + P2PMessage::Hello(hello) => Ok(hello), + _ => { + // TODO: this should never occur due to the id check + Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake)) + } + }?; + + // TODO: explicitly document that we only support v5. + if their_hello.protocol_version != ProtocolVersion::V5 { + // TODO: do we want to send a `Disconnect` message here? + return Err(P2PStreamError::MismatchedProtocolVersion { + expected: ProtocolVersion::V5 as u8, + got: their_hello.protocol_version as u8, + }) + } + + // determine shared capabilities (currently returns only one capability) + let capability = set_capability_offsets(hello.capabilities, their_hello.capabilities)?; + + let stream = P2PStream::new(self.inner, capability); + + Ok(stream) + } +} + +/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p` +/// protocol messages. +#[pin_project] +pub struct P2PStream { + #[pin] + inner: S, + + /// The snappy encoder used for compressing outgoing messages + encoder: snap::raw::Encoder, + + /// The snappy decoder used for decompressing incoming messages + decoder: snap::raw::Decoder, + + /// The state machine used for keeping track of the peer's ping status. + pinger: IntervalTimeoutPinger, + + /// The supported capability for this stream. + shared_capability: SharedCapability, +} + +impl P2PStream { + /// Create a new unauthed [`P2PStream`] from the provided stream. You will need to manually + /// handshake with a peer. + pub fn new(inner: S, capability: SharedCapability) -> Self { + Self { + inner, + encoder: snap::raw::Encoder::new(), + decoder: snap::raw::Decoder::new(), + pinger: IntervalTimeoutPinger::new(MAX_FAILED_PINGS, PING_INTERVAL, PING_TIMEOUT), + shared_capability: capability, + } + } +} + +// S must also be `Sink` because we need to be able to respond with ping messages to follow the +// protocol +impl Stream for P2PStream +where + S: Stream> + Sink + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // poll the pinger to determine if we should send a ping + let pinger_res = ready!(Pin::new(&mut this.pinger).poll_next(cx)); + match pinger_res { + Some(Ok(PingerEvent::Ping)) => { + // encode the ping message + let mut ping_bytes = BytesMut::new(); + P2PMessage::Ping.encode(&mut ping_bytes); + + // TODO: fix use of Sink API + let send_res = Pin::new(&mut this.inner).send(ping_bytes.into()).poll_unpin(cx)?; + ready!(send_res) + } + // either None (stream ended) or Some(PingEvent::Timeout) or Err(err) + _ => { + // encode the disconnect message + let mut disconnect_bytes = BytesMut::new(); + P2PMessage::Disconnect(DisconnectReason::PingTimeout).encode(&mut disconnect_bytes); + + // TODO: fix use of Sink API + let send_res = + Pin::new(&mut this.inner).send(disconnect_bytes.into()).poll_unpin(cx)?; + ready!(send_res); + + // since the ping stream has timed out, let's send a None + return Poll::Ready(None) + } + }; + + // we should loop here to ensure we don't return Poll::Pending if we have a message to + // return behind any pings we need to respond to + while let Poll::Ready(res) = this.inner.as_mut().poll_next(cx) { + let bytes = match res { + Some(Ok(bytes)) => bytes, + Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + None => return Poll::Ready(None), + }; + + let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?; + if id == P2PMessageID::Ping as u8 { + // TODO: do we need to decode the ping? + // we have received a ping, so we will send a pong + let mut pong_bytes = BytesMut::new(); + P2PMessage::Pong.encode(&mut pong_bytes); + + // TODO: fix use of Sink API + let send_res = Pin::new(&mut this.inner).send(pong_bytes.into()).poll_unpin(cx)?; + ready!(send_res) + + // continue to the next message if there is one + } else if id == P2PMessageID::Disconnect as u8 { + let reason = DisconnectReason::decode(&mut &bytes[1..])?; + // TODO: do something with the reason + return Poll::Ready(Some(Err(P2PStreamError::Disconnected))) + } else if id == P2PMessageID::Hello as u8 { + // we have received a hello message outside of the handshake, so we will return an + // error + return Poll::Ready(Some(Err(P2PStreamError::HandshakeError( + P2PHandshakeError::HelloNotInHandshake, + )))) + } else if id == P2PMessageID::Pong as u8 { + // TODO: do we need to decode the pong? + // if we were waiting for a pong, this will reset the pinger state + this.pinger.pong_received()? + } else if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID { + // we have received an unknown reserved message + return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id)))) + } else { + // first check that the compressed message length does not exceed the max message + // size + let decompressed_len = snap::raw::decompress_len(&bytes[1..])?; + if decompressed_len > MAX_PAYLOAD_SIZE { + return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig { + message_size: decompressed_len, + max_size: MAX_PAYLOAD_SIZE, + }))) + } + + // then decompress the message + let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1); + + // we have a subprotocol message that needs to be sent in the stream. + // first, switch the message id based on offset so the next layer can decode it + // without being aware of the p2p stream's state (shared capabilities / the message + // id offset) + decompress_buf[0] = bytes[0] - this.shared_capability.offset(); + this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?; + + return Poll::Ready(Some(Ok(decompress_buf))) + } + } + + Poll::Pending + } +} + +impl Sink for P2PStream +where + S: Sink + Unpin, +{ + type Error = P2PStreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx).map_err(Into::into) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let this = self.project(); + + let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1)); + + // all messages sent in this stream are subprotocol messages, so we need to switch the + // message id based on the offset + compressed[0] = item[0] + this.shared_capability.offset(); + let compressed_size = this.encoder.compress(&item[1..], &mut compressed[1..])?; + + // truncate the compressed buffer to the actual compressed size (plus one for the message + // id) + compressed.truncate(compressed_size + 1); + + this.inner.start_send(compressed.freeze())?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx).map_err(Into::into) + } +} + +/// Determines the offsets for each shared capability between the input list of peer +/// capabilities and the input list of locally supported capabilities. +/// +/// Currently only `eth` versions 66 and 67 are supported. +pub fn set_capability_offsets( + local_capabilities: Vec, + peer_capabilities: Vec, +) -> Result { + // find intersection of capabilities + let our_capabilities_map = + local_capabilities.into_iter().map(|c| (c.name, c.version)).collect::>(); + + // map of capability name to version + let mut shared_capabilities = HashMap::new(); + + // sorted list of capability names + // TODO: the Ord implementation for strings says the following: + // > Strings are ordered lexicographically by their byte values. This orders Unicode code + // points based on their positions in the code charts. This is not necessarily the same as + // “alphabetical” order. + // We need to implement a case-sensitive alphabetical sort + let mut shared_capability_names = BTreeSet::new(); + + // find highest shared version of each shared capability + for capability in peer_capabilities { + // if this is Some, we share this capability + if let Some(version) = our_capabilities_map.get(&capability.name) { + // If multiple versions are shared of the same (equal name) capability, the numerically + // highest wins, others are ignored + if capability.version <= *version { + shared_capabilities.insert(capability.name.clone(), capability.version); + shared_capability_names.insert(capability.name); + } + } + } + + // disconnect if we don't share any capabilities + if shared_capabilities.is_empty() { + // TODO: send a disconnect message? if we want to do this, this will need to be a member + // method of `UnauthedP2PStream` so it can access the inner stream + return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities)) + } + + // order versions based on capability name (alphabetical) and select offsets based on + // BASE_OFFSET + prev_total_message + let mut shared_with_offsets = Vec::new(); + + // Message IDs are assumed to be compact from ID 0x10 onwards (0x00-0x0f is reserved for the + // "p2p" capability) and given to each shared (equal-version, equal-name) capability in + // alphabetic order. + let mut offset = MAX_RESERVED_MESSAGE_ID + 1; + for name in shared_capability_names { + let version = shared_capabilities.get(&name).unwrap(); + + let shared_capability = SharedCapability::new(&name, *version as u8, offset)?; + + match shared_capability { + SharedCapability::UnknownCapability { .. } => { + // Capabilities which are not shared are ignored + tracing::warn!("unknown capability: name={:?}, version={}", name, version,); + } + SharedCapability::Eth { .. } => { + shared_with_offsets.push(shared_capability.clone()); + + // increment the offset if the capability is known + offset += shared_capability.num_messages()?; + } + } + } + + // TODO: support multiple capabilities - we would need a new Stream type to go on top of + // `P2PStream` containing its capability. `P2PStream` would still send pings and handle + // pongs, but instead contain a map of capabilities to their respective stream / channel. + // Each channel would be responsible for containing the offset for that stream and would + // only increment / decrement message IDs. + // NOTE: since the `P2PStream` currently only supports one capability, we set the + // capability with the lowest offset. + Ok(shared_with_offsets + .first() + .ok_or_else(|| P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))? + .clone()) +} + +/// This represents only the reserved `p2p` subprotocol messages. +#[derive(Debug, PartialEq, Eq)] +pub enum P2PMessage { + /// The first packet sent over the connection, and sent once by both sides. + Hello(HelloMessage), + + /// Inform the peer that a disconnection is imminent; if received, a peer should disconnect + /// immediately. + Disconnect(DisconnectReason), + + /// Requests an immediate reply of [`Pong`] from the peer. + Ping, + + /// Reply to the peer's [`Ping`] packet. + Pong, +} + +impl P2PMessage { + /// Gets the [`P2PMessageID`] for the given message. + pub fn message_id(&self) -> P2PMessageID { + match self { + P2PMessage::Hello(_) => P2PMessageID::Hello, + P2PMessage::Disconnect(_) => P2PMessageID::Disconnect, + P2PMessage::Ping => P2PMessageID::Ping, + P2PMessage::Pong => P2PMessageID::Pong, + } + } +} + +impl Encodable for P2PMessage { + fn length(&self) -> usize { + let payload_len = match self { + P2PMessage::Hello(msg) => msg.length(), + P2PMessage::Disconnect(msg) => msg.length(), + P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3 + P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3 + }; + payload_len + 1 // (1 for length of p2p message id) + } + + fn encode(&self, out: &mut dyn bytes::BufMut) { + out.put_u8(self.message_id() as u8); + match self { + P2PMessage::Hello(msg) => msg.encode(out), + P2PMessage::Disconnect(msg) => msg.encode(out), + P2PMessage::Ping => { + out.put_u8(0x01); + out.put_u8(0x00); + out.put_u8(0x80); + } + P2PMessage::Pong => { + out.put_u8(0x01); + out.put_u8(0x00); + out.put_u8(0x80); + } + } + } +} + +impl Decodable for P2PMessage { + fn decode(buf: &mut &[u8]) -> Result { + let first = buf.first().expect("cannot decode empty p2p message"); + let id = P2PMessageID::try_from(*first) + .or(Err(DecodeError::Custom("unknown p2p message id")))?; + buf.advance(1); + match id { + P2PMessageID::Hello => Ok(P2PMessage::Hello(HelloMessage::decode(buf)?)), + P2PMessageID::Disconnect => Ok(P2PMessage::Disconnect(DisconnectReason::decode(buf)?)), + P2PMessageID::Ping => { + // len([0x01, 0x00, 0x80]) = 3 + buf.advance(3); + Ok(P2PMessage::Ping) + } + P2PMessageID::Pong => { + // len([0x01, 0x00, 0x80]) = 3 + buf.advance(3); + Ok(P2PMessage::Pong) + } + } + } +} + +/// Message IDs for `p2p` subprotocol messages. +pub enum P2PMessageID { + /// Message ID for the [`P2PMessage::Hello`] message. + Hello = 0x00, + + /// Message ID for the [`P2PMessage::Disconnect`] message. + Disconnect = 0x01, + + /// Message ID for the [`P2PMessage::Ping`] message. + Ping = 0x02, + + /// Message ID for the [`P2PMessage::Pong`] message. + Pong = 0x03, +} + +impl From for P2PMessageID { + fn from(msg: P2PMessage) -> Self { + match msg { + P2PMessage::Hello(_) => P2PMessageID::Hello, + P2PMessage::Disconnect(_) => P2PMessageID::Disconnect, + P2PMessage::Ping => P2PMessageID::Ping, + P2PMessage::Pong => P2PMessageID::Pong, + } + } +} + +impl TryFrom for P2PMessageID { + type Error = P2PStreamError; + + fn try_from(id: u8) -> Result { + match id { + 0x00 => Ok(P2PMessageID::Hello), + 0x01 => Ok(P2PMessageID::Disconnect), + 0x02 => Ok(P2PMessageID::Ping), + 0x03 => Ok(P2PMessageID::Pong), + _ => Err(P2PStreamError::UnknownReservedMessageId(id)), + } + } +} + +/// A message indicating a supported capability and capability version. +#[derive(Clone, Debug, PartialEq, Eq, RlpEncodable, RlpDecodable)] +pub struct CapabilityMessage { + /// The name of the subprotocol + pub name: String, + /// The version of the subprotocol + pub version: usize, +} + +impl CapabilityMessage { + /// Create a new `CapabilityMessage` with the given name and version. + pub fn new(name: String, version: usize) -> Self { + Self { name, version } + } +} + +// TODO: determine if we should allow for the extra fields at the end like EIP-706 suggests +/// Message used in the `p2p` handshake, containing information about the supported RLPx protocol +/// version and capabilities. +#[derive(Clone, Debug, PartialEq, Eq, RlpEncodable, RlpDecodable)] +pub struct HelloMessage { + /// The version of the `p2p` protocol. + pub protocol_version: ProtocolVersion, + /// Specifies the client software identity, as a human-readable string (e.g. + /// "Ethereum(++)/1.0.0"). + pub client_version: String, + /// The list of supported capabilities and their versions. + pub capabilities: Vec, + /// The port that the client is listening on, zero indicates the client is not listening. + pub port: u16, + /// The secp256k1 public key corresponding to the node's private key. + pub id: PeerId, +} + +/// RLPx `p2p` protocol version +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ProtocolVersion { + /// `p2p` version 4 + V4 = 4, + /// `p2p` version 5 + V5 = 5, +} + +impl Encodable for ProtocolVersion { + fn length(&self) -> usize { + // the version should be a single byte + (*self as u8).length() + } + fn encode(&self, out: &mut dyn bytes::BufMut) { + (*self as u8).encode(out) + } +} + +impl Decodable for ProtocolVersion { + fn decode(buf: &mut &[u8]) -> Result { + let version = u8::decode(buf)?; + match version { + 4 => Ok(ProtocolVersion::V4), + 5 => Ok(ProtocolVersion::V5), + _ => Err(DecodeError::Custom("unknown p2p protocol version")), + } + } +} + +/// RLPx disconnect reason. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DisconnectReason { + /// Disconnect requested by the local node or remote peer. + DisconnectRequested = 0x00, + /// TCP related error + TcpSubsystemError = 0x01, + /// Breach of protocol at the transport or p2p level + ProtocolBreach = 0x02, + /// Node has no matching protocols. + UselessPeer = 0x03, + /// Either the remote or local node has too many peers. + TooManyPeers = 0x04, + /// Already connected to the peer. + AlreadyConnected = 0x05, + /// `p2p` protocol version is incompatible + IncompatibleP2PProtocolVersion = 0x06, + NullNodeIdentity = 0x07, + ClientQuitting = 0x08, + UnexpectedHandshakeIdentity = 0x09, + /// The node is connected to itself + ConnectedToSelf = 0x0a, + /// Peer or local node did not respond to a ping in time. + PingTimeout = 0x0b, + /// Peer or local node violated a subprotocol-specific rule. + SubprotocolSpecific = 0x10, +} + +impl Display for DisconnectReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let message = match self { + DisconnectReason::DisconnectRequested => "Disconnect requested", + DisconnectReason::TcpSubsystemError => "TCP sub-system error", + DisconnectReason::ProtocolBreach => { + "Breach of protocol, e.g. a malformed message, bad RLP, ..." + } + DisconnectReason::UselessPeer => "Useless peer", + DisconnectReason::TooManyPeers => "Too many peers", + DisconnectReason::AlreadyConnected => "Already connected", + DisconnectReason::IncompatibleP2PProtocolVersion => "Incompatible P2P protocol version", + DisconnectReason::NullNodeIdentity => { + "Null node identity received - this is automatically invalid" + } + DisconnectReason::ClientQuitting => "Client quitting", + DisconnectReason::UnexpectedHandshakeIdentity => "Unexpected identity in handshake", + DisconnectReason::ConnectedToSelf => { + "Identity is the same as this node (i.e. connected to itself)" + } + DisconnectReason::PingTimeout => "Ping timeout", + DisconnectReason::SubprotocolSpecific => "Some other reason specific to a subprotocol", + }; + + write!(f, "{}", message) + } +} + +/// This represents an unknown disconnect reason with the given code. +#[derive(Debug, Clone)] +pub struct UnknownDisconnectReason(u8); + +impl TryFrom for DisconnectReason { + // This error type should not be used to crash the node, but rather to log the error and + // disconnect the peer. + type Error = UnknownDisconnectReason; + + fn try_from(value: u8) -> Result { + match value { + 0x00 => Ok(DisconnectReason::DisconnectRequested), + 0x01 => Ok(DisconnectReason::TcpSubsystemError), + 0x02 => Ok(DisconnectReason::ProtocolBreach), + 0x03 => Ok(DisconnectReason::UselessPeer), + 0x04 => Ok(DisconnectReason::TooManyPeers), + 0x05 => Ok(DisconnectReason::AlreadyConnected), + 0x06 => Ok(DisconnectReason::IncompatibleP2PProtocolVersion), + 0x07 => Ok(DisconnectReason::NullNodeIdentity), + 0x08 => Ok(DisconnectReason::ClientQuitting), + 0x09 => Ok(DisconnectReason::UnexpectedHandshakeIdentity), + 0x0a => Ok(DisconnectReason::ConnectedToSelf), + 0x0b => Ok(DisconnectReason::PingTimeout), + 0x10 => Ok(DisconnectReason::SubprotocolSpecific), + _ => Err(UnknownDisconnectReason(value)), + } + } +} + +impl Encodable for DisconnectReason { + fn length(&self) -> usize { + // disconnect reasons are snappy encoded as follows: + // [0x01, 0x00, reason as u8] + // this is 3 bytes + 3 + } + fn encode(&self, out: &mut dyn bytes::BufMut) { + // disconnect reasons are snappy encoded as follows: + // [0x01, 0x00, reason as u8] + // this is 3 bytes + out.put_u8(0x01); + out.put_u8(0x00); + out.put_u8(*self as u8); + } +} + +impl Decodable for DisconnectReason { + fn decode(buf: &mut &[u8]) -> Result { + let first = *buf.first().expect("disconnect reason should have at least 1 byte"); + buf.advance(1); + if first != 0x01 { + return Err(DecodeError::Custom("invalid disconnect reason - invalid snappy header")) + } + + let second = *buf.first().expect("disconnect reason should have at least 2 bytes"); + buf.advance(1); + if second != 0x00 { + // TODO: make sure this error message is correct + return Err(DecodeError::Custom("invalid disconnect reason - invalid snappy header")) + } + + let reason = *buf.first().expect("disconnect reason should have 3 bytes"); + buf.advance(1); + DisconnectReason::try_from(reason) + .map_err(|_| DecodeError::Custom("unknown disconnect reason")) + } +} + +#[cfg(test)] +mod tests { + use reth_ecies::util::pk2id; + use reth_rlp::EMPTY_STRING_CODE; + use secp256k1::{SecretKey, SECP256K1}; + use tokio::net::{TcpListener, TcpStream}; + use tokio_util::codec::Decoder; + + use crate::EthVersion; + + use super::*; + + #[tokio::test] + async fn test_handshake_passthrough() { + // create a p2p stream and server, then confirm that the two are authed + // create tcpstream + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + // roughly based off of the design of tokio::net::TcpListener + let (incoming, _) = listener.accept().await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(incoming); + + let server_key = SecretKey::new(&mut rand::thread_rng()); + let server_hello = HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "bitcoind/1.0.0".to_string(), + capabilities: vec![EthVersion::Eth67.into()], + port: 30303, + id: pk2id(&server_key.public_key(SECP256K1)), + }; + + let unauthed_stream = UnauthedP2PStream::new(stream); + let p2p_stream = unauthed_stream.handshake(server_hello).await.unwrap(); + + // ensure that the two share a single capability, eth67 + assert_eq!( + p2p_stream.shared_capability, + SharedCapability::Eth { + version: EthVersion::Eth67, + offset: MAX_RESERVED_MESSAGE_ID + 1 + } + ); + }); + + let client_key = SecretKey::new(&mut rand::thread_rng()); + + let outgoing = TcpStream::connect(local_addr).await.unwrap(); + let sink = crate::PassthroughCodec::default().framed(outgoing); + + let client_hello = HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "bitcoind/1.0.0".to_string(), + capabilities: vec![EthVersion::Eth67.into()], + port: 30303, + id: pk2id(&client_key.public_key(SECP256K1)), + }; + + let unauthed_stream = UnauthedP2PStream::new(sink); + let p2p_stream = unauthed_stream.handshake(client_hello).await.unwrap(); + + // ensure that the two share a single capability, eth67 + assert_eq!( + p2p_stream.shared_capability, + SharedCapability::Eth { + version: EthVersion::Eth67, + offset: MAX_RESERVED_MESSAGE_ID + 1 + } + ); + + // make sure the server receives the message and asserts before ending the test + handle.await.unwrap(); + } + + #[test] + fn test_ping_snappy_encoding_parity() { + // encode ping using our `Encodable` implementation + let ping = P2PMessage::Ping; + let mut ping_encoded = Vec::new(); + ping.encode(&mut ping_encoded); + + // the definition of ping is 0x80 (an empty rlp string) + let ping_raw = vec![EMPTY_STRING_CODE]; + let mut snappy_encoder = snap::raw::Encoder::new(); + let ping_compressed = snappy_encoder.compress_vec(&ping_raw).unwrap(); + let mut ping_expected = vec![P2PMessageID::Ping as u8]; + ping_expected.extend(&ping_compressed); + + // ensure that the two encodings are equal + assert_eq!( + ping_expected, ping_encoded, + "left: {:#x?}, right: {:#x?}", + ping_expected, ping_encoded + ); + + // also ensure that the length is correct + assert_eq!(ping_expected.len(), P2PMessage::Ping.length()); + + // try to decode using Decodable + let p2p_message = P2PMessage::decode(&mut &ping_expected[..]).unwrap(); + assert_eq!(p2p_message, P2PMessage::Ping); + + // finally decode the encoded message with snappy + let mut snappy_decoder = snap::raw::Decoder::new(); + + // the message id is not compressed, only compress the latest bits + let decompressed = snappy_decoder.decompress_vec(&ping_encoded[1..]).unwrap(); + + assert_eq!(decompressed, ping_raw); + } + + #[test] + fn test_pong_snappy_encoding_parity() { + // encode pong using our `Encodable` implementation + let pong = P2PMessage::Pong; + let mut pong_encoded = Vec::new(); + pong.encode(&mut pong_encoded); + + // the definition of pong is 0x80 (an empty rlp string) + let pong_raw = vec![EMPTY_STRING_CODE]; + let mut snappy_encoder = snap::raw::Encoder::new(); + let pong_compressed = snappy_encoder.compress_vec(&pong_raw).unwrap(); + let mut pong_expected = vec![P2PMessageID::Pong as u8]; + pong_expected.extend(&pong_compressed); + + // ensure that the two encodings are equal + assert_eq!( + pong_expected, pong_encoded, + "left: {:#x?}, right: {:#x?}", + pong_expected, pong_encoded + ); + + // also ensure that the length is correct + assert_eq!(pong_expected.len(), P2PMessage::Pong.length()); + + // try to decode using Decodable + let p2p_message = P2PMessage::decode(&mut &pong_expected[..]).unwrap(); + assert_eq!(p2p_message, P2PMessage::Pong); + + // finally decode the encoded message with snappy + let mut snappy_decoder = snap::raw::Decoder::new(); + + // the message id is not compressed, only compress the latest bits + let decompressed = snappy_decoder.decompress_vec(&pong_encoded[1..]).unwrap(); + + assert_eq!(decompressed, pong_raw); + } + + #[test] + fn test_hello_encoding_round_trip() { + let secret_key = SecretKey::new(&mut rand::thread_rng()); + let id = pk2id(&secret_key.public_key(SECP256K1)); + let hello = P2PMessage::Hello(HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "reth/0.1.0".to_string(), + capabilities: vec![CapabilityMessage::new( + "eth".to_string(), + EthVersion::Eth67 as usize, + )], + port: 30303, + id, + }); + + let mut hello_encoded = Vec::new(); + hello.encode(&mut hello_encoded); + + let hello_decoded = P2PMessage::decode(&mut &hello_encoded[..]).unwrap(); + + assert_eq!(hello, hello_decoded); + } + + #[test] + fn hello_encoding_length() { + let secret_key = SecretKey::new(&mut rand::thread_rng()); + let id = pk2id(&secret_key.public_key(SECP256K1)); + let hello = P2PMessage::Hello(HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "reth/0.1.0".to_string(), + capabilities: vec![CapabilityMessage::new( + "eth".to_string(), + EthVersion::Eth67 as usize, + )], + port: 30303, + id, + }); + + let mut hello_encoded = Vec::new(); + hello.encode(&mut hello_encoded); + + assert_eq!(hello_encoded.len(), hello.length()); + } + + #[test] + fn hello_message_id_prefix() { + // ensure that the hello message id is prefixed + let secret_key = SecretKey::new(&mut rand::thread_rng()); + let id = pk2id(&secret_key.public_key(SECP256K1)); + let hello = P2PMessage::Hello(HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "reth/0.1.0".to_string(), + capabilities: vec![CapabilityMessage::new( + "eth".to_string(), + EthVersion::Eth67 as usize, + )], + port: 30303, + id, + }); + + let mut hello_encoded = Vec::new(); + hello.encode(&mut hello_encoded); + + assert_eq!(hello_encoded[0], P2PMessageID::Hello as u8); + } + + #[test] + fn disconnect_round_trip() { + let all_reasons = vec![ + DisconnectReason::DisconnectRequested, + DisconnectReason::TcpSubsystemError, + DisconnectReason::ProtocolBreach, + DisconnectReason::UselessPeer, + DisconnectReason::TooManyPeers, + DisconnectReason::AlreadyConnected, + DisconnectReason::IncompatibleP2PProtocolVersion, + DisconnectReason::NullNodeIdentity, + DisconnectReason::ClientQuitting, + DisconnectReason::UnexpectedHandshakeIdentity, + DisconnectReason::ConnectedToSelf, + DisconnectReason::PingTimeout, + DisconnectReason::SubprotocolSpecific, + ]; + + for reason in all_reasons { + let disconnect = P2PMessage::Disconnect(reason); + + let mut disconnect_encoded = Vec::new(); + disconnect.encode(&mut disconnect_encoded); + + let disconnect_decoded = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap(); + + assert_eq!(disconnect, disconnect_decoded); + } + } + + #[test] + fn disconnect_encoding_length() { + let all_reasons = vec![ + DisconnectReason::DisconnectRequested, + DisconnectReason::TcpSubsystemError, + DisconnectReason::ProtocolBreach, + DisconnectReason::UselessPeer, + DisconnectReason::TooManyPeers, + DisconnectReason::AlreadyConnected, + DisconnectReason::IncompatibleP2PProtocolVersion, + DisconnectReason::NullNodeIdentity, + DisconnectReason::ClientQuitting, + DisconnectReason::UnexpectedHandshakeIdentity, + DisconnectReason::ConnectedToSelf, + DisconnectReason::PingTimeout, + DisconnectReason::SubprotocolSpecific, + ]; + + for reason in all_reasons { + let disconnect = P2PMessage::Disconnect(reason); + + let mut disconnect_encoded = Vec::new(); + disconnect.encode(&mut disconnect_encoded); + + assert_eq!(disconnect_encoded.len(), disconnect.length()); + } + } +} diff --git a/crates/net/eth-wire/src/pinger.rs b/crates/net/eth-wire/src/pinger.rs new file mode 100644 index 0000000000..3003432fc3 --- /dev/null +++ b/crates/net/eth-wire/src/pinger.rs @@ -0,0 +1,581 @@ +use futures::{ready, StreamExt}; +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::interval; +use tokio_stream::{wrappers::IntervalStream, Stream}; + +use crate::error::PingerError; + +/// This represents the possible states of the pinger. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub(crate) enum PingState { + /// There are no pings in flight, or all pings have been responded to and we are ready to send + /// a ping at a later point. + Ready, + + /// We have sent a ping and are waiting for a pong, but the peer has missed n pongs. + WaitingForPong(u8), + + /// The peer has missed n pongs and is considered timed out. + TimedOut(u8), +} + +/// The pinger is a state machine that is created with a maximum number of pongs that can be +/// missed. +#[derive(Debug, Clone)] +pub(crate) struct Pinger { + /// The maximum number of pongs that can be missed. + max_missed: u8, + + /// The current state of the pinger. + state: PingState, +} + +impl Pinger { + /// Create a new pinger with the given maximum number of pongs that can be missed. + pub(crate) fn new(max_missed: u8) -> Self { + Self { max_missed, state: PingState::Ready } + } + + /// Return the current state of the pinger. + pub(crate) fn state(&self) -> &PingState { + &self.state + } + + /// Check if the pinger is in the `Ready` state. + pub(crate) fn is_ready(&self) -> bool { + matches!(self.state, PingState::Ready) + } + + /// Check if the pinger is in the `WaitingForPong` state. + pub(crate) fn is_waiting_for_pong(&self) -> bool { + matches!(self.state, PingState::WaitingForPong(_)) + } + + /// Check if the pinger is in the `TimedOut` state. + pub(crate) fn is_timed_out(&self) -> bool { + matches!(self.state, PingState::TimedOut(_)) + } + + /// Transition the pinger to the `WaitingForPong` state if it was in the `Ready` state. + /// + /// If the pinger is in the `WaitingForPong` state, the number of missed pongs will be + /// incremented. If the number of missed pongs exceeds the maximum missed pongs allowed, the + /// pinger will be transitioned to the `TimedOut` state. + /// + /// If the pinger is in the `TimedOut` state, this method will return an error. + pub(crate) fn next_state(&mut self) -> Result<(), PingerError> { + match self.state { + PingState::Ready => { + self.state = PingState::WaitingForPong(0); + Ok(()) + } + PingState::WaitingForPong(missed) => { + if missed + 1 >= self.max_missed { + self.state = PingState::TimedOut(missed + 1); + Ok(()) + } else { + self.state = PingState::WaitingForPong(missed + 1); + Ok(()) + } + } + PingState::TimedOut(_) => Err(PingerError::PingWhileTimedOut), + } + } + + /// Mark a pong as received, and transition the pinger to the `Ready` state if it was in the + /// `WaitingForPong` state. + /// + /// If the pinger is in the `Ready` or `TimedOut` state, this method will return an error. + pub(crate) fn pong_received(&mut self) -> Result<(), PingerError> { + match self.state { + PingState::Ready => Err(PingerError::PongWhileReady), + PingState::WaitingForPong(_) => { + self.state = PingState::Ready; + Ok(()) + } + PingState::TimedOut(_) => Err(PingerError::PongWhileTimedOut), + } + } +} + +/// A Pinger that can be used as a `Stream`, which will emit +#[derive(Debug, Clone)] +pub(crate) struct PingerStream { + /// The pinger. + pinger: Pinger, + + /// Whether a `Timeout` event has already been sent. + timeout_sent: bool, +} + +impl PingerStream { + /// Poll the [`Pinger`] for a [`Option`], which can be either a [`PingEvent::Ping`] + /// or a final [`PingEvent::Timeout`] event, after which the stream will end and return + /// None. + pub(crate) fn poll(&mut self) -> Option> { + // the stream has already sent a timeout event, so we return None + if self.timeout_sent { + return None + } + + match self.pinger.state { + PingState::Ready => { + // the pinger is ready, send a ping + match self.pinger.next_state() { + Ok(()) => Some(Ok(PingerEvent::Ping)), + Err(e) => Some(Err(e)), + } + } + PingState::WaitingForPong(_) => { + // the peer has not timed out (yet), send another ping if the pinger does + // not exceed the maximum number of missed pongs + match self.pinger.next_state() { + Ok(()) => { + match self.pinger.state() { + PingState::TimedOut(_) => { + // the pinger has timed out, send a timeout event and end the + // stream + self.timeout_sent = true; + Some(Ok(PingerEvent::Timeout)) + } + _ => { + // the pinger is still waiting for a pong, send another ping + Some(Ok(PingerEvent::Ping)) + } + } + } + Err(e) => Some(Err(e)), + } + } + PingState::TimedOut(_) => { + self.timeout_sent = true; + Some(Ok(PingerEvent::Timeout)) + } + } + } +} + +impl Stream for PingerStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.timeout_sent { + return Poll::Ready(None) + } + + match self.pinger.state { + PingState::Ready => { + // the pinger is ready, send a ping + self.pinger.next_state()?; + Poll::Ready(Some(Ok(PingerEvent::Ping))) + } + PingState::WaitingForPong(_) => { + // the peer has not timed out (yet), send another ping if the pinger does + // not exceed the maximum number of missed pongs + self.pinger.next_state()?; + match self.pinger.state() { + PingState::TimedOut(_) => { + // the pinger has timed out, send a timeout event + Poll::Ready(Some(Ok(PingerEvent::Timeout))) + } + _ => { + // the pinger is still waiting for a pong, send another ping + Poll::Ready(Some(Ok(PingerEvent::Ping))) + } + } + } + PingState::TimedOut(_) => { + self.timeout_sent = true; + Poll::Ready(Some(Ok(PingerEvent::Timeout))) + } + } + } +} + +/// The element type produced by a [`IntervalPingerStream`], representing either a new [`Ping`] +/// message to send, or an indication that the peer should be timed out. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum PingerEvent { + /// A new [`Ping`] message should be sent. + Ping, + + /// The peer should be timed out. + Timeout, +} + +/// A type of [`Pinger`] that uses an interval and a timeout to determine when to send a ping and +/// when to consider the peer timed out. +#[derive(Debug)] +pub(crate) struct IntervalTimeoutPinger { + /// The interval pinger stream. + interval_stream: IntervalStream, + + /// The pinger stream we are using. + pinger_stream: PingerStream, + + /// The timeout duration for each ping. + timeout: Duration, + + /// The Interval that determines when to timeout the peer and send another ping. + sleep: Option, +} + +impl IntervalTimeoutPinger { + /// Creates a new [`IntervalTimeoutPinger`] with the given max missed pongs, interval duration, + /// and timeout duration. + pub(crate) fn new( + max_missed: u8, + interval_duration: Duration, + timeout_duration: Duration, + ) -> Self { + Self { + interval_stream: IntervalStream::new(interval(interval_duration)), + pinger_stream: PingerStream { pinger: Pinger::new(max_missed), timeout_sent: false }, + timeout: timeout_duration, + sleep: None, + } + } + + /// Mark a pong as received, and transition the pinger to the `Ready` state if it was in the + /// `WaitingForPong` state. Unsets the sleep timer. + pub(crate) fn pong_received(&mut self) -> Result<(), PingerError> { + self.interval_stream.as_mut().reset(); + self.pinger_stream.pinger.pong_received()?; + self.sleep = None; + Ok(()) + } + + /// Waits until the pinger sends a timeout event by exhausting the stream. + pub(crate) async fn wait_for_timeout(&mut self) { + while let Some(Ok(PingerEvent::Ping)) = self.next().await {} + } + + /// Returns the current state of the pinger. + pub(crate) fn state(&self) -> &PingState { + self.pinger_stream.pinger.state() + } +} + +impl Stream for IntervalTimeoutPinger { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // if the pinger state is None, we should also return None regardless of the sleep or + // interval state + + // if we have a sleep timer, prefer that over the interval stream + if let Some(inner_sleep) = this.sleep.as_mut() { + // if the sleep is pending, we should return pending (we are waiting for a timeout) + let pinned_sleep = Pin::new(inner_sleep); + ready!(pinned_sleep.poll_next(cx)); + + // let's reset the interval, because the first one returns immediately when created + // using `interval` + let mut interval = interval(this.timeout); + interval.reset(); + + // the sleep has elapsed, create a new sleep for the next timeout interval, then send a + // new ping + this.sleep = Some(IntervalStream::new(interval)); + + Pin::new(&mut this.pinger_stream).poll_next(cx) + } else { + // first poll the interval stream, if it is ready, send a ping + let res = ready!(this.interval_stream.poll_next_unpin(cx)); + if res.is_none() { + // this should never happen (the Stream impl of IntervalStream never is always Some) + return Poll::Ready(None) + } + + let pinned_stream = Pin::new(&mut this.pinger_stream); + let stream_res = ready!(pinned_stream.poll_next(cx)); + + // let's reset the interval, because the first one returns immediately when created + // using `interval` + let mut interval = interval(this.timeout); + interval.reset(); + + this.sleep = Some(IntervalStream::new(interval)); + Poll::Ready(stream_res) + } + } +} + +#[cfg(test)] +mod tests { + use tokio::select; + + use super::*; + + #[test] + fn send_many_pings() { + // tests the simple pinger by sending many pings without pongs + let mut pinger = Pinger::new(3); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::WaitingForPong(0)); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::WaitingForPong(1)); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::WaitingForPong(2)); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::TimedOut(3)); + } + + #[test] + fn send_many_pings_with_pongs() { + // tests the simple pinger by sending many pings with pongs + let mut pinger = Pinger::new(3); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::WaitingForPong(0)); + + pinger.pong_received().unwrap(); + assert_eq!(*pinger.state(), PingState::Ready); + + pinger.next_state().unwrap(); + assert_eq!(*pinger.state(), PingState::WaitingForPong(0)); + + pinger.pong_received().unwrap(); + assert_eq!(*pinger.state(), PingState::Ready); + } + + #[test] + fn send_many_pings_stream() { + let mut pinger_stream = PingerStream { pinger: Pinger::new(3), timeout_sent: false }; + + assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Timeout); + } + + #[tokio::test] + async fn send_many_pings_interval_timeout() { + // we should wait for the interval to elapse, just like the interval-only version + // TODO: should the timeout ever be less than the interval? + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout); + } + + #[tokio::test] + async fn send_many_pings_interval_timeout_with_pongs() { + // we should wait for the interval to elapse and receive a pong before the timeout elapses + + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + + pinger.pong_received().unwrap(); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout); + } + + #[tokio::test] + async fn check_timing_over_interval() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait for the interval to elapse, and compare it to the interval ping + // to avoid flakiness let's do 25? + let sleep = tokio::time::sleep(Duration::from_millis(25)); + let wait_for_timeout = pinger.next(); + + select! { + _ = sleep => panic!("interval should have elapsed"), + _ = wait_for_timeout => {} + } + } + + #[tokio::test] + async fn check_timing_under_interval() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait for the interval to elapse, and compare it to the interval ping + // to avoid flakiness let's do 15? + let sleep = tokio::time::sleep(Duration::from_millis(15)); + let next_ping = pinger.next(); + + select! { + _ = sleep => {} + _ = next_ping => panic!("sleep should have elapsed first") + } + } + + #[tokio::test] + async fn check_timing_before_timeout() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait ~20ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // ensure that a <10ms sleep completes first + let sleep = tokio::time::sleep(Duration::from_millis(5)); + let next_ping = pinger.next(); + + select! { + _ = sleep => {} + _ = next_ping => panic!("sleep should have before re-sending a ping") + } + + // check that we are in the WaitingForPong(0) state (we should not have timed out the first + // ping yet) + let curr_state = *pinger.state(); + assert_eq!(curr_state, PingState::WaitingForPong(0)); + } + + #[tokio::test] + async fn check_timing_after_timeout() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait ~20ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // ensure that the ping completes before a >10ms sleep + let sleep = tokio::time::sleep(Duration::from_millis(15)); + let next_ping = pinger.next(); + + select! { + _ = sleep => panic!("ping retry should have completed before sleep"), + _ = next_ping => {} + } + + // check that we are in the WaitingForPong(1) state (we should have timed out the first + // ping) + let curr_state = *pinger.state(); + assert_eq!(curr_state, PingState::WaitingForPong(1)); + } + + #[tokio::test] + async fn check_timing_after_second_timeout() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait ~20ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // wait another ~10ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // ensure that the ping completes before a >10ms sleep + let sleep = tokio::time::sleep(Duration::from_millis(15)); + let next_ping = pinger.next(); + + select! { + _ = sleep => panic!("ping retry should have completed before sleep"), + _ = next_ping => {} + } + + // check that we are in the WaitingForPong(2) state (we should have timed out the second + // ping) + let curr_state = *pinger.state(); + assert_eq!(curr_state, PingState::WaitingForPong(2)); + } + + #[tokio::test] + async fn check_timing_after_last_timeout() { + // send pongs after a ping event, timing the interval between the two + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + pinger.pong_received().unwrap(); + + // wait ~20ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // wait another ~10ms for the next ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // wait another ~10ms for the last ping + let next_ping = pinger.next().await.unwrap().unwrap(); + assert_eq!(next_ping, PingerEvent::Ping); + + // ensure that the ping completes before a >10ms sleep + let sleep = tokio::time::sleep(Duration::from_millis(15)); + let next_ping = pinger.next(); + + let ping_res = select! { + _ = sleep => panic!("ping retry should have completed before sleep"), + res = next_ping => { + res.expect("stream should not be empty yet") + } + }; + + assert_eq!(ping_res.unwrap(), PingerEvent::Timeout); + + // check that we are in the TimedOut(3) state (we should have timed out after the last ping) + let curr_state = *pinger.state(); + assert_eq!(curr_state, PingState::TimedOut(3)); + } + + #[tokio::test] + async fn timeout_with_pongs() { + // we should wait for the interval to elapse and receive a pong before the timeout elapses + let mut pinger = + IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10)); + + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping); + + pinger.pong_received().unwrap(); + + // let's wait for the timeout to elapse (3 ping timeouts + interval + 10ms for flake + // protection) + let sleep = tokio::time::sleep(Duration::from_millis(60)); + let wait_for_timeout = pinger.wait_for_timeout(); + + select! { + _ = sleep => panic!("timeout should have elapsed by now"), + _ = wait_for_timeout => (), + } + } +} diff --git a/crates/net/eth-wire/src/types/mod.rs b/crates/net/eth-wire/src/types/mod.rs index 83bb844c5b..6e84030022 100644 --- a/crates/net/eth-wire/src/types/mod.rs +++ b/crates/net/eth-wire/src/types/mod.rs @@ -3,7 +3,7 @@ mod status; pub use status::Status; -mod version; +pub mod version; pub use version::EthVersion; pub mod forkid; diff --git a/crates/net/eth-wire/src/types/version.rs b/crates/net/eth-wire/src/types/version.rs index cd29c49d8f..ec99307534 100644 --- a/crates/net/eth-wire/src/types/version.rs +++ b/crates/net/eth-wire/src/types/version.rs @@ -1,6 +1,8 @@ use std::str::FromStr; use thiserror::Error; +use crate::p2pstream::CapabilityMessage; + #[derive(Debug, Clone, PartialEq, Eq, Error)] #[error("Unknown eth protocol version: {0}")] pub struct ParseVersionError(String); @@ -16,6 +18,19 @@ pub enum EthVersion { Eth67 = 67, } +impl EthVersion { + /// Returns the total number of messages the protocol version supports. + pub fn total_messages(&self) -> u8 { + match self { + EthVersion::Eth66 => 15, + EthVersion::Eth67 => { + // eth/67 is eth/66 minus GetNodeData and NodeData messages + 13 + } + } + } +} + /// Allow for converting from a `&str` to an `EthVersion`. /// /// # Example @@ -86,6 +101,13 @@ impl From for &'static str { } } +impl From for CapabilityMessage { + #[inline] + fn from(v: EthVersion) -> CapabilityMessage { + CapabilityMessage { name: String::from("eth"), version: v as usize } + } +} + #[cfg(test)] mod test { use super::{EthVersion, ParseVersionError};