Prover(perf): faster global constraints compilation (#704)

* bench(global): adds a benchmark for the global constraint compiler

* perf(merging): accumulates the factors before creating the expression

* perf(product): computes the ESH without using a smart-vector

* perf(factor): preallocations in the factorization algorithm

* perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs

* perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations.

* perf(factor): use an integer map instead of a field.Element map when possible

* fixup(expands): fix the skip condition for term expansion

* perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb

* feat(repr): adds a json repr function to help debugging

* test(constructor): cleans the test of the constructors

* perf(factor): address maps using the first limb of a field.Element instead of the full field.Element

* fixup(commit): adds missing file in previous commit

* perf(factor): reduce the number of calls to rankChildren

* perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims

* perf(factors): use a counter in getCommonProdParentOfCs

* perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function

* clean(factor): remove unneeded function and imports

* feat(utils): adds a generic sort interface implementation

* perf(rankChildren): lazy allocation of the map to save on allocations

* perf(factorize): reduces the loop-bound for factorizeExpression

* (chore): fix a missing argument and format gofmt

* feat: readd test

---------

Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com>
Co-authored-by: gusiri <dreamerty@postech.ac.kr>
This commit is contained in:
AlexandreBelling
2025-03-21 12:55:54 +01:00
committed by GitHub
parent 869c0c63d4
commit 7334693931
12 changed files with 333 additions and 201 deletions

View File

@@ -0,0 +1,75 @@
package globalcs_test
import (
"os"
"testing"
"github.com/consensys/linea-monorepo/prover/config"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/cleanup"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/zkevm"
"github.com/sirupsen/logrus"
)
// BenchmarkGlobalConstraint benchmarks the global constraints compiler against the
// actual zk-evm constraint system.
func BenchmarkGlobalConstraintWoArtefacts(b *testing.B) {
// partialSuite corresponds to the actual compilation suite of
// the full zk-evm down to the point where the global constraints
// are compiled.
partialSuite := []func(*wizard.CompiledIOP){
mimc.CompileMiMC,
specialqueries.RangeProof,
specialqueries.CompileFixedPermutations,
permutation.CompileGrandProduct,
lookup.CompileLogDerivative,
innerproduct.Compile,
sticker.Sticker(1<<10, 1<<19),
splitter.SplitColumns(1 << 19),
cleanup.CleanUp,
localcs.Compile,
globalcs.Compile,
}
// In order to load the config we need to position ourselves in the root
// folder.
_ = os.Chdir("../../..")
defer os.Chdir("protocol/compiler/globalcs")
// config corresponds to the config we use on sepolia
cfg, cfgErr := config.NewConfigFromFile("./config/config-sepolia-full.toml")
if cfgErr != nil {
b.Fatalf("could not find the config: %v", cfgErr)
}
// Shut the logger to not overwhelm the benchmark output
logrus.SetLevel(logrus.PanicLevel)
b.ResetTimer()
for c_ := 0; c_ < b.N; c_++ {
b.StopTimer()
// Removes the artefacts to
if err := os.RemoveAll("/tmp/prover-artefacts"); err != nil {
b.Fatalf("could not remove the artefacts: %v", err)
}
b.StartTimer()
_ = zkevm.FullZKEVMWithSuite(&cfg.TracesLimits, partialSuite, cfg)
}
}

View File

@@ -146,19 +146,28 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
domainSize = cs.DomainSize domainSize = cs.DomainSize
x = variables.NewXVar() x = variables.NewXVar()
omega = fft.GetOmega(domainSize) omega = fft.GetOmega(domainSize)
// factors is a list of expression to multiply to obtain the return expression. It
// is initialized with "only" the initial expression and we iteratively add the
// terms (X-i) to it. At the end, we call [sym.Mul] a single time. This structure
// is important because it [sym.Mul] operates a sequence of optimization routines
// that are everytime we call it. In an earlier version, we were calling [sym.Mul]
// for every factor and this were making the function have a quadratic/cubic runtime.
factors = make([]any, 0, utils.Abs(cancelRange.Max)+utils.Abs(cancelRange.Min)+1)
) )
// cancelExprAtPoint cancels the expression at a particular position factors = append(factors, res)
cancelExprAtPoint := func(expr *symbolic.Expression, i int) *symbolic.Expression {
// appendFactor appends an expressions representing $X-\rho^i$ to [factors]
appendFactor := func(i int) {
var root field.Element var root field.Element
root.Exp(omega, big.NewInt(int64(i))) root.Exp(omega, big.NewInt(int64(i)))
return symbolic.Mul(expr, symbolic.Sub(x, root)) factors = append(factors, symbolic.Sub(x, root))
} }
if cancelRange.Min < 0 { if cancelRange.Min < 0 {
// Cancels the expression on the range [0, -cancelRange.Min) // Cancels the expression on the range [0, -cancelRange.Min)
for i := 0; i < -cancelRange.Min; i++ { for i := 0; i < -cancelRange.Min; i++ {
res = cancelExprAtPoint(res, i) appendFactor(i)
} }
} }
@@ -166,11 +175,18 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
// Cancels the expression on the range (N-cancelRange.Max-1, N-1] // Cancels the expression on the range (N-cancelRange.Max-1, N-1]
for i := 0; i < cancelRange.Max; i++ { for i := 0; i < cancelRange.Max; i++ {
point := domainSize - i - 1 // point at which we want to cancel the constraint point := domainSize - i - 1 // point at which we want to cancel the constraint
res = cancelExprAtPoint(res, point) appendFactor(point)
} }
} }
return res // When factors is of length 1, it means the expression does not need to be
// bound-cancelled and we can directly return the original expression
// without calling [sym.Mul].
if len(factors) == 1 {
return res
}
return symbolic.Mul(factors...)
} }
// getExprRatio computes the ratio of the expression and ceil to the next power // getExprRatio computes the ratio of the expression and ceil to the next power

