Files
linea-monorepo/prover/crypto/ringsis/ringsis_test.go
Arya Tabaie fdd84f24e5 refactor: Use new GKR API (#1064)
* use new gkr API

* fix Table pointers

* refactor: remove removed test engine option

* chore: don't initialize struct for interface assertion

* refactor: plonk-in-wizard hardcoded over U64 for now

* refactor: use new gnark-crypto stateless RSis API

* test: disable incompatible tests

* chore: go mod update to PR tip

* chore: dependency against gnark master

* chore: cherry-pick 43141fc13d

* test: cherry pick test from 407d2e25ecfc32f5ed702ab42e5b829d7cabd483

* chore: remove magic values

* chore: update go version in Docker builder to match go.mod

---------

Co-authored-by: Ivo Kubjas <ivo.kubjas@consensys.net>
2025-06-09 14:17:34 +02:00

414 lines
11 KiB
Go

package ringsis
import (
"fmt"
"testing"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
"github.com/consensys/linea-monorepo/prover/maths/common/poly"
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/common/vector"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var testCasesKey = []struct {
Size int
Params
}{
{
Size: 2,
Params: StdParams,
},
{
Size: 100,
Params: StdParams,
},
{
Size: 32,
Params: StdParams,
},
{
Size: 5,
Params: StdParams,
},
{
Size: 576,
Params: StdParams,
},
{
Size: 43,
Params: Params{
LogTwoBound: 8,
LogTwoDegree: 6,
},
},
}
func TestKeyMaxNumFieldHashable(t *testing.T) {
for _, testCase := range testCasesKey {
t.Run(fmt.Sprintf("case-%++v", testCase), func(t *testing.T) {
numFieldPerPoly := 8 * field.Bytes / testCase.LogTwoBound
key := GenerateKey(testCase.Params, testCase.Size)
assert.LessOrEqual(t, testCase.Size, key.MaxNumFieldHashable())
assert.LessOrEqual(t, key.MaxNumFieldHashable(), testCase.Size+numFieldPerPoly-1)
})
}
}
func TestHashModXnMinusOne(t *testing.T) {
runTest := func(t *testing.T, key *Key, limbs []field.Element) {
var (
dualHash = key.HashModXnMinus1(limbs)
recomputedHash = []field.Element{}
flattenedKey = key.FlattenedKey() // accounts for the Montgommery skip
)
for i := range key.gnarkInternal.A {
ai := flattenedKey[i*key.OutputSize() : (i+1)*key.OutputSize()]
si := make([]field.Element, key.OutputSize())
copy(si, limbs[i*key.OutputSize():])
tmp := poly.Mul(ai, si)
recomputedHash = poly.Add(recomputedHash, tmp)
}
for i := key.OutputSize(); i < len(recomputedHash); i++ {
recomputedHash[i-key.OutputSize()].Add(&recomputedHash[i-key.OutputSize()], &recomputedHash[i])
}
recomputedHash = recomputedHash[:key.OutputSize()]
require.Equal(t, dualHash, recomputedHash)
}
for i, testCase := range testCasesKey {
key := GenerateKey(testCasesKey[i].Params, testCase.Size)
t.Run(fmt.Sprintf("case-%++v/all-ones", i), func(t *testing.T) {
runTest(t, &key, vector.Repeat(field.One(), key.maxNumLimbsHashable()))
})
t.Run(fmt.Sprintf("case-%++v/all-zeroes", i), func(t *testing.T) {
runTest(t, &key, vector.Repeat(field.Zero(), key.maxNumLimbsHashable()))
})
t.Run(fmt.Sprintf("case-%++v/rand-constant", i), func(t *testing.T) {
var r field.Element
r.SetRandom()
runTest(t, &key, vector.Repeat(r, key.maxNumLimbsHashable()))
})
t.Run(fmt.Sprintf("case-%++v/full-rand", i), func(t *testing.T) {
runTest(t, &key, vector.Rand(key.maxNumLimbsHashable()))
})
// ==== passing shorter vectors
t.Run(fmt.Sprintf("case-%++v/all-ones-shorter", i), func(t *testing.T) {
runTest(t, &key, vector.Repeat(field.One(), key.maxNumLimbsHashable()-1))
})
t.Run(fmt.Sprintf("case-%++v/all-zeroes-shorter", i), func(t *testing.T) {
runTest(t, &key, vector.Repeat(field.Zero(), key.maxNumLimbsHashable()-1))
})
t.Run(fmt.Sprintf("case-%++v/rand-constant-shorter", i), func(t *testing.T) {
var r field.Element
r.SetRandom()
runTest(t, &key, vector.Repeat(r, key.maxNumLimbsHashable()-1))
})
t.Run(fmt.Sprintf("case-%++v/full-rand-shorter", i), func(t *testing.T) {
runTest(t, &key, vector.Rand(key.maxNumLimbsHashable()-1))
})
}
}
func TestLimbSplit(t *testing.T) {
randX := mimc.BlockCompression(field.Zero(), field.Zero())
randX.Inverse(&randX)
arrays := [][]field.Element{
{field.One()},
{field.One(), field.One()},
{field.NewElement(0), field.NewElement(1), field.NewElement(2), field.NewElement(3)},
{field.NewElement(0xffffffffffffffff), field.NewFromString("-1"), field.NewElement(8)},
{randX, field.One(), field.NewFromString("-1")},
}
coreTest := func(subT *testing.T, v []field.Element, key Key) {
bound := field.NewElement(1 << key.LogTwoBound)
limbs := key.LimbSplit(v)
for i := range v {
subLimbs := limbs[i*key.NumLimbs() : (i+1)*key.NumLimbs()]
recomposed := poly.EvalUnivariate(subLimbs, bound)
assert.Equal(subT, v[i].String(), recomposed.String())
}
}
for arrID, arr := range arrays {
for keyID, tcKey := range testCasesKey {
key := GenerateKey(tcKey.Params, tcKey.Size)
t.Run(
fmt.Sprintf("array-#%v-key-#%v", arrID, keyID),
func(t *testing.T) {
coreTest(t, arr, key)
},
)
}
}
}
func TestHashFromLimbs(t *testing.T) {
coreTest := func(t *testing.T, key *Key, v []field.Element) {
hashed := key.Hash(v)
limbs := key.LimbSplit(v)
hashedFromLimbs := key.hashFromLimbs(limbs)
assert.Equal(t, hashed, hashedFromLimbs)
}
testCaseVecs := [][]field.Element{
{field.One()},
{field.One(), field.One()},
{field.NewElement(0), field.NewElement(1), field.NewElement(2), field.NewElement(3)},
{field.NewElement(0xffffffffffffffff), field.NewFromString("-1"), field.NewElement(8)},
vector.Rand(4),
vector.Rand(16),
vector.Rand(5),
vector.Rand(17),
}
for pId, tcParams := range testCasesKey {
for vecId, tcVec := range testCaseVecs {
key := GenerateKey(tcParams.Params, len(tcVec))
t.Run(fmt.Sprintf("params-#%v-vec-1%v", pId, vecId), func(t *testing.T) {
coreTest(t, &key, tcVec)
})
}
}
}
func TestTransveralHashFromLimbs(t *testing.T) {
testCaseDimensions := []struct {
NumRows, NumCols int
}{
{
NumRows: 4,
NumCols: 16,
},
{
NumRows: 5,
NumCols: 4,
},
{
NumRows: 4,
NumCols: 5,
},
{
NumRows: 4,
NumCols: 18,
},
{
NumRows: 5,
NumCols: 18,
},
{
NumRows: 128,
NumCols: 512,
},
{
NumRows: 128,
NumCols: 512 - 1,
},
{
NumRows: 128,
NumCols: 512 + 1,
},
{
NumRows: 128 - 1,
NumCols: 512,
},
{
NumRows: 128 - 1,
NumCols: 512 - 1,
},
{
NumRows: 128 - 1,
NumCols: 512 + 1,
},
{
NumRows: 128 + 1,
NumCols: 512,
},
{
NumRows: 128 + 1,
NumCols: 512 - 1,
},
{
NumRows: 128 + 1,
NumCols: 512 + 1,
},
}
for _, tcKeyParams := range testCasesKey {
for _, tcDim := range testCaseDimensions {
t.Logf("params-%v-numRow=%v-nCols=%v", tcKeyParams.Size, tcDim.NumRows, tcDim.NumCols)
// t.Run(
// fmt.Sprintf("params-%v-numRow=%v-nCols=%v", pId, tcDim.NumRows, tcDim.NumCols),
// func(t *testing.T) {
assert := require.New(t)
key := GenerateKey(tcKeyParams.Params, tcDim.NumRows)
inputs := make([]smartvectors.SmartVector, 4)
for i := range inputs {
inputs[i] = smartvectors.Rand(16)
}
transposed := make([][]field.Element, 16)
for i := range transposed {
transposed[i] = make([]fr.Element, 4)
for j := range transposed[i] {
transposed[i][j] = inputs[j].Get(i)
}
}
res := key.TransversalHash(inputs)
for i := range transposed {
baseline := key.Hash(transposed[i])
for j := range baseline {
assert.Equal(baseline[j], res[i*key.OutputSize()+j], "transversal hash does not match col hash at %d %d", i, j)
}
// assert.Equal(baseline, res[i*key.OutputSize():(i+1)*key.OutputSize()])
}
// t.FailNow()
// },
// )
}
}
}
func TestHashLimbsFromSlice(t *testing.T) {
for i, tcParams := range testCasesKey {
t.Run(fmt.Sprintf("params-%v", i), func(t *testing.T) {
var (
key = GenerateKey(tcParams.Params, tcParams.Size)
inputs = vector.Rand(tcParams.Size)
keyVec = key.FlattenedKey()
limbs = key.LimbSplit(inputs)
expectedHash = key.hashFromLimbs(limbs)
hashToTest = hashLimbsWithSlice(keyVec, limbs, key.gnarkInternal.Domain, key.OutputSize())
)
require.Equal(t, vector.Prettify(expectedHash), vector.Prettify(hashToTest))
})
}
}
// For testing purposes
func hashLimbsWithSlice(keySlice []field.Element, limbs []field.Element, domain *fft.Domain, degree int) []field.Element {
nbPolyUsed := utils.DivCeil(len(limbs), degree)
if len(limbs) > len(keySlice) {
utils.Panic("Too many inputs max is %v but has %v", len(limbs), len(keySlice))
}
// we can hash now.
res := make([]field.Element, degree)
// method 1: fft
limbsChunk := make([]field.Element, degree)
keyChunk := make([]field.Element, degree)
for i := 0; i < nbPolyUsed; i++ {
// extract the key element, without copying
copy(keyChunk, keySlice[i*degree:(i+1)*degree])
domain.FFT(keyChunk, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1))
// extract a copy of the limbs
// The last poly may be incomplete
copy(limbsChunk, limbs[i*degree:])
// so we may need to zero-pad the last one if incomplete. This
// loop will be skipped if limns[i*degree] is larger than
// the degree
for i := len(limbs[i*degree:]); i < degree; i++ {
limbsChunk[i].SetZero()
}
domain.FFT(limbsChunk, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1))
var tmp field.Element
for j := range res {
tmp.Mul(&keyChunk[j], &limbsChunk[j])
res[j].Add(&res[j], &tmp)
}
}
domain.FFTInverse(res, fft.DIT, fft.OnCoset(), fft.WithNbTasks(1)) // -> reduces mod Xᵈ+1
return res
}
// hashFromLimbs hashes a vector of limbs in Montgommery form. Unoptimized, only
// there for testing purpose.
func (key *Key) hashFromLimbs(limbs []field.Element) []field.Element {
nbPolyUsed := utils.DivCeil(len(limbs), key.OutputSize())
if nbPolyUsed > len(key.gnarkInternal.Ag) {
utils.Panic("Too many inputs max is %v but has %v", len(key.gnarkInternal.Ag)*key.OutputSize(), len(limbs))
}
var (
// res accumulates and stores the result of the hash as we compute it
// throughout the function
res = make([]field.Element, key.OutputSize())
// k serves as a preallocated buffer aimed at storing a copy of the
// limbs as we compute them. Its purpose is to minimize the number of
// allocations made throughout the hashing procedure.
k = make([]field.Element, key.modulusDegree())
)
for i := 0; i < nbPolyUsed; i++ {
// This accounts for the fact that limbs may be smaller than the modulus
// degree. When that happens, there won't be an OOB error but will need
// to manually zeroize the buffer.
copy(k, limbs[i*key.modulusDegree():])
for i := len(limbs[i*key.modulusDegree():]); i < key.modulusDegree(); i++ {
k[i].SetZero()
}
key.gnarkInternal.Domain.FFT(k, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1))
var tmp field.Element
for j := range res {
tmp.Mul(&k[j], &key.gnarkInternal.Ag[i][j])
res[j].Add(&res[j], &tmp)
}
}
// Since the Ag are normally assumed to work with non-montgomery limbs
// (when doing normal hashing)
for j := range res {
res[j] = field.MulRInv(res[j])
}
key.gnarkInternal.Domain.FFTInverse(res, fft.DIT, fft.OnCoset(), fft.WithNbTasks(1)) // -> reduces mod Xᵈ+1
return res
}