diff --git a/prover/circuits/blobdecompression/large-tests/maximize_blob_size/main.go b/prover/circuits/blobdecompression/large-tests/maximize_blob_size/main.go index 068eba49..40059d87 100644 --- a/prover/circuits/blobdecompression/large-tests/maximize_blob_size/main.go +++ b/prover/circuits/blobdecompression/large-tests/maximize_blob_size/main.go @@ -3,18 +3,17 @@ package main import ( "flag" "fmt" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" v1 "github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v1" blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1" + "runtime" ) const maxNbConstraints = 1 << 27 func nbConstraints(blobSize int) int { - fmt.Printf("*********************\nfor blob of size %dB or %.2fKB:\n", blobSize, float32(blobSize)/1024) c := v1.Circuit{ BlobBytes: make([]frontend.Variable, 32*4096), @@ -22,6 +21,7 @@ func nbConstraints(blobSize int) int { MaxBlobPayloadNbBytes: blobSize, UseGkrMiMC: true, } + runtime.GC() if cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &c, frontend.WithCapacity(maxNbConstraints*6/5)); err != nil { panic(err) } else { diff --git a/prover/circuits/setup.go b/prover/circuits/setup.go index 3bf635cb..c1f57470 100644 --- a/prover/circuits/setup.go +++ b/prover/circuits/setup.go @@ -2,10 +2,10 @@ package circuits import ( "bufio" - "bytes" "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -192,23 +192,37 @@ func LoadSetup(cfg *config.Config, circuitID CircuitID) (Setup, error) { return Setup{}, fmt.Errorf("fetching SRS: %w", err) } pk := plonk.NewProvingKey(curveID) + var kzgVkFromVk, kzgVkFromSrs io.WriterTo switch pk := pk.(type) { case *plonk_bn254.ProvingKey: pk.Vk = vk.(*plonk_bn254.VerifyingKey) - pk.Kzg = srsCanonical.(*kzg254.SRS).Pk + srsC := srsCanonical.(*kzg254.SRS) + pk.Kzg = srsC.Pk pk.KzgLagrange = srsLagrange.(*kzg254.SRS).Pk + kzgVkFromVk = &pk.Vk.Kzg + kzgVkFromSrs = &srsC.Vk case *plonk_bls12377.ProvingKey: pk.Vk = vk.(*plonk_bls12377.VerifyingKey) - pk.Kzg = srsCanonical.(*kzg377.SRS).Pk + srsC := srsCanonical.(*kzg377.SRS) + pk.Kzg = srsC.Pk pk.KzgLagrange = srsLagrange.(*kzg377.SRS).Pk + kzgVkFromVk = &pk.Vk.Kzg + kzgVkFromSrs = &srsC.Vk case *plonk_bw6761.ProvingKey: pk.Vk = vk.(*plonk_bw6761.VerifyingKey) - pk.Kzg = srsCanonical.(*kzgbw6.SRS).Pk + srsC := srsCanonical.(*kzgbw6.SRS) + pk.Kzg = srsC.Pk pk.KzgLagrange = srsLagrange.(*kzgbw6.SRS).Pk + kzgVkFromVk = &pk.Vk.Kzg + kzgVkFromSrs = &srsC.Vk default: panic("not implemented") } + if err = utils.WriterstoEqual(kzgVkFromSrs, kzgVkFromVk); err != nil { + return Setup{}, fmt.Errorf("verifying key <> SRS mismatch: %w", err) + } + return Setup{ Manifest: *manifest, Circuit: circuit, @@ -272,12 +286,13 @@ func readFromFile(path string, into any) error { panic(fmt.Sprintf("unsupported type %T", into)) } - data, err := os.ReadFile(path) + f, err := os.Open(path) if err != nil { return fmt.Errorf("opening %q: %w", path, err) } - if _, err = rFunc(bytes.NewReader(data)); err != nil { + _, err = rFunc(f) + if err = errors.Join(err, f.Close()); err != nil { return fmt.Errorf("reading %q from disk: %w", path, err) } diff --git a/prover/circuits/setup_manifest.go b/prover/circuits/setup_manifest.go index 803c2b17..f4603271 100644 --- a/prover/circuits/setup_manifest.go +++ b/prover/circuits/setup_manifest.go @@ -10,7 +10,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" ) -// SetupManifest is the human readable manifest of the assets generated by the prover setup command +// SetupManifest is the human-readable manifest of the assets generated by the prover setup command type SetupManifest struct { CircuitName string `json:"circuitName"` Timestamp time.Time `json:"timestamp"` diff --git a/prover/circuits/setup_test.go b/prover/circuits/setup_test.go new file mode 100644 index 00000000..f0c887f7 --- /dev/null +++ b/prover/circuits/setup_test.go @@ -0,0 +1,74 @@ +package circuits + +import ( + "context" + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test/unsafekzg" + "github.com/consensys/linea-monorepo/prover/config" + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" +) + +func TestLoadSetup(t *testing.T) { + dir := t.TempDir() + cfg := config.Config{ + AssetsDir: dir, + } + cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit{make([]frontend.Variable, 1)}) + require.NoError(t, err) + + canonicalSize, lagrangeSize := plonk.SRSSize(cs) + canonical, lagrange, err := unsafekzg.NewSRS(cs) + require.NoError(t, err) + const srsFilenameTemplate = "kzg_srs_%s_%d_bn254_aleo.memdump" + srsDir := filepath.Join(dir, "kzgsrs") + require.NoError(t, os.Mkdir(srsDir, 0700)) + dumpToFile(t, canonical, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "canonical", canonicalSize))) + dumpToFile(t, lagrange, filepath.Join(srsDir, fmt.Sprintf(srsFilenameTemplate, "lagrange", lagrangeSize))) + + const circuitName = "test" + + srsProvider := NewUnsafeSRSProvider() + setup, err := MakeSetup(context.TODO(), circuitName, cs, srsProvider, map[string]any{}) + require.NoError(t, err) + + require.NoError(t, setup.WriteTo(filepath.Join(dir, circuitName))) + + _, err = LoadSetup(&cfg, circuitName) + require.NoError(t, err) +} + +type circuit struct { + // this is a dummy circuit that does nothing + // it is used to generate the SRS + Input []frontend.Variable `gnark:",public"` +} + +func (circuit *circuit) Define(api frontend.API) error { + c, err := api.(frontend.Committer).Commit(circuit.Input...) + if err != nil { + return err + } + prod := frontend.Variable(1) + for _, x := range circuit.Input { + prod = api.Mul(prod, x) + } + api.AssertIsDifferent(c, prod) + return nil +} + +func dumpToFile(t *testing.T, o kzg.Serializable, path string) { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) + require.NoError(t, err) + + err = errors.Join(o.WriteDump(f), f.Close()) + require.NoError(t, err) +} diff --git a/prover/utils/test_utils/test_utils.go b/prover/utils/test_utils/test_utils.go index aff17af1..74616356 100644 --- a/prover/utils/test_utils/test_utils.go +++ b/prover/utils/test_utils/test_utils.go @@ -12,7 +12,6 @@ import ( "math" "os" "reflect" - "strconv" "strings" "testing" @@ -52,100 +51,12 @@ func RandIntSliceN(length, n int) []int { return res } -type BytesEqualError struct { - Index int - error string -} - -func (e *BytesEqualError) Error() string { - return e.error -} - func LoadJson(t *testing.T, path string, v any) { in, err := os.Open(path) require.NoError(t, err) require.NoError(t, json.NewDecoder(in).Decode(v)) } -// BytesEqual between byte slices a,b -// a readable error message would show in case of inequality -// TODO error options: block size, check forwards or backwards etc -func BytesEqual(expected, actual []byte) error { - l := min(len(expected), len(actual)) - - failure := 0 - for failure < l { - if expected[failure] != actual[failure] { - break - } - failure++ - } - - if len(expected) == len(actual) { - return nil - } - - // there is a mismatch - var sb strings.Builder - - const ( - radius = 40 - blockSize = 32 - ) - - printCentered := func(b []byte) { - - for i := max(failure-radius, 0); i <= failure+radius; i++ { - if i%blockSize == 0 && i != failure-radius { - sb.WriteString(" ") - } - if i >= 0 && i < len(b) { - sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once - } else { - sb.WriteString(" ") - } - } - } - - sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure)) - - sb.WriteString("expected: ") - printCentered(expected) - sb.WriteString("\n") - - sb.WriteString("actual: ") - printCentered(actual) - sb.WriteString("\n") - - sb.WriteString(" ") - for i := max(failure-radius, 0); i <= failure+radius; { - if i%blockSize == 0 && i != failure-radius { - s := strconv.Itoa(i) - sb.WriteString(" ") - sb.WriteString(s) - i += len(s) / 2 - if len(s)%2 != 0 { - sb.WriteString(" ") - i++ - } - } else { - if i == failure { - sb.WriteString("^^") - } else { - sb.WriteString(" ") - } - i++ - } - } - - sb.WriteString("\n") - - return &BytesEqualError{ - Index: failure, - error: sb.String(), - } -} - func SlicesEqual[T any](expected, actual []T) error { if l1, l2 := len(expected), len(actual); l1 != l2 { return fmt.Errorf("length mismatch %d≠%d", l1, l2) diff --git a/prover/utils/utils.go b/prover/utils/utils.go index a9c38bd7..9a3fa8f7 100644 --- a/prover/utils/utils.go +++ b/prover/utils/utils.go @@ -1,9 +1,11 @@ package utils import ( + "bytes" "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "math" @@ -11,6 +13,7 @@ import ( "os" "reflect" "strconv" + "strings" "github.com/consensys/gnark/frontend" "golang.org/x/exp/constraints" @@ -317,3 +320,126 @@ func WriteToJSON(path string, v interface{}) error { defer f.Close() return json.NewEncoder(f).Encode(v) } + +func WriterstoEqual(expected, actual io.WriterTo) error { + var bb bytes.Buffer + if _, err := expected.WriteTo(&bb); err != nil { + return err + } + ab := bb.Bytes() + bb.Reset() + if _, err := actual.WriteTo(&bb); err != nil { + return err + } + return BytesEqual(ab, bb.Bytes()) +} + +// BytesEqual between byte slices a,b +// a readable error message would show in case of inequality +// TODO error options: block size, check forwards or backwards etc +func BytesEqual(expected, actual []byte) error { + if bytes.Equal(expected, actual) { + return nil // equality fast path + } + + l := min(len(expected), len(actual)) + + failure := 0 + for failure < l { + if expected[failure] != actual[failure] { + break + } + failure++ + } + + if len(expected) == len(actual) && failure == l { + panic("bytes.Equal returned false, but could not find a mismatch") + } + + // there is a mismatch + var sb strings.Builder + + const ( + radius = 40 + blockSize = 32 + ) + + printCentered := func(b []byte) { + + for i := max(failure-radius, 0); i <= failure+radius; i++ { + if i%blockSize == 0 && i != failure-radius { + sb.WriteString(" ") + } + if i >= 0 && i < len(b) { + sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once + } else { + sb.WriteString(" ") + } + } + } + + sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure)) + + sb.WriteString("expected: ") + printCentered(expected) + sb.WriteString("\n") + + sb.WriteString("actual: ") + printCentered(actual) + sb.WriteString("\n") + + sb.WriteString(" ") + for i := max(failure-radius, 0); i <= failure+radius; { + if i%blockSize == 0 && i != failure-radius { + s := strconv.Itoa(i) + sb.WriteString(" ") + sb.WriteString(s) + i += len(s) / 2 + if len(s)%2 != 0 { + sb.WriteString(" ") + i++ + } + } else { + if i == failure { + sb.WriteString("^^") + } else { + sb.WriteString(" ") + } + i++ + } + } + + sb.WriteString("\n") + + return &BytesEqualError{ + Index: failure, + error: sb.String(), + } +} + +type BytesEqualError struct { + Index int + error string +} + +func (e *BytesEqualError) Error() string { + return e.error +} + +func ReadFromFile(path string, to io.ReaderFrom) error { + f, err := os.Open(path) + if err != nil { + return err + } + _, err = to.ReadFrom(f) + return errors.Join(err, f.Close()) +} + +func WriteToFile(path string, from io.WriterTo) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) // TODO @Tabaie option for permissions? + if err != nil { + return err + } + _, err = from.WriteTo(f) + return errors.Join(err, f.Close()) +}