new internal datase to store large and multiple values encoded, new methods to handle big ints as keys and multiple big ints as values of tree leaves

This commit is contained in:
Lucas Menendez
2025-03-26 12:49:13 +01:00
parent b775477907
commit 72ab6fc585
9 changed files with 773 additions and 126 deletions

View File

@@ -76,7 +76,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
var keys, values [][]byte
@@ -100,7 +100,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
tree2.dbgInit()
start = time.Now()
@@ -128,7 +128,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < nLeafs; i++ {
@@ -144,7 +144,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
@@ -177,14 +177,14 @@ func TestAddBatchTestVector1(t *testing.T) {
tree1, err := NewTree(Config{database1, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
// leafs in 2nd level subtrees: [ 6, 0, 1, 1]
testvectorKeys := []string{
@@ -219,14 +219,14 @@ func TestAddBatchTestVector1(t *testing.T) {
tree1, err = NewTree(Config{database1, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err = pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err = NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
testvectorKeys = []string{
"1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
@@ -269,14 +269,14 @@ func TestAddBatchTestVector2(t *testing.T) {
tree1, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
bLen := tree1.HashFunction().Len()
var keys, values [][]byte
@@ -316,14 +316,14 @@ func TestAddBatchTestVector3(t *testing.T) {
tree1, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
bLen := tree1.HashFunction().Len()
var keys, values [][]byte
@@ -367,14 +367,14 @@ func TestAddBatchTreeEmptyRandomKeys(t *testing.T) {
tree1, err := NewTree(Config{database1, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
@@ -719,7 +719,7 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) {
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
tree2.dbgInit()
var keys, values [][]byte
@@ -797,7 +797,7 @@ func benchAdd(t *testing.T, ks, vs [][]byte) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
start := time.Now()
for i := 0; i < len(ks); i++ {
@@ -818,7 +818,7 @@ func benchAddBatch(t *testing.T, ks, vs [][]byte) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
tree.dbgInit()
@@ -852,7 +852,7 @@ func TestDbgStats(t *testing.T) {
tree1, err := NewTree(Config{database1, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
tree1.dbgInit()
@@ -867,7 +867,7 @@ func TestDbgStats(t *testing.T) {
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
tree2.dbgInit()
@@ -881,7 +881,7 @@ func TestDbgStats(t *testing.T) {
tree3, err := NewTree(Config{database3, 256, DefaultThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree3.db.Close() //nolint:errcheck
defer tree3.treedb.Close() //nolint:errcheck
tree3.dbgInit()
@@ -916,7 +916,7 @@ func TestLoadVT(t *testing.T) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
@@ -951,7 +951,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) {
tree, err := NewTree(Config{database, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
var keys, values [][]byte
@@ -973,7 +973,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) {
tree2, err := NewTree(Config{database2, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
tree2.dbgInit()
invalids, err := tree2.AddBatch(keys, values)
@@ -988,7 +988,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) {
tree3, err := NewTree(Config{database3, 256, DefaultThresholdNLeafs,
HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree3.db.Close() //nolint:errcheck
defer tree3.treedb.Close() //nolint:errcheck
invalids, err = tree3.AddBatch(keys, nil)
c.Assert(err, qt.IsNil)
@@ -1038,21 +1038,21 @@ func TestAddBatchThresholdInDisk(t *testing.T) {
tree1, err := NewTree(Config{database1, 256, testThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
database2, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(Config{database2, 256, testThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
database3, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree3, err := NewTree(Config{database3, 256, testThresholdNLeafs,
HashFunctionBlake2b})
c.Assert(err, qt.IsNil)
defer tree3.db.Close() //nolint:errcheck
defer tree3.treedb.Close() //nolint:errcheck
var keys, values [][]byte
for i := 0; i < 3*testThresholdNLeafs; i++ {
@@ -1078,7 +1078,7 @@ func TestAddBatchThresholdInDisk(t *testing.T) {
checkRoots(c, tree1, tree2)
// call directly the tree3.addBatchInDisk to ensure that is tested
wTx := tree3.db.WriteTx()
wTx := tree3.treedb.WriteTx()
defer wTx.Discard()
invalids, err = tree3.addBatchInDisk(wTx, keys, values)
c.Assert(err, qt.IsNil)
@@ -1127,7 +1127,7 @@ func testUpFromSubRoots(c *qt.C, tree1, tree2 *Tree, preSubRoots [][]byte) {
root1, err := tree1.Root()
c.Assert(err, qt.IsNil)
wTx := tree2.db.WriteTx()
wTx := tree2.treedb.WriteTx()
subRoots := make([][]byte, len(preSubRoots))
for i := 0; i < len(preSubRoots); i++ {
if preSubRoots[i] == nil || bytes.Equal(preSubRoots[i], tree1.emptyHash) {
@@ -1159,8 +1159,8 @@ func testUpFromSubRoots(c *qt.C, tree1, tree2 *Tree, preSubRoots [][]byte) {
func testUpFromSubRootsWithEmpties(c *qt.C, preSubRoots [][]byte, indexEmpties []int) {
tree1, tree2 := initTestUpFromSubRoots(c)
defer tree1.db.Close() //nolint:errcheck
defer tree2.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
testPreSubRoots := make([][]byte, len(preSubRoots))
copy(testPreSubRoots[:], preSubRoots[:])

View File

@@ -20,7 +20,7 @@ type CircomVerifierProof struct {
// MarshalJSON implements the JSON marshaler
func (cvp CircomVerifierProof) MarshalJSON() ([]byte, error) {
m := make(map[string]interface{})
m := make(map[string]any)
m["root"] = BytesToBigInt(cvp.Root).String()
m["siblings"] = siblingsToStringArray(cvp.Siblings)

View File

@@ -17,7 +17,7 @@ func TestCircomVerifierProof(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 4,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
testVector := [][]int64{
{1, 11},

View File

@@ -1,6 +1,7 @@
package main
import (
"crypto/rand"
"encoding/json"
"math/big"
"os"
@@ -8,16 +9,16 @@ import (
qt "github.com/frankban/quicktest"
"github.com/vocdoni/arbo"
"go.vocdoni.io/dvote/db"
"go.vocdoni.io/dvote/db/pebbledb"
"github.com/vocdoni/arbo/memdb"
)
func TestGenerator(t *testing.T) {
c := qt.New(t)
database, err := pebbledb.New(db.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := arbo.NewTree(arbo.Config{Database: database, MaxLevels: 4,
HashFunction: arbo.HashFunctionPoseidon})
tree, err := arbo.NewTree(arbo.Config{
Database: memdb.New(),
MaxLevels: 160,
HashFunction: arbo.HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
testVector := [][]int64{
@@ -27,7 +28,7 @@ func TestGenerator(t *testing.T) {
{4, 44},
}
bLen := 1
for i := 0; i < len(testVector); i++ {
for i := range testVector {
k := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][0]))
v := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][1]))
if err := tree.Add(k, v); err != nil {
@@ -54,4 +55,32 @@ func TestGenerator(t *testing.T) {
// store the data into a file that will be used at the circom test
err = os.WriteFile("go-smt-verifier-non-existence-inputs.json", jCvp, 0600)
c.Assert(err, qt.IsNil)
// create a new tree with big.Int keys
bigtree, err := arbo.NewTree(arbo.Config{
Database: memdb.New(),
MaxLevels: 160,
HashFunction: arbo.HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
// add 100 elements to the tree
var bk *big.Int
for i := range 100 {
k, err := rand.Int(rand.Reader, big.NewInt(100_000_000_000))
c.Assert(err, qt.IsNil)
v := new(big.Int).Mul(k, big.NewInt(2))
c.Assert(bigtree.AddBigInt(k, v), qt.IsNil)
if i == 0 {
bk = k
}
}
// generate a proof of existence for the first key
cvp, err = tree.GenerateCircomVerifierProofBigInt(bk)
c.Assert(err, qt.IsNil)
jCvp, err = json.Marshal(cvp)
c.Assert(err, qt.IsNil)
// store the data into a file that will be used at the circom test
err = os.WriteFile("go-smt-verifier-big-inputs.json", jCvp, 0600)
c.Assert(err, qt.IsNil)
}

View File

@@ -14,7 +14,7 @@ describe("merkletreetree circom-proof-verifier", function () {
before( async() => {
const circuitCode = `
include "smt-proof-verifier_test.circom";
component main = SMTVerifierTest(4);
component main = SMTVerifierTest(160);
`;
fs.writeFileSync(circuitPath, circuitCode, "utf8");
@@ -80,6 +80,17 @@ describe("merkletreetree circom-proof-verifier", function () {
const witness = await circuit.calculateWitness(inputsVerifier);
await circuit.checkConstraints(witness);
});
it("Test smt-verifier proof of existence go big inputs", async () => {
// fromGo is a json CircomVerifierProof generated from Go code using
// https://github.com/vocdoni/arbo
let rawdata = fs.readFileSync('go-data-generator/go-smt-verifier-big-inputs.json');
let fromGo = JSON.parse(rawdata);
inputsVerifier=fromGo;
// console.log("smtverifier js inputs:\n", inputsVerifier);
const witness = await circuit.calculateWitness(inputsVerifier);
await circuit.checkConstraints(witness);
});
it("Test smt-verifier proof of non-existence go inputs", async () => {
// fromGo is a json CircomVerifierProof generated from Go code using
// https://github.com/vocdoni/arbo

124
tree.go
View File

@@ -22,6 +22,8 @@ import (
"sync"
"go.vocdoni.io/dvote/db"
"go.vocdoni.io/dvote/db/prefixeddb"
"slices"
)
const (
@@ -55,6 +57,9 @@ var (
// in disk.
DefaultThresholdNLeafs = 65536
dbTreePrefix = []byte("treedb")
dbValuesPrefix = []byte("valuesdb")
dbKeyRoot = []byte("root")
dbKeyNLeafs = []byte("nleafs")
emptyValue = []byte{0}
@@ -88,7 +93,8 @@ var (
type Tree struct {
sync.Mutex
db db.Database
treedb db.Database
valuesdb db.Database
maxLevels int
// thresholdNLeafs defines the threshold number of leafs in the tree
// that determines if AddBatch will work in memory or in disk. It is
@@ -116,7 +122,7 @@ type Config struct {
// NewTree returns a new Tree, if there is a Tree still in the given database, it
// will load it.
func NewTree(cfg Config) (*Tree, error) {
wTx := cfg.Database.WriteTx()
wTx := prefixeddb.NewPrefixedWriteTx(cfg.Database.WriteTx(), dbTreePrefix)
defer wTx.Discard()
t, err := NewTreeWithTx(wTx, cfg)
@@ -138,12 +144,16 @@ func NewTreeWithTx(wTx db.WriteTx, cfg Config) (*Tree, error) {
if cfg.ThresholdNLeafs == 0 {
cfg.ThresholdNLeafs = DefaultThresholdNLeafs
}
t := Tree{db: cfg.Database, maxLevels: cfg.MaxLevels,
thresholdNLeafs: cfg.ThresholdNLeafs, hashFunction: cfg.HashFunction}
t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
t := Tree{
treedb: prefixeddb.NewPrefixedDatabase(cfg.Database, dbTreePrefix),
valuesdb: prefixeddb.NewPrefixedDatabase(cfg.Database, dbValuesPrefix),
maxLevels: cfg.MaxLevels,
thresholdNLeafs: cfg.ThresholdNLeafs,
hashFunction: cfg.HashFunction,
emptyHash: make([]byte, cfg.HashFunction.Len()), // empty
}
_, err := wTx.Get(dbKeyRoot)
if err == db.ErrKeyNotFound {
if _, err := wTx.Get(dbKeyRoot); err == db.ErrKeyNotFound {
// store new root 0 (empty)
if err = wTx.Set(dbKeyRoot, t.emptyHash); err != nil {
return nil, err
@@ -160,7 +170,7 @@ func NewTreeWithTx(wTx db.WriteTx, cfg Config) (*Tree, error) {
// Root returns the root of the Tree
func (t *Tree) Root() ([]byte, error) {
return t.RootWithTx(t.db)
return t.RootWithTx(t.treedb)
}
// RootWithTx returns the root of the Tree using the given db.ReadTx
@@ -201,7 +211,7 @@ type Invalid struct {
// the indexes of the keys failed to add. Supports empty values as input
// parameters, which is equivalent to 0 valued byte array.
func (t *Tree) AddBatch(keys, values [][]byte) ([]Invalid, error) {
wTx := t.db.WriteTx()
wTx := t.treedb.WriteTx()
defer wTx.Discard()
invalids, err := t.AddBatchWithTx(wTx, keys, values)
@@ -221,7 +231,6 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
if !t.editable() {
return nil, ErrSnapshotNotEditable
}
e := []byte{}
// equal the number of keys & values
if len(keys) > len(values) {
@@ -233,7 +242,6 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
// crop extra values
values = values[:len(keys)]
}
nLeafs, err := t.GetNLeafsWithTx(wTx)
if err != nil {
return nil, err
@@ -248,26 +256,25 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
nCPU := flp2(runtime.NumCPU())
if nCPU == 1 || len(keys) < nCPU {
var invalids []Invalid
for i := 0; i < len(keys); i++ {
for i := range keys {
if err := t.addWithTx(wTx, keys[i], values[i]); err != nil {
invalids = append(invalids, Invalid{i, err})
}
}
return invalids, nil
}
// split keys and values in buckets to add them in parallel by CPU
kvs, invalids, err := keysValuesToKvs(t.maxLevels, keys, values)
if err != nil {
return nil, err
}
buckets := splitInBuckets(kvs, nCPU)
// get the root to start adding the keys
root, err := t.RootWithTx(wTx)
if err != nil {
return nil, err
}
// get the subRoots at level l+1
l := int(math.Log2(float64(nCPU)))
subRoots, err := t.getSubRootsAtLevel(wTx, root, l+1)
if err != nil {
@@ -277,12 +284,12 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
// Already populated Tree but Unbalanced.
// add one key at each bucket, and then continue with the flow
for i := 0; i < len(buckets); i++ {
for i := range buckets {
// add one leaf of the bucket, if there is an error when
// adding the k-v, try to add the next one of the bucket
// (until one is added)
inserted := -1
for j := 0; j < len(buckets[i]); j++ {
for j := range buckets[i] {
if newRoot, err := t.add(wTx, root, 0,
buckets[i][j].k, buckets[i][j].v); err == nil {
inserted = j
@@ -290,10 +297,9 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
break
}
}
// remove the inserted element from buckets[i]
if inserted != -1 {
buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
buckets[i] = slices.Delete(buckets[i], inserted, inserted+1)
}
}
subRoots, err = t.getSubRootsAtLevel(wTx, root, l+1)
@@ -311,8 +317,8 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
invalidsInBucket := make([][]Invalid, nCPU)
txs := make([]db.WriteTx, nCPU)
for i := 0; i < nCPU; i++ {
txs[i] = t.db.WriteTx()
for i := range nCPU {
txs[i] = t.treedb.WriteTx()
err := txs[i].Apply(wTx)
if err != nil {
return nil, err
@@ -321,12 +327,12 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
var wg sync.WaitGroup
wg.Add(nCPU)
for i := 0; i < nCPU; i++ {
for i := range nCPU {
go func(cpu int) {
// use different wTx for each cpu, after once all
// are done, iter over the cpuWTxs and copy their
// content into the main wTx
for j := 0; j < len(buckets[cpu]); j++ {
for j := range buckets[cpu] {
newSubRoot, err := t.add(txs[cpu], subRoots[cpu],
l, buckets[cpu][j].k, buckets[cpu][j].v)
if err != nil {
@@ -341,33 +347,27 @@ func (t *Tree) addBatchInDisk(wTx db.WriteTx, keys, values [][]byte) ([]Invalid,
}(i)
}
wg.Wait()
for i := 0; i < nCPU; i++ {
for i := range nCPU {
if err := wTx.Apply(txs[i]); err != nil {
return nil, err
}
txs[i].Discard()
}
for i := 0; i < len(invalidsInBucket); i++ {
for i := range invalidsInBucket {
invalids = append(invalids, invalidsInBucket[i]...)
}
newRoot, err := t.upFromSubRoots(wTx, subRoots)
if err != nil {
return nil, err
}
// update dbKeyNLeafs
if err := t.SetRootWithTx(wTx, newRoot); err != nil {
return nil, err
}
// update nLeafs
if err := t.incNLeafs(wTx, len(keys)-len(invalids)); err != nil {
return nil, err
}
return invalids, nil
}
@@ -412,7 +412,6 @@ func (t *Tree) upFromSubRoots(wTx db.WriteTx, subRoots [][]byte) ([]byte, error)
newSubRoots = append(newSubRoots, subRoots[i+1])
continue
}
k, v, err := t.newIntermediate(subRoots[i], subRoots[i+1])
if err != nil {
return nil, err
@@ -430,7 +429,6 @@ func (t *Tree) upFromSubRoots(wTx db.WriteTx, subRoots [][]byte) ([]byte, error)
func (t *Tree) getSubRootsAtLevel(rTx db.Reader, root []byte, l int) ([][]byte, error) {
// go at level l and return each node key, where each node key is the
// subRoot of the subTree that starts there
var subRoots [][]byte
err := t.iterWithStop(rTx, root, 0, func(currLvl int, k, v []byte) bool {
if currLvl == l && !bytes.Equal(k, t.emptyHash) {
@@ -441,7 +439,6 @@ func (t *Tree) getSubRootsAtLevel(rTx db.Reader, root []byte, l int) ([][]byte,
}
return false
})
return subRoots, err
}
@@ -450,12 +447,10 @@ func (t *Tree) addBatchInMemory(wTx db.WriteTx, keys, values [][]byte) ([]Invali
if err != nil {
return nil, err
}
invalids, err := vt.addBatch(keys, values)
if err != nil {
return nil, err
}
// once the VirtualTree is build, compute the hashes
pairs, err := vt.computeHashes()
if err != nil {
@@ -464,26 +459,22 @@ func (t *Tree) addBatchInMemory(wTx db.WriteTx, keys, values [][]byte) ([]Invali
// nothing stored in the db and the error is returned
return nil, err
}
// store pairs in db
for i := 0; i < len(pairs); i++ {
for i := range pairs {
if err := wTx.Set(pairs[i][0], pairs[i][1]); err != nil {
return nil, err
}
}
// store root (from the vt) to db
if vt.root != nil {
if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil {
return nil, err
}
}
// update nLeafs
if err := t.incNLeafs(wTx, len(keys)-len(invalids)); err != nil {
return nil, err
}
return invalids, nil
}
@@ -493,7 +484,7 @@ func (t *Tree) loadVT() (vt, error) {
vt := newVT(t.maxLevels, t.hashFunction)
vt.params.dbg = t.dbg
var callbackErr error
err := t.IterateWithStopWithTx(t.db, nil, func(_ int, k, v []byte) bool {
err := t.IterateWithStopWithTx(t.treedb, nil, func(_ int, k, v []byte) bool {
if v[0] != PrefixValueLeaf {
return false
}
@@ -515,7 +506,7 @@ func (t *Tree) loadVT() (vt, error) {
// *big.Int, is expected that are represented by a Little-Endian byte array
// (for circom compatibility).
func (t *Tree) Add(k, v []byte) error {
wTx := t.db.WriteTx()
wTx := t.treedb.WriteTx()
defer wTx.Discard()
if err := t.AddWithTx(wTx, k, v); err != nil {
@@ -561,13 +552,19 @@ func (t *Tree) addWithTx(wTx db.WriteTx, k, v []byte) error {
return nil
}
// keyLenByLevels returns the key length in bytes that can be used in the tree
// with maxLevels levels. The key length is calculated as the ceil(maxLevels/8).
func keyLenByLevels(maxLevels int) int {
return int(math.Ceil(float64(maxLevels) / float64(8)))
}
// keyPathFromKey returns the keyPath and checks that the key is not bigger
// than maximum key length for the tree maxLevels size.
// This is because if the key bits length is bigger than the maxLevels of the
// tree, two different keys that their difference is at the end, will collision
// in the same leaf of the tree (at the max depth).
func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) {
maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
maxKeyLen := keyLenByLevels(maxLevels) //nolint:gomnd
if len(k) > maxKeyLen {
return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+
@@ -718,10 +715,8 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
if currLvl > t.maxLevels-1 {
return nil, ErrMaxVirtualLevel
}
if oldPath[currLvl] == newPath[currLvl] {
siblings = append(siblings, t.emptyHash)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
if err != nil {
return nil, err
@@ -730,7 +725,6 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
}
// reached the divergence
siblings = append(siblings, oldKey)
return siblings, nil
}
@@ -754,12 +748,10 @@ func (t *Tree) up(wTx db.WriteTx, key []byte, siblings [][]byte, path []bool,
if err = wTx.Set(k, v); err != nil {
return nil, err
}
if currLvl == 0 {
// reached the root
return k, nil
}
return t.up(wTx, k, siblings, path, currLvl-1, toLvl)
}
@@ -795,7 +787,6 @@ func ReadLeafValue(b []byte) ([]byte, []byte) {
if len(b) < PrefixValueLen {
return []byte{}, []byte{}
}
kLen := b[1]
if len(b) < PrefixValueLen+int(kLen) {
return []byte{}, []byte{}
@@ -825,12 +816,10 @@ func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error)
b[1] = byte(len(l))
copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
copy(b[PrefixValueLen+hashFunc.Len():], r)
key, err := hashFunc.Hash(l, r)
if err != nil {
return nil, nil, err
}
return key, b, nil
}
@@ -839,7 +828,6 @@ func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
if len(b) < PrefixValueLen {
return []byte{}, []byte{}
}
lLen := b[1]
if len(b) < PrefixValueLen+int(lLen) {
return []byte{}, []byte{}
@@ -851,7 +839,7 @@ func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
func getPath(numLevels int, k []byte) []bool {
path := make([]bool, numLevels)
for n := 0; n < numLevels; n++ {
for n := range numLevels {
path[n] = k[n/8]&(1<<(n%8)) != 0
}
return path
@@ -860,9 +848,8 @@ func getPath(numLevels int, k []byte) []bool {
// Update updates the value for a given existing key. If the given key does not
// exist, returns an error.
func (t *Tree) Update(k, v []byte) error {
wTx := t.db.WriteTx()
wTx := t.treedb.WriteTx()
defer wTx.Discard()
if err := t.UpdateWithTx(wTx, k, v); err != nil {
return err
}
@@ -875,7 +862,6 @@ func (t *Tree) Update(k, v []byte) error {
func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
t.Lock()
defer t.Unlock()
if !t.editable() {
return ErrSnapshotNotEditable
}
@@ -930,7 +916,7 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
// returned, together with the packed siblings of the proof, and a boolean
// parameter that indicates if the proof is of existence (true) or not (false).
func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
return t.GenProofWithTx(t.db, k)
return t.GenProofWithTx(t.treedb, k)
}
// GenProofWithTx does the same than the GenProof method, but allowing to pass
@@ -1062,7 +1048,7 @@ func bytesToBitmap(b []byte) []bool {
// will be placed the data found in the tree in the leaf that was on the path
// going to the input key.
func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
return t.GetWithTx(t.db, k)
return t.GetWithTx(t.treedb, k)
}
// GetWithTx does the same than the Get method, but allowing to pass the
@@ -1151,7 +1137,7 @@ func (t *Tree) setNLeafs(wTx db.WriteTx, nLeafs int) error {
// GetNLeafs returns the number of Leafs of the Tree.
func (t *Tree) GetNLeafs() (int, error) {
return t.GetNLeafsWithTx(t.db)
return t.GetNLeafsWithTx(t.treedb)
}
// GetNLeafsWithTx does the same than the GetNLeafs method, but allowing to
@@ -1167,7 +1153,7 @@ func (t *Tree) GetNLeafsWithTx(rTx db.Reader) (int, error) {
// SetRoot sets the root to the given root
func (t *Tree) SetRoot(root []byte) error {
wTx := t.db.WriteTx()
wTx := t.treedb.WriteTx()
defer wTx.Discard()
if err := t.SetRootWithTx(wTx, root); err != nil {
@@ -1209,7 +1195,7 @@ func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
return nil, err
}
}
rTx := t.db
rTx := t.treedb
// check that the root exists in the db
if !bytes.Equal(fromRoot, t.emptyHash) {
if _, err := rTx.Get(fromRoot); err == ErrKeyNotFound {
@@ -1222,7 +1208,7 @@ func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
}
return &Tree{
db: t.db,
treedb: t.treedb,
maxLevels: t.maxLevels,
snapshotRoot: fromRoot,
emptyHash: t.emptyHash,
@@ -1234,7 +1220,7 @@ func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
// Iterate iterates through the full Tree, executing the given function on each
// node of the Tree.
func (t *Tree) Iterate(fromRoot []byte, f func([]byte, []byte)) error {
return t.IterateWithTx(t.db, fromRoot, f)
return t.IterateWithTx(t.treedb, fromRoot, f)
}
// IterateWithTx does the same than the Iterate method, but allowing to pass
@@ -1258,12 +1244,12 @@ func (t *Tree) IterateWithStop(fromRoot []byte, f func(int, []byte, []byte) bool
// allow to define which root to use
if fromRoot == nil {
var err error
fromRoot, err = t.RootWithTx(t.db)
fromRoot, err = t.RootWithTx(t.treedb)
if err != nil {
return err
}
}
return t.iterWithStop(t.db, fromRoot, 0, f)
return t.iterWithStop(t.treedb, fromRoot, 0, f)
}
// IterateWithStopWithTx does the same than the IterateWithStop method, but
@@ -1470,14 +1456,14 @@ node [fontname=Monospace,fontsize=10,shape=box]
}
if fromRoot == nil {
var err error
fromRoot, err = t.RootWithTx(t.db)
fromRoot, err = t.RootWithTx(t.treedb)
if err != nil {
return err
}
}
nEmpties := 0
err := t.iterWithStop(t.db, fromRoot, 0, func(currLvl int, k, v []byte) bool {
err := t.iterWithStop(t.treedb, fromRoot, 0, func(currLvl int, k, v []byte) bool {
if currLvl == untilLvl {
return true // to stop the iter from going down
}

302
tree_big.go Normal file
View File

@@ -0,0 +1,302 @@
package arbo
import (
"bytes"
"fmt"
"log"
"math/big"
"runtime"
"slices"
"sync"
"go.vocdoni.io/dvote/db"
)
// AddBatchBigInt adds a batch of key-value pairs to the tree, it converts the
// big.Int keys and the slices of big.Int values into bytes and adds them to
// the tree. It locks the tree to prevent concurrent writes to the valuesdb and
// creates a transaction to store the full values in the valuesdb. It returns
// a slice of Invalid items and an error if something fails.
func (t *Tree) AddBatchBigInt(k []*big.Int, v [][]*big.Int) ([]Invalid, error) {
if len(k) != len(v) {
return nil, fmt.Errorf("the number of keys and values missmatch")
}
// convert each key-value tuple into bytes
var err error
bks := make([][]byte, len(k))
bvs := make([][]byte, len(k))
fbvs := make([][]byte, len(k))
for i, ki := range k {
bks[i], bvs[i], fbvs[i], err = bigIntLeaf(t.HashFunction(), t.maxKeyLen(), ki, v[i])
if err != nil {
return nil, err
}
}
// add the keys and leaf values in batch
if invalids, err := t.AddBatch(bks, bvs); err != nil {
return invalids, err
}
// lock the tree to prevent concurrent writes to the valuesdb
t.Lock()
defer t.Unlock()
// create a transaction for each group of keys and full values and store
// the errors in a slice to return them
var fullInvalids []Invalid
wTx := t.valuesdb.WriteTx()
defer wTx.Discard()
for i := range bks {
if err := wTx.Set(bks[i], fbvs[i]); err != nil {
fullInvalids = append(fullInvalids, Invalid{i, err})
}
}
return fullInvalids, wTx.Commit()
}
//nolint:unused
func (t *Tree) addBatchBigIntByCPU(bks, fbvs [][]byte) ([]Invalid, error) {
// lock the tree to prevent concurrent writes to the valuesdb
t.Lock()
defer t.Unlock()
// split keys and full values in groups to add them in parallel by CPU
nCPU := flp2(runtime.NumCPU())
groupsOfKeys := splitInGroups(bks, nCPU)
groupOfFullValues := splitInGroups(fbvs, nCPU)
// create a transaction for each group of keys and full values and store
// the errors in a slice to return them
var fullInvalids []Invalid
wTx := t.valuesdb.WriteTx()
// if there is only one CPU or the number of groups is less than the number
// of CPUs, add the full values in the same goroutine and commit the
// transaction
if nCPU == 1 || len(groupsOfKeys) < nCPU {
for i, bk := range bks {
if err := wTx.Set(bk, fbvs[i]); err != nil {
fullInvalids = append(fullInvalids, Invalid{i, err})
}
}
return fullInvalids, wTx.Commit()
}
// add the full values in parallel
var wg sync.WaitGroup
wg.Add(nCPU)
txs := make([]db.WriteTx, nCPU)
for i := range nCPU {
// create a transaction for each CPU
txs[i] = t.valuesdb.WriteTx()
if err := txs[i].Apply(wTx); err != nil {
log.Println(err)
return fullInvalids, err
}
// add each group of full values in a goroutine
go func(cpu int) {
for j := range len(groupsOfKeys[cpu]) {
if err := txs[cpu].Set(groupsOfKeys[cpu][j], groupOfFullValues[cpu][j]); err != nil {
idx := (cpu + 1) * j
fullInvalids = append(fullInvalids, Invalid{idx, err})
}
}
wg.Done()
}(i)
}
// wait for all the goroutines to finish and apply the transactions
wg.Wait()
for i := range nCPU {
if err := wTx.Apply(txs[i]); err != nil {
return fullInvalids, err
}
txs[i].Discard()
}
return fullInvalids, nil
}
// AddBigInt adds a key-value pair to the tree, it converts the big.Int key
// and the slice of big.Int values into bytes and adds them to the tree. It
// locks the tree to prevent concurrent writes to the valuesdb and creates a
// transaction to store the full value in the valuesdb. It returns an error if
// something fails.
func (t *Tree) AddBigInt(k *big.Int, v ...*big.Int) error {
if k == nil {
return fmt.Errorf("key cannot be nil")
}
// convert the big ints to bytes
bk, bv, fbv, err := bigIntLeaf(t.HashFunction(), t.maxKeyLen(), k, v)
if err != nil {
return err
}
// add it to the tree
if err := t.Add(bk, bv); err != nil {
return err
}
// lock the tree to prevent concurrent writes to the valuesdb
t.Lock()
defer t.Unlock()
// create a transaction to store the full value
wTx := t.valuesdb.WriteTx()
defer wTx.Discard()
// store the full value in the valuesdb
if err := wTx.Set(bk, fbv); err != nil {
return err
}
return wTx.Commit()
}
// UpdateBigInt updates the value of a key as a big.Int and the values of the
// leaf node as a slice of big.Ints. It encodes the key as bytes and updates
// the leaf node in the tree, then it stores the full value in the valuesdb. It
// returns an error if something fails.
func (t *Tree) UpdateBigInt(k *big.Int, value ...*big.Int) error {
if k == nil {
return fmt.Errorf("key cannot be nil")
}
// convert the big ints to bytes
bk, bv, fbv, err := bigIntLeaf(t.HashFunction(), t.maxKeyLen(), k, value)
if err != nil {
return err
}
// update the leaf in the tree
if err := t.Update(bk, bv); err != nil {
return err
}
// lock the tree to prevent concurrent writes to the valuesdb
t.Lock()
defer t.Unlock()
// create a transaction to store the full value
wTx := t.valuesdb.WriteTx()
defer wTx.Discard()
// store the full value in the valuesdb
if err := wTx.Set(bk, fbv); err != nil {
return err
}
return wTx.Commit()
}
// GetBigInt gets the value of a key as a big.Int and the values of the leaf
// node as a slice of big.Ints. It encodes the key as bytes and gets the leaf
// node from the tree, then it decodes the full value of the leaf node and
// returns the key and the values or an error if something fails.
func (t *Tree) GetBigInt(k *big.Int) (*big.Int, []*big.Int, error) {
if k == nil {
return nil, nil, fmt.Errorf("key cannot be nil")
}
bk, bv, err := t.Get(bigIntToKey(t.maxKeyLen(), k))
if err != nil {
return nil, nil, err
}
return t.leafToBigInts(bk, bv)
}
func (t *Tree) GenProofBigInts(k *big.Int) ([]byte, []byte, []byte, bool, error) {
if k == nil {
return nil, nil, nil, false, fmt.Errorf("key cannot be nil")
}
return t.GenProof(bigIntToKey(t.maxKeyLen(), k))
}
func (t *Tree) GenerateCircomVerifierProofBigInt(k *big.Int) (*CircomVerifierProof, error) {
if k == nil {
return nil, fmt.Errorf("key cannot be nil")
}
return t.GenerateCircomVerifierProof(bigIntToKey(t.maxKeyLen(), k))
}
// maxKeyLen returns the maximum length of the key in bytes for a tree
func (t *Tree) maxKeyLen() int {
return keyLenByLevels(t.maxLevels)
}
// leafToBigInts converts the bytes of the key and the value of a leaf node
// into a big.Int key and a slice of big.Int values, it gets the full value
// from the valuesdb and checks if it matches the value of the leaf node. It
// returns the original key and values or an error if the values don't match.
func (t *Tree) leafToBigInts(key, value []byte) (*big.Int, []*big.Int, error) {
bFullValue, err := t.valuesdb.Get(key)
if err != nil {
return nil, nil, err
}
// recalculate the value to check if it matches the stored value
expectedFullValue, err := bigIntToLeafValue(t.HashFunction(), bFullValue)
if err != nil {
return nil, nil, err
}
if !bytes.Equal(expectedFullValue, value) {
return nil, nil, fmt.Errorf("LeafToBigInt: expectedFullValue != value")
}
// reverse the process of values encoding
values := []*big.Int{}
iter := slices.Clone(bFullValue)
for len(iter) > 0 {
lenV := int(iter[0])
values = append(values, new(big.Int).SetBytes(iter[1:1+lenV]))
iter = iter[1+lenV:]
}
return BytesToBigInt(key), values, nil
}
// BigIntToBytes converts a big.Int into a byte slice of length keyLen
func bigIntToKey(keyLen int, b *big.Int) []byte {
return BigIntToBytes(keyLen, b)
}
// bigIntLeaf converts a big.Int key and a slice of big.Int values into the
// bytes of the key, the bytes of the value used to build the tree and the
// bytes of the full value encoded
func bigIntLeaf(hFn HashFunction, keyLen int, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) {
if key == nil {
return nil, nil, nil, fmt.Errorf("key cannot be nil")
}
// calculate the bytes of the key
bKey := bigIntToKey(keyLen, key)
// calculate the bytes of the full values (should be reversible)
bFullValue := []byte{}
for _, v := range values {
if v == nil {
return nil, nil, nil, fmt.Errorf("value cannot be nil")
}
vBytes := v.Bytes()
if len(vBytes) > 255 {
return nil, nil, nil, fmt.Errorf("value byte length cannot exceed 255")
}
val := append([]byte{byte(len(vBytes))}, vBytes...)
bFullValue = append(bFullValue, val...)
}
// calculate the value used to build the tree
bValue, err := bigIntToLeafValue(hFn, bFullValue)
if err != nil {
return nil, nil, nil, err
}
return bKey, bValue, bFullValue, nil
}
// bigIntToLeafValue hashes the full value of a leaf node by splitting it in
// chunks of the size of the hash function output and hashing them
func bigIntToLeafValue(hFn HashFunction, bFullValue []byte) ([]byte, error) {
// split the full value in chunks of the size of the hash function output
chunks := [][]byte{}
chunk := []byte{}
for i := range bFullValue {
chunk = append(chunk, bFullValue[i])
if len(chunk) == hFn.Len() {
chunks = append(chunks, chunk)
chunk = []byte{}
}
}
// if there is a chunk left, add it to the chunks
if len(chunk) > 0 {
chunks = append(chunks, chunk)
}
// hash the chunks
bValue, err := hFn.Hash(chunks...)
if err != nil {
return nil, err
}
return bValue, nil
}
// splitInGroups splits the items in nGroups groups
func splitInGroups[T any](items []T, nGroups int) [][]T {
groups := make([][]T, nGroups)
for i, item := range items {
groups[i%nGroups] = append(groups[i%nGroups], item)
}
return groups
}

319
tree_big_test.go Normal file
View File

@@ -0,0 +1,319 @@
package arbo
import (
"crypto/rand"
"math/big"
"testing"
qt "github.com/frankban/quicktest"
"github.com/vocdoni/arbo/memdb"
)
func TestGenCheckProofBigInt(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 256,
HashFunction: HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
defer tree.treedb.Close() //nolint:errcheck
defer tree.valuesdb.Close() //nolint:errcheck
keys := []*big.Int{}
for range 1000 {
k, err := rand.Int(rand.Reader, big.NewInt(100_000_000_000))
c.Assert(err, qt.IsNil)
v := new(big.Int).Mul(k, big.NewInt(2))
c.Assert(tree.AddBigInt(k, v), qt.IsNil)
keys = append(keys, k)
}
// validate 20 random keys
for range 20 {
i, err := rand.Int(rand.Reader, big.NewInt(int64(len(keys))))
c.Assert(err, qt.IsNil)
k := keys[i.Int64()]
kAux, vAux, siblings, existence, err := tree.GenProofBigInts(k)
c.Assert(err, qt.IsNil)
c.Assert(existence, qt.IsTrue)
root, err := tree.Root()
c.Assert(err, qt.IsNil)
verif, err := CheckProof(tree.hashFunction, kAux, vAux, root, siblings)
c.Assert(err, qt.IsNil)
c.Check(verif, qt.IsTrue)
}
}
func TestAddGetBigInt(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 256,
HashFunction: HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
defer tree.treedb.Close() //nolint:errcheck
defer tree.valuesdb.Close() //nolint:errcheck
// Add multiple key-value pairs with large random big ints
keys := make([]*big.Int, 100)
values := make([][]*big.Int, 100)
for i := range 100 {
k, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
keys[i] = k
// Create multiple random values for each key
v1, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
v2, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
values[i] = []*big.Int{v1, v2}
c.Assert(tree.AddBigInt(k, v1, v2), qt.IsNil)
// Verify retrieval
retrievedK, retrievedVs, err := tree.GetBigInt(k)
c.Assert(err, qt.IsNil)
c.Check(retrievedK.Cmp(k), qt.Equals, 0)
c.Assert(len(retrievedVs), qt.Equals, 2)
c.Check(retrievedVs[0].Cmp(v1), qt.Equals, 0)
c.Check(retrievedVs[1].Cmp(v2), qt.Equals, 0)
}
// Test non-existent key
nonExistentKey, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
_, _, err = tree.GetBigInt(nonExistentKey)
c.Check(err, qt.IsNotNil)
// Test nil key
_, _, err = tree.GetBigInt(nil)
c.Check(err, qt.IsNotNil)
// Test adding duplicate key
err = tree.AddBigInt(keys[0], values[0]...)
c.Check(err, qt.IsNotNil)
}
func TestUpdateBigInt(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 256,
HashFunction: HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
defer tree.treedb.Close() //nolint:errcheck
defer tree.valuesdb.Close() //nolint:errcheck
// Store keys for later updates
keys := make([]*big.Int, 50)
values := make([][]*big.Int, 50)
// Add entries with large random big ints
for i := range 50 {
k, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
keys[i] = k
v1, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
v2, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
values[i] = []*big.Int{v1, v2}
c.Assert(tree.AddBigInt(k, v1, v2), qt.IsNil)
}
// Update entries with new random values
for i := range 25 {
k := keys[i]
newV1, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
newV2, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
c.Assert(tree.UpdateBigInt(k, newV1, newV2), qt.IsNil)
// Verify update
_, retrievedVs, err := tree.GetBigInt(k)
c.Assert(err, qt.IsNil)
c.Assert(len(retrievedVs), qt.Equals, 2)
c.Check(retrievedVs[0].Cmp(newV1), qt.Equals, 0)
c.Check(retrievedVs[1].Cmp(newV2), qt.Equals, 0)
}
// Test updating non-existent key
nonExistentKey, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
err = tree.UpdateBigInt(nonExistentKey, big.NewInt(1))
c.Check(err, qt.IsNotNil)
// Test updating with nil key
err = tree.UpdateBigInt(nil, big.NewInt(1))
c.Check(err, qt.IsNotNil)
}
func TestAddBatchBigInt(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 256,
HashFunction: HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
defer tree.treedb.Close() //nolint:errcheck
defer tree.valuesdb.Close() //nolint:errcheck
// Prepare batch data with large random big ints
batchSize := 1000
keys := make([]*big.Int, batchSize)
values := make([][]*big.Int, batchSize)
for i := range batchSize {
k, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
keys[i] = k
// Create multiple random values for each key
v1, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
v2, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
values[i] = []*big.Int{v1, v2}
}
// Add batch
invalids, err := tree.AddBatchBigInt(keys, values)
c.Assert(err, qt.IsNil)
c.Check(len(invalids), qt.Equals, 0)
// Verify random sample of entries
for i := range 50 {
idx := i % batchSize
_, retrievedVs, err := tree.GetBigInt(keys[idx])
c.Assert(err, qt.IsNil)
c.Assert(len(retrievedVs), qt.Equals, 2)
c.Check(retrievedVs[0].Cmp(values[idx][0]), qt.Equals, 0)
c.Check(retrievedVs[1].Cmp(values[idx][1]), qt.Equals, 0)
}
// Test mismatched lengths
_, err = tree.AddBatchBigInt(keys[:10], values[:5])
c.Check(err, qt.IsNotNil)
// Test empty batch
invalids, err = tree.AddBatchBigInt([]*big.Int{}, [][]*big.Int{})
c.Assert(err, qt.IsNil)
c.Check(len(invalids), qt.Equals, 0)
// Test nil values
invalids, err = tree.AddBatchBigInt(nil, nil)
c.Assert(err, qt.IsNil)
c.Check(len(invalids), qt.Equals, 0)
}
func TestGenerateCircomVerifierProofBigInt(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 256,
HashFunction: HashFunctionPoseidon,
})
c.Assert(err, qt.IsNil)
defer tree.treedb.Close() //nolint:errcheck
defer tree.valuesdb.Close() //nolint:errcheck
// Store keys for later proof generation
keys := make([]*big.Int, 100)
values := make([][]*big.Int, 100)
// Add entries with large random big ints
for i := range 100 {
k, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
keys[i] = k
v1, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
v2, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
values[i] = []*big.Int{v1, v2}
c.Assert(tree.AddBigInt(k, v1, v2), qt.IsNil)
}
// Generate and verify proofs for random keys
for i := range 10 {
idx := i % len(keys)
proof, err := tree.GenerateCircomVerifierProofBigInt(keys[idx])
c.Assert(err, qt.IsNil)
c.Assert(proof, qt.IsNotNil)
// Verify the proof structure
c.Assert(proof.Root, qt.IsNotNil)
c.Assert(proof.Key, qt.IsNotNil)
c.Assert(proof.Value, qt.IsNotNil)
c.Assert(proof.Siblings, qt.IsNotNil)
}
// Test non-existent key
nonExistentKey, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
c.Assert(err, qt.IsNil)
nonExistentProof, err := tree.GenerateCircomVerifierProofBigInt(nonExistentKey)
c.Check(err, qt.IsNil)
c.Check(nonExistentProof, qt.IsNotNil)
// Test nil key
_, err = tree.GenerateCircomVerifierProofBigInt(nil)
c.Check(err, qt.IsNotNil)
}
func BenchmarkAddBatchBigInt(b *testing.B) {
// Prepare batch data with large random big ints
batchSize := 1000
keys := make([]*big.Int, batchSize)
values := make([][]*big.Int, batchSize)
for i := range batchSize {
keys[i], _ = rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
v1, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
v2, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 25))
values[i] = []*big.Int{v1, v2}
}
b.Run("Poseidon", func(b *testing.B) {
benchmarkAddBatchBigInt(b, HashFunctionPoseidon, keys, values)
})
b.Run("Sha256", func(b *testing.B) {
benchmarkAddBatchBigInt(b, HashFunctionSha256, keys, values)
})
}
func benchmarkAddBatchBigInt(b *testing.B, hashFunc HashFunction, keys []*big.Int, values [][]*big.Int) {
c := qt.New(b)
b.ResetTimer()
for range b.N {
tree, err := NewTree(Config{
Database: memdb.New(),
MaxLevels: 140,
HashFunction: hashFunc,
})
c.Assert(err, qt.IsNil)
_, err = tree.AddBatchBigInt(keys, values)
if err != nil {
b.Fatal(err)
}
tree.treedb.Close() //nolint:errcheck
tree.valuesdb.Close() //nolint:errcheck
}
}

View File

@@ -76,7 +76,7 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: hashFunc})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
root, err := tree.Root()
c.Assert(err, qt.IsNil)
@@ -109,7 +109,7 @@ func TestAddBatch(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < 1000; i++ {
@@ -128,7 +128,7 @@ func TestAddBatch(t *testing.T) {
tree2, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
var keys, values [][]byte
for i := 0; i < 1000; i++ {
@@ -152,7 +152,7 @@ func TestAddDifferentOrder(t *testing.T) {
tree1, err := NewTree(Config{Database: database1, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < 16; i++ {
@@ -168,7 +168,7 @@ func TestAddDifferentOrder(t *testing.T) {
tree2, err := NewTree(Config{Database: database2, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
for i := 16 - 1; i >= 0; i-- {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
@@ -194,7 +194,7 @@ func TestAddRepeatedIndex(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
k := BigIntToBytes(bLen, big.NewInt(int64(3)))
@@ -213,7 +213,7 @@ func TestUpdate(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
k := BigIntToBytes(bLen, big.NewInt(int64(20)))
@@ -267,7 +267,7 @@ func TestAux(t *testing.T) { // TODO split in proper tests
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
k := BigIntToBytes(bLen, big.NewInt(int64(1)))
@@ -307,7 +307,7 @@ func TestGet(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < 10; i++ {
@@ -432,10 +432,10 @@ func TestGenProofAndVerify(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < 10; i++ {
for i := range 1000 {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
v := BigIntToBytes(bLen, big.NewInt(int64(i*2)))
if err := tree.Add(k, v); err != nil {
@@ -473,7 +473,7 @@ func testDumpAndImportDump(t *testing.T, inFile bool) {
tree1, err := NewTree(Config{Database: database1, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
defer tree1.treedb.Close() //nolint:errcheck
bLen := 32
for i := 0; i < 16; i++ {
@@ -502,7 +502,7 @@ func testDumpAndImportDump(t *testing.T, inFile bool) {
tree2, err := NewTree(Config{Database: database2, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
defer tree2.treedb.Close() //nolint:errcheck
if inFile {
f, err := os.Open(filepath.Clean(fileName))
@@ -530,7 +530,7 @@ func TestRWMutex(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
var keys, values [][]byte
@@ -783,7 +783,7 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) {
tree, err := NewTree(Config{Database: database, MaxLevels: 256,
HashFunction: HashFunctionPoseidon})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
bLen := 32
k := BigIntToBytes(bLen, big.NewInt(int64(3)))
@@ -953,7 +953,7 @@ func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
tree, err := NewTree(Config{Database: database, MaxLevels: 140,
HashFunction: hashFunc})
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck
defer tree.treedb.Close() //nolint:errcheck
for i := 0; i < len(ks); i++ {
if err := tree.Add(ks[i], vs[i]); err != nil {