View File

@@ -27,12 +27,12 @@ func Add(inputs ...any) *Expression {
exprInputs := intoExprSlice(inputs...) exprInputs := intoExprSlice(inputs...)
res := exprInputs[0] magnitudes := make([]int, len(exprInputs))
for i := 1; i < len(exprInputs); i++ { for i := range exprInputs {
res = res.Add(exprInputs[i]) magnitudes[i] = 1
} }
return res return NewLinComb(exprInputs, magnitudes)
} }
// Mul constructs a symbolic expression representing the product of its inputs. // Mul constructs a symbolic expression representing the product of its inputs.
@@ -57,12 +57,12 @@ func Mul(inputs ...any) *Expression {
exprInputs := intoExprSlice(inputs...) exprInputs := intoExprSlice(inputs...)
res := exprInputs[0] magnitudes := make([]int, len(exprInputs))
for i := 1; i < len(exprInputs); i++ { for i := range exprInputs {
res = res.Mul(exprInputs[i]) magnitudes[i] = 1
} }
return res return NewProduct(exprInputs, magnitudes)
} }
// Sub returns a symbolic expression representing the subtraction of `a` by all // Sub returns a symbolic expression representing the subtraction of `a` by all
@@ -82,16 +82,18 @@ func Sub(a any, bs ...any) *Expression {
} }
var ( var (
aExpr = intoExpr(a) aExpr = intoExpr(a)
bExpr = intoExprSlice(bs...) bExpr = intoExprSlice(bs...)
res = aExpr exprInputs = append([]*Expression{aExpr}, bExpr...)
magnitudes = make([]int, len(exprInputs))
) )
for i := range bExpr { for i := range exprInputs {
res = res.Sub(bExpr[i]) magnitudes[i] = -1
} }
magnitudes[0] = 1
return res return NewLinComb(exprInputs, magnitudes)
} }
// Neg returns an expression representing the negation of an expression or of // Neg returns an expression representing the negation of an expression or of
@@ -119,7 +121,7 @@ func Square(x any) *Expression {
return nil return nil
} }
return intoExpr(x).Square() return NewProduct([]*Expression{intoExpr(x)}, []int{2})
} }
// Pow returns an expression representing the raising to the power "n" of an // Pow returns an expression representing the raising to the power "n" of an

View File

