diff --git a/rln/benches/pmtree_benchmark.rs b/rln/benches/pmtree_benchmark.rs index 6fa3d87..c4a9341 100644 --- a/rln/benches/pmtree_benchmark.rs +++ b/rln/benches/pmtree_benchmark.rs @@ -21,7 +21,7 @@ pub fn pmtree_benchmark(c: &mut Criterion) { c.bench_function("Pmtree::override_range", |b| { b.iter(|| { - tree.override_range(0, leaves.clone(), [0, 1, 2, 3]) + tree.override_range(0, leaves.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); }) }); diff --git a/rln/src/pm_tree_adapter.rs b/rln/src/pm_tree_adapter.rs index 7901f4f..c7e80ee 100644 --- a/rln/src/pm_tree_adapter.rs +++ b/rln/src/pm_tree_adapter.rs @@ -244,7 +244,7 @@ impl ZerokitMerkleTree for PmTree { (0, 0) => Err(Report::msg("no leaves or indices to be removed")), (1, 0) => self.set(start, leaves[0]), (0, 1) => self.delete(indices[0]), - (_, 0) => self.set_range(start, leaves), + (_, 0) => self.set_range(start, leaves.into_iter()), (0, _) => self.remove_indices(&indices), (_, _) => self.remove_indices_and_set_leaves(start, leaves, &indices), } diff --git a/rln/src/public.rs b/rln/src/public.rs index 2fa534a..a8195e6 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -385,7 +385,7 @@ impl RLN { // We set the leaves self.tree - .override_range(index, leaves, []) + .override_range(index, leaves.into_iter(), [].into_iter()) .map_err(|_| Report::msg("Could not set leaves"))?; Ok(()) } @@ -468,7 +468,7 @@ impl RLN { // We set the leaves self.tree - .override_range(index, leaves, indices) + .override_range(index, leaves.into_iter(), indices.into_iter()) .map_err(|e| Report::msg(format!("Could not perform the batch operation: {e}")))?; Ok(()) } diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index fa1a2b7..8032cb2 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -156,6 +156,7 @@ mod test { // random number between 0..no_of_leaves let mut rng = thread_rng(); let set_index = rng.gen_range(0..NO_OF_LEAVES) as usize; + println!("set_index: {}", set_index); // We add leaves in a batch into the tree set_leaves_init(rln_pointer, &leaves); @@ -176,7 +177,10 @@ mod test { // We get the root of the tree obtained adding leaves in batch let root_batch_with_custom_index = get_tree_root(rln_pointer); - assert_eq!(root_batch_with_init, root_batch_with_custom_index); + assert_eq!( + root_batch_with_init, root_batch_with_custom_index, + "root batch !=" + ); // We reset the tree to default let success = set_tree(rln_pointer, TEST_TREE_HEIGHT); @@ -192,7 +196,10 @@ mod test { // We get the root of the tree obtained adding leaves using the internal index let root_single_additions = get_tree_root(rln_pointer); - assert_eq!(root_batch_with_init, root_single_additions); + assert_eq!( + root_batch_with_init, root_single_additions, + "root single additions !=" + ); } #[test] diff --git a/rln/tests/poseidon_tree.rs b/rln/tests/poseidon_tree.rs index 69749b1..08bf612 100644 --- a/rln/tests/poseidon_tree.rs +++ b/rln/tests/poseidon_tree.rs @@ -23,6 +23,7 @@ mod test { assert_eq!(proof.leaf_index(), i); tree_opt.set(i, leaves[i]).unwrap(); + assert_eq!(tree_opt.root(), tree_full.root()); let proof = tree_opt.proof(i).expect("index should be set"); assert_eq!(proof.leaf_index(), i); } @@ -37,11 +38,11 @@ mod test { #[test] fn test_subtree_root() { const DEPTH: usize = 3; - const LEAVES_LEN: usize = 6; + const LEAVES_LEN: usize = 8; let mut tree = PoseidonTree::default(DEPTH).unwrap(); let leaves: Vec = (0..LEAVES_LEN).map(|s| Fr::from(s as i32)).collect(); - let _ = tree.set_range(0, leaves); + let _ = tree.set_range(0, leaves.into_iter()); for i in 0..LEAVES_LEN { // check leaves @@ -78,7 +79,7 @@ mod test { let leaves: Vec = (0..nof_leaves).map(|s| Fr::from(s as i32)).collect(); // check set_range - let _ = tree.set_range(0, leaves.clone()); + let _ = tree.set_range(0, leaves.clone().into_iter()); assert!(tree.get_empty_leaves_indices().is_empty()); let mut vec_idxs = Vec::new(); @@ -98,26 +99,28 @@ mod test { // check remove_indices_and_set_leaves inside override_range function assert!(tree.get_empty_leaves_indices().is_empty()); let leaves_2: Vec = (0..2).map(|s| Fr::from(s as i32)).collect(); - tree.override_range(0, leaves_2.clone(), [0, 1, 2, 3]) + tree.override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree.get_empty_leaves_indices(), vec![2, 3]); // check remove_indices inside override_range function - tree.override_range(0, [], [0, 1]).unwrap(); + tree.override_range(0, [].into_iter(), [0, 1].into_iter()) + .unwrap(); assert_eq!(tree.get_empty_leaves_indices(), vec![0, 1, 2, 3]); // check set_range inside override_range function - tree.override_range(0, leaves_2.clone(), []).unwrap(); + tree.override_range(0, leaves_2.clone().into_iter(), [].into_iter()) + .unwrap(); assert_eq!(tree.get_empty_leaves_indices(), vec![2, 3]); let leaves_4: Vec = (0..4).map(|s| Fr::from(s as i32)).collect(); // check if the indexes for write and delete are the same - tree.override_range(0, leaves_4.clone(), [0, 1, 2, 3]) + tree.override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert!(tree.get_empty_leaves_indices().is_empty()); // check if indexes for deletion are before indexes for overwriting - tree.override_range(4, leaves_4.clone(), [0, 1, 2, 3]) + tree.override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); // The result will be like this, because in the set_range function in pmtree // the next_index value is increased not by the number of elements to insert, @@ -128,7 +131,7 @@ mod test { ); // check if the indices for write and delete do not overlap completely - tree.override_range(2, leaves_4.clone(), [0, 1, 2, 3]) + tree.override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); // The result will be like this, because in the set_range function in pmtree // the next_index value is increased not by the number of elements to insert, diff --git a/utils/benches/merkle_tree_benchmark.rs b/utils/benches/merkle_tree_benchmark.rs index b1adc90..89ebbe6 100644 --- a/utils/benches/merkle_tree_benchmark.rs +++ b/utils/benches/merkle_tree_benchmark.rs @@ -75,7 +75,8 @@ pub fn optimal_merkle_tree_benchmark(c: &mut Criterion) { c.bench_function("OptimalMerkleTree::override_range", |b| { b.iter(|| { - tree.override_range(0, *LEAVES, [0, 1, 2, 3]).unwrap(); + tree.override_range(0, LEAVES.into_iter(), [0, 1, 2, 3].into_iter()) + .unwrap(); }) }); @@ -123,7 +124,8 @@ pub fn full_merkle_tree_benchmark(c: &mut Criterion) { c.bench_function("FullMerkleTree::override_range", |b| { b.iter(|| { - tree.override_range(0, *LEAVES, [0, 1, 2, 3]).unwrap(); + tree.override_range(0, LEAVES.into_iter(), [0, 1, 2, 3].into_iter()) + .unwrap(); }) }); diff --git a/utils/src/merkle_tree/full_merkle_tree.rs b/utils/src/merkle_tree/full_merkle_tree.rs index 55aa0d5..98e6463 100644 --- a/utils/src/merkle_tree/full_merkle_tree.rs +++ b/utils/src/merkle_tree/full_merkle_tree.rs @@ -236,7 +236,7 @@ where self.cached_leaves_indices[i] = 0; } - self.set_range(start, set_values) + self.set_range(start, set_values.into_iter()) .map_err(|e| Report::msg(e.to_string())) } diff --git a/utils/src/merkle_tree/merkle_tree.rs b/utils/src/merkle_tree/merkle_tree.rs index 65ad1a7..337f133 100644 --- a/utils/src/merkle_tree/merkle_tree.rs +++ b/utils/src/merkle_tree/merkle_tree.rs @@ -54,13 +54,13 @@ pub trait ZerokitMerkleTree { fn set(&mut self, index: usize, leaf: FrOf) -> Result<()>; fn set_range(&mut self, start: usize, leaves: I) -> Result<()> where - I: IntoIterator>; + I: ExactSizeIterator>; fn get(&self, index: usize) -> Result>; fn get_empty_leaves_indices(&self) -> Vec; fn override_range(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()> where - I: IntoIterator>, - J: IntoIterator; + I: ExactSizeIterator>, + J: ExactSizeIterator; fn update_next(&mut self, leaf: FrOf) -> Result<()>; fn delete(&mut self, index: usize) -> Result<()>; fn proof(&self, index: usize) -> Result; diff --git a/utils/src/merkle_tree/optimal_merkle_tree.rs b/utils/src/merkle_tree/optimal_merkle_tree.rs index a61095a..ab0f9fc 100644 --- a/utils/src/merkle_tree/optimal_merkle_tree.rs +++ b/utils/src/merkle_tree/optimal_merkle_tree.rs @@ -1,10 +1,10 @@ use crate::merkle_tree::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree}; use crate::FrOf; use color_eyre::{Report, Result}; +use std::cmp::min; use std::collections::HashMap; use std::str::FromStr; use std::{cmp::max, fmt::Debug}; - //////////////////////////////////////////////////////////// ///// Optimal Merkle Tree Implementation //////////////////////////////////////////////////////////// @@ -81,9 +81,9 @@ where } cached_nodes.reverse(); Ok(OptimalMerkleTree { - cached_nodes: cached_nodes.clone(), + cached_nodes, depth, - nodes: HashMap::new(), + nodes: HashMap::with_capacity(1 << depth), cached_leaves_indices: vec![0; 1 << depth], next_index: 0, metadata: Vec::new(), @@ -136,7 +136,7 @@ where return Err(Report::msg("index exceeds set size")); } self.nodes.insert((self.depth, index), leaf); - self.recalculate_from(index)?; + self.update_hashes(index, 1)?; self.next_index = max(self.next_index, index + 1); self.cached_leaves_indices[index] = 1; Ok(()) @@ -161,25 +161,29 @@ where } // Sets multiple leaves from the specified tree index - fn set_range>(&mut self, start: usize, leaves: I) -> Result<()> { - let leaves = leaves.into_iter().collect::>(); + fn set_range>( + &mut self, + start: usize, + leaves: I, + ) -> Result<()> { // check if the range is valid - if start + leaves.len() > self.capacity() { + let leaves_len = leaves.len(); + if start + leaves_len > self.capacity() { return Err(Report::msg("provided range exceeds set size")); } - for (i, leaf) in leaves.iter().enumerate() { - self.nodes.insert((self.depth, start + i), *leaf); + for (i, leaf) in leaves.enumerate() { + self.nodes.insert((self.depth, start + i), leaf); self.cached_leaves_indices[start + i] = 1; - self.recalculate_from(start + i)?; } - self.next_index = max(self.next_index, start + leaves.len()); + self.update_hashes(start, leaves_len)?; + self.next_index = max(self.next_index, start + leaves_len); Ok(()) } fn override_range(&mut self, start: usize, leaves: I, indices: J) -> Result<()> where - I: IntoIterator>, - J: IntoIterator, + I: ExactSizeIterator>, + J: ExactSizeIterator, { let indices = indices.into_iter().collect::>(); let min_index = *indices.first().unwrap(); @@ -204,7 +208,7 @@ where self.cached_leaves_indices[i] = 0; } - self.set_range(start, set_values) + self.set_range(start, set_values.into_iter()) .map_err(|e| Report::msg(e.to_string())) } @@ -316,6 +320,60 @@ where } Ok(()) } + + /// Update hashes after some leaves have been set or updated + /// index - first leaf index (which has been set or updated) + /// length - number of elements set or updated + fn update_hashes(&mut self, index: usize, length: usize) -> Result<()> { + // parent depth & index (used to store in the tree) + let mut parent_depth = self.depth - 1; // tree depth (or leaves depth) - 1 + let mut parent_index = index >> 1; + let mut parent_index_bak = parent_index; + // maximum index at this depth + let parent_max_index_0 = (1 << parent_depth) / 2; + // Based on given length (number of elements we will update) + // we could restrict the parent_max_index + let current_index_max = if (index + length) % 2 == 0 { + index + length + 2 + } else { + index + length + 1 + }; + let mut parent_max_index = min(current_index_max >> 1, parent_max_index_0); + + // current depth & index (used to compute the hash) + // current depth initially == tree depth (or leaves depth) + let mut current_depth = self.depth; + let mut current_index = if index % 2 == 0 { index } else { index - 1 }; + let mut current_index_bak = current_index; + + loop { + // Hash 2 values at (current depth, current_index) & (current_depth, current_index + 1) + let n_hash = self.hash_couple(current_depth, current_index); + // Insert this hash at (parent_depth, parent_index) + self.nodes.insert((parent_depth, parent_index), n_hash); + + if parent_depth == 0 { + // We just set the root hash of the tree - nothing to do anymore + break; + } + // Incr parent index + parent_index += 1; + // Incr current index (+2 because we've just hashed current index & current_index + 1) + current_index += 2; + if parent_index >= parent_max_index { + // reset (aka decr depth & reset indexes) + parent_depth -= 1; + parent_index = parent_index_bak >> 1; + parent_index_bak = parent_index; + parent_max_index >>= 1; + current_depth -= 1; + current_index = current_index_bak >> 1; + current_index_bak = current_index; + } + } + + Ok(()) + } } impl ZerokitMerkleProof for OptimalMerkleProof diff --git a/utils/src/poseidon/poseidon_hash.rs b/utils/src/poseidon/poseidon_hash.rs index 802c5db..975bf14 100644 --- a/utils/src/poseidon/poseidon_hash.rs +++ b/utils/src/poseidon/poseidon_hash.rs @@ -7,7 +7,7 @@ use crate::poseidon_constants::find_poseidon_ark_and_mds; use ark_ff::PrimeField; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct RoundParamenters { +pub struct RoundParameters { pub t: usize, pub n_rounds_f: usize, pub n_rounds_p: usize, @@ -17,16 +17,16 @@ pub struct RoundParamenters { } pub struct Poseidon { - round_params: Vec>, + round_params: Vec>, } impl Poseidon { // Loads round parameters and generates round constants // poseidon_params is a vector containing tuples (t, RF, RP, skip_matrices) - // where: t is the rate (input lenght + 1), RF is the number of full rounds, RP is the number of partial rounds + // where: t is the rate (input length + 1), RF is the number of full rounds, RP is the number of partial rounds // and skip_matrices is a (temporary) parameter used to generate secure MDS matrices (see comments in the description of find_poseidon_ark_and_mds) // TODO: implement automatic generation of round parameters pub fn from(poseidon_params: &[(usize, usize, usize, usize)]) -> Self { - let mut read_params = Vec::>::new(); + let mut read_params = Vec::>::with_capacity(poseidon_params.len()); for &(t, n_rounds_f, n_rounds_p, skip_matrices) in poseidon_params { let (ark, mds) = find_poseidon_ark_and_mds::( @@ -38,7 +38,7 @@ impl Poseidon { n_rounds_p as u64, skip_matrices, ); - let rp = RoundParamenters { + let rp = RoundParameters { t, n_rounds_p, n_rounds_f, @@ -54,7 +54,7 @@ impl Poseidon { } } - pub fn get_parameters(&self) -> Vec> { + pub fn get_parameters(&self) -> Vec> { self.round_params.clone() } @@ -80,21 +80,19 @@ impl Poseidon { } } - pub fn mix(&self, state: &[F], m: &[Vec]) -> Vec { - let mut new_state: Vec = Vec::new(); + pub fn mix_2(&self, state: &[F], m: &[Vec], state_2: &mut [F]) { for i in 0..state.len() { - new_state.push(F::zero()); + state_2[i] = F::ZERO; for (j, state_item) in state.iter().enumerate() { let mut mij = m[i][j]; mij *= state_item; - new_state[i] += mij; + state_2[i] += mij; } } - new_state.clone() } pub fn hash(&self, inp: Vec) -> Result { - // Note that the rate t becomes input lenght + 1, hence for lenght N we pick parameters with T = N + 1 + // Note that the rate t becomes input length + 1, hence for length N we pick parameters with T = N + 1 let t = inp.len() + 1; // We seek the index (Poseidon's round_params is an ordered vector) for the parameters corresponding to t @@ -106,8 +104,9 @@ impl Poseidon { let param_index = param_index.unwrap(); - let mut state = vec![F::zero(); t]; + let mut state = vec![F::ZERO; t]; state[1..].clone_from_slice(&inp); + let mut state_2 = vec![F::ZERO; state.len()]; for i in 0..(self.round_params[param_index].n_rounds_f + self.round_params[param_index].n_rounds_p) @@ -123,7 +122,8 @@ impl Poseidon { &mut state, i, ); - state = self.mix(&state, &self.round_params[param_index].m); + self.mix_2(&state, &self.round_params[param_index].m, &mut state_2); + std::mem::swap(&mut state, &mut state_2); } Ok(state[0]) diff --git a/utils/tests/merkle_tree.rs b/utils/tests/merkle_tree.rs index 00f46a6..0c531c1 100644 --- a/utils/tests/merkle_tree.rs +++ b/utils/tests/merkle_tree.rs @@ -107,7 +107,7 @@ pub mod test { let leaves_4: Vec = (0u32..4).map(TestFr::from).collect(); let mut tree_full = default_full_merkle_tree(depth); - let _ = tree_full.set_range(0, leaves.clone()); + let _ = tree_full.set_range(0, leaves.clone().into_iter()); assert!(tree_full.get_empty_leaves_indices().is_empty()); let mut vec_idxs = Vec::new(); @@ -125,31 +125,31 @@ pub mod test { // Check situation when the number of items to insert is less than the number of items to delete tree_full - .override_range(0, leaves_2.clone(), [0, 1, 2, 3]) + .override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); // check if the indexes for write and delete are the same tree_full - .override_range(0, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), vec![]); // check if indexes for deletion are before indexes for overwriting tree_full - .override_range(4, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1, 2, 3]); // check if the indices for write and delete do not overlap completely tree_full - .override_range(2, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1]); //// Optimal Merkle Tree Trest let mut tree_opt = default_optimal_merkle_tree(depth); - let _ = tree_opt.set_range(0, leaves.clone()); + let _ = tree_opt.set_range(0, leaves.clone().into_iter()); assert!(tree_opt.get_empty_leaves_indices().is_empty()); let mut vec_idxs = Vec::new(); @@ -166,24 +166,24 @@ pub mod test { // Check situation when the number of items to insert is less than the number of items to delete tree_opt - .override_range(0, leaves_2.clone(), [0, 1, 2, 3]) + .override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); // check if the indexes for write and delete are the same tree_opt - .override_range(0, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), vec![]); // check if indexes for deletion are before indexes for overwriting tree_opt - .override_range(4, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1, 2, 3]); // check if the indices for write and delete do not overlap completely tree_opt - .override_range(2, leaves_4.clone(), [0, 1, 2, 3]) + .override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1]); } @@ -191,7 +191,7 @@ pub mod test { #[test] fn test_subtree_root() { let depth = 3; - let nof_leaves: usize = 6; + let nof_leaves: usize = 4; let leaves: Vec = (0..nof_leaves as u32).map(TestFr::from).collect(); let mut tree_full = default_optimal_merkle_tree(depth);