Fix: test PI Interconnection with Sepolia data (#208)

* feat better pi assignment error finding
* fix append to allL2MessageHashes
* feat relax rolling hash verification in circuit
* PR feedback: remove unit test

---------

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
Co-authored-by: AlexandreBelling <alexandrebelling8@gmail.com>
This commit is contained in:
Arya Tabaie
2024-10-21 12:48:40 -05:00
committed by GitHub
parent 8bb5fe7ac2
commit 333bc7718b
9 changed files with 122 additions and 49 deletions

View File

@@ -39,9 +39,9 @@ const (
func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
var (
l2MessageHashes []string
l2MsgBlockOffsets []bool
cf = &CollectedFields{
allL2MessageHashes []string
l2MsgBlockOffsets []bool
cf = &CollectedFields{
L2MsgTreeDepth: l2MsgMerkleTreeDepth,
ParentAggregationLastBlockTimestamp: uint(req.ParentAggregationLastBlockTimestamp),
LastFinalizedL1RollingHash: req.ParentAggregationLastL1RollingHash,
@@ -55,9 +55,10 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
for i, execReqFPath := range req.ExecutionProofs {
var (
po = &execution.Response{}
fpath = path.Join(cfg.Execution.DirTo(), execReqFPath)
f = files.MustRead(fpath)
po = &execution.Response{}
l2MessageHashes []string
fpath = path.Join(cfg.Execution.DirTo(), execReqFPath)
f = files.MustRead(fpath)
)
if err := json.NewDecoder(f).Decode(po); err != nil {
@@ -112,7 +113,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
finalBlock := &po.BlocksData[len(po.BlocksData)-1]
piq, err := public_input.ExecutionSerializable{
L2MsgHashes: l2MessageHashes,
FinalStateRootHash: po.PublicInput.Hex(), // TODO @tabaie make sure this is the right value
FinalStateRootHash: finalBlock.RootHash.Hex(),
FinalBlockNumber: uint64(cf.FinalBlockNumber),
FinalBlockTimestamp: finalBlock.TimeStamp,
FinalRollingHash: cf.L1RollingHash,
@@ -123,6 +124,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
}
cf.ExecutionPI = append(cf.ExecutionPI, piq)
}
allL2MessageHashes = append(allL2MessageHashes, l2MessageHashes...)
}
cf.DecompressionPI = make([]blobsubmission.Response, 0, len(req.DecompressionProofs))
@@ -166,7 +169,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
}
cf.L2MessagingBlocksOffsets = utils.HexEncodeToString(PackOffsets(l2MsgBlockOffsets))
cf.L2MsgRootHashes = PackInMiniTrees(l2MessageHashes)
cf.L2MsgRootHashes = PackInMiniTrees(allL2MessageHashes)
return cf, nil

View File

@@ -61,9 +61,13 @@ func assign(
funcInputs FunctionalPublicInput,
) CircuitExecution {
wizardVerifier := wizard.GetWizardVerifierCircuitAssignment(comp, proof)
fpiSnark, err := funcInputs.ToSnarkType()
if err != nil {
panic(err) // TODO error handling
}
return CircuitExecution{
WizardVerifier: *wizardVerifier,
FuncInputs: funcInputs.ToSnarkType(),
FuncInputs: fpiSnark,
PublicInput: new(big.Int).SetBytes(funcInputs.Sum()),
}
}

View File

@@ -2,6 +2,7 @@ package execution
import (
"encoding/binary"
"fmt"
"hash"
"slices"
@@ -147,7 +148,7 @@ func (pi *FunctionalPublicInputSnark) Sum(api frontend.API, hsh gnarkHash.FieldH
return hsh.Sum()
}
func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
func (pi *FunctionalPublicInput) ToSnarkType() (FunctionalPublicInputSnark, error) {
res := FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{
DataChecksum: slices.Clone(pi.DataChecksum[:]),
@@ -167,7 +168,12 @@ func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
utils.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:])
utils.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:])
return res
var err error
if nbMsg := len(pi.L2MessageHashes); nbMsg > pi.MaxNbL2MessageHashes {
err = fmt.Errorf("has %d L2 message hashes but a maximum of %d is allowed", nbMsg, pi.MaxNbL2MessageHashes)
}
return res, err
}
func (pi *FunctionalPublicInput) Sum() []byte { // all mimc; no need to provide a keccak hasher

View File

@@ -1,6 +1,7 @@
package execution
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/consensys/gnark/frontend"
@@ -35,7 +36,8 @@ func TestPIConsistency(t *testing.T) {
pi.InitialStateRootHash[0] &= 0x0f
pi.FinalStateRootHash[0] &= 0x0f
snarkPi := pi.ToSnarkType()
snarkPi, err := pi.ToSnarkType()
require.NoError(t, err)
piSum := pi.Sum()
snarkTestUtils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable {

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"hash"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
@@ -121,7 +122,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}
if prevShnarf = shnarf.Compute(); !bytes.Equal(prevShnarf, shnarfs[i]) {
err = errors.New("shnarf mismatch")
err = fmt.Errorf("shnarf mismatch, i:%d, shnarf: %x, prevShnarf: %x, ", i, shnarfs[i], prevShnarf)
return
}
}
@@ -163,7 +164,6 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}
// Aggregation FPI
aggregationFPI, err := public_input.NewAggregationFPI(&r.Aggregation)
if err != nil {
return
@@ -178,6 +178,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
merkleNbLeaves := 1 << config.L2MsgMerkleDepth
maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves
l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes)
// Execution FPI
executionFPI := execution.FunctionalPublicInput{
FinalStateRootHash: aggregationFPI.InitialStateRootHash,
@@ -213,21 +214,48 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}
a.ExecutionPublicInput[i] = executionFPI.Sum()
a.ExecutionFPIQ[i] = executionFPI.ToSnarkType().FunctionalPublicInputQSnark
if snarkFPI, _err := executionFPI.ToSnarkType(); _err != nil {
err = fmt.Errorf("execution #%d: %w", i, _err)
return
} else {
a.ExecutionFPIQ[i] = snarkFPI.FunctionalPublicInputQSnark
}
}
// consistency check
if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp ||
executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber ||
executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash ||
executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber {
err = errors.New("final execution values not matching final aggregation values")
if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp {
err = fmt.Errorf("final block timestamps do not match: execution=%x, aggregation=%x",
executionFPI.FinalBlockTimestamp, aggregationFPI.FinalBlockTimestamp)
return
}
if executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber {
err = fmt.Errorf("final block numbers do not match: execution=%v, aggregation=%x",
executionFPI.FinalBlockNumber, aggregationFPI.FinalBlockNumber)
return
}
if executionFPI.FinalRollingHash != [32]byte{} {
if executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash {
err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x",
executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash)
return
}
if executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber {
err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v",
executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber)
return
}
}
if len(l2MessageHashes) > maxNbL2MessageHashes {
err = errors.New("too many L2 messages")
return
}
if minNbRoots := (len(l2MessageHashes) + merkleNbLeaves - 1) / merkleNbLeaves; len(r.Aggregation.L2MsgRootHashes) < minNbRoots {
err = fmt.Errorf("the %d merkle roots provided are too few to accommodate all %d execution messages. A minimum of %d is needed", len(r.Aggregation.L2MsgRootHashes), len(l2MessageHashes), minNbRoots)
return
}
for i := range r.Aggregation.L2MsgRootHashes {
var expectedRoot []byte
if expectedRoot, err = utils.HexDecodeString(r.Aggregation.L2MsgRootHashes[i]); err != nil {
@@ -239,6 +267,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
return
}
}
// padding merkle root hashes
emptyTree := make([][]byte, config.L2MsgMerkleDepth+1)
emptyTree[0] = make([]byte, 64)

