diff --git a/crates/primitives/src/checkpoints.rs b/crates/primitives/src/checkpoints.rs index ee9baad7c0..ba460e493a 100644 --- a/crates/primitives/src/checkpoints.rs +++ b/crates/primitives/src/checkpoints.rs @@ -1,6 +1,6 @@ use crate::{ trie::{hash_builder::HashBuilderState, StoredSubNode}, - Address, H256, + Address, BlockNumber, H256, }; use bytes::Buf; use reth_codecs::{main_codec, Compact}; @@ -8,7 +8,8 @@ use reth_codecs::{main_codec, Compact}; /// Saves the progress of Merkle stage. #[derive(Default, Debug, Clone, PartialEq)] pub struct MerkleCheckpoint { - // TODO: target block? + /// The target block number. + pub target_block: BlockNumber, /// The last hashed account key processed. pub last_account_key: H256, /// The last walker key processed. @@ -19,6 +20,19 @@ pub struct MerkleCheckpoint { pub state: HashBuilderState, } +impl MerkleCheckpoint { + /// Creates a new Merkle checkpoint. + pub fn new( + target_block: BlockNumber, + last_account_key: H256, + last_walker_key: Vec, + walker_stack: Vec, + state: HashBuilderState, + ) -> Self { + Self { target_block, last_account_key, last_walker_key, walker_stack, state } + } +} + impl Compact for MerkleCheckpoint { fn to_compact(self, buf: &mut B) -> usize where @@ -26,6 +40,9 @@ impl Compact for MerkleCheckpoint { { let mut len = 0; + buf.put_u64(self.target_block); + len += 8; + buf.put_slice(self.last_account_key.as_slice()); len += self.last_account_key.len(); @@ -47,6 +64,8 @@ impl Compact for MerkleCheckpoint { where Self: Sized, { + let target_block = buf.get_u64(); + let last_account_key = H256::from_slice(&buf[..32]); buf.advance(32); @@ -63,7 +82,16 @@ impl Compact for MerkleCheckpoint { } let (state, buf) = HashBuilderState::from_compact(buf, 0); - (MerkleCheckpoint { last_account_key, last_walker_key, walker_stack, state }, buf) + ( + MerkleCheckpoint { + target_block, + last_account_key, + last_walker_key, + walker_stack, + state, + }, + buf, + ) } } @@ -92,3 +120,30 @@ pub struct StorageHashingCheckpoint { /// Last transition id pub to: u64, } + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + + #[test] + fn merkle_checkpoint_roundtrip() { + let mut rng = rand::thread_rng(); + let checkpoint = MerkleCheckpoint { + target_block: rng.gen(), + last_account_key: H256::from_low_u64_be(rng.gen()), + last_walker_key: H256::from_low_u64_be(rng.gen()).to_vec(), + walker_stack: Vec::from([StoredSubNode { + key: H256::from_low_u64_be(rng.gen()).to_vec(), + nibble: Some(rng.gen()), + node: None, + }]), + state: HashBuilderState::default(), + }; + + let mut buf = Vec::new(); + let encoded = checkpoint.clone().to_compact(&mut buf); + let (decoded, _) = MerkleCheckpoint::from_compact(&buf, encoded); + assert_eq!(decoded, checkpoint); + } +} diff --git a/crates/stages/src/stages/merkle.rs b/crates/stages/src/stages/merkle.rs index ece8da5e9b..d050c48e1a 100644 --- a/crates/stages/src/stages/merkle.rs +++ b/crates/stages/src/stages/merkle.rs @@ -6,7 +6,7 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, }; use reth_interfaces::consensus; -use reth_primitives::{hex, BlockNumber, MerkleCheckpoint, H256}; +use reth_primitives::{hex, trie::StoredSubNode, BlockNumber, MerkleCheckpoint, H256}; use reth_provider::Transaction; use reth_trie::{IntermediateStateRootState, StateRoot, StateRootProgress}; use std::{fmt::Debug, ops::DerefMut}; @@ -168,7 +168,7 @@ impl Stage for MerkleStage { block_root } else if to_block - from_block > threshold || from_block == 1 { // if there are more blocks than threshold it is faster to rebuild the trie - if let Some(checkpoint) = &checkpoint { + if let Some(checkpoint) = checkpoint.as_ref().filter(|c| c.target_block == to_block) { debug!( target: "sync::stages::merkle::exec", current = ?current_block, @@ -182,8 +182,11 @@ impl Stage for MerkleStage { target: "sync::stages::merkle::exec", current = ?current_block, target = ?to_block, + previous_checkpoint = ?checkpoint, "Rebuilding trie" ); + // Reset the checkpoint and clear trie tables + self.save_execution_checkpoint(tx, None)?; tx.clear::()?; tx.clear::()?; } @@ -195,7 +198,14 @@ impl Stage for MerkleStage { match progress { StateRootProgress::Progress(state, updates) => { updates.flush(tx.deref_mut())?; - self.save_execution_checkpoint(tx, Some((*state).into()))?; + let checkpoint = MerkleCheckpoint::new( + to_block, + state.last_account_key, + state.last_walker_key.hex_data, + state.walker_stack.into_iter().map(StoredSubNode::from).collect(), + state.hash_builder.into(), + ); + self.save_execution_checkpoint(tx, Some(checkpoint))?; return Ok(ExecOutput { stage_progress: input.stage_progress(), done: false }) } StateRootProgress::Complete(root, updates) => { diff --git a/crates/trie/src/progress.rs b/crates/trie/src/progress.rs index 0164d60178..3ef851e546 100644 --- a/crates/trie/src/progress.rs +++ b/crates/trie/src/progress.rs @@ -1,6 +1,6 @@ use crate::{trie_cursor::CursorSubNode, updates::TrieUpdates}; use reth_primitives::{ - trie::{hash_builder::HashBuilder, Nibbles, StoredSubNode}, + trie::{hash_builder::HashBuilder, Nibbles}, MerkleCheckpoint, H256, }; @@ -27,17 +27,6 @@ pub struct IntermediateStateRootState { pub last_walker_key: Nibbles, } -impl From for MerkleCheckpoint { - fn from(value: IntermediateStateRootState) -> Self { - Self { - last_account_key: value.last_account_key, - last_walker_key: value.last_walker_key.hex_data, - walker_stack: value.walker_stack.into_iter().map(StoredSubNode::from).collect(), - state: value.hash_builder.into(), - } - } -} - impl From for IntermediateStateRootState { fn from(value: MerkleCheckpoint) -> Self { Self {