diff --git a/crates/net/eth-wire/src/types/broadcast.rs b/crates/net/eth-wire/src/types/broadcast.rs index a789b81362..b46248e3ee 100644 --- a/crates/net/eth-wire/src/types/broadcast.rs +++ b/crates/net/eth-wire/src/types/broadcast.rs @@ -11,7 +11,11 @@ use reth_primitives::{ Block, Bytes, PooledTransactionsElement, TransactionSigned, TxHash, B256, U128, }; -use std::{collections::HashMap, mem, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + mem, + sync::Arc, +}; #[cfg(feature = "arbitrary")] use proptest::{collection::vec, prelude::*}; @@ -653,7 +657,7 @@ impl ValidAnnouncementData { /// Destructs returning only the valid hashes and the announcement message version. Caution! If /// this is [`Eth68`](EthVersion::Eth68) announcement data, this drops the metadata. pub fn into_request_hashes(self) -> (RequestTxHashes, EthVersion) { - let hashes = self.data.into_keys().collect::>(); + let hashes = self.data.into_keys().collect::>(); (RequestTxHashes::new(hashes), self.version) } @@ -687,30 +691,49 @@ pub struct RequestTxHashes { #[deref] #[deref_mut] #[into_iterator(owned, ref)] - hashes: Vec, + hashes: HashSet, } impl RequestTxHashes { /// Returns a new [`RequestTxHashes`] with given capacity for hashes. Caution! Make sure to - /// call [`Vec::shrink_to_fit`] on [`RequestTxHashes`] when full, especially where it will be - /// stored in its entirety like in the future waiting for a + /// call [`HashSet::shrink_to_fit`] on [`RequestTxHashes`] when full, especially where it will + /// be stored in its entirety like in the future waiting for a /// [`GetPooledTransactions`](crate::GetPooledTransactions) request to resolve. pub fn with_capacity(capacity: usize) -> Self { - Self::new(Vec::with_capacity(capacity)) + Self::new(HashSet::with_capacity(capacity)) + } + + /// Returns an new empty instance. + fn empty() -> Self { + Self::new(HashSet::new()) + } + + /// Retains the given number of elements, returning and iterator over the rest. + pub fn retain_count(&mut self, count: usize) -> Self { + let rest_capacity = self.hashes.len().saturating_sub(count); + if rest_capacity == 0 { + return Self::empty() + } + let mut rest = Self::with_capacity(rest_capacity); + + let mut i = 0; + self.hashes.retain(|hash| { + if i >= count { + rest.insert(*hash); + return false + } + i += 1; + + true + }); + + rest } } impl FromIterator<(TxHash, Eth68TxMetadata)> for RequestTxHashes { fn from_iter>(iter: I) -> Self { - let mut hashes = Vec::with_capacity(32); - - for (hash, _) in iter { - hashes.push(hash); - } - - hashes.shrink_to_fit(); - - RequestTxHashes::new(hashes) + RequestTxHashes::new(iter.into_iter().map(|(hash, _)| hash).collect::>()) } } @@ -718,7 +741,7 @@ impl FromIterator<(TxHash, Eth68TxMetadata)> for RequestTxHashes { mod tests { use super::*; use bytes::BytesMut; - use reth_primitives::hex; + use reth_primitives::{b256, hex}; use std::str::FromStr; /// Takes as input a struct / encoded hex message pair, ensuring that we encode to the exact hex @@ -868,4 +891,63 @@ mod tests { test_encoding_vector(vector); } } + + #[test] + fn request_hashes_retain_count_keep_subset() { + let mut hashes = RequestTxHashes::new( + [ + b256!("0000000000000000000000000000000000000000000000000000000000000001"), + b256!("0000000000000000000000000000000000000000000000000000000000000002"), + b256!("0000000000000000000000000000000000000000000000000000000000000003"), + b256!("0000000000000000000000000000000000000000000000000000000000000004"), + b256!("0000000000000000000000000000000000000000000000000000000000000005"), + ] + .into_iter() + .collect::>(), + ); + + let rest = hashes.retain_count(3); + + assert_eq!(3, hashes.len()); + assert_eq!(2, rest.len()); + } + + #[test] + fn request_hashes_retain_count_keep_all() { + let mut hashes = RequestTxHashes::new( + [ + b256!("0000000000000000000000000000000000000000000000000000000000000001"), + b256!("0000000000000000000000000000000000000000000000000000000000000002"), + b256!("0000000000000000000000000000000000000000000000000000000000000003"), + b256!("0000000000000000000000000000000000000000000000000000000000000004"), + b256!("0000000000000000000000000000000000000000000000000000000000000005"), + ] + .into_iter() + .collect::>(), + ); + + let _ = hashes.retain_count(6); + + assert_eq!(5, hashes.len()); + } + + #[test] + fn split_request_hashes_keep_none() { + let mut hashes = RequestTxHashes::new( + [ + b256!("0000000000000000000000000000000000000000000000000000000000000001"), + b256!("0000000000000000000000000000000000000000000000000000000000000002"), + b256!("0000000000000000000000000000000000000000000000000000000000000003"), + b256!("0000000000000000000000000000000000000000000000000000000000000004"), + b256!("0000000000000000000000000000000000000000000000000000000000000005"), + ] + .into_iter() + .collect::>(), + ); + + let rest = hashes.retain_count(0); + + assert_eq!(0, hashes.len()); + assert_eq!(5, rest.len()); + } } diff --git a/crates/net/network/src/transactions/fetcher.rs b/crates/net/network/src/transactions/fetcher.rs index bffa6fd29d..403c8b03ee 100644 --- a/crates/net/network/src/transactions/fetcher.rs +++ b/crates/net/network/src/transactions/fetcher.rs @@ -149,7 +149,7 @@ impl TransactionFetcher { let idle_peer = self.get_idle_peer_for(hash, &is_session_active); if idle_peer.is_some() { - hashes_to_request.push(hash); + hashes_to_request.insert(hash); break idle_peer.copied() } @@ -160,7 +160,7 @@ impl TransactionFetcher { } } }; - let hash = hashes_to_request.first()?; + let hash = hashes_to_request.iter().next()?; // pop hash that is loaded in request buffer from cache of hashes pending fetch drop(hashes_pending_fetch_iter); @@ -206,7 +206,7 @@ impl TransactionFetcher { let mut hashes_from_announcement_iter = hashes_from_announcement.into_iter(); if let Some((hash, Some((_ty, size)))) = hashes_from_announcement_iter.next() { - hashes_to_request.push(hash); + hashes_to_request.insert(hash); // tx is really big, pack request with single tx if size >= self.info.soft_limit_byte_size_pooled_transactions_response_on_pack_request { @@ -235,9 +235,9 @@ impl TransactionFetcher { // only update accumulated size of tx response if tx will fit in without exceeding // soft limit acc_size_response = next_acc_size; - hashes_to_request.push(hash) + _ = hashes_to_request.insert(hash) } else { - surplus_hashes.push(hash) + _ = surplus_hashes.insert(hash) } let free_space = @@ -275,11 +275,11 @@ impl TransactionFetcher { RequestTxHashes::default() } else { let surplus_hashes = - hashes.split_off(SOFT_LIMIT_COUNT_HASHES_IN_GET_POOLED_TRANSACTIONS_REQUEST - 1); + hashes.retain_count(SOFT_LIMIT_COUNT_HASHES_IN_GET_POOLED_TRANSACTIONS_REQUEST); *hashes_to_request = hashes; hashes_to_request.shrink_to_fit(); - RequestTxHashes::new(surplus_hashes) + surplus_hashes } } @@ -604,7 +604,9 @@ impl TransactionFetcher { let (response, rx) = oneshot::channel(); let req: PeerRequest = PeerRequest::GetPooledTransactions { - request: GetPooledTransactions(new_announced_hashes.clone()), + request: GetPooledTransactions( + new_announced_hashes.iter().copied().collect::>(), + ), response, }; @@ -612,12 +614,9 @@ impl TransactionFetcher { if let Err(err) = peer.request_tx.try_send(req) { // peer channel is full match err { - TrySendError::Full(req) | TrySendError::Closed(req) => { - // need to do some cleanup so - let req = req.into_get_pooled_transactions().expect("is get pooled tx"); - + TrySendError::Full(_) | TrySendError::Closed(_) => { metrics_increment_egress_peer_channel_full(); - return Some(RequestTxHashes::new(req.0)) + return Some(new_announced_hashes) } } } else { @@ -657,7 +656,7 @@ impl TransactionFetcher { seen_hashes: &LruCache, mut budget_fill_request: Option, // check max `budget` lru pending hashes ) { - let Some(hash) = hashes_to_request.first() else { return }; + let Some(hash) = hashes_to_request.iter().next() else { return }; let mut acc_size_response = self .hashes_fetch_inflight_and_pending_fetch @@ -681,7 +680,7 @@ impl TransactionFetcher { }; // 2. Optimistically include the hash in the request. - hashes_to_request.push(*hash); + hashes_to_request.insert(*hash); // 3. Accumulate expected total response size. let size = self @@ -1141,7 +1140,7 @@ impl DedupPayload for VerifiedPooledTransactions { trait VerifyPooledTransactionsResponse { fn verify( self, - requested_hashes: &[TxHash], + requested_hashes: &RequestTxHashes, peer_id: &PeerId, ) -> (VerificationOutcome, VerifiedPooledTransactions); } @@ -1149,7 +1148,7 @@ trait VerifyPooledTransactionsResponse { impl VerifyPooledTransactionsResponse for UnverifiedPooledTransactions { fn verify( self, - requested_hashes: &[TxHash], + requested_hashes: &RequestTxHashes, _peer_id: &PeerId, ) -> (VerificationOutcome, VerifiedPooledTransactions) { let mut verification_outcome = VerificationOutcome::Ok; @@ -1291,9 +1290,11 @@ mod test { 0, ]; - let expected_request_hashes = [eth68_hashes[0], eth68_hashes[2]]; + let expected_request_hashes = + [eth68_hashes[0], eth68_hashes[2]].into_iter().collect::>(); - let expected_surplus_hashes = [eth68_hashes[1], eth68_hashes[3], eth68_hashes[4]]; + let expected_surplus_hashes = + [eth68_hashes[1], eth68_hashes[3], eth68_hashes[4]].into_iter().collect::>(); let mut eth68_hashes_to_request = RequestTxHashes::with_capacity(3); @@ -1310,11 +1311,11 @@ mod test { let surplus_eth68_hashes = tx_fetcher.pack_request_eth68(&mut eth68_hashes_to_request, valid_announcement_data); - let eth68_hashes_to_request = eth68_hashes_to_request.into_iter().collect::>(); - let surplus_eth68_hashes = surplus_eth68_hashes.into_iter().collect::>(); + let eth68_hashes_to_request = eth68_hashes_to_request.into_iter().collect::>(); + let surplus_eth68_hashes = surplus_eth68_hashes.into_iter().collect::>(); - assert_eq!(expected_request_hashes.to_vec(), eth68_hashes_to_request); - assert_eq!(expected_surplus_hashes.to_vec(), surplus_eth68_hashes); + assert_eq!(expected_request_hashes, eth68_hashes_to_request); + assert_eq!(expected_surplus_hashes, surplus_eth68_hashes); } #[tokio::test] @@ -1437,7 +1438,8 @@ mod test { assert_ne!(hash, signed_tx_2.hash()) } - let request_hashes = RequestTxHashes::new(request_hashes.clone().to_vec()); + let request_hashes = + RequestTxHashes::new(request_hashes.into_iter().collect::>()); // but response contains tx 1 + another tx let response_txns = PooledTransactions(vec![signed_tx_1.clone(), signed_tx_2.clone()]); diff --git a/crates/net/network/src/transactions/mod.rs b/crates/net/network/src/transactions/mod.rs index a69aac19fe..8efcb24d19 100644 --- a/crates/net/network/src/transactions/mod.rs +++ b/crates/net/network/src/transactions/mod.rs @@ -2071,7 +2071,9 @@ mod tests { let PeerRequest::GetPooledTransactions { request, response } = req else { unreachable!() }; let GetPooledTransactions(hashes) = request; - assert_eq!(hashes, seen_hashes); + let hashes = hashes.into_iter().collect::>(); + + assert_eq!(hashes, seen_hashes.into_iter().collect::>()); // fail request to peer_1 response