From c0568080a152e1f96f822c64810453eac8b77064 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 14 Aug 2024 07:24:08 -0500 Subject: [PATCH] Prover: implements the aggregation circuit for beta v1 (#3780) --------- Signed-off-by: Arya Tabaie Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Co-authored-by: AlexandreBelling --- prover/backend/aggregation/craft.go | 26 +- prover/backend/aggregation/prove.go | 51 +++- prover/backend/aggregation/request.go | 5 + prover/backend/blobdecompression/prove.go | 4 +- prover/backend/execution/prove.go | 4 +- .../circuits/aggregation/asset_generation.go | 22 +- prover/circuits/aggregation/circuit.go | 241 ++++----------- prover/circuits/aggregation/circuit_test.go | 2 +- prover/circuits/aggregation/prover.go | 18 +- .../blobdecompression/public-input/pi.go | 4 +- .../blobdecompression/public-input/pi_test.go | 11 +- .../circuits/blobdecompression/v0/circuit.go | 3 +- .../circuits/blobdecompression/v0/prelude.go | 3 +- .../circuits/blobdecompression/v1/circuit.go | 5 +- prover/circuits/execution/pi.go | 10 +- prover/circuits/internal/io.go | 2 +- .../internal/test_utils/test_utils.go | 17 ++ prover/circuits/internal/utils.go | 278 +++++++++++------- prover/circuits/internal/utils_test.go | 116 +++++++- prover/circuits/pi-interconnection/assign.go | 34 +-- .../circuits/pi-interconnection/bench/main.go | 17 +- prover/circuits/pi-interconnection/circuit.go | 123 +++++--- .../pi-interconnection/circuit_test.go | 9 +- .../compile/test_compile.go | 17 +- .../circuits/pi-interconnection/e2e_test.go | 39 ++- .../pi-interconnection/keccak/assign_test.go | 5 +- .../test_utils/test_utils.go | 5 +- prover/circuits/registry.go | 21 +- prover/circuits/setup.go | 2 +- prover/cmd/prover/cmd/setup.go | 16 +- prover/config/config.go | 19 +- .../circuit-testing/aggregation/main.go | 5 +- prover/public-input/aggregation.go | 163 ++++++++++ prover/public-input/execution.go | 61 ++++ prover/utils/snark.go | 120 ++++++++ 35 files changed, 1008 insertions(+), 470 deletions(-) create mode 100644 prover/public-input/execution.go create mode 100644 prover/utils/snark.go diff --git a/prover/backend/aggregation/craft.go b/prover/backend/aggregation/craft.go index eef79460..821a3a5c 100644 --- a/prover/backend/aggregation/craft.go +++ b/prover/backend/aggregation/craft.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "github.com/consensys/zkevm-monorepo/prover/backend/blobsubmission" public_input "github.com/consensys/zkevm-monorepo/prover/public-input" "path" @@ -44,6 +45,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { } ) + cf.ExecutionPI = make([]public_input.Execution, 0, len(req.ExecutionProofs)) + for i, execReqFPath := range req.ExecutionProofs { po := &execution.Response{} fpath := path.Join(cfg.Execution.DirTo(), execReqFPath) @@ -66,8 +69,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { utils.Panic("conflated batch %v reports a parent state hash mismatch, but this is not the first batch of the sequence", i) } - // This is purposefuly overwritten at each iteration over i. We want to - // keep the final velue. + // This is purposefully overwritten at each iteration over i. We want to + // keep the final value. cf.FinalBlockNumber = uint(po.FirstBlockNumber + len(po.BlocksData) - 1) for _, blockdata := range po.BlocksData { @@ -86,14 +89,30 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { // Append the proof claim to the list of collected proofs if !cf.IsProoflessJob { - pClaim, err := parseProofClaim(po.Proof, po.DebugData.FinalHash, po.VerifyingKeyShaSum) + pClaim, err := parseProofClaim(po.Proof, po.DebugData.FinalHash, po.VerifyingKeyShaSum) // @gbotrel Is finalHash the state hash? If so why is it given as the execution circuit's PI? if err != nil { return nil, fmt.Errorf("could not parse the proof claim for `%v` : %w", fpath, err) } cf.ProofClaims = append(cf.ProofClaims, *pClaim) + // TODO make sure this belongs in the if + finalBlock := &po.BlocksData[len(po.BlocksData)-1] + piq, err := public_input.ExecutionSerializable{ + L2MsgHashes: l2MessageHashes, + FinalStateRootHash: po.DebugData.FinalHash, // TODO @tabaie make sure this is the right value + FinalBlockNumber: uint64(cf.FinalBlockNumber), + FinalBlockTimestamp: finalBlock.TimeStamp, + FinalRollingHash: cf.L1RollingHash, + FinalRollingHashNumber: uint64(cf.L1RollingHashMessageNumber), + }.Decode() + if err != nil { + return nil, err + } + cf.ExecutionPI = append(cf.ExecutionPI, piq) } } + cf.DecompressionPI = make([]blobsubmission.Response, 0, len(req.CompressionProofs)) + for i, decompReqFPath := range req.CompressionProofs { dp := &blobdecompression.Response{} fpath := path.Join(cfg.BlobDecompression.DirTo(), decompReqFPath) @@ -120,6 +139,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { return nil, fmt.Errorf("could not parse the proof claim for `%v` : %w", fpath, err) } cf.ProofClaims = append(cf.ProofClaims, *pClaim) + cf.DecompressionPI = append(cf.DecompressionPI, dp.Request) } } diff --git a/prover/backend/aggregation/prove.go b/prover/backend/aggregation/prove.go index a9da96fb..a652fbe8 100644 --- a/prover/backend/aggregation/prove.go +++ b/prover/backend/aggregation/prove.go @@ -2,6 +2,8 @@ package aggregation import ( "fmt" + pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" + public_input "github.com/consensys/zkevm-monorepo/prover/public-input" "math" "path/filepath" @@ -43,7 +45,12 @@ func makeProof( return makeDummyProof(cfg, publicInput, circuits.MockCircuitIDEmulation), nil } - proofBW6, circuitID, err := makeBw6Proof(cfg, cf, publicInput) + piProof, err := makePiProof(cfg, cf) + if err != nil { + return "", fmt.Errorf("could not create the public input proof: %w", err) + } + + proofBW6, circuitID, err := makeBw6Proof(cfg, cf, piProof, publicInput) if err != nil { return "", fmt.Errorf("error when running the BW6 proof: %w", err) } @@ -56,6 +63,43 @@ func makeProof( return circuits.SerializeProofSolidityBn254(proofBn254), nil } +func makePiProof(cfg *config.Config, cf *CollectedFields) (plonk.Proof, error) { + + c, err := pi_interconnection.Compile(cfg.PublicInputInterconnection, pi_interconnection.WizardCompilationParameters()...) + if err != nil { + return nil, fmt.Errorf("could not create the public-input circuit: %w", err) + } + + assignment, err := c.Assign(pi_interconnection.Request{ + Decompressions: cf.DecompressionPI, + Executions: cf.ExecutionPI, + Aggregation: public_input.Aggregation{ + FinalShnarf: cf.FinalShnarf, + ParentAggregationFinalShnarf: cf.ParentAggregationFinalShnarf, + ParentStateRootHash: cf.ParentStateRootHash, + ParentAggregationLastBlockTimestamp: cf.ParentAggregationLastBlockTimestamp, + FinalTimestamp: cf.FinalTimestamp, + LastFinalizedBlockNumber: cf.LastFinalizedBlockNumber, + FinalBlockNumber: cf.FinalBlockNumber, + LastFinalizedL1RollingHash: cf.LastFinalizedL1RollingHash, + L1RollingHash: cf.L1RollingHash, + LastFinalizedL1RollingHashMessageNumber: cf.LastFinalizedL1RollingHashMessageNumber, + L1RollingHashMessageNumber: cf.L1RollingHashMessageNumber, + L2MsgRootHashes: cf.L2MsgRootHashes, + L2MsgMerkleTreeDepth: int(cf.L2MsgTreeDepth), + }, + }) + if err != nil { + return nil, fmt.Errorf("could not assign the public input circuit: %w", err) + } + + setup, err := circuits.LoadSetup(cfg, circuits.PublicInputInterconnectionCircuitID) + if err != nil { + return nil, fmt.Errorf("could not load the setup: %w", err) + } + return circuits.ProveCheck(&setup, &assignment) +} + // Generates a fake proof. The public input is given in hex string format. // Returns the proof in hex string format. The circuit ID parameter specifies // for which circuit should the proof be generated. @@ -86,6 +130,7 @@ func makeDummyProof(cfg *config.Config, input string, circID circuits.MockCircui func makeBw6Proof( cfg *config.Config, cf *CollectedFields, + piProof plonk.Proof, publicInput string, ) (proof plonk.Proof, circuitID int, err error) { @@ -156,7 +201,7 @@ func makeBw6Proof( assignCircuitIDToProofClaims(bestAllowedVkForAggregation, cf.ProofClaims) // Although the public input is restrained to fit on the BN254 scalar field, - // the BW6 field is larger. This allow us representing the public input as a + // the BW6 field is larger. This allows us to represent the public input as a // single field element. var piBW6 frBW6.Element @@ -166,7 +211,7 @@ func makeBw6Proof( } logrus.Infof("running the BW6 prover") - proofBW6, err := aggregation.MakeProof(&setup, bestSize, cf.ProofClaims, piBW6) + proofBW6, err := aggregation.MakeProof(&setup, bestSize, cf.ProofClaims, piProof, piBW6) if err != nil { return nil, 0, fmt.Errorf("could not create BW6 proof: %w", err) } diff --git a/prover/backend/aggregation/request.go b/prover/backend/aggregation/request.go index 4736661f..516cd1b2 100644 --- a/prover/backend/aggregation/request.go +++ b/prover/backend/aggregation/request.go @@ -1,7 +1,9 @@ package aggregation import ( + "github.com/consensys/zkevm-monorepo/prover/backend/blobdecompression" "github.com/consensys/zkevm-monorepo/prover/circuits/aggregation" + public_input "github.com/consensys/zkevm-monorepo/prover/public-input" ) // Request collects all the fields used to perform an aggregation request. @@ -96,4 +98,7 @@ type CollectedFields struct { // The proof claims for the execution prover ProofClaims []aggregation.ProofClaimAssignment + + ExecutionPI []public_input.Execution + DecompressionPI []blobdecompression.Request } diff --git a/prover/backend/blobdecompression/prove.go b/prover/backend/blobdecompression/prove.go index 7bae6e0f..3a94895f 100644 --- a/prover/backend/blobdecompression/prove.go +++ b/prover/backend/blobdecompression/prove.go @@ -181,7 +181,7 @@ func Prove(cfg *config.Config, req *Request) (*Response, error) { Request: *req, ProverVersion: cfg.Version, DecompressionProof: circuits.SerializeProofRaw(proof), - VerifyingKeyShaSum: setup.VerifiyingKeyDigest(), + VerifyingKeyShaSum: setup.VerifyingKeyDigest(), } resp.Debug.PublicInput = "0x" + pubInput.Text(16) @@ -227,7 +227,7 @@ func dummyProve(cfg *config.Config, req *Request) (*Response, error) { Request: *req, ProverVersion: cfg.Version, DecompressionProof: proof, - VerifyingKeyShaSum: setup.VerifiyingKeyDigest(), + VerifyingKeyShaSum: setup.VerifyingKeyDigest(), } inputString := utils.HexEncodeToString(input) diff --git a/prover/backend/execution/prove.go b/prover/backend/execution/prove.go index d20212ad..dd889ef1 100644 --- a/prover/backend/execution/prove.go +++ b/prover/backend/execution/prove.go @@ -113,7 +113,7 @@ func mustProveAndPass( utils.Panic(err.Error()) } - return dummy.MakeProof(&setup, publicInput, circuits.MockCircuitIDExecution), setup.VerifiyingKeyDigest() + return dummy.MakeProof(&setup, publicInput, circuits.MockCircuitIDExecution), setup.VerifyingKeyDigest() case config.ProverModeFull: logrus.Info("Running the FULL prover") @@ -157,7 +157,7 @@ func mustProveAndPass( } // TODO: implements the collection of the functional inputs from the prover response - return execution.MakeProof(setup, fullZkEvm.WizardIOP, proof, execution.FunctionalPublicInput{}, publicInput), setup.VerifiyingKeyDigest() + return execution.MakeProof(setup, fullZkEvm.WizardIOP, proof, execution.FunctionalPublicInput{}, publicInput), setup.VerifyingKeyDigest() default: panic("not implemented") } diff --git a/prover/circuits/aggregation/asset_generation.go b/prover/circuits/aggregation/asset_generation.go index 527df51f..ad7d51d3 100644 --- a/prover/circuits/aggregation/asset_generation.go +++ b/prover/circuits/aggregation/asset_generation.go @@ -11,33 +11,43 @@ import ( ) type builder struct { - maxNbProofs int - vKeys []plonk.VerifyingKey + maxNbProofs int + vKeys []plonk.VerifyingKey + allowedInputs []string + piVKey plonk.VerifyingKey } func NewBuilder( maxNbProofs int, + allowedInputs []string, + piVKey plonk.VerifyingKey, vKeys []plonk.VerifyingKey, ) *builder { return &builder{ - maxNbProofs: maxNbProofs, - vKeys: vKeys, + piVKey: piVKey, + allowedInputs: allowedInputs, + maxNbProofs: maxNbProofs, + vKeys: vKeys, } } func (b *builder) Compile() (constraint.ConstraintSystem, error) { - return MakeCS(b.maxNbProofs, b.vKeys) + return MakeCS(b.maxNbProofs, b.allowedInputs, b.piVKey, b.vKeys) } // Initializes the bw6 aggregation circuit and returns a compiled constraint // system. func MakeCS( maxNbProofs int, + allowedInputs []string, + piVKey plonk.VerifyingKey, vKeys []plonk.VerifyingKey, ) (constraint.ConstraintSystem, error) { - aggCircuit, err := AllocateAggregationCircuit( + aggCircuit, err := AllocateCircuit( maxNbProofs, + allowedInputs, + piVKey, vKeys, ) diff --git a/prover/circuits/aggregation/circuit.go b/prover/circuits/aggregation/circuit.go index 9c27cbf8..7f35657c 100644 --- a/prover/circuits/aggregation/circuit.go +++ b/prover/circuits/aggregation/circuit.go @@ -3,18 +3,15 @@ package aggregation import ( + "errors" "fmt" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/lookup/logderivlookup" "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/std/rangecheck" emPlonk "github.com/consensys/gnark/std/recursion/plonk" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" - "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak" - public_input "github.com/consensys/zkevm-monorepo/prover/public-input" - "github.com/consensys/zkevm-monorepo/prover/utils" - "github.com/consensys/zkevm-monorepo/prover/utils/types" "slices" ) @@ -32,9 +29,9 @@ type ( // emBaseVkey = emPlonk.BaseVerifyingKey[emFr, emG1, emG2] ) -// The AggregationCircuit is used to aggregate multiple execution proofs and +// The Circuit is used to aggregate multiple execution proofs and // aggregation proofs together. -type AggregationCircuit struct { +type Circuit struct { // The list of claims to be provided to the circuit. ProofClaims []proofClaim `gnark:",secret"` @@ -42,29 +39,76 @@ type AggregationCircuit struct { // is treated as a constant by the circuit. verifyingKeys []emVkey `gnark:"-"` - // Dummy general public input - DummyPublicInput frontend.Variable `gnark:",public"` + publicInputVerifyingKey emVkey `gnark:"-"` + PublicInputProof emProof `gnark:",secret"` + PublicInputWitness emWitness `gnark:",secret"` // ordered for the PI circuit + PublicInputWitnessClaimIndexes []frontend.Variable `gnark:",secret"` + + // general public input + PublicInput frontend.Variable `gnark:",public"` } -func (c *AggregationCircuit) Define(api frontend.API) error { - // Verify the constraints the execution proofs - err := verifyClaimBatch(api, c.verifyingKeys, c.ProofClaims) +func (c *Circuit) Define(api frontend.API) error { + // match the public input with the public input of the PI circuit + field, err := emulated.NewField[emFr](api) if err != nil { - return fmt.Errorf("processing execution proofs: %w", err) + return err + } + internal.AssertSliceEquals(api, api.ToBinary(c.PublicInput, emFr{}.Modulus().BitLen()), field.ToBitsCanonical(&c.PublicInputWitness.Public[0])) + + vks := append(slices.Clone(c.verifyingKeys), c.publicInputVerifyingKey) + piVkIndex := len(vks) - 1 + + for i := range c.ProofClaims { + api.AssertIsLessOrEqual(c.ProofClaims[i].CircuitID, len(c.verifyingKeys)-1) // make sure the prover can't sneak in an extra PI circuit } - // TODO incorporate statements (circuitID+publicInput?) and vk's into public input + // create a lookup table of actual public inputs + actualPI := make([]*logderivlookup.Table, (emFr{}).NbLimbs()) + for i := range actualPI { + actualPI[i] = logderivlookup.New(api) + } + for _, claim := range c.ProofClaims { + if len(claim.PublicInput.Public) != 1 { + return errors.New("expected 1 public input per decompression/execution circuit") + } + pi := claim.PublicInput.Public[0] + for i := range actualPI { + actualPI[i].Insert(pi.Limbs[i]) + } + } + + if len(c.PublicInputWitnessClaimIndexes)+1 != len(c.PublicInputWitness.Public) { + return errors.New("expected the number of public inputs to match the number of public input witness claim indexes") + } + + // verify that every valid input to the PI circuit is accounted for + for i, actualI := range c.PublicInputWitnessClaimIndexes { + hubPI := &c.PublicInputWitness.Public[i+1] + isNonZero := api.Sub(1, field.IsZero(hubPI)) // if a PI is zero, due to preimage resistance we can infer that the PI circuit is not using it + for j := range actualPI { + internal.AssertEqualIf(api, isNonZero, actualPI[j].Lookup(actualI)[0], hubPI.Limbs[j]) + } + } + + claims := append(slices.Clone(c.ProofClaims), proofClaim{ + CircuitID: piVkIndex, + Proof: c.PublicInputProof, + PublicInput: c.PublicInputWitness, + }) + + // Verify the constraints the execution proofs + if err = verifyClaimBatch(api, vks, claims); err != nil { + return fmt.Errorf("processing execution proofs: %w", err) + } return err } -// Instantiate a new AggregationCircuit from a list of verification keys and +// Instantiate a new Circuit from a list of verification keys and // a maximal number of proofs. The function should only be called with the // purpose of running `frontend.Compile` over it. -func AllocateAggregationCircuit( - nbProofs int, - verifyingKeys []plonk.VerifyingKey, -) (*AggregationCircuit, error) { +func AllocateCircuit(nbProofs int, allowedInputs []string, key plonk.VerifyingKey, verifyingKeys []plonk.VerifyingKey) (*Circuit, error) { var ( err error @@ -84,7 +128,7 @@ func AllocateAggregationCircuit( proofClaims[i] = allocatableClaimPlaceHolder(csPlaceHolder) } - return &AggregationCircuit{ + return &Circuit{ verifyingKeys: emVKeys, ProofClaims: proofClaims, }, nil @@ -121,160 +165,3 @@ func verifyClaimBatch(api frontend.API, vks []emVkey, claims []proofClaim) error } return nil } - -type FunctionalPublicInputQSnark struct { - ParentShnarf [32]frontend.Variable - NbDecompression frontend.Variable - InitialStateRootHash frontend.Variable - InitialBlockNumber frontend.Variable - InitialBlockTimestamp frontend.Variable - InitialRollingHash [32]frontend.Variable - InitialRollingHashNumber frontend.Variable - ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID - L2MessageServiceAddr frontend.Variable // 20 bytes -} - -type FunctionalPublicInputSnark struct { - FunctionalPublicInputQSnark - NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary - L2MsgMerkleTreeRoots [][32]frontend.Variable - // FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf - FinalBlockNumber frontend.Variable - FinalBlockTimestamp frontend.Variable - FinalRollingHash [32]frontend.Variable - FinalRollingHashNumber frontend.Variable - FinalShnarf [32]frontend.Variable - L2MsgMerkleTreeDepth int -} - -// FunctionalPublicInput holds the same info as public_input.Aggregation, except in parsed form -type FunctionalPublicInput struct { - ParentShnarf [32]byte - NbDecompression uint64 - InitialStateRootHash [32]byte - InitialBlockNumber uint64 - InitialBlockTimestamp uint64 - InitialRollingHash [32]byte - InitialRollingHashNumber uint64 - ChainID uint64 // for now we're forcing all executions to have the same chain ID - L2MessageServiceAddr types.EthAddress - NbL2Messages uint64 // TODO not used in hash. delete if not necessary - L2MsgMerkleTreeRoots [][32]byte - //FinalStateRootHash [32]byte redundant: incorporated into shnarf - FinalBlockNumber uint64 - FinalBlockTimestamp uint64 - FinalRollingHash [32]byte - FinalRollingHashNumber uint64 - FinalShnarf [32]byte - L2MsgMerkleTreeDepth int -} - -// NewFunctionalPublicInput does NOT set all fields, only the ones covered in public_input.Aggregation -func NewFunctionalPublicInput(fpi *public_input.Aggregation) (s *FunctionalPublicInput, err error) { - s = &FunctionalPublicInput{ - InitialBlockNumber: uint64(fpi.LastFinalizedBlockNumber), - InitialBlockTimestamp: uint64(fpi.ParentAggregationLastBlockTimestamp), - InitialRollingHashNumber: uint64(fpi.LastFinalizedL1RollingHashMessageNumber), - L2MsgMerkleTreeRoots: make([][32]byte, len(fpi.L2MsgRootHashes)), - FinalBlockNumber: uint64(fpi.FinalBlockNumber), - FinalBlockTimestamp: uint64(fpi.FinalTimestamp), - FinalRollingHashNumber: uint64(fpi.L1RollingHashMessageNumber), - L2MsgMerkleTreeDepth: fpi.L2MsgMerkleTreeDepth, - } - - if err = copyFromHex(s.InitialStateRootHash[:], fpi.ParentStateRootHash); err != nil { - return - } - if err = copyFromHex(s.FinalRollingHash[:], fpi.L1RollingHash); err != nil { - return - } - if err = copyFromHex(s.InitialRollingHash[:], fpi.LastFinalizedL1RollingHash); err != nil { - return - } - if err = copyFromHex(s.ParentShnarf[:], fpi.ParentAggregationFinalShnarf); err != nil { - return - } - if err = copyFromHex(s.FinalShnarf[:], fpi.FinalShnarf); err != nil { - return - } - - for i := range s.L2MsgMerkleTreeRoots { - if err = copyFromHex(s.L2MsgMerkleTreeRoots[i][:], fpi.L2MsgRootHashes[i]); err != nil { - return - } - } - return -} - -func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark { - s := FunctionalPublicInputSnark{ - FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{ - InitialBlockNumber: pi.InitialBlockNumber, - InitialBlockTimestamp: pi.InitialBlockTimestamp, - InitialRollingHash: [32]frontend.Variable{}, - InitialRollingHashNumber: pi.InitialRollingHashNumber, - InitialStateRootHash: pi.InitialStateRootHash[:], - - NbDecompression: pi.NbDecompression, - ChainID: pi.ChainID, - L2MessageServiceAddr: pi.L2MessageServiceAddr[:], - }, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), - FinalBlockNumber: pi.FinalBlockNumber, - FinalBlockTimestamp: pi.FinalBlockTimestamp, - FinalRollingHashNumber: pi.FinalRollingHashNumber, - L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, - } - - internal.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:]) - internal.Copy(s.InitialRollingHash[:], pi.InitialRollingHash[:]) - internal.Copy(s.ParentShnarf[:], pi.ParentShnarf[:]) - internal.Copy(s.FinalShnarf[:], pi.FinalShnarf[:]) - - for i := range s.L2MsgMerkleTreeRoots { - internal.Copy(s.L2MsgMerkleTreeRoots[i][:], pi.L2MsgMerkleTreeRoots[i][:]) - } - - return s -} - -func (pi *FunctionalPublicInputSnark) Sum(api frontend.API, hash keccak.BlockHasher) [32]frontend.Variable { - // number of hashes: 12 - sum := hash.Sum(nil, - pi.ParentShnarf, - pi.FinalShnarf, - internal.ToBytes(api, pi.InitialBlockTimestamp), - internal.ToBytes(api, pi.FinalBlockTimestamp), - internal.ToBytes(api, pi.InitialBlockNumber), - internal.ToBytes(api, pi.FinalBlockNumber), - pi.InitialRollingHash, - pi.FinalRollingHash, - internal.ToBytes(api, pi.InitialRollingHashNumber), - internal.ToBytes(api, pi.FinalRollingHashNumber), - internal.ToBytes(api, pi.L2MsgMerkleTreeDepth), - hash.Sum(nil, pi.L2MsgMerkleTreeRoots...), - ) - - // turn the hash into a bn254 element - var res [32]frontend.Variable - copy(res[:], internal.ReduceBytes[emulated.BN254Fr](api, sum[:])) - return res -} - -func (pi *FunctionalPublicInputQSnark) RangeCheck(api frontend.API) { - rc := rangecheck.New(api) - for _, v := range append(slices.Clone(pi.InitialRollingHash[:]), pi.ParentShnarf[:]...) { - rc.Check(v, 8) - } - // not checking L2MsgServiceAddr as its range is never assumed in the pi circuit - // not checking NbDecompressions as the NewRange in the pi circuit range checks it; TODO do it here instead -} - -func copyFromHex(dst []byte, src string) error { - b, err := utils.HexDecodeString(src) - if err != nil { - return err - } - copy(dst[len(dst)-len(b):], b) // panics if src is too long - return nil -} diff --git a/prover/circuits/aggregation/circuit_test.go b/prover/circuits/aggregation/circuit_test.go index 11bac058..ac03e06e 100644 --- a/prover/circuits/aggregation/circuit_test.go +++ b/prover/circuits/aggregation/circuit_test.go @@ -32,7 +32,7 @@ func TestPublicInput(t *testing.T) { for i := range testCases { - fpi, err := NewFunctionalPublicInput(&testCases[i]) + fpi, err := public_input.NewAggregationFPI(&testCases[i]) assert.NoError(t, err) sfpi := fpi.ToSnarkType() diff --git a/prover/circuits/aggregation/prover.go b/prover/circuits/aggregation/prover.go index fff1ec69..7d65ecb4 100644 --- a/prover/circuits/aggregation/prover.go +++ b/prover/circuits/aggregation/prover.go @@ -2,7 +2,6 @@ package aggregation import ( "fmt" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark/backend/plonk" @@ -17,6 +16,7 @@ func MakeProof( setup *circuits.Setup, maxNbProof int, proofClaims []ProofClaimAssignment, + piProof plonk.Proof, publicInput fr.Element, ) ( plonk.Proof, @@ -27,6 +27,7 @@ func MakeProof( assignment, err := AssignAggregationCircuit( maxNbProof, proofClaims, + piProof, publicInput, ) @@ -44,18 +45,11 @@ func MakeProof( } // Assigns the proof using placeholders -func AssignAggregationCircuit( - maxNbProof int, - proofClaims []ProofClaimAssignment, - publicInput fr.Element, -) ( - c *AggregationCircuit, - err error, -) { +func AssignAggregationCircuit(maxNbProof int, proofClaims []ProofClaimAssignment, publicInputProof plonk.Proof, publicInput fr.Element) (c *Circuit, err error) { - c = &AggregationCircuit{ - ProofClaims: make([]proofClaim, maxNbProof), - DummyPublicInput: publicInput, + c = &Circuit{ + ProofClaims: make([]proofClaim, maxNbProof), + PublicInput: publicInput, } for i := range c.ProofClaims { diff --git a/prover/circuits/blobdecompression/public-input/pi.go b/prover/circuits/blobdecompression/public-input/pi.go index 830966ae..ad6ba5a4 100644 --- a/prover/circuits/blobdecompression/public-input/pi.go +++ b/prover/circuits/blobdecompression/public-input/pi.go @@ -6,7 +6,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/compress" "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/zkevm-monorepo/prover/circuits/internal" + "github.com/consensys/zkevm-monorepo/prover/utils" "math/big" "math/bits" ) @@ -87,7 +87,7 @@ func VerifyBlobConsistency(api frontend.API, blobCrumbs []frontend.Variable, eva } blobEmulated := packCrumbsEmulated(api, blobCrumbs) // perf-TODO use the original blob bytes - evaluationChallengeEmulated := internal.NewElementFromBytes[emulated.BLS12381Fr](api, evaluationChallenge[:]) + evaluationChallengeEmulated := utils.NewElementFromBytes[emulated.BLS12381Fr](api, evaluationChallenge[:]) blobEmulatedBitReversed := make([]*emulated.Element[emulated.BLS12381Fr], len(blobEmulated)) copy(blobEmulatedBitReversed, blobEmulated) diff --git a/prover/circuits/blobdecompression/public-input/pi_test.go b/prover/circuits/blobdecompression/public-input/pi_test.go index 77cbd33a..4dc5a741 100644 --- a/prover/circuits/blobdecompression/public-input/pi_test.go +++ b/prover/circuits/blobdecompression/public-input/pi_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" + "github.com/consensys/zkevm-monorepo/prover/utils" "math/big" "os" "path/filepath" @@ -89,11 +90,11 @@ func TestInterpolateLagrange(t *testing.T) { scalars, err := internal.Bls12381ScalarToBls12377Scalars(evaluationPointFr) assert.NoError(t, err) - internal.Copy(assignment.EvaluationPoint[:], scalars[:]) + utils.Copy(assignment.EvaluationPoint[:], scalars[:]) scalars, err = internal.Bls12381ScalarToBls12377Scalars(evaluation) assert.NoError(t, err) - internal.Copy(assignment.Evaluation[:], scalars[:]) + utils.Copy(assignment.Evaluation[:], scalars[:]) return &assignment } @@ -198,7 +199,7 @@ func decodeHexHL(t *testing.T, s string) (r [2]frontend.Variable) { scalars, err := internal.Bls12381ScalarToBls12377Scalars(b) assert.NoError(t, err) - internal.Copy(r[:], scalars[:]) + utils.Copy(r[:], scalars[:]) return } @@ -238,7 +239,7 @@ func TestVerifyBlobConsistencyIntegration(t *testing.T) { assignment.Eip4844Enabled = 1 } - internal.Copy(assignment.X[:], decodeHex(t, testCase.ExpectedX)) + utils.Copy(assignment.X[:], decodeHex(t, testCase.ExpectedX)) assignment.Y = decodeHexHL(t, testCase.ExpectedY) t.Run(folderAndFile, func(t *testing.T) { @@ -375,7 +376,7 @@ func TestFrConversions(t *testing.T) { assert.Equal(t, tmp, xBack, fmt.Sprintf("out-of-snark conversion round-trip failed on %s or 0x%s", tmp.Text(10), tmp.Text(16))) var assignment testFrConversionCircuit - internal.Copy(assignment.X[:], xPartitioned[:]) + utils.Copy(assignment.X[:], xPartitioned[:]) options = append(options, test.WithValidAssignment(&assignment)) } diff --git a/prover/circuits/blobdecompression/v0/circuit.go b/prover/circuits/blobdecompression/v0/circuit.go index d6825442..6517e19e 100644 --- a/prover/circuits/blobdecompression/v0/circuit.go +++ b/prover/circuits/blobdecompression/v0/circuit.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/consensys/gnark/std/rangecheck" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" + "github.com/consensys/zkevm-monorepo/prover/utils" fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -137,7 +138,7 @@ func (c *Circuit) Define(api frontend.API) error { } } - xBytes := internal.ToBytes(api, c.X[1]) + xBytes := utils.ToBytes(api, c.X[1]) rc := rangecheck.New(api) const nbBitsLower = (fr377.Bits - 1) % 8 rc.Check(xBytes[0], nbBitsLower) diff --git a/prover/circuits/blobdecompression/v0/prelude.go b/prover/circuits/blobdecompression/v0/prelude.go index f7dd15bf..b8916e72 100644 --- a/prover/circuits/blobdecompression/v0/prelude.go +++ b/prover/circuits/blobdecompression/v0/prelude.go @@ -101,7 +101,8 @@ func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Elem // in the Assign function although this creates an unintuitive side effect. // It should be harmless though. lzss.RegisterHints() - internal.RegisterHints() + //internal.RegisterHints() + utils.RegisterHints() a := Circuit{ BlobBytesLen: len(blobDataUnpacked), diff --git a/prover/circuits/blobdecompression/v1/circuit.go b/prover/circuits/blobdecompression/v1/circuit.go index a4baab20..20302da7 100644 --- a/prover/circuits/blobdecompression/v1/circuit.go +++ b/prover/circuits/blobdecompression/v1/circuit.go @@ -18,6 +18,7 @@ import ( test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" "github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc" + "github.com/consensys/zkevm-monorepo/prover/utils" "math/big" blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" @@ -111,11 +112,11 @@ func (i *FunctionalPublicInput) ToSnarkType() (FunctionalPublicInputSnark, error NbBatches: len(i.BatchSums), }, } - internal.Copy(res.X[:], i.X[:]) + utils.Copy(res.X[:], i.X[:]) if len(i.BatchSums) > len(res.BatchSums) { return res, errors.New("batches do not fit in circuit") } - for n := internal.Copy(res.BatchSums[:], i.BatchSums); n < len(res.BatchSums); n++ { + for n := utils.Copy(res.BatchSums[:], i.BatchSums); n < len(res.BatchSums); n++ { res.BatchSums[n] = 0 } return res, nil diff --git a/prover/circuits/execution/pi.go b/prover/circuits/execution/pi.go index 78898bf6..b1afcf11 100644 --- a/prover/circuits/execution/pi.go +++ b/prover/circuits/execution/pi.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/gnark/std/rangecheck" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" "github.com/consensys/zkevm-monorepo/prover/crypto/mimc" + "github.com/consensys/zkevm-monorepo/prover/utils" "github.com/consensys/zkevm-monorepo/prover/utils/types" "hash" "slices" @@ -15,11 +16,11 @@ import ( // FunctionalPublicInputQSnark the information on this execution that cannot be extracted from other input in the same aggregation batch type FunctionalPublicInputQSnark struct { DataChecksum frontend.Variable - L2MessageHashes internal.Var32Slice // TODO range check + L2MessageHashes internal.Var32Slice FinalStateRootHash frontend.Variable FinalBlockNumber frontend.Variable FinalBlockTimestamp frontend.Variable - FinalRollingHash [32]frontend.Variable // TODO range check + FinalRollingHash [32]frontend.Variable FinalRollingHashNumber frontend.Variable } @@ -103,15 +104,14 @@ func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark { ChainID: pi.ChainID, L2MessageServiceAddr: slices.Clone(pi.L2MessageServiceAddr[:]), } - internal.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:]) - internal.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:]) + utils.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:]) + utils.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:]) return res } func (pi *FunctionalPublicInput) Sum() []byte { // all mimc; no need to provide a keccak hasher hsh := mimc.NewMiMC() - // TODO incorporate length too? Technically not necessary var zero [1]byte for i := range pi.L2MessageHashes { diff --git a/prover/circuits/internal/io.go b/prover/circuits/internal/io.go index 9e82522f..e8b4b165 100644 --- a/prover/circuits/internal/io.go +++ b/prover/circuits/internal/io.go @@ -17,7 +17,7 @@ func CopyHexEncodedBytes(dst []frontend.Variable, hex string) error { dst[i] = 0 } - Copy(dst[slack:], b) // This will panic if b is too long + utils.Copy(dst[slack:], b) // This will panic if b is too long return nil } diff --git a/prover/circuits/internal/test_utils/test_utils.go b/prover/circuits/internal/test_utils/test_utils.go index 984e7370..142f2f8f 100644 --- a/prover/circuits/internal/test_utils/test_utils.go +++ b/prover/circuits/internal/test_utils/test_utils.go @@ -180,3 +180,20 @@ func (FakeTestingT) Errorf(format string, args ...interface{}) { func (FakeTestingT) FailNow() { os.Exit(-1) } + +func RandIntN(n int) int { + var b [8]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return int(binary.BigEndian.Uint64(b[:]) % uint64(n)) +} + +func RandIntSliceN(length, n int) []int { + res := make([]int, length) + for i := range res { + res[i] = RandIntN(n) + } + return res +} diff --git a/prover/circuits/internal/utils.go b/prover/circuits/internal/utils.go index f5ef073d..704ee016 100644 --- a/prover/circuits/internal/utils.go +++ b/prover/circuits/internal/utils.go @@ -13,8 +13,8 @@ import ( "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/lookup/logderivlookup" "github.com/consensys/gnark/std/math/emulated" - "github.com/consensys/gnark/std/rangecheck" "github.com/consensys/zkevm-monorepo/prover/circuits/internal/plonk" + "github.com/consensys/zkevm-monorepo/prover/utils" "golang.org/x/exp/constraints" "math/big" "slices" @@ -128,7 +128,7 @@ func (r *Range) LastArray32F(provider func(int) [32]frontend.Variable) [32]front } func RegisterHints() { - hint.RegisterHint(toCrumbsHint, concatHint, decomposeIntoBytesHint, checksumSubSlicesHint) + hint.RegisterHint(toCrumbsHint, concatHint, checksumSubSlicesHint, partitionSliceHint) } func toCrumbsHint(_ *big.Int, ins, outs []*big.Int) error { @@ -605,71 +605,6 @@ func Ite[T any](cond bool, ifSo, ifNot T) T { return ifNot } -func ToBytes(api frontend.API, x frontend.Variable) [32]frontend.Variable { - var res [32]frontend.Variable - d := decomposeIntoBytes(api, x, fr377.Bits) - slack := 32 - len(d) // should be zero - copy(res[slack:], d) - for i := 0; i < slack; i++ { - res[i] = 0 - } - return res -} - -func decomposeIntoBytes(api frontend.API, data frontend.Variable, nbBits int) []frontend.Variable { - if nbBits == 0 { - nbBits = api.Compiler().FieldBitLen() - } - - nbBytes := (api.Compiler().FieldBitLen() + 7) / 8 - - bytes, err := api.Compiler().NewHint(decomposeIntoBytesHint, nbBytes, data) - if err != nil { - panic(err) - } - lastNbBits := nbBits % 8 - if lastNbBits == 0 { - lastNbBits = 8 - } - rc := rangecheck.New(api) - api.AssertIsLessOrEqual(bytes[0], 1<= 0; j-- { - outs[i*nbBytes+j].Mod(&v, &radix) - v.Rsh(&v, 8) - } - if v.Cmp(&zero) != 0 { - return errors.New("not fitting in len(outs)/len(ins) many bytes") - } - } - return nil -} - -func Copy[T any](dst []frontend.Variable, src []T) (n int) { - n = min(len(dst), len(src)) - - for i := 0; i < n; i++ { - dst[i] = src[i] - } - - return -} - // PartialSums returns s[0], s[0]+s[1], ..., s[0]+s[1]+...+s[len(s)-1] func PartialSums(api frontend.API, s []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(s)) @@ -690,46 +625,6 @@ func Differences(api frontend.API, s []frontend.Variable) []frontend.Variable { return res } -// NewElementFromBytes range checks the bytes and gives a reduced field element -func NewElementFromBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) *emulated.Element[T] { - bits := make([]frontend.Variable, 8*len(bytes)) - for i := range bytes { - copy(bits[8*i:], api.ToBinary(bytes[len(bytes)-i-1], 8)) - } - - f, err := emulated.NewField[T](api) - if err != nil { - panic(err) - } - - return f.Reduce(f.Add(f.FromBits(bits...), f.Zero())) -} - -// ReduceBytes reduces given bytes modulo a given field. As a side effect, the "bytes" are range checked -func ReduceBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) []frontend.Variable { - f, err := emulated.NewField[T](api) - if err != nil { - panic(err) - } - - bits := f.ToBits(NewElementFromBytes[T](api, bytes)) - res := make([]frontend.Variable, (len(bits)+7)/8) - copy(bits[:], bits) - for i := len(bits); i < len(bits); i++ { - bits[i] = 0 - } - for i := range res { - bitsStart := 8 * (len(res) - i - 1) - bitsEnd := bitsStart + 8 - if i == 0 { - bitsEnd = len(bits) - } - res[i] = api.FromBinary(bits[bitsStart:bitsEnd]...) - } - - return res -} - func NewSliceOf32Array[T any](values [][32]T, maxLen int) Var32Slice { if maxLen < len(values) { panic("maxLen too small") @@ -739,11 +634,11 @@ func NewSliceOf32Array[T any](values [][32]T, maxLen int) Var32Slice { Length: len(values), } for i := range values { - Copy(res.Values[i][:], values[i][:]) + utils.Copy(res.Values[i][:], values[i][:]) } var zeros [32]byte for i := len(values); i < maxLen; i++ { - Copy(res.Values[i][:], zeros[:]) + utils.Copy(res.Values[i][:], zeros[:]) } return res } @@ -811,3 +706,168 @@ func CloneSlice[T any](s []T, cap ...int) []T { copy(res, s) return res } + +// PartitionSlice populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i +// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact. +// It may produce an incorrect result if selectors are out of range +func PartitionSlice(api frontend.API, s []frontend.Variable, selectors []frontend.Variable, subs ...[]frontend.Variable) { + if len(s) != len(selectors) { + panic("s and selectors must have the same length") + } + hintIn := make([]frontend.Variable, 1+len(subs)+len(s)+len(selectors)) + hintIn[0] = len(subs) + hintOutLen := 0 + for i := range subs { + hintIn[1+i] = len(subs[i]) + hintOutLen += len(subs[i]) + } + for i := range s { + hintIn[1+len(subs)+i] = s[i] + hintIn[1+len(subs)+len(s)+i] = selectors[i] + } + subsGlued, err := api.Compiler().NewHint(partitionSliceHint, hintOutLen, hintIn...) + if err != nil { + panic(err) + } + + subsT := make([]*logderivlookup.Table, len(subs)) + for i := range subs { + copy(subs[i], subsGlued[:len(subs[i])]) + subsGlued = subsGlued[len(subs[i]):] + subsT[i] = SliceToTable(api, subs[i]) + subsT[i].Insert(0) + } + + subI := make([]frontend.Variable, len(subs)) + for i := range subI { + subI[i] = 0 + } + + indicators := make([]frontend.Variable, len(subs)) + subHeads := make([]frontend.Variable, len(subs)) + for i := range s { + for j := range subs[:len(subs)-1] { + indicators[j] = api.IsZero(api.Sub(selectors[i], j)) + } + indicators[len(subs)-1] = api.Sub(1, SumSnark(api, indicators[:len(subs)-1]...)) + + for j := range subs { + subHeads[j] = subsT[j].Lookup(subI[j])[0] + subI[j] = api.Add(subI[j], indicators[j]) + } + + api.AssertIsEqual(s[i], InnerProd(api, subHeads, indicators)) + } + + // Check that the dummy trailing values weren't actually picked + for i := range subI { + api.AssertIsDifferent(subI[i], len(subs[i])+1) + } +} + +func SumSnark(api frontend.API, x ...frontend.Variable) frontend.Variable { + res := frontend.Variable(0) + for i := range x { + res = api.Add(res, x[i]) + } + return res +} + +// ins: [nbSubs, maxLen_0, ..., maxLen_{nbSubs-1}, s..., indicators...] +func partitionSliceHint(_ *big.Int, ins, outs []*big.Int) error { + + subs := make([][]*big.Int, ins[0].Uint64()) + for i := range subs { + subs[i] = outs[:ins[1+i].Uint64()] + outs = outs[len(subs[i]):] + } + if len(outs) != 0 { + return errors.New("the sum of subslice max lengths does not equal output length") + } + + ins = ins[1+len(subs):] + + s := ins[:len(ins)/2] + indicators := ins[len(s):] + if len(s) != len(indicators) { + return errors.New("s and indicators must be of the same length") + } + + for i := range s { + b := int(indicators[i].Uint64()) + if b < 0 || b >= len(subs) || !indicators[i].IsUint64() { + return errors.New("indicator out of range") + } + subs[b][0] = s[i] + subs[b] = subs[b][1:] + } + + for i := range subs { + for j := range subs[i] { + subs[i][j].SetInt64(0) + } + } + + return nil +} + +// PartitionSliceEmulated populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i +// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact. +// It may produce an incorrect result if selectors are out of range +func PartitionSliceEmulated[T emulated.FieldParams](api frontend.API, s []emulated.Element[T], selectors []frontend.Variable, subSliceMaxLens ...int) [][]emulated.Element[T] { + field, err := emulated.NewField[T](api) + if err != nil { + panic(err) + } + + // transpose limbs for selection + limbs := make([][]frontend.Variable, len(s[0].Limbs)) // limbs are indexed limb first, element second + for i := range limbs { + limbs[i] = make([]frontend.Variable, len(s)) + } + for i := range s { + if len(limbs) != len(s[i].Limbs) { + panic("expected uniform number of limbs") + } + for j := range limbs { + limbs[j][i] = s[i].Limbs[j] + } + } + + subLimbs := make([][][]frontend.Variable, len(limbs)) // subLimbs is indexed limb first, sub-slice second, element third + + for i := range limbs { // construct the sub-slices limb by limb + subLimbs[i] = make([][]frontend.Variable, len(subSliceMaxLens)) + for j := range subSliceMaxLens { + subLimbs[i][j] = make([]frontend.Variable, subSliceMaxLens[j]) + } + + PartitionSlice(api, limbs[i], selectors, subLimbs[i]...) + } + + // put the limbs back together + subSlices := make([][]emulated.Element[T], len(subSliceMaxLens)) + for i := range subSlices { + subSlices[i] = make([]emulated.Element[T], subSliceMaxLens[i]) + for j := range subSlices[i] { + currLimbs := make([]frontend.Variable, len(limbs)) + for k := range currLimbs { + currLimbs[k] = subLimbs[k][i][j] + } + subSlices[i][j] = *field.NewElement(currLimbs) // TODO make sure dereferencing is not problematic + } + } + + return subSlices +} + +func InnerProd(api frontend.API, x, y []frontend.Variable) frontend.Variable { + if len(x) != len(y) { + panic("mismatched lengths") + } + res := frontend.Variable(0) + for i := range x { + res = api.Add(res, api.Mul(x[i], y[i])) + } + return res +} diff --git a/prover/circuits/internal/utils_test.go b/prover/circuits/internal/utils_test.go index 13a4faf7..d203a95a 100644 --- a/prover/circuits/internal/utils_test.go +++ b/prover/circuits/internal/utils_test.go @@ -3,6 +3,8 @@ package internal_test import ( "crypto/rand" "fmt" + fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash/mimc" @@ -55,7 +57,7 @@ func testChecksumSubSlices(t *testing.T, bigSliceLength, lengthsSliceLength int, } endPointsSnark := make([]frontend.Variable, lengthsSliceLength) - for n := internal.Copy(endPointsSnark, internal.PartialSumsInt(lengths)); n < lengthsSliceLength; n++ { + for n := utils.Copy(endPointsSnark, internal.PartialSumsInt(lengths)); n < lengthsSliceLength; n++ { endPointsSnark[n] = n * 234 } @@ -103,7 +105,7 @@ func TestReduceBytes(t *testing.T) { test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable { for i := range cases { - got := internal.ReduceBytes[emulated.BN254Fr](api, test_vector_utils.ToVariableSlice(cases[i])) + got := utils.ReduceBytes[emulated.BN254Fr](api, test_vector_utils.ToVariableSlice(cases[i])) internal.AssertSliceEquals(api, got, test_vector_utils.ToVariableSlice(reduced[i]), @@ -113,3 +115,113 @@ func TestReduceBytes(t *testing.T) { return nil })(t) } + +func TestPartitionSliceEmulated(t *testing.T) { + selectors := []int{1, 0, 2, 2, 1} + + s := make([]fr381.Element, len(selectors)) + for i := range s { + _, err := s[i].SetRandom() + assert.NoError(t, err) + } + + subs := make([][]fr381.Element, 3) + for i := range subs { + subs[i] = make([]fr381.Element, 0, len(selectors)-1) + } + + for i := range s { + subs[selectors[i]] = append(subs[selectors[i]], s[i]) + } + + test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable { + + field, err := emulated.NewField[emulated.BLS12381Fr](api) + assert.NoError(t, err) + + // convert randomized elements to emulated + sEmulated := elementsToEmulated(field, s) + + subsEmulatedExpected := internal.MapSlice(func(s []fr381.Element) []emulated.Element[emulated.BLS12381Fr] { + return elementsToEmulated(field, append(s, make([]fr381.Element, cap(s)-len(s))...)) // pad with zeros to see if padding is done correctly + }, subs...) + + subsEmulated := internal.PartitionSliceEmulated(api, sEmulated, test_vector_utils.ToVariableSlice(selectors), internal.MapSlice(func(s []fr381.Element) int { return cap(s) }, subs...)...) + + assert.Equal(t, len(subsEmulatedExpected), len(subsEmulated)) + for i := range subsEmulated { + assert.Equal(t, len(subsEmulatedExpected[i]), len(subsEmulated[i])) + for j := range subsEmulated[i] { + field.AssertIsEqual(&subsEmulated[i][j], &subsEmulatedExpected[i][j]) + } + } + + return nil + })(t) +} + +func elementsToEmulated(field *emulated.Field[emulated.BLS12381Fr], s []fr381.Element) []emulated.Element[emulated.BLS12381Fr] { + return internal.MapSlice(func(element fr381.Element) emulated.Element[emulated.BLS12381Fr] { + return *field.NewElement(internal.MapSlice(func(x uint64) frontend.Variable { return x }, element[:]...)) + }, s...) +} + +func TestPartitionSlice(t *testing.T) { + const ( + nbSubs = 3 + sliceLen = 10 + ) + + test := func(slice []frontend.Variable, selectors []int, subsSlack []int) func(*testing.T) { + assert.Equal(t, len(selectors), len(slice)) + assert.Equal(t, len(subsSlack), nbSubs) + + subs := make([][]frontend.Variable, nbSubs) + for j := range subs { + subs[j] = make([]frontend.Variable, 0, sliceLen) + } + + for j := range slice { + subs[selectors[j]] = append(subs[selectors[j]], slice[j]) + } + + for j := range subs { + subs[j] = append(subs[j], test_vector_utils.ToVariableSlice(make([]int, subsSlack[j]))...) // add some padding + } + + return test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable { + + slice := test_vector_utils.ToVariableSlice(slice) + + subsEncountered := internal.MapSlice(func(s []frontend.Variable) []frontend.Variable { return make([]frontend.Variable, len(s)) }, subs...) + internal.PartitionSlice(api, slice, test_vector_utils.ToVariableSlice(selectors), subsEncountered...) + + assert.Equal(t, len(subs), len(subsEncountered)) + for j := range subsEncountered { + internal.AssertSliceEquals(api, subsEncountered[j], subs[j]) + } + + return nil + }) + } + + test([]frontend.Variable{5}, []int{2}, []int{1, 0, 0})(t) + test([]frontend.Variable{1, 2, 3}, []int{0, 1, 2}, []int{0, 0, 0}) + test(test_vector_utils.ToVariableSlice(test_utils.Range[int](10)), []int{0, 1, 2, 0, 0, 0, 1, 1, 1, 2}, []int{0, 0, 0}) + + for i := 0; i < 200; i++ { + + slice := make([]frontend.Variable, sliceLen) + for j := range slice { + var x fr377.Element + _, err := x.SetRandom() + slice[j] = &x + assert.NoError(t, err) + } + + selectors := test_utils.RandIntSliceN(sliceLen, nbSubs) + subsSlack := test_utils.RandIntSliceN(nbSubs, 2) + + t.Run(fmt.Sprintf("selectors=%v,slack=%v", selectors, subsSlack), test(slice, selectors, subsSlack)) + } +} diff --git a/prover/circuits/pi-interconnection/assign.go b/prover/circuits/pi-interconnection/assign.go index 36827a86..aa31773b 100644 --- a/prover/circuits/pi-interconnection/assign.go +++ b/prover/circuits/pi-interconnection/assign.go @@ -6,7 +6,6 @@ import ( "errors" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/zkevm-monorepo/prover/backend/blobsubmission" - "github.com/consensys/zkevm-monorepo/prover/circuits/aggregation" decompression "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v1" "github.com/consensys/zkevm-monorepo/prover/circuits/execution" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" @@ -18,31 +17,22 @@ import ( "hash" ) -type ExecutionRequest struct { - L2MsgHashes [][32]byte - FinalStateRootHash [32]byte - FinalBlockNumber uint64 - FinalBlockTimestamp uint64 - FinalRollingHash [32]byte - FinalRollingHashNumber uint64 -} - type Request struct { - DecompDict []byte Decompressions []blobsubmission.Response - Executions []ExecutionRequest + Executions []public_input.Execution Aggregation public_input.Aggregation } func (c *Compiled) Assign(r Request) (a Circuit, err error) { internal.RegisterHints() keccak.RegisterHints() + utils.RegisterHints() // TODO there is data duplication in the request. Check consistency // infer config config := c.getConfig() - a = config.allocateCircuit() + a = allocateCircuit(config) if len(r.Decompressions) > config.MaxNbDecompression { err = errors.New("number of decompression proofs exceeds maximum") @@ -69,7 +59,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { if err != nil { return } - internal.Copy(a.ParentShnarf[:], prevShnarf) + utils.Copy(a.ParentShnarf[:], prevShnarf) execDataChecksums := make([][]byte, 0, len(r.Executions)) shnarfs := make([][]byte, config.MaxNbDecompression) @@ -156,7 +146,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { } else { a.DecompressionFPIQ[i] = fpis.FunctionalPublicInputQSnark } - internal.Copy(a.DecompressionFPIQ[i].X[:], zero[:]) + utils.Copy(a.DecompressionFPIQ[i].X[:], zero[:]) if a.DecompressionPublicInput[i], err = fpi.Sum(decompression.WithBatchesSum(zero[:])); err != nil { // TODO zero batches sum is probably incorrect return } @@ -164,7 +154,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { // Aggregation FPI - aggregationFPI, err := aggregation.NewFunctionalPublicInput(&r.Aggregation) + aggregationFPI, err := public_input.NewAggregationFPI(&r.Aggregation) if err != nil { return } @@ -173,10 +163,10 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { return } aggregationFPI.NbDecompression = uint64(len(r.Decompressions)) - a.FunctionalPublicInputQSnark = aggregationFPI.ToSnarkType().FunctionalPublicInputQSnark + a.AggregationFPIQSnark = aggregationFPI.ToSnarkType().AggregationFPIQSnark merkleNbLeaves := 1 << config.L2MsgMerkleDepth - maxNbL2MessageHashes := config.L2MessageMaxNbMerkle * merkleNbLeaves + maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes) // Execution FPI executionFPI := execution.FunctionalPublicInput{ @@ -187,7 +177,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { FinalRollingHashNumber: aggregationFPI.InitialRollingHashNumber, L2MessageServiceAddr: aggregationFPI.L2MessageServiceAddr, ChainID: aggregationFPI.ChainID, - MaxNbL2MessageHashes: config.MaxNbMsgPerExecution, + MaxNbL2MessageHashes: config.ExecutionMaxNbMsg, } for i := range a.ExecutionFPIQ { executionFPI.InitialRollingHash = executionFPI.FinalRollingHash @@ -251,16 +241,16 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { } // pad the merkle roots - if len(r.Aggregation.L2MsgRootHashes) > config.L2MessageMaxNbMerkle { + if len(r.Aggregation.L2MsgRootHashes) > config.L2MsgMaxNbMerkle { err = errors.New("more merkle trees than there is capacity") return } { - roots := internal.CloneSlice(r.Aggregation.L2MsgRootHashes, config.L2MessageMaxNbMerkle) + roots := internal.CloneSlice(r.Aggregation.L2MsgRootHashes, config.L2MsgMaxNbMerkle) emptyRootHex := utils.HexEncodeToString(emptyTree[len(emptyTree)-1][:32]) - for i := len(r.Aggregation.L2MsgRootHashes); i < config.L2MessageMaxNbMerkle; i++ { + for i := len(r.Aggregation.L2MsgRootHashes); i < config.L2MsgMaxNbMerkle; i++ { for depth := config.L2MsgMerkleDepth; depth > 0; depth-- { for j := 0; j < 1<<(depth-1); j++ { hshK.Skip(emptyTree[config.L2MsgMerkleDepth-depth]) diff --git a/prover/circuits/pi-interconnection/bench/main.go b/prover/circuits/pi-interconnection/bench/main.go index 08cdcc67..ca2289d5 100644 --- a/prover/circuits/pi-interconnection/bench/main.go +++ b/prover/circuits/pi-interconnection/bench/main.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils" pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" pitesting "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/test_utils" + "github.com/consensys/zkevm-monorepo/prover/config" "github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc" "time" @@ -23,14 +24,14 @@ func main() { var b test_utils.FakeTestingT req := pitesting.AssignSingleBlockBlob(b) - c, err := pi_interconnection.Config{ - MaxNbDecompression: 400, - MaxNbExecution: 400, - MaxNbKeccakF: 10000, - MaxNbMsgPerExecution: 16, - L2MsgMerkleDepth: 5, - L2MessageMaxNbMerkle: 10, - }.Compile(dummy.Compile) // note that the solving/proving time will not reflect the wizard proof or verification + c, err := pi_interconnection.Compile(config.PublicInput{ + MaxNbDecompression: 400, + MaxNbExecution: 400, + MaxNbKeccakF: 10000, + ExecutionMaxNbMsg: 16, + L2MsgMerkleDepth: 5, + L2MsgMaxNbMerkle: 10, + }, dummy.Compile) // note that the solving/proving time will not reflect the wizard proof or verification assert.NoError(b, err) a, err := c.Assign(req) diff --git a/prover/circuits/pi-interconnection/circuit.go b/prover/circuits/pi-interconnection/circuit.go index 02edb970..fa621221 100644 --- a/prover/circuits/pi-interconnection/circuit.go +++ b/prover/circuits/pi-interconnection/circuit.go @@ -2,6 +2,13 @@ package pi_interconnection import ( "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/zkevm-monorepo/prover/circuits" + "github.com/consensys/zkevm-monorepo/prover/config" + public_input "github.com/consensys/zkevm-monorepo/prover/public-input" "math/big" "slices" @@ -11,7 +18,6 @@ import ( "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/lookup/logderivlookup" "github.com/consensys/gnark/std/math/cmp" - "github.com/consensys/zkevm-monorepo/prover/circuits/aggregation" decompression "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v1" "github.com/consensys/zkevm-monorepo/prover/circuits/execution" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" @@ -23,12 +29,13 @@ import ( type Circuit struct { AggregationPublicInput frontend.Variable `gnark:",public"` // the public input of the aggregation circuit - DecompressionPublicInput []frontend.Variable `gnark:",public"` ExecutionPublicInput []frontend.Variable `gnark:",public"` + DecompressionPublicInput []frontend.Variable `gnark:",public"` DecompressionFPIQ []decompression.FunctionalPublicInputQSnark ExecutionFPIQ []execution.FunctionalPublicInputQSnark - aggregation.FunctionalPublicInputQSnark + + public_input.AggregationFPIQSnark Keccak keccak.StrictHasherCircuit @@ -36,7 +43,7 @@ type Circuit struct { L2MessageMerkleDepth int L2MessageMaxNbMerkle int - MaxNbCircuits int + MaxNbCircuits int // possibly useless TODO consider removing UseGkrMimc bool } @@ -46,7 +53,7 @@ func (c *Circuit) Define(api frontend.API) error { return errors.New("public / functional public input length mismatch") } - c.FunctionalPublicInputQSnark.RangeCheck(api) + c.AggregationFPIQSnark.RangeCheck(api) rDecompression := internal.NewRange(api, c.NbDecompression, len(c.DecompressionPublicInput)) hshK := c.Keccak.NewHasher(api) @@ -90,8 +97,8 @@ func (c *Circuit) Define(api frontend.API) error { piq.RangeCheck(api) shnarfParams[i] = ShnarfIteration{ // prepare shnarf verification data - BlobDataSnarkHash: internal.ToBytes(api, piq.SnarkHash), - NewStateRootHash: internal.ToBytes(api, finalStateRootHashes.Lookup(api.Sub(nbBatchesSums[i], 1))[0]), + BlobDataSnarkHash: utils.ToBytes(api, piq.SnarkHash), + NewStateRootHash: utils.ToBytes(api, finalStateRootHashes.Lookup(api.Sub(nbBatchesSums[i], 1))[0]), EvaluationPointBytes: piq.X, EvaluationClaimBytes: fr377EncodedFr381ToBytes(api, piq.Y), } @@ -162,20 +169,20 @@ func (c *Circuit) Define(api frontend.API) error { } rExecution := internal.NewRange(api, nbExecution, maxNbExecution) - pi := aggregation.FunctionalPublicInputSnark{ - FunctionalPublicInputQSnark: c.FunctionalPublicInputQSnark, - NbL2Messages: merkleLeavesConcat.Length, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), - FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), - FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), - FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }), - FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }), - FinalShnarf: rDecompression.LastArray32(shnarfs), - L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, + pi := public_input.AggregationFPISnark{ + AggregationFPIQSnark: c.AggregationFPIQSnark, + NbL2Messages: merkleLeavesConcat.Length, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), + FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), + FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), + FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }), + FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }), + FinalShnarf: rDecompression.LastArray32(shnarfs), + L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, } for i := range pi.L2MsgMerkleTreeRoots { - pi.L2MsgMerkleTreeRoots[i] = merkleRoot(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves]) + pi.L2MsgMerkleTreeRoots[i] = MerkleRootSnark(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves]) } // "open" aggregation public input @@ -185,7 +192,7 @@ func (c *Circuit) Define(api frontend.API) error { return hshK.Finalize() } -func merkleRoot(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]frontend.Variable { +func MerkleRootSnark(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]frontend.Variable { values := slices.Clone(leaves) if !utils.IsPowerOfTwo(len(values)) { @@ -201,38 +208,28 @@ func merkleRoot(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]fro return values[0] } -type Config struct { - MaxNbDecompression int - MaxNbExecution int - MaxNbCircuits int - MaxNbKeccakF int - MaxNbMsgPerExecution int - L2MsgMerkleDepth int - L2MessageMaxNbMerkle int // if not explicitly provided (i.e. non-positive) it will be set to maximum -} - type Compiled struct { Circuit *Circuit Keccak keccak.CompiledStrictHasher } -func (c Config) Compile(wizardCompilationOpts ...func(iop *wizard.CompiledIOP)) (*Compiled, error) { +func Compile(c config.PublicInput, wizardCompilationOpts ...func(iop *wizard.CompiledIOP)) (*Compiled, error) { - if c.L2MessageMaxNbMerkle <= 0 { + if c.L2MsgMaxNbMerkle <= 0 { merkleNbLeaves := 1 << c.L2MsgMerkleDepth - c.L2MessageMaxNbMerkle = (c.MaxNbExecution*c.MaxNbMsgPerExecution + merkleNbLeaves - 1) / merkleNbLeaves + c.L2MsgMaxNbMerkle = (c.MaxNbExecution*c.ExecutionMaxNbMsg + merkleNbLeaves - 1) / merkleNbLeaves } - sh := c.newKeccakCompiler().Compile(c.MaxNbKeccakF, wizardCompilationOpts...) + sh := newKeccakCompiler(c).Compile(c.MaxNbKeccakF, wizardCompilationOpts...) shc, err := sh.GetCircuit() if err != nil { return nil, err } - circuit := c.allocateCircuit() + circuit := allocateCircuit(c) circuit.Keccak = shc for i := range circuit.ExecutionFPIQ { - circuit.ExecutionFPIQ[i].L2MessageHashes.Values = make([][32]frontend.Variable, c.MaxNbMsgPerExecution) + circuit.ExecutionFPIQ[i].L2MessageHashes.Values = make([][32]frontend.Variable, c.ExecutionMaxNbMsg) } return &Compiled{ @@ -241,33 +238,33 @@ func (c Config) Compile(wizardCompilationOpts ...func(iop *wizard.CompiledIOP)) }, nil } -func (c *Compiled) getConfig() Config { - return Config{ - MaxNbDecompression: len(c.Circuit.DecompressionFPIQ), - MaxNbExecution: len(c.Circuit.ExecutionFPIQ), - MaxNbKeccakF: c.Keccak.MaxNbKeccakF(), - MaxNbMsgPerExecution: len(c.Circuit.ExecutionFPIQ[0].L2MessageHashes.Values), - L2MsgMerkleDepth: c.Circuit.L2MessageMerkleDepth, - L2MessageMaxNbMerkle: c.Circuit.L2MessageMaxNbMerkle, - MaxNbCircuits: c.Circuit.MaxNbCircuits, +func (c *Compiled) getConfig() config.PublicInput { + return config.PublicInput{ + MaxNbDecompression: len(c.Circuit.DecompressionFPIQ), + MaxNbExecution: len(c.Circuit.ExecutionFPIQ), + MaxNbKeccakF: c.Keccak.MaxNbKeccakF(), + ExecutionMaxNbMsg: len(c.Circuit.ExecutionFPIQ[0].L2MessageHashes.Values), + L2MsgMerkleDepth: c.Circuit.L2MessageMerkleDepth, + L2MsgMaxNbMerkle: c.Circuit.L2MessageMaxNbMerkle, + MaxNbCircuits: c.Circuit.MaxNbCircuits, } } -func (c Config) allocateCircuit() Circuit { +func allocateCircuit(c config.PublicInput) Circuit { return Circuit{ DecompressionPublicInput: make([]frontend.Variable, c.MaxNbDecompression), ExecutionPublicInput: make([]frontend.Variable, c.MaxNbExecution), DecompressionFPIQ: make([]decompression.FunctionalPublicInputQSnark, c.MaxNbDecompression), ExecutionFPIQ: make([]execution.FunctionalPublicInputQSnark, c.MaxNbExecution), L2MessageMerkleDepth: c.L2MsgMerkleDepth, - L2MessageMaxNbMerkle: c.L2MessageMaxNbMerkle, + L2MessageMaxNbMerkle: c.L2MsgMaxNbMerkle, MaxNbCircuits: c.MaxNbCircuits, } } -func (c Config) newKeccakCompiler() *keccak.StrictHasherCompiler { +func newKeccakCompiler(c config.PublicInput) *keccak.StrictHasherCompiler { nbShnarf := c.MaxNbDecompression - nbMerkle := c.L2MessageMaxNbMerkle * ((1 << c.L2MsgMerkleDepth) - 1) + nbMerkle := c.L2MsgMaxNbMerkle * ((1 << c.L2MsgMerkleDepth) - 1) res := keccak.NewStrictHasherCompiler(nbShnarf, nbMerkle, 2) for i := 0; i < nbShnarf; i++ { res.WithHashLengths(160) // 5 components in every shnarf @@ -278,8 +275,36 @@ func (c Config) newKeccakCompiler() *keccak.StrictHasherCompiler { } // aggregation PI opening - res.WithHashLengths(32 * c.L2MessageMaxNbMerkle) + res.WithHashLengths(32 * c.L2MsgMaxNbMerkle) res.WithHashLengths(384) return &res } + +type builder struct { + *config.PublicInput +} + +func NewBuilder(c config.PublicInput) circuits.Builder { + return builder{&c} +} + +func (b builder) Compile() (constraint.ConstraintSystem, error) { + c, err := Compile(*b.PublicInput, WizardCompilationParameters()...) + if err != nil { + return nil, err + } + const estimatedNbConstraints = 35_000_000 + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, c.Circuit, frontend.WithCapacity(estimatedNbConstraints)) + if err != nil { + return nil, err + } + if nbC := cs.GetNbConstraints(); nbC > estimatedNbConstraints || estimatedNbConstraints-nbC > 5_000_000 { + return nil, fmt.Errorf("constraint estimate is off; got %d", nbC) + } + return cs, nil +} + +func WizardCompilationParameters() []func(iop *wizard.CompiledIOP) { + panic("implement me") +} diff --git a/prover/circuits/pi-interconnection/circuit_test.go b/prover/circuits/pi-interconnection/circuit_test.go index dc6c4eab..c9fc48f3 100644 --- a/prover/circuits/pi-interconnection/circuit_test.go +++ b/prover/circuits/pi-interconnection/circuit_test.go @@ -1,4 +1,4 @@ -package pi_interconnection +package pi_interconnection_test import ( "errors" @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark/test" "github.com/consensys/zkevm-monorepo/prover/backend/aggregation" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" + pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak" "github.com/consensys/zkevm-monorepo/prover/utils" "github.com/stretchr/testify/assert" @@ -38,7 +39,7 @@ func TestMerkle(t *testing.T) { assert.NoError(t, err) toHashBytes[i] = x.Bytes() toHashHex[i] = utils.HexEncodeToString(toHashBytes[i][:]) - assert.Equal(t, 32, internal.Copy(toHashSnark[i][:], toHashBytes[i][:])) + assert.Equal(t, 32, utils.Copy(toHashSnark[i][:], toHashBytes[i][:])) } for i := len(toHashHex); i < len(toHashSnark); i++ { // pad with zeros for j := range toHashSnark[i] { @@ -49,7 +50,7 @@ func TestMerkle(t *testing.T) { hsh := sha3.NewLegacyKeccak256() for i := range rootsHex { - root := MerkleRoot(hsh, c.nbLeaves, toHashBytes[i*c.nbLeaves:min(len(toHashBytes), (i+1)*c.nbLeaves)]) + root := pi_interconnection.MerkleRoot(hsh, c.nbLeaves, toHashBytes[i*c.nbLeaves:min(len(toHashBytes), (i+1)*c.nbLeaves)]) rootHex := utils.HexEncodeToString(root[:]) assert.Equal(t, rootsHex[i], rootHex) } @@ -90,7 +91,7 @@ func (c *testMerkleCircuit) Define(api frontend.API) error { } for i := range c.Roots { - root := merkleRoot(hshK, c.ToHash[i*nbLeaves:(i+1)*nbLeaves]) + root := pi_interconnection.MerkleRootSnark(hshK, c.ToHash[i*nbLeaves:(i+1)*nbLeaves]) internal.AssertSliceEquals(api, c.Roots[i][:], root[:]) } diff --git a/prover/circuits/pi-interconnection/compile/test_compile.go b/prover/circuits/pi-interconnection/compile/test_compile.go index c1ff5c76..0d3588dd 100644 --- a/prover/circuits/pi-interconnection/compile/test_compile.go +++ b/prover/circuits/pi-interconnection/compile/test_compile.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils" pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" + "github.com/consensys/zkevm-monorepo/prover/config" "github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy" "github.com/stretchr/testify/assert" ) @@ -16,14 +17,14 @@ func main() { fmt.Println("creating wizard circuit") - c, err := pi_interconnection.Config{ - MaxNbDecompression: 400, - MaxNbExecution: 400, - MaxNbKeccakF: 10000, - MaxNbMsgPerExecution: 16, - L2MsgMerkleDepth: 5, - L2MessageMaxNbMerkle: 10, - }.Compile(dummy.Compile) + c, err := pi_interconnection.Compile(config.PublicInput{ + MaxNbDecompression: 400, + MaxNbExecution: 400, + MaxNbKeccakF: 10000, + ExecutionMaxNbMsg: 16, + L2MsgMerkleDepth: 5, + L2MsgMaxNbMerkle: 10, + }, dummy.Compile) var t test_utils.FakeTestingT assert.NoError(t, err) diff --git a/prover/circuits/pi-interconnection/e2e_test.go b/prover/circuits/pi-interconnection/e2e_test.go index e5ceaf1b..acb5abe4 100644 --- a/prover/circuits/pi-interconnection/e2e_test.go +++ b/prover/circuits/pi-interconnection/e2e_test.go @@ -17,6 +17,7 @@ import ( "github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils" pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" pitesting "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/test_utils" + "github.com/consensys/zkevm-monorepo/prover/config" "github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc" blobtesting "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1/test_utils" "github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy" @@ -34,15 +35,15 @@ func TestSingleBlockBlob(t *testing.T) { func TestSingleBlobBlobE2E(t *testing.T) { req := pitesting.AssignSingleBlockBlob(t) - config := pi_interconnection.Config{ - MaxNbDecompression: len(req.Decompressions), - MaxNbExecution: len(req.Executions), - MaxNbKeccakF: 100, - MaxNbMsgPerExecution: 1, - L2MsgMerkleDepth: 5, - L2MessageMaxNbMerkle: 1, + cfg := config.PublicInput{ + MaxNbDecompression: len(req.Decompressions), + MaxNbExecution: len(req.Executions), + MaxNbKeccakF: 100, + ExecutionMaxNbMsg: 1, + L2MsgMerkleDepth: 5, + L2MsgMaxNbMerkle: 1, } - compiled, err := config.Compile(dummy.Compile) + compiled, err := pi_interconnection.Compile(cfg, dummy.Compile) assert.NoError(t, err) a, err := compiled.Assign(req) @@ -74,7 +75,7 @@ func TestTinyTwoBatchBlob(t *testing.T) { blob := blobtesting.TinyTwoBatchBlob(t) - execReq := []pi_interconnection.ExecutionRequest{{ + execReq := []public_input.Execution{{ L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(3)}, FinalStateRootHash: internal.Uint64To32Bytes(4), FinalBlockNumber: 5, @@ -104,7 +105,6 @@ func TestTinyTwoBatchBlob(t *testing.T) { merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes)) req := pi_interconnection.Request{ - DecompDict: blobtesting.GetDict(t), Decompressions: []blobsubmission.Response{*blobResp}, Executions: execReq, Aggregation: public_input.Aggregation{ @@ -131,7 +131,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) { t.Skipf("Flacky test due to the number of keccakf outgoing the limit specified for the test") blobs := blobtesting.ConsecutiveBlobs(t, 2, 2) - execReq := []pi_interconnection.ExecutionRequest{{ + execReq := []public_input.Execution{{ L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(3)}, FinalStateRootHash: internal.Uint64To32Bytes(4), FinalBlockNumber: 5, @@ -186,7 +186,6 @@ func TestTwoTwoBatchBlobs(t *testing.T) { merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes)) req := pi_interconnection.Request{ - DecompDict: blobtesting.GetDict(t), Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1}, Executions: execReq, Aggregation: public_input.Aggregation{ @@ -215,17 +214,17 @@ func testPI(t *testing.T, maxNbKeccakF int, req pi_interconnection.Request) { decomposeLittleEndian(t, slack[:], i, 3) - config := pi_interconnection.Config{ - MaxNbDecompression: len(req.Decompressions) + slack[0], - MaxNbExecution: len(req.Executions) + slack[1], - MaxNbKeccakF: maxNbKeccakF, - MaxNbMsgPerExecution: 1 + slack[2], - L2MsgMerkleDepth: 5, - L2MessageMaxNbMerkle: 1 + slack[3], + config := config.PublicInput{ + MaxNbDecompression: len(req.Decompressions) + slack[0], + MaxNbExecution: len(req.Executions) + slack[1], + MaxNbKeccakF: maxNbKeccakF, + ExecutionMaxNbMsg: 1 + slack[2], + L2MsgMerkleDepth: 5, + L2MsgMaxNbMerkle: 1 + slack[3], } t.Run(fmt.Sprintf("slack profile %v", slack), func(t *testing.T) { - compiled, err := config.Compile(dummy.Compile) + compiled, err := pi_interconnection.Compile(config, dummy.Compile) assert.NoError(t, err) a, err := compiled.Assign(req) diff --git a/prover/circuits/pi-interconnection/keccak/assign_test.go b/prover/circuits/pi-interconnection/keccak/assign_test.go index f836ce64..b11a9f61 100644 --- a/prover/circuits/pi-interconnection/keccak/assign_test.go +++ b/prover/circuits/pi-interconnection/keccak/assign_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark/test" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" "github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/zkevm-monorepo/prover/utils" "github.com/stretchr/testify/assert" "testing" ) @@ -36,8 +37,8 @@ func TestAssign(t *testing.T) { } assignment.H, err = hsh.Assign() assert.NoError(t, err) - internal.Copy(assignment.Outs[0][:], res) - internal.Copy(assignment.Ins[0][0][:], zero[:]) + utils.Copy(assignment.Outs[0][:], res) + utils.Copy(assignment.Ins[0][0][:], zero[:]) assert.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) } diff --git a/prover/circuits/pi-interconnection/test_utils/test_utils.go b/prover/circuits/pi-interconnection/test_utils/test_utils.go index 9d29c124..cf86e28e 100644 --- a/prover/circuits/pi-interconnection/test_utils/test_utils.go +++ b/prover/circuits/pi-interconnection/test_utils/test_utils.go @@ -31,7 +31,7 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request { blobResp, err := blobsubmission.CraftResponse(&blobReq) assert.NoError(t, err) - execReq := pi_interconnection.ExecutionRequest{ + execReq := public_input.Execution{ L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(4)}, FinalStateRootHash: finalStateRootHash, FinalBlockNumber: 9, @@ -43,9 +43,8 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request { merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq.L2MsgHashes)) return pi_interconnection.Request{ - DecompDict: blobtesting.GetDict(t), Decompressions: []blobsubmission.Response{*blobResp}, - Executions: []pi_interconnection.ExecutionRequest{execReq}, + Executions: []public_input.Execution{execReq}, Aggregation: public_input.Aggregation{ FinalShnarf: blobResp.ExpectedShnarf, ParentAggregationFinalShnarf: blobReq.PrevShnarf, diff --git a/prover/circuits/registry.go b/prover/circuits/registry.go index ead0710f..e19e36c7 100644 --- a/prover/circuits/registry.go +++ b/prover/circuits/registry.go @@ -1,19 +1,20 @@ package circuits // CircuitID is a type to represent the different circuits. -// Is is used to identify the circuit to be used in the prover. +// It is used to identify the circuit to be used in the prover. type CircuitID string const ( - ExecutionCircuitID CircuitID = "execution" - ExecutionLargeCircuitID CircuitID = "execution-large" - BlobDecompressionV0CircuitID CircuitID = "blob-decompression-v0" - BlobDecompressionV1CircuitID CircuitID = "blob-decompression-v1" - AggregationCircuitID CircuitID = "aggregation" - EmulationCircuitID CircuitID = "emulation" - EmulationDummyCircuitID CircuitID = "emulation-dummy" - ExecutionDummyCircuitID CircuitID = "execution-dummy" - BlobDecompressionDummyCircuitID CircuitID = "blob-decompression-dummy" + ExecutionCircuitID CircuitID = "execution" + ExecutionLargeCircuitID CircuitID = "execution-large" + BlobDecompressionV0CircuitID CircuitID = "blob-decompression-v0" + BlobDecompressionV1CircuitID CircuitID = "blob-decompression-v1" + AggregationCircuitID CircuitID = "aggregation" + EmulationCircuitID CircuitID = "emulation" + EmulationDummyCircuitID CircuitID = "emulation-dummy" + ExecutionDummyCircuitID CircuitID = "execution-dummy" + BlobDecompressionDummyCircuitID CircuitID = "blob-decompression-dummy" + PublicInputInterconnectionCircuitID CircuitID = "public-input-interconnection" ) // MockCircuitID is a type to represent the different mock circuits. diff --git a/prover/circuits/setup.go b/prover/circuits/setup.go index bce79641..81fe4407 100644 --- a/prover/circuits/setup.go +++ b/prover/circuits/setup.go @@ -89,7 +89,7 @@ func (s *Setup) CurveID() ecc.ID { return fieldToCurve(s.Circuit.Field()) } -func (s *Setup) VerifiyingKeyDigest() string { +func (s *Setup) VerifyingKeyDigest() string { r, err := objectChecksum(s.VerifyingKey) if err != nil { utils.Panic("could not get the verifying key digest: %v", err) diff --git a/prover/cmd/prover/cmd/setup.go b/prover/cmd/prover/cmd/setup.go index 8f5d9538..84777d14 100644 --- a/prover/cmd/prover/cmd/setup.go +++ b/prover/cmd/prover/cmd/setup.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "fmt" + pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection" "io" "os" "path/filepath" @@ -52,6 +53,7 @@ var allCircuits = []string{ string(circuits.AggregationCircuitID), string(circuits.EmulationCircuitID), string(circuits.EmulationDummyCircuitID), // we want to generate Verifier.sol for this one + string(circuits.PublicInputInterconnectionCircuitID), } func init() { @@ -102,7 +104,7 @@ func cmdSetup(cmd *cobra.Command, args []string) error { } // for each circuit, we start by compiling the circuit - // the we do a shashum and compare against the one in the manifest.json + // then we do a sha sum and compare against the one in the manifest.json for c, setup := range inCircuits { if !setup { // we skip aggregation in this first loop since the setup is more complex @@ -139,11 +141,13 @@ func cmdSetup(cmd *cobra.Command, args []string) error { extraFlags["maxUncompressedBytes"] = blob_v1.MaxUncompressedBytes builder = v1.NewBuilder(len(dict)) } + case circuits.PublicInputInterconnectionCircuitID: + builder = pi_interconnection.NewBuilder(cfg.PublicInputInterconnection) case circuits.EmulationDummyCircuitID: // we can get the Verifier.sol from there. builder = dummy.NewBuilder(circuits.MockCircuitIDEmulation, ecc.BN254.ScalarField()) default: - continue // dummy, aggregation or emulation circuits are handled later + continue // dummy, aggregation, emulation or public input circuits are handled later } if err := updateSetup(cmd.Context(), cfg, srsProvider, c, builder, extraFlags); err != nil { @@ -164,6 +168,12 @@ func cmdSetup(cmd *cobra.Command, args []string) error { return nil } + // get verifying key for public-input circuit + piSetup, err := circuits.LoadSetup(cfg, circuits.PublicInputInterconnectionCircuitID) + if err != nil { + return fmt.Errorf("%s failed to load public input interconnection setup: %w", cmd.Name(), err) + } + // first, we need to collect the verifying keys var allowedVkForAggregation []plonk.VerifyingKey for _, allowedInput := range cfg.Aggregation.AllowedInputs { @@ -217,7 +227,7 @@ func cmdSetup(cmd *cobra.Command, args []string) error { c := circuits.CircuitID(fmt.Sprintf("%s-%d", string(circuits.AggregationCircuitID), numProofs)) logrus.Infof("setting up %s (numProofs=%d)", c, numProofs) - builder := aggregation.NewBuilder(numProofs, allowedVkForAggregation) + builder := aggregation.NewBuilder(numProofs, cfg.Aggregation.AllowedInputs, piSetup.VerifyingKey, allowedVkForAggregation) if err := updateSetup(cmd.Context(), cfg, srsProvider, c, builder, extraFlagsForAggregationCircuit); err != nil { return err } diff --git a/prover/config/config.go b/prover/config/config.go index c22f5777..6fbcf7fc 100644 --- a/prover/config/config.go +++ b/prover/config/config.go @@ -100,10 +100,11 @@ type Config struct { // accessed (prover). The file structure is described in TODO @gbotrel. AssetsDir string `mapstructure:"assets_dir" validate:"required,dir"` - Controller Controller - Execution Execution - BlobDecompression BlobDecompression `mapstructure:"blob_decompression"` - Aggregation Aggregation + Controller Controller + Execution Execution + BlobDecompression BlobDecompression `mapstructure:"blob_decompression"` + Aggregation Aggregation + PublicInputInterconnection PublicInput `mapstructure:"public_input_interconnection"` // TODO add wizard compilation params Debug struct { // Profiling indicates whether we want to generate profiles using the [runtime/pprof] pkg. @@ -249,3 +250,13 @@ func (cfg *WithRequestDir) DirTo() string { func (cfg *WithRequestDir) DirDone() string { return path.Join(cfg.RequestsRootDir, RequestsDoneSubDir) } + +type PublicInput struct { + MaxNbDecompression int `mapstructure:"max_nb_decompression" validate:"gte=0"` + MaxNbExecution int `mapstructure:"max_nb_execution" validate:"gte=0"` + MaxNbCircuits int `mapstructure:"max_nb_circuits" validate:"gte=0"` // if not set, will be set to MaxNbDecompression + MaxNbExecution + MaxNbKeccakF int `mapstructure:"max_nb_keccakf" validate:"gte=0"` + ExecutionMaxNbMsg int `mapstructure:"execution_max_nb_msg" validate:"gte=0"` + L2MsgMerkleDepth int `mapstructure:"l2_msg_merkle_depth" validate:"gte=0"` + L2MsgMaxNbMerkle int `mapstructure:"l2_msg_max_nb_merkle" validate:"gte=0"` // if not explicitly provided (i.e. non-positive) it will be set to maximum +} diff --git a/prover/integration/circuit-testing/aggregation/main.go b/prover/integration/circuit-testing/aggregation/main.go index 4f0bcffa..3c1875d2 100644 --- a/prover/integration/circuit-testing/aggregation/main.go +++ b/prover/integration/circuit-testing/aggregation/main.go @@ -45,7 +45,7 @@ func main() { // Building aggregation circuit for max `nc` proofs logrus.Infof("Building aggregation circuit for size of %v\n", nc) - ccs, err := aggregation.MakeCS(nc, vkeys) + ccs, err := aggregation.MakeCS(nc, []string{"blob-decompression-v1", "execution"}, nil, vkeys) // TODO @Tabaie add a PI key if err != nil { panic(err) } @@ -89,7 +89,8 @@ func main() { // Assigning the BW6 circuit logrus.Infof("Generating the aggregation proof for arity %v", nc) - bw6Proof, err := aggregation.MakeProof(&ppBw6, nc, innerProofClaims, frBw6.NewElement(10)) + // TODO @Tabaie add a PI proof + bw6Proof, err := aggregation.MakeProof(&ppBw6, nc, innerProofClaims, nil, frBw6.NewElement(10)) if err != nil { panic(err) } diff --git a/prover/public-input/aggregation.go b/prover/public-input/aggregation.go index d0ad457f..5264cdd6 100644 --- a/prover/public-input/aggregation.go +++ b/prover/public-input/aggregation.go @@ -2,9 +2,15 @@ package public_input import ( bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" + "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak" "github.com/consensys/zkevm-monorepo/prover/utils" + "github.com/consensys/zkevm-monorepo/prover/utils/types" "golang.org/x/crypto/sha3" "hash" + "slices" ) // Aggregation collects all the field that are used to construct the public @@ -79,3 +85,160 @@ func (p Aggregation) Sum(hsh hash.Hash) []byte { func (p Aggregation) GetPublicInputHex() string { return utils.HexEncodeToString(p.Sum(sha3.NewLegacyKeccak256())) } + +// AggregationFPI holds the same info as public_input.Aggregation, except in parsed form +type AggregationFPI struct { + ParentShnarf [32]byte + NbDecompression uint64 + InitialStateRootHash [32]byte + InitialBlockNumber uint64 + InitialBlockTimestamp uint64 + InitialRollingHash [32]byte + InitialRollingHashNumber uint64 + ChainID uint64 // for now we're forcing all executions to have the same chain ID + L2MessageServiceAddr types.EthAddress + NbL2Messages uint64 // TODO not used in hash. delete if not necessary + L2MsgMerkleTreeRoots [][32]byte + //FinalStateRootHash [32]byte redundant: incorporated into shnarf + FinalBlockNumber uint64 + FinalBlockTimestamp uint64 + FinalRollingHash [32]byte + FinalRollingHashNumber uint64 + FinalShnarf [32]byte + L2MsgMerkleTreeDepth int +} + +func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark { + s := AggregationFPISnark{ + AggregationFPIQSnark: AggregationFPIQSnark{ + InitialBlockNumber: pi.InitialBlockNumber, + InitialBlockTimestamp: pi.InitialBlockTimestamp, + InitialRollingHash: [32]frontend.Variable{}, + InitialRollingHashNumber: pi.InitialRollingHashNumber, + InitialStateRootHash: pi.InitialStateRootHash[:], + + NbDecompression: pi.NbDecompression, + ChainID: pi.ChainID, + L2MessageServiceAddr: pi.L2MessageServiceAddr[:], + }, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), + FinalBlockNumber: pi.FinalBlockNumber, + FinalBlockTimestamp: pi.FinalBlockTimestamp, + FinalRollingHashNumber: pi.FinalRollingHashNumber, + L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, + } + + utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:]) + utils.Copy(s.InitialRollingHash[:], pi.InitialRollingHash[:]) + utils.Copy(s.ParentShnarf[:], pi.ParentShnarf[:]) + utils.Copy(s.FinalShnarf[:], pi.FinalShnarf[:]) + + for i := range s.L2MsgMerkleTreeRoots { + utils.Copy(s.L2MsgMerkleTreeRoots[i][:], pi.L2MsgMerkleTreeRoots[i][:]) + } + + return s +} + +type AggregationFPIQSnark struct { + ParentShnarf [32]frontend.Variable + NbDecompression frontend.Variable + InitialStateRootHash frontend.Variable + InitialBlockNumber frontend.Variable + InitialBlockTimestamp frontend.Variable + InitialRollingHash [32]frontend.Variable + InitialRollingHashNumber frontend.Variable + ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID + L2MessageServiceAddr frontend.Variable // 20 bytes +} + +type AggregationFPISnark struct { + AggregationFPIQSnark + NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary + L2MsgMerkleTreeRoots [][32]frontend.Variable + // FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf + FinalBlockNumber frontend.Variable + FinalBlockTimestamp frontend.Variable + FinalRollingHash [32]frontend.Variable + FinalRollingHashNumber frontend.Variable + FinalShnarf [32]frontend.Variable + L2MsgMerkleTreeDepth int +} + +// NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation +func NewAggregationFPI(fpi *Aggregation) (s *AggregationFPI, err error) { + s = &AggregationFPI{ + InitialBlockNumber: uint64(fpi.LastFinalizedBlockNumber), + InitialBlockTimestamp: uint64(fpi.ParentAggregationLastBlockTimestamp), + InitialRollingHashNumber: uint64(fpi.LastFinalizedL1RollingHashMessageNumber), + L2MsgMerkleTreeRoots: make([][32]byte, len(fpi.L2MsgRootHashes)), + FinalBlockNumber: uint64(fpi.FinalBlockNumber), + FinalBlockTimestamp: uint64(fpi.FinalTimestamp), + FinalRollingHashNumber: uint64(fpi.L1RollingHashMessageNumber), + L2MsgMerkleTreeDepth: fpi.L2MsgMerkleTreeDepth, + } + + if err = copyFromHex(s.InitialStateRootHash[:], fpi.ParentStateRootHash); err != nil { + return + } + if err = copyFromHex(s.FinalRollingHash[:], fpi.L1RollingHash); err != nil { + return + } + if err = copyFromHex(s.InitialRollingHash[:], fpi.LastFinalizedL1RollingHash); err != nil { + return + } + if err = copyFromHex(s.ParentShnarf[:], fpi.ParentAggregationFinalShnarf); err != nil { + return + } + if err = copyFromHex(s.FinalShnarf[:], fpi.FinalShnarf); err != nil { + return + } + + for i := range s.L2MsgMerkleTreeRoots { + if err = copyFromHex(s.L2MsgMerkleTreeRoots[i][:], fpi.L2MsgRootHashes[i]); err != nil { + return + } + } + return +} + +func (pi *AggregationFPISnark) Sum(api frontend.API, hash keccak.BlockHasher) [32]frontend.Variable { + // number of hashes: 12 + sum := hash.Sum(nil, + pi.ParentShnarf, + pi.FinalShnarf, + utils.ToBytes(api, pi.InitialBlockTimestamp), + utils.ToBytes(api, pi.FinalBlockTimestamp), + utils.ToBytes(api, pi.InitialBlockNumber), + utils.ToBytes(api, pi.FinalBlockNumber), + pi.InitialRollingHash, + pi.FinalRollingHash, + utils.ToBytes(api, pi.InitialRollingHashNumber), + utils.ToBytes(api, pi.FinalRollingHashNumber), + utils.ToBytes(api, pi.L2MsgMerkleTreeDepth), + hash.Sum(nil, pi.L2MsgMerkleTreeRoots...), + ) + + // turn the hash into a bn254 element + var res [32]frontend.Variable + copy(res[:], utils.ReduceBytes[emulated.BN254Fr](api, sum[:])) + return res +} + +func (pi *AggregationFPIQSnark) RangeCheck(api frontend.API) { + rc := rangecheck.New(api) + for _, v := range append(slices.Clone(pi.InitialRollingHash[:]), pi.ParentShnarf[:]...) { + rc.Check(v, 8) + } + // not checking L2MsgServiceAddr as its range is never assumed in the pi circuit + // not checking NbDecompressions as the NewRange in the pi circuit range checks it; TODO do it here instead +} + +func copyFromHex(dst []byte, src string) error { + b, err := utils.HexDecodeString(src) + if err != nil { + return err + } + copy(dst[len(dst)-len(b):], b) // panics if src is too long + return nil +} diff --git a/prover/public-input/execution.go b/prover/public-input/execution.go new file mode 100644 index 00000000..c970b31c --- /dev/null +++ b/prover/public-input/execution.go @@ -0,0 +1,61 @@ +package public_input + +import ( + "errors" + "github.com/consensys/zkevm-monorepo/prover/utils" +) + +type Execution struct { + L2MsgHashes [][32]byte + FinalStateRootHash [32]byte + FinalBlockNumber uint64 + FinalBlockTimestamp uint64 + FinalRollingHash [32]byte + FinalRollingHashNumber uint64 +} + +type ExecutionSerializable struct { + L2MsgHashes []string `json:"l2MsgHashes"` + FinalStateRootHash string `json:"finalStateRootHash"` + FinalBlockNumber uint64 `json:"finalBlockNumber"` + FinalBlockTimestamp uint64 `json:"finalBlockTimestamp"` + FinalRollingHash string `json:"finalRollingHash"` + FinalRollingHashNumber uint64 `json:"finalRollingHashNumber"` +} + +func (e ExecutionSerializable) Decode() (decoded Execution, err error) { + decoded = Execution{ + L2MsgHashes: make([][32]byte, len(e.L2MsgHashes)), + FinalBlockNumber: e.FinalBlockNumber, + FinalBlockTimestamp: e.FinalBlockTimestamp, + FinalRollingHashNumber: e.FinalRollingHashNumber, + } + + fillWithHex := func(dst []byte, src string) { + var d []byte + if d, err = utils.HexDecodeString(src); err != nil { + return + } + if len(d) > len(dst) { + err = errors.New("decoded bytes too long") + } + n := len(dst) - len(d) + copy(dst[n:], d) + for n > 0 { + n-- + dst[n] = 0 + } + } + + for i, hex := range e.L2MsgHashes { + if fillWithHex(decoded.L2MsgHashes[i][:], hex); err != nil { + return + } + } + if fillWithHex(decoded.FinalStateRootHash[:], e.FinalStateRootHash); err != nil { + return + } + + fillWithHex(decoded.FinalRollingHash[:], e.FinalRollingHash) + return +} diff --git a/prover/utils/snark.go b/prover/utils/snark.go new file mode 100644 index 00000000..72f39bfb --- /dev/null +++ b/prover/utils/snark.go @@ -0,0 +1,120 @@ +package utils + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" + "math/big" +) + +func Copy[T any](dst []frontend.Variable, src []T) (n int) { + n = min(len(dst), len(src)) + + for i := 0; i < n; i++ { + dst[i] = src[i] + } + + return +} + +func ToBytes(api frontend.API, x frontend.Variable) [32]frontend.Variable { + var res [32]frontend.Variable + d := decomposeIntoBytes(api, x, fr.Bits) + slack := 32 - len(d) // should be zero + copy(res[slack:], d) + for i := 0; i < slack; i++ { + res[i] = 0 + } + return res +} + +func decomposeIntoBytes(api frontend.API, data frontend.Variable, nbBits int) []frontend.Variable { + if nbBits == 0 { + nbBits = api.Compiler().FieldBitLen() + } + + nbBytes := (api.Compiler().FieldBitLen() + 7) / 8 + + bytes, err := api.Compiler().NewHint(decomposeIntoBytesHint, nbBytes, data) + if err != nil { + panic(err) + } + lastNbBits := nbBits % 8 + if lastNbBits == 0 { + lastNbBits = 8 + } + rc := rangecheck.New(api) + api.AssertIsLessOrEqual(bytes[0], 1<= 0; j-- { + outs[i*nbBytes+j].Mod(&v, &radix) + v.Rsh(&v, 8) + } + if v.Cmp(&zero) != 0 { + return errors.New("not fitting in len(outs)/len(ins) many bytes") + } + } + return nil +} + +func RegisterHints() { + solver.RegisterHint(decomposeIntoBytesHint) +} + +// ReduceBytes reduces given bytes modulo a given field. As a side effect, the "bytes" are range checked +func ReduceBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) []frontend.Variable { + f, err := emulated.NewField[T](api) + if err != nil { + panic(err) + } + + bits := f.ToBits(NewElementFromBytes[T](api, bytes)) + res := make([]frontend.Variable, (len(bits)+7)/8) + copy(bits[:], bits) + for i := len(bits); i < len(bits); i++ { + bits[i] = 0 + } + for i := range res { + bitsStart := 8 * (len(res) - i - 1) + bitsEnd := bitsStart + 8 + if i == 0 { + bitsEnd = len(bits) + } + res[i] = api.FromBinary(bits[bitsStart:bitsEnd]...) + } + + return res +} + +// NewElementFromBytes range checks the bytes and gives a reduced field element +func NewElementFromBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) *emulated.Element[T] { + bits := make([]frontend.Variable, 8*len(bytes)) + for i := range bytes { + copy(bits[8*i:], api.ToBinary(bytes[len(bytes)-i-1], 8)) + } + + f, err := emulated.NewField[T](api) + if err != nil { + panic(err) + } + + return f.Reduce(f.Add(f.FromBits(bits...), f.Zero())) +}