diff --git a/packages/incremental-merkle-tree.sol/contracts/IncrementalBinaryTree.sol b/packages/incremental-merkle-tree.sol/contracts/IncrementalBinaryTree.sol index 3b5052e..1e1c867 100644 --- a/packages/incremental-merkle-tree.sol/contracts/IncrementalBinaryTree.sol +++ b/packages/incremental-merkle-tree.sol/contracts/IncrementalBinaryTree.sol @@ -71,18 +71,22 @@ library IncrementalBinaryTree { /// @dev Updates a leaf in the tree. /// @param self: Tree data. - /// @param leaf: Leaf to be updated. + /// @param leaf: [0] = existing leaf to replace, [1] = new leaf to insert /// @param proofSiblings: Array of the sibling nodes of the proof of membership. /// @param proofPathIndices: Path of the proof of membership. function update( IncrementalTreeData storage self, - uint256 leaf, + uint256[2] calldata leaf, uint256[] calldata proofSiblings, uint8[] calldata proofPathIndices ) public { - require(leaf < SNARK_SCALAR_FIELD, "IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"); + require( + verify(self, leaf[0], proofSiblings, proofPathIndices), + "IncrementalBinaryTree: provided current leaf not found" + ); + require(leaf[1] < SNARK_SCALAR_FIELD, "IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"); - uint256 hash = leaf; + uint256 hash = leaf[1]; for (uint8 i = 0; i < self.depth; i++) { if (proofPathIndices[i] == 0) { if (proofSiblings[i] == self.lastSubtrees[i][1]) { diff --git a/packages/incremental-merkle-tree.sol/contracts/IncrementalQuinTree.sol b/packages/incremental-merkle-tree.sol/contracts/IncrementalQuinTree.sol index 5d9109c..c85a9b8 100644 --- a/packages/incremental-merkle-tree.sol/contracts/IncrementalQuinTree.sol +++ b/packages/incremental-merkle-tree.sol/contracts/IncrementalQuinTree.sol @@ -81,18 +81,22 @@ library IncrementalQuinTree { /// @dev Updates a leaf in the tree. /// @param self: Tree data. - /// @param leaf: Leaf to be updated. + /// @param leaf: [0] = existing leaf to replace, [1] = new leaf to insert /// @param proofSiblings: Array of the sibling nodes of the proof of membership. /// @param proofPathIndices: Path of the proof of membership. function update( IncrementalTreeData storage self, - uint256 leaf, + uint256[2] calldata leaf, uint256[4][] calldata proofSiblings, uint8[] calldata proofPathIndices ) public { - require(leaf < SNARK_SCALAR_FIELD, "IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD"); + require( + verify(self, leaf[0], proofSiblings, proofPathIndices), + "IncrementalQuinTree: provided current leaf not found" + ); + require(leaf[1] < SNARK_SCALAR_FIELD, "IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD"); - uint256 hash = leaf; + uint256 hash = leaf[1]; for (uint8 i = 0; i < self.depth; i++) { uint256[5] memory nodes; diff --git a/packages/incremental-merkle-tree.sol/contracts/test/IncrementalBinaryTreeTest.sol b/packages/incremental-merkle-tree.sol/contracts/test/IncrementalBinaryTreeTest.sol index 85e54b6..d3c12d6 100644 --- a/packages/incremental-merkle-tree.sol/contracts/test/IncrementalBinaryTreeTest.sol +++ b/packages/incremental-merkle-tree.sol/contracts/test/IncrementalBinaryTreeTest.sol @@ -32,7 +32,7 @@ contract IncrementalBinaryTreeTest { function updateLeaf( bytes32 _treeId, - uint256 _leaf, + uint256[2] calldata _leaf, uint256[] calldata _proofSiblings, uint8[] calldata _proofPathIndices ) external { @@ -40,7 +40,7 @@ contract IncrementalBinaryTreeTest { trees[_treeId].update(_leaf, _proofSiblings, _proofPathIndices); - emit LeafUpdated(_treeId, _leaf, trees[_treeId].root); + emit LeafUpdated(_treeId, _leaf[1], trees[_treeId].root); } function removeLeaf( diff --git a/packages/incremental-merkle-tree.sol/contracts/test/IncrementalQuinTreeTest.sol b/packages/incremental-merkle-tree.sol/contracts/test/IncrementalQuinTreeTest.sol index 090af54..1cf1dfc 100644 --- a/packages/incremental-merkle-tree.sol/contracts/test/IncrementalQuinTreeTest.sol +++ b/packages/incremental-merkle-tree.sol/contracts/test/IncrementalQuinTreeTest.sol @@ -32,7 +32,7 @@ contract IncrementalQuinTreeTest { function updateLeaf( bytes32 _treeId, - uint256 _leaf, + uint256[2] calldata _leaf, uint256[4][] calldata _proofSiblings, uint8[] calldata _proofPathIndices ) external { @@ -40,7 +40,7 @@ contract IncrementalQuinTreeTest { trees[_treeId].update(_leaf, _proofSiblings, _proofPathIndices); - emit LeafUpdated(_treeId, _leaf, trees[_treeId].root); + emit LeafUpdated(_treeId, _leaf[1], trees[_treeId].root); } function removeLeaf( diff --git a/packages/incremental-merkle-tree.sol/test/IncrementalBinaryTreeTest.ts b/packages/incremental-merkle-tree.sol/test/IncrementalBinaryTreeTest.ts index b0ad131..0c2072f 100644 --- a/packages/incremental-merkle-tree.sol/test/IncrementalBinaryTreeTest.ts +++ b/packages/incremental-merkle-tree.sol/test/IncrementalBinaryTreeTest.ts @@ -86,7 +86,7 @@ describe("IncrementalBinaryTreeTest", () => { it("Should not update a leaf if the tree does not exist", async () => { const treeId = ethers.utils.formatBytes32String("none") - const transaction = contract.updateLeaf(treeId, leaf, [0, 1], [0, 1]) + const transaction = contract.updateLeaf(treeId, [leaf, leaf], [0, 1], [0, 1]) await expect(transaction).to.be.revertedWith("BinaryTreeTest: tree does not exist") }) @@ -94,11 +94,29 @@ describe("IncrementalBinaryTreeTest", () => { it("Should not update a leaf if its value is > SNARK_SCALAR_FIELD", async () => { const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618") - const transaction = contract.updateLeaf(treeId, leaf, [0, 1], [0, 1]) + const transaction = contract.updateLeaf(treeId, [leaf, leaf], [0, 1], [0, 1]) await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD") }) + it("Should not update a leaf if wrong current leaf is given", async () => { + const treeId = ethers.utils.formatBytes32String("tree2") + const tree = createTree(depth, 0) + for (let i = 0; i < 4; i += 1) tree.insert(BigInt(i + 1)) + + const leaf = BigInt(1337) + tree.update(2, leaf) + const { pathIndices, siblings } = tree.createProof(2) + const transaction = contract.updateLeaf( + treeId, + [leaf, leaf], + siblings.map((s) => s[0]), + pathIndices + ) + + await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: provided current leaf not found") + }) + it("Should update a leaf", async () => { const treeId = ethers.utils.formatBytes32String("tree2") const tree = createTree(depth, 0) @@ -109,7 +127,7 @@ describe("IncrementalBinaryTreeTest", () => { const { root, pathIndices, siblings } = tree.createProof(2) const transaction = contract.updateLeaf( treeId, - leaf, + [BigInt(3), leaf], siblings.map((s) => s[0]), pathIndices ) diff --git a/packages/incremental-merkle-tree.sol/test/IncrementalQuinTreeTest.ts b/packages/incremental-merkle-tree.sol/test/IncrementalQuinTreeTest.ts index 97bb78c..74dcede 100644 --- a/packages/incremental-merkle-tree.sol/test/IncrementalQuinTreeTest.ts +++ b/packages/incremental-merkle-tree.sol/test/IncrementalQuinTreeTest.ts @@ -89,7 +89,7 @@ describe("IncrementalQuinTreeTest", () => { it("Should not update a leaf if the tree does not exist", async () => { const treeId = ethers.utils.formatBytes32String("none") - const transaction = contract.updateLeaf(treeId, leaf, [[0, 1, 2, 3]], [0]) + const transaction = contract.updateLeaf(treeId, [leaf, leaf], [[0, 1, 2, 3]], [0]) await expect(transaction).to.be.revertedWith("QuinTreeTest: tree does not exist") }) @@ -97,11 +97,24 @@ describe("IncrementalQuinTreeTest", () => { it("Should not update a leaf if its value is > SNARK_SCALAR_FIELD", async () => { const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618") - const transaction = contract.updateLeaf(treeId, leaf, [[0, 1, 2, 3]], [0]) + const transaction = contract.updateLeaf(treeId, [leaf, leaf], [[0, 1, 2, 3]], [0]) await expect(transaction).to.be.revertedWith("IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD") }) + it("Should not update a leaf if wrong current leaf is given", async () => { + const treeId = ethers.utils.formatBytes32String("tree2") + const tree = createTree(depth, 0, 5) + for (let i = 0; i < 6; i += 1) tree.insert(BigInt(i + 1)) + + const leaf = BigInt(1337) + tree.update(2, leaf) + const { pathIndices, siblings } = tree.createProof(2) + const transaction = contract.updateLeaf(treeId, [leaf, leaf], siblings, pathIndices) + + await expect(transaction).to.be.revertedWith("IncrementalQuinTree: provided current leaf not found") + }) + it("Should update a leaf", async () => { const treeId = ethers.utils.formatBytes32String("tree2") const tree = createTree(depth, 0, 5) @@ -110,7 +123,7 @@ describe("IncrementalQuinTreeTest", () => { const leaf = BigInt(1337) tree.update(2, leaf) const { pathIndices, siblings, root } = tree.createProof(2) - const transaction = contract.updateLeaf(treeId, leaf, siblings, pathIndices) + const transaction = contract.updateLeaf(treeId, [BigInt(3), leaf], siblings, pathIndices) await expect(transaction).to.emit(contract, "LeafUpdated").withArgs(treeId, leaf, root) })