diff --git a/crates/transaction-pool/src/identifier.rs b/crates/transaction-pool/src/identifier.rs index 876fa00dca..37a3450e06 100644 --- a/crates/transaction-pool/src/identifier.rs +++ b/crates/transaction-pool/src/identifier.rs @@ -5,15 +5,15 @@ use std::collections::HashMap; /// An internal mapping of addresses. /// -/// This assigns a _unique_ `SenderId` for a new `Address`. +/// This assigns a _unique_ [`SenderId`] for a new [`Address`]. /// It has capacity for 2^64 unique addresses. #[derive(Debug, Default)] pub struct SenderIdentifiers { /// The identifier to use next. id: u64, - /// Assigned `SenderId` for an `Address`. + /// Assigned [`SenderId`] for an [`Address`]. address_to_id: HashMap, - /// Reverse mapping of `SenderId` to `Address`. + /// Reverse mapping of [`SenderId`] to [`Address`]. sender_to_address: FxHashMap, } @@ -24,12 +24,12 @@ impl SenderIdentifiers { self.sender_to_address.get(id) } - /// Returns the `SenderId` that belongs to the given address, if it exists + /// Returns the [`SenderId`] that belongs to the given address, if it exists pub fn sender_id(&self, addr: &Address) -> Option { self.address_to_id.get(addr).copied() } - /// Returns the existing `SendId` or assigns a new one if it's missing + /// Returns the existing [`SenderId`] or assigns a new one if it's missing pub fn sender_id_or_create(&mut self, addr: Address) -> SenderId { self.sender_id(&addr).unwrap_or_else(|| { let id = self.next_id(); @@ -39,11 +39,11 @@ impl SenderIdentifiers { }) } - /// Returns a new address + /// Returns the current identifier and increments the counter. fn next_id(&mut self) -> SenderId { let id = self.id; self.id = self.id.wrapping_add(1); - SenderId(id) + id.into() } } @@ -54,10 +54,8 @@ impl SenderIdentifiers { #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct SenderId(u64); -// === impl SenderId === - impl SenderId { - /// Returns a `Bound` for `TransactionId` starting with nonce `0` + /// Returns a `Bound` for [`TransactionId`] starting with nonce `0` pub const fn start_bound(self) -> std::ops::Bound { std::ops::Bound::Included(TransactionId::new(self, 0)) } @@ -81,34 +79,29 @@ pub struct TransactionId { pub nonce: u64, } -// === impl TransactionId === - impl TransactionId { /// Create a new identifier pair pub const fn new(sender: SenderId, nonce: u64) -> Self { Self { sender, nonce } } - /// Returns the `TransactionId` this transaction depends on. + /// Returns the [`TransactionId`] this transaction depends on. /// /// This returns `transaction_nonce - 1` if `transaction_nonce` is higher than the /// `on_chain_nonce` pub fn ancestor(transaction_nonce: u64, on_chain_nonce: u64, sender: SenderId) -> Option { - if transaction_nonce == on_chain_nonce { - return None - } - let prev_nonce = transaction_nonce.saturating_sub(1); - (on_chain_nonce <= prev_nonce).then(|| Self::new(sender, prev_nonce)) + (transaction_nonce > on_chain_nonce) + .then(|| Self::new(sender, transaction_nonce.saturating_sub(1))) } - /// Returns the `TransactionId` that would come before this transaction. + /// Returns the [`TransactionId`] that would come before this transaction. pub fn unchecked_ancestor(&self) -> Option { (self.nonce != 0).then(|| Self::new(self.sender, self.nonce - 1)) } - /// Returns the `TransactionId` that directly follows this transaction: `self.nonce + 1` + /// Returns the [`TransactionId`] that directly follows this transaction: `self.nonce + 1` pub const fn descendant(&self) -> Self { - Self::new(self.sender, self.nonce + 1) + Self::new(self.sender, self.next_nonce()) } /// Returns the nonce that follows immediately after this one. @@ -123,6 +116,67 @@ mod tests { use super::*; use std::collections::BTreeSet; + #[test] + fn test_transaction_id_new() { + let sender = SenderId(1); + let tx_id = TransactionId::new(sender, 5); + assert_eq!(tx_id.sender, sender); + assert_eq!(tx_id.nonce, 5); + } + + #[test] + fn test_transaction_id_ancestor() { + let sender = SenderId(1); + + // Special case with nonce 0 and higher on-chain nonce + let tx_id = TransactionId::ancestor(0, 1, sender); + assert_eq!(tx_id, None); + + // Special case with nonce 0 and same on-chain nonce + let tx_id = TransactionId::ancestor(0, 0, sender); + assert_eq!(tx_id, None); + + // Ancestor is the previous nonce if the transaction nonce is higher than the on-chain nonce + let tx_id = TransactionId::ancestor(5, 0, sender); + assert_eq!(tx_id, Some(TransactionId::new(sender, 4))); + + // No ancestor if the transaction nonce is the same as the on-chain nonce + let tx_id = TransactionId::ancestor(5, 5, sender); + assert_eq!(tx_id, None); + + // No ancestor if the transaction nonce is lower than the on-chain nonce + let tx_id = TransactionId::ancestor(5, 15, sender); + assert_eq!(tx_id, None); + } + + #[test] + fn test_transaction_id_unchecked_ancestor() { + let sender = SenderId(1); + + // Ancestor is the previous nonce if transaction nonce is higher than 0 + let tx_id = TransactionId::new(sender, 5); + assert_eq!(tx_id.unchecked_ancestor(), Some(TransactionId::new(sender, 4))); + + // No ancestor if transaction nonce is 0 + let tx_id = TransactionId::new(sender, 0); + assert_eq!(tx_id.unchecked_ancestor(), None); + } + + #[test] + fn test_transaction_id_descendant() { + let sender = SenderId(1); + let tx_id = TransactionId::new(sender, 5); + let descendant = tx_id.descendant(); + assert_eq!(descendant, TransactionId::new(sender, 6)); + } + + #[test] + fn test_transaction_id_next_nonce() { + let sender = SenderId(1); + let tx_id = TransactionId::new(sender, 5); + assert_eq!(tx_id.next_nonce(), 6); + } + #[test] fn test_transaction_id_ord_eq_sender() { let tx1 = TransactionId::new(100u64.into(), 0u64); @@ -140,4 +194,67 @@ mod tests { let set = BTreeSet::from([tx1, tx2]); assert_eq!(set.into_iter().collect::>(), vec![tx1, tx2]); } + + #[test] + fn test_address_retrieval() { + let mut identifiers = SenderIdentifiers::default(); + let address = Address::new([1; 20]); + let id = identifiers.sender_id_or_create(address); + assert_eq!(identifiers.address(&id), Some(&address)); + } + + #[test] + fn test_sender_id_retrieval() { + let mut identifiers = SenderIdentifiers::default(); + let address = Address::new([1; 20]); + let id = identifiers.sender_id_or_create(address); + assert_eq!(identifiers.sender_id(&address), Some(id)); + } + + #[test] + fn test_sender_id_or_create_existing() { + let mut identifiers = SenderIdentifiers::default(); + let address = Address::new([1; 20]); + let id1 = identifiers.sender_id_or_create(address); + let id2 = identifiers.sender_id_or_create(address); + assert_eq!(id1, id2); + } + + #[test] + fn test_sender_id_or_create_new() { + let mut identifiers = SenderIdentifiers::default(); + let address1 = Address::new([1; 20]); + let address2 = Address::new([2; 20]); + let id1 = identifiers.sender_id_or_create(address1); + let id2 = identifiers.sender_id_or_create(address2); + assert_ne!(id1, id2); + } + + #[test] + fn test_next_id_wrapping() { + let mut identifiers = SenderIdentifiers { id: u64::MAX, ..Default::default() }; + + // The current ID is `u64::MAX`, the next ID should wrap around to 0. + let id1 = identifiers.next_id(); + assert_eq!(id1, SenderId(u64::MAX)); + + // The next ID should now be 0 because of wrapping. + let id2 = identifiers.next_id(); + assert_eq!(id2, SenderId(0)); + + // And then 1, continuing incrementing. + let id3 = identifiers.next_id(); + assert_eq!(id3, SenderId(1)); + } + + #[test] + fn test_sender_id_start_bound() { + let sender = SenderId(1); + let start_bound = sender.start_bound(); + if let std::ops::Bound::Included(tx_id) = start_bound { + assert_eq!(tx_id, TransactionId::new(sender, 0)); + } else { + panic!("Expected included bound"); + } + } } diff --git a/crates/transaction-pool/src/pool/mod.rs b/crates/transaction-pool/src/pool/mod.rs index a00d6522d1..1878242eee 100644 --- a/crates/transaction-pool/src/pool/mod.rs +++ b/crates/transaction-pool/src/pool/mod.rs @@ -194,7 +194,7 @@ where self.pool.write().set_block_info(info) } - /// Returns the internal `SenderId` for this address + /// Returns the internal [`SenderId`] for this address pub(crate) fn get_sender_id(&self, addr: Address) -> SenderId { self.identifiers.write().sender_id_or_create(addr) }