mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-08 23:17:58 -05:00
Feat Setup checks equality between KZG Vk from circuit Vk and from SRS (#325)
* feat test vk consistency * test write-read setup * feat readfromfile-writetofile --------- Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
This commit is contained in:
@@ -3,18 +3,17 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/frontend/cs/scs"
|
||||
v1 "github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v1"
|
||||
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const maxNbConstraints = 1 << 27
|
||||
|
||||
func nbConstraints(blobSize int) int {
|
||||
|
||||
fmt.Printf("*********************\nfor blob of size %dB or %.2fKB:\n", blobSize, float32(blobSize)/1024)
|
||||
c := v1.Circuit{
|
||||
BlobBytes: make([]frontend.Variable, 32*4096),
|
||||
@@ -22,6 +21,7 @@ func nbConstraints(blobSize int) int {
|
||||
MaxBlobPayloadNbBytes: blobSize,
|
||||
UseGkrMiMC: true,
|
||||
}
|
||||
runtime.GC()
|
||||
if cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &c, frontend.WithCapacity(maxNbConstraints*6/5)); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
|
||||
@@ -2,10 +2,10 @@ package circuits
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
@@ -192,23 +192,37 @@ func LoadSetup(cfg *config.Config, circuitID CircuitID) (Setup, error) {
|
||||
return Setup{}, fmt.Errorf("fetching SRS: %w", err)
|
||||
}
|
||||
pk := plonk.NewProvingKey(curveID)
|
||||
var kzgVkFromVk, kzgVkFromSrs io.WriterTo
|
||||
switch pk := pk.(type) {
|
||||
case *plonk_bn254.ProvingKey:
|
||||
pk.Vk = vk.(*plonk_bn254.VerifyingKey)
|
||||
pk.Kzg = srsCanonical.(*kzg254.SRS).Pk
|
||||
srsC := srsCanonical.(*kzg254.SRS)
|
||||
pk.Kzg = srsC.Pk
|
||||
pk.KzgLagrange = srsLagrange.(*kzg254.SRS).Pk
|
||||
kzgVkFromVk = &pk.Vk.Kzg
|
||||
kzgVkFromSrs = &srsC.Vk
|
||||
case *plonk_bls12377.ProvingKey:
|
||||
pk.Vk = vk.(*plonk_bls12377.VerifyingKey)
|
||||
pk.Kzg = srsCanonical.(*kzg377.SRS).Pk
|
||||
srsC := srsCanonical.(*kzg377.SRS)
|
||||
pk.Kzg = srsC.Pk
|
||||
pk.KzgLagrange = srsLagrange.(*kzg377.SRS).Pk
|
||||
kzgVkFromVk = &pk.Vk.Kzg
|
||||
kzgVkFromSrs = &srsC.Vk
|
||||
case *plonk_bw6761.ProvingKey:
|
||||
pk.Vk = vk.(*plonk_bw6761.VerifyingKey)
|
||||
pk.Kzg = srsCanonical.(*kzgbw6.SRS).Pk
|
||||
srsC := srsCanonical.(*kzgbw6.SRS)
|
||||
pk.Kzg = srsC.Pk
|
||||
pk.KzgLagrange = srsLagrange.(*kzgbw6.SRS).Pk
|
||||
kzgVkFromVk = &pk.Vk.Kzg
|
||||
kzgVkFromSrs = &srsC.Vk
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
if err = utils.WriterstoEqual(kzgVkFromSrs, kzgVkFromVk); err != nil {
|
||||
return Setup{}, fmt.Errorf("verifying key <> SRS mismatch: %w", err)
|
||||
}
|
||||
|
||||
return Setup{
|
||||
Manifest: *manifest,
|
||||
Circuit: circuit,
|
||||
@@ -272,12 +286,13 @@ func readFromFile(path string, into any) error {
|
||||
panic(fmt.Sprintf("unsupported type %T", into))
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening %q: %w", path, err)
|
||||
}
|
||||
|
||||
if _, err = rFunc(bytes.NewReader(data)); err != nil {
|
||||
_, err = rFunc(f)
|
||||
if err = errors.Join(err, f.Close()); err != nil {
|
||||
return fmt.Errorf("reading %q from disk: %w", path, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
)
|
||||
|
||||
// SetupManifest is the human readable manifest of the assets generated by the prover setup command
|
||||
// SetupManifest is the human-readable manifest of the assets generated by the prover setup command
|
||||
type SetupManifest struct {
|
||||
CircuitName string `json:"circuitName"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
74
prover/circuits/setup_test.go
Normal file
74
prover/circuits/setup_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark-crypto/kzg"
|
||||
"github.com/consensys/gnark/backend/plonk"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/frontend/cs/scs"
|
||||
"github.com/consensys/gnark/test/unsafekzg"
|
||||
"github.com/consensys/linea-monorepo/prover/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadSetup(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := config.Config{
|
||||
AssetsDir: dir,
|
||||
}
|
||||
cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit{make([]frontend.Variable, 1)})
|
||||
require.NoError(t, err)
|
||||
|
||||
canonicalSize, lagrangeSize := plonk.SRSSize(cs)
|
||||
canonical, lagrange, err := unsafekzg.NewSRS(cs)
|
||||
require.NoError(t, err)
|
||||
const srsFilenameTemplate = "kzg_srs_%s_%d_bn254_aleo.memdump"
|
||||
srsDir := filepath.Join(dir, "kzgsrs")
|
||||
require.NoError(t, os.Mkdir(srsDir, 0700))
|
||||
dumpToFile(t, canonical, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "canonical", canonicalSize)))
|
||||
dumpToFile(t, lagrange, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "lagrange", lagrangeSize)))
|
||||
|
||||
const circuitName = "test"
|
||||
|
||||
srsProvider := NewUnsafeSRSProvider()
|
||||
setup, err := MakeSetup(context.TODO(), circuitName, cs, srsProvider, map[string]any{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, setup.WriteTo(filepath.Join(dir, circuitName)))
|
||||
|
||||
_, err = LoadSetup(&cfg, circuitName)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type circuit struct {
|
||||
// this is a dummy circuit that does nothing
|
||||
// it is used to generate the SRS
|
||||
Input []frontend.Variable `gnark:",public"`
|
||||
}
|
||||
|
||||
func (circuit *circuit) Define(api frontend.API) error {
|
||||
c, err := api.(frontend.Committer).Commit(circuit.Input...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prod := frontend.Variable(1)
|
||||
for _, x := range circuit.Input {
|
||||
prod = api.Mul(prod, x)
|
||||
}
|
||||
api.AssertIsDifferent(c, prod)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dumpToFile(t *testing.T, o kzg.Serializable, path string) {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = errors.Join(o.WriteDump(f), f.Close())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -52,100 +51,12 @@ func RandIntSliceN(length, n int) []int {
|
||||
return res
|
||||
}
|
||||
|
||||
type BytesEqualError struct {
|
||||
Index int
|
||||
error string
|
||||
}
|
||||
|
||||
func (e *BytesEqualError) Error() string {
|
||||
return e.error
|
||||
}
|
||||
|
||||
func LoadJson(t *testing.T, path string, v any) {
|
||||
in, err := os.Open(path)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, json.NewDecoder(in).Decode(v))
|
||||
}
|
||||
|
||||
// BytesEqual between byte slices a,b
|
||||
// a readable error message would show in case of inequality
|
||||
// TODO error options: block size, check forwards or backwards etc
|
||||
func BytesEqual(expected, actual []byte) error {
|
||||
l := min(len(expected), len(actual))
|
||||
|
||||
failure := 0
|
||||
for failure < l {
|
||||
if expected[failure] != actual[failure] {
|
||||
break
|
||||
}
|
||||
failure++
|
||||
}
|
||||
|
||||
if len(expected) == len(actual) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// there is a mismatch
|
||||
var sb strings.Builder
|
||||
|
||||
const (
|
||||
radius = 40
|
||||
blockSize = 32
|
||||
)
|
||||
|
||||
printCentered := func(b []byte) {
|
||||
|
||||
for i := max(failure-radius, 0); i <= failure+radius; i++ {
|
||||
if i%blockSize == 0 && i != failure-radius {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
if i >= 0 && i < len(b) {
|
||||
sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once
|
||||
} else {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure))
|
||||
|
||||
sb.WriteString("expected: ")
|
||||
printCentered(expected)
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString("actual: ")
|
||||
printCentered(actual)
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString(" ")
|
||||
for i := max(failure-radius, 0); i <= failure+radius; {
|
||||
if i%blockSize == 0 && i != failure-radius {
|
||||
s := strconv.Itoa(i)
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(s)
|
||||
i += len(s) / 2
|
||||
if len(s)%2 != 0 {
|
||||
sb.WriteString(" ")
|
||||
i++
|
||||
}
|
||||
} else {
|
||||
if i == failure {
|
||||
sb.WriteString("^^")
|
||||
} else {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
return &BytesEqualError{
|
||||
Index: failure,
|
||||
error: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func SlicesEqual[T any](expected, actual []T) error {
|
||||
if l1, l2 := len(expected), len(actual); l1 != l2 {
|
||||
return fmt.Errorf("length mismatch %d≠%d", l1, l2)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"golang.org/x/exp/constraints"
|
||||
@@ -317,3 +320,126 @@ func WriteToJSON(path string, v interface{}) error {
|
||||
defer f.Close()
|
||||
return json.NewEncoder(f).Encode(v)
|
||||
}
|
||||
|
||||
func WriterstoEqual(expected, actual io.WriterTo) error {
|
||||
var bb bytes.Buffer
|
||||
if _, err := expected.WriteTo(&bb); err != nil {
|
||||
return err
|
||||
}
|
||||
ab := bb.Bytes()
|
||||
bb.Reset()
|
||||
if _, err := actual.WriteTo(&bb); err != nil {
|
||||
return err
|
||||
}
|
||||
return BytesEqual(ab, bb.Bytes())
|
||||
}
|
||||
|
||||
// BytesEqual between byte slices a,b
|
||||
// a readable error message would show in case of inequality
|
||||
// TODO error options: block size, check forwards or backwards etc
|
||||
func BytesEqual(expected, actual []byte) error {
|
||||
if bytes.Equal(expected, actual) {
|
||||
return nil // equality fast path
|
||||
}
|
||||
|
||||
l := min(len(expected), len(actual))
|
||||
|
||||
failure := 0
|
||||
for failure < l {
|
||||
if expected[failure] != actual[failure] {
|
||||
break
|
||||
}
|
||||
failure++
|
||||
}
|
||||
|
||||
if len(expected) == len(actual) && failure == l {
|
||||
panic("bytes.Equal returned false, but could not find a mismatch")
|
||||
}
|
||||
|
||||
// there is a mismatch
|
||||
var sb strings.Builder
|
||||
|
||||
const (
|
||||
radius = 40
|
||||
blockSize = 32
|
||||
)
|
||||
|
||||
printCentered := func(b []byte) {
|
||||
|
||||
for i := max(failure-radius, 0); i <= failure+radius; i++ {
|
||||
if i%blockSize == 0 && i != failure-radius {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
if i >= 0 && i < len(b) {
|
||||
sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once
|
||||
} else {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure))
|
||||
|
||||
sb.WriteString("expected: ")
|
||||
printCentered(expected)
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString("actual: ")
|
||||
printCentered(actual)
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString(" ")
|
||||
for i := max(failure-radius, 0); i <= failure+radius; {
|
||||
if i%blockSize == 0 && i != failure-radius {
|
||||
s := strconv.Itoa(i)
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(s)
|
||||
i += len(s) / 2
|
||||
if len(s)%2 != 0 {
|
||||
sb.WriteString(" ")
|
||||
i++
|
||||
}
|
||||
} else {
|
||||
if i == failure {
|
||||
sb.WriteString("^^")
|
||||
} else {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
return &BytesEqualError{
|
||||
Index: failure,
|
||||
error: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
type BytesEqualError struct {
|
||||
Index int
|
||||
error string
|
||||
}
|
||||
|
||||
func (e *BytesEqualError) Error() string {
|
||||
return e.error
|
||||
}
|
||||
|
||||
func ReadFromFile(path string, to io.ReaderFrom) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = to.ReadFrom(f)
|
||||
return errors.Join(err, f.Close())
|
||||
}
|
||||
|
||||
func WriteToFile(path string, from io.WriterTo) error {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) // TODO @Tabaie option for permissions?
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = from.WriteTo(f)
|
||||
return errors.Join(err, f.Close())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user