diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 00dbb9b95c..918e3948b0 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -2,7 +2,7 @@ use crate::{ db::Transaction, metrics::HeaderMetrics, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput, }; -use futures_util::StreamExt; +use futures_util::{StreamExt, TryStreamExt}; use reth_db::{ cursor::{DbCursorRO, DbCursorRW}, database::Database, @@ -73,63 +73,66 @@ impl, input: ExecInput, ) -> Result { - let stage_progress = input.stage_progress.unwrap_or_default(); - self.update_head::(tx, stage_progress).await?; + let current_progress = input.stage_progress.unwrap_or_default(); + self.update_head::(tx, current_progress).await?; // Lookup the head and tip of the sync range - let (head, tip) = self.get_head_and_tip(tx, stage_progress).await?; + let (head, tip) = self.get_head_and_tip(tx, current_progress).await?; debug!(target: "sync::stages::headers", ?tip, head = ?head.hash(), "Commencing sync"); - let mut current_progress = stage_progress; - let mut stream = - self.downloader.stream(head.clone(), tip).chunks(self.commit_threshold as usize); - // The stage relies on the downloader to return the headers - // in descending order starting from the tip down to - // the local head (latest block in db) - while let Some(headers) = stream.next().await { - match headers.into_iter().collect::, _>>() { - Ok(res) => { - info!(target: "sync::stages::headers", len = res.len(), "Received headers"); - self.metrics.headers_counter.increment(res.len() as u64); + // The downloader returns the headers in descending order starting from the tip + // down to the local head (latest block in db) + let downloaded_headers: Result, DownloadError> = self + .downloader + .stream(head.clone(), tip) + .take(self.commit_threshold as usize) // Only stream [self.commit_threshold] headers + .try_collect() + .await; - // Perform basic response validation - self.validate_header_response(&res)?; - let write_progress = - self.write_headers::(tx, res).await?.unwrap_or_default(); - current_progress = current_progress.max(write_progress); + match downloaded_headers { + Ok(res) => { + info!(target: "sync::stages::headers", len = res.len(), "Received headers"); + self.metrics.headers_counter.increment(res.len() as u64); + + // Perform basic response validation + self.validate_header_response(&res)?; + + // Write the headers to db + self.write_headers::(tx, res).await?.unwrap_or_default(); + + if self.is_stage_done(tx, current_progress).await? { + // Update total difficulty values after we have reached fork choice + debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty"); + self.write_td::(tx, &head)?; + let stage_progress = current_progress.max( + tx.cursor::()? + .last()? + .map(|(num, _)| num) + .unwrap_or_default(), + ); + Ok(ExecOutput { stage_progress, done: true }) + } else { + Ok(ExecOutput { stage_progress: current_progress, done: false }) } - Err(e) => { - self.metrics.update_headers_error_metrics(&e); - match e { - DownloadError::Timeout => { - warn!(target: "sync::stages::headers", "No response for header request"); - return Err(StageError::Recoverable(DownloadError::Timeout.into())) - } - DownloadError::HeaderValidation { hash, error } => { - error!(target: "sync::stages::headers", ?error, ?hash, "Validation error"); - return Err(StageError::Validation { block: stage_progress, error }) - } - error => { - error!(target: "sync::stages::headers", ?error, "Unexpected error"); - return Err(StageError::Recoverable(error.into())) - } + } + Err(e) => { + self.metrics.update_headers_error_metrics(&e); + match e { + DownloadError::Timeout => { + warn!(target: "sync::stages::headers", "No response for header request"); + return Err(StageError::Recoverable(DownloadError::Timeout.into())) + } + DownloadError::HeaderValidation { hash, error } => { + error!(target: "sync::stages::headers", ?error, ?hash, "Validation error"); + return Err(StageError::Validation { block: current_progress, error }) + } + error => { + error!(target: "sync::stages::headers", ?error, "Unexpected error"); + return Err(StageError::Recoverable(error.into())) } } } } - - // Write total difficulty values after all headers have been inserted - debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty"); - self.write_td::(tx, &head)?; - - let stage_progress = current_progress.max( - tx.cursor::()? - .last()? - .map(|(num, _)| num) - .unwrap_or_default(), - ); - - Ok(ExecOutput { stage_progress, done: true }) } /// Unwind the stage. @@ -166,6 +169,19 @@ impl Ok(()) } + async fn is_stage_done( + &self, + tx: &Transaction<'_, DB>, + stage_progress: u64, + ) -> Result { + let mut header_cursor = tx.cursor::()?; + let (head_num, _) = header_cursor + .seek_exact(stage_progress)? + .ok_or(DatabaseIntegrityError::CanonicalHeader { number: stage_progress })?; + // Check if the next entry is congruent + Ok(header_cursor.next()?.map(|(next_num, _)| head_num + 1 == next_num).unwrap_or_default()) + } + /// Get the head and tip of the range we need to sync async fn get_head_and_tip( &self, @@ -207,6 +223,7 @@ impl None => self.next_fork_choice_state(&head.hash()).await.head_block_hash, _ => return Err(StageError::StageProgress(stage_progress)), }; + Ok((head, tip)) } @@ -261,7 +278,6 @@ impl cursor_header.insert(key, header)?; cursor_canonical.insert(key.number(), key.hash())?; } - Ok(latest) } @@ -388,11 +404,7 @@ mod tests { runner.consensus.update_tip(tip.hash()); let result = rx.await.unwrap(); - assert_matches!( - result, - Ok(ExecOutput { done: true, stage_progress }) - if stage_progress == tip.number - ); + assert_matches!(result, Ok(ExecOutput { done: true, stage_progress }) if stage_progress == tip.number); assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } @@ -507,7 +519,7 @@ mod tests { client: self.client.clone(), downloader: self.downloader.clone(), network_handle: self.network_handle.clone(), - commit_threshold: 100, + commit_threshold: 500, metrics: HeaderMetrics::default(), } }