feat: add incremental quin tree lib

This commit is contained in:
cedoor
2022-02-02 17:40:13 +01:00
parent 3bbc46135b
commit 9ca54897ec
12 changed files with 613 additions and 210 deletions

View File

@@ -1,9 +1,10 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;
import "./IncrementalBinaryTree.sol";
contract Test {
contract BinaryTreeTest {
using IncrementalBinaryTree for IncrementalTreeData;
event TreeCreated(bytes32 id, uint8 depth);
@@ -13,7 +14,7 @@ contract Test {
mapping(bytes32 => IncrementalTreeData) public trees;
function createTree(bytes32 _id, uint8 _depth) external {
require(trees[_id].depth == 0, "Test: tree already exists");
require(trees[_id].depth == 0, "BinaryTreeTest: tree already exists");
trees[_id].init(_depth, 0);
@@ -21,7 +22,7 @@ contract Test {
}
function insertLeaf(bytes32 _treeId, uint256 _leaf) external {
require(trees[_treeId].depth != 0, "Test: tree does not exist");
require(trees[_treeId].depth != 0, "BinaryTreeTest: tree does not exist");
trees[_treeId].insert(_leaf);
@@ -34,7 +35,7 @@ contract Test {
uint256[] memory _proofSiblings,
uint8[] memory _proofPathIndices
) external {
require(trees[_treeId].depth != 0, "Test: tree does not exist");
require(trees[_treeId].depth != 0, "BinaryTreeTest: tree does not exist");
trees[_treeId].remove(_leaf, _proofSiblings, _proofPathIndices);

View File

@@ -6,151 +6,135 @@ import {PoseidonT3} from "./Hashes.sol";
// Each incremental tree has certain properties and data that will
// be used to add new leaves.
struct IncrementalTreeData {
uint8 depth; // Depth of the tree (levels - 1).
uint256 root; // Root hash of the tree.
uint256 numberOfLeaves; // Number of leaves of the tree.
mapping(uint256 => uint256) zeroes; // Zero hashes used for empty nodes (level -> zero hash).
// The nodes of the subtrees used in the last addition of a leaf (level -> [left node, right node]).
mapping(uint256 => uint256[2]) lastSubtrees; // Caching these values is essential to efficient appends.
uint8 depth; // Depth of the tree (levels - 1).
uint256 root; // Root hash of the tree.
uint256 numberOfLeaves; // Number of leaves of the tree.
mapping(uint256 => uint256) zeroes; // Zero hashes used for empty nodes (level -> zero hash).
// The nodes of the subtrees used in the last addition of a leaf (level -> [left node, right node]).
mapping(uint256 => uint256[2]) lastSubtrees; // Caching these values is essential to efficient appends.
}
/// @title Incremental binary Merkle tree.
/// @dev The incremental tree allows to calculate the root hash each time a leaf is added, ensuring
/// the integrity of the tree.
library IncrementalBinaryTree {
uint8 internal constant MAX_DEPTH = 32;
uint256 internal constant SNARK_SCALAR_FIELD =
21888242871839275222246405745257275088548364400416034343698204186575808495617;
uint8 internal constant MAX_DEPTH = 32;
uint256 internal constant SNARK_SCALAR_FIELD =
21888242871839275222246405745257275088548364400416034343698204186575808495617;
/// @dev Initializes a tree.
/// @param self: Tree data.
/// @param depth: Depth of the tree.
/// @param zero: Zero value to be used.
function init(
IncrementalTreeData storage self,
uint8 depth,
uint256 zero
) public {
require(
depth > 0 && depth <= MAX_DEPTH,
"IncrementalBinaryTree: tree depth must be between 1 and 32"
);
/// @dev Initializes a tree.
/// @param self: Tree data.
/// @param depth: Depth of the tree.
/// @param zero: Zero value to be used.
function init(
IncrementalTreeData storage self,
uint8 depth,
uint256 zero
) public {
require(depth > 0 && depth <= MAX_DEPTH, "IncrementalBinaryTree: tree depth must be between 1 and 32");
self.depth = depth;
self.depth = depth;
for (uint8 i = 0; i < depth; i++) {
self.zeroes[i] = zero;
zero = PoseidonT3.poseidon([zero, zero]);
}
self.root = zero;
for (uint8 i = 0; i < depth; i++) {
self.zeroes[i] = zero;
zero = PoseidonT3.poseidon([zero, zero]);
}
/// @dev Inserts a leaf in the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be inserted.
function insert(IncrementalTreeData storage self, uint256 leaf) public {
require(
leaf < SNARK_SCALAR_FIELD,
"IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"
);
require(
self.numberOfLeaves < 2**self.depth,
"IncrementalBinaryTree: tree is full"
);
self.root = zero;
}
uint256 index = self.numberOfLeaves;
uint256 hash = leaf;
/// @dev Inserts a leaf in the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be inserted.
function insert(IncrementalTreeData storage self, uint256 leaf) public {
require(leaf < SNARK_SCALAR_FIELD, "IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD");
require(self.numberOfLeaves < 2**self.depth, "IncrementalBinaryTree: tree is full");
for (uint8 i = 0; i < self.depth; i++) {
if (index % 2 == 0) {
self.lastSubtrees[i] = [hash, self.zeroes[i]];
} else {
self.lastSubtrees[i][1] = hash;
}
uint256 index = self.numberOfLeaves;
uint256 hash = leaf;
hash = PoseidonT3.poseidon(self.lastSubtrees[i]);
index /= 2;
}
for (uint8 i = 0; i < self.depth; i++) {
if (index % 2 == 0) {
self.lastSubtrees[i] = [hash, self.zeroes[i]];
} else {
self.lastSubtrees[i][1] = hash;
}
self.root = hash;
self.numberOfLeaves += 1;
hash = PoseidonT3.poseidon(self.lastSubtrees[i]);
index /= 2;
}
/// @dev Removes a leaf from the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblingNodes: Array of the sibling nodes of the proof of membership.
/// @param proofPath: Path of the proof of membership.
function remove(
IncrementalTreeData storage self,
uint256 leaf,
uint256[] memory proofSiblingNodes,
uint8[] memory proofPath
) public {
require(
verify(self, leaf, proofSiblingNodes, proofPath),
"IncrementalBinaryTree: leaf is not part of the tree"
);
self.root = hash;
self.numberOfLeaves += 1;
}
uint256 hash = self.zeroes[0];
/// @dev Removes a leaf from the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblings: Array of the sibling nodes of the proof of membership.
/// @param proofPathIndices: Path of the proof of membership.
function remove(
IncrementalTreeData storage self,
uint256 leaf,
uint256[] memory proofSiblings,
uint8[] memory proofPathIndices
) public {
require(verify(self, leaf, proofSiblings, proofPathIndices), "IncrementalBinaryTree: leaf is not part of the tree");
for (uint8 i = 0; i < self.depth; i++) {
if (proofPath[i] == 0) {
if (proofSiblingNodes[i] == self.lastSubtrees[i][1]) {
self.lastSubtrees[i][0] = hash;
}
uint256 hash = self.zeroes[0];
hash = PoseidonT3.poseidon([hash, proofSiblingNodes[i]]);
} else {
if (proofSiblingNodes[i] == self.lastSubtrees[i][0]) {
self.lastSubtrees[i][1] = hash;
}
hash = PoseidonT3.poseidon([proofSiblingNodes[i], hash]);
}
for (uint8 i = 0; i < self.depth; i++) {
if (proofPathIndices[i] == 0) {
if (proofSiblings[i] == self.lastSubtrees[i][1]) {
self.lastSubtrees[i][0] = hash;
}
self.root = hash;
}
/// @dev Verify if the path is correct and the leaf is part of the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblingNodes: Array of the sibling nodes of the proof of membership.
/// @param proofPath: Path of the proof of membership.
/// @return True or false.
function verify(
IncrementalTreeData storage self,
uint256 leaf,
uint256[] memory proofSiblingNodes,
uint8[] memory proofPath
) private view returns (bool) {
require(
leaf < SNARK_SCALAR_FIELD,
"IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"
);
require(
proofPath.length == self.depth &&
proofSiblingNodes.length == self.depth,
"IncrementalBinaryTree: length of path is not correct"
);
uint256 hash = leaf;
for (uint8 i = 0; i < self.depth; i++) {
require(
proofSiblingNodes[i] < SNARK_SCALAR_FIELD,
"IncrementalBinaryTree: sibling node must be < SNARK_SCALAR_FIELD"
);
if (proofPath[i] == 0) {
hash = PoseidonT3.poseidon([hash, proofSiblingNodes[i]]);
} else {
hash = PoseidonT3.poseidon([proofSiblingNodes[i], hash]);
}
hash = PoseidonT3.poseidon([hash, proofSiblings[i]]);
} else {
if (proofSiblings[i] == self.lastSubtrees[i][0]) {
self.lastSubtrees[i][1] = hash;
}
return hash == self.root;
hash = PoseidonT3.poseidon([proofSiblings[i], hash]);
}
}
self.root = hash;
}
/// @dev Verify if the path is correct and the leaf is part of the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblings: Array of the sibling nodes of the proof of membership.
/// @param proofPathIndices: Path of the proof of membership.
/// @return True or false.
function verify(
IncrementalTreeData storage self,
uint256 leaf,
uint256[] memory proofSiblings,
uint8[] memory proofPathIndices
) private view returns (bool) {
require(leaf < SNARK_SCALAR_FIELD, "IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD");
require(
proofPathIndices.length == self.depth && proofSiblings.length == self.depth,
"IncrementalBinaryTree: length of path is not correct"
);
uint256 hash = leaf;
for (uint8 i = 0; i < self.depth; i++) {
require(
proofSiblings[i] < SNARK_SCALAR_FIELD,
"IncrementalBinaryTree: sibling node must be < SNARK_SCALAR_FIELD"
);
if (proofPathIndices[i] == 0) {
hash = PoseidonT3.poseidon([hash, proofSiblings[i]]);
} else {
hash = PoseidonT3.poseidon([proofSiblings[i], hash]);
}
}
return hash == self.root;
}
}

View File

@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;
import {PoseidonT6} from "./Hashes.sol";
// Each incremental tree has certain properties and data that will
// be used to add new leaves.
struct IncrementalTreeData {
uint8 depth; // Depth of the tree (levels - 1).
uint256 root; // Root hash of the tree.
uint256 numberOfLeaves; // Number of leaves of the tree.
mapping(uint256 => uint256) zeroes; // Zero hashes used for empty nodes (level -> zero hash).
// The nodes of the subtrees used in the last addition of a leaf (level -> [nodes]).
mapping(uint256 => uint256[5]) lastSubtrees; // Caching these values is essential to efficient appends.
}
/// @title Incremental quin Merkle tree.
/// @dev The incremental tree allows to calculate the root hash each time a leaf is added, ensuring
/// the integrity of the tree.
library IncrementalQuinTree {
uint8 internal constant MAX_DEPTH = 32;
uint256 internal constant SNARK_SCALAR_FIELD =
21888242871839275222246405745257275088548364400416034343698204186575808495617;
/// @dev Initializes a tree.
/// @param self: Tree data.
/// @param depth: Depth of the tree.
/// @param zero: Zero value to be used.
function init(
IncrementalTreeData storage self,
uint8 depth,
uint256 zero
) public {
require(depth > 0 && depth <= MAX_DEPTH, "IncrementalQuinTree: tree depth must be between 1 and 32");
self.depth = depth;
for (uint8 i = 0; i < depth; i++) {
self.zeroes[i] = zero;
uint256[5] memory zeroChildren;
for (uint8 j = 0; j < 5; j++) {
zeroChildren[j] = zero;
}
zero = PoseidonT6.poseidon(zeroChildren);
}
self.root = zero;
}
/// @dev Inserts a leaf in the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be inserted.
function insert(IncrementalTreeData storage self, uint256 leaf) public {
require(leaf < SNARK_SCALAR_FIELD, "IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD");
require(self.numberOfLeaves < 5**self.depth, "IncrementalQuinTree: tree is full");
uint256 index = self.numberOfLeaves;
uint256 hash = leaf;
for (uint8 i = 0; i < self.depth; i++) {
uint8 position = uint8(index % 5);
self.lastSubtrees[i][position] = hash;
if (position == 0) {
for (uint8 j = 1; j < 5; j++) {
self.lastSubtrees[i][j] = self.zeroes[i];
}
}
hash = PoseidonT6.poseidon(self.lastSubtrees[i]);
index /= 5;
}
self.root = hash;
self.numberOfLeaves += 1;
}
/// @dev Removes a leaf from the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblings: Array of the sibling nodes of the proof of membership.
/// @param proofPathIndices: Path of the proof of membership.
function remove(
IncrementalTreeData storage self,
uint256 leaf,
uint256[4][] memory proofSiblings,
uint8[] memory proofPathIndices
) public {
require(verify(self, leaf, proofSiblings, proofPathIndices), "IncrementalQuinTree: leaf is not part of the tree");
uint256 hash = self.zeroes[0];
for (uint8 i = 0; i < self.depth; i++) {
uint256[5] memory nodes;
for (uint8 j = 0; j < 5; j++) {
if (j < proofPathIndices[i]) {
nodes[j] = proofSiblings[i][j];
} else if (j == proofPathIndices[i]) {
nodes[j] = hash;
} else {
nodes[j] = proofSiblings[i][j - 1];
}
}
if (nodes[0] == self.lastSubtrees[i][0] || nodes[4] == self.lastSubtrees[i][4]) {
self.lastSubtrees[i][proofPathIndices[i]] = hash;
}
hash = PoseidonT6.poseidon(nodes);
}
self.root = hash;
}
/// @dev Verify if the path is correct and the leaf is part of the tree.
/// @param self: Tree data.
/// @param leaf: Leaf to be removed.
/// @param proofSiblings: Array of the sibling nodes of the proof of membership.
/// @param proofPathIndices: Path of the proof of membership.
/// @return True or false.
function verify(
IncrementalTreeData storage self,
uint256 leaf,
uint256[4][] memory proofSiblings,
uint8[] memory proofPathIndices
) private view returns (bool) {
require(leaf < SNARK_SCALAR_FIELD, "IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD");
require(
proofPathIndices.length == self.depth && proofSiblings.length == self.depth,
"IncrementalQuinTree: length of path is not correct"
);
uint256 hash = leaf;
for (uint8 i = 0; i < self.depth; i++) {
uint256[5] memory nodes;
for (uint8 j = 0; j < 5; j++) {
if (j < proofPathIndices[i]) {
require(
proofSiblings[i][j] < SNARK_SCALAR_FIELD,
"IncrementalQuinTree: sibling node must be < SNARK_SCALAR_FIELD"
);
nodes[j] = proofSiblings[i][j];
} else if (j == proofPathIndices[i]) {
nodes[j] = hash;
} else {
require(
proofSiblings[i][j - 1] < SNARK_SCALAR_FIELD,
"IncrementalQuinTree: sibling node must be < SNARK_SCALAR_FIELD"
);
nodes[j] = proofSiblings[i][j - 1];
}
}
hash = PoseidonT6.poseidon(nodes);
}
return hash == self.root;
}
}

View File

@@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;
import "./IncrementalQuinTree.sol";
contract QuinTreeTest {
using IncrementalQuinTree for IncrementalTreeData;
event TreeCreated(bytes32 id, uint8 depth);
event LeafInserted(bytes32 indexed treeId, uint256 leaf, uint256 root);
event LeafRemoved(bytes32 indexed treeId, uint256 leaf, uint256 root);
mapping(bytes32 => IncrementalTreeData) public trees;
function createTree(bytes32 _id, uint8 _depth) external {
require(trees[_id].depth == 0, "QuinTreeTest: tree already exists");
trees[_id].init(_depth, 0);
emit TreeCreated(_id, _depth);
}
function insertLeaf(bytes32 _treeId, uint256 _leaf) external {
require(trees[_treeId].depth != 0, "QuinTreeTest: tree does not exist");
trees[_treeId].insert(_leaf);
emit LeafInserted(_treeId, _leaf, trees[_treeId].root);
}
function removeLeaf(
bytes32 _treeId,
uint256 _leaf,
uint256[4][] memory _proofSiblings,
uint8[] memory _proofPathIndices
) external {
require(trees[_treeId].depth != 0, "QuinTreeTest: tree does not exist");
trees[_treeId].remove(_leaf, _proofSiblings, _proofPathIndices);
emit LeafRemoved(_treeId, _leaf, trees[_treeId].root);
}
}

View File

@@ -6,7 +6,8 @@ import { HardhatUserConfig } from "hardhat/config"
import { resolve } from "path"
import "solidity-coverage"
import { config } from "./package.json"
import "./tasks/deploy-test"
import "./tasks/deploy-binary-tree-test"
import "./tasks/deploy-quin-tree-test"
dotenvConfig({ path: resolve(__dirname, "./.env") })
@@ -16,7 +17,7 @@ const hardhatConfig: HardhatUserConfig = {
sources: config.paths.contracts,
tests: config.paths.tests,
cache: config.paths.cache,
artifacts: config.paths.build.contracts
artifacts: config.paths.build
},
gasReporter: {
currency: "USD",

View File

@@ -6,6 +6,8 @@
"files": [
"contracts/",
"!contracts/Test.sol",
"build/",
"!build/contracts/Test.sol",
"README.md"
],
"keywords": [
@@ -27,13 +29,15 @@
},
"scripts": {
"start": "hardhat node",
"compile": "hardhat compile",
"deploy:test": "hardhat deploy:test",
"build": "hardhat compile",
"deploy:binary-tree-test": "hardhat deploy:binary-tree-test",
"deploy:quin-tree-test": "hardhat deploy:quin-tree-test",
"test": "hardhat test",
"test:report-gas": "REPORT_GAS=true hardhat test",
"test:coverage": "hardhat coverage",
"test:prod": "yarn lint && yarn coverage",
"lint": "solhint 'contracts/**/*.sol'"
"lint": "solhint 'contracts/**/*.sol'",
"prepublishOnly": "yarn build"
},
"publishConfig": {
"access": "public"
@@ -64,14 +68,9 @@
},
"paths": {
"contracts": "./contracts",
"circuit": "./circuit",
"tests": "./test",
"cache": "./cache",
"build": {
"snark": "./build/snark",
"contracts": "./build/contracts",
"typechain": "./build/typechain"
}
"build": "./build"
}
}
}

View File

@@ -2,7 +2,7 @@ import { poseidon_gencontract as poseidonContract } from "circomlibjs"
import { Contract } from "ethers"
import { task, types } from "hardhat/config"
task("deploy:test", "Deploy a Test contract")
task("deploy:binary-tree-test", "Deploy a BinaryTreeTest contract")
.addOptionalParam<boolean>("logs", "Print the logs", true, types.boolean)
.setAction(async ({ logs }, { ethers }): Promise<Contract> => {
const poseidonT3ABI = poseidonContract.generateABI(2)
@@ -10,41 +10,29 @@ task("deploy:test", "Deploy a Test contract")
const [signer] = await ethers.getSigners()
const PoseidonLibT3Factory = new ethers.ContractFactory(
poseidonT3ABI,
poseidonT3Bytecode,
signer
)
const PoseidonLibT3Factory = new ethers.ContractFactory(poseidonT3ABI, poseidonT3Bytecode, signer)
const poseidonT3Lib = await PoseidonLibT3Factory.deploy()
await poseidonT3Lib.deployed()
if (logs) {
console.info(
`PoseidonT3 library has been deployed to: ${poseidonT3Lib.address}`
)
console.info(`PoseidonT3 library has been deployed to: ${poseidonT3Lib.address}`)
}
const IncrementalBinaryTreeLibFactory = await ethers.getContractFactory(
"IncrementalBinaryTree",
{
libraries: {
PoseidonT3: poseidonT3Lib.address
}
const IncrementalBinaryTreeLibFactory = await ethers.getContractFactory("IncrementalBinaryTree", {
libraries: {
PoseidonT3: poseidonT3Lib.address
}
)
const incrementalBinaryTreeLib =
await IncrementalBinaryTreeLibFactory.deploy()
})
const incrementalBinaryTreeLib = await IncrementalBinaryTreeLibFactory.deploy()
await incrementalBinaryTreeLib.deployed()
if (logs) {
console.info(
`IncrementalBinaryTree library has been deployed to: ${incrementalBinaryTreeLib.address}`
)
console.info(`IncrementalBinaryTree library has been deployed to: ${incrementalBinaryTreeLib.address}`)
}
const ContractFactory = await ethers.getContractFactory("Test", {
const ContractFactory = await ethers.getContractFactory("BinaryTreeTest", {
libraries: {
IncrementalBinaryTree: incrementalBinaryTreeLib.address
}

View File

@@ -0,0 +1,50 @@
import { poseidon_gencontract as poseidonContract } from "circomlibjs"
import { Contract } from "ethers"
import { task, types } from "hardhat/config"
task("deploy:quin-tree-test", "Deploy a QuinTreeTest contract")
.addOptionalParam<boolean>("logs", "Print the logs", true, types.boolean)
.setAction(async ({ logs }, { ethers }): Promise<Contract> => {
const poseidonT6ABI = poseidonContract.generateABI(5)
const poseidonT6Bytecode = poseidonContract.createCode(5)
const [signer] = await ethers.getSigners()
const PoseidonLibT6Factory = new ethers.ContractFactory(poseidonT6ABI, poseidonT6Bytecode, signer)
const poseidonT6Lib = await PoseidonLibT6Factory.deploy()
await poseidonT6Lib.deployed()
if (logs) {
console.info(`PoseidonT6 library has been deployed to: ${poseidonT6Lib.address}`)
}
const IncrementalQuinTreeLibFactory = await ethers.getContractFactory("IncrementalQuinTree", {
libraries: {
PoseidonT6: poseidonT6Lib.address
}
})
const incrementalQuinTreeLib = await IncrementalQuinTreeLibFactory.deploy()
await incrementalQuinTreeLib.deployed()
if (logs) {
console.info(`IncrementalQuinTree library has been deployed to: ${incrementalQuinTreeLib.address}`)
}
const ContractFactory = await ethers.getContractFactory("QuinTreeTest", {
libraries: {
IncrementalQuinTree: incrementalQuinTreeLib.address
}
})
const contract = await ContractFactory.deploy()
await contract.deployed()
if (logs) {
console.info(`Test contract has been deployed to: ${contract.address}`)
}
return contract
})

View File

@@ -4,7 +4,7 @@ import { ethers, run } from "hardhat"
import { createTree } from "./utils"
/* eslint-disable jest/valid-expect */
describe("Test", () => {
describe("BinaryTreeTest", () => {
let contract: Contract
const treeId = ethers.utils.formatBytes32String("treeId")
@@ -12,29 +12,25 @@ describe("Test", () => {
const depth = 16
before(async () => {
contract = await run("deploy:test", { logs: false })
contract = await run("deploy:binary-tree-test", { logs: false })
})
it("Should not create a tree with a depth > 32", async () => {
const transaction = contract.createTree(treeId, 33)
await expect(transaction).to.be.revertedWith(
"IncrementalBinaryTree: tree depth must be between 1 and 32"
)
await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: tree depth must be between 1 and 32")
})
it("Should create a tree", async () => {
const transaction = contract.createTree(treeId, depth)
await expect(transaction)
.to.emit(contract, "TreeCreated")
.withArgs(treeId, depth)
await expect(transaction).to.emit(contract, "TreeCreated").withArgs(treeId, depth)
})
it("Should not create a tree with an existing id", async () => {
const transaction = contract.createTree(treeId, depth)
await expect(transaction).to.be.revertedWith("Test: tree already exists")
await expect(transaction).to.be.revertedWith("BinaryTreeTest: tree already exists")
})
it("Should not insert a leaf if the tree does not exist", async () => {
@@ -42,31 +38,37 @@ describe("Test", () => {
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith("Test: tree does not exist")
await expect(transaction).to.be.revertedWith("BinaryTreeTest: tree does not exist")
})
it("Should not insert a leaf if its value is > SNARK_SCALAR_FIELD", async () => {
const leaf = BigInt(
"21888242871839275222246405745257275088548364400416034343698204186575808495618"
)
const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618")
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith(
"IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"
)
await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD")
})
it("Should insert a leaf in a tree", async () => {
const tree = createTree(depth, 1)
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction)
.to.emit(contract, "LeafInserted")
.withArgs(
treeId,
leaf,
"16211261537006706331557500769845541584780950636316907182067421710925347020533"
)
await expect(transaction).to.emit(contract, "LeafInserted").withArgs(treeId, leaf, tree.root)
})
it("Should insert 4 leaves in a tree", async () => {
const treeId = ethers.utils.formatBytes32String("tree2")
await contract.createTree(treeId, depth)
const tree = createTree(depth, 0)
for (let i = 0; i < 4; i += 1) {
tree.insert(BigInt(i + 1))
const transaction = contract.insertLeaf(treeId, BigInt(i + 1))
await expect(transaction)
.to.emit(contract, "LeafInserted")
.withArgs(treeId, BigInt(i + 1), tree.root)
}
})
it("Should not insert a leaf if the tree is full", async () => {
@@ -78,9 +80,7 @@ describe("Test", () => {
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith(
"IncrementalBinaryTree: tree is full"
)
await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: tree is full")
})
it("Should not remove a leaf if the tree does not exist", async () => {
@@ -88,19 +88,15 @@ describe("Test", () => {
const transaction = contract.removeLeaf(treeId, leaf, [0, 1], [0, 1])
await expect(transaction).to.be.revertedWith("Test: tree does not exist")
await expect(transaction).to.be.revertedWith("BinaryTreeTest: tree does not exist")
})
it("Should not remove a leaf if its value is > SNARK_SCALAR_FIELD", async () => {
const leaf = BigInt(
"21888242871839275222246405745257275088548364400416034343698204186575808495618"
)
const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618")
const transaction = contract.removeLeaf(treeId, leaf, [0, 1], [0, 1])
await expect(transaction).to.be.revertedWith(
"IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD"
)
await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: leaf must be < SNARK_SCALAR_FIELD")
})
it("Should remove a leaf", async () => {
@@ -123,9 +119,7 @@ describe("Test", () => {
pathIndices
)
await expect(transaction)
.to.emit(contract, "LeafRemoved")
.withArgs(treeId, BigInt(1), root)
await expect(transaction).to.emit(contract, "LeafRemoved").withArgs(treeId, BigInt(1), root)
})
it("Should remove another leaf", async () => {
@@ -144,9 +138,7 @@ describe("Test", () => {
pathIndices
)
await expect(transaction)
.to.emit(contract, "LeafRemoved")
.withArgs(treeId, BigInt(2), root)
await expect(transaction).to.emit(contract, "LeafRemoved").withArgs(treeId, BigInt(2), root)
})
it("Should not remove a leaf that does not exist", async () => {
@@ -165,9 +157,7 @@ describe("Test", () => {
pathIndices
)
await expect(transaction).to.be.revertedWith(
"IncrementalBinaryTree: leaf is not part of the tree"
)
await expect(transaction).to.be.revertedWith("IncrementalBinaryTree: leaf is not part of the tree")
})
it("Should insert a leaf in a tree after a removal", async () => {
@@ -179,9 +169,7 @@ describe("Test", () => {
const transaction = contract.insertLeaf(treeId, BigInt(4))
await expect(transaction)
.to.emit(contract, "LeafInserted")
.withArgs(treeId, BigInt(4), tree.root)
await expect(transaction).to.emit(contract, "LeafInserted").withArgs(treeId, BigInt(4), tree.root)
})
it("Should insert 4 leaves and remove them all", async () => {

View File

@@ -0,0 +1,185 @@
import { expect } from "chai"
import { Contract } from "ethers"
import { ethers, run } from "hardhat"
import { createTree } from "./utils"
/* eslint-disable jest/valid-expect */
describe("QuinTreeTest", () => {
let contract: Contract
const treeId = ethers.utils.formatBytes32String("treeId")
const leaf = BigInt(1)
const depth = 16
before(async () => {
contract = await run("deploy:quin-tree-test", { logs: false })
})
it("Should not create a tree with a depth > 32", async () => {
const transaction = contract.createTree(treeId, 33)
await expect(transaction).to.be.revertedWith("IncrementalQuinTree: tree depth must be between 1 and 32")
})
it("Should create a tree", async () => {
const transaction = contract.createTree(treeId, depth)
await expect(transaction).to.emit(contract, "TreeCreated").withArgs(treeId, depth)
})
it("Should not create a tree with an existing id", async () => {
const transaction = contract.createTree(treeId, depth)
await expect(transaction).to.be.revertedWith("QuinTreeTest: tree already exists")
})
it("Should not insert a leaf if the tree does not exist", async () => {
const treeId = ethers.utils.formatBytes32String("treeId2")
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith("QuinTreeTest: tree does not exist")
})
it("Should not insert a leaf if its value is > SNARK_SCALAR_FIELD", async () => {
const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618")
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith("IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD")
})
it("Should insert a leaf in a tree", async () => {
const tree = createTree(depth, 1, 5)
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.emit(contract, "LeafInserted").withArgs(treeId, leaf, tree.root)
})
it("Should insert 6 leaves in a tree", async () => {
const treeId = ethers.utils.formatBytes32String("tree2")
await contract.createTree(treeId, depth)
const tree = createTree(depth, 0, 5)
for (let i = 0; i < 6; i += 1) {
tree.insert(BigInt(i + 1))
const transaction = contract.insertLeaf(treeId, BigInt(i + 1))
await expect(transaction)
.to.emit(contract, "LeafInserted")
.withArgs(treeId, BigInt(i + 1), tree.root)
}
})
it("Should not insert a leaf if the tree is full", async () => {
const treeId = ethers.utils.formatBytes32String("tinyTree")
await contract.createTree(treeId, 1)
await contract.insertLeaf(treeId, leaf)
await contract.insertLeaf(treeId, leaf)
await contract.insertLeaf(treeId, leaf)
await contract.insertLeaf(treeId, leaf)
await contract.insertLeaf(treeId, leaf)
const transaction = contract.insertLeaf(treeId, leaf)
await expect(transaction).to.be.revertedWith("IncrementalQuinTree: tree is full")
})
it("Should not remove a leaf if the tree does not exist", async () => {
const treeId = ethers.utils.formatBytes32String("none")
const transaction = contract.removeLeaf(treeId, leaf, [[0, 1, 2, 3]], [0])
await expect(transaction).to.be.revertedWith("QuinTreeTest: tree does not exist")
})
it("Should not remove a leaf if its value is > SNARK_SCALAR_FIELD", async () => {
const leaf = BigInt("21888242871839275222246405745257275088548364400416034343698204186575808495618")
const transaction = contract.removeLeaf(treeId, leaf, [[0, 1, 2, 3]], [0])
await expect(transaction).to.be.revertedWith("IncrementalQuinTree: leaf must be < SNARK_SCALAR_FIELD")
})
it("Should remove a leaf", async () => {
const treeId = ethers.utils.formatBytes32String("hello")
const tree = createTree(depth, 3, 5)
tree.delete(0)
await contract.createTree(treeId, depth)
await contract.insertLeaf(treeId, BigInt(1))
await contract.insertLeaf(treeId, BigInt(2))
await contract.insertLeaf(treeId, BigInt(3))
const { siblings, pathIndices, root } = tree.createProof(0)
const transaction = contract.removeLeaf(treeId, BigInt(1), siblings, pathIndices)
await expect(transaction).to.emit(contract, "LeafRemoved").withArgs(treeId, BigInt(1), root)
})
it("Should remove another leaf", async () => {
const treeId = ethers.utils.formatBytes32String("hello")
const tree = createTree(depth, 3, 5)
tree.delete(0)
tree.delete(1)
const { siblings, pathIndices, root } = tree.createProof(1)
const transaction = contract.removeLeaf(treeId, BigInt(2), siblings, pathIndices)
await expect(transaction).to.emit(contract, "LeafRemoved").withArgs(treeId, BigInt(2), root)
})
it("Should not remove a leaf that does not exist", async () => {
const treeId = ethers.utils.formatBytes32String("hello")
const tree = createTree(depth, 3, 5)
tree.delete(0)
tree.delete(1)
const { siblings, pathIndices } = tree.createProof(0)
const transaction = contract.removeLeaf(treeId, BigInt(4), siblings, pathIndices)
await expect(transaction).to.be.revertedWith("IncrementalQuinTree: leaf is not part of the tree")
})
it("Should insert a leaf in a tree after a removal", async () => {
const treeId = ethers.utils.formatBytes32String("hello")
const tree = createTree(depth, 4, 5)
tree.delete(0)
tree.delete(1)
const transaction = contract.insertLeaf(treeId, BigInt(4))
await expect(transaction).to.emit(contract, "LeafInserted").withArgs(treeId, BigInt(4), tree.root)
})
it("Should insert 4 leaves and remove them all", async () => {
const treeId = ethers.utils.formatBytes32String("complex")
const tree = createTree(depth, 4, 5)
await contract.createTree(treeId, depth)
for (let i = 0; i < 4; i += 1) {
await contract.insertLeaf(treeId, BigInt(i + 1))
}
for (let i = 0; i < 4; i += 1) {
tree.delete(i)
const { siblings, pathIndices } = tree.createProof(i)
await contract.removeLeaf(treeId, BigInt(i + 1), siblings, pathIndices)
}
const { root } = await contract.trees(treeId)
expect(root).to.equal(tree.root)
})
})

View File

@@ -2,12 +2,7 @@ import { IncrementalMerkleTree } from "@zk-kit/incremental-merkle-tree"
import { poseidon } from "circomlibjs"
/* eslint-disable import/prefer-default-export */
export function createTree(
depth: number,
numberOfNodes = 0,
zeroValue = BigInt(0),
arity = 2
): IncrementalMerkleTree {
export function createTree(depth: number, numberOfNodes = 0, arity = 2, zeroValue = BigInt(0)): IncrementalMerkleTree {
const tree = new IncrementalMerkleTree(poseidon, depth, zeroValue, arity)
for (let i = 0; i < numberOfNodes; i += 1) {

View File

@@ -12,6 +12,7 @@ export default function verifyProof(proof: MerkleProof, hash: HashFunction): boo
let node = proof.leaf
for (let i = 0; i < proof.siblings.length; i += 1) {
// TODO: bug splice update the structure of the object
proof.siblings[i].splice(proof.pathIndices[i], 0, node)
node = hash(proof.siblings[i])