View File

@@ -172,16 +172,35 @@ func (c *Circuit) Define(api frontend.API) error {
}
rExecution := internal.NewRange(api, nbExecution, maxNbExecution)
twoPow8 := big.NewInt(256)
hi16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[:16], twoPow8)
}
lo16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[16:], twoPow8)
}
{ // if rolling hash values are present in the last execution, they must match those of aggregation
finalRollingHashFromExec := rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash })
finalRollingHashNumFromExec := rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber })
h, l := hi16B(finalRollingHashFromExec), lo16B(finalRollingHashFromExec)
finalRollingHashPresent := api.Sub(1, api.Mul(api.IsZero(h), api.IsZero(l)))
internal.AssertEqualIf(api, finalRollingHashPresent, h, hi16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, l, lo16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, finalRollingHashNumFromExec, c.FinalRollingHashNumber)
}
pi := public_input.AggregationFPISnark{
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }),
FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
}
for i := range pi.L2MsgMerkleTreeRoots {
@@ -190,7 +209,6 @@ func (c *Circuit) Define(api frontend.API) error {
// "open" aggregation public input
aggregationPIBytes := pi.Sum(api, &hshK)
twoPow8 := big.NewInt(256)
api.AssertIsEqual(c.AggregationPublicInput[0], compress.ReadNum(api, aggregationPIBytes[:16], twoPow8))
api.AssertIsEqual(c.AggregationPublicInput[1], compress.ReadNum(api, aggregationPIBytes[16:], twoPow8))

View File

@@ -15,7 +15,7 @@ import (
"github.com/consensys/linea-monorepo/prover/backend/aggregation"
"github.com/consensys/linea-monorepo/prover/backend/blobsubmission"
"github.com/consensys/linea-monorepo/prover/circuits/internal"
"github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils"
circuittesting "github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils"
pi_interconnection "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection"
pitesting "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection/test_utils"
"github.com/consensys/linea-monorepo/prover/config"
@@ -101,7 +101,7 @@ func TestTinyTwoBatchBlob(t *testing.T) {
blobResp, err := blobsubmission.CraftResponse(&blobReq)
assert.NoError(t, err)
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes))
merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes))
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp},
@@ -182,7 +182,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
blobResp1, err := blobsubmission.CraftResponse(&blobReq1)
assert.NoError(t, err)
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes))
merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes))
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1},

View File

@@ -123,15 +123,15 @@ func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark {
InitialRollingHashNumber: pi.InitialRollingHashNumber,
InitialStateRootHash: pi.InitialStateRootHash[:],
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
FinalRollingHashNumber: pi.FinalRollingHashNumber,
},
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
FinalRollingHashNumber: pi.FinalRollingHashNumber,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
}
utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:])
@@ -154,8 +154,12 @@ type AggregationFPIQSnark struct {
InitialBlockTimestamp frontend.Variable
InitialRollingHash [32]frontend.Variable
InitialRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
// Ideally, FinalRollingHash and FinalRollingHashNumber would be inferred from the executions
// but sometimes executions are missing those values
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
}
type AggregationFPISnark struct {
@@ -163,12 +167,10 @@ type AggregationFPISnark struct {
NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]frontend.Variable
// FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
}
// NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation

View File

@@ -3,9 +3,12 @@ package test_utils
import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"fmt"
"github.com/stretchr/testify/require"
"math"
"os"
"testing"
)
type FakeTestingT struct{}
@@ -37,3 +40,9 @@ func RandIntSliceN(length, n int) []int {
}
return res
}
func LoadJson(t *testing.T, path string, v any) {
in, err := os.Open(path)
require.NoError(t, err)
require.NoError(t, json.NewDecoder(in).Decode(v))
}