From 396c2ec342fbd0a776e494ed46de1e4f27d8f93e Mon Sep 17 00:00:00 2001 From: rymnc <43716372+rymnc@users.noreply.github.com> Date: Tue, 1 Aug 2023 00:23:28 +0530 Subject: [PATCH] test: new batching mechanism --- rln/src/pm_tree_adapter.rs | 87 ++++++++----------- rln/src/public.rs | 6 ++ utils/benches/merkle_tree_benchmark.rs | 22 ++++- utils/src/merkle_tree/full_merkle_tree.rs | 89 +++++++++----------- utils/src/merkle_tree/merkle_tree.rs | 19 +++-- utils/src/merkle_tree/optimal_merkle_tree.rs | 41 ++------- utils/tests/merkle_tree.rs | 36 ++++---- 7 files changed, 135 insertions(+), 165 deletions(-) diff --git a/rln/src/pm_tree_adapter.rs b/rln/src/pm_tree_adapter.rs index 9492270..3d3e407 100644 --- a/rln/src/pm_tree_adapter.rs +++ b/rln/src/pm_tree_adapter.rs @@ -3,11 +3,11 @@ use crate::hashers::{poseidon_hash, PoseidonHash}; use crate::utils::{bytes_le_to_fr, fr_to_bytes_le}; use color_eyre::{Report, Result}; use serde_json::Value; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::path::PathBuf; use std::str::FromStr; -use utils::pmtree::{Database, Hasher}; +use utils::pmtree::{DBKey, Database, Hasher}; use utils::*; const METADATA_KEY: [u8; 8] = *b"metadata"; @@ -192,57 +192,42 @@ impl ZerokitMerkleTree for PmTree { indices: J, ) -> Result<()> { let leaves = leaves.into_iter().collect::>(); - let indices = indices.into_iter().collect::>(); - let end = start + leaves.len() + indices.len(); + let indices = indices.into_iter().collect::>(); + let mut subtree = HashMap::::new(); + let leaves_len = leaves.len(); + let leaves_set = self.tree.leaves_set(); - // handle each case appropriately - - // case 1: both leaves and indices to be removed are passed in - // case 2: only leaves are passed in - // case 3: only indices are passed in - // case 4: neither leaves nor indices are passed in - match (leaves.len(), indices.len()) { - (0, 0) => Err(Report::msg("no leaves or indices to be removed")), - (0, _) => { - // case 3 - // remove indices - let mut new_leaves = Vec::new(); - let start = start + indices[0]; - let end = start + indices.len(); - for _ in start..end { - // Insert 0 - new_leaves.push(Self::Hasher::default_leaf()); - } - self.tree - .set_range(start, new_leaves) - .map_err(|e| Report::msg(e.to_string())) - } - (_, 0) => { - // case 2 - // insert leaves - self.tree - .set_range(start, leaves) - .map_err(|e| Report::msg(e.to_string())) - } - (_, _) => { - // case 1 - // remove indices - let mut new_leaves = Vec::new(); - let indices = indices.into_iter().collect::>(); - let new_start = start + leaves.len(); - for i in new_start..=end { - if indices.contains(&i) { - // Insert 0 - new_leaves.push(Self::Hasher::default_leaf()); - } else if let Some(leaf) = leaves.get(i - new_start) { - // Insert leaf - new_leaves.push(*leaf); - } - } - self.tree - .set_range(start, new_leaves) - .map_err(|e| Report::msg(e.to_string())) - } + dbg!(self.tree.root()); + + // insert the old leaves + for i in 0..leaves_set { + let leaf = self.tree.get(i)?; + subtree.insert(i, leaf); } + + // zero out the leaves to be removed + for index in indices { + if index >= leaves_set { + return Err(Report::msg(format!( + "Index {} is out of bounds, leaves_set: {}", + index, leaves_set + ))); + } + subtree.insert(index, Self::Hasher::default_leaf()); + } + // insert the new leaves from start + for i in start..(start + leaves_len) { + let leaf = leaves[i - start]; + subtree.insert(i, leaf); + } + + // Use set_range with the new_leaves buffer to update the tree. + let res = self + .tree + .set_range(0, subtree.into_iter().map(|(_, v)| v)) + .map_err(|e| Report::msg(e.to_string())); + dbg!(self.tree.root()); + return res; } fn update_next(&mut self, leaf: FrOf) -> Result<()> { diff --git a/rln/src/public.rs b/rln/src/public.rs index 499711c..db49a6f 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -1438,6 +1438,12 @@ mod test { rln.get_root(&mut buffer).unwrap(); let (root_after_noop, _) = bytes_le_to_fr(&buffer.into_inner()); + let mut output_buffer = Cursor::new(Vec::::new()); + rln.get_leaf(last_leaf_index, &mut output_buffer).unwrap(); + let (received_leaf, _) = bytes_le_to_fr(output_buffer.into_inner().as_ref()); + + assert_eq!(received_leaf, last_leaf[0]); + assert_eq!(root_after_insertion, root_after_noop); } diff --git a/utils/benches/merkle_tree_benchmark.rs b/utils/benches/merkle_tree_benchmark.rs index 57e2663..dea163b 100644 --- a/utils/benches/merkle_tree_benchmark.rs +++ b/utils/benches/merkle_tree_benchmark.rs @@ -3,7 +3,7 @@ use hex_literal::hex; use tiny_keccak::{Hasher as _, Keccak}; use zerokit_utils::{ FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, - ZerokitMerkleTree, + ZerokitMerkleTree, BatchOf, }; #[derive(Clone, Copy, Eq, PartialEq)] @@ -50,9 +50,16 @@ pub fn optimal_merkle_tree_benchmark(c: &mut Criterion) { }) }); - c.bench_function("OptimalMerkleTree::override_range", |b| { + c.bench_function("OptimalMerkleTree::set_range", |b| { b.iter(|| { - tree.override_range(0, leaves, [0, 1, 2, 3]).unwrap(); + let mut batch = BatchOf::>::new(); + for i in 0..leaves.len() { + batch.insert(i, leaves[i]); + } + for i in [0, 1, 2, 3] { + batch.remove(&i); + } + tree.set_range(&batch).unwrap(); }) }); @@ -94,7 +101,14 @@ pub fn full_merkle_tree_benchmark(c: &mut Criterion) { c.bench_function("FullMerkleTree::override_range", |b| { b.iter(|| { - tree.override_range(0, leaves, [0, 1, 2, 3]).unwrap(); + let mut batch = BatchOf::>::new(); + for i in 0..leaves.len() { + batch.insert(i, leaves[i]); + } + for i in [0, 1, 2, 3] { + batch.remove(&i); + } + tree.set_range(&batch).unwrap(); }) }); diff --git a/utils/src/merkle_tree/full_merkle_tree.rs b/utils/src/merkle_tree/full_merkle_tree.rs index 15ef422..26dd3c5 100644 --- a/utils/src/merkle_tree/full_merkle_tree.rs +++ b/utils/src/merkle_tree/full_merkle_tree.rs @@ -1,10 +1,10 @@ -use crate::merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree}; +use crate::{merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree}, merkle_tree::Batch}; use color_eyre::{Report, Result}; use std::{ cmp::max, fmt::Debug, - iter::{once, repeat, successors}, - str::FromStr, + iter::{repeat, successors}, + str::FromStr, collections::HashMap, }; //////////////////////////////////////////////////////////// @@ -59,6 +59,30 @@ impl FromStr for FullMerkleConfig { } } +impl Batch for HashMap> +where + H: Hasher, +{ + type Key = usize; + + fn insert(&mut self, key: usize, value: FrOf) { + self.insert(key, value); + } + + fn remove(&mut self, key: usize) { + self.remove(&key); + } + + fn max_index(&self) -> usize { + *self.keys().max().unwrap_or(&0) + } + + fn min_index(&self) -> usize { + *self.keys().min().unwrap_or(&0) + } +} + + /// Implementations impl ZerokitMerkleTree for FullMerkleTree where @@ -67,6 +91,7 @@ where type Proof = FullMerkleProof; type Hasher = H; type Config = FullMerkleConfig; + type Batch = HashMap>; fn default(depth: usize) -> Result { FullMerkleTree::::new(depth, Self::Hasher::default_leaf(), Self::Config::default()) @@ -128,7 +153,12 @@ where // Sets a leaf at the specified tree index fn set(&mut self, leaf: usize, hash: FrOf) -> Result<()> { - self.set_range(leaf, once(hash))?; + if leaf >= self.capacity() { + return Err(Report::msg("leaf index out of bounds")); + } + let capacity = self.capacity(); + self.nodes[capacity + leaf - 1] = hash; + self.update_nodes(capacity + leaf - 1, capacity + leaf - 1)?; self.next_index = max(self.next_index, leaf + 1); Ok(()) } @@ -143,59 +173,18 @@ where // Sets tree nodes, starting from start index // Function proper of FullMerkleTree implementation - fn set_range>>( + fn set_range( &mut self, - start: usize, - hashes: I, + batch: &Self::Batch, ) -> Result<()> { - let index = self.capacity() + start - 1; - let mut count = 0; // first count number of hashes, and check that they fit in the tree // then insert into the tree - let hashes = hashes.into_iter().collect::>(); - if hashes.len() + start > self.capacity() { - return Err(Report::msg("provided hashes do not fit in the tree")); - } - hashes.into_iter().for_each(|hash| { - self.nodes[index + count] = hash; - count += 1; - }); - if count != 0 { - self.update_nodes(index, index + (count - 1))?; - self.next_index = max(self.next_index, start + count); - } - Ok(()) - } - - fn override_range(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()> - where - I: IntoIterator>, - J: IntoIterator, - { - let index = self.capacity() + start - 1; - let mut count = 0; - let leaves = leaves.into_iter().collect::>(); - let to_remove_indices = to_remove_indices.into_iter().collect::>(); - // first count number of hashes, and check that they fit in the tree - // then insert into the tree - if leaves.len() + start - to_remove_indices.len() > self.capacity() { + if batch.len() > self.capacity() { return Err(Report::msg("provided hashes do not fit in the tree")); } - // remove leaves - for i in &to_remove_indices { - self.delete(*i)?; - } - - // insert new leaves - for hash in leaves { - self.nodes[index + count] = hash; - count += 1; - } - - if count != 0 { - self.update_nodes(index, index + (count - 1))?; - self.next_index = max(self.next_index, start + count - to_remove_indices.len()); + for (key, value) in batch { + self.set(*key, *value)?; } Ok(()) } diff --git a/utils/src/merkle_tree/merkle_tree.rs b/utils/src/merkle_tree/merkle_tree.rs index 5ee6bdf..d9708ce 100644 --- a/utils/src/merkle_tree/merkle_tree.rs +++ b/utils/src/merkle_tree/merkle_tree.rs @@ -30,7 +30,17 @@ pub trait Hasher { fn hash(input: &[Self::Fr]) -> Self::Fr; } +pub trait Batch where H:Hasher { + type Key; + + fn insert(&mut self, key: usize, value: H::Fr); + fn remove(&mut self, key: usize); + fn max_index(&self) -> usize; + fn min_index(&self) -> usize; +} + pub type FrOf = ::Fr; +pub type BatchOf = ::Batch; /// In the ZerokitMerkleTree trait we define the methods that are required to be implemented by a Merkle tree /// Including, OptimalMerkleTree, FullMerkleTree @@ -38,6 +48,7 @@ pub trait ZerokitMerkleTree { type Proof: ZerokitMerkleProof; type Hasher: Hasher; type Config: Default + FromStr; + type Batch: Batch; fn default(depth: usize) -> Result where @@ -51,14 +62,8 @@ pub trait ZerokitMerkleTree { fn root(&self) -> FrOf; fn compute_root(&mut self) -> Result>; fn set(&mut self, index: usize, leaf: FrOf) -> Result<()>; - fn set_range(&mut self, start: usize, leaves: I) -> Result<()> - where - I: IntoIterator>; + fn set_range(&mut self, batch: &Self::Batch) -> Result<()>; fn get(&self, index: usize) -> Result>; - fn override_range(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()> - where - I: IntoIterator>, - J: IntoIterator; fn update_next(&mut self, leaf: FrOf) -> Result<()>; fn delete(&mut self, index: usize) -> Result<()>; fn proof(&self, index: usize) -> Result; diff --git a/utils/src/merkle_tree/optimal_merkle_tree.rs b/utils/src/merkle_tree/optimal_merkle_tree.rs index 48470e7..9735690 100644 --- a/utils/src/merkle_tree/optimal_merkle_tree.rs +++ b/utils/src/merkle_tree/optimal_merkle_tree.rs @@ -60,6 +60,7 @@ where type Proof = OptimalMerkleProof; type Hasher = H; type Config = OptimalMerkleConfig; + type Batch = HashMap>; fn default(depth: usize) -> Result { OptimalMerkleTree::::new(depth, H::default_leaf(), Self::Config::default()) @@ -128,47 +129,15 @@ where } // Sets multiple leaves from the specified tree index - fn set_range>(&mut self, start: usize, leaves: I) -> Result<()> { - let leaves = leaves.into_iter().collect::>(); + fn set_range(&mut self, batch: &Self::Batch) -> Result<()> { // check if the range is valid - if start + leaves.len() > self.capacity() { - return Err(Report::msg("provided range exceeds set size")); - } - for (i, leaf) in leaves.iter().enumerate() { - self.nodes.insert((self.depth, start + i), *leaf); - self.recalculate_from(start + i)?; - } - self.next_index = max(self.next_index, start + leaves.len()); - Ok(()) - } - - fn override_range(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()> - where - I: IntoIterator>, - J: IntoIterator, - { - let leaves = leaves.into_iter().collect::>(); - let to_remove_indices = to_remove_indices.into_iter().collect::>(); - // check if the range is valid - if leaves.len() + start - to_remove_indices.len() > self.capacity() { + if batch.len() > self.capacity() { return Err(Report::msg("provided range exceeds set size")); } - // remove leaves - for i in &to_remove_indices { - self.delete(*i)?; + for (key, value) in batch { + self.set(*key, *value)?; } - - // add leaves - for (i, leaf) in leaves.iter().enumerate() { - self.nodes.insert((self.depth, start + i), *leaf); - self.recalculate_from(start + i)?; - } - - self.next_index = max( - self.next_index, - start + leaves.len() - to_remove_indices.len(), - ); Ok(()) } diff --git a/utils/tests/merkle_tree.rs b/utils/tests/merkle_tree.rs index 2abcaeb..d33b89d 100644 --- a/utils/tests/merkle_tree.rs +++ b/utils/tests/merkle_tree.rs @@ -1,11 +1,13 @@ // Tests adapted from https://github.com/worldcoin/semaphore-rs/blob/d462a4372f1fd9c27610f2acfe4841fab1d396aa/src/merkle_tree.rs #[cfg(test)] mod test { + use std::collections::HashMap; + use hex_literal::hex; use tiny_keccak::{Hasher as _, Keccak}; use zerokit_utils::{ FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, - ZerokitMerkleProof, ZerokitMerkleTree, + ZerokitMerkleProof, ZerokitMerkleTree, BatchOf, }; #[derive(Clone, Copy, Eq, PartialEq)] struct Keccak256; @@ -139,27 +141,27 @@ mod test { OptimalMerkleTree::::new(2, [0; 32], OptimalMerkleConfig::default()) .unwrap(); - // We set the leaves - tree.set_range(0, initial_leaves.iter().cloned()).unwrap(); + // We set the leaves in a batch + // Batch = Hashmap + let batch = initial_leaves + .iter() + .enumerate() + .map(|(i, leaf)| (i, *leaf)) + .collect::>(); + tree.set_range(&batch).unwrap(); - let new_leaves = [ - hex!("0000000000000000000000000000000000000000000000000000000000000005"), - hex!("0000000000000000000000000000000000000000000000000000000000000006"), - ]; - - let to_delete_indices: [usize; 2] = [0, 1]; + let mut new_batch = BatchOf::>::new(); + new_batch.remove(&0); + new_batch.remove(&1); + new_batch.insert(tree.leaves_set() - 2, hex!("0000000000000000000000000000000000000000000000000000000000000005")); + new_batch.insert(tree.leaves_set() - 1, hex!("0000000000000000000000000000000000000000000000000000000000000006")); // We override the leaves - tree.override_range( - 0, // start from the end of the initial leaves - new_leaves.iter().cloned(), - to_delete_indices.iter().cloned(), - ) - .unwrap(); + tree.set_range(&new_batch).unwrap(); // ensure that the leaves are set correctly - for i in 0..new_leaves.len() { - assert_eq!(tree.get_leaf(i), new_leaves[i]); + for (i, leaf) in new_batch { + assert_eq!(tree.get_leaf(i), leaf); } } }