Prover: implements the aggregation circuit for beta v1 (#3780)

---------

Signed-off-by: Arya Tabaie <arya.pourtabatabaie@gmail.com>
Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
Co-authored-by: AlexandreBelling <alexandrebelling8@gmail.com>
This commit is contained in:
Arya Tabaie
2024-08-14 07:24:08 -05:00
committed by GitHub
parent 78c84da689
commit c0568080a1
35 changed files with 1008 additions and 470 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"github.com/consensys/zkevm-monorepo/prover/backend/blobsubmission"
public_input "github.com/consensys/zkevm-monorepo/prover/public-input"
"path"
@@ -44,6 +45,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
}
)
cf.ExecutionPI = make([]public_input.Execution, 0, len(req.ExecutionProofs))
for i, execReqFPath := range req.ExecutionProofs {
po := &execution.Response{}
fpath := path.Join(cfg.Execution.DirTo(), execReqFPath)
@@ -66,8 +69,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
utils.Panic("conflated batch %v reports a parent state hash mismatch, but this is not the first batch of the sequence", i)
}
// This is purposefuly overwritten at each iteration over i. We want to
// keep the final velue.
// This is purposefully overwritten at each iteration over i. We want to
// keep the final value.
cf.FinalBlockNumber = uint(po.FirstBlockNumber + len(po.BlocksData) - 1)
for _, blockdata := range po.BlocksData {
@@ -86,14 +89,30 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
// Append the proof claim to the list of collected proofs
if !cf.IsProoflessJob {
pClaim, err := parseProofClaim(po.Proof, po.DebugData.FinalHash, po.VerifyingKeyShaSum)
pClaim, err := parseProofClaim(po.Proof, po.DebugData.FinalHash, po.VerifyingKeyShaSum) // @gbotrel Is finalHash the state hash? If so why is it given as the execution circuit's PI?
if err != nil {
return nil, fmt.Errorf("could not parse the proof claim for `%v` : %w", fpath, err)
}
cf.ProofClaims = append(cf.ProofClaims, *pClaim)
// TODO make sure this belongs in the if
finalBlock := &po.BlocksData[len(po.BlocksData)-1]
piq, err := public_input.ExecutionSerializable{
L2MsgHashes: l2MessageHashes,
FinalStateRootHash: po.DebugData.FinalHash, // TODO @tabaie make sure this is the right value
FinalBlockNumber: uint64(cf.FinalBlockNumber),
FinalBlockTimestamp: finalBlock.TimeStamp,
FinalRollingHash: cf.L1RollingHash,
FinalRollingHashNumber: uint64(cf.L1RollingHashMessageNumber),
}.Decode()
if err != nil {
return nil, err
}
cf.ExecutionPI = append(cf.ExecutionPI, piq)
}
}
cf.DecompressionPI = make([]blobsubmission.Response, 0, len(req.CompressionProofs))
for i, decompReqFPath := range req.CompressionProofs {
dp := &blobdecompression.Response{}
fpath := path.Join(cfg.BlobDecompression.DirTo(), decompReqFPath)
@@ -120,6 +139,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
return nil, fmt.Errorf("could not parse the proof claim for `%v` : %w", fpath, err)
}
cf.ProofClaims = append(cf.ProofClaims, *pClaim)
cf.DecompressionPI = append(cf.DecompressionPI, dp.Request)
}
}

View File

