diff --git a/hash.go b/hash.go index 53bac04..66fb7de 100644 --- a/hash.go +++ b/hash.go @@ -196,7 +196,6 @@ func (f HashMultiPoseidon) SafeBigInt(b *big.Int) []byte { safeBigInt := BigToFF(BN254BaseField, b) arboBytes := BigIntToBytes(f.Len(), safeBigInt) return ExplicitZero(arboBytes) - // return ExplicitZero(BigToFF(BN254BaseField, b).Bytes()) } // HashBlake2b implements the HashFunction interface for the Blake2b hash @@ -372,9 +371,11 @@ func (f HashMiMC7) Hash(b ...[]byte) ([]byte, error) { } func (f HashMiMC7) SafeValue(b []byte) []byte { - return BigToFF(BN254BaseField, BytesToBigInt(b)).Bytes() + return f.SafeBigInt(BytesToBigInt(b)) } func (f HashMiMC7) SafeBigInt(b *big.Int) []byte { - return f.SafeValue(b.Bytes()) + safeBigInt := BigToFF(BN254BaseField, b) + arboBytes := BigIntToBytes(f.Len(), safeBigInt) + return ExplicitZero(arboBytes) } diff --git a/testvectors/gnark/gnark_test.go b/testvectors/gnark/gnark_test.go index 89af430..c31d6d0 100644 --- a/testvectors/gnark/gnark_test.go +++ b/testvectors/gnark/gnark_test.go @@ -14,26 +14,27 @@ import ( "github.com/vocdoni/arbo/memdb" "github.com/vocdoni/gnark-crypto-primitives/hash/bn254/poseidon" garbo "github.com/vocdoni/gnark-crypto-primitives/tree/arbo" + gsmt "github.com/vocdoni/gnark-crypto-primitives/tree/smt" ) const nLevels = 160 -type testCircuit struct { +type testCircuitArbo struct { Root frontend.Variable Key frontend.Variable Value frontend.Variable Siblings [nLevels]frontend.Variable } -func (circuit *testCircuit) Define(api frontend.API) error { +func (circuit *testCircuitArbo) Define(api frontend.API) error { return garbo.CheckInclusionProof(api, poseidon.MultiHash, circuit.Key, circuit.Value, circuit.Root, circuit.Siblings[:]) } -func TestGnarkSMTVerifier(t *testing.T) { +func TestGnarkArboVerifier(t *testing.T) { c := qt.New(t) tree, err := arbo.NewTree(arbo.Config{ Database: memdb.New(), - MaxLevels: 255, + MaxLevels: nLevels, HashFunction: arbo.HashFunctionMultiPoseidon, }) c.Assert(err, qt.IsNil) @@ -66,7 +67,64 @@ func TestGnarkSMTVerifier(t *testing.T) { } assert := test.NewAssert(t) - assert.SolvingSucceeded(&testCircuit{}, &testCircuit{ + assert.SolvingSucceeded(&testCircuitArbo{}, &testCircuitArbo{ + Root: proof.Root, + Key: proof.Key, + Value: proof.Value, + Siblings: [160]frontend.Variable(paddedSiblings), + }, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) +} + +type testCircuitSMT struct { + Root frontend.Variable + Key frontend.Variable + Value frontend.Variable + Siblings [nLevels]frontend.Variable +} + +func (circuit *testCircuitSMT) Define(api frontend.API) error { + gsmt.InclusionVerifier(api, poseidon.MultiHash, circuit.Root, circuit.Siblings[:], circuit.Key, circuit.Value) + return nil +} + +func TestGnarkSMTVerifier(t *testing.T) { + c := qt.New(t) + tree, err := arbo.NewTree(arbo.Config{ + Database: memdb.New(), + MaxLevels: nLevels, + HashFunction: arbo.HashFunctionMultiPoseidon, + }) + c.Assert(err, qt.IsNil) + + var ( + keys []*big.Int + values [][]*big.Int + ) + max, _ := new(big.Int).SetString("10000000000000000000000000", 10) + for range 100 { + k, err := rand.Int(rand.Reader, max) + qt.Assert(t, err, qt.IsNil) + v := new(big.Int).Mul(k, big.NewInt(2)) + keys = append(keys, k) + values = append(values, []*big.Int{v}) + } + _, err = tree.AddBatchBigInt(keys, values) + c.Assert(err, qt.IsNil) + + proof, err := tree.GenerateGnarkVerifierProofBigInt(keys[0]) + c.Assert(err, qt.IsNil) + + var paddedSiblings [nLevels]frontend.Variable + for i := range paddedSiblings { + if i < len(proof.Siblings) { + paddedSiblings[i] = proof.Siblings[i] + continue + } + paddedSiblings[i] = 0 + } + + assert := test.NewAssert(t) + assert.SolvingSucceeded(&testCircuitSMT{}, &testCircuitSMT{ Root: proof.Root, Key: proof.Key, Value: proof.Value, diff --git a/tree_big.go b/tree_big.go index a75d382..aae4f85 100644 --- a/tree_big.go +++ b/tree_big.go @@ -7,6 +7,14 @@ import ( "slices" ) +func MaxKeyLen(levels, hashLen int) int { + return min(int(math.Ceil(float64(levels)/float64(8))), hashLen) +} + +func (t *Tree) MaxKeyLen() int { + return MaxKeyLen(t.maxLevels, t.HashFunction().Len()) +} + // AddBatchBigInt adds a batch of key-value pairs to the tree, it converts the // big.Int keys and the slices of big.Int values into bytes and adds them to // the tree. It locks the tree to prevent concurrent writes to the valuesdb and @@ -22,7 +30,7 @@ func (t *Tree) AddBatchBigInt(k []*big.Int, v [][]*big.Int) ([]Invalid, error) { bvs := make([][]byte, len(k)) fbvs := make([][]byte, len(k)) for i, ki := range k { - bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), ki, v[i]) + bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), t.MaxKeyLen(), ki, v[i]) if err != nil { return nil, err } @@ -57,7 +65,7 @@ func (t *Tree) AddBigInt(k *big.Int, v ...*big.Int) error { return fmt.Errorf("key cannot be nil") } // convert the big ints to bytes - bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, v) + bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.MaxKeyLen(), k, v) if err != nil { return err } @@ -87,7 +95,7 @@ func (t *Tree) UpdateBigInt(k *big.Int, value ...*big.Int) error { return fmt.Errorf("key cannot be nil") } // convert the big ints to bytes - bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, value) + bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.MaxKeyLen(), k, value) if err != nil { return err } @@ -116,7 +124,7 @@ func (t *Tree) GetBigInt(k *big.Int) (*big.Int, []*big.Int, error) { if k == nil { return nil, nil, fmt.Errorf("key cannot be nil") } - bk := t.HashFunction().SafeBigInt(k) + bk := bigIntToLeafKey(k, t.MaxKeyLen()) _, bv, err := t.Get(bk) if err != nil { return nil, nil, err @@ -136,7 +144,9 @@ func (t *Tree) GenProofBigInts(k *big.Int) ([]byte, []byte, []byte, bool, error) if k == nil { return nil, nil, nil, false, fmt.Errorf("key cannot be nil") } - return t.GenProof(t.HashFunction().SafeBigInt(k)) + + bk := bigIntToLeafKey(k, t.MaxKeyLen()) + return t.GenProof(bk) } // GenerateCircomVerifierProofBigInt generates a CircomVerifierProof for a key @@ -147,7 +157,8 @@ func (t *Tree) GenerateCircomVerifierProofBigInt(k *big.Int) (*CircomVerifierPro if k == nil { return nil, fmt.Errorf("key cannot be nil") } - return t.GenerateCircomVerifierProof(t.HashFunction().SafeBigInt(k)) + bk := bigIntToLeafKey(k, t.MaxKeyLen()) + return t.GenerateCircomVerifierProof(bk) } // GenerateGnarkVerifierProofBigInt generates a GnarkVerifierProof for a key @@ -158,7 +169,8 @@ func (t *Tree) GenerateGnarkVerifierProofBigInt(k *big.Int) (*GnarkVerifierProof if k == nil { return nil, fmt.Errorf("key cannot be nil") } - return t.GenerateGnarkVerifierProof(t.HashFunction().SafeBigInt(k)) + bk := bigIntToLeafKey(k, t.MaxKeyLen()) + return t.GenerateGnarkVerifierProof(bk) } // leafToBigInts converts the bytes of the key and the value of a leaf node @@ -167,9 +179,9 @@ func (t *Tree) GenerateGnarkVerifierProofBigInt(k *big.Int) (*GnarkVerifierProof // returns the original key and values or an error if the values don't match. func (t *Tree) leafToBigInts(key, value, fullValue []byte) (*big.Int, []*big.Int, error) { // reverse the process of values encoding - values := fullValueToValues(fullValue) + values := FullValueToValues(fullValue) // recalculate the value to check if it matches the stored value - expectedFullValue, err := valuesToFullValue(values) + expectedFullValue, err := ValuesToFullValue(values) if err != nil { return nil, nil, err } @@ -178,7 +190,7 @@ func (t *Tree) leafToBigInts(key, value, fullValue []byte) (*big.Int, []*big.Int return nil, nil, fmt.Errorf("LeafToBigInt: expectedFullValue != value") } // reencode the leaf value of the tree to check if it matches the value - encodedValues, err := encodeBigIntValues(t.HashFunction(), values...) + encodedValues, err := EncodeBigIntValues(t.HashFunction(), values...) if err != nil { return nil, nil, err } @@ -191,15 +203,24 @@ func (t *Tree) leafToBigInts(key, value, fullValue []byte) (*big.Int, []*big.Int return leafKeyToBigInt(key), values, nil } -// leafKeyToBigInt converts the bytes of a key into a big.Int +// leafKeyToBigInt converts the bytes of a key into a big.Int. It returns the +// big.Int value of the key in Big-Endian format, assuming the key is encoded +// in Little-Endian format. func leafKeyToBigInt(key []byte) *big.Int { - return new(big.Int).SetBytes(key) + return BytesToBigInt(key) } -// valuesToFullValue converts a slice of big.Int values into the bytes of the +// bigIntToLeafKey converts a big.Int key into the bytes of the key. It +// encodes the key in Little-Endian format and pads it to the maximum length +// of the key. It returns the bytes of the key. +func bigIntToLeafKey(key *big.Int, maxLen int) []byte { + return BigIntToBytes(maxLen, key) +} + +// ValuesToFullValue converts a slice of big.Int values into the bytes of the // full value encoded in a reversible way. It concatenates the bytes of the // values with the length of each value at the beginning of each value. -func valuesToFullValue(values []*big.Int) ([]byte, error) { +func ValuesToFullValue(values []*big.Int) ([]byte, error) { // calculate the bytes of the full values (should be reversible) bFullValue := []byte{} for _, v := range values { @@ -216,11 +237,11 @@ func valuesToFullValue(values []*big.Int) ([]byte, error) { return bFullValue, nil } -// fullValueToValues converts the bytes of the full value encoded into a slice +// FullValueToValues converts the bytes of the full value encoded into a slice // of big.Int values. It iterates over the bytes of the full value and extracts // the length of each value and the bytes of the value to build the big.Int // values. -func fullValueToValues(fullValue []byte) []*big.Int { +func FullValueToValues(fullValue []byte) []*big.Int { values := []*big.Int{} iter := slices.Clone(fullValue) for len(iter) > 0 { @@ -234,29 +255,30 @@ func fullValueToValues(fullValue []byte) []*big.Int { // encodeBigIntData converts a big.Int key and a slice of big.Int values into the // bytes of the key, the bytes of the value used to build the tree and the // bytes of the full value encoded -func encodeBigIntData(hFn HashFunction, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) { +func encodeBigIntData(hFn HashFunction, keyLen int, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) { if key == nil { return nil, nil, nil, fmt.Errorf("key cannot be nil") } // calculate the bytes of the key - bKey := hFn.SafeBigInt(key) + // bKey := hFn.SafeBigInt(key) + bKey := bigIntToLeafKey(key, keyLen) // calculate the bytes of the full values (should be reversible) - bFullValue, err := valuesToFullValue(values) + bFullValue, err := ValuesToFullValue(values) if err != nil { return nil, nil, nil, err } // calculate the value used to build the tree - bValue, err := encodeBigIntValues(hFn, values...) + bValue, err := EncodeBigIntValues(hFn, values...) if err != nil { return nil, nil, nil, err } return bKey, bValue, bFullValue, nil } -// encodeBigIntValues converts a slice of big.Int values into the bytes of the +// EncodeBigIntValues converts a slice of big.Int values into the bytes of the // value used to build the tree. It hashes the bytes of the big.Int values // using the hash function of the tree. -func encodeBigIntValues(hFn HashFunction, values ...*big.Int) ([]byte, error) { +func EncodeBigIntValues(hFn HashFunction, values ...*big.Int) ([]byte, error) { chunks := make([][]byte, len(values)) for _, v := range values { value := hFn.SafeBigInt(v)