fix(trie): PST: Fix update_leaf atomicity, remove update_leaves revealed tracking, fix callback calling (#21573)

This commit is contained in:
Brian Picciano
2026-01-29 17:18:42 +01:00
committed by GitHub
parent 70bfdafd26
commit 320f2a6015
3 changed files with 120 additions and 39 deletions

View File

@@ -123,9 +123,6 @@ pub struct ParallelSparseTrie {
update_actions_buffers: Vec<Vec<SparseTrieUpdatesAction>>,
/// Thresholds controlling when parallelism is enabled for different operations.
parallelism_thresholds: ParallelismThresholds,
/// Tracks proof targets already requested via `update_leaves` to avoid duplicate callbacks
/// across retry calls. Key is (`leaf_path`, `min_depth`).
requested_proof_targets: alloy_primitives::map::HashSet<(Nibbles, u8)>,
/// Metrics for the parallel sparse trie.
#[cfg(feature = "metrics")]
metrics: crate::metrics::ParallelSparseTrieMetrics,
@@ -144,7 +141,6 @@ impl Default for ParallelSparseTrie {
branch_node_masks: BranchNodeMasksMap::default(),
update_actions_buffers: Vec::default(),
parallelism_thresholds: Default::default(),
requested_proof_targets: Default::default(),
#[cfg(feature = "metrics")]
metrics: Default::default(),
}
@@ -1182,7 +1178,7 @@ impl SparseTrieExt for ParallelSparseTrie {
fn update_leaves(
&mut self,
updates: &mut alloy_primitives::map::B256Map<reth_trie_sparse::LeafUpdate>,
mut proof_required_fn: impl FnMut(Nibbles, u8),
mut proof_required_fn: impl FnMut(B256, u8),
) -> SparseTrieResult<()> {
use reth_trie_sparse::{provider::NoRevealProvider, LeafUpdate};
@@ -1204,10 +1200,9 @@ impl SparseTrieExt for ParallelSparseTrie {
Ok(()) => {}
Err(e) => {
if let Some(path) = Self::get_retriable_path(&e) {
let target_key = Self::nibbles_to_padded_b256(&path);
let min_len = (path.len() as u8).min(64);
if self.requested_proof_targets.insert((full_path, min_len)) {
proof_required_fn(full_path, min_len);
}
proof_required_fn(target_key, min_len);
updates.insert(key, LeafUpdate::Changed(value));
} else {
return Err(e);
@@ -1219,10 +1214,9 @@ impl SparseTrieExt for ParallelSparseTrie {
if let Err(e) = self.update_leaf(full_path, value.clone(), NoRevealProvider)
{
if let Some(path) = Self::get_retriable_path(&e) {
let target_key = Self::nibbles_to_padded_b256(&path);
let min_len = (path.len() as u8).min(64);
if self.requested_proof_targets.insert((full_path, min_len)) {
proof_required_fn(full_path, min_len);
}
proof_required_fn(target_key, min_len);
updates.insert(key, LeafUpdate::Changed(value));
} else {
return Err(e);
@@ -1234,10 +1228,9 @@ impl SparseTrieExt for ParallelSparseTrie {
// Touched is read-only: check if path is accessible, request proof if blinded.
match self.find_leaf(&full_path, None) {
Err(LeafLookupError::BlindedNode { path, .. }) => {
let target_key = Self::nibbles_to_padded_b256(&path);
let min_len = (path.len() as u8).min(64);
if self.requested_proof_targets.insert((full_path, min_len)) {
proof_required_fn(full_path, min_len);
}
proof_required_fn(target_key, min_len);
updates.insert(key, LeafUpdate::Touched);
}
// Path is fully revealed (exists or proven non-existent), no action needed.
@@ -1263,14 +1256,6 @@ impl ParallelSparseTrie {
self.updates.is_some()
}
/// Clears the set of already-requested proof targets.
///
/// Call this when reusing the trie for a new payload to ensure proof callbacks
/// are emitted fresh.
pub fn clear_requested_proof_targets(&mut self) {
self.requested_proof_targets.clear();
}
/// Returns true if parallelism should be enabled for revealing the given number of nodes.
/// Will always return false in nostd builds.
const fn is_reveal_parallelism_enabled(&self, num_nodes: usize) -> bool {
@@ -1303,6 +1288,14 @@ impl ParallelSparseTrie {
}
}
/// Converts a nibbles path to a B256, right-padding with zeros to 64 nibbles.
fn nibbles_to_padded_b256(path: &Nibbles) -> B256 {
let packed = path.pack();
let mut bytes = [0u8; 32];
bytes[..packed.len()].copy_from_slice(&packed);
B256::from(bytes)
}
/// Rolls back a partial update by removing the value, removing any inserted nodes,
/// and restoring any modified original node.
/// This ensures `update_leaf` is atomic - either it succeeds completely or leaves the trie
@@ -2110,6 +2103,9 @@ impl SparseSubtrie {
///
/// If an update requires revealing a blinded node, an error is returned if the blinded
/// provider returns an error.
///
/// This method is atomic: if an error occurs during structural changes, all modifications
/// are rolled back and the trie state is unchanged.
pub fn update_leaf(
&mut self,
full_path: Nibbles,
@@ -2118,21 +2114,46 @@ impl SparseSubtrie {
retain_updates: bool,
) -> SparseTrieResult<Option<(Nibbles, BranchNodeMasks)>> {
debug_assert!(full_path.starts_with(&self.path));
let existing = self.inner.values.insert(full_path, value);
if existing.is_some() {
// trie structure unchanged, return immediately
// Check if value already exists - if so, just update it (no structural changes needed)
if let Entry::Occupied(mut e) = self.inner.values.entry(full_path) {
e.insert(value);
return Ok(None)
}
// Here we are starting at the root of the subtrie, and traversing from there.
let mut current = Some(self.path);
let mut revealed = None;
// Track inserted nodes and modified original for rollback on error
let mut inserted_nodes: Vec<Nibbles> = Vec::new();
let mut modified_original: Option<(Nibbles, SparseNode)> = None;
while let Some(current_path) = current {
match self.update_next_node(current_path, &full_path, retain_updates)? {
// Save original node for potential rollback (only if not already saved)
if modified_original.is_none() &&
let Some(node) = self.nodes.get(&current_path)
{
modified_original = Some((current_path, node.clone()));
}
let step_result = self.update_next_node(current_path, &full_path, retain_updates);
// Handle errors from update_next_node - rollback and propagate
if let Err(e) = step_result {
self.rollback_leaf_insert(&full_path, &inserted_nodes, modified_original.take());
return Err(e);
}
match step_result? {
LeafUpdateStep::Continue { next_node } => {
current = Some(next_node);
// Clear modified_original since we haven't actually modified anything yet
modified_original = None;
}
LeafUpdateStep::Complete { reveal_path, .. } => {
LeafUpdateStep::Complete { inserted_nodes: new_inserted, reveal_path } => {
inserted_nodes.extend(new_inserted);
if let Some(reveal_path) = reveal_path &&
self.nodes.get(&reveal_path).expect("node must exist").is_hash()
{
@@ -2142,10 +2163,29 @@ impl SparseSubtrie {
leaf_full_path = ?full_path,
"Extension node child not revealed in update_leaf, falling back to db",
);
if let Some(RevealedNode { node, tree_mask, hash_mask }) =
provider.trie_node(&reveal_path)?
{
let decoded = TrieNode::decode(&mut &node[..])?;
let revealed_node = match provider.trie_node(&reveal_path) {
Ok(node) => node,
Err(e) => {
self.rollback_leaf_insert(
&full_path,
&inserted_nodes,
modified_original.take(),
);
return Err(e);
}
};
if let Some(RevealedNode { node, tree_mask, hash_mask }) = revealed_node {
let decoded = match TrieNode::decode(&mut &node[..]) {
Ok(d) => d,
Err(e) => {
self.rollback_leaf_insert(
&full_path,
&inserted_nodes,
modified_original.take(),
);
return Err(e.into());
}
};
trace!(
target: "trie::parallel_sparse",
?reveal_path,
@@ -2155,7 +2195,14 @@ impl SparseSubtrie {
"Revealing child (from lower)",
);
let masks = BranchNodeMasks::from_optional(hash_mask, tree_mask);
self.reveal_node(reveal_path, &decoded, masks)?;
if let Err(e) = self.reveal_node(reveal_path, &decoded, masks) {
self.rollback_leaf_insert(
&full_path,
&inserted_nodes,
modified_original.take(),
);
return Err(e);
}
debug_assert_eq!(
revealed, None,
@@ -2163,6 +2210,11 @@ impl SparseSubtrie {
);
revealed = masks.map(|masks| (reveal_path, masks));
} else {
self.rollback_leaf_insert(
&full_path,
&inserted_nodes,
modified_original.take(),
);
return Err(SparseTrieErrorKind::NodeNotFoundInProvider {
path: reveal_path,
}
@@ -2178,9 +2230,36 @@ impl SparseSubtrie {
}
}
// Only insert the value after all structural changes succeed
self.inner.values.insert(full_path, value);
Ok(revealed)
}
/// Rollback structural changes made during a failed leaf insert.
///
/// This removes any nodes that were inserted and restores the original node
/// that was modified, ensuring atomicity of `update_leaf`.
fn rollback_leaf_insert(
&mut self,
full_path: &Nibbles,
inserted_nodes: &[Nibbles],
modified_original: Option<(Nibbles, SparseNode)>,
) {
// Remove any values that may have been inserted
self.inner.values.remove(full_path);
// Remove all inserted nodes
for node_path in inserted_nodes {
self.nodes.remove(node_path);
}
// Restore the original node that was modified
if let Some((path, original_node)) = modified_original {
self.nodes.insert(path, original_node);
}
}
/// Processes the current node, returning what to do next in the leaf update process.
///
/// This will add or update any nodes in the trie as necessary.

View File

@@ -281,15 +281,18 @@ pub trait SparseTrieExt: SparseTrie {
/// Once that proof is calculated and revealed via [`SparseTrie::reveal_nodes`], the same
/// `updates` map can be reused to retry the update.
///
/// Proof targets are deduplicated by `(full_path, min_len)` across all calls to this method.
/// The callback will only be invoked once per unique target, even across retry loops.
/// A deeper blinded node (higher `min_len`) for the same path is considered a new target.
/// The callback receives `(key, min_len)` where `key` is the full 32-byte hashed key
/// (right-padded with zeros from the blinded path) and `min_len` is the minimum depth
/// at which proof nodes should be returned.
///
/// The callback may be invoked multiple times for the same target across retry loops.
/// Callers should deduplicate if needed.
///
/// [`LeafUpdate::Touched`] behaves identically except it does not modify the leaf value.
fn update_leaves(
&mut self,
updates: &mut B256Map<LeafUpdate>,
proof_required_fn: impl FnMut(Nibbles, u8),
proof_required_fn: impl FnMut(B256, u8),
) -> SparseTrieResult<()>;
}

View File

@@ -301,14 +301,13 @@ impl<T: SparseTrieExt + Default> RevealableSparseTrie<T> {
pub fn update_leaves(
&mut self,
updates: &mut B256Map<LeafUpdate>,
mut proof_required_fn: impl FnMut(Nibbles, u8),
mut proof_required_fn: impl FnMut(B256, u8),
) -> SparseTrieResult<()> {
match self {
Self::Blind(_) => {
// Nothing is revealed - emit proof targets for all keys with min_len = 0
for key in updates.keys() {
let full_path = Nibbles::unpack(*key);
proof_required_fn(full_path, 0);
proof_required_fn(*key, 0);
}
// All updates remain in the map for retry after proofs are fetched
Ok(())