From 90621de27cb413c09f641af676c5dc47672bfa22 Mon Sep 17 00:00:00 2001 From: Forostovec Date: Mon, 17 Nov 2025 17:43:47 +0200 Subject: [PATCH] fix(prune): avoid extra iterator consumption (#19758) Co-authored-by: Alexey Shekhirin <5773434+shekhirin@users.noreply.github.com> --- crates/prune/prune/src/db_ext.rs | 163 ++++++++++++++++++++++++++++++- 1 file changed, 160 insertions(+), 3 deletions(-) diff --git a/crates/prune/prune/src/db_ext.rs b/crates/prune/prune/src/db_ext.rs index 63ab87c446..09224a795f 100644 --- a/crates/prune/prune/src/db_ext.rs +++ b/crates/prune/prune/src/db_ext.rs @@ -19,11 +19,12 @@ pub(crate) trait DbTxPruneExt: DbTxMut { mut delete_callback: impl FnMut(TableRow), ) -> Result<(usize, bool), DatabaseError> { let mut cursor = self.cursor_write::()?; - let mut keys = keys.into_iter(); + let mut keys = keys.into_iter().peekable(); let mut deleted_entries = 0; - for key in &mut keys { + let mut done = true; + while keys.peek().is_some() { if limiter.is_limit_reached() { debug!( target: "providers::db", @@ -33,9 +34,11 @@ pub(crate) trait DbTxPruneExt: DbTxMut { table = %T::NAME, "Pruning limit reached" ); + done = false; break } + let key = keys.next().expect("peek() said Some"); let row = cursor.seek_exact(key)?; if let Some(row) = row { cursor.delete_current()?; @@ -45,7 +48,6 @@ pub(crate) trait DbTxPruneExt: DbTxMut { } } - let done = keys.next().is_none(); Ok((deleted_entries, done)) } @@ -124,3 +126,158 @@ pub(crate) trait DbTxPruneExt: DbTxMut { } impl DbTxPruneExt for Tx where Tx: DbTxMut {} + +#[cfg(test)] +mod tests { + use super::DbTxPruneExt; + use crate::PruneLimiter; + use reth_db_api::tables; + use reth_primitives_traits::SignerRecoverable; + use reth_provider::{DBProvider, DatabaseProviderFactory}; + use reth_stages::test_utils::{StorageKind, TestStageDB}; + use reth_testing_utils::generators::{self, random_block_range, BlockRangeParams}; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + struct CountingIter { + data: Vec, + calls: Arc, + } + + impl CountingIter { + fn new(data: Vec, calls: Arc) -> Self { + Self { data, calls } + } + } + + struct CountingIntoIter { + inner: std::vec::IntoIter, + calls: Arc, + } + + impl Iterator for CountingIntoIter { + type Item = u64; + fn next(&mut self) -> Option { + let res = self.inner.next(); + self.calls.fetch_add(1, Ordering::SeqCst); + res + } + } + + impl IntoIterator for CountingIter { + type Item = u64; + type IntoIter = CountingIntoIter; + fn into_iter(self) -> Self::IntoIter { + CountingIntoIter { inner: self.data.into_iter(), calls: self.calls } + } + } + + #[test] + fn prune_table_with_iterator_early_exit_does_not_overconsume() { + let db = TestStageDB::default(); + let mut rng = generators::rng(); + + let blocks = random_block_range( + &mut rng, + 1..=3, + BlockRangeParams { + parent: Some(alloy_primitives::B256::ZERO), + tx_count: 2..3, + ..Default::default() + }, + ); + db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks"); + + let mut tx_senders = Vec::new(); + for block in &blocks { + tx_senders.reserve_exact(block.transaction_count()); + for transaction in &block.body().transactions { + tx_senders.push(( + tx_senders.len() as u64, + transaction.recover_signer().expect("recover signer"), + )); + } + } + let total = tx_senders.len(); + db.insert_transaction_senders(tx_senders).expect("insert transaction senders"); + + let provider = db.factory.database_provider_rw().unwrap(); + + let calls = Arc::new(AtomicUsize::new(0)); + let keys: Vec = (0..total as u64).collect(); + let counting_iter = CountingIter::new(keys, calls.clone()); + + let mut limiter = PruneLimiter::default().set_deleted_entries_limit(2); + + let (pruned, done) = provider + .tx_ref() + .prune_table_with_iterator::( + counting_iter, + &mut limiter, + |_| {}, + ) + .expect("prune"); + + assert_eq!(pruned, 2); + assert!(!done); + assert_eq!(calls.load(Ordering::SeqCst), pruned + 1); + + provider.commit().expect("commit"); + assert_eq!(db.table::().unwrap().len(), total - 2); + } + + #[test] + fn prune_table_with_iterator_consumes_to_end_reports_done() { + let db = TestStageDB::default(); + let mut rng = generators::rng(); + + let blocks = random_block_range( + &mut rng, + 1..=2, + BlockRangeParams { + parent: Some(alloy_primitives::B256::ZERO), + tx_count: 1..2, + ..Default::default() + }, + ); + db.insert_blocks(blocks.iter(), StorageKind::Database(None)).expect("insert blocks"); + + let mut tx_senders = Vec::new(); + for block in &blocks { + for transaction in &block.body().transactions { + tx_senders.push(( + tx_senders.len() as u64, + transaction.recover_signer().expect("recover signer"), + )); + } + } + let total = tx_senders.len(); + db.insert_transaction_senders(tx_senders).expect("insert transaction senders"); + + let provider = db.factory.database_provider_rw().unwrap(); + + let calls = Arc::new(AtomicUsize::new(0)); + let keys: Vec = (0..total as u64).collect(); + let counting_iter = CountingIter::new(keys, calls.clone()); + + let mut limiter = PruneLimiter::default().set_deleted_entries_limit(usize::MAX); + + let (pruned, done) = provider + .tx_ref() + .prune_table_with_iterator::( + counting_iter, + &mut limiter, + |_| {}, + ) + .expect("prune"); + + assert_eq!(pruned, total); + assert!(done); + assert_eq!(calls.load(Ordering::SeqCst), total + 1); + + provider.commit().expect("commit"); + assert_eq!(db.table::().unwrap().len(), 0); + } +}