@@ -2,6 +2,8 @@ package aggregation
import (
"fmt"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
public_input "github.com/consensys/zkevm-monorepo/prover/public-input"
"math"
"path/filepath"
@@ -43,7 +45,12 @@ func makeProof(
return makeDummyProof(cfg, publicInput, circuits.MockCircuitIDEmulation), nil
}
proofBW6, circuitID, err := makeBw6Proof(cfg, cf, publicInput)
piProof, err := makePiProof(cfg, cf)
if err != nil {
return "", fmt.Errorf("could not create the public input proof: %w", err)
}
proofBW6, circuitID, err := makeBw6Proof(cfg, cf, piProof, publicInput)
if err != nil {
return "", fmt.Errorf("error when running the BW6 proof: %w", err)
}
@@ -56,6 +63,43 @@ func makeProof(
return circuits.SerializeProofSolidityBn254(proofBn254), nil
}
func makePiProof(cfg *config.Config, cf *CollectedFields) (plonk.Proof, error) {
c, err := pi_interconnection.Compile(cfg.PublicInputInterconnection, pi_interconnection.WizardCompilationParameters()...)
if err != nil {
return nil, fmt.Errorf("could not create the public-input circuit: %w", err)
}
assignment, err := c.Assign(pi_interconnection.Request{
Decompressions: cf.DecompressionPI,
Executions: cf.ExecutionPI,
Aggregation: public_input.Aggregation{
FinalShnarf: cf.FinalShnarf,
ParentAggregationFinalShnarf: cf.ParentAggregationFinalShnarf,
ParentStateRootHash: cf.ParentStateRootHash,
ParentAggregationLastBlockTimestamp: cf.ParentAggregationLastBlockTimestamp,
FinalTimestamp: cf.FinalTimestamp,
LastFinalizedBlockNumber: cf.LastFinalizedBlockNumber,
FinalBlockNumber: cf.FinalBlockNumber,
LastFinalizedL1RollingHash: cf.LastFinalizedL1RollingHash,
L1RollingHash: cf.L1RollingHash,
LastFinalizedL1RollingHashMessageNumber: cf.LastFinalizedL1RollingHashMessageNumber,
L1RollingHashMessageNumber: cf.L1RollingHashMessageNumber,
L2MsgRootHashes: cf.L2MsgRootHashes,
L2MsgMerkleTreeDepth: int(cf.L2MsgTreeDepth),
},
})
if err != nil {
return nil, fmt.Errorf("could not assign the public input circuit: %w", err)
}
setup, err := circuits.LoadSetup(cfg, circuits.PublicInputInterconnectionCircuitID)
if err != nil {
return nil, fmt.Errorf("could not load the setup: %w", err)
}
return circuits.ProveCheck(&setup, &assignment)
}
// Generates a fake proof. The public input is given in hex string format.
// Returns the proof in hex string format. The circuit ID parameter specifies
// for which circuit should the proof be generated.
@@ -86,6 +130,7 @@ func makeDummyProof(cfg *config.Config, input string, circID circuits.MockCircui
func makeBw6Proof(
cfg *config.Config,
cf *CollectedFields,
piProof plonk.Proof,
publicInput string,
) (proof plonk.Proof, circuitID int, err error) {
@@ -156,7 +201,7 @@ func makeBw6Proof(
assignCircuitIDToProofClaims(bestAllowedVkForAggregation, cf.ProofClaims)
// Although the public input is restrained to fit on the BN254 scalar field,
// the BW6 field is larger. This allow us representing the public input as a
// the BW6 field is larger. This allows us to represent the public input as a
// single field element.
var piBW6 frBW6.Element
@@ -166,7 +211,7 @@ func makeBw6Proof(
}
logrus.Infof("running the BW6 prover")
proofBW6, err := aggregation.MakeProof(&setup, bestSize, cf.ProofClaims, piBW6)
proofBW6, err := aggregation.MakeProof(&setup, bestSize, cf.ProofClaims, piProof, piBW6)
if err != nil {
return nil, 0, fmt.Errorf("could not create BW6 proof: %w", err)
}

View File

@@ -1,7 +1,9 @@
package aggregation
import (
"github.com/consensys/zkevm-monorepo/prover/backend/blobdecompression"
"github.com/consensys/zkevm-monorepo/prover/circuits/aggregation"
public_input "github.com/consensys/zkevm-monorepo/prover/public-input"
)
// Request collects all the fields used to perform an aggregation request.
@@ -96,4 +98,7 @@ type CollectedFields struct {
// The proof claims for the execution prover
ProofClaims []aggregation.ProofClaimAssignment
ExecutionPI []public_input.Execution
DecompressionPI []blobdecompression.Request
}

View File

@@ -181,7 +181,7 @@ func Prove(cfg *config.Config, req *Request) (*Response, error) {
Request: *req,
ProverVersion: cfg.Version,
DecompressionProof: circuits.SerializeProofRaw(proof),
VerifyingKeyShaSum: setup.VerifiyingKeyDigest(),
VerifyingKeyShaSum: setup.VerifyingKeyDigest(),
}
resp.Debug.PublicInput = "0x" + pubInput.Text(16)
@@ -227,7 +227,7 @@ func dummyProve(cfg *config.Config, req *Request) (*Response, error) {
Request: *req,
ProverVersion: cfg.Version,
DecompressionProof: proof,
VerifyingKeyShaSum: setup.VerifiyingKeyDigest(),
VerifyingKeyShaSum: setup.VerifyingKeyDigest(),
}
inputString := utils.HexEncodeToString(input)

View File

@@ -113,7 +113,7 @@ func mustProveAndPass(
utils.Panic(err.Error())
}
return dummy.MakeProof(&setup, publicInput, circuits.MockCircuitIDExecution), setup.VerifiyingKeyDigest()
return dummy.MakeProof(&setup, publicInput, circuits.MockCircuitIDExecution), setup.VerifyingKeyDigest()
case config.ProverModeFull:
logrus.Info("Running the FULL prover")
@@ -157,7 +157,7 @@ func mustProveAndPass(
}
// TODO: implements the collection of the functional inputs from the prover response
return execution.MakeProof(setup, fullZkEvm.WizardIOP, proof, execution.FunctionalPublicInput{}, publicInput), setup.VerifiyingKeyDigest()
return execution.MakeProof(setup, fullZkEvm.WizardIOP, proof, execution.FunctionalPublicInput{}, publicInput), setup.VerifyingKeyDigest()
default:
panic("not implemented")
}

View File

@@ -11,33 +11,43 @@ import (
)
type builder struct {
maxNbProofs int
vKeys []plonk.VerifyingKey
maxNbProofs int
vKeys []plonk.VerifyingKey
allowedInputs []string
piVKey plonk.VerifyingKey
}
func NewBuilder(
maxNbProofs int,
allowedInputs []string,
piVKey plonk.VerifyingKey,
vKeys []plonk.VerifyingKey,
) *builder {
return &builder{
maxNbProofs: maxNbProofs,
vKeys: vKeys,
piVKey: piVKey,
allowedInputs: allowedInputs,
maxNbProofs: maxNbProofs,
vKeys: vKeys,
}
}
func (b *builder) Compile() (constraint.ConstraintSystem, error) {
return MakeCS(b.maxNbProofs, b.vKeys)
return MakeCS(b.maxNbProofs, b.allowedInputs, b.piVKey, b.vKeys)
}
// Initializes the bw6 aggregation circuit and returns a compiled constraint
// system.
func MakeCS(
maxNbProofs int,
allowedInputs []string,
piVKey plonk.VerifyingKey,
vKeys []plonk.VerifyingKey,
) (constraint.ConstraintSystem, error) {
aggCircuit, err := AllocateAggregationCircuit(
aggCircuit, err := AllocateCircuit(
maxNbProofs,
allowedInputs,
piVKey,
vKeys,
)

View File

@@ -3,18 +3,15 @@
package aggregation
import (
"errors"
"fmt"
"github.com/consensys/gnark/backend/plonk"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/native/sw_bls12377"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
emPlonk "github.com/consensys/gnark/std/recursion/plonk"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak"
public_input "github.com/consensys/zkevm-monorepo/prover/public-input"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/consensys/zkevm-monorepo/prover/utils/types"
"slices"
)
@@ -32,9 +29,9 @@ type (
// emBaseVkey = emPlonk.BaseVerifyingKey[emFr, emG1, emG2]
)
// The AggregationCircuit is used to aggregate multiple execution proofs and
// The Circuit is used to aggregate multiple execution proofs and
// aggregation proofs together.
type AggregationCircuit struct {
type Circuit struct {
// The list of claims to be provided to the circuit.
ProofClaims []proofClaim `gnark:",secret"`
@@ -42,29 +39,76 @@ type AggregationCircuit struct {
// is treated as a constant by the circuit.
verifyingKeys []emVkey `gnark:"-"`
// Dummy general public input
DummyPublicInput frontend.Variable `gnark:",public"`
publicInputVerifyingKey emVkey `gnark:"-"`
PublicInputProof emProof `gnark:",secret"`
PublicInputWitness emWitness `gnark:",secret"` // ordered for the PI circuit
PublicInputWitnessClaimIndexes []frontend.Variable `gnark:",secret"`
// general public input
PublicInput frontend.Variable `gnark:",public"`
}
func (c *AggregationCircuit) Define(api frontend.API) error {
// Verify the constraints the execution proofs
err := verifyClaimBatch(api, c.verifyingKeys, c.ProofClaims)
func (c *Circuit) Define(api frontend.API) error {
// match the public input with the public input of the PI circuit
field, err := emulated.NewField[emFr](api)
if err != nil {
return fmt.Errorf("processing execution proofs: %w", err)
return err
}
internal.AssertSliceEquals(api, api.ToBinary(c.PublicInput, emFr{}.Modulus().BitLen()), field.ToBitsCanonical(&c.PublicInputWitness.Public[0]))
vks := append(slices.Clone(c.verifyingKeys), c.publicInputVerifyingKey)
piVkIndex := len(vks) - 1
for i := range c.ProofClaims {
api.AssertIsLessOrEqual(c.ProofClaims[i].CircuitID, len(c.verifyingKeys)-1) // make sure the prover can't sneak in an extra PI circuit
}
// TODO incorporate statements (circuitID+publicInput?) and vk's into public input
// create a lookup table of actual public inputs
actualPI := make([]*logderivlookup.Table, (emFr{}).NbLimbs())
for i := range actualPI {
actualPI[i] = logderivlookup.New(api)
}
for _, claim := range c.ProofClaims {
if len(claim.PublicInput.Public) != 1 {
return errors.New("expected 1 public input per decompression/execution circuit")
}
pi := claim.PublicInput.Public[0]
for i := range actualPI {
actualPI[i].Insert(pi.Limbs[i])
}
}
if len(c.PublicInputWitnessClaimIndexes)+1 != len(c.PublicInputWitness.Public) {
return errors.New("expected the number of public inputs to match the number of public input witness claim indexes")
}
// verify that every valid input to the PI circuit is accounted for
for i, actualI := range c.PublicInputWitnessClaimIndexes {
hubPI := &c.PublicInputWitness.Public[i+1]
isNonZero := api.Sub(1, field.IsZero(hubPI)) // if a PI is zero, due to preimage resistance we can infer that the PI circuit is not using it
for j := range actualPI {
internal.AssertEqualIf(api, isNonZero, actualPI[j].Lookup(actualI)[0], hubPI.Limbs[j])
}
}
claims := append(slices.Clone(c.ProofClaims), proofClaim{
CircuitID: piVkIndex,
Proof: c.PublicInputProof,
PublicInput: c.PublicInputWitness,
})
// Verify the constraints the execution proofs
if err = verifyClaimBatch(api, vks, claims); err != nil {
return fmt.Errorf("processing execution proofs: %w", err)
}
return err
}
// Instantiate a new AggregationCircuit from a list of verification keys and
// Instantiate a new Circuit from a list of verification keys and
// a maximal number of proofs. The function should only be called with the
// purpose of running `frontend.Compile` over it.
func AllocateAggregationCircuit(
nbProofs int,
verifyingKeys []plonk.VerifyingKey,
) (*AggregationCircuit, error) {
func AllocateCircuit(nbProofs int, allowedInputs []string, key plonk.VerifyingKey, verifyingKeys []plonk.VerifyingKey) (*Circuit, error) {
var (
err error
@@ -84,7 +128,7 @@ func AllocateAggregationCircuit(
proofClaims[i] = allocatableClaimPlaceHolder(csPlaceHolder)
}
return &AggregationCircuit{
return &Circuit{
verifyingKeys: emVKeys,
ProofClaims: proofClaims,
}, nil
@@ -121,160 +165,3 @@ func verifyClaimBatch(api frontend.API, vks []emVkey, claims []proofClaim) error
}
return nil
}
type FunctionalPublicInputQSnark struct {
ParentShnarf [32]frontend.Variable
NbDecompression frontend.Variable
InitialStateRootHash frontend.Variable
InitialBlockNumber frontend.Variable
InitialBlockTimestamp frontend.Variable
InitialRollingHash [32]frontend.Variable
InitialRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
}
type FunctionalPublicInputSnark struct {
FunctionalPublicInputQSnark
NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]frontend.Variable
// FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
}
// FunctionalPublicInput holds the same info as public_input.Aggregation, except in parsed form
type FunctionalPublicInput struct {
ParentShnarf [32]byte
NbDecompression uint64
InitialStateRootHash [32]byte
InitialBlockNumber uint64
InitialBlockTimestamp uint64
InitialRollingHash [32]byte
InitialRollingHashNumber uint64
ChainID uint64 // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr types.EthAddress
NbL2Messages uint64 // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]byte
//FinalStateRootHash [32]byte redundant: incorporated into shnarf
FinalBlockNumber uint64
FinalBlockTimestamp uint64
FinalRollingHash [32]byte
FinalRollingHashNumber uint64
FinalShnarf [32]byte
L2MsgMerkleTreeDepth int
}
// NewFunctionalPublicInput does NOT set all fields, only the ones covered in public_input.Aggregation
func NewFunctionalPublicInput(fpi *public_input.Aggregation) (s *FunctionalPublicInput, err error) {
s = &FunctionalPublicInput{
InitialBlockNumber: uint64(fpi.LastFinalizedBlockNumber),
InitialBlockTimestamp: uint64(fpi.ParentAggregationLastBlockTimestamp),
InitialRollingHashNumber: uint64(fpi.LastFinalizedL1RollingHashMessageNumber),
L2MsgMerkleTreeRoots: make([][32]byte, len(fpi.L2MsgRootHashes)),
FinalBlockNumber: uint64(fpi.FinalBlockNumber),
FinalBlockTimestamp: uint64(fpi.FinalTimestamp),
FinalRollingHashNumber: uint64(fpi.L1RollingHashMessageNumber),
L2MsgMerkleTreeDepth: fpi.L2MsgMerkleTreeDepth,
}
if err = copyFromHex(s.InitialStateRootHash[:], fpi.ParentStateRootHash); err != nil {
return
}
if err = copyFromHex(s.FinalRollingHash[:], fpi.L1RollingHash); err != nil {
return
}
if err = copyFromHex(s.InitialRollingHash[:], fpi.LastFinalizedL1RollingHash); err != nil {
return
}
if err = copyFromHex(s.ParentShnarf[:], fpi.ParentAggregationFinalShnarf); err != nil {
return
}
if err = copyFromHex(s.FinalShnarf[:], fpi.FinalShnarf); err != nil {
return
}
for i := range s.L2MsgMerkleTreeRoots {
if err = copyFromHex(s.L2MsgMerkleTreeRoots[i][:], fpi.L2MsgRootHashes[i]); err != nil {
return
}
}
return
}
func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
s := FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{
InitialBlockNumber: pi.InitialBlockNumber,
InitialBlockTimestamp: pi.InitialBlockTimestamp,
InitialRollingHash: [32]frontend.Variable{},
InitialRollingHashNumber: pi.InitialRollingHashNumber,
InitialStateRootHash: pi.InitialStateRootHash[:],
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
},
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
FinalRollingHashNumber: pi.FinalRollingHashNumber,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
}
internal.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:])
internal.Copy(s.InitialRollingHash[:], pi.InitialRollingHash[:])
internal.Copy(s.ParentShnarf[:], pi.ParentShnarf[:])
internal.Copy(s.FinalShnarf[:], pi.FinalShnarf[:])
for i := range s.L2MsgMerkleTreeRoots {
internal.Copy(s.L2MsgMerkleTreeRoots[i][:], pi.L2MsgMerkleTreeRoots[i][:])
}
return s
}
func (pi *FunctionalPublicInputSnark) Sum(api frontend.API, hash keccak.BlockHasher) [32]frontend.Variable {
// number of hashes: 12
sum := hash.Sum(nil,
pi.ParentShnarf,
pi.FinalShnarf,
internal.ToBytes(api, pi.InitialBlockTimestamp),
internal.ToBytes(api, pi.FinalBlockTimestamp),
internal.ToBytes(api, pi.InitialBlockNumber),
internal.ToBytes(api, pi.FinalBlockNumber),
pi.InitialRollingHash,
pi.FinalRollingHash,
internal.ToBytes(api, pi.InitialRollingHashNumber),
internal.ToBytes(api, pi.FinalRollingHashNumber),
internal.ToBytes(api, pi.L2MsgMerkleTreeDepth),
hash.Sum(nil, pi.L2MsgMerkleTreeRoots...),
)
// turn the hash into a bn254 element
var res [32]frontend.Variable
copy(res[:], internal.ReduceBytes[emulated.BN254Fr](api, sum[:]))
return res
}
func (pi *FunctionalPublicInputQSnark) RangeCheck(api frontend.API) {
rc := rangecheck.New(api)
for _, v := range append(slices.Clone(pi.InitialRollingHash[:]), pi.ParentShnarf[:]...) {
rc.Check(v, 8)
}
// not checking L2MsgServiceAddr as its range is never assumed in the pi circuit
// not checking NbDecompressions as the NewRange in the pi circuit range checks it; TODO do it here instead
}
func copyFromHex(dst []byte, src string) error {
b, err := utils.HexDecodeString(src)
if err != nil {
return err
}
copy(dst[len(dst)-len(b):], b) // panics if src is too long
return nil
}

View File

@@ -32,7 +32,7 @@ func TestPublicInput(t *testing.T) {
for i := range testCases {
fpi, err := NewFunctionalPublicInput(&testCases[i])
fpi, err := public_input.NewAggregationFPI(&testCases[i])
assert.NoError(t, err)
sfpi := fpi.ToSnarkType()

View File

@@ -2,7 +2,6 @@ package aggregation
import (
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
"github.com/consensys/gnark/backend/plonk"
@@ -17,6 +16,7 @@ func MakeProof(
setup *circuits.Setup,
maxNbProof int,
proofClaims []ProofClaimAssignment,
piProof plonk.Proof,
publicInput fr.Element,
) (
plonk.Proof,
@@ -27,6 +27,7 @@ func MakeProof(
assignment, err := AssignAggregationCircuit(
maxNbProof,
proofClaims,
piProof,
publicInput,
)
@@ -44,18 +45,11 @@ func MakeProof(
}
// Assigns the proof using placeholders
func AssignAggregationCircuit(
maxNbProof int,
proofClaims []ProofClaimAssignment,
publicInput fr.Element,
) (
c *AggregationCircuit,
err error,
) {
func AssignAggregationCircuit(maxNbProof int, proofClaims []ProofClaimAssignment, publicInputProof plonk.Proof, publicInput fr.Element) (c *Circuit, err error) {
c = &AggregationCircuit{
ProofClaims: make([]proofClaim, maxNbProof),
DummyPublicInput: publicInput,
c = &Circuit{
ProofClaims: make([]proofClaim, maxNbProof),
PublicInput: publicInput,
}
for i := range c.ProofClaims {

View File

@@ -6,7 +6,7 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/compress"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/utils"
"math/big"
"math/bits"
)
@@ -87,7 +87,7 @@ func VerifyBlobConsistency(api frontend.API, blobCrumbs []frontend.Variable, eva
}
blobEmulated := packCrumbsEmulated(api, blobCrumbs) // perf-TODO use the original blob bytes
evaluationChallengeEmulated := internal.NewElementFromBytes[emulated.BLS12381Fr](api, evaluationChallenge[:])
evaluationChallengeEmulated := utils.NewElementFromBytes[emulated.BLS12381Fr](api, evaluationChallenge[:])
blobEmulatedBitReversed := make([]*emulated.Element[emulated.BLS12381Fr], len(blobEmulated))
copy(blobEmulatedBitReversed, blobEmulated)

View File

@@ -9,6 +9,7 @@ import (
"encoding/json"
"fmt"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/utils"
"math/big"
"os"
"path/filepath"
@@ -89,11 +90,11 @@ func TestInterpolateLagrange(t *testing.T) {
scalars, err := internal.Bls12381ScalarToBls12377Scalars(evaluationPointFr)
assert.NoError(t, err)
internal.Copy(assignment.EvaluationPoint[:], scalars[:])
utils.Copy(assignment.EvaluationPoint[:], scalars[:])
scalars, err = internal.Bls12381ScalarToBls12377Scalars(evaluation)
assert.NoError(t, err)
internal.Copy(assignment.Evaluation[:], scalars[:])
utils.Copy(assignment.Evaluation[:], scalars[:])
return &assignment
}
@@ -198,7 +199,7 @@ func decodeHexHL(t *testing.T, s string) (r [2]frontend.Variable) {
scalars, err := internal.Bls12381ScalarToBls12377Scalars(b)
assert.NoError(t, err)
internal.Copy(r[:], scalars[:])
utils.Copy(r[:], scalars[:])
return
}
@@ -238,7 +239,7 @@ func TestVerifyBlobConsistencyIntegration(t *testing.T) {
assignment.Eip4844Enabled = 1
}
internal.Copy(assignment.X[:], decodeHex(t, testCase.ExpectedX))
utils.Copy(assignment.X[:], decodeHex(t, testCase.ExpectedX))
assignment.Y = decodeHexHL(t, testCase.ExpectedY)
t.Run(folderAndFile, func(t *testing.T) {
@@ -375,7 +376,7 @@ func TestFrConversions(t *testing.T) {
assert.Equal(t, tmp, xBack, fmt.Sprintf("out-of-snark conversion round-trip failed on %s or 0x%s", tmp.Text(10), tmp.Text(16)))
var assignment testFrConversionCircuit
internal.Copy(assignment.X[:], xPartitioned[:])
utils.Copy(assignment.X[:], xPartitioned[:])
options = append(options, test.WithValidAssignment(&assignment))
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/utils"
fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
@@ -137,7 +138,7 @@ func (c *Circuit) Define(api frontend.API) error {
}
}
xBytes := internal.ToBytes(api, c.X[1])
xBytes := utils.ToBytes(api, c.X[1])
rc := rangecheck.New(api)
const nbBitsLower = (fr377.Bits - 1) % 8
rc.Check(xBytes[0], nbBitsLower)

View File

@@ -101,7 +101,8 @@ func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Elem
// in the Assign function although this creates an unintuitive side effect.
// It should be harmless though.
lzss.RegisterHints()
internal.RegisterHints()
//internal.RegisterHints()
utils.RegisterHints()
a := Circuit{
BlobBytesLen: len(blobDataUnpacked),

View File

@@ -18,6 +18,7 @@ import (
test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc"
"github.com/consensys/zkevm-monorepo/prover/utils"
"math/big"
blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1"
@@ -111,11 +112,11 @@ func (i *FunctionalPublicInput) ToSnarkType() (FunctionalPublicInputSnark, error
NbBatches: len(i.BatchSums),
},
}
internal.Copy(res.X[:], i.X[:])
utils.Copy(res.X[:], i.X[:])
if len(i.BatchSums) > len(res.BatchSums) {
return res, errors.New("batches do not fit in circuit")
}
for n := internal.Copy(res.BatchSums[:], i.BatchSums); n < len(res.BatchSums); n++ {
for n := utils.Copy(res.BatchSums[:], i.BatchSums); n < len(res.BatchSums); n++ {
res.BatchSums[n] = 0
}
return res, nil

View File

@@ -7,6 +7,7 @@ import (
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/crypto/mimc"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/consensys/zkevm-monorepo/prover/utils/types"
"hash"
"slices"
@@ -15,11 +16,11 @@ import (
// FunctionalPublicInputQSnark the information on this execution that cannot be extracted from other input in the same aggregation batch
type FunctionalPublicInputQSnark struct {
DataChecksum frontend.Variable
L2MessageHashes internal.Var32Slice // TODO range check
L2MessageHashes internal.Var32Slice
FinalStateRootHash frontend.Variable
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalRollingHash [32]frontend.Variable // TODO range check
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
}
@@ -103,15 +104,14 @@ func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
ChainID: pi.ChainID,
L2MessageServiceAddr: slices.Clone(pi.L2MessageServiceAddr[:]),
}
internal.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:])
internal.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:])
utils.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:])
utils.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:])
return res
}
func (pi *FunctionalPublicInput) Sum() []byte { // all mimc; no need to provide a keccak hasher
hsh := mimc.NewMiMC()
// TODO incorporate length too? Technically not necessary
var zero [1]byte
for i := range pi.L2MessageHashes {

View File

@@ -17,7 +17,7 @@ func CopyHexEncodedBytes(dst []frontend.Variable, hex string) error {
dst[i] = 0
}
Copy(dst[slack:], b) // This will panic if b is too long
utils.Copy(dst[slack:], b) // This will panic if b is too long
return nil
}

View File

@@ -180,3 +180,20 @@ func (FakeTestingT) Errorf(format string, args ...interface{}) {
func (FakeTestingT) FailNow() {
os.Exit(-1)
}
func RandIntN(n int) int {
var b [8]byte
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
}
func RandIntSliceN(length, n int) []int {
res := make([]int, length)
for i := range res {
res[i] = RandIntN(n)
}
return res
}

View File

@@ -13,8 +13,8 @@ import (
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal/plonk"
"github.com/consensys/zkevm-monorepo/prover/utils"
"golang.org/x/exp/constraints"
"math/big"
"slices"
@@ -128,7 +128,7 @@ func (r *Range) LastArray32F(provider func(int) [32]frontend.Variable) [32]front
}
func RegisterHints() {
hint.RegisterHint(toCrumbsHint, concatHint, decomposeIntoBytesHint, checksumSubSlicesHint)
hint.RegisterHint(toCrumbsHint, concatHint, checksumSubSlicesHint, partitionSliceHint)
}
func toCrumbsHint(_ *big.Int, ins, outs []*big.Int) error {
@@ -605,71 +605,6 @@ func Ite[T any](cond bool, ifSo, ifNot T) T {
return ifNot
}
func ToBytes(api frontend.API, x frontend.Variable) [32]frontend.Variable {
var res [32]frontend.Variable
d := decomposeIntoBytes(api, x, fr377.Bits)
slack := 32 - len(d) // should be zero
copy(res[slack:], d)
for i := 0; i < slack; i++ {
res[i] = 0
}
return res
}
func decomposeIntoBytes(api frontend.API, data frontend.Variable, nbBits int) []frontend.Variable {
if nbBits == 0 {
nbBits = api.Compiler().FieldBitLen()
}
nbBytes := (api.Compiler().FieldBitLen() + 7) / 8
bytes, err := api.Compiler().NewHint(decomposeIntoBytesHint, nbBytes, data)
if err != nil {
panic(err)
}
lastNbBits := nbBits % 8
if lastNbBits == 0 {
lastNbBits = 8
}
rc := rangecheck.New(api)
api.AssertIsLessOrEqual(bytes[0], 1<<lastNbBits-1) //TODO try range checking this as well
for i := 1; i < nbBytes; i++ {
rc.Check(bytes[i], 8)
}
return bytes
}
func decomposeIntoBytesHint(_ *big.Int, ins, outs []*big.Int) error {
nbBytes := len(outs) / len(ins)
if nbBytes*len(ins) != len(outs) {
return errors.New("incongruent number of ins/outs")
}
var v, radix, zero big.Int
radix.SetUint64(256)
for i := range ins {
v.Set(ins[i])
for j := nbBytes - 1; j >= 0; j-- {
outs[i*nbBytes+j].Mod(&v, &radix)
v.Rsh(&v, 8)
}
if v.Cmp(&zero) != 0 {
return errors.New("not fitting in len(outs)/len(ins) many bytes")
}
}
return nil
}
func Copy[T any](dst []frontend.Variable, src []T) (n int) {
n = min(len(dst), len(src))
for i := 0; i < n; i++ {
dst[i] = src[i]
}
return
}
// PartialSums returns s[0], s[0]+s[1], ..., s[0]+s[1]+...+s[len(s)-1]
func PartialSums(api frontend.API, s []frontend.Variable) []frontend.Variable {
res := make([]frontend.Variable, len(s))
@@ -690,46 +625,6 @@ func Differences(api frontend.API, s []frontend.Variable) []frontend.Variable {
return res
}
// NewElementFromBytes range checks the bytes and gives a reduced field element
func NewElementFromBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) *emulated.Element[T] {
bits := make([]frontend.Variable, 8*len(bytes))
for i := range bytes {
copy(bits[8*i:], api.ToBinary(bytes[len(bytes)-i-1], 8))
}
f, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
return f.Reduce(f.Add(f.FromBits(bits...), f.Zero()))
}
// ReduceBytes reduces given bytes modulo a given field. As a side effect, the "bytes" are range checked
func ReduceBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) []frontend.Variable {
f, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
bits := f.ToBits(NewElementFromBytes[T](api, bytes))
res := make([]frontend.Variable, (len(bits)+7)/8)
copy(bits[:], bits)
for i := len(bits); i < len(bits); i++ {
bits[i] = 0
}
for i := range res {
bitsStart := 8 * (len(res) - i - 1)
bitsEnd := bitsStart + 8
if i == 0 {
bitsEnd = len(bits)
}
res[i] = api.FromBinary(bits[bitsStart:bitsEnd]...)
}
return res
}
func NewSliceOf32Array[T any](values [][32]T, maxLen int) Var32Slice {
if maxLen < len(values) {
panic("maxLen too small")
@@ -739,11 +634,11 @@ func NewSliceOf32Array[T any](values [][32]T, maxLen int) Var32Slice {
Length: len(values),
}
for i := range values {
Copy(res.Values[i][:], values[i][:])
utils.Copy(res.Values[i][:], values[i][:])
}
var zeros [32]byte
for i := len(values); i < maxLen; i++ {
Copy(res.Values[i][:], zeros[:])
utils.Copy(res.Values[i][:], zeros[:])
}
return res
}
@@ -811,3 +706,168 @@ func CloneSlice[T any](s []T, cap ...int) []T {
copy(res, s)
return res
}
// PartitionSlice populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i
// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact.
// It may produce an incorrect result if selectors are out of range
func PartitionSlice(api frontend.API, s []frontend.Variable, selectors []frontend.Variable, subs ...[]frontend.Variable) {
if len(s) != len(selectors) {
panic("s and selectors must have the same length")
}
hintIn := make([]frontend.Variable, 1+len(subs)+len(s)+len(selectors))
hintIn[0] = len(subs)
hintOutLen := 0
for i := range subs {
hintIn[1+i] = len(subs[i])
hintOutLen += len(subs[i])
}
for i := range s {
hintIn[1+len(subs)+i] = s[i]
hintIn[1+len(subs)+len(s)+i] = selectors[i]
}
subsGlued, err := api.Compiler().NewHint(partitionSliceHint, hintOutLen, hintIn...)
if err != nil {
panic(err)
}
subsT := make([]*logderivlookup.Table, len(subs))
for i := range subs {
copy(subs[i], subsGlued[:len(subs[i])])
subsGlued = subsGlued[len(subs[i]):]
subsT[i] = SliceToTable(api, subs[i])
subsT[i].Insert(0)
}
subI := make([]frontend.Variable, len(subs))
for i := range subI {
subI[i] = 0
}
indicators := make([]frontend.Variable, len(subs))
subHeads := make([]frontend.Variable, len(subs))
for i := range s {
for j := range subs[:len(subs)-1] {
indicators[j] = api.IsZero(api.Sub(selectors[i], j))
}
indicators[len(subs)-1] = api.Sub(1, SumSnark(api, indicators[:len(subs)-1]...))
for j := range subs {
subHeads[j] = subsT[j].Lookup(subI[j])[0]
subI[j] = api.Add(subI[j], indicators[j])
}
api.AssertIsEqual(s[i], InnerProd(api, subHeads, indicators))
}
// Check that the dummy trailing values weren't actually picked
for i := range subI {
api.AssertIsDifferent(subI[i], len(subs[i])+1)
}
}
func SumSnark(api frontend.API, x ...frontend.Variable) frontend.Variable {
res := frontend.Variable(0)
for i := range x {
res = api.Add(res, x[i])
}
return res
}
// ins: [nbSubs, maxLen_0, ..., maxLen_{nbSubs-1}, s..., indicators...]
func partitionSliceHint(_ *big.Int, ins, outs []*big.Int) error {
subs := make([][]*big.Int, ins[0].Uint64())
for i := range subs {
subs[i] = outs[:ins[1+i].Uint64()]
outs = outs[len(subs[i]):]
}
if len(outs) != 0 {
return errors.New("the sum of subslice max lengths does not equal output length")
}
ins = ins[1+len(subs):]
s := ins[:len(ins)/2]
indicators := ins[len(s):]
if len(s) != len(indicators) {
return errors.New("s and indicators must be of the same length")
}
for i := range s {
b := int(indicators[i].Uint64())
if b < 0 || b >= len(subs) || !indicators[i].IsUint64() {
return errors.New("indicator out of range")
}
subs[b][0] = s[i]
subs[b] = subs[b][1:]
}
for i := range subs {
for j := range subs[i] {
subs[i][j].SetInt64(0)
}
}
return nil
}
// PartitionSliceEmulated populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i
// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact.
// It may produce an incorrect result if selectors are out of range
func PartitionSliceEmulated[T emulated.FieldParams](api frontend.API, s []emulated.Element[T], selectors []frontend.Variable, subSliceMaxLens ...int) [][]emulated.Element[T] {
field, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
// transpose limbs for selection
limbs := make([][]frontend.Variable, len(s[0].Limbs)) // limbs are indexed limb first, element second
for i := range limbs {
limbs[i] = make([]frontend.Variable, len(s))
}
for i := range s {
if len(limbs) != len(s[i].Limbs) {
panic("expected uniform number of limbs")
}
for j := range limbs {
limbs[j][i] = s[i].Limbs[j]
}
}
subLimbs := make([][][]frontend.Variable, len(limbs)) // subLimbs is indexed limb first, sub-slice second, element third
for i := range limbs { // construct the sub-slices limb by limb
subLimbs[i] = make([][]frontend.Variable, len(subSliceMaxLens))
for j := range subSliceMaxLens {
subLimbs[i][j] = make([]frontend.Variable, subSliceMaxLens[j])
}
PartitionSlice(api, limbs[i], selectors, subLimbs[i]...)
}
// put the limbs back together
subSlices := make([][]emulated.Element[T], len(subSliceMaxLens))
for i := range subSlices {
subSlices[i] = make([]emulated.Element[T], subSliceMaxLens[i])
for j := range subSlices[i] {
currLimbs := make([]frontend.Variable, len(limbs))
for k := range currLimbs {
currLimbs[k] = subLimbs[k][i][j]
}
subSlices[i][j] = *field.NewElement(currLimbs) // TODO make sure dereferencing is not problematic
}
}
return subSlices
}
func InnerProd(api frontend.API, x, y []frontend.Variable) frontend.Variable {
if len(x) != len(y) {
panic("mismatched lengths")
}
res := frontend.Variable(0)
for i := range x {
res = api.Add(res, api.Mul(x[i], y[i]))
}
return res
}

View File

@@ -3,6 +3,8 @@ package internal_test
import (
"crypto/rand"
"fmt"
fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash/mimc"
@@ -55,7 +57,7 @@ func testChecksumSubSlices(t *testing.T, bigSliceLength, lengthsSliceLength int,
}
endPointsSnark := make([]frontend.Variable, lengthsSliceLength)
for n := internal.Copy(endPointsSnark, internal.PartialSumsInt(lengths)); n < lengthsSliceLength; n++ {
for n := utils.Copy(endPointsSnark, internal.PartialSumsInt(lengths)); n < lengthsSliceLength; n++ {
endPointsSnark[n] = n * 234
}
@@ -103,7 +105,7 @@ func TestReduceBytes(t *testing.T) {
test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable {
for i := range cases {
got := internal.ReduceBytes[emulated.BN254Fr](api, test_vector_utils.ToVariableSlice(cases[i]))
got := utils.ReduceBytes[emulated.BN254Fr](api, test_vector_utils.ToVariableSlice(cases[i]))
internal.AssertSliceEquals(api,
got,
test_vector_utils.ToVariableSlice(reduced[i]),
@@ -113,3 +115,113 @@ func TestReduceBytes(t *testing.T) {
return nil
})(t)
}
func TestPartitionSliceEmulated(t *testing.T) {
selectors := []int{1, 0, 2, 2, 1}
s := make([]fr381.Element, len(selectors))
for i := range s {
_, err := s[i].SetRandom()
assert.NoError(t, err)
}
subs := make([][]fr381.Element, 3)
for i := range subs {
subs[i] = make([]fr381.Element, 0, len(selectors)-1)
}
for i := range s {
subs[selectors[i]] = append(subs[selectors[i]], s[i])
}
test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable {
field, err := emulated.NewField[emulated.BLS12381Fr](api)
assert.NoError(t, err)
// convert randomized elements to emulated
sEmulated := elementsToEmulated(field, s)
subsEmulatedExpected := internal.MapSlice(func(s []fr381.Element) []emulated.Element[emulated.BLS12381Fr] {
return elementsToEmulated(field, append(s, make([]fr381.Element, cap(s)-len(s))...)) // pad with zeros to see if padding is done correctly
}, subs...)
subsEmulated := internal.PartitionSliceEmulated(api, sEmulated, test_vector_utils.ToVariableSlice(selectors), internal.MapSlice(func(s []fr381.Element) int { return cap(s) }, subs...)...)
assert.Equal(t, len(subsEmulatedExpected), len(subsEmulated))
for i := range subsEmulated {
assert.Equal(t, len(subsEmulatedExpected[i]), len(subsEmulated[i]))
for j := range subsEmulated[i] {
field.AssertIsEqual(&subsEmulated[i][j], &subsEmulatedExpected[i][j])
}
}
return nil
})(t)
}
func elementsToEmulated(field *emulated.Field[emulated.BLS12381Fr], s []fr381.Element) []emulated.Element[emulated.BLS12381Fr] {
return internal.MapSlice(func(element fr381.Element) emulated.Element[emulated.BLS12381Fr] {
return *field.NewElement(internal.MapSlice(func(x uint64) frontend.Variable { return x }, element[:]...))
}, s...)
}
func TestPartitionSlice(t *testing.T) {
const (
nbSubs = 3
sliceLen = 10
)
test := func(slice []frontend.Variable, selectors []int, subsSlack []int) func(*testing.T) {
assert.Equal(t, len(selectors), len(slice))
assert.Equal(t, len(subsSlack), nbSubs)
subs := make([][]frontend.Variable, nbSubs)
for j := range subs {
subs[j] = make([]frontend.Variable, 0, sliceLen)
}
for j := range slice {
subs[selectors[j]] = append(subs[selectors[j]], slice[j])
}
for j := range subs {
subs[j] = append(subs[j], test_vector_utils.ToVariableSlice(make([]int, subsSlack[j]))...) // add some padding
}
return test_utils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable {
slice := test_vector_utils.ToVariableSlice(slice)
subsEncountered := internal.MapSlice(func(s []frontend.Variable) []frontend.Variable { return make([]frontend.Variable, len(s)) }, subs...)
internal.PartitionSlice(api, slice, test_vector_utils.ToVariableSlice(selectors), subsEncountered...)
assert.Equal(t, len(subs), len(subsEncountered))
for j := range subsEncountered {
internal.AssertSliceEquals(api, subsEncountered[j], subs[j])
}
return nil
})
}
test([]frontend.Variable{5}, []int{2}, []int{1, 0, 0})(t)
test([]frontend.Variable{1, 2, 3}, []int{0, 1, 2}, []int{0, 0, 0})
test(test_vector_utils.ToVariableSlice(test_utils.Range[int](10)), []int{0, 1, 2, 0, 0, 0, 1, 1, 1, 2}, []int{0, 0, 0})
for i := 0; i < 200; i++ {
slice := make([]frontend.Variable, sliceLen)
for j := range slice {
var x fr377.Element
_, err := x.SetRandom()
slice[j] = &x
assert.NoError(t, err)
}
selectors := test_utils.RandIntSliceN(sliceLen, nbSubs)
subsSlack := test_utils.RandIntSliceN(nbSubs, 2)
t.Run(fmt.Sprintf("selectors=%v,slack=%v", selectors, subsSlack), test(slice, selectors, subsSlack))
}
}

View File

@@ -6,7 +6,6 @@ import (
"errors"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/zkevm-monorepo/prover/backend/blobsubmission"
"github.com/consensys/zkevm-monorepo/prover/circuits/aggregation"
decompression "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v1"
"github.com/consensys/zkevm-monorepo/prover/circuits/execution"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
@@ -18,31 +17,22 @@ import (
"hash"
)
type ExecutionRequest struct {
L2MsgHashes [][32]byte
FinalStateRootHash [32]byte
FinalBlockNumber uint64
FinalBlockTimestamp uint64
FinalRollingHash [32]byte
FinalRollingHashNumber uint64
}
type Request struct {
DecompDict []byte
Decompressions []blobsubmission.Response
Executions []ExecutionRequest
Executions []public_input.Execution
Aggregation public_input.Aggregation
}
func (c *Compiled) Assign(r Request) (a Circuit, err error) {
internal.RegisterHints()
keccak.RegisterHints()
utils.RegisterHints()
// TODO there is data duplication in the request. Check consistency
// infer config
config := c.getConfig()
a = config.allocateCircuit()
a = allocateCircuit(config)
if len(r.Decompressions) > config.MaxNbDecompression {
err = errors.New("number of decompression proofs exceeds maximum")
@@ -69,7 +59,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
if err != nil {
return
}
internal.Copy(a.ParentShnarf[:], prevShnarf)
utils.Copy(a.ParentShnarf[:], prevShnarf)
execDataChecksums := make([][]byte, 0, len(r.Executions))
shnarfs := make([][]byte, config.MaxNbDecompression)
@@ -156,7 +146,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
} else {
a.DecompressionFPIQ[i] = fpis.FunctionalPublicInputQSnark
}
internal.Copy(a.DecompressionFPIQ[i].X[:], zero[:])
utils.Copy(a.DecompressionFPIQ[i].X[:], zero[:])
if a.DecompressionPublicInput[i], err = fpi.Sum(decompression.WithBatchesSum(zero[:])); err != nil { // TODO zero batches sum is probably incorrect
return
}
@@ -164,7 +154,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
// Aggregation FPI
aggregationFPI, err := aggregation.NewFunctionalPublicInput(&r.Aggregation)
aggregationFPI, err := public_input.NewAggregationFPI(&r.Aggregation)
if err != nil {
return
}
@@ -173,10 +163,10 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
return
}
aggregationFPI.NbDecompression = uint64(len(r.Decompressions))
a.FunctionalPublicInputQSnark = aggregationFPI.ToSnarkType().FunctionalPublicInputQSnark
a.AggregationFPIQSnark = aggregationFPI.ToSnarkType().AggregationFPIQSnark
merkleNbLeaves := 1 << config.L2MsgMerkleDepth
maxNbL2MessageHashes := config.L2MessageMaxNbMerkle * merkleNbLeaves
maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves
l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes)
// Execution FPI
executionFPI := execution.FunctionalPublicInput{
@@ -187,7 +177,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
FinalRollingHashNumber: aggregationFPI.InitialRollingHashNumber,
L2MessageServiceAddr: aggregationFPI.L2MessageServiceAddr,
ChainID: aggregationFPI.ChainID,
MaxNbL2MessageHashes: config.MaxNbMsgPerExecution,
MaxNbL2MessageHashes: config.ExecutionMaxNbMsg,
}
for i := range a.ExecutionFPIQ {
executionFPI.InitialRollingHash = executionFPI.FinalRollingHash
@@ -251,16 +241,16 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}
// pad the merkle roots
if len(r.Aggregation.L2MsgRootHashes) > config.L2MessageMaxNbMerkle {
if len(r.Aggregation.L2MsgRootHashes) > config.L2MsgMaxNbMerkle {
err = errors.New("more merkle trees than there is capacity")
return
}
{
roots := internal.CloneSlice(r.Aggregation.L2MsgRootHashes, config.L2MessageMaxNbMerkle)
roots := internal.CloneSlice(r.Aggregation.L2MsgRootHashes, config.L2MsgMaxNbMerkle)
emptyRootHex := utils.HexEncodeToString(emptyTree[len(emptyTree)-1][:32])
for i := len(r.Aggregation.L2MsgRootHashes); i < config.L2MessageMaxNbMerkle; i++ {
for i := len(r.Aggregation.L2MsgRootHashes); i < config.L2MsgMaxNbMerkle; i++ {
for depth := config.L2MsgMerkleDepth; depth > 0; depth-- {
for j := 0; j < 1<<(depth-1); j++ {
hshK.Skip(emptyTree[config.L2MsgMerkleDepth-depth])

View File

@@ -11,6 +11,7 @@ import (
"github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
pitesting "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/test_utils"
"github.com/consensys/zkevm-monorepo/prover/config"
"github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc"
"time"
@@ -23,14 +24,14 @@ func main() {
var b test_utils.FakeTestingT
req := pitesting.AssignSingleBlockBlob(b)
c, err := pi_interconnection.Config{
MaxNbDecompression: 400,
MaxNbExecution: 400,
MaxNbKeccakF: 10000,
MaxNbMsgPerExecution: 16,
L2MsgMerkleDepth: 5,
L2MessageMaxNbMerkle: 10,
}.Compile(dummy.Compile) // note that the solving/proving time will not reflect the wizard proof or verification
c, err := pi_interconnection.Compile(config.PublicInput{
MaxNbDecompression: 400,
MaxNbExecution: 400,
MaxNbKeccakF: 10000,
ExecutionMaxNbMsg: 16,
L2MsgMerkleDepth: 5,
L2MsgMaxNbMerkle: 10,
}, dummy.Compile) // note that the solving/proving time will not reflect the wizard proof or verification
assert.NoError(b, err)
a, err := c.Assign(req)

View File

@@ -2,6 +2,13 @@ package pi_interconnection
import (
"errors"
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/constraint"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/zkevm-monorepo/prover/circuits"
"github.com/consensys/zkevm-monorepo/prover/config"
public_input "github.com/consensys/zkevm-monorepo/prover/public-input"
"math/big"
"slices"
@@ -11,7 +18,6 @@ import (
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/zkevm-monorepo/prover/circuits/aggregation"
decompression "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v1"
"github.com/consensys/zkevm-monorepo/prover/circuits/execution"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
@@ -23,12 +29,13 @@ import (
type Circuit struct {
AggregationPublicInput frontend.Variable `gnark:",public"` // the public input of the aggregation circuit
DecompressionPublicInput []frontend.Variable `gnark:",public"`
ExecutionPublicInput []frontend.Variable `gnark:",public"`
DecompressionPublicInput []frontend.Variable `gnark:",public"`
DecompressionFPIQ []decompression.FunctionalPublicInputQSnark
ExecutionFPIQ []execution.FunctionalPublicInputQSnark
aggregation.FunctionalPublicInputQSnark
public_input.AggregationFPIQSnark
Keccak keccak.StrictHasherCircuit
@@ -36,7 +43,7 @@ type Circuit struct {
L2MessageMerkleDepth int
L2MessageMaxNbMerkle int
MaxNbCircuits int
MaxNbCircuits int // possibly useless TODO consider removing
UseGkrMimc bool
}
@@ -46,7 +53,7 @@ func (c *Circuit) Define(api frontend.API) error {
return errors.New("public / functional public input length mismatch")
}
c.FunctionalPublicInputQSnark.RangeCheck(api)
c.AggregationFPIQSnark.RangeCheck(api)
rDecompression := internal.NewRange(api, c.NbDecompression, len(c.DecompressionPublicInput))
hshK := c.Keccak.NewHasher(api)
@@ -90,8 +97,8 @@ func (c *Circuit) Define(api frontend.API) error {
piq.RangeCheck(api)
shnarfParams[i] = ShnarfIteration{ // prepare shnarf verification data
BlobDataSnarkHash: internal.ToBytes(api, piq.SnarkHash),
NewStateRootHash: internal.ToBytes(api, finalStateRootHashes.Lookup(api.Sub(nbBatchesSums[i], 1))[0]),
BlobDataSnarkHash: utils.ToBytes(api, piq.SnarkHash),
NewStateRootHash: utils.ToBytes(api, finalStateRootHashes.Lookup(api.Sub(nbBatchesSums[i], 1))[0]),
EvaluationPointBytes: piq.X,
EvaluationClaimBytes: fr377EncodedFr381ToBytes(api, piq.Y),
}
@@ -162,20 +169,20 @@ func (c *Circuit) Define(api frontend.API) error {
}
rExecution := internal.NewRange(api, nbExecution, maxNbExecution)
pi := aggregation.FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: c.FunctionalPublicInputQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }),
FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
pi := public_input.AggregationFPISnark{
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }),
FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
}
for i := range pi.L2MsgMerkleTreeRoots {
pi.L2MsgMerkleTreeRoots[i] = merkleRoot(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves])
pi.L2MsgMerkleTreeRoots[i] = MerkleRootSnark(&hshK, merkleLeavesConcat.Values[i*merkleNbLeaves:(i+1)*merkleNbLeaves])
}
// "open" aggregation public input
@@ -185,7 +192,7 @@ func (c *Circuit) Define(api frontend.API) error {
return hshK.Finalize()
}
func merkleRoot(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]frontend.Variable {
func MerkleRootSnark(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]frontend.Variable {
values := slices.Clone(leaves)
if !utils.IsPowerOfTwo(len(values)) {
@@ -201,38 +208,28 @@ func merkleRoot(hshK keccak.BlockHasher, leaves [][32]frontend.Variable) [32]fro
return values[0]
}
type Config struct {
MaxNbDecompression int
MaxNbExecution int
MaxNbCircuits int
MaxNbKeccakF int
MaxNbMsgPerExecution int
L2MsgMerkleDepth int
L2MessageMaxNbMerkle int // if not explicitly provided (i.e. non-positive) it will be set to maximum
}
type Compiled struct {
Circuit *Circuit
Keccak keccak.CompiledStrictHasher
}
func (c Config) Compile(wizardCompilationOpts ...func(iop *wizard.CompiledIOP)) (*Compiled, error) {
func Compile(c config.PublicInput, wizardCompilationOpts ...func(iop *wizard.CompiledIOP)) (*Compiled, error) {
if c.L2MessageMaxNbMerkle <= 0 {
if c.L2MsgMaxNbMerkle <= 0 {
merkleNbLeaves := 1 << c.L2MsgMerkleDepth
c.L2MessageMaxNbMerkle = (c.MaxNbExecution*c.MaxNbMsgPerExecution + merkleNbLeaves - 1) / merkleNbLeaves
c.L2MsgMaxNbMerkle = (c.MaxNbExecution*c.ExecutionMaxNbMsg + merkleNbLeaves - 1) / merkleNbLeaves
}
sh := c.newKeccakCompiler().Compile(c.MaxNbKeccakF, wizardCompilationOpts...)
sh := newKeccakCompiler(c).Compile(c.MaxNbKeccakF, wizardCompilationOpts...)
shc, err := sh.GetCircuit()
if err != nil {
return nil, err
}
circuit := c.allocateCircuit()
circuit := allocateCircuit(c)
circuit.Keccak = shc
for i := range circuit.ExecutionFPIQ {
circuit.ExecutionFPIQ[i].L2MessageHashes.Values = make([][32]frontend.Variable, c.MaxNbMsgPerExecution)
circuit.ExecutionFPIQ[i].L2MessageHashes.Values = make([][32]frontend.Variable, c.ExecutionMaxNbMsg)
}
return &Compiled{
@@ -241,33 +238,33 @@ func (c Config) Compile(wizardCompilationOpts ...func(iop *wizard.CompiledIOP))
}, nil
}
func (c *Compiled) getConfig() Config {
return Config{
MaxNbDecompression: len(c.Circuit.DecompressionFPIQ),
MaxNbExecution: len(c.Circuit.ExecutionFPIQ),
MaxNbKeccakF: c.Keccak.MaxNbKeccakF(),
MaxNbMsgPerExecution: len(c.Circuit.ExecutionFPIQ[0].L2MessageHashes.Values),
L2MsgMerkleDepth: c.Circuit.L2MessageMerkleDepth,
L2MessageMaxNbMerkle: c.Circuit.L2MessageMaxNbMerkle,
MaxNbCircuits: c.Circuit.MaxNbCircuits,
func (c *Compiled) getConfig() config.PublicInput {
return config.PublicInput{
MaxNbDecompression: len(c.Circuit.DecompressionFPIQ),
MaxNbExecution: len(c.Circuit.ExecutionFPIQ),
MaxNbKeccakF: c.Keccak.MaxNbKeccakF(),
ExecutionMaxNbMsg: len(c.Circuit.ExecutionFPIQ[0].L2MessageHashes.Values),
L2MsgMerkleDepth: c.Circuit.L2MessageMerkleDepth,
L2MsgMaxNbMerkle: c.Circuit.L2MessageMaxNbMerkle,
MaxNbCircuits: c.Circuit.MaxNbCircuits,
}
}
func (c Config) allocateCircuit() Circuit {
func allocateCircuit(c config.PublicInput) Circuit {
return Circuit{
DecompressionPublicInput: make([]frontend.Variable, c.MaxNbDecompression),
ExecutionPublicInput: make([]frontend.Variable, c.MaxNbExecution),
DecompressionFPIQ: make([]decompression.FunctionalPublicInputQSnark, c.MaxNbDecompression),
ExecutionFPIQ: make([]execution.FunctionalPublicInputQSnark, c.MaxNbExecution),
L2MessageMerkleDepth: c.L2MsgMerkleDepth,
L2MessageMaxNbMerkle: c.L2MessageMaxNbMerkle,
L2MessageMaxNbMerkle: c.L2MsgMaxNbMerkle,
MaxNbCircuits: c.MaxNbCircuits,
}
}
func (c Config) newKeccakCompiler() *keccak.StrictHasherCompiler {
func newKeccakCompiler(c config.PublicInput) *keccak.StrictHasherCompiler {
nbShnarf := c.MaxNbDecompression
nbMerkle := c.L2MessageMaxNbMerkle * ((1 << c.L2MsgMerkleDepth) - 1)
nbMerkle := c.L2MsgMaxNbMerkle * ((1 << c.L2MsgMerkleDepth) - 1)
res := keccak.NewStrictHasherCompiler(nbShnarf, nbMerkle, 2)
for i := 0; i < nbShnarf; i++ {
res.WithHashLengths(160) // 5 components in every shnarf
@@ -278,8 +275,36 @@ func (c Config) newKeccakCompiler() *keccak.StrictHasherCompiler {
}
// aggregation PI opening
res.WithHashLengths(32 * c.L2MessageMaxNbMerkle)
res.WithHashLengths(32 * c.L2MsgMaxNbMerkle)
res.WithHashLengths(384)
return &res
}
type builder struct {
*config.PublicInput
}
func NewBuilder(c config.PublicInput) circuits.Builder {
return builder{&c}
}
func (b builder) Compile() (constraint.ConstraintSystem, error) {
c, err := Compile(*b.PublicInput, WizardCompilationParameters()...)
if err != nil {
return nil, err
}
const estimatedNbConstraints = 35_000_000
cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, c.Circuit, frontend.WithCapacity(estimatedNbConstraints))
if err != nil {
return nil, err
}
if nbC := cs.GetNbConstraints(); nbC > estimatedNbConstraints || estimatedNbConstraints-nbC > 5_000_000 {
return nil, fmt.Errorf("constraint estimate is off; got %d", nbC)
}
return cs, nil
}
func WizardCompilationParameters() []func(iop *wizard.CompiledIOP) {
panic("implement me")
}

View File

@@ -1,4 +1,4 @@
package pi_interconnection
package pi_interconnection_test
import (
"errors"
@@ -9,6 +9,7 @@ import (
"github.com/consensys/gnark/test"
"github.com/consensys/zkevm-monorepo/prover/backend/aggregation"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
"github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/stretchr/testify/assert"
@@ -38,7 +39,7 @@ func TestMerkle(t *testing.T) {
assert.NoError(t, err)
toHashBytes[i] = x.Bytes()
toHashHex[i] = utils.HexEncodeToString(toHashBytes[i][:])
assert.Equal(t, 32, internal.Copy(toHashSnark[i][:], toHashBytes[i][:]))
assert.Equal(t, 32, utils.Copy(toHashSnark[i][:], toHashBytes[i][:]))
}
for i := len(toHashHex); i < len(toHashSnark); i++ { // pad with zeros
for j := range toHashSnark[i] {
@@ -49,7 +50,7 @@ func TestMerkle(t *testing.T) {
hsh := sha3.NewLegacyKeccak256()
for i := range rootsHex {
root := MerkleRoot(hsh, c.nbLeaves, toHashBytes[i*c.nbLeaves:min(len(toHashBytes), (i+1)*c.nbLeaves)])
root := pi_interconnection.MerkleRoot(hsh, c.nbLeaves, toHashBytes[i*c.nbLeaves:min(len(toHashBytes), (i+1)*c.nbLeaves)])
rootHex := utils.HexEncodeToString(root[:])
assert.Equal(t, rootsHex[i], rootHex)
}
@@ -90,7 +91,7 @@ func (c *testMerkleCircuit) Define(api frontend.API) error {
}
for i := range c.Roots {
root := merkleRoot(hshK, c.ToHash[i*nbLeaves:(i+1)*nbLeaves])
root := pi_interconnection.MerkleRootSnark(hshK, c.ToHash[i*nbLeaves:(i+1)*nbLeaves])
internal.AssertSliceEquals(api, c.Roots[i][:], root[:])
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/consensys/gnark/profile"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
"github.com/consensys/zkevm-monorepo/prover/config"
"github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy"
"github.com/stretchr/testify/assert"
)
@@ -16,14 +17,14 @@ func main() {
fmt.Println("creating wizard circuit")
c, err := pi_interconnection.Config{
MaxNbDecompression: 400,
MaxNbExecution: 400,
MaxNbKeccakF: 10000,
MaxNbMsgPerExecution: 16,
L2MsgMerkleDepth: 5,
L2MessageMaxNbMerkle: 10,
}.Compile(dummy.Compile)
c, err := pi_interconnection.Compile(config.PublicInput{
MaxNbDecompression: 400,
MaxNbExecution: 400,
MaxNbKeccakF: 10000,
ExecutionMaxNbMsg: 16,
L2MsgMerkleDepth: 5,
L2MsgMaxNbMerkle: 10,
}, dummy.Compile)
var t test_utils.FakeTestingT
assert.NoError(t, err)

View File

@@ -17,6 +17,7 @@ import (
"github.com/consensys/zkevm-monorepo/prover/circuits/internal/test_utils"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
pitesting "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/test_utils"
"github.com/consensys/zkevm-monorepo/prover/config"
"github.com/consensys/zkevm-monorepo/prover/crypto/mimc/gkrmimc"
blobtesting "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy"
@@ -34,15 +35,15 @@ func TestSingleBlockBlob(t *testing.T) {
func TestSingleBlobBlobE2E(t *testing.T) {
req := pitesting.AssignSingleBlockBlob(t)
config := pi_interconnection.Config{
MaxNbDecompression: len(req.Decompressions),
MaxNbExecution: len(req.Executions),
MaxNbKeccakF: 100,
MaxNbMsgPerExecution: 1,
L2MsgMerkleDepth: 5,
L2MessageMaxNbMerkle: 1,
cfg := config.PublicInput{
MaxNbDecompression: len(req.Decompressions),
MaxNbExecution: len(req.Executions),
MaxNbKeccakF: 100,
ExecutionMaxNbMsg: 1,
L2MsgMerkleDepth: 5,
L2MsgMaxNbMerkle: 1,
}
compiled, err := config.Compile(dummy.Compile)
compiled, err := pi_interconnection.Compile(cfg, dummy.Compile)
assert.NoError(t, err)
a, err := compiled.Assign(req)
@@ -74,7 +75,7 @@ func TestTinyTwoBatchBlob(t *testing.T) {
blob := blobtesting.TinyTwoBatchBlob(t)
execReq := []pi_interconnection.ExecutionRequest{{
execReq := []public_input.Execution{{
L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(3)},
FinalStateRootHash: internal.Uint64To32Bytes(4),
FinalBlockNumber: 5,
@@ -104,7 +105,6 @@ func TestTinyTwoBatchBlob(t *testing.T) {
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes))
req := pi_interconnection.Request{
DecompDict: blobtesting.GetDict(t),
Decompressions: []blobsubmission.Response{*blobResp},
Executions: execReq,
Aggregation: public_input.Aggregation{
@@ -131,7 +131,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
t.Skipf("Flacky test due to the number of keccakf outgoing the limit specified for the test")
blobs := blobtesting.ConsecutiveBlobs(t, 2, 2)
execReq := []pi_interconnection.ExecutionRequest{{
execReq := []public_input.Execution{{
L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(3)},
FinalStateRootHash: internal.Uint64To32Bytes(4),
FinalBlockNumber: 5,
@@ -186,7 +186,6 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes))
req := pi_interconnection.Request{
DecompDict: blobtesting.GetDict(t),
Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1},
Executions: execReq,
Aggregation: public_input.Aggregation{
@@ -215,17 +214,17 @@ func testPI(t *testing.T, maxNbKeccakF int, req pi_interconnection.Request) {
decomposeLittleEndian(t, slack[:], i, 3)
config := pi_interconnection.Config{
MaxNbDecompression: len(req.Decompressions) + slack[0],
MaxNbExecution: len(req.Executions) + slack[1],
MaxNbKeccakF: maxNbKeccakF,
MaxNbMsgPerExecution: 1 + slack[2],
L2MsgMerkleDepth: 5,
L2MessageMaxNbMerkle: 1 + slack[3],
config := config.PublicInput{
MaxNbDecompression: len(req.Decompressions) + slack[0],
MaxNbExecution: len(req.Executions) + slack[1],
MaxNbKeccakF: maxNbKeccakF,
ExecutionMaxNbMsg: 1 + slack[2],
L2MsgMerkleDepth: 5,
L2MsgMaxNbMerkle: 1 + slack[3],
}
t.Run(fmt.Sprintf("slack profile %v", slack), func(t *testing.T) {
compiled, err := config.Compile(dummy.Compile)
compiled, err := pi_interconnection.Compile(config, dummy.Compile)
assert.NoError(t, err)
a, err := compiled.Assign(req)

View File

@@ -6,6 +6,7 @@ import (
"github.com/consensys/gnark/test"
"github.com/consensys/zkevm-monorepo/prover/circuits/internal"
"github.com/consensys/zkevm-monorepo/prover/protocol/compiler/dummy"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/stretchr/testify/assert"
"testing"
)
@@ -36,8 +37,8 @@ func TestAssign(t *testing.T) {
}
assignment.H, err = hsh.Assign()
assert.NoError(t, err)
internal.Copy(assignment.Outs[0][:], res)
internal.Copy(assignment.Ins[0][0][:], zero[:])
utils.Copy(assignment.Outs[0][:], res)
utils.Copy(assignment.Ins[0][0][:], zero[:])
assert.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField()))
}

View File

@@ -31,7 +31,7 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request {
blobResp, err := blobsubmission.CraftResponse(&blobReq)
assert.NoError(t, err)
execReq := pi_interconnection.ExecutionRequest{
execReq := public_input.Execution{
L2MsgHashes: [][32]byte{internal.Uint64To32Bytes(4)},
FinalStateRootHash: finalStateRootHash,
FinalBlockNumber: 9,
@@ -43,9 +43,8 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request {
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq.L2MsgHashes))
return pi_interconnection.Request{
DecompDict: blobtesting.GetDict(t),
Decompressions: []blobsubmission.Response{*blobResp},
Executions: []pi_interconnection.ExecutionRequest{execReq},
Executions: []public_input.Execution{execReq},
Aggregation: public_input.Aggregation{
FinalShnarf: blobResp.ExpectedShnarf,
ParentAggregationFinalShnarf: blobReq.PrevShnarf,

View File

@@ -1,19 +1,20 @@
package circuits
// CircuitID is a type to represent the different circuits.
// Is is used to identify the circuit to be used in the prover.
// It is used to identify the circuit to be used in the prover.
type CircuitID string
const (
ExecutionCircuitID CircuitID = "execution"
ExecutionLargeCircuitID CircuitID = "execution-large"
BlobDecompressionV0CircuitID CircuitID = "blob-decompression-v0"
BlobDecompressionV1CircuitID CircuitID = "blob-decompression-v1"
AggregationCircuitID CircuitID = "aggregation"
EmulationCircuitID CircuitID = "emulation"
EmulationDummyCircuitID CircuitID = "emulation-dummy"
ExecutionDummyCircuitID CircuitID = "execution-dummy"
BlobDecompressionDummyCircuitID CircuitID = "blob-decompression-dummy"
ExecutionCircuitID CircuitID = "execution"
ExecutionLargeCircuitID CircuitID = "execution-large"
BlobDecompressionV0CircuitID CircuitID = "blob-decompression-v0"
BlobDecompressionV1CircuitID CircuitID = "blob-decompression-v1"
AggregationCircuitID CircuitID = "aggregation"
EmulationCircuitID CircuitID = "emulation"
EmulationDummyCircuitID CircuitID = "emulation-dummy"
ExecutionDummyCircuitID CircuitID = "execution-dummy"
BlobDecompressionDummyCircuitID CircuitID = "blob-decompression-dummy"
PublicInputInterconnectionCircuitID CircuitID = "public-input-interconnection"
)
// MockCircuitID is a type to represent the different mock circuits.

View File

@@ -89,7 +89,7 @@ func (s *Setup) CurveID() ecc.ID {
return fieldToCurve(s.Circuit.Field())
}
func (s *Setup) VerifiyingKeyDigest() string {
func (s *Setup) VerifyingKeyDigest() string {
r, err := objectChecksum(s.VerifyingKey)
if err != nil {
utils.Panic("could not get the verifying key digest: %v", err)

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"fmt"
pi_interconnection "github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection"
"io"
"os"
"path/filepath"
@@ -52,6 +53,7 @@ var allCircuits = []string{
string(circuits.AggregationCircuitID),
string(circuits.EmulationCircuitID),
string(circuits.EmulationDummyCircuitID), // we want to generate Verifier.sol for this one
string(circuits.PublicInputInterconnectionCircuitID),
}
func init() {
@@ -102,7 +104,7 @@ func cmdSetup(cmd *cobra.Command, args []string) error {
}
// for each circuit, we start by compiling the circuit
// the we do a shashum and compare against the one in the manifest.json
// then we do a sha sum and compare against the one in the manifest.json
for c, setup := range inCircuits {
if !setup {
// we skip aggregation in this first loop since the setup is more complex
@@ -139,11 +141,13 @@ func cmdSetup(cmd *cobra.Command, args []string) error {
extraFlags["maxUncompressedBytes"] = blob_v1.MaxUncompressedBytes
builder = v1.NewBuilder(len(dict))
}
case circuits.PublicInputInterconnectionCircuitID:
builder = pi_interconnection.NewBuilder(cfg.PublicInputInterconnection)
case circuits.EmulationDummyCircuitID:
// we can get the Verifier.sol from there.
builder = dummy.NewBuilder(circuits.MockCircuitIDEmulation, ecc.BN254.ScalarField())
default:
continue // dummy, aggregation or emulation circuits are handled later
continue // dummy, aggregation, emulation or public input circuits are handled later
}
if err := updateSetup(cmd.Context(), cfg, srsProvider, c, builder, extraFlags); err != nil {
@@ -164,6 +168,12 @@ func cmdSetup(cmd *cobra.Command, args []string) error {
return nil
}
// get verifying key for public-input circuit
piSetup, err := circuits.LoadSetup(cfg, circuits.PublicInputInterconnectionCircuitID)
if err != nil {
return fmt.Errorf("%s failed to load public input interconnection setup: %w", cmd.Name(), err)
}
// first, we need to collect the verifying keys
var allowedVkForAggregation []plonk.VerifyingKey
for _, allowedInput := range cfg.Aggregation.AllowedInputs {
@@ -217,7 +227,7 @@ func cmdSetup(cmd *cobra.Command, args []string) error {
c := circuits.CircuitID(fmt.Sprintf("%s-%d", string(circuits.AggregationCircuitID), numProofs))
logrus.Infof("setting up %s (numProofs=%d)", c, numProofs)
builder := aggregation.NewBuilder(numProofs, allowedVkForAggregation)
builder := aggregation.NewBuilder(numProofs, cfg.Aggregation.AllowedInputs, piSetup.VerifyingKey, allowedVkForAggregation)
if err := updateSetup(cmd.Context(), cfg, srsProvider, c, builder, extraFlagsForAggregationCircuit); err != nil {
return err
}

View File

@@ -100,10 +100,11 @@ type Config struct {
// accessed (prover). The file structure is described in TODO @gbotrel.
AssetsDir string `mapstructure:"assets_dir" validate:"required,dir"`
Controller Controller
Execution Execution
BlobDecompression BlobDecompression `mapstructure:"blob_decompression"`
Aggregation Aggregation
Controller Controller
Execution Execution
BlobDecompression BlobDecompression `mapstructure:"blob_decompression"`
Aggregation Aggregation
PublicInputInterconnection PublicInput `mapstructure:"public_input_interconnection"` // TODO add wizard compilation params
Debug struct {
// Profiling indicates whether we want to generate profiles using the [runtime/pprof] pkg.
@@ -249,3 +250,13 @@ func (cfg *WithRequestDir) DirTo() string {
func (cfg *WithRequestDir) DirDone() string {
return path.Join(cfg.RequestsRootDir, RequestsDoneSubDir)
}
type PublicInput struct {
MaxNbDecompression int `mapstructure:"max_nb_decompression" validate:"gte=0"`
MaxNbExecution int `mapstructure:"max_nb_execution" validate:"gte=0"`
MaxNbCircuits int `mapstructure:"max_nb_circuits" validate:"gte=0"` // if not set, will be set to MaxNbDecompression + MaxNbExecution
MaxNbKeccakF int `mapstructure:"max_nb_keccakf" validate:"gte=0"`
ExecutionMaxNbMsg int `mapstructure:"execution_max_nb_msg" validate:"gte=0"`
L2MsgMerkleDepth int `mapstructure:"l2_msg_merkle_depth" validate:"gte=0"`
L2MsgMaxNbMerkle int `mapstructure:"l2_msg_max_nb_merkle" validate:"gte=0"` // if not explicitly provided (i.e. non-positive) it will be set to maximum
}

View File

@@ -45,7 +45,7 @@ func main() {
// Building aggregation circuit for max `nc` proofs
logrus.Infof("Building aggregation circuit for size of %v\n", nc)
ccs, err := aggregation.MakeCS(nc, vkeys)
ccs, err := aggregation.MakeCS(nc, []string{"blob-decompression-v1", "execution"}, nil, vkeys) // TODO @Tabaie add a PI key
if err != nil {
panic(err)
}
@@ -89,7 +89,8 @@ func main() {
// Assigning the BW6 circuit
logrus.Infof("Generating the aggregation proof for arity %v", nc)
bw6Proof, err := aggregation.MakeProof(&ppBw6, nc, innerProofClaims, frBw6.NewElement(10))
// TODO @Tabaie add a PI proof
bw6Proof, err := aggregation.MakeProof(&ppBw6, nc, innerProofClaims, nil, frBw6.NewElement(10))
if err != nil {
panic(err)
}

View File

@@ -2,9 +2,15 @@ package public_input
import (
bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/zkevm-monorepo/prover/circuits/pi-interconnection/keccak"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/consensys/zkevm-monorepo/prover/utils/types"
"golang.org/x/crypto/sha3"
"hash"
"slices"
)
// Aggregation collects all the field that are used to construct the public
@@ -79,3 +85,160 @@ func (p Aggregation) Sum(hsh hash.Hash) []byte {
func (p Aggregation) GetPublicInputHex() string {
return utils.HexEncodeToString(p.Sum(sha3.NewLegacyKeccak256()))
}
// AggregationFPI holds the same info as public_input.Aggregation, except in parsed form
type AggregationFPI struct {
ParentShnarf [32]byte
NbDecompression uint64
InitialStateRootHash [32]byte
InitialBlockNumber uint64
InitialBlockTimestamp uint64
InitialRollingHash [32]byte
InitialRollingHashNumber uint64
ChainID uint64 // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr types.EthAddress
NbL2Messages uint64 // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]byte
//FinalStateRootHash [32]byte redundant: incorporated into shnarf
FinalBlockNumber uint64
FinalBlockTimestamp uint64
FinalRollingHash [32]byte
FinalRollingHashNumber uint64
FinalShnarf [32]byte
L2MsgMerkleTreeDepth int
}
func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark {
s := AggregationFPISnark{
AggregationFPIQSnark: AggregationFPIQSnark{
InitialBlockNumber: pi.InitialBlockNumber,
InitialBlockTimestamp: pi.InitialBlockTimestamp,
InitialRollingHash: [32]frontend.Variable{},
InitialRollingHashNumber: pi.InitialRollingHashNumber,
InitialStateRootHash: pi.InitialStateRootHash[:],
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
},
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
FinalRollingHashNumber: pi.FinalRollingHashNumber,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
}
utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:])
utils.Copy(s.InitialRollingHash[:], pi.InitialRollingHash[:])
utils.Copy(s.ParentShnarf[:], pi.ParentShnarf[:])
utils.Copy(s.FinalShnarf[:], pi.FinalShnarf[:])
for i := range s.L2MsgMerkleTreeRoots {
utils.Copy(s.L2MsgMerkleTreeRoots[i][:], pi.L2MsgMerkleTreeRoots[i][:])
}
return s
}
type AggregationFPIQSnark struct {
ParentShnarf [32]frontend.Variable
NbDecompression frontend.Variable
InitialStateRootHash frontend.Variable
InitialBlockNumber frontend.Variable
InitialBlockTimestamp frontend.Variable
InitialRollingHash [32]frontend.Variable
InitialRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
}
type AggregationFPISnark struct {
AggregationFPIQSnark
NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]frontend.Variable
// FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
}
// NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation
func NewAggregationFPI(fpi *Aggregation) (s *AggregationFPI, err error) {
s = &AggregationFPI{
InitialBlockNumber: uint64(fpi.LastFinalizedBlockNumber),
InitialBlockTimestamp: uint64(fpi.ParentAggregationLastBlockTimestamp),
InitialRollingHashNumber: uint64(fpi.LastFinalizedL1RollingHashMessageNumber),
L2MsgMerkleTreeRoots: make([][32]byte, len(fpi.L2MsgRootHashes)),
FinalBlockNumber: uint64(fpi.FinalBlockNumber),
FinalBlockTimestamp: uint64(fpi.FinalTimestamp),
FinalRollingHashNumber: uint64(fpi.L1RollingHashMessageNumber),
L2MsgMerkleTreeDepth: fpi.L2MsgMerkleTreeDepth,
}
if err = copyFromHex(s.InitialStateRootHash[:], fpi.ParentStateRootHash); err != nil {
return
}
if err = copyFromHex(s.FinalRollingHash[:], fpi.L1RollingHash); err != nil {
return
}
if err = copyFromHex(s.InitialRollingHash[:], fpi.LastFinalizedL1RollingHash); err != nil {
return
}
if err = copyFromHex(s.ParentShnarf[:], fpi.ParentAggregationFinalShnarf); err != nil {
return
}
if err = copyFromHex(s.FinalShnarf[:], fpi.FinalShnarf); err != nil {
return
}
for i := range s.L2MsgMerkleTreeRoots {
if err = copyFromHex(s.L2MsgMerkleTreeRoots[i][:], fpi.L2MsgRootHashes[i]); err != nil {
return
}
}
return
}
func (pi *AggregationFPISnark) Sum(api frontend.API, hash keccak.BlockHasher) [32]frontend.Variable {
// number of hashes: 12
sum := hash.Sum(nil,
pi.ParentShnarf,
pi.FinalShnarf,
utils.ToBytes(api, pi.InitialBlockTimestamp),
utils.ToBytes(api, pi.FinalBlockTimestamp),
utils.ToBytes(api, pi.InitialBlockNumber),
utils.ToBytes(api, pi.FinalBlockNumber),
pi.InitialRollingHash,
pi.FinalRollingHash,
utils.ToBytes(api, pi.InitialRollingHashNumber),
utils.ToBytes(api, pi.FinalRollingHashNumber),
utils.ToBytes(api, pi.L2MsgMerkleTreeDepth),
hash.Sum(nil, pi.L2MsgMerkleTreeRoots...),
)
// turn the hash into a bn254 element
var res [32]frontend.Variable
copy(res[:], utils.ReduceBytes[emulated.BN254Fr](api, sum[:]))
return res
}
func (pi *AggregationFPIQSnark) RangeCheck(api frontend.API) {
rc := rangecheck.New(api)
for _, v := range append(slices.Clone(pi.InitialRollingHash[:]), pi.ParentShnarf[:]...) {
rc.Check(v, 8)
}
// not checking L2MsgServiceAddr as its range is never assumed in the pi circuit
// not checking NbDecompressions as the NewRange in the pi circuit range checks it; TODO do it here instead
}
func copyFromHex(dst []byte, src string) error {
b, err := utils.HexDecodeString(src)
if err != nil {
return err
}
copy(dst[len(dst)-len(b):], b) // panics if src is too long
return nil
}

View File

@@ -0,0 +1,61 @@
package public_input
import (
"errors"
"github.com/consensys/zkevm-monorepo/prover/utils"
)
type Execution struct {
L2MsgHashes [][32]byte
FinalStateRootHash [32]byte
FinalBlockNumber uint64
FinalBlockTimestamp uint64
FinalRollingHash [32]byte
FinalRollingHashNumber uint64
}
type ExecutionSerializable struct {
L2MsgHashes []string `json:"l2MsgHashes"`
FinalStateRootHash string `json:"finalStateRootHash"`
FinalBlockNumber uint64 `json:"finalBlockNumber"`
FinalBlockTimestamp uint64 `json:"finalBlockTimestamp"`
FinalRollingHash string `json:"finalRollingHash"`
FinalRollingHashNumber uint64 `json:"finalRollingHashNumber"`
}
func (e ExecutionSerializable) Decode() (decoded Execution, err error) {
decoded = Execution{
L2MsgHashes: make([][32]byte, len(e.L2MsgHashes)),
FinalBlockNumber: e.FinalBlockNumber,
FinalBlockTimestamp: e.FinalBlockTimestamp,
FinalRollingHashNumber: e.FinalRollingHashNumber,
}
fillWithHex := func(dst []byte, src string) {
var d []byte
if d, err = utils.HexDecodeString(src); err != nil {
return
}
if len(d) > len(dst) {
err = errors.New("decoded bytes too long")
}
n := len(dst) - len(d)
copy(dst[n:], d)
for n > 0 {
n--
dst[n] = 0
}
}
for i, hex := range e.L2MsgHashes {
if fillWithHex(decoded.L2MsgHashes[i][:], hex); err != nil {
return
}
}
if fillWithHex(decoded.FinalStateRootHash[:], e.FinalStateRootHash); err != nil {
return
}
fillWithHex(decoded.FinalRollingHash[:], e.FinalRollingHash)
return
}

120
prover/utils/snark.go Normal file
View File

@@ -0,0 +1,120 @@
package utils
import (
"errors"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"math/big"
)
func Copy[T any](dst []frontend.Variable, src []T) (n int) {
n = min(len(dst), len(src))
for i := 0; i < n; i++ {
dst[i] = src[i]
}
return
}
func ToBytes(api frontend.API, x frontend.Variable) [32]frontend.Variable {
var res [32]frontend.Variable
d := decomposeIntoBytes(api, x, fr.Bits)
slack := 32 - len(d) // should be zero
copy(res[slack:], d)
for i := 0; i < slack; i++ {
res[i] = 0
}
return res
}
func decomposeIntoBytes(api frontend.API, data frontend.Variable, nbBits int) []frontend.Variable {
if nbBits == 0 {
nbBits = api.Compiler().FieldBitLen()
}
nbBytes := (api.Compiler().FieldBitLen() + 7) / 8
bytes, err := api.Compiler().NewHint(decomposeIntoBytesHint, nbBytes, data)
if err != nil {
panic(err)
}
lastNbBits := nbBits % 8
if lastNbBits == 0 {
lastNbBits = 8
}
rc := rangecheck.New(api)
api.AssertIsLessOrEqual(bytes[0], 1<<lastNbBits-1) //TODO try range checking this as well
for i := 1; i < nbBytes; i++ {
rc.Check(bytes[i], 8)
}
return bytes
}
func decomposeIntoBytesHint(_ *big.Int, ins, outs []*big.Int) error {
nbBytes := len(outs) / len(ins)
if nbBytes*len(ins) != len(outs) {
return errors.New("incongruent number of ins/outs")
}
var v, radix, zero big.Int
radix.SetUint64(256)
for i := range ins {
v.Set(ins[i])
for j := nbBytes - 1; j >= 0; j-- {
outs[i*nbBytes+j].Mod(&v, &radix)
v.Rsh(&v, 8)
}
if v.Cmp(&zero) != 0 {
return errors.New("not fitting in len(outs)/len(ins) many bytes")
}
}
return nil
}
func RegisterHints() {
solver.RegisterHint(decomposeIntoBytesHint)
}
// ReduceBytes reduces given bytes modulo a given field. As a side effect, the "bytes" are range checked
func ReduceBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) []frontend.Variable {
f, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
bits := f.ToBits(NewElementFromBytes[T](api, bytes))
res := make([]frontend.Variable, (len(bits)+7)/8)
copy(bits[:], bits)
for i := len(bits); i < len(bits); i++ {
bits[i] = 0
}
for i := range res {
bitsStart := 8 * (len(res) - i - 1)
bitsEnd := bitsStart + 8
if i == 0 {
bitsEnd = len(bits)
}
res[i] = api.FromBinary(bits[bitsStart:bitsEnd]...)
}
return res
}
// NewElementFromBytes range checks the bytes and gives a reduced field element
func NewElementFromBytes[T emulated.FieldParams](api frontend.API, bytes []frontend.Variable) *emulated.Element[T] {
bits := make([]frontend.Variable, 8*len(bytes))
for i := range bytes {
copy(bits[8*i:], api.ToBinary(bytes[len(bytes)-i-1], 8))
}
f, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
return f.Reduce(f.Add(f.FromBits(bits...), f.Zero()))
}