diff --git a/hash.go b/hash.go index bcd5585..1698b40 100644 --- a/hash.go +++ b/hash.go @@ -1,7 +1,9 @@ package arbo import ( + "bytes" "crypto/sha256" + "fmt" "math/big" fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -387,55 +389,54 @@ func (f HashMiMC7) SafeBigInt(b *big.Int) []byte { } // ──────────────────────────────────────────────────────────────────────────────── -// Poseidon2 implementation (BN254, Merkle–Damgård, 32-byte digest) +// Poseidon-2 (BN254) – Merkle–Damgård, width-2, 32-byte digest // ──────────────────────────────────────────────────────────────────────────────── -// HashPoseidon2 implements the HashFunction interface using the -// gnark-crypto Poseidon2 permutation in Merkle-Damgård mode -// (width = 2, rF = 6, rP = 50). type HashPoseidon2 struct{} -// Type returns the function label. func (HashPoseidon2) Type() []byte { return TypeHashPoseidon2 } +func (HashPoseidon2) Len() int { return 32 } -// Len returns the byte length of the digest (one BN254 field element). -func (HashPoseidon2) Len() int { return 32 } +var perm2 = poseidon2.NewPermutation(2 /*t*/, 6 /*rF*/, 50 /*rP*/) -// Hash concatenates the inputs and hashes them with Poseidon2. -// Each Write call handles its own padding; no chunking is required here. -// Normalizes the order of inputs to ensure consistent hashing regardless -// of the order in which values are provided. -func (f HashPoseidon2) Hash(b ...[]byte) ([]byte, error) { - // For Merkle intermediate nodes, we have exactly 2 inputs - if len(b) == 2 { - // Convert both inputs to big.Int for deterministic comparison - bigA := new(big.Int).SetBytes(b[0]) - bigB := new(big.Int).SetBytes(b[1]) - - // Compare and consistently order the inputs - if bigA.Cmp(bigB) > 0 { - // Swap if a > b to ensure consistent ordering - b[0], b[1] = b[1], b[0] - } +func (h HashPoseidon2) Hash(inp ...[]byte) ([]byte, error) { + // accept 2 (internal) or 3 (leaf) elements + if len(inp) != 2 && len(inp) != 3 { + return nil, fmt.Errorf("Poseidon-2: need 2 or 3 inputs, got %d", len(inp)) + } + // canonicalise each limb + safe := make([][]byte, len(inp)) + for i, b := range inp { + safe[i] = h.SafeBigInt(new(big.Int).SetBytes(b)) } - h := poseidon2.NewMerkleDamgardHasher() // implements hash.Hash - for _, in := range b { - if _, err := h.Write(f.SafeValue(in)); err != nil { + // internal node ⇒ order min‖max + if len(inp) == 2 && bytes.Compare(safe[0], safe[1]) > 0 { + safe[0], safe[1] = safe[1], safe[0] + } + + md := hash.NewMerkleDamgardHasher(perm2, make([]byte, 32)) // IV = 0 + for _, b := range safe { + if _, err := md.Write(b); err != nil { return nil, err } } - return h.Sum(nil), nil // 32-byte digest + return md.Sum(nil), nil } -// SafeValue converts an arbitrary little-endian byte slice to a -// canonical BN254 field-element encoding. -func (f HashPoseidon2) SafeValue(b []byte) []byte { - return f.SafeBigInt(new(big.Int).SetBytes(b)) +// helpers ---------------------------------------------------------------------- +func (HashPoseidon2) SafeValue(x []byte) []byte { + return BigIntToFFwithPadding(new(big.Int).SetBytes(x), BN254BaseField) } -// SafeBigInt converts an arbitrary big.Int into a byte slice that is a -// valid BN254 base-field element in little-endian form. -func (HashPoseidon2) SafeBigInt(b *big.Int) []byte { - return ExplicitZero(BigToFF(BN254BaseField, b).Bytes()) +func (HashPoseidon2) SafeBigInt(x *big.Int) []byte { + return BigIntToFFwithPadding(x, BN254BaseField) +} + +func BigIntToFFwithPadding(x, modulus *big.Int) []byte { + b := BigToFF(modulus, x).Bytes() + for len(b) < 32 { + b = append([]byte{0}, b...) + } + return b } diff --git a/hash_test.go b/hash_test.go index 43aec20..b3a98a3 100644 --- a/hash_test.go +++ b/hash_test.go @@ -67,3 +67,16 @@ func TestHashMiMC(t *testing.T) { qt.Equals, "f881f34991492d823e02565c778b824bac5eacef6340b70ee90a8966a2e63900") } + +func TestHashPoseidon2(t *testing.T) { + // Poseidon hash + hashFunc := &HashPoseidon2{} + bLen := hashFunc.Len() + h, err := hashFunc.Hash( + BigIntToBytes(bLen, big.NewInt(1)), + BigIntToBytes(bLen, big.NewInt(2))) + if err != nil { + t.Fatal(err) + } + t.Logf("hash: %x", h) +}