diff --git a/Cargo.lock b/Cargo.lock index fa170f8721..85d6a6d2e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5317,8 +5317,13 @@ name = "reth-trie" version = "0.1.0" dependencies = [ "hex", + "proptest", "reth-primitives", "reth-rlp", + "tokio", + "tokio-stream", + "tracing", + "triehash", ] [[package]] diff --git a/crates/primitives/src/trie/branch_node.rs b/crates/primitives/src/trie/branch_node.rs index 40e29642da..e8fc6d6ada 100644 --- a/crates/primitives/src/trie/branch_node.rs +++ b/crates/primitives/src/trie/branch_node.rs @@ -53,7 +53,7 @@ impl BranchNodeCompact { } /// Returns the hash associated with the given nibble. - pub fn hash_for_nibble(&self, nibble: i8) -> H256 { + pub fn hash_for_nibble(&self, nibble: u8) -> H256 { let mask = *TrieMask::from_nibble(nibble) - 1; let index = (*self.hash_mask & mask).count_ones(); self.hashes[index as usize] diff --git a/crates/primitives/src/trie/mask.rs b/crates/primitives/src/trie/mask.rs index 79f1eaa85b..1eb3fe67ff 100644 --- a/crates/primitives/src/trie/mask.rs +++ b/crates/primitives/src/trie/mask.rs @@ -1,10 +1,14 @@ -use derive_more::{BitAnd, Deref, From}; +use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Deref, From, Not}; use reth_codecs::Compact; use serde::{Deserialize, Serialize}; /// A struct representing a mask of 16 bits, used for Ethereum trie operations. +/// +/// Masks in a trie are used to efficiently represent and manage information about the presence or +/// absence of certain elements, such as child nodes, within a trie. Masks are usually implemented +/// as bit vectors, where each bit represents the presence (1) or absence (0) of a corresponding +/// element. #[derive( - Debug, Default, Clone, Copy, @@ -17,12 +21,21 @@ use serde::{Deserialize, Serialize}; Deref, From, BitAnd, + BitAndAssign, + BitOr, + BitOrAssign, + Not, )] pub struct TrieMask(u16); impl TrieMask { + /// Creates a new `TrieMask` from the given inner value. + pub fn new(inner: u16) -> Self { + Self(inner) + } + /// Creates a new `TrieMask` from the given nibble. - pub fn from_nibble(nibble: i8) -> Self { + pub fn from_nibble(nibble: u8) -> Self { Self(1u16 << nibble) } @@ -30,6 +43,22 @@ impl TrieMask { pub fn is_subset_of(&self, other: &Self) -> bool { *self & *other == *self } + + /// Returns `true` if a given bit is set in a mask. + pub fn is_bit_set(&self, index: i32) -> bool { + self.0 & (1u16 << index) != 0 + } + + /// Returns `true` if the mask is empty. + pub fn is_empty(&self) -> bool { + self.0 == 0 + } +} + +impl std::fmt::Debug for TrieMask { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TrieMask({:016b})", self.0) + } } impl Compact for TrieMask { diff --git a/crates/primitives/src/trie/nibbles.rs b/crates/primitives/src/trie/nibbles.rs index 944bad72f5..82b634fae0 100644 --- a/crates/primitives/src/trie/nibbles.rs +++ b/crates/primitives/src/trie/nibbles.rs @@ -30,6 +30,12 @@ pub struct Nibbles { pub hex_data: Vec, } +impl From<&[u8]> for Nibbles { + fn from(slice: &[u8]) -> Self { + Nibbles::from_hex(slice.to_vec()) + } +} + impl std::fmt::Debug for Nibbles { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Nibbles").field("hex_data", &hex::encode(&self.hex_data)).finish() @@ -187,7 +193,7 @@ impl Nibbles { } /// Slice the current nibbles from the given start index to the end. - pub fn slice_at(&self, index: usize) -> Nibbles { + pub fn slice_from(&self, index: usize) -> Nibbles { self.slice(index, self.hex_data.len()) } diff --git a/crates/trie/Cargo.toml b/crates/trie/Cargo.toml index 5d09bd5374..ceefac27dd 100644 --- a/crates/trie/Cargo.toml +++ b/crates/trie/Cargo.toml @@ -14,9 +14,23 @@ Merkle trie implementation reth-primitives = { path = "../primitives" } reth-rlp = { path = "../rlp" } +# tokio +tokio = { version = "1.21.2", default-features = false, features = ["sync"] } + +# tracing +tracing = "0.1" + # misc hex = "0.4" [dev-dependencies] # reth -reth-primitives = { path = "../primitives", features = ["test-utils"] } \ No newline at end of file +reth-primitives = { path = "../primitives", features = ["test-utils", "arbitrary"] } + +# trie +triehash = "0.8" + +# misc +proptest = "1.0" +tokio = { version = "1.21.2", default-features = false, features = ["sync", "rt", "macros"] } +tokio-stream = "0.1.10" diff --git a/crates/trie/src/hash_builder/mod.rs b/crates/trie/src/hash_builder/mod.rs new file mode 100644 index 0000000000..bf44da6c4d --- /dev/null +++ b/crates/trie/src/hash_builder/mod.rs @@ -0,0 +1,570 @@ +use crate::nodes::{rlp_hash, BranchNode, ExtensionNode, LeafNode}; +use reth_primitives::{ + keccak256, + proofs::EMPTY_ROOT, + trie::{BranchNodeCompact, Nibbles, TrieMask}, + H256, +}; +use std::fmt::Debug; +use tokio::sync::mpsc; + +mod value; +use value::HashBuilderValue; + +/// A type alias for a sender of branch nodes. +pub type BranchNodeSender = mpsc::UnboundedSender<(Nibbles, BranchNodeCompact)>; + +/// A component used to construct the root hash of the trie. The primary purpose of a Hash Builder +/// is to build the Merkle proof that is essential for verifying the integrity and authenticity of +/// the trie's contents. It achieves this by constructing the root hash from the hashes of child +/// nodes according to specific rules, depending on the type of the node (branch, extension, or +/// leaf). +/// +/// Here's an overview of how the Hash Builder works for each type of node: +/// * Branch Node: The Hash Builder combines the hashes of all the child nodes of the branch node, +/// using a cryptographic hash function like SHA-256. The child nodes' hashes are concatenated +/// and hashed, and the result is considered the hash of the branch node. The process is repeated +/// recursively until the root hash is obtained. +/// * Extension Node: In the case of an extension node, the Hash Builder first encodes the node's +/// shared nibble path, followed by the hash of the next child node. It concatenates these values +/// and then computes the hash of the resulting data, which represents the hash of the extension +/// node. +/// * Leaf Node: For a leaf node, the Hash Builder first encodes the key-path and the value of the +/// leaf node. It then concatenates theĀ encoded key-path and value, and computes the hash of this +/// concatenated data, which represents the hash of the leaf node. +/// +/// The Hash Builder operates recursively, starting from the bottom of the trie and working its way +/// up, combining the hashes of child nodes and ultimately generating the root hash. The root hash +/// can then be used to verify the integrity and authenticity of the trie's data by constructing and +/// verifying Merkle proofs. +#[derive(Clone, Debug, Default)] +pub struct HashBuilder { + key: Nibbles, + stack: Vec>, + value: HashBuilderValue, + + groups: Vec, + tree_masks: Vec, + hash_masks: Vec, + + stored_in_database: bool, + + branch_node_sender: Option, +} + +impl HashBuilder { + /// Creates a new instance of the Hash Builder. + pub fn new(store_tx: Option) -> Self { + Self { branch_node_sender: store_tx, ..Default::default() } + } + + /// Set a branch node sender on the Hash Builder instance. + pub fn with_branch_node_sender(mut self, tx: BranchNodeSender) -> Self { + self.branch_node_sender = Some(tx); + self + } + + /// Print the current stack of the Hash Builder. + pub fn print_stack(&self) { + println!("============ STACK ==============="); + for item in &self.stack { + println!("{}", hex::encode(item)); + } + println!("============ END STACK ==============="); + } + + /// Adds a new leaf element & its value to the trie hash builder. + pub fn add_leaf(&mut self, key: Nibbles, value: &[u8]) { + assert!(key > self.key); + if !self.key.is_empty() { + self.update(&key); + } + self.set_key_value(key, value); + } + + /// Adds a new branch element & its hash to the trie hash builder. + pub fn add_branch(&mut self, key: Nibbles, value: H256, stored_in_database: bool) { + assert!(key > self.key || (self.key.is_empty() && key.is_empty())); + if !self.key.is_empty() { + self.update(&key); + } else if key.is_empty() { + self.stack.push(rlp_hash(value)); + } + self.set_key_value(key, value); + self.stored_in_database = stored_in_database; + } + + fn set_key_value>(&mut self, key: Nibbles, value: T) { + tracing::trace!(target: "trie::hash_builder", key = ?self.key, value = ?self.value, "old key/value"); + self.key = key; + self.value = value.into(); + tracing::trace!(target: "trie::hash_builder", key = ?self.key, value = ?self.value, "new key/value"); + } + + /// Returns the current root hash of the trie builder. + pub fn root(&mut self) -> H256 { + // Clears the internal state + if !self.key.is_empty() { + self.update(&Nibbles::default()); + self.key.clear(); + self.value = HashBuilderValue::Bytes(vec![]); + } + self.current_root() + } + + fn current_root(&self) -> H256 { + if let Some(node_ref) = self.stack.last() { + if node_ref.len() == H256::len_bytes() + 1 { + H256::from_slice(&node_ref[1..]) + } else { + keccak256(node_ref) + } + } else { + EMPTY_ROOT + } + } + + /// Given a new element, it appends it to the stack and proceeds to loop through the stack state + /// and convert the nodes it can into branch / extension nodes and hash them. This ensures + /// that the top of the stack always contains the merkle root corresponding to the trie + /// built so far. + fn update(&mut self, succeeding: &Nibbles) { + let mut build_extensions = false; + // current / self.key is always the latest added element in the trie + let mut current = self.key.clone(); + + tracing::debug!(target: "trie::hash_builder", ?current, ?succeeding, "updating merkle tree"); + + let mut i = 0; + loop { + let span = tracing::span!( + target: "trie::hash_builder", + tracing::Level::TRACE, + "loop", + i, + current = hex::encode(¤t.hex_data), + ?build_extensions + ); + let _enter = span.enter(); + + let preceding_exists = !self.groups.is_empty(); + let preceding_len: usize = self.groups.len().saturating_sub(1); + + let common_prefix_len = succeeding.common_prefix_length(¤t); + let len = std::cmp::max(preceding_len, common_prefix_len); + assert!(len < current.len()); + + tracing::trace!( + target: "trie::hash_builder", + ?len, + ?common_prefix_len, + ?preceding_len, + preceding_exists, + "prefix lengths after comparing keys" + ); + + // Adjust the state masks for branch calculation + let extra_digit = current[len]; + if self.groups.len() <= len { + let new_len = len + 1; + tracing::trace!(target: "trie::hash_builder", new_len, old_len = self.groups.len(), "scaling state masks to fit"); + self.groups.resize(new_len, TrieMask::default()); + } + self.groups[len] |= TrieMask::from_nibble(extra_digit); + tracing::trace!( + target: "trie::hash_builder", + ?extra_digit, + groups = self.groups.iter().map(|x| format!("{x:?}")).collect::>().join(","), + ); + + // Adjust the tree masks for exporting to the DB + if self.tree_masks.len() < current.len() { + self.resize_masks(current.len()); + } + + let mut len_from = len; + if !succeeding.is_empty() || preceding_exists { + len_from += 1; + } + tracing::trace!(target: "trie::hash_builder", "skipping {} nibbles", len_from); + + // The key without the common prefix + let short_node_key = current.slice_from(len_from); + tracing::trace!(target: "trie::hash_builder", ?short_node_key); + + // Concatenate the 2 nodes together + if !build_extensions { + match &self.value { + HashBuilderValue::Bytes(leaf_value) => { + let leaf_node = LeafNode::new(&short_node_key, leaf_value); + tracing::debug!(target: "trie::hash_builder", ?leaf_node, "pushing leaf node"); + tracing::trace!(target: "trie::hash_builder", rlp = hex::encode(&leaf_node.rlp()), "leaf node rlp"); + self.stack.push(leaf_node.rlp()); + } + HashBuilderValue::Hash(hash) => { + tracing::debug!(target: "trie::hash_builder", ?hash, "pushing branch node hash"); + self.stack.push(rlp_hash(*hash)); + + if self.stored_in_database { + self.tree_masks[current.len() - 1] |= + TrieMask::from_nibble(current.last().unwrap()); + } + self.hash_masks[current.len() - 1] |= + TrieMask::from_nibble(current.last().unwrap()); + + build_extensions = true; + } + } + } + + if build_extensions && !short_node_key.is_empty() { + self.update_masks(¤t, len_from); + let stack_last = + self.stack.pop().expect("there should be at least one stack item; qed"); + let extension_node = ExtensionNode::new(&short_node_key, &stack_last); + tracing::debug!(target: "trie::hash_builder", ?extension_node, "pushing extension node"); + tracing::trace!(target: "trie::hash_builder", rlp = hex::encode(&extension_node.rlp()), "extension node rlp"); + self.stack.push(extension_node.rlp()); + self.resize_masks(len_from); + } + + if preceding_len <= common_prefix_len && !succeeding.is_empty() { + tracing::trace!(target: "trie::hash_builder", "no common prefix to create branch nodes from, returning"); + return + } + + // Insert branch nodes in the stack + if !succeeding.is_empty() || preceding_exists { + // Pushes the corresponding branch node to the stack + let children = self.push_branch_node(len); + // Need to store the branch node in an efficient format + // outside of the hash builder + self.store_branch_node(¤t, len, children); + } + + self.groups.resize(len, TrieMask::default()); + self.resize_masks(len); + + if preceding_len == 0 { + tracing::trace!(target: "trie::hash_builder", "0 or 1 state masks means we have no more elements to process"); + return + } + + current.truncate(preceding_len); + tracing::trace!(target: "trie::hash_builder", ?current, "truncated nibbles to {} bytes", preceding_len); + + tracing::trace!(target: "trie::hash_builder", groups = ?self.groups, "popping empty state masks"); + while self.groups.last() == Some(&TrieMask::default()) { + self.groups.pop(); + } + + build_extensions = true; + + i += 1; + } + } + + /// Given the size of the longest common prefix, it proceeds to create a branch node + /// from the state mask and existing stack state, and store its RLP to the top of the stack, + /// after popping all the relevant elements from the stack. + fn push_branch_node(&mut self, len: usize) -> Vec { + let state_mask = self.groups[len]; + let hash_mask = self.hash_masks[len]; + let branch_node = BranchNode::new(&self.stack); + let children = branch_node.children(state_mask, hash_mask).collect(); + let rlp = branch_node.rlp(state_mask); + + // Clears the stack from the branch node elements + let first_child_idx = self.stack.len() - state_mask.count_ones() as usize; + tracing::debug!( + target: "trie::hash_builder", + new_len = first_child_idx, + old_len = self.stack.len(), + "resizing stack to prepare branch node" + ); + self.stack.resize(first_child_idx, vec![]); + + tracing::debug!(target: "trie::hash_builder", "pushing branch node with {:?} mask from stack", state_mask); + tracing::trace!(target: "trie::hash_builder", rlp = hex::encode(&rlp), "branch node rlp"); + self.stack.push(rlp); + children + } + + /// Given the current nibble prefix and the highest common prefix length, proceeds + /// to update the masks for the next level and store the branch node and the + /// masks in the database. We will use that when consuming the intermediate nodes + /// from the database to efficiently build the trie. + fn store_branch_node(&mut self, current: &Nibbles, len: usize, children: Vec) { + if len > 0 { + let parent_index = len - 1; + self.hash_masks[parent_index] |= TrieMask::from_nibble(current[parent_index]); + } + + let store_in_db_trie = !self.tree_masks[len].is_empty() || !self.hash_masks[len].is_empty(); + if store_in_db_trie { + if len > 0 { + let parent_index = len - 1; + self.tree_masks[parent_index] |= TrieMask::from_nibble(current[parent_index]); + } + + let mut n = BranchNodeCompact::new( + self.groups[len], + self.tree_masks[len], + self.hash_masks[len], + children, + None, + ); + + if len == 0 { + n.root_hash = Some(self.current_root()); + } + + // Send it over to the provided channel which will handle it on the + // other side of the HashBuilder + tracing::debug!(target: "trie::hash_builder", node = ?n, "intermediate node"); + let common_prefix = current.slice(0, len); + if let Some(tx) = &self.branch_node_sender { + let _ = tx.send((common_prefix, n)); + } + } + } + + fn update_masks(&mut self, current: &Nibbles, len_from: usize) { + if len_from > 0 { + let flag = TrieMask::from_nibble(current[len_from - 1]); + + self.hash_masks[len_from - 1] &= !flag; + + if !self.tree_masks[current.len() - 1].is_empty() { + self.tree_masks[len_from - 1] |= flag; + } + } + } + + fn resize_masks(&mut self, new_len: usize) { + tracing::trace!( + target: "trie::hash_builder", + new_len, + old_tree_mask_len = self.tree_masks.len(), + old_hash_mask_len = self.hash_masks.len(), + "resizing tree/hash masks" + ); + self.tree_masks.resize(new_len, TrieMask::default()); + self.hash_masks.resize(new_len, TrieMask::default()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + use reth_primitives::{hex_literal::hex, proofs::KeccakHasher, trie::Nibbles, H256, U256}; + use std::collections::{BTreeMap, HashMap}; + use tokio::sync::mpsc::unbounded_channel; + use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; + + fn trie_root(iter: I) -> H256 + where + I: IntoIterator, + K: AsRef<[u8]> + Ord, + V: AsRef<[u8]>, + { + // We use `trie_root` instead of `sec_trie_root` because we assume + // the incoming keys are already hashed, which makes sense given + // we're going to be using the Hashed tables & pre-hash the data + // on the way in. + triehash::trie_root::(iter) + } + + // Hashes the keys, RLP encodes the values, compares the trie builder with the upstream root. + fn assert_hashed_trie_root<'a, I, K>(iter: I) + where + I: Iterator, + K: AsRef<[u8]> + Ord, + { + let hashed = iter + .map(|(k, v)| (keccak256(k.as_ref()), reth_rlp::encode_fixed_size(v).to_vec())) + // Collect into a btree map to sort the data + .collect::>(); + + let mut hb = HashBuilder::default(); + + hashed.iter().for_each(|(key, val)| { + let nibbles = Nibbles::unpack(key); + hb.add_leaf(nibbles, &val); + }); + + assert_eq!(hb.root(), trie_root(&hashed)); + } + + // No hashing involved + fn assert_trie_root(iter: I) + where + I: Iterator, + K: AsRef<[u8]> + Ord, + V: AsRef<[u8]>, + { + let mut hb = HashBuilder::default(); + + let data = iter.collect::>(); + data.iter().for_each(|(key, val)| { + let nibbles = Nibbles::unpack(key); + hb.add_leaf(nibbles, val.as_ref()); + }); + assert_eq!(hb.root(), trie_root(data)); + } + + #[test] + fn empty() { + assert_eq!(HashBuilder::default().root(), EMPTY_ROOT); + } + + #[test] + fn arbitrary_hashed_root() { + proptest!(|(state: BTreeMap)| { + assert_hashed_trie_root(state.iter()); + }); + } + + #[test] + fn arbitrary_root() { + proptest!(|(state: BTreeMap, Vec>)| { + // filter non-nibbled keys + let state = state.into_iter().filter(|(k, _)| !k.is_empty() && k.len() % 2 == 0).collect::>(); + assert_trie_root(state.into_iter()); + }); + } + + #[tokio::test] + async fn test_generates_branch_node() { + let (sender, recv) = unbounded_channel(); + let mut hb = HashBuilder::new(Some(sender)); + + // We have 1 branch node update to be stored at 0x01, indicated by the first nibble. + // That branch root node has 2 branch node children present at 0x1 and 0x2. + // - 0x1 branch: It has the 2 empty items, at `0` and `1`. + // - 0x2 branch: It has the 2 empty items, at `0` and `2`. + // This is enough information to construct the intermediate node value: + // 1. State Mask: 0b111. The children of the branch + the branch value at `0`, `1` and `2`. + // 2. Hash Mask: 0b110. Of the above items, `1` and `2` correspond to sub-branch nodes. + // 3. Tree Mask: 0b000. + // 4. Hashes: The 2 sub-branch roots, at `1` and `2`, calculated by hashing + // the 0th and 1st element for the 0x1 branch (according to the 3rd nibble), + // and the 0th and 2nd element for the 0x2 branch (according to the 3rd nibble). + // This basically means that every BranchNodeCompact is capable of storing up to 2 levels + // deep of nodes (?). + let data = BTreeMap::from([ + ( + hex!("1000000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1100000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1110000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1200000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1220000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + // unrelated leaf + hex!("1320000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ]); + data.iter().for_each(|(key, val)| { + let nibbles = Nibbles::unpack(key); + hb.add_leaf(nibbles, val.as_ref()); + }); + let root = hb.root(); + drop(hb); + + let receiver = UnboundedReceiverStream::new(recv); + let updates = receiver.collect::>().await; + + let updates = updates.iter().cloned().collect::>(); + let update = updates.get(&Nibbles::from(hex!("01").as_slice())).unwrap(); + assert_eq!(update.state_mask, TrieMask::new(0b1111)); // 1st nibble: 0, 1, 2, 3 + assert_eq!(update.tree_mask, TrieMask::new(0)); + assert_eq!(update.hash_mask, TrieMask::new(6)); // in the 1st nibble, the ones with 1 and 2 are branches with `hashes` + assert_eq!(update.hashes.len(), 2); // calculated while the builder is running + + assert_eq!(root, trie_root(data)); + } + + #[test] + fn test_root_raw_data() { + let data = vec![ + (hex!("646f").to_vec(), hex!("76657262").to_vec()), + (hex!("676f6f64").to_vec(), hex!("7075707079").to_vec()), + (hex!("676f6b32").to_vec(), hex!("7075707079").to_vec()), + (hex!("676f6b34").to_vec(), hex!("7075707079").to_vec()), + ]; + assert_trie_root(data.into_iter()); + } + + #[test] + fn test_root_rlp_hashed_data() { + let data = HashMap::from([ + (H256::from_low_u64_le(1), U256::from(2)), + (H256::from_low_u64_be(3), U256::from(4)), + ]); + assert_hashed_trie_root(data.iter()); + } + + #[test] + fn test_root_known_hash() { + let root_hash = H256::random(); + let mut hb = HashBuilder::default(); + hb.add_branch(Nibbles::default(), root_hash, false); + assert_eq!(hb.root(), root_hash); + } + + #[test] + fn manual_branch_node_ok() { + let raw_input = vec![ + (hex!("646f").to_vec(), hex!("76657262").to_vec()), + (hex!("676f6f64").to_vec(), hex!("7075707079").to_vec()), + ]; + let input = + raw_input.iter().map(|(key, value)| (Nibbles::unpack(key), value)).collect::>(); + + // We create the hash builder and add the leaves + let mut hb = HashBuilder::default(); + for (key, val) in input.iter() { + hb.add_leaf(key.clone(), val.as_slice()); + } + + // Manually create the branch node that should be there after the first 2 leaves are added. + // Skip the 0th element given in this example they have a common prefix and will + // collapse to a Branch node. + use reth_primitives::bytes::BytesMut; + use reth_rlp::Encodable; + let leaf1 = LeafNode::new(&Nibbles::unpack(&raw_input[0].0[1..]), input[0].1); + let leaf2 = LeafNode::new(&Nibbles::unpack(&raw_input[1].0[1..]), input[1].1); + let mut branch: [&dyn Encodable; 17] = [b""; 17]; + // We set this to `4` and `7` because that mathces the 2nd element of the corresponding + // leaves. We set this to `7` because the 2nd element of Leaf 1 is `7`. + branch[4] = &leaf1; + branch[7] = &leaf2; + let mut branch_node_rlp = BytesMut::new(); + reth_rlp::encode_list::(&branch, &mut branch_node_rlp); + let branch_node_hash = keccak256(branch_node_rlp); + + let mut hb2 = HashBuilder::default(); + // Insert the branch with the `0x6` shared prefix. + hb2.add_branch(Nibbles::from_hex(vec![0x6]), branch_node_hash, false); + + let expected = trie_root(raw_input.clone()); + assert_eq!(hb.root(), expected); + assert_eq!(hb2.root(), expected); + } +} diff --git a/crates/trie/src/hash_builder/value.rs b/crates/trie/src/hash_builder/value.rs new file mode 100644 index 0000000000..71acfdf133 --- /dev/null +++ b/crates/trie/src/hash_builder/value.rs @@ -0,0 +1,40 @@ +use reth_primitives::H256; + +#[derive(Clone)] +pub(crate) enum HashBuilderValue { + Bytes(Vec), + Hash(H256), +} + +impl std::fmt::Debug for HashBuilderValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bytes(bytes) => write!(f, "Bytes({:?})", hex::encode(bytes)), + Self::Hash(hash) => write!(f, "Hash({:?})", hash), + } + } +} + +impl From> for HashBuilderValue { + fn from(value: Vec) -> Self { + Self::Bytes(value) + } +} + +impl From<&[u8]> for HashBuilderValue { + fn from(value: &[u8]) -> Self { + Self::Bytes(value.to_vec()) + } +} + +impl From for HashBuilderValue { + fn from(value: H256) -> Self { + Self::Hash(value) + } +} + +impl Default for HashBuilderValue { + fn default() -> Self { + Self::Bytes(vec![]) + } +} diff --git a/crates/trie/src/lib.rs b/crates/trie/src/lib.rs index 68aaddfaa7..f0ae16d486 100644 --- a/crates/trie/src/lib.rs +++ b/crates/trie/src/lib.rs @@ -11,3 +11,6 @@ /// Various branch nodes producde by the hash builder. pub mod nodes; + +/// The implementation of hash builder. +pub mod hash_builder; diff --git a/crates/trie/src/nodes/branch.rs b/crates/trie/src/nodes/branch.rs index c533c45924..be3f34e52c 100644 --- a/crates/trie/src/nodes/branch.rs +++ b/crates/trie/src/nodes/branch.rs @@ -1,5 +1,5 @@ -use super::{matches_mask, rlp_node}; -use reth_primitives::{bytes::BytesMut, H256}; +use super::rlp_node; +use reth_primitives::{bytes::BytesMut, trie::TrieMask, H256}; use reth_rlp::{BufMut, EMPTY_STRING_CODE}; /// A Branch node is only a pointer to the stack of nodes and is used to @@ -19,12 +19,16 @@ impl<'a> BranchNode<'a> { /// Given the hash and state mask of children present, return an iterator over the stack items /// that match the mask. - pub fn children(&self, state_mask: u16, hash_mask: u16) -> impl Iterator + '_ { + pub fn children( + &self, + state_mask: TrieMask, + hash_mask: TrieMask, + ) -> impl Iterator + '_ { let mut index = self.stack.len() - state_mask.count_ones() as usize; (0..16).filter_map(move |digit| { let mut child = None; - if matches_mask(state_mask, digit) { - if matches_mask(hash_mask, digit) { + if state_mask.is_bit_set(digit) { + if hash_mask.is_bit_set(digit) { child = Some(&self.stack[index]); } index += 1; @@ -34,7 +38,7 @@ impl<'a> BranchNode<'a> { } /// Returns the RLP encoding of the branch node given the state mask of children present. - pub fn rlp(&self, state_mask: u16) -> Vec { + pub fn rlp(&self, state_mask: TrieMask) -> Vec { let first_child_idx = self.stack.len() - state_mask.count_ones() as usize; let mut buf = BytesMut::new(); @@ -43,7 +47,7 @@ impl<'a> BranchNode<'a> { let header = (0..16).fold( reth_rlp::Header { list: true, payload_length: 1 }, |mut header, digit| { - if matches_mask(state_mask, digit) { + if state_mask.is_bit_set(digit) { header.payload_length += self.stack[i].len(); i += 1; } else { @@ -57,7 +61,7 @@ impl<'a> BranchNode<'a> { // Extend the RLP buffer with the present children let mut i = first_child_idx; (0..16).for_each(|idx| { - if matches_mask(state_mask, idx) { + if state_mask.is_bit_set(idx) { buf.extend_from_slice(&self.stack[i]); i += 1; } else { diff --git a/crates/trie/src/nodes/mod.rs b/crates/trie/src/nodes/mod.rs index b7e3efc29d..347f0f722a 100644 --- a/crates/trie/src/nodes/mod.rs +++ b/crates/trie/src/nodes/mod.rs @@ -2,13 +2,10 @@ use reth_primitives::{keccak256, H256}; use reth_rlp::EMPTY_STRING_CODE; mod branch; -pub use branch::BranchNode; - mod extension; -pub use extension::ExtensionNode; - mod leaf; -pub use leaf::LeafNode; + +pub use self::{branch::BranchNode, extension::ExtensionNode, leaf::LeafNode}; /// Given an RLP encoded node, returns either RLP(Node) or RLP(keccak(RLP(node))) fn rlp_node(rlp: &[u8]) -> Vec { @@ -23,7 +20,3 @@ fn rlp_node(rlp: &[u8]) -> Vec { pub fn rlp_hash(hash: H256) -> Vec { [[EMPTY_STRING_CODE + H256::len_bytes() as u8].as_slice(), hash.0.as_slice()].concat() } - -fn matches_mask(mask: u16, idx: i32) -> bool { - mask & (1u16 << idx) != 0 -}