From f37ef730ce1f0ced2691e53aecf895191dab9438 Mon Sep 17 00:00:00 2001
From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com>
Date: Sun, 28 Jul 2024 12:27:58 +0200
Subject: [PATCH] tx-pool: add unit tests and refactor identifiers (#9855)
---
crates/transaction-pool/src/identifier.rs | 159 +++++++++++++++++++---
crates/transaction-pool/src/pool/mod.rs | 2 +-
2 files changed, 139 insertions(+), 22 deletions(-)
diff --git a/crates/transaction-pool/src/identifier.rs b/crates/transaction-pool/src/identifier.rs
index 876fa00dca..37a3450e06 100644
--- a/crates/transaction-pool/src/identifier.rs
+++ b/crates/transaction-pool/src/identifier.rs
@@ -5,15 +5,15 @@ use std::collections::HashMap;
/// An internal mapping of addresses.
///
-/// This assigns a _unique_ `SenderId` for a new `Address`.
+/// This assigns a _unique_ [`SenderId`] for a new [`Address`].
/// It has capacity for 2^64 unique addresses.
#[derive(Debug, Default)]
pub struct SenderIdentifiers {
/// The identifier to use next.
id: u64,
- /// Assigned `SenderId` for an `Address`.
+ /// Assigned [`SenderId`] for an [`Address`].
address_to_id: HashMap
,
- /// Reverse mapping of `SenderId` to `Address`.
+ /// Reverse mapping of [`SenderId`] to [`Address`].
sender_to_address: FxHashMap,
}
@@ -24,12 +24,12 @@ impl SenderIdentifiers {
self.sender_to_address.get(id)
}
- /// Returns the `SenderId` that belongs to the given address, if it exists
+ /// Returns the [`SenderId`] that belongs to the given address, if it exists
pub fn sender_id(&self, addr: &Address) -> Option {
self.address_to_id.get(addr).copied()
}
- /// Returns the existing `SendId` or assigns a new one if it's missing
+ /// Returns the existing [`SenderId`] or assigns a new one if it's missing
pub fn sender_id_or_create(&mut self, addr: Address) -> SenderId {
self.sender_id(&addr).unwrap_or_else(|| {
let id = self.next_id();
@@ -39,11 +39,11 @@ impl SenderIdentifiers {
})
}
- /// Returns a new address
+ /// Returns the current identifier and increments the counter.
fn next_id(&mut self) -> SenderId {
let id = self.id;
self.id = self.id.wrapping_add(1);
- SenderId(id)
+ id.into()
}
}
@@ -54,10 +54,8 @@ impl SenderIdentifiers {
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SenderId(u64);
-// === impl SenderId ===
-
impl SenderId {
- /// Returns a `Bound` for `TransactionId` starting with nonce `0`
+ /// Returns a `Bound` for [`TransactionId`] starting with nonce `0`
pub const fn start_bound(self) -> std::ops::Bound {
std::ops::Bound::Included(TransactionId::new(self, 0))
}
@@ -81,34 +79,29 @@ pub struct TransactionId {
pub nonce: u64,
}
-// === impl TransactionId ===
-
impl TransactionId {
/// Create a new identifier pair
pub const fn new(sender: SenderId, nonce: u64) -> Self {
Self { sender, nonce }
}
- /// Returns the `TransactionId` this transaction depends on.
+ /// Returns the [`TransactionId`] this transaction depends on.
///
/// This returns `transaction_nonce - 1` if `transaction_nonce` is higher than the
/// `on_chain_nonce`
pub fn ancestor(transaction_nonce: u64, on_chain_nonce: u64, sender: SenderId) -> Option {
- if transaction_nonce == on_chain_nonce {
- return None
- }
- let prev_nonce = transaction_nonce.saturating_sub(1);
- (on_chain_nonce <= prev_nonce).then(|| Self::new(sender, prev_nonce))
+ (transaction_nonce > on_chain_nonce)
+ .then(|| Self::new(sender, transaction_nonce.saturating_sub(1)))
}
- /// Returns the `TransactionId` that would come before this transaction.
+ /// Returns the [`TransactionId`] that would come before this transaction.
pub fn unchecked_ancestor(&self) -> Option {
(self.nonce != 0).then(|| Self::new(self.sender, self.nonce - 1))
}
- /// Returns the `TransactionId` that directly follows this transaction: `self.nonce + 1`
+ /// Returns the [`TransactionId`] that directly follows this transaction: `self.nonce + 1`
pub const fn descendant(&self) -> Self {
- Self::new(self.sender, self.nonce + 1)
+ Self::new(self.sender, self.next_nonce())
}
/// Returns the nonce that follows immediately after this one.
@@ -123,6 +116,67 @@ mod tests {
use super::*;
use std::collections::BTreeSet;
+ #[test]
+ fn test_transaction_id_new() {
+ let sender = SenderId(1);
+ let tx_id = TransactionId::new(sender, 5);
+ assert_eq!(tx_id.sender, sender);
+ assert_eq!(tx_id.nonce, 5);
+ }
+
+ #[test]
+ fn test_transaction_id_ancestor() {
+ let sender = SenderId(1);
+
+ // Special case with nonce 0 and higher on-chain nonce
+ let tx_id = TransactionId::ancestor(0, 1, sender);
+ assert_eq!(tx_id, None);
+
+ // Special case with nonce 0 and same on-chain nonce
+ let tx_id = TransactionId::ancestor(0, 0, sender);
+ assert_eq!(tx_id, None);
+
+ // Ancestor is the previous nonce if the transaction nonce is higher than the on-chain nonce
+ let tx_id = TransactionId::ancestor(5, 0, sender);
+ assert_eq!(tx_id, Some(TransactionId::new(sender, 4)));
+
+ // No ancestor if the transaction nonce is the same as the on-chain nonce
+ let tx_id = TransactionId::ancestor(5, 5, sender);
+ assert_eq!(tx_id, None);
+
+ // No ancestor if the transaction nonce is lower than the on-chain nonce
+ let tx_id = TransactionId::ancestor(5, 15, sender);
+ assert_eq!(tx_id, None);
+ }
+
+ #[test]
+ fn test_transaction_id_unchecked_ancestor() {
+ let sender = SenderId(1);
+
+ // Ancestor is the previous nonce if transaction nonce is higher than 0
+ let tx_id = TransactionId::new(sender, 5);
+ assert_eq!(tx_id.unchecked_ancestor(), Some(TransactionId::new(sender, 4)));
+
+ // No ancestor if transaction nonce is 0
+ let tx_id = TransactionId::new(sender, 0);
+ assert_eq!(tx_id.unchecked_ancestor(), None);
+ }
+
+ #[test]
+ fn test_transaction_id_descendant() {
+ let sender = SenderId(1);
+ let tx_id = TransactionId::new(sender, 5);
+ let descendant = tx_id.descendant();
+ assert_eq!(descendant, TransactionId::new(sender, 6));
+ }
+
+ #[test]
+ fn test_transaction_id_next_nonce() {
+ let sender = SenderId(1);
+ let tx_id = TransactionId::new(sender, 5);
+ assert_eq!(tx_id.next_nonce(), 6);
+ }
+
#[test]
fn test_transaction_id_ord_eq_sender() {
let tx1 = TransactionId::new(100u64.into(), 0u64);
@@ -140,4 +194,67 @@ mod tests {
let set = BTreeSet::from([tx1, tx2]);
assert_eq!(set.into_iter().collect::>(), vec![tx1, tx2]);
}
+
+ #[test]
+ fn test_address_retrieval() {
+ let mut identifiers = SenderIdentifiers::default();
+ let address = Address::new([1; 20]);
+ let id = identifiers.sender_id_or_create(address);
+ assert_eq!(identifiers.address(&id), Some(&address));
+ }
+
+ #[test]
+ fn test_sender_id_retrieval() {
+ let mut identifiers = SenderIdentifiers::default();
+ let address = Address::new([1; 20]);
+ let id = identifiers.sender_id_or_create(address);
+ assert_eq!(identifiers.sender_id(&address), Some(id));
+ }
+
+ #[test]
+ fn test_sender_id_or_create_existing() {
+ let mut identifiers = SenderIdentifiers::default();
+ let address = Address::new([1; 20]);
+ let id1 = identifiers.sender_id_or_create(address);
+ let id2 = identifiers.sender_id_or_create(address);
+ assert_eq!(id1, id2);
+ }
+
+ #[test]
+ fn test_sender_id_or_create_new() {
+ let mut identifiers = SenderIdentifiers::default();
+ let address1 = Address::new([1; 20]);
+ let address2 = Address::new([2; 20]);
+ let id1 = identifiers.sender_id_or_create(address1);
+ let id2 = identifiers.sender_id_or_create(address2);
+ assert_ne!(id1, id2);
+ }
+
+ #[test]
+ fn test_next_id_wrapping() {
+ let mut identifiers = SenderIdentifiers { id: u64::MAX, ..Default::default() };
+
+ // The current ID is `u64::MAX`, the next ID should wrap around to 0.
+ let id1 = identifiers.next_id();
+ assert_eq!(id1, SenderId(u64::MAX));
+
+ // The next ID should now be 0 because of wrapping.
+ let id2 = identifiers.next_id();
+ assert_eq!(id2, SenderId(0));
+
+ // And then 1, continuing incrementing.
+ let id3 = identifiers.next_id();
+ assert_eq!(id3, SenderId(1));
+ }
+
+ #[test]
+ fn test_sender_id_start_bound() {
+ let sender = SenderId(1);
+ let start_bound = sender.start_bound();
+ if let std::ops::Bound::Included(tx_id) = start_bound {
+ assert_eq!(tx_id, TransactionId::new(sender, 0));
+ } else {
+ panic!("Expected included bound");
+ }
+ }
}
diff --git a/crates/transaction-pool/src/pool/mod.rs b/crates/transaction-pool/src/pool/mod.rs
index a00d6522d1..1878242eee 100644
--- a/crates/transaction-pool/src/pool/mod.rs
+++ b/crates/transaction-pool/src/pool/mod.rs
@@ -194,7 +194,7 @@ where
self.pool.write().set_block_info(info)
}
- /// Returns the internal `SenderId` for this address
+ /// Returns the internal [`SenderId`] for this address
pub(crate) fn get_sender_id(&self, addr: Address) -> SenderId {
self.identifiers.write().sender_id_or_create(addr)
}