fix: add more safety checks to reveals of upper subtrie nodes (#21905)

This commit is contained in:
Arsenii Kulikov
2026-02-06 23:06:30 +04:00
committed by GitHub
parent 28f5a28a9a
commit 9aee291093

View File

@@ -2,7 +2,7 @@ use crate::LowerSparseSubtrie;
use alloc::borrow::Cow;
use alloy_primitives::{
map::{Entry, HashMap},
B256,
B256, U256,
};
use alloy_rlp::Decodable;
use alloy_trie::{BranchNodeCompact, TrieMask, EMPTY_ROOT_HASH};
@@ -214,13 +214,25 @@ impl SparseTrie for ParallelSparseTrie {
self.reveal_upper_node(node.path, &node.node, node.masks)?;
}
let reachable_subtries = self.reachable_subtries();
if !self.is_reveal_parallelism_enabled(lower_nodes.len()) {
for node in lower_nodes {
if let Some(subtrie) = self.lower_subtrie_for_path_mut(&node.path) {
subtrie.reveal_node(node.path, &node.node, node.masks)?;
} else {
panic!("upper subtrie node {node:?} found amongst lower nodes");
let idx = path_subtrie_index_unchecked(&node.path);
if !reachable_subtries.get(idx) {
trace!(
target: "trie::parallel_sparse",
reveal_path = ?node.path,
"Node's lower subtrie is not reachable, skipping",
);
continue;
}
self.lower_subtries[idx].reveal(&node.path);
self.subtrie_heat.mark_modified(idx);
self.lower_subtries[idx]
.as_revealed_mut()
.expect("just revealed")
.reveal_node(node.path, &node.node, node.masks)?;
}
return Ok(())
}
@@ -247,19 +259,29 @@ impl SparseTrie for ParallelSparseTrie {
// `zip` to be happy.
let lower_subtries: Vec<_> = node_groups
.iter()
.map(|nodes| {
.filter_map(|nodes| {
// NOTE: chunk_by won't produce empty groups
let node = &nodes[0];
let idx =
SparseSubtrieType::from_path(&node.path).lower_index().unwrap_or_else(
|| panic!("upper subtrie node {node:?} found amongst lower nodes"),
);
if !reachable_subtries.get(idx) {
trace!(
target: "trie::parallel_sparse",
nodes = ?nodes,
"Lower subtrie is not reachable, skipping reveal",
);
return None;
}
// due to the nodes being sorted secondarily on their path, and chunk_by keeping
// the first element of each group, the `path` here will necessarily be the
// shortest path being revealed for each subtrie. Therefore we can reveal the
// subtrie itself using this path and retain correct behavior.
self.lower_subtries[idx].reveal(&node.path);
(idx, self.lower_subtries[idx].take_revealed().expect("just revealed"))
Some((idx, self.lower_subtries[idx].take_revealed().expect("just revealed")))
})
.collect();
@@ -277,7 +299,7 @@ impl SparseTrie for ParallelSparseTrie {
// Reveal each node in the subtrie, returning early on any errors
let res = subtrie.reveal_node(node.path, &node.node, node.masks);
if res.is_err() {
return (subtrie_idx, subtrie, res)
return (subtrie_idx, subtrie, res.map(|_| ()))
}
}
(subtrie_idx, subtrie, Ok(()))
@@ -2067,9 +2089,16 @@ impl ParallelSparseTrie {
node: &TrieNode,
masks: Option<BranchNodeMasks>,
) -> SparseTrieResult<()> {
// If there is no subtrie for the path it means the path is UPPER_TRIE_MAX_DEPTH or less
// nibbles, and so belongs to the upper trie.
self.upper_subtrie.reveal_node(path, node, masks)?;
// Only reveal nodes that can be reached given the current state of the upper trie. If they
// can't be reached, it means that they were removed.
if !self.is_path_reachable_from_upper(&path) {
return Ok(())
}
// Exit early if the node was already revealed before.
if !self.upper_subtrie.reveal_node(path, node, masks)? {
return Ok(())
}
// The previous upper_trie.reveal_node call will not have revealed any child nodes via
// reveal_node_or_hash if the child node would be found on a lower subtrie. We handle that
@@ -2169,31 +2198,94 @@ impl ParallelSparseTrie {
size
}
/// Determines if the given path can be directly reached from the upper trie.
fn is_path_reachable_from_upper(&self, path: &Nibbles) -> bool {
let mut current = Nibbles::default();
while current.len() < path.len() {
let Some(node) = self.upper_subtrie.nodes.get(&current) else { return false };
match node {
SparseNode::Branch { state_mask, .. } => {
if !state_mask.is_bit_set(path.get_unchecked(current.len())) {
return false
}
current.push_unchecked(path.get_unchecked(current.len()));
}
SparseNode::Extension { key, .. } => {
if *key != path.slice(current.len()..current.len() + key.len()) {
return false
}
current.extend(key);
}
SparseNode::Hash(_) | SparseNode::Empty | SparseNode::Leaf { .. } => return false,
}
}
true
}
/// Returns a bitset of all subtries that are reachable from the upper trie. If subtrie is not
/// reachable it means that it does not exist.
fn reachable_subtries(&self) -> SubtriesBitmap {
let mut reachable = SubtriesBitmap::default();
let mut stack = Vec::new();
stack.push(Nibbles::default());
while let Some(current) = stack.pop() {
let Some(node) = self.upper_subtrie.nodes.get(&current) else { continue };
match node {
SparseNode::Branch { state_mask, .. } => {
for idx in state_mask.iter() {
let mut next = current;
next.push_unchecked(idx);
if next.len() >= UPPER_TRIE_MAX_DEPTH {
reachable.set(path_subtrie_index_unchecked(&next));
} else {
stack.push(next);
}
}
}
SparseNode::Extension { key, .. } => {
let mut next = current;
next.extend(key);
if next.len() >= UPPER_TRIE_MAX_DEPTH {
reachable.set(path_subtrie_index_unchecked(&next));
} else {
stack.push(next);
}
}
SparseNode::Hash(_) | SparseNode::Empty | SparseNode::Leaf { .. } => {}
};
}
reachable
}
}
/// Bitset tracking which of the 256 lower subtries were modified in the current cycle.
#[derive(Clone, Default, PartialEq, Eq, Debug)]
struct ModifiedSubtries([u64; 4]);
struct SubtriesBitmap(U256);
impl ModifiedSubtries {
impl SubtriesBitmap {
/// Marks a subtrie index as modified.
#[inline]
fn set(&mut self, idx: usize) {
debug_assert!(idx < NUM_LOWER_SUBTRIES);
self.0[idx >> 6] |= 1 << (idx & 63);
self.0.set_bit(idx, true);
}
/// Returns whether a subtrie index is marked as modified.
#[inline]
fn get(&self, idx: usize) -> bool {
debug_assert!(idx < NUM_LOWER_SUBTRIES);
(self.0[idx >> 6] & (1 << (idx & 63))) != 0
self.0.bit(idx)
}
/// Clears all modification flags.
#[inline]
const fn clear(&mut self) {
self.0 = [0; 4];
self.0 = U256::ZERO;
}
}
@@ -2210,12 +2302,12 @@ struct SubtrieModifications {
/// Heat level (0-255) for each of the 256 lower subtries.
heat: [u8; NUM_LOWER_SUBTRIES],
/// Tracks which subtries were modified in the current cycle.
modified: ModifiedSubtries,
modified: SubtriesBitmap,
}
impl Default for SubtrieModifications {
fn default() -> Self {
Self { heat: [0; NUM_LOWER_SUBTRIES], modified: ModifiedSubtries::default() }
Self { heat: [0; NUM_LOWER_SUBTRIES], modified: SubtriesBitmap::default() }
}
}
@@ -2606,12 +2698,12 @@ impl SparseSubtrie {
path: Nibbles,
node: &TrieNode,
masks: Option<BranchNodeMasks>,
) -> SparseTrieResult<()> {
) -> SparseTrieResult<bool> {
debug_assert!(path.starts_with(&self.path));
// If the node is already revealed and it's not a hash node, do nothing.
if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) {
return Ok(())
return Ok(false)
}
match node {
@@ -2650,17 +2742,7 @@ impl SparseSubtrie {
})),
});
}
// Branch node already exists, or an extension node was placed where a
// branch node was before.
SparseNode::Branch { .. } | SparseNode::Extension { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: *entry.key(),
node: Box::new(node.clone()),
}
.into())
}
_ => unreachable!("checked that node is either a hash or non-existent"),
},
Entry::Vacant(entry) => {
entry.insert(SparseNode::new_branch(branch.state_mask));
@@ -2684,17 +2766,7 @@ impl SparseSubtrie {
self.reveal_node_or_hash(child_path, &ext.child)?;
}
}
// Extension node already exists, or an extension node was placed where a branch
// node was before.
SparseNode::Extension { .. } | SparseNode::Branch { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty | SparseNode::Leaf { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: *entry.key(),
node: Box::new(node.clone()),
}
.into())
}
_ => unreachable!("checked that node is either a hash or non-existent"),
},
Entry::Vacant(entry) => {
let mut child_path = *entry.key();
@@ -2719,18 +2791,7 @@ impl SparseSubtrie {
hash: Some(*hash),
});
}
// Leaf node already exists.
SparseNode::Leaf { .. } => {}
// All other node types can't be handled.
node @ (SparseNode::Empty |
SparseNode::Extension { .. } |
SparseNode::Branch { .. }) => {
return Err(SparseTrieErrorKind::Reveal {
path: *entry.key(),
node: Box::new(node.clone()),
}
.into())
}
_ => unreachable!("checked that node is either a hash or non-existent"),
},
Entry::Vacant(entry) => {
let mut full = *entry.key();
@@ -2741,7 +2802,7 @@ impl SparseSubtrie {
},
}
Ok(())
Ok(true)
}
/// Reveals either a node or its hash placeholder based on the provided child data.
@@ -2784,7 +2845,9 @@ impl SparseSubtrie {
return Ok(())
}
self.reveal_node(path, &TrieNode::decode(&mut &child[..])?, None)
self.reveal_node(path, &TrieNode::decode(&mut &child[..])?, None)?;
Ok(())
}
/// Recalculates and updates the RLP hashes for the changed nodes in this subtrie.
@@ -4082,9 +4145,12 @@ mod tests {
#[test]
fn test_reveal_node_leaves() {
let mut trie = ParallelSparseTrie::default();
// Reveal leaf in the upper trie. A root branch with child 0x1 makes path [0x1]
// reachable for the subsequent reveal_nodes call.
let root_branch =
create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
// Reveal leaf in the upper trie
{
let path = Nibbles::from_nibbles([0x1]);
let node = create_leaf_node([0x2, 0x3], 42);
@@ -4094,7 +4160,7 @@ mod tests {
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Leaf { key, hash: None })
Some(SparseNode::Leaf { key, hash: Some(_) })
if key == &Nibbles::from_nibbles([0x2, 0x3])
);
@@ -4105,7 +4171,22 @@ mod tests {
);
}
// Reveal leaf in a lower trie
// Reveal leaf in a lower trie. A separate trie is needed because the structure at
// [0x1] conflicts: the upper trie test placed a leaf there, but reaching [0x1, 0x2]
// requires a branch at [0x1]. A root branch → branch at [0x1] with child 0x2
// makes path [0x1, 0x2] reachable.
let root_branch =
create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]);
let branch_at_1 =
create_branch_node_with_children(&[0x2], [RlpNode::word_rlp(&B256::repeat_byte(0xBB))]);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
trie.reveal_nodes(&mut [ProofTrieNode {
path: Nibbles::from_nibbles([0x1]),
node: branch_at_1,
masks: None,
}])
.unwrap();
{
let path = Nibbles::from_nibbles([0x1, 0x2]);
let node = create_leaf_node([0x3, 0x4], 42);
@@ -4123,7 +4204,7 @@ mod tests {
assert_matches!(
lower_subtrie.nodes.get(&path),
Some(SparseNode::Leaf { key, hash: None })
Some(SparseNode::Leaf { key, hash: Some(_) })
if key == &Nibbles::from_nibbles([0x3, 0x4])
);
}
@@ -4190,7 +4271,11 @@ mod tests {
#[test]
fn test_reveal_node_extension_cross_level_boundary() {
let mut trie = ParallelSparseTrie::default();
// Set up root branch with nibble 0x1 so path [0x1] is reachable.
let root_branch =
create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
let path = Nibbles::from_nibbles([0x1]);
let child_hash = B256::repeat_byte(0xcd);
let node = create_extension_node([0x2], child_hash);
@@ -4198,10 +4283,10 @@ mod tests {
trie.reveal_nodes(&mut [ProofTrieNode { path, node, masks }]).unwrap();
// Extension node should be in upper trie
// Extension node should be in upper trie, hash is memoized from the previous Hash node
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Extension { key, hash: None, .. })
Some(SparseNode::Extension { key, hash: Some(_), .. })
if key == &Nibbles::from_nibbles([0x2])
);
@@ -4248,7 +4333,11 @@ mod tests {
#[test]
fn test_reveal_node_branch_cross_level() {
let mut trie = ParallelSparseTrie::default();
// Set up root branch with nibble 0x1 so path [0x1] is reachable.
let root_branch =
create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
let path = Nibbles::from_nibbles([0x1]); // Exactly 1 nibbles - boundary case
let child_hashes = [
RlpNode::word_rlp(&B256::repeat_byte(0x33)),
@@ -4260,10 +4349,10 @@ mod tests {
trie.reveal_nodes(&mut [ProofTrieNode { path, node, masks }]).unwrap();
// Branch node should be in upper trie
// Branch node should be in upper trie, hash is memoized from the previous Hash node
assert_matches!(
trie.upper_subtrie.nodes.get(&path),
Some(SparseNode::Branch { state_mask, hash: None, .. })
Some(SparseNode::Branch { state_mask, hash: Some(_), .. })
if *state_mask == 0b1000000010000001.into()
);
@@ -4287,10 +4376,18 @@ mod tests {
#[test]
fn test_update_subtrie_hashes_prefix_set_matching() {
// Create a trie and reveal leaf nodes using reveal_nodes
let mut trie = ParallelSparseTrie::default();
// Create a trie with a root branch that makes paths [0x0, ...] and [0x3, ...]
// reachable from the upper trie.
let root_branch = create_branch_node_with_children(
&[0x0, 0x3],
[
RlpNode::word_rlp(&B256::repeat_byte(0xAA)),
RlpNode::word_rlp(&B256::repeat_byte(0xBB)),
],
);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
// Create dummy leaf nodes.
// Create leaf paths.
let leaf_1_full_path = Nibbles::from_nibbles([0; 64]);
let leaf_1_path = leaf_1_full_path.slice(..2);
let leaf_1_key = leaf_1_full_path.slice(2..);
@@ -4298,33 +4395,36 @@ mod tests {
let leaf_2_path = leaf_2_full_path.slice(..2);
let leaf_2_key = leaf_2_full_path.slice(2..);
let leaf_3_full_path = Nibbles::from_nibbles([vec![0, 2], vec![0; 62]].concat());
let leaf_3_path = leaf_3_full_path.slice(..2);
let leaf_3_key = leaf_3_full_path.slice(2..);
let leaf_1 = create_leaf_node(leaf_1_key.to_vec(), 1);
let leaf_2 = create_leaf_node(leaf_2_key.to_vec(), 2);
let leaf_3 = create_leaf_node(leaf_3_key.to_vec(), 3);
// Create branch node with hashes for each leaf.
// Create branch node at [0x0] with only children 0x0 and 0x1.
// Child 0x2 (leaf_3) will be inserted via update_leaf to create a fresh node
// with hash: None.
let child_hashes = [
RlpNode::word_rlp(&B256::repeat_byte(0x00)),
RlpNode::word_rlp(&B256::repeat_byte(0x11)),
// deliberately omit hash for leaf_3
];
let branch_path = Nibbles::from_nibbles([0x0]);
let branch_node = create_branch_node_with_children(&[0x0, 0x1, 0x2], child_hashes);
let branch_node = create_branch_node_with_children(&[0x0, 0x1], child_hashes);
// Reveal nodes using reveal_nodes
// Reveal the existing nodes
trie.reveal_nodes(&mut [
ProofTrieNode { path: branch_path, node: branch_node, masks: None },
ProofTrieNode { path: leaf_1_path, node: leaf_1, masks: None },
ProofTrieNode { path: leaf_2_path, node: leaf_2, masks: None },
ProofTrieNode { path: leaf_3_path, node: leaf_3, masks: None },
])
.unwrap();
// Insert leaf_3 via update_leaf. This modifies the branch at [0x0] to add child
// 0x2 and creates a fresh leaf node with hash: None in the lower subtrie.
let provider = MockTrieNodeProvider::new();
trie.update_leaf(leaf_3_full_path, encode_account_value(3), provider).unwrap();
// Calculate subtrie indexes
let subtrie_1_index = SparseSubtrieType::from_path(&leaf_1_path).lower_index().unwrap();
let subtrie_2_index = SparseSubtrieType::from_path(&leaf_2_path).lower_index().unwrap();
let leaf_3_path = leaf_3_full_path.slice(..2);
let subtrie_3_index = SparseSubtrieType::from_path(&leaf_3_path).lower_index().unwrap();
let mut unchanged_prefix_set = PrefixSetMut::from([
@@ -7798,7 +7898,19 @@ mod tests {
// This test demonstrates that get_leaf_value must look in the correct subtrie,
// not always in upper_subtrie.
let mut trie = ParallelSparseTrie::default();
// Set up a root branch pointing to nibble 0x1, and a branch at [0x1] pointing to
// nibble 0x2, so that the lower subtrie at [0x1, 0x2] is reachable.
let root_branch =
create_branch_node_with_children(&[0x1], [RlpNode::word_rlp(&B256::repeat_byte(0xAA))]);
let branch_at_1 =
create_branch_node_with_children(&[0x2], [RlpNode::word_rlp(&B256::repeat_byte(0xBB))]);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
trie.reveal_nodes(&mut [ProofTrieNode {
path: Nibbles::from_nibbles([0x1]),
node: branch_at_1,
masks: None,
}])
.unwrap();
// Create a leaf node with path >= 2 nibbles (will go to lower subtrie)
let leaf_path = Nibbles::from_nibbles([0x1, 0x2]);
@@ -8992,8 +9104,36 @@ mod tests {
// Should at least be the size of the struct itself
assert!(empty_size >= core::mem::size_of::<ParallelSparseTrie>());
// Create a trie with some data
let mut trie = ParallelSparseTrie::default();
// Create a trie with some data. Set up a root branch with children at 0x1 and
// 0x5, and branches at [0x1] and [0x5] pointing to 0x2 and 0x6 respectively,
// so the lower subtries at [0x1, 0x2] and [0x5, 0x6] are reachable.
let root_branch = create_branch_node_with_children(
&[0x1, 0x5],
[
RlpNode::word_rlp(&B256::repeat_byte(0xAA)),
RlpNode::word_rlp(&B256::repeat_byte(0xBB)),
],
);
let mut trie = ParallelSparseTrie::from_root(root_branch, None, false).unwrap();
let branch_at_1 =
create_branch_node_with_children(&[0x2], [RlpNode::word_rlp(&B256::repeat_byte(0xCC))]);
let branch_at_5 =
create_branch_node_with_children(&[0x6], [RlpNode::word_rlp(&B256::repeat_byte(0xDD))]);
trie.reveal_nodes(&mut [
ProofTrieNode {
path: Nibbles::from_nibbles_unchecked([0x1]),
node: branch_at_1,
masks: None,
},
ProofTrieNode {
path: Nibbles::from_nibbles_unchecked([0x5]),
node: branch_at_5,
masks: None,
},
])
.unwrap();
let mut nodes = vec![
ProofTrieNode {
path: Nibbles::from_nibbles_unchecked([0x1, 0x2]),