diff --git a/crates/db/src/kv/mod.rs b/crates/db/src/kv/mod.rs index 87bd92c7f0..b4f3910768 100644 --- a/crates/db/src/kv/mod.rs +++ b/crates/db/src/kv/mod.rs @@ -141,7 +141,9 @@ mod tests { use reth_interfaces::{ db::{ models::ShardedKey, - tables::{AccountHistory, Headers, PlainAccountState, PlainStorageState}, + tables::{ + AccountHistory, CanonicalHeaders, Headers, PlainAccountState, PlainStorageState, + }, Database, DbCursorRO, DbDupCursorRO, DbTx, DbTxMut, }, provider::{ProviderImpl, StateProviderFactory}, @@ -208,6 +210,32 @@ mod tests { assert_eq!(first.1, value, "First next should be put value"); } + #[test] + fn db_cursor_seek_exact_or_previous_key() { + let db: Arc> = test_utils::create_test_db(EnvKind::RW); + + // PUT + let tx = db.tx_mut().expect(ERROR_INIT_TX); + vec![0, 1, 3] + .into_iter() + .try_for_each(|key| tx.put::(key, H256::zero())) + .expect(ERROR_PUT); + tx.commit().expect(ERROR_COMMIT); + + // Cursor + let missing_key = 2; + let tx = db.tx().expect(ERROR_INIT_TX); + let mut cursor = tx.cursor::().unwrap(); + assert_eq!(cursor.current(), Ok(None)); + + // Seek exact + let exact = cursor.seek_exact(missing_key).unwrap(); + assert_eq!(exact, None); + assert_eq!(cursor.current(), Ok(Some((missing_key + 1, H256::zero())))); + assert_eq!(cursor.prev(), Ok(Some((missing_key - 1, H256::zero())))); + assert_eq!(cursor.prev(), Ok(Some((missing_key - 2, H256::zero())))); + } + #[test] fn db_closure_put_get() { let path = TempDir::new().expect(test_utils::ERROR_TEMPDIR).into_path(); diff --git a/crates/stages/src/db.rs b/crates/stages/src/db.rs index 5ded66f334..e348ab731f 100644 --- a/crates/stages/src/db.rs +++ b/crates/stages/src/db.rs @@ -3,6 +3,7 @@ use std::{ ops::{Deref, DerefMut}, }; +use reth_db::kv::cursor::PairResult; use reth_interfaces::db::{ models::{BlockNumHash, NumTransactions}, tables, Database, DatabaseGAT, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table, @@ -89,6 +90,12 @@ where self.tx.take(); } + /// Get exact or previous value from the database + pub(crate) fn get_exact_or_prev(&self, key: T::Key) -> PairResult { + let mut cursor = self.cursor::()?; + Ok(cursor.seek_exact(key)?.or(cursor.prev()?)) + } + /// Query [tables::CanonicalHeaders] table for block hash by block number pub(crate) fn get_block_hash(&self, number: BlockNumber) -> Result { let hash = self diff --git a/crates/stages/src/stages/senders.rs b/crates/stages/src/stages/senders.rs index 42fc2737bc..fe45063ab6 100644 --- a/crates/stages/src/stages/senders.rs +++ b/crates/stages/src/stages/senders.rs @@ -94,10 +94,17 @@ impl Stage for SendersStage { input: UnwindInput, ) -> Result> { // Look up the hash of the unwind block - if let Some(unwind_hash) = db.get::(input.unwind_to)? { + if let Some((_, unwind_hash)) = + db.get_exact_or_prev::(input.unwind_to)? + { // Look up the cumulative tx count at unwind block - let latest_tx = db.get_tx_count((input.unwind_to, unwind_hash).into())?; - db.unwind_table_by_num::(latest_tx - 1)?; + let key = (input.unwind_to, unwind_hash).into(); + if let Some((_, unwind_tx_count)) = + db.get_exact_or_prev::(key)? + { + // The last remaining tx_id should be at `cum_tx_count - 1` + db.unwind_table_by_num::(unwind_tx_count - 1)?; + } } Ok(UnwindOutput { stage_progress: input.unwind_to })