fix(prune): avoid extra iterator consumption (#19758)

Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com>
This commit is contained in:
Forostovec
2025-11-17 17:43:47 +02:00
committed by GitHub
parent adbc68c66c
commit 90621de27c

View File

@@ -19,11 +19,12 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
mut delete_callback: impl FnMut(TableRow<T>),
) -> Result<(usize, bool), DatabaseError> {
let mut cursor = self.cursor_write::<T>()?;
let mut keys = keys.into_iter();
let mut keys = keys.into_iter().peekable();
let mut deleted_entries = 0;
for key in &mut keys {
let mut done = true;
while keys.peek().is_some() {
if limiter.is_limit_reached() {
debug!(
target: "providers::db",
@@ -33,9 +34,11 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
table = %T::NAME,
"Pruning limit reached"
);
done = false;
break
}
let key = keys.next().expect("peek() said Some");
let row = cursor.seek_exact(key)?;
if let Some(row) = row {
cursor.delete_current()?;
@@ -45,7 +48,6 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
}
}
let done = keys.next().is_none();
Ok((deleted_entries, done))
}
@@ -124,3 +126,158 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
}
impl<Tx> DbTxPruneExt for Tx where Tx: DbTxMut {}
#[cfg(test)]
mod tests {
use super::DbTxPruneExt;
use crate::PruneLimiter;
use reth_db_api::tables;
use reth_primitives_traits::SignerRecoverable;
use reth_provider::{DBProvider, DatabaseProviderFactory};
use reth_stages::test_utils::{StorageKind, TestStageDB};
use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
struct CountingIter {
data: Vec<u64>,
calls: Arc<AtomicUsize>,
}
impl CountingIter {
fn new(data: Vec<u64>, calls: Arc<AtomicUsize>) -> Self {
Self { data, calls }
}
}
struct CountingIntoIter {
inner: std::vec::IntoIter<u64>,
calls: Arc<AtomicUsize>,
}
impl Iterator for CountingIntoIter {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
let res = self.inner.next();
self.calls.fetch_add(1, Ordering::SeqCst);
res
}
}
impl IntoIterator for CountingIter {
type Item = u64;
type IntoIter = CountingIntoIter;
fn into_iter(self) -> Self::IntoIter {
CountingIntoIter { inner: self.data.into_iter(), calls: self.calls }
}
}
#[test]
fn prune_table_with_iterator_early_exit_does_not_overconsume() {
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(
&mut rng,
1..=3,
BlockRangeParams {
parent: Some(alloy_primitives::B256::ZERO),
tx_count: 2..3,
..Default::default()
},
);
db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
let mut tx_senders = Vec::new();
for block in &blocks {
tx_senders.reserve_exact(block.transaction_count());
for transaction in &block.body().transactions {
tx_senders.push((
tx_senders.len() as u64,
transaction.recover_signer().expect("recover signer"),
));
}
}
let total = tx_senders.len();
db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
let provider = db.factory.database_provider_rw().unwrap();
let calls = Arc::new(AtomicUsize::new(0));
let keys: Vec<u64> = (0..total as u64).collect();
let counting_iter = CountingIter::new(keys, calls.clone());
let mut limiter = PruneLimiter::default().set_deleted_entries_limit(2);
let (pruned, done) = provider
.tx_ref()
.prune_table_with_iterator::<tables::TransactionSenders>(
counting_iter,
&mut limiter,
|_| {},
)
.expect("prune");
assert_eq!(pruned, 2);
assert!(!done);
assert_eq!(calls.load(Ordering::SeqCst), pruned + 1);
provider.commit().expect("commit");
assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), total - 2);
}
#[test]
fn prune_table_with_iterator_consumes_to_end_reports_done() {
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(
&mut rng,
1..=2,
BlockRangeParams {
parent: Some(alloy_primitives::B256::ZERO),
tx_count: 1..2,
..Default::default()
},
);
db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks");
let mut tx_senders = Vec::new();
for block in &blocks {
for transaction in &block.body().transactions {
tx_senders.push((
tx_senders.len() as u64,
transaction.recover_signer().expect("recover signer"),
));
}
}
let total = tx_senders.len();
db.insert_transaction_senders(tx_senders).expect("insert transaction senders");
let provider = db.factory.database_provider_rw().unwrap();
let calls = Arc::new(AtomicUsize::new(0));
let keys: Vec<u64> = (0..total as u64).collect();
let counting_iter = CountingIter::new(keys, calls.clone());
let mut limiter = PruneLimiter::default().set_deleted_entries_limit(usize::MAX);
let (pruned, done) = provider
.tx_ref()
.prune_table_with_iterator::<tables::TransactionSenders>(
counting_iter,
&mut limiter,
|_| {},
)
.expect("prune");
assert_eq!(pruned, total);
assert!(done);
assert_eq!(calls.load(Ordering::SeqCst), total + 1);
provider.commit().expect("commit");
assert_eq!(db.table::<tables::TransactionSenders>().unwrap().len(), 0);
}
}