Feat Setup checks equality between KZG Vk from circuit Vk and from SRS (#325)

* feat test vk consistency
* test write-read setup
* feat readfromfile-writetofile

---------

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
This commit is contained in:
Arya Tabaie
2024-12-02 07:45:15 -06:00
committed by GitHub
parent 8be665e11c
commit c4c3bfd739
6 changed files with 224 additions and 98 deletions

View File

@@ -3,18 +3,17 @@ package main
import (
"flag"
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/scs"
v1 "github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v1"
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"runtime"
)
const maxNbConstraints = 1 << 27
func nbConstraints(blobSize int) int {
fmt.Printf("*********************\nfor blob of size %dB or %.2fKB:\n", blobSize, float32(blobSize)/1024)
c := v1.Circuit{
BlobBytes: make([]frontend.Variable, 32*4096),
@@ -22,6 +21,7 @@ func nbConstraints(blobSize int) int {
MaxBlobPayloadNbBytes: blobSize,
UseGkrMiMC: true,
}
runtime.GC()
if cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &c, frontend.WithCapacity(maxNbConstraints*6/5)); err != nil {
panic(err)
} else {

View File

@@ -2,10 +2,10 @@ package circuits
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
@@ -192,23 +192,37 @@ func LoadSetup(cfg *config.Config, circuitID CircuitID) (Setup, error) {
return Setup{}, fmt.Errorf("fetching SRS: %w", err)
}
pk := plonk.NewProvingKey(curveID)
var kzgVkFromVk, kzgVkFromSrs io.WriterTo
switch pk := pk.(type) {
case *plonk_bn254.ProvingKey:
pk.Vk = vk.(*plonk_bn254.VerifyingKey)
pk.Kzg = srsCanonical.(*kzg254.SRS).Pk
srsC := srsCanonical.(*kzg254.SRS)
pk.Kzg = srsC.Pk
pk.KzgLagrange = srsLagrange.(*kzg254.SRS).Pk
kzgVkFromVk = &pk.Vk.Kzg
kzgVkFromSrs = &srsC.Vk
case *plonk_bls12377.ProvingKey:
pk.Vk = vk.(*plonk_bls12377.VerifyingKey)
pk.Kzg = srsCanonical.(*kzg377.SRS).Pk
srsC := srsCanonical.(*kzg377.SRS)
pk.Kzg = srsC.Pk
pk.KzgLagrange = srsLagrange.(*kzg377.SRS).Pk
kzgVkFromVk = &pk.Vk.Kzg
kzgVkFromSrs = &srsC.Vk
case *plonk_bw6761.ProvingKey:
pk.Vk = vk.(*plonk_bw6761.VerifyingKey)
pk.Kzg = srsCanonical.(*kzgbw6.SRS).Pk
srsC := srsCanonical.(*kzgbw6.SRS)
pk.Kzg = srsC.Pk
pk.KzgLagrange = srsLagrange.(*kzgbw6.SRS).Pk
kzgVkFromVk = &pk.Vk.Kzg
kzgVkFromSrs = &srsC.Vk
default:
panic("not implemented")
}
if err = utils.WriterstoEqual(kzgVkFromSrs, kzgVkFromVk); err != nil {
return Setup{}, fmt.Errorf("verifying key <> SRS mismatch: %w", err)
}
return Setup{
Manifest: *manifest,
Circuit: circuit,
@@ -272,12 +286,13 @@ func readFromFile(path string, into any) error {
panic(fmt.Sprintf("unsupported type %T", into))
}
data, err := os.ReadFile(path)
f, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening %q: %w", path, err)
}
if _, err = rFunc(bytes.NewReader(data)); err != nil {
_, err = rFunc(f)
if err = errors.Join(err, f.Close()); err != nil {
return fmt.Errorf("reading %q from disk: %w", path, err)
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/consensys/gnark-crypto/ecc"
)
// SetupManifest is the human readable manifest of the assets generated by the prover setup command
// SetupManifest is the human-readable manifest of the assets generated by the prover setup command
type SetupManifest struct {
CircuitName string `json:"circuitName"`
Timestamp time.Time `json:"timestamp"`

View File

@@ -0,0 +1,74 @@
package circuits
import (
"context"
"errors"
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/kzg"
"github.com/consensys/gnark/backend/plonk"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/test/unsafekzg"
"github.com/consensys/linea-monorepo/prover/config"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"testing"
)
func TestLoadSetup(t *testing.T) {
dir := t.TempDir()
cfg := config.Config{
AssetsDir: dir,
}
cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit{make([]frontend.Variable, 1)})
require.NoError(t, err)
canonicalSize, lagrangeSize := plonk.SRSSize(cs)
canonical, lagrange, err := unsafekzg.NewSRS(cs)
require.NoError(t, err)
const srsFilenameTemplate = "kzg_srs_%s_%d_bn254_aleo.memdump"
srsDir := filepath.Join(dir, "kzgsrs")
require.NoError(t, os.Mkdir(srsDir, 0700))
dumpToFile(t, canonical, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "canonical", canonicalSize)))
dumpToFile(t, lagrange, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "lagrange", lagrangeSize)))
const circuitName = "test"
srsProvider := NewUnsafeSRSProvider()
setup, err := MakeSetup(context.TODO(), circuitName, cs, srsProvider, map[string]any{})
require.NoError(t, err)
require.NoError(t, setup.WriteTo(filepath.Join(dir, circuitName)))
_, err = LoadSetup(&cfg, circuitName)
require.NoError(t, err)
}
type circuit struct {
// this is a dummy circuit that does nothing
// it is used to generate the SRS
Input []frontend.Variable `gnark:",public"`
}
func (circuit *circuit) Define(api frontend.API) error {
c, err := api.(frontend.Committer).Commit(circuit.Input...)
if err != nil {
return err
}
prod := frontend.Variable(1)
for _, x := range circuit.Input {
prod = api.Mul(prod, x)
}
api.AssertIsDifferent(c, prod)
return nil
}
func dumpToFile(t *testing.T, o kzg.Serializable, path string) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600)
require.NoError(t, err)
err = errors.Join(o.WriteDump(f), f.Close())
require.NoError(t, err)
}

