mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-01-29 17:18:08 -05:00
bug(cursor_walk): handle empty range (#2057)
This commit is contained in:
@@ -239,8 +239,6 @@ pub struct RangeWalker<'cursor, 'tx, T: Table, CURSOR: DbCursorRO<'tx, T>> {
|
||||
cursor: &'cursor mut CURSOR,
|
||||
/// `(key, value)` where to start the walk.
|
||||
start: IterPairResult<T>,
|
||||
/// `key` where to start the walk.
|
||||
start_key: Bound<T::Key>,
|
||||
/// `key` where to stop the walk.
|
||||
end_key: Bound<T::Key>,
|
||||
/// flag whether is ended
|
||||
@@ -258,12 +256,9 @@ impl<'cursor, 'tx, T: Table, CURSOR: DbCursorRO<'tx, T>> std::iter::Iterator
|
||||
return None
|
||||
}
|
||||
|
||||
let start = self.start.take();
|
||||
if start.is_some() && matches!(self.start_key, Bound::Included(_) | Bound::Unbounded) {
|
||||
return start
|
||||
}
|
||||
let next_item = self.start.take().or_else(|| self.cursor.next().transpose());
|
||||
|
||||
match self.cursor.next().transpose() {
|
||||
match next_item {
|
||||
Some(Ok((key, value))) => match &self.end_key {
|
||||
Bound::Included(end_key) if &key <= end_key => Some(Ok((key, value))),
|
||||
Bound::Excluded(end_key) if &key < end_key => Some(Ok((key, value))),
|
||||
@@ -288,17 +283,19 @@ impl<'cursor, 'tx, T: Table, CURSOR: DbCursorRO<'tx, T>> RangeWalker<'cursor, 't
|
||||
pub fn new(
|
||||
cursor: &'cursor mut CURSOR,
|
||||
start: IterPairResult<T>,
|
||||
start_key: Bound<T::Key>,
|
||||
end_key: Bound<T::Key>,
|
||||
) -> Self {
|
||||
Self {
|
||||
cursor,
|
||||
start,
|
||||
start_key,
|
||||
end_key,
|
||||
is_done: false,
|
||||
_tx_phantom: std::marker::PhantomData,
|
||||
}
|
||||
// mark done if range is empty.
|
||||
let is_done = match start {
|
||||
Some(Ok((ref start_key, _))) => match &end_key {
|
||||
Bound::Included(end_key) if start_key > end_key => true,
|
||||
Bound::Excluded(end_key) if start_key >= end_key => true,
|
||||
_ => false,
|
||||
},
|
||||
None => true,
|
||||
_ => false,
|
||||
};
|
||||
Self { cursor, start, end_key, is_done, _tx_phantom: std::marker::PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -94,12 +94,7 @@ impl<'tx, K: TransactionKind, T: Table> DbCursorRO<'tx, T> for Cursor<'tx, K, T>
|
||||
Self: Sized,
|
||||
{
|
||||
let start = match range.start_bound().cloned() {
|
||||
Bound::Included(key) => {
|
||||
if matches!(range.end_bound().cloned(), Bound::Included(end_key) | Bound::Excluded(end_key) if end_key < key) {
|
||||
return Err(Error::Read(2))
|
||||
}
|
||||
self.inner.set_range(key.encode().as_ref())
|
||||
}
|
||||
Bound::Included(key) => self.inner.set_range(key.encode().as_ref()),
|
||||
Bound::Excluded(_key) => {
|
||||
unreachable!("Rust doesn't allow for Bound::Excluded in starting bounds");
|
||||
}
|
||||
@@ -108,7 +103,7 @@ impl<'tx, K: TransactionKind, T: Table> DbCursorRO<'tx, T> for Cursor<'tx, K, T>
|
||||
.map_err(|e| Error::Read(e.into()))?
|
||||
.map(decoder::<T>);
|
||||
|
||||
Ok(RangeWalker::new(self, start, range.start_bound().cloned(), range.end_bound().cloned()))
|
||||
Ok(RangeWalker::new(self, start, range.end_bound().cloned()))
|
||||
}
|
||||
|
||||
fn walk_back<'cursor>(
|
||||
|
||||
@@ -296,12 +296,16 @@ mod tests {
|
||||
let mut cursor = tx.cursor_read::<CanonicalHeaders>().unwrap();
|
||||
|
||||
// start bound greater than end bound
|
||||
let res = cursor.walk_range(3..1);
|
||||
assert!(matches!(res, Err(Error::Read(2))));
|
||||
let mut res = cursor.walk_range(3..1).unwrap();
|
||||
assert_eq!(res.next(), None);
|
||||
|
||||
// start bound greater than end bound
|
||||
let res = cursor.walk_range(15..=2);
|
||||
assert!(matches!(res, Err(Error::Read(2))));
|
||||
let mut res = cursor.walk_range(15..=2).unwrap();
|
||||
assert_eq!(res.next(), None);
|
||||
|
||||
// returning nothing
|
||||
let mut walker = cursor.walk_range(1..1).unwrap();
|
||||
assert_eq!(walker.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1621,7 +1621,7 @@ mod test {
|
||||
tx.insert_hashes(
|
||||
block1.number,
|
||||
exec_res1.transitions_count() as TransitionId,
|
||||
exec_res2.transitions_count() as TransitionId,
|
||||
(exec_res1.transitions_count() + exec_res2.transitions_count()) as TransitionId,
|
||||
2,
|
||||
block2.hash,
|
||||
block2.state_root,
|
||||
|
||||
Reference in New Issue
Block a user