smal fixes for safe use of big ints

This commit is contained in:
Lucas Menendez
2025-04-02 13:07:16 +02:00
parent 10be9eebaa
commit a7ce325c13
4 changed files with 70 additions and 54 deletions

8
ff.go
View File

@@ -12,11 +12,5 @@ var (
// BigToFF function returns the finite field representation of the big.Int
// provided. It uses the curve scalar field to represent the provided number.
func BigToFF(baseField, iv *big.Int) *big.Int {
z := big.NewInt(0)
if c := iv.Cmp(baseField); c == 0 {
return z
} else if c != 1 && iv.Cmp(z) != -1 {
return iv
}
return z.Mod(iv, baseField)
return new(big.Int).Mod(iv, baseField)
}

View File

@@ -46,7 +46,7 @@ func (t *Tree) GenerateGnarkVerifierProof(k []byte) (*GnarkVerifierProof, error)
// initialize the GnarkVerifierProof
gp := GnarkVerifierProof{
Root: BytesToBigInt(root),
Key: leafKeyToBigInt(k),
Key: BytesToBigInt(k),
Value: BytesToBigInt(value),
Siblings: bigSiblings,
}

71
hash.go
View File

@@ -66,7 +66,8 @@ type HashFunction interface {
Type() []byte
Len() int
Hash(...[]byte) ([]byte, error)
SafeValue([]byte) ([]byte, error)
SafeValue([]byte) []byte
SafeBigInt(*big.Int) []byte
}
// HashSha256 implements the HashFunction interface for the Sha256 hash
@@ -92,8 +93,12 @@ func (f HashSha256) Hash(b ...[]byte) ([]byte, error) {
return h[:], nil
}
func (f HashSha256) SafeValue(b []byte) ([]byte, error) {
return b, nil
func (f HashSha256) SafeValue(b []byte) []byte {
return b
}
func (f HashSha256) SafeBigInt(b *big.Int) []byte {
return b.Bytes()
}
// HashPoseidon implements the HashFunction interface for the Poseidon hash
@@ -126,8 +131,12 @@ func (f HashPoseidon) Hash(b ...[]byte) ([]byte, error) {
return hB, nil
}
func (f HashPoseidon) SafeValue(b []byte) ([]byte, error) {
return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil
func (f HashPoseidon) SafeValue(b []byte) []byte {
return f.SafeBigInt(new(big.Int).SetBytes(b))
}
func (f HashPoseidon) SafeBigInt(b *big.Int) []byte {
return BigToFF(BN254BaseField, b).Bytes()
}
// HashMultiPoseidon implements the HashFunction interface for the MultiPoseidon hash
@@ -168,8 +177,12 @@ func (f HashMultiPoseidon) Hash(b ...[]byte) ([]byte, error) {
return BigIntToBytes(f.Len(), h), nil
}
func (f HashMultiPoseidon) SafeValue(b []byte) ([]byte, error) {
return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil
func (f HashMultiPoseidon) SafeValue(b []byte) []byte {
return f.SafeBigInt(new(big.Int).SetBytes(b))
}
func (f HashMultiPoseidon) SafeBigInt(b *big.Int) []byte {
return BigToFF(BN254BaseField, b).Bytes()
}
// HashBlake2b implements the HashFunction interface for the Blake2b hash
@@ -199,8 +212,12 @@ func (f HashBlake2b) Hash(b ...[]byte) ([]byte, error) {
return hasher.Sum(nil), nil
}
func (f HashBlake2b) SafeValue(b []byte) ([]byte, error) {
return b, nil
func (f HashBlake2b) SafeValue(b []byte) []byte {
return b
}
func (f HashBlake2b) SafeBigInt(b *big.Int) []byte {
return b.Bytes()
}
// HashMiMC_BLS12_377 implements the HashFunction interface for the MiMC hash
@@ -224,8 +241,12 @@ func (f HashMiMC_BLS12_377) Hash(b ...[]byte) ([]byte, error) {
return hashMiMCbyChunks(h, q, b...)
}
func (f HashMiMC_BLS12_377) SafeValue(b []byte) ([]byte, error) {
return BigToFF(BLS12377BaseField, new(big.Int).SetBytes(b)).Bytes(), nil
func (f HashMiMC_BLS12_377) SafeValue(b []byte) []byte {
return f.SafeBigInt(new(big.Int).SetBytes(b))
}
func (f HashMiMC_BLS12_377) SafeBigInt(b *big.Int) []byte {
return BigToFF(BLS12377BaseField, b).Bytes()
}
// HashMiMC_BN254 implements the HashFunction interface for the MiMC hash
@@ -248,16 +269,24 @@ func (f HashMiMC_BN254) Hash(b ...[]byte) ([]byte, error) {
// h := mimc_bn254.NewMiMC()
// return hashMiMCbyChunks(h, q, b...)
h := mimc_bn254.NewMiMC()
var fullBytes []byte
for _, input := range b {
if _, err := h.Write(input); err != nil {
return nil, err
}
fullBytes = append(fullBytes, input...)
}
for start := 0; start < len(fullBytes); start += h.BlockSize() {
end := min(start+h.BlockSize(), len(fullBytes))
chunk := fullBytes[start:end]
h.Write(chunk)
}
return h.Sum(nil), nil
}
func (f HashMiMC_BN254) SafeValue(b []byte) ([]byte, error) {
return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil
func (f HashMiMC_BN254) SafeValue(b []byte) []byte {
return f.SafeBigInt(new(big.Int).SetBytes(b))
}
func (f HashMiMC_BN254) SafeBigInt(b *big.Int) []byte {
return BigToFF(BN254BaseField, b).Bytes()
}
// hashMiMCbyChunks is a helper function to hash by chunks using the MiMC hash.
@@ -319,7 +348,7 @@ func (f HashMiMC7) Len() int {
func (f HashMiMC7) Hash(b ...[]byte) ([]byte, error) {
var toHash []*big.Int
for _, i := range b {
toHash = append(toHash, new(big.Int).SetBytes(SwapEndianness(i)))
toHash = append(toHash, BytesToBigInt(i))
}
h, err := mimc7.Hash(toHash, nil)
if err != nil {
@@ -328,6 +357,10 @@ func (f HashMiMC7) Hash(b ...[]byte) ([]byte, error) {
return BigIntToBytes(f.Len(), h), nil
}
func (f HashMiMC7) SafeValue(b []byte) ([]byte, error) {
return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil
func (f HashMiMC7) SafeValue(b []byte) []byte {
return BigToFF(BN254BaseField, BytesToBigInt(b)).Bytes()
}
func (f HashMiMC7) SafeBigInt(b *big.Int) []byte {
return f.SafeValue(b.Bytes())
}

View File

@@ -3,7 +3,6 @@ package arbo
import (
"bytes"
"fmt"
"log"
"math/big"
"slices"
)
@@ -23,7 +22,7 @@ func (t *Tree) AddBatchBigInt(k []*big.Int, v [][]*big.Int) ([]Invalid, error) {
bvs := make([][]byte, len(k))
fbvs := make([][]byte, len(k))
for i, ki := range k {
bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), t.maxKeyLen(), ki, v[i])
bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), ki, v[i])
if err != nil {
return nil, err
}
@@ -58,9 +57,8 @@ func (t *Tree) AddBigInt(k *big.Int, v ...*big.Int) error {
return fmt.Errorf("key cannot be nil")
}
// convert the big ints to bytes
bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.maxKeyLen(), k, v)
bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, v)
if err != nil {
log.Println(err, k, v)
return err
}
// add it to the tree
@@ -89,7 +87,7 @@ func (t *Tree) UpdateBigInt(k *big.Int, value ...*big.Int) error {
return fmt.Errorf("key cannot be nil")
}
// convert the big ints to bytes
bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.maxKeyLen(), k, value)
bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, value)
if err != nil {
return err
}
@@ -118,7 +116,8 @@ func (t *Tree) GetBigInt(k *big.Int) (*big.Int, []*big.Int, error) {
if k == nil {
return nil, nil, fmt.Errorf("key cannot be nil")
}
bk, bv, err := t.Get(bigIntToLeafKey(t.maxKeyLen(), k))
bk := t.HashFunction().SafeBigInt(k)
_, bv, err := t.Get(bk)
if err != nil {
return nil, nil, err
}
@@ -137,7 +136,7 @@ func (t *Tree) GenProofBigInts(k *big.Int) ([]byte, []byte, []byte, bool, error)
if k == nil {
return nil, nil, nil, false, fmt.Errorf("key cannot be nil")
}
return t.GenProof(bigIntToLeafKey(t.maxKeyLen(), k))
return t.GenProof(t.HashFunction().SafeBigInt(k))
}
// GenerateCircomVerifierProofBigInt generates a CircomVerifierProof for a key
@@ -148,7 +147,7 @@ func (t *Tree) GenerateCircomVerifierProofBigInt(k *big.Int) (*CircomVerifierPro
if k == nil {
return nil, fmt.Errorf("key cannot be nil")
}
return t.GenerateCircomVerifierProof(bigIntToLeafKey(t.maxKeyLen(), k))
return t.GenerateCircomVerifierProof(t.HashFunction().SafeBigInt(k))
}
// GenerateGnarkVerifierProofBigInt generates a GnarkVerifierProof for a key
@@ -159,12 +158,7 @@ func (t *Tree) GenerateGnarkVerifierProofBigInt(k *big.Int) (*GnarkVerifierProof
if k == nil {
return nil, fmt.Errorf("key cannot be nil")
}
return t.GenerateGnarkVerifierProof(bigIntToLeafKey(t.maxKeyLen(), k))
}
// maxKeyLen returns the maximum length of the key in bytes for a tree
func (t *Tree) maxKeyLen() int {
return keyLenByLevels(t.maxLevels)
return t.GenerateGnarkVerifierProof(t.HashFunction().SafeBigInt(k))
}
// leafToBigInts converts the bytes of the key and the value of a leaf node
@@ -194,17 +188,12 @@ func (t *Tree) leafToBigInts(key, value, fullValue []byte) (*big.Int, []*big.Int
return nil, nil, fmt.Errorf("LeafToBigInt: encodedValues != value")
}
// convert the bytes of the key to a big.Int
return BytesToBigInt(key), values, nil
}
// BigIntToBytes converts a big.Int into a byte slice of length keyLen
func bigIntToLeafKey(keyLen int, biKey *big.Int) []byte {
return BigIntToBytes(keyLen, biKey)
return leafKeyToBigInt(key), values, nil
}
// leafKeyToBigInt converts the bytes of a key into a big.Int
func leafKeyToBigInt(key []byte) *big.Int {
return BytesToBigInt(key)
return new(big.Int).SetBytes(key)
}
// valuesToFullValue converts a slice of big.Int values into the bytes of the
@@ -245,12 +234,12 @@ func fullValueToValues(fullValue []byte) []*big.Int {
// encodeBigIntData converts a big.Int key and a slice of big.Int values into the
// bytes of the key, the bytes of the value used to build the tree and the
// bytes of the full value encoded
func encodeBigIntData(hFn HashFunction, keyLen int, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) {
func encodeBigIntData(hFn HashFunction, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) {
if key == nil {
return nil, nil, nil, fmt.Errorf("key cannot be nil")
}
// calculate the bytes of the key
bKey := bigIntToLeafKey(keyLen, key)
bKey := hFn.SafeBigInt(key)
// calculate the bytes of the full values (should be reversible)
bFullValue, err := valuesToFullValue(values)
if err != nil {
@@ -271,11 +260,11 @@ func encodeBigIntValues(hFn HashFunction, values ...*big.Int) ([]byte, error) {
chunks := make([][]byte, len(values))
for _, v := range values {
// truncate the value if it exceeds the maximum chunk bytes
value, err := hFn.SafeValue(v.Bytes())
if err != nil {
return nil, err
value := hFn.SafeBigInt(v)
if value == nil {
return nil, fmt.Errorf("value cannot be nil")
}
chunks = append(chunks, value)
chunks = append(chunks, SwapEndianness(value))
}
return hFn.Hash(chunks...)
}