From ed4fad8f9e1f958a21204b8401bdc168f2470663 Mon Sep 17 00:00:00 2001 From: rymnc <43716372+rymnc@users.noreply.github.com> Date: Tue, 9 May 2023 22:29:44 +0530 Subject: [PATCH] fix: fn --- src/tree.rs | 35 +++++++++++++++++++++++++---------- tests/sled_keccak.rs | 4 ++-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index daf792d..d3b5a59 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,7 +1,7 @@ use crate::*; use std::cmp::{max, min}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; // db[DEPTH_KEY] = depth @@ -216,14 +216,29 @@ where leaves: I, to_remove_indices: J, ) -> PmtreeResult<()> { - let leaves = leaves.into_iter().collect::>(); - let to_remove_indices = to_remove_indices.into_iter().collect::>(); + // this fn should first remove the elements at the given indices, then insert the new leaves + // this is done in parallel, and then the tree is recalculated from the first index + // operation must be atomic, so if one of the operations fails, the tree is not updated + + let leaves: Vec = leaves.into_iter().collect(); + let to_remove_indices: Vec = to_remove_indices.into_iter().collect(); let start = start.unwrap_or(self.next_index); let end = start + leaves.len(); - if end - to_remove_indices.len() > self.capacity() { - return Err(PmtreeErrorKind::TreeError(TreeErrorKind::MerkleTreeIsFull)); + // check if the leaves are in the correct range + if leaves.len() + start > self.capacity() { + return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds)); + } + + // check if the indices are in the correct range + if to_remove_indices.iter().any(|&i| i >= self.next_index) { + return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds)); + } + + // check if the indices are unique + if to_remove_indices.len() != to_remove_indices.iter().collect::>().len() { + return Err(PmtreeErrorKind::TreeError(TreeErrorKind::CustomError("Indices are not unique".to_string()))); } let mut subtree = HashMap::::new(); @@ -232,11 +247,12 @@ where subtree.insert(root_key, self.root); - for i in to_remove_indices { - subtree.insert(Key(self.depth, i), H::default_leaf()); - } self.fill_nodes(root_key, start, end, &mut subtree, &leaves, start)?; - + for i in to_remove_indices { + subtree.insert(Key(self.depth, i - leaves.len()), H::default_leaf()); + } + + let subtree = Arc::new(RwLock::new(subtree)); let root_val = rayon::ThreadPoolBuilder::new() @@ -261,7 +277,6 @@ where .put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?; } - // Update root value in memory self.root = root_val; Ok(()) diff --git a/tests/sled_keccak.rs b/tests/sled_keccak.rs index 3d08e90..8737810 100644 --- a/tests/sled_keccak.rs +++ b/tests/sled_keccak.rs @@ -230,11 +230,11 @@ fn batch_operations() -> PmtreeResult<()> { hex!("a9bb8c3f1f12e9aa903a50c47f314b57610a3ab32f2d463293f58836def38d36") ); - mt.batch_operations(None, [], [0, 1, 2, 3])?; + mt.batch_operations(None, [], [3])?; assert_eq!( mt.root(), - hex!("b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30") + hex!("222ff5e0b5877792c2bc1670e2ccd0c2c97cd7bb1672a57d598db05092d3d72c") ); fs::remove_dir_all("batch_operations").expect("Error removing db");