diff --git a/ff.go b/ff.go index 444e985..504352a 100644 --- a/ff.go +++ b/ff.go @@ -12,11 +12,5 @@ var ( // BigToFF function returns the finite field representation of the big.Int // provided. It uses the curve scalar field to represent the provided number. func BigToFF(baseField, iv *big.Int) *big.Int { - z := big.NewInt(0) - if c := iv.Cmp(baseField); c == 0 { - return z - } else if c != 1 && iv.Cmp(z) != -1 { - return iv - } - return z.Mod(iv, baseField) + return new(big.Int).Mod(iv, baseField) } diff --git a/gnark.go b/gnark.go index 8d8cefb..9bd35af 100644 --- a/gnark.go +++ b/gnark.go @@ -46,7 +46,7 @@ func (t *Tree) GenerateGnarkVerifierProof(k []byte) (*GnarkVerifierProof, error) // initialize the GnarkVerifierProof gp := GnarkVerifierProof{ Root: BytesToBigInt(root), - Key: leafKeyToBigInt(k), + Key: BytesToBigInt(k), Value: BytesToBigInt(value), Siblings: bigSiblings, } diff --git a/hash.go b/hash.go index efaa583..d3c8cd1 100644 --- a/hash.go +++ b/hash.go @@ -66,7 +66,8 @@ type HashFunction interface { Type() []byte Len() int Hash(...[]byte) ([]byte, error) - SafeValue([]byte) ([]byte, error) + SafeValue([]byte) []byte + SafeBigInt(*big.Int) []byte } // HashSha256 implements the HashFunction interface for the Sha256 hash @@ -92,8 +93,12 @@ func (f HashSha256) Hash(b ...[]byte) ([]byte, error) { return h[:], nil } -func (f HashSha256) SafeValue(b []byte) ([]byte, error) { - return b, nil +func (f HashSha256) SafeValue(b []byte) []byte { + return b +} + +func (f HashSha256) SafeBigInt(b *big.Int) []byte { + return b.Bytes() } // HashPoseidon implements the HashFunction interface for the Poseidon hash @@ -126,8 +131,12 @@ func (f HashPoseidon) Hash(b ...[]byte) ([]byte, error) { return hB, nil } -func (f HashPoseidon) SafeValue(b []byte) ([]byte, error) { - return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil +func (f HashPoseidon) SafeValue(b []byte) []byte { + return f.SafeBigInt(new(big.Int).SetBytes(b)) +} + +func (f HashPoseidon) SafeBigInt(b *big.Int) []byte { + return BigToFF(BN254BaseField, b).Bytes() } // HashMultiPoseidon implements the HashFunction interface for the MultiPoseidon hash @@ -168,8 +177,12 @@ func (f HashMultiPoseidon) Hash(b ...[]byte) ([]byte, error) { return BigIntToBytes(f.Len(), h), nil } -func (f HashMultiPoseidon) SafeValue(b []byte) ([]byte, error) { - return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil +func (f HashMultiPoseidon) SafeValue(b []byte) []byte { + return f.SafeBigInt(new(big.Int).SetBytes(b)) +} + +func (f HashMultiPoseidon) SafeBigInt(b *big.Int) []byte { + return BigToFF(BN254BaseField, b).Bytes() } // HashBlake2b implements the HashFunction interface for the Blake2b hash @@ -199,8 +212,12 @@ func (f HashBlake2b) Hash(b ...[]byte) ([]byte, error) { return hasher.Sum(nil), nil } -func (f HashBlake2b) SafeValue(b []byte) ([]byte, error) { - return b, nil +func (f HashBlake2b) SafeValue(b []byte) []byte { + return b +} + +func (f HashBlake2b) SafeBigInt(b *big.Int) []byte { + return b.Bytes() } // HashMiMC_BLS12_377 implements the HashFunction interface for the MiMC hash @@ -224,8 +241,12 @@ func (f HashMiMC_BLS12_377) Hash(b ...[]byte) ([]byte, error) { return hashMiMCbyChunks(h, q, b...) } -func (f HashMiMC_BLS12_377) SafeValue(b []byte) ([]byte, error) { - return BigToFF(BLS12377BaseField, new(big.Int).SetBytes(b)).Bytes(), nil +func (f HashMiMC_BLS12_377) SafeValue(b []byte) []byte { + return f.SafeBigInt(new(big.Int).SetBytes(b)) +} + +func (f HashMiMC_BLS12_377) SafeBigInt(b *big.Int) []byte { + return BigToFF(BLS12377BaseField, b).Bytes() } // HashMiMC_BN254 implements the HashFunction interface for the MiMC hash @@ -248,16 +269,24 @@ func (f HashMiMC_BN254) Hash(b ...[]byte) ([]byte, error) { // h := mimc_bn254.NewMiMC() // return hashMiMCbyChunks(h, q, b...) h := mimc_bn254.NewMiMC() + var fullBytes []byte for _, input := range b { - if _, err := h.Write(input); err != nil { - return nil, err - } + fullBytes = append(fullBytes, input...) + } + for start := 0; start < len(fullBytes); start += h.BlockSize() { + end := min(start+h.BlockSize(), len(fullBytes)) + chunk := fullBytes[start:end] + h.Write(chunk) } return h.Sum(nil), nil } -func (f HashMiMC_BN254) SafeValue(b []byte) ([]byte, error) { - return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil +func (f HashMiMC_BN254) SafeValue(b []byte) []byte { + return f.SafeBigInt(new(big.Int).SetBytes(b)) +} + +func (f HashMiMC_BN254) SafeBigInt(b *big.Int) []byte { + return BigToFF(BN254BaseField, b).Bytes() } // hashMiMCbyChunks is a helper function to hash by chunks using the MiMC hash. @@ -319,7 +348,7 @@ func (f HashMiMC7) Len() int { func (f HashMiMC7) Hash(b ...[]byte) ([]byte, error) { var toHash []*big.Int for _, i := range b { - toHash = append(toHash, new(big.Int).SetBytes(SwapEndianness(i))) + toHash = append(toHash, BytesToBigInt(i)) } h, err := mimc7.Hash(toHash, nil) if err != nil { @@ -328,6 +357,10 @@ func (f HashMiMC7) Hash(b ...[]byte) ([]byte, error) { return BigIntToBytes(f.Len(), h), nil } -func (f HashMiMC7) SafeValue(b []byte) ([]byte, error) { - return BigToFF(BN254BaseField, new(big.Int).SetBytes(b)).Bytes(), nil +func (f HashMiMC7) SafeValue(b []byte) []byte { + return BigToFF(BN254BaseField, BytesToBigInt(b)).Bytes() +} + +func (f HashMiMC7) SafeBigInt(b *big.Int) []byte { + return f.SafeValue(b.Bytes()) } diff --git a/tree_big.go b/tree_big.go index d27483a..ac9648a 100644 --- a/tree_big.go +++ b/tree_big.go @@ -3,7 +3,6 @@ package arbo import ( "bytes" "fmt" - "log" "math/big" "slices" ) @@ -23,7 +22,7 @@ func (t *Tree) AddBatchBigInt(k []*big.Int, v [][]*big.Int) ([]Invalid, error) { bvs := make([][]byte, len(k)) fbvs := make([][]byte, len(k)) for i, ki := range k { - bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), t.maxKeyLen(), ki, v[i]) + bks[i], bvs[i], fbvs[i], err = encodeBigIntData(t.HashFunction(), ki, v[i]) if err != nil { return nil, err } @@ -58,9 +57,8 @@ func (t *Tree) AddBigInt(k *big.Int, v ...*big.Int) error { return fmt.Errorf("key cannot be nil") } // convert the big ints to bytes - bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.maxKeyLen(), k, v) + bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, v) if err != nil { - log.Println(err, k, v) return err } // add it to the tree @@ -89,7 +87,7 @@ func (t *Tree) UpdateBigInt(k *big.Int, value ...*big.Int) error { return fmt.Errorf("key cannot be nil") } // convert the big ints to bytes - bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), t.maxKeyLen(), k, value) + bk, bv, fbv, err := encodeBigIntData(t.HashFunction(), k, value) if err != nil { return err } @@ -118,7 +116,8 @@ func (t *Tree) GetBigInt(k *big.Int) (*big.Int, []*big.Int, error) { if k == nil { return nil, nil, fmt.Errorf("key cannot be nil") } - bk, bv, err := t.Get(bigIntToLeafKey(t.maxKeyLen(), k)) + bk := t.HashFunction().SafeBigInt(k) + _, bv, err := t.Get(bk) if err != nil { return nil, nil, err } @@ -137,7 +136,7 @@ func (t *Tree) GenProofBigInts(k *big.Int) ([]byte, []byte, []byte, bool, error) if k == nil { return nil, nil, nil, false, fmt.Errorf("key cannot be nil") } - return t.GenProof(bigIntToLeafKey(t.maxKeyLen(), k)) + return t.GenProof(t.HashFunction().SafeBigInt(k)) } // GenerateCircomVerifierProofBigInt generates a CircomVerifierProof for a key @@ -148,7 +147,7 @@ func (t *Tree) GenerateCircomVerifierProofBigInt(k *big.Int) (*CircomVerifierPro if k == nil { return nil, fmt.Errorf("key cannot be nil") } - return t.GenerateCircomVerifierProof(bigIntToLeafKey(t.maxKeyLen(), k)) + return t.GenerateCircomVerifierProof(t.HashFunction().SafeBigInt(k)) } // GenerateGnarkVerifierProofBigInt generates a GnarkVerifierProof for a key @@ -159,12 +158,7 @@ func (t *Tree) GenerateGnarkVerifierProofBigInt(k *big.Int) (*GnarkVerifierProof if k == nil { return nil, fmt.Errorf("key cannot be nil") } - return t.GenerateGnarkVerifierProof(bigIntToLeafKey(t.maxKeyLen(), k)) -} - -// maxKeyLen returns the maximum length of the key in bytes for a tree -func (t *Tree) maxKeyLen() int { - return keyLenByLevels(t.maxLevels) + return t.GenerateGnarkVerifierProof(t.HashFunction().SafeBigInt(k)) } // leafToBigInts converts the bytes of the key and the value of a leaf node @@ -194,17 +188,12 @@ func (t *Tree) leafToBigInts(key, value, fullValue []byte) (*big.Int, []*big.Int return nil, nil, fmt.Errorf("LeafToBigInt: encodedValues != value") } // convert the bytes of the key to a big.Int - return BytesToBigInt(key), values, nil -} - -// BigIntToBytes converts a big.Int into a byte slice of length keyLen -func bigIntToLeafKey(keyLen int, biKey *big.Int) []byte { - return BigIntToBytes(keyLen, biKey) + return leafKeyToBigInt(key), values, nil } // leafKeyToBigInt converts the bytes of a key into a big.Int func leafKeyToBigInt(key []byte) *big.Int { - return BytesToBigInt(key) + return new(big.Int).SetBytes(key) } // valuesToFullValue converts a slice of big.Int values into the bytes of the @@ -245,12 +234,12 @@ func fullValueToValues(fullValue []byte) []*big.Int { // encodeBigIntData converts a big.Int key and a slice of big.Int values into the // bytes of the key, the bytes of the value used to build the tree and the // bytes of the full value encoded -func encodeBigIntData(hFn HashFunction, keyLen int, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) { +func encodeBigIntData(hFn HashFunction, key *big.Int, values []*big.Int) ([]byte, []byte, []byte, error) { if key == nil { return nil, nil, nil, fmt.Errorf("key cannot be nil") } // calculate the bytes of the key - bKey := bigIntToLeafKey(keyLen, key) + bKey := hFn.SafeBigInt(key) // calculate the bytes of the full values (should be reversible) bFullValue, err := valuesToFullValue(values) if err != nil { @@ -271,11 +260,11 @@ func encodeBigIntValues(hFn HashFunction, values ...*big.Int) ([]byte, error) { chunks := make([][]byte, len(values)) for _, v := range values { // truncate the value if it exceeds the maximum chunk bytes - value, err := hFn.SafeValue(v.Bytes()) - if err != nil { - return nil, err + value := hFn.SafeBigInt(v) + if value == nil { + return nil, fmt.Errorf("value cannot be nil") } - chunks = append(chunks, value) + chunks = append(chunks, SwapEndianness(value)) } return hFn.Hash(chunks...) }