From b9341a7b4734eb606cd498ef1047801383ce9346 Mon Sep 17 00:00:00 2001 From: Federico Gimenez Date: Tue, 1 Oct 2024 20:57:32 +0200 Subject: [PATCH] fix(tree): use in-memory data first to query total difficulty (#11382) --- .../src/providers/blockchain_provider.rs | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/crates/storage/provider/src/providers/blockchain_provider.rs b/crates/storage/provider/src/providers/blockchain_provider.rs index 491a66d2fd..4c03621507 100644 --- a/crates/storage/provider/src/providers/blockchain_provider.rs +++ b/crates/storage/provider/src/providers/blockchain_provider.rs @@ -377,22 +377,25 @@ impl HeaderProvider for BlockchainProvider2 { } fn header_td_by_number(&self, number: BlockNumber) -> ProviderResult> { - // If the TD is recorded on disk, we can just return that - if let Some(td) = self.database.header_td_by_number(number)? { - Ok(Some(td)) - } else if self.canonical_in_memory_state.hash_by_number(number).is_some() { - // Otherwise, if the block exists in memory, we should return a TD for it. + let number = if self.canonical_in_memory_state.hash_by_number(number).is_some() { + // If the block exists in memory, we should return a TD for it. // // The canonical in memory state should only store post-merge blocks. Post-merge blocks // have zero difficulty. This means we can use the total difficulty for the last - // persisted block number. - let last_persisted_block_number = self.database.last_block_number()?; - self.database.header_td_by_number(last_persisted_block_number) + // finalized block number if present (so that we are not affected by reorgs), if not the + // last number in the database will be used. + if let Some(last_finalized_num_hash) = + self.canonical_in_memory_state.get_finalized_num_hash() + { + last_finalized_num_hash.number + } else { + self.database.last_block_number()? + } } else { - // If the block does not exist in memory, and does not exist on-disk, we should not - // return a TD for it. - Ok(None) - } + // Otherwise, return what we have on disk for the input block + number + }; + self.database.header_td_by_number(number) } fn headers_range(&self, range: impl RangeBounds) -> ProviderResult> { @@ -2650,6 +2653,10 @@ mod tests { let database_block = database_blocks.first().unwrap().clone(); let in_memory_block = in_memory_blocks.last().unwrap().clone(); + // make sure that the finalized block is on db + let finalized_block = database_blocks.get(database_blocks.len() - 3).unwrap(); + provider.set_finalized(finalized_block.header.clone()); + let blocks = [database_blocks, in_memory_blocks].concat(); assert_eq!(provider.header(&database_block.hash())?, Some(database_block.header().clone()));