@@ -1,6 +1,7 @@
package symbolic package symbolic
import ( import (
"encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"sync" "sync"
@@ -284,5 +285,14 @@ func (e *Expression) SameWithNewChildren(newChildren []*Expression) *Expression
default: default:
panic("unexpected type: " + reflect.TypeOf(op).String()) panic("unexpected type: " + reflect.TypeOf(op).String())
} }
}
// MarshalJSONString returns a JSON string returns a JSON string representation
// of the expression.
func (e *Expression) MarshalJSONString() string {
js, jsErr := json.MarshalIndent(e, "", " ")
if jsErr != nil {
utils.Panic("failed to marshal expression: %v", jsErr)
}
return string(js)
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils/collection" "github.com/consensys/linea-monorepo/prover/utils/collection"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -96,7 +97,7 @@ func TestLCConstruction(t *testing.T) {
x := NewDummyVar("x") x := NewDummyVar("x")
y := NewDummyVar("y") y := NewDummyVar("y")
{ t.Run("simple-addition", func(t *testing.T) {
/* /*
Test t a simple case of addition Test t a simple case of addition
*/ */
@@ -108,23 +109,23 @@ func TestLCConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs)) require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs))
require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1) require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1)
require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1) require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1)
} })
{ t.Run("x-y-x", func(t *testing.T) {
/* /*
Adding y then substracting x should give back (y) Adding y then substracting x should give back (y)
*/ */
expr1 := x.Add(y).Sub(x) expr1 := x.Add(y).Sub(x)
require.Equal(t, expr1, y) require.Equal(t, expr1, y)
} })
{ t.Run("(-x)+x+y", func(t *testing.T) {
/* /*
Same thing when using Neg Same thing when using Neg
*/ */
expr := x.Neg().Add(x).Add(y) expr := x.Neg().Add(x).Add(y)
require.Equal(t, expr, y) assert.Equal(t, expr, y)
} })
} }
@@ -133,7 +134,7 @@ func TestProductConstruction(t *testing.T) {
x := NewDummyVar("x") x := NewDummyVar("x")
y := NewDummyVar("y") y := NewDummyVar("y")
{ t.Run("x * y", func(t *testing.T) {
/* /*
Test t a simple case of addition Test t a simple case of addition
*/ */
@@ -145,9 +146,9 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents)) require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
require.Equal(t, expr1.Operator.(Product).Exponents[0], 1) require.Equal(t, expr1.Operator.(Product).Exponents[0], 1)
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1) require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
} })
{ t.Run("x * y * x", func(t *testing.T) {
/* /*
Adding y then substracting x should give back (y) Adding y then substracting x should give back (y)
*/ */
@@ -158,9 +159,9 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents)) require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
require.Equal(t, expr1.Operator.(Product).Exponents[0], 2) require.Equal(t, expr1.Operator.(Product).Exponents[0], 2)
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1) require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
} })
{ t.Run("x^2", func(t *testing.T) {
/* /*
When we square When we square
*/ */
@@ -169,6 +170,6 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, expr.Children[0], x) require.Equal(t, expr.Children[0], x)
require.Equal(t, 1, len(expr.Operator.(Product).Exponents)) require.Equal(t, 1, len(expr.Operator.(Product).Exponents))
require.Equal(t, expr.Operator.(Product).Exponents[0], 2) require.Equal(t, expr.Operator.(Product).Exponents[0], 2)
} })
} }

View File

