diff --git a/src/tree.rs b/src/tree.rs index 99d14c4..5739b78 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -193,8 +193,9 @@ where } /// Batch insertion, updates the tree in parallel. - pub fn batch_insert(&mut self, leaves: &[H::Fr]) -> Result<()> { - let end = self.next_index + leaves.len(); + pub fn batch_insert(&mut self, start: Option, leaves: &[H::Fr]) -> Result<()> { + let start = start.unwrap_or(self.next_index); + let end = start + leaves.len(); if end > self.capacity() { return Err(Box::new(Error( @@ -207,14 +208,7 @@ where let root_key = Key(0, 0); subtree.insert(root_key, self.root); - self.fill_nodes( - root_key, - self.next_index, - end, - &mut subtree, - leaves, - self.next_index, - )?; + self.fill_nodes(root_key, start, end, &mut subtree, leaves, start)?; let subtree = Arc::new(RwLock::new(subtree)); @@ -233,17 +227,24 @@ where .collect(), )?; - // Update root value and next_index in memory - self.root = root_val; - self.next_index = end; - // Update next_index value in db - self.db - .put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?; + + if start + leaves.len() > self.next_index { + self.next_index = start + leaves.len(); + self.db + .put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?; + } + + // Update root value in memory + self.root = root_val; Ok(()) } + pub fn set_range(&mut self, start: usize, leaves: &[H::Fr]) -> Result<()> { + self.batch_insert(Some(start), leaves) + } + // Fills hashmap subtree fn fill_nodes( &self, diff --git a/tests/memory_keccak.rs b/tests/memory_keccak.rs index 53bb623..6d3b5cc 100644 --- a/tests/memory_keccak.rs +++ b/tests/memory_keccak.rs @@ -120,7 +120,7 @@ fn batch_insertions() -> Result<()> { hex!("0000000000000000000000000000000000000000000000000000000000000004"), ]; - mt.batch_insert(&leaves)?; + mt.batch_insert(None, &leaves)?; assert_eq!( mt.root(), @@ -129,3 +129,22 @@ fn batch_insertions() -> Result<()> { Ok(()) } + +#[test] +fn set_range() -> Result<()> { + let mut mt = MerkleTree::::new(2, MemoryDBConfig)?; + + let leaves = [ + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + hex!("0000000000000000000000000000000000000000000000000000000000000002"), + ]; + + mt.set_range(2, &leaves)?; + + assert_eq!( + mt.root(), + hex!("1e9f6c8d3fd5b7ae3a29792adb094c6d4cc6149d0c81c8c8e57cf06c161a92b8") + ); + + Ok(()) +} diff --git a/tests/sled_keccak.rs b/tests/sled_keccak.rs index 0f8a4eb..c26d28a 100644 --- a/tests/sled_keccak.rs +++ b/tests/sled_keccak.rs @@ -157,7 +157,7 @@ fn batch_insertions() -> Result<()> { hex!("0000000000000000000000000000000000000000000000000000000000000004"), ]; - mt.batch_insert(&leaves)?; + mt.batch_insert(None, &leaves)?; assert_eq!( mt.root(), @@ -168,3 +168,29 @@ fn batch_insertions() -> Result<()> { Ok(()) } + +#[test] +fn set_range() -> Result<()> { + let mut mt = MerkleTree::::new( + 2, + SledConfig { + path: String::from("abacabasab"), + }, + )?; + + let leaves = [ + hex!("0000000000000000000000000000000000000000000000000000000000000001"), + hex!("0000000000000000000000000000000000000000000000000000000000000002"), + ]; + + mt.set_range(2, &leaves)?; + + assert_eq!( + mt.root(), + hex!("1e9f6c8d3fd5b7ae3a29792adb094c6d4cc6149d0c81c8c8e57cf06c161a92b8") + ); + + fs::remove_dir_all("abacabasab").expect("Error removing db"); + + Ok(()) +}