View File

@@ -12,7 +12,6 @@ import (
"math"
"os"
"reflect"
"strconv"
"strings"
"testing"
@@ -52,100 +51,12 @@ func RandIntSliceN(length, n int) []int {
return res
}
type BytesEqualError struct {
Index int
error string
}
func (e *BytesEqualError) Error() string {
return e.error
}
func LoadJson(t *testing.T, path string, v any) {
in, err := os.Open(path)
require.NoError(t, err)
require.NoError(t, json.NewDecoder(in).Decode(v))
}
// BytesEqual between byte slices a,b
// a readable error message would show in case of inequality
// TODO error options: block size, check forwards or backwards etc
func BytesEqual(expected, actual []byte) error {
l := min(len(expected), len(actual))
failure := 0
for failure < l {
if expected[failure] != actual[failure] {
break
}
failure++
}
if len(expected) == len(actual) {
return nil
}
// there is a mismatch
var sb strings.Builder
const (
radius = 40
blockSize = 32
)
printCentered := func(b []byte) {
for i := max(failure-radius, 0); i <= failure+radius; i++ {
if i%blockSize == 0 && i != failure-radius {
sb.WriteString(" ")
}
if i >= 0 && i < len(b) {
sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once
} else {
sb.WriteString(" ")
}
}
}
sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure))
sb.WriteString("expected: ")
printCentered(expected)
sb.WriteString("\n")
sb.WriteString("actual: ")
printCentered(actual)
sb.WriteString("\n")
sb.WriteString(" ")
for i := max(failure-radius, 0); i <= failure+radius; {
if i%blockSize == 0 && i != failure-radius {
s := strconv.Itoa(i)
sb.WriteString(" ")
sb.WriteString(s)
i += len(s) / 2
if len(s)%2 != 0 {
sb.WriteString(" ")
i++
}
} else {
if i == failure {
sb.WriteString("^^")
} else {
sb.WriteString(" ")
}
i++
}
}
sb.WriteString("\n")
return &BytesEqualError{
Index: failure,
error: sb.String(),
}
}
func SlicesEqual[T any](expected, actual []T) error {
if l1, l2 := len(expected), len(actual); l1 != l2 {
return fmt.Errorf("length mismatch %d≠%d", l1, l2)

View File

@@ -1,9 +1,11 @@
package utils
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math"
@@ -11,6 +13,7 @@ import (
"os"
"reflect"
"strconv"
"strings"
"github.com/consensys/gnark/frontend"
"golang.org/x/exp/constraints"
@@ -317,3 +320,126 @@ func WriteToJSON(path string, v interface{}) error {
defer f.Close()
return json.NewEncoder(f).Encode(v)
}
func WriterstoEqual(expected, actual io.WriterTo) error {
var bb bytes.Buffer
if _, err := expected.WriteTo(&bb); err != nil {
return err
}
ab := bb.Bytes()
bb.Reset()
if _, err := actual.WriteTo(&bb); err != nil {
return err
}
return BytesEqual(ab, bb.Bytes())
}
// BytesEqual between byte slices a,b
// a readable error message would show in case of inequality
// TODO error options: block size, check forwards or backwards etc
func BytesEqual(expected, actual []byte) error {
if bytes.Equal(expected, actual) {
return nil // equality fast path
}
l := min(len(expected), len(actual))
failure := 0
for failure < l {
if expected[failure] != actual[failure] {
break
}
failure++
}
if len(expected) == len(actual) && failure == l {
panic("bytes.Equal returned false, but could not find a mismatch")
}
// there is a mismatch
var sb strings.Builder
const (
radius = 40
blockSize = 32
)
printCentered := func(b []byte) {
for i := max(failure-radius, 0); i <= failure+radius; i++ {
if i%blockSize == 0 && i != failure-radius {
sb.WriteString(" ")
}
if i >= 0 && i < len(b) {
sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once
} else {
sb.WriteString(" ")
}
}
}
sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure))
sb.WriteString("expected: ")
printCentered(expected)
sb.WriteString("\n")
sb.WriteString("actual: ")
printCentered(actual)
sb.WriteString("\n")
sb.WriteString(" ")
for i := max(failure-radius, 0); i <= failure+radius; {
if i%blockSize == 0 && i != failure-radius {
s := strconv.Itoa(i)
sb.WriteString(" ")
sb.WriteString(s)
i += len(s) / 2
if len(s)%2 != 0 {
sb.WriteString(" ")
i++
}
} else {
if i == failure {
sb.WriteString("^^")
} else {
sb.WriteString(" ")
}
i++
}
}
sb.WriteString("\n")
return &BytesEqualError{
Index: failure,
error: sb.String(),
}
}
type BytesEqualError struct {
Index int
error string
}
func (e *BytesEqualError) Error() string {
return e.error
}
func ReadFromFile(path string, to io.ReaderFrom) error {
f, err := os.Open(path)
if err != nil {
return err
}
_, err = to.ReadFrom(f)
return errors.Join(err, f.Close())
}
func WriteToFile(path string, from io.WriterTo) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) // TODO @Tabaie option for permissions?
if err != nil {
return err
}
_, err = from.WriteTo(f)
return errors.Join(err, f.Close())
}