@@ -95,18 +95,30 @@ func NewProduct(items []*Expression, exponents []int) *Expression {
e := &Expression{ e := &Expression{
Operator: Product{Exponents: exponents}, Operator: Product{Exponents: exponents},
Children: items, Children: items,
ESHash: field.One(),
} }
// Now we need to assign the ESH
eshashes := make([]sv.SmartVector, len(e.Children))
for i := range e.Children { for i := range e.Children {
eshashes[i] = sv.NewConstant(e.Children[i].ESHash, 1) var tmp field.Element
} switch {
case exponents[i] == 1:
if len(items) > 0 { e.ESHash.Mul(&e.ESHash, &e.Children[i].ESHash)
// The cast back to sv.Constant is no important functionally but is an easy case exponents[i] == 2:
// sanity check. tmp.Square(&e.Children[i].ESHash)
e.ESHash = e.Operator.Evaluate(eshashes).(*sv.Constant).Get(0) e.ESHash.Mul(&e.ESHash, &tmp)
case exponents[i] == 3:
tmp.Square(&e.Children[i].ESHash)
tmp.Mul(&tmp, &e.Children[i].ESHash)
e.ESHash.Mul(&e.ESHash, &tmp)
case exponents[i] == 4:
tmp.Square(&e.Children[i].ESHash)
tmp.Square(&tmp)
e.ESHash.Mul(&e.ESHash, &tmp)
default:
exponent := big.NewInt(int64(exponents[i]))
tmp.Exp(e.Children[i].ESHash, exponent)
e.ESHash.Mul(&e.ESHash, &tmp)
}
} }
return e return e

View File

@@ -78,22 +78,38 @@ func regroupTerms(magnitudes []int, children []*Expression) (
// the linear combination. This function is used both for simplifying [LinComb] // the linear combination. This function is used both for simplifying [LinComb]
// expressions and for simplifying [Product]. "magnitude" denotes either the // expressions and for simplifying [Product]. "magnitude" denotes either the
// coefficient for LinComb or exponents for Product. // coefficient for LinComb or exponents for Product.
//
// The function takes ownership of the provided slices.
func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) { func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) {
if len(magnitudes) != len(children) { if len(magnitudes) != len(children) {
panic("magnitudes and children don't have the same length") panic("magnitudes and children don't have the same length")
} }
cleanChildren = make([]*Expression, 0, len(children)) // cleanChildren and cleanMagnitudes are initialized lazily to
cleanMagnitudes = make([]int, 0, len(children)) // avoid unnecessarily allocating memory. The underlying assumption
// is that the application will 99% of time never pass zero as a
// magnitude.
for i, c := range magnitudes { for i, c := range magnitudes {
if c != 0 {
if c == 0 && cleanChildren == nil {
cleanChildren = make([]*Expression, i, len(children))
cleanMagnitudes = make([]int, i, len(children))
copy(cleanChildren, children[:i])
copy(cleanMagnitudes, magnitudes[:i])
}
if c != 0 && cleanChildren != nil {
cleanMagnitudes = append(cleanMagnitudes, magnitudes[i]) cleanMagnitudes = append(cleanMagnitudes, magnitudes[i])
cleanChildren = append(cleanChildren, children[i]) cleanChildren = append(cleanChildren, children[i])
} }
} }
if cleanChildren == nil {
cleanChildren = children
cleanMagnitudes = magnitudes
}
return cleanMagnitudes, cleanChildren return cleanMagnitudes, cleanChildren
} }
@@ -108,14 +124,16 @@ func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes
// The caller passes a target operator which may be any value of type either // The caller passes a target operator which may be any value of type either
// [LinComb] or [Product]. Any other type yields a panic error. // [LinComb] or [Product]. Any other type yields a panic error.
func expandTerms(op Operator, magnitudes []int, children []*Expression) ( func expandTerms(op Operator, magnitudes []int, children []*Expression) (
expandedMagnitudes []int, []int,
expandedExpression []*Expression, []*Expression,
) { ) {
var ( var (
opIsProd bool opIsProd bool
opIsLinC bool opIsLinC bool
numChildren = len(children) numChildren = len(children)
totalReturnSize = 0
needExpand = false
) )
switch op.(type) { switch op.(type) {
@@ -133,9 +151,34 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
panic("incompatible number of children and magnitudes") panic("incompatible number of children and magnitudes")
} }
// The capacity allocation is purely heuristic // This loops performs a first scan of the children to compute the total
expandedExpression = make([]*Expression, 0, 2*len(magnitudes)) // number of elements to allocate.
expandedMagnitudes = make([]int, 0, 2*len(magnitudes)) for i, child := range children {
switch child.Operator.(type) {
case Product, *Product:
if opIsProd {
needExpand = true
totalReturnSize += len(children[i].Children)
continue
}
case LinComb, *LinComb:
if opIsLinC {
needExpand = true
totalReturnSize += len(children[i].Children)
continue
}
}
totalReturnSize++
}
if !needExpand {
return magnitudes, children
}
expandedMagnitudes := make([]int, 0, totalReturnSize)
expandedExpression := make([]*Expression, 0, totalReturnSize)
for i := 0; i < numChildren; i++ { for i := 0; i < numChildren; i++ {
@@ -175,7 +218,6 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
for k := range child.Children { for k := range child.Children {
expandedExpression = append(expandedExpression, child.Children[k]) expandedExpression = append(expandedExpression, child.Children[k])
expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k]) expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k])
} }
continue continue
} }

View File

@@ -1,7 +1,6 @@
package simplify package simplify
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"sort" "sort"
@@ -46,9 +45,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
// factoring possibilities. There is also a bound on the loop to // factoring possibilities. There is also a bound on the loop to
// prevent infinite loops. // prevent infinite loops.
// //
// The choice of 1000 is purely heuristic and is not meant to be // The choice of 100 is purely heuristic and is not meant to be
// actually met. // actually met.
for k := 0; k < 1000; k++ { for k := 0; k < 100; k++ {
_, ok := new.Operator.(sym.LinComb) _, ok := new.Operator.(sym.LinComb)
if !ok { if !ok {
return new return new
@@ -103,13 +102,18 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
// children that are already in the children set. // children that are already in the children set.
func rankChildren( func rankChildren(
parents []*sym.Expression, parents []*sym.Expression,
childrenSet map[field.Element]*sym.Expression, childrenSet map[uint64]*sym.Expression,
) []*sym.Expression { ) []*sym.Expression {
// List all the grand-children of the expression whose parents are // List all the grand-children of the expression whose parents are
// products and counts the number of occurences by summing the exponents. // products and counts the number of occurences by summing the exponents.
relevantGdChildrenCnt := map[field.Element]int{} // As an optimization the map is addressed using the first uint64 repr
uniqueChildrenList := make([]*sym.Expression, 0) // of the element. We consider this is good enough to avoid collisions.
// The risk if it happens is that it gets caught by the validation checks
// at the end of the factorization routine. The preallocation value is
// purely heuristic to avoid successive allocations.
var relevantGdChildrenCnt map[uint64]int
var uniqueChildrenList []*sym.Expression
for _, p := range parents { for _, p := range parents {
@@ -127,16 +131,21 @@ func rankChildren(
// If it's in the group, it does not count. We can't add it a second // If it's in the group, it does not count. We can't add it a second
// time. // time.
if _, ok := childrenSet[c.ESHash]; ok { if _, ok := childrenSet[c.ESHash[0]]; ok {
continue continue
} }
if _, ok := relevantGdChildrenCnt[c.ESHash]; !ok { if relevantGdChildrenCnt == nil {
relevantGdChildrenCnt[c.ESHash] = 0 relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2)
uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2)
}
if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok {
relevantGdChildrenCnt[c.ESHash[0]] = 0
uniqueChildrenList = append(uniqueChildrenList, c) uniqueChildrenList = append(uniqueChildrenList, c)
} }
relevantGdChildrenCnt[c.ESHash]++ relevantGdChildrenCnt[c.ESHash[0]]++
} }
} }
@@ -144,7 +153,7 @@ func rankChildren(
x := uniqueChildrenList[i].ESHash x := uniqueChildrenList[i].ESHash
y := uniqueChildrenList[j].ESHash y := uniqueChildrenList[j].ESHash
// We want to a decreasing order // We want to a decreasing order
return relevantGdChildrenCnt[x] > relevantGdChildrenCnt[y] return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]]
}) })
return uniqueChildrenList return uniqueChildrenList
@@ -155,76 +164,54 @@ func rankChildren(
// than one parent. The finding is based on a greedy algorithm. We iteratively // than one parent. The finding is based on a greedy algorithm. We iteratively
// add nodes in the group so that the number of common parents decreases as // add nodes in the group so that the number of common parents decreases as
// slowly as possible. // slowly as possible.
func findGdChildrenGroup(expr *sym.Expression) map[field.Element]*sym.Expression { func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression {
curParents := expr.Children curParents := expr.Children
childrenSet := map[field.Element]*sym.Expression{} childrenSet := map[uint64]*sym.Expression{}
for { ranked := rankChildren(curParents, childrenSet)
ranked := rankChildren(curParents, childrenSet)
// Can happen when we have a lincomb of lincomb. Ideally they should be // Can happen when we have a lincomb of lincomb. Ideally they should be
// merged during canonization. // merged during canonization.
if len(ranked) == 0 { if len(ranked) == 0 {
return childrenSet return childrenSet
} }
best := ranked[0] for i := range ranked {
newChildrenSet := copyMap(childrenSet)
newChildrenSet[best.ESHash] = best best := ranked[i]
newParents := getCommonProdParentOfCs(newChildrenSet, curParents) childrenSet[best.ESHash[0]] = best
curParents = filterParentsWithChildren(curParents, best.ESHash)
// Can't grow the set anymore // Can't grow the set anymore
if len(newParents) <= 1 { if len(curParents) <= 1 {
delete(childrenSet, best.ESHash[0])
return childrenSet return childrenSet
} }
childrenSet = newChildrenSet
curParents = newParents
logrus.Tracef( logrus.Tracef(
"find groups, so far we have %v parents and %v siblings", "find groups, so far we have %v parents and %v siblings",
len(curParents), len(childrenSet)) len(curParents), len(childrenSet))
// Sanity-check
if err := parentsMustHaveAllChildren(curParents, childrenSet); err != nil {
panic(err)
}
} }
return childrenSet
} }
// getCommonProdParentOfCs returns the parents that have all cs as children and // filterParentsWithChildren returns a filtered list of parents who have at
// that are themselves children of gdp (grandparent). The parents must be of // least one child with the given ESHash. The function allocates a new list
// type product however. // of parents and returns it without mutating he original list.
func getCommonProdParentOfCs( func filterParentsWithChildren(
cs map[field.Element]*sym.Expression,
parents []*sym.Expression, parents []*sym.Expression,
childEsh field.Element,
) []*sym.Expression { ) []*sym.Expression {
res := []*sym.Expression{} res := make([]*sym.Expression, 0, len(parents))
for _, p := range parents { for _, p := range parents {
prod, ok := p.Operator.(sym.Product) for _, c := range p.Children {
if !ok { if c.ESHash == childEsh {
continue res = append(res, p)
} break
// Account for the fact that p may contain duplicates. So we cannot
// just use a counter here.
founds := map[field.Element]struct{}{}
for i, c := range p.Children {
if prod.Exponents[i] == 0 {
continue
} }
if _, inside := cs[c.ESHash]; inside {
// logrus.Tracef("%v contains %v", p.ESHash.String(), c.ESHash.String())
founds[c.ESHash] = struct{}{}
}
}
if len(founds) == len(cs) {
res = append(res, p)
} }
} }
@@ -235,22 +222,26 @@ func getCommonProdParentOfCs(
// determine the best common factor. // determine the best common factor.
func factorLinCompFromGroup( func factorLinCompFromGroup(
lincom *sym.Expression, lincom *sym.Expression,
group map[field.Element]*sym.Expression, group map[uint64]*sym.Expression,
) *sym.Expression { ) *sym.Expression {
var ( var (
// numTerms indicates the number of children in the linear-combination
numTerms = len(lincom.Children)
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
// Build the common term by taking the max of the exponents // Build the common term by taking the max of the exponents
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group) exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
// Separate the non-factored terms // Separate the non-factored terms
nonFactoredTerms = []*sym.Expression{} nonFactoredTerms = make([]*sym.Expression, 0, numTerms)
nonFactoredCoeffs = []int{} nonFactoredCoeffs = make([]int, 0, numTerms)
// The factored terms of the linear combination divided by the common // The factored terms of the linear combination divided by the common
// group factor // group factor
factoredTerms = []*sym.Expression{} factoredTerms = make([]*sym.Expression, 0, numTerms)
factoredCoeffs = []int{} factoredCoeffs = make([]int, 0, numTerms)
) )
numFactors := 0 numFactors := 0
@@ -295,7 +286,7 @@ func factorLinCompFromGroup(
// //
// Fortunately, this is guaranteed if the expression was constructed via // Fortunately, this is guaranteed if the expression was constructed via
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory. // [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) ( func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) (
factored *sym.Expression, factored *sym.Expression,
success bool, success bool,
) { ) {
@@ -310,7 +301,7 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
numMatches := 0 numMatches := 0
for i, c := range e.Children { for i, c := range e.Children {
eig, found := exponentsOfGroup[c.ESHash] eig, found := exponentsOfGroup[c.ESHash[0]]
if !found { if !found {
continue continue
} }
@@ -335,14 +326,14 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
// have the whole group as children. // have the whole group as children.
func optimRegroupExponents( func optimRegroupExponents(
parents []*sym.Expression, parents []*sym.Expression,
group map[field.Element]*sym.Expression, group map[uint64]*sym.Expression,
) ( ) (
exponentMap map[field.Element]int, exponentMap map[uint64]int,
groupedTerm *sym.Expression, groupedTerm *sym.Expression,
) { ) {
exponentMap = map[field.Element]int{} exponentMap = make(map[uint64]int, 16)
canonTermList := make([]*sym.Expression, 0) // built in deterministic order canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order
for _, p := range parents { for _, p := range parents {
@@ -354,10 +345,10 @@ func optimRegroupExponents(
// Used to sanity-check that all the nodes of the group have been // Used to sanity-check that all the nodes of the group have been
// reached through this parent. // reached through this parent.
matched := map[field.Element]int{} matched := make(map[uint64]int, len(p.Children))
for i, c := range p.Children { for i, c := range p.Children {
if _, ingroup := group[c.ESHash]; !ingroup { if _, ingroup := group[c.ESHash[0]]; !ingroup {
continue continue
} }
@@ -365,16 +356,16 @@ func optimRegroupExponents(
panic("The expression is not canonic") panic("The expression is not canonic")
} }
_, initialized := exponentMap[c.ESHash] _, initialized := exponentMap[c.ESHash[0]]
if !initialized { if !initialized {
// Max int is used as a placeholder. It will be replaced anytime // Max int is used as a placeholder. It will be replaced anytime
// we wall utils.Min(exponentMap[h], n) where n is actually an // we wall utils.Min(exponentMap[h], n) where n is actually an
// exponent. // exponent.
exponentMap[c.ESHash] = math.MaxInt exponentMap[c.ESHash[0]] = math.MaxInt
canonTermList = append(canonTermList, c) canonTermList = append(canonTermList, c)
} }
matched[c.ESHash] = exponents[i] matched[c.ESHash[0]] = exponents[i]
} }
if len(matched) != len(group) { if len(matched) != len(group) {
@@ -391,48 +382,8 @@ func optimRegroupExponents(
canonExponents := []int{} canonExponents := []int{}
for _, e := range canonTermList { for _, e := range canonTermList {
canonExponents = append(canonExponents, exponentMap[e.ESHash]) canonExponents = append(canonExponents, exponentMap[e.ESHash[0]])
} }
return exponentMap, sym.NewProduct(canonTermList, canonExponents) return exponentMap, sym.NewProduct(canonTermList, canonExponents)
} }
// parentsMustHaveAllChildren returns an error if at least one of the parents
// is missing one children from the set. This function is used internally to
// enforce invariants throughout the simplification routines.
func parentsMustHaveAllChildren[T any](
parents []*sym.Expression,
childrenSet map[field.Element]T,
) (resErr error) {
for parentID, p := range parents {
// Account for the fact that the node may contain duplicates of the node
// we are looking for.
founds := map[field.Element]struct{}{}
for _, c := range p.Children {
if _, ok := childrenSet[c.ESHash]; ok {
founds[c.ESHash] = struct{}{}
}
}
if len(founds) != len(childrenSet) {
resErr = errors.Join(
resErr,
fmt.Errorf(
"parent num %v is incomplete : found = %d/%d",
parentID, len(founds), len(childrenSet),
),
)
}
}
return resErr
}
func copyMap[K comparable, V any](m map[K]V) map[K]V {
res := make(map[K]V, len(m))
for k, v := range m {
res[k] = v
}
return res
}

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/consensys/linea-monorepo/prover/maths/field"
sym "github.com/consensys/linea-monorepo/prover/symbolic" sym "github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -60,12 +59,6 @@ func TestIsFactored(t *testing.T) {
IsFactored: true, IsFactored: true,
Factor: a, Factor: a,
}, },
{
Expr: sym.Mul(a, a, b),
By: a,
IsFactored: true,
Factor: sym.Mul(a, b),
},
} }
for i, tc := range testcases { for i, tc := range testcases {
@@ -74,13 +67,13 @@ func TestIsFactored(t *testing.T) {
// Build the group exponent map. If by is a product, we use directly // Build the group exponent map. If by is a product, we use directly
// the exponents it contains. Otherwise, we say this is a single // the exponents it contains. Otherwise, we say this is a single
// term product with an exponent of 1. // term product with an exponent of 1.
groupedExp := map[field.Element]int{} groupedExp := map[uint64]int{}
if byProd, ok := tc.By.Operator.(sym.Product); ok { if byProd, ok := tc.By.Operator.(sym.Product); ok {
for i, ex := range byProd.Exponents { for i, ex := range byProd.Exponents {
groupedExp[tc.By.Children[i].ESHash] = ex groupedExp[tc.By.Children[i].ESHash[0]] = ex
} }
} else { } else {
groupedExp[tc.By.ESHash] = 1 groupedExp[tc.By.ESHash[0]] = 1
} }
factored, isFactored := isFactored(tc.Expr, groupedExp) factored, isFactored := isFactored(tc.Expr, groupedExp)
@@ -207,15 +200,27 @@ func TestFactorLinCompFromGroup(t *testing.T) {
for i, testCase := range testCases { for i, testCase := range testCases {
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) { t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
group := map[field.Element]*sym.Expression{} group := map[uint64]*sym.Expression{}
for _, e := range testCase.Group { for _, e := range testCase.Group {
group[e.ESHash] = e group[e.ESHash[0]] = e
} }
factored := factorLinCompFromGroup(testCase.LinComb, group) factored := factorLinCompFromGroup(testCase.LinComb, group)
require.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String()) assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
if t.Failed() {
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
t.Fatal()
}
require.NoError(t, factored.Validate()) require.NoError(t, factored.Validate())
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored)) assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
if t.Failed() {
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
}
}) })
} }

View File

@@ -21,24 +21,17 @@ func removePolyEval(e *sym.Expression) *sym.Expression {
return oldExpr // Handle edge case where there are no coefficients return oldExpr // Handle edge case where there are no coefficients
} }
acc := cs[0]
// Precompute powers of x // Precompute powers of x
powersOfX := make([]*sym.Expression, len(cs)) monomialTerms := make([]any, len(cs))
powersOfX[0] = x for i := 0; i < len(cs); i++ {
for i := 1; i < len(cs); i++ {
// We don't use the default constructor because it will collapse the // We don't use the default constructor because it will collapse the
// intermediate terms into a single term. The intermediates are useful because // intermediate terms into a single term. The intermediates are useful because
// they tell the evaluator to reuse the intermediate terms instead of // they tell the evaluator to reuse the intermediate terms instead of
// computing x^i for every term. // computing x^i for every term.
powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1}) monomialTerms[i] = any(sym.NewProduct([]*sym.Expression{cs[i], x}, []int{1, i}))
} }
for i := 1; i < len(cs); i++ { acc := sym.Add(monomialTerms...)
// Here we want to use the default constructor to ensure that we
// will have a merged sum at the end.
acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i]))
}
if oldExpr.ESHash != acc.ESHash { if oldExpr.ESHash != acc.ESHash {
panic("ESH was altered") panic("ESH was altered")

22
prover/utils/sort.go Normal file
View File

@@ -0,0 +1,22 @@
package utils
// GenSorter is a generic interface implementing [sort.Interface]
// but let the user provide the methods directly as closures.
// Without needing to implement a new custom type.
type GenSorter struct {
LenFn func() int
SwapFn func(int, int)
LessFn func(int, int) bool
}
func (s GenSorter) Len() int {
return s.LenFn()
}
func (s GenSorter) Swap(i, j int) {
s.SwapFn(i, j)
}
func (s GenSorter) Less(i, j int) bool {
return s.LessFn(i, j)
}

View File

@@ -105,7 +105,7 @@ func FullZkEvm(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
onceFullZkEvm.Do(func() { onceFullZkEvm.Do(func() {
// Initialize the Full zkEVM arithmetization // Initialize the Full zkEVM arithmetization
fullZkEvm = fullZKEVMWithSuite(tl, fullCompilationSuite, cfg) fullZkEvm = FullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
}) })
return fullZkEvm return fullZkEvm
@@ -115,13 +115,16 @@ func FullZkEVMCheckOnly(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
onceFullZkEvmCheckOnly.Do(func() { onceFullZkEvmCheckOnly.Do(func() {
// Initialize the Full zkEVM arithmetization // Initialize the Full zkEVM arithmetization
fullZkEvmCheckOnly = fullZKEVMWithSuite(tl, dummyCompilationSuite, cfg) fullZkEvmCheckOnly = FullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
}) })
return fullZkEvmCheckOnly return fullZkEvmCheckOnly
} }
func fullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm { // FullZKEVMWithSuite returns a compiled zkEVM with the given compilation suite.
// It can be used to benchmark the compilation time of the zkEVM and helps with
// performance optimization.
func FullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
// @Alex: only set mandatory parameters here. aka, the one that are not // @Alex: only set mandatory parameters here. aka, the one that are not
// actually feature-gated. // actually feature-gated.