Prover(perf): faster global constraints compilation (#704)

* bench(global): adds a benchmark for the global constraint compiler

* perf(merging): accumulates the factors before creating the expression

* perf(product): computes the ESH without using a smart-vector

* perf(factor): preallocations in the factorization algorithm

* perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs

* perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations.

* perf(factor): use an integer map instead of a field.Element map when possible

* fixup(expands): fix the skip condition for term expansion

* perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb

* feat(repr): adds a json repr function to help debugging

* test(constructor): cleans the test of the constructors

* perf(factor): address maps using the first limb of a field.Element instead of the full field.Element

* fixup(commit): adds missing file in previous commit

* perf(factor): reduce the number of calls to rankChildren

* perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims

* perf(factors): use a counter in getCommonProdParentOfCs

* perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function

* clean(factor): remove unneeded function and imports

* feat(utils): adds a generic sort interface implementation

* perf(rankChildren): lazy allocation of the map to save on allocations

* perf(factorize): reduces the loop-bound for factorizeExpression

* (chore): fix a missing argument and format gofmt

* feat: readd test

---------

Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com>
Co-authored-by: gusiri <dreamerty@postech.ac.kr>
This commit is contained in:
AlexandreBelling
2025-03-21 12:55:54 +01:00
committed by GitHub
parent 869c0c63d4
commit 7334693931
12 changed files with 333 additions and 201 deletions

View File

@@ -0,0 +1,75 @@
package globalcs_test
import (
"os"
"testing"
"github.com/consensys/linea-monorepo/prover/config"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/cleanup"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/zkevm"
"github.com/sirupsen/logrus"
)
// BenchmarkGlobalConstraint benchmarks the global constraints compiler against the
// actual zk-evm constraint system.
func BenchmarkGlobalConstraintWoArtefacts(b *testing.B) {
// partialSuite corresponds to the actual compilation suite of
// the full zk-evm down to the point where the global constraints
// are compiled.
partialSuite := []func(*wizard.CompiledIOP){
mimc.CompileMiMC,
specialqueries.RangeProof,
specialqueries.CompileFixedPermutations,
permutation.CompileGrandProduct,
lookup.CompileLogDerivative,
innerproduct.Compile,
sticker.Sticker(1<<10, 1<<19),
splitter.SplitColumns(1 << 19),
cleanup.CleanUp,
localcs.Compile,
globalcs.Compile,
}
// In order to load the config we need to position ourselves in the root
// folder.
_ = os.Chdir("../../..")
defer os.Chdir("protocol/compiler/globalcs")
// config corresponds to the config we use on sepolia
cfg, cfgErr := config.NewConfigFromFile("./config/config-sepolia-full.toml")
if cfgErr != nil {
b.Fatalf("could not find the config: %v", cfgErr)
}
// Shut the logger to not overwhelm the benchmark output
logrus.SetLevel(logrus.PanicLevel)
b.ResetTimer()
for c_ := 0; c_ < b.N; c_++ {
b.StopTimer()
// Removes the artefacts to
if err := os.RemoveAll("/tmp/prover-artefacts"); err != nil {
b.Fatalf("could not remove the artefacts: %v", err)
}
b.StartTimer()
_ = zkevm.FullZKEVMWithSuite(&cfg.TracesLimits, partialSuite, cfg)
}
}

View File

@@ -146,19 +146,28 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
domainSize = cs.DomainSize
x = variables.NewXVar()
omega = fft.GetOmega(domainSize)
// factors is a list of expression to multiply to obtain the return expression. It
// is initialized with "only" the initial expression and we iteratively add the
// terms (X-i) to it. At the end, we call [sym.Mul] a single time. This structure
// is important because it [sym.Mul] operates a sequence of optimization routines
// that are everytime we call it. In an earlier version, we were calling [sym.Mul]
// for every factor and this were making the function have a quadratic/cubic runtime.
factors = make([]any, 0, utils.Abs(cancelRange.Max)+utils.Abs(cancelRange.Min)+1)
)
// cancelExprAtPoint cancels the expression at a particular position
cancelExprAtPoint := func(expr *symbolic.Expression, i int) *symbolic.Expression {
factors = append(factors, res)
// appendFactor appends an expressions representing $X-\rho^i$ to [factors]
appendFactor := func(i int) {
var root field.Element
root.Exp(omega, big.NewInt(int64(i)))
return symbolic.Mul(expr, symbolic.Sub(x, root))
factors = append(factors, symbolic.Sub(x, root))
}
if cancelRange.Min < 0 {
// Cancels the expression on the range [0, -cancelRange.Min)
for i := 0; i < -cancelRange.Min; i++ {
res = cancelExprAtPoint(res, i)
appendFactor(i)
}
}
@@ -166,11 +175,18 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
// Cancels the expression on the range (N-cancelRange.Max-1, N-1]
for i := 0; i < cancelRange.Max; i++ {
point := domainSize - i - 1 // point at which we want to cancel the constraint
res = cancelExprAtPoint(res, point)
appendFactor(point)
}
}
return res
// When factors is of length 1, it means the expression does not need to be
// bound-cancelled and we can directly return the original expression
// without calling [sym.Mul].
if len(factors) == 1 {
return res
}
return symbolic.Mul(factors...)
}
// getExprRatio computes the ratio of the expression and ceil to the next power

View File

@@ -27,12 +27,12 @@ func Add(inputs ...any) *Expression {
exprInputs := intoExprSlice(inputs...)
res := exprInputs[0]
for i := 1; i < len(exprInputs); i++ {
res = res.Add(exprInputs[i])
magnitudes := make([]int, len(exprInputs))
for i := range exprInputs {
magnitudes[i] = 1
}
return res
return NewLinComb(exprInputs, magnitudes)
}
// Mul constructs a symbolic expression representing the product of its inputs.
@@ -57,12 +57,12 @@ func Mul(inputs ...any) *Expression {
exprInputs := intoExprSlice(inputs...)
res := exprInputs[0]
for i := 1; i < len(exprInputs); i++ {
res = res.Mul(exprInputs[i])
magnitudes := make([]int, len(exprInputs))
for i := range exprInputs {
magnitudes[i] = 1
}
return res
return NewProduct(exprInputs, magnitudes)
}
// Sub returns a symbolic expression representing the subtraction of `a` by all
@@ -82,16 +82,18 @@ func Sub(a any, bs ...any) *Expression {
}
var (
aExpr = intoExpr(a)
bExpr = intoExprSlice(bs...)
res = aExpr
aExpr = intoExpr(a)
bExpr = intoExprSlice(bs...)
exprInputs = append([]*Expression{aExpr}, bExpr...)
magnitudes = make([]int, len(exprInputs))
)
for i := range bExpr {
res = res.Sub(bExpr[i])
for i := range exprInputs {
magnitudes[i] = -1
}
magnitudes[0] = 1
return res
return NewLinComb(exprInputs, magnitudes)
}
// Neg returns an expression representing the negation of an expression or of
@@ -119,7 +121,7 @@ func Square(x any) *Expression {
return nil
}
return intoExpr(x).Square()
return NewProduct([]*Expression{intoExpr(x)}, []int{2})
}
// Pow returns an expression representing the raising to the power "n" of an

View File

@@ -1,6 +1,7 @@
package symbolic
import (
"encoding/json"
"fmt"
"reflect"
"sync"
@@ -284,5 +285,14 @@ func (e *Expression) SameWithNewChildren(newChildren []*Expression) *Expression
default:
panic("unexpected type: " + reflect.TypeOf(op).String())
}
}
// MarshalJSONString returns a JSON string returns a JSON string representation
// of the expression.
func (e *Expression) MarshalJSONString() string {
js, jsErr := json.MarshalIndent(e, "", " ")
if jsErr != nil {
utils.Panic("failed to marshal expression: %v", jsErr)
}
return string(js)
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils/collection"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -96,7 +97,7 @@ func TestLCConstruction(t *testing.T) {
x := NewDummyVar("x")
y := NewDummyVar("y")
{
t.Run("simple-addition", func(t *testing.T) {
/*
Test t a simple case of addition
*/
@@ -108,23 +109,23 @@ func TestLCConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs))
require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1)
require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1)
}
})
{
t.Run("x-y-x", func(t *testing.T) {
/*
Adding y then substracting x should give back (y)
*/
expr1 := x.Add(y).Sub(x)
require.Equal(t, expr1, y)
}
})
{
t.Run("(-x)+x+y", func(t *testing.T) {
/*
Same thing when using Neg
*/
expr := x.Neg().Add(x).Add(y)
require.Equal(t, expr, y)
}
assert.Equal(t, expr, y)
})
}
@@ -133,7 +134,7 @@ func TestProductConstruction(t *testing.T) {
x := NewDummyVar("x")
y := NewDummyVar("y")
{
t.Run("x * y", func(t *testing.T) {
/*
Test t a simple case of addition
*/
@@ -145,9 +146,9 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
require.Equal(t, expr1.Operator.(Product).Exponents[0], 1)
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
}
})
{
t.Run("x * y * x", func(t *testing.T) {
/*
Adding y then substracting x should give back (y)
*/
@@ -158,9 +159,9 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
require.Equal(t, expr1.Operator.(Product).Exponents[0], 2)
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
}
})
{
t.Run("x^2", func(t *testing.T) {
/*
When we square
*/
@@ -169,6 +170,6 @@ func TestProductConstruction(t *testing.T) {
require.Equal(t, expr.Children[0], x)
require.Equal(t, 1, len(expr.Operator.(Product).Exponents))
require.Equal(t, expr.Operator.(Product).Exponents[0], 2)
}
})
}

View File

@@ -95,18 +95,30 @@ func NewProduct(items []*Expression, exponents []int) *Expression {
e := &Expression{
Operator: Product{Exponents: exponents},
Children: items,
ESHash: field.One(),
}
// Now we need to assign the ESH
eshashes := make([]sv.SmartVector, len(e.Children))
for i := range e.Children {
eshashes[i] = sv.NewConstant(e.Children[i].ESHash, 1)
}
if len(items) > 0 {
// The cast back to sv.Constant is no important functionally but is an easy
// sanity check.
e.ESHash = e.Operator.Evaluate(eshashes).(*sv.Constant).Get(0)
var tmp field.Element
switch {
case exponents[i] == 1:
e.ESHash.Mul(&e.ESHash, &e.Children[i].ESHash)
case exponents[i] == 2:
tmp.Square(&e.Children[i].ESHash)
e.ESHash.Mul(&e.ESHash, &tmp)
case exponents[i] == 3:
tmp.Square(&e.Children[i].ESHash)
tmp.Mul(&tmp, &e.Children[i].ESHash)
e.ESHash.Mul(&e.ESHash, &tmp)
case exponents[i] == 4:
tmp.Square(&e.Children[i].ESHash)
tmp.Square(&tmp)
e.ESHash.Mul(&e.ESHash, &tmp)
default:
exponent := big.NewInt(int64(exponents[i]))
tmp.Exp(e.Children[i].ESHash, exponent)
e.ESHash.Mul(&e.ESHash, &tmp)
}
}
return e

View File

@@ -78,22 +78,38 @@ func regroupTerms(magnitudes []int, children []*Expression) (
// the linear combination. This function is used both for simplifying [LinComb]
// expressions and for simplifying [Product]. "magnitude" denotes either the
// coefficient for LinComb or exponents for Product.
//
// The function takes ownership of the provided slices.
func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) {
if len(magnitudes) != len(children) {
panic("magnitudes and children don't have the same length")
}
cleanChildren = make([]*Expression, 0, len(children))
cleanMagnitudes = make([]int, 0, len(children))
// cleanChildren and cleanMagnitudes are initialized lazily to
// avoid unnecessarily allocating memory. The underlying assumption
// is that the application will 99% of time never pass zero as a
// magnitude.
for i, c := range magnitudes {
if c != 0 {
if c == 0 && cleanChildren == nil {
cleanChildren = make([]*Expression, i, len(children))
cleanMagnitudes = make([]int, i, len(children))
copy(cleanChildren, children[:i])
copy(cleanMagnitudes, magnitudes[:i])
}
if c != 0 && cleanChildren != nil {
cleanMagnitudes = append(cleanMagnitudes, magnitudes[i])
cleanChildren = append(cleanChildren, children[i])
}
}
if cleanChildren == nil {
cleanChildren = children
cleanMagnitudes = magnitudes
}
return cleanMagnitudes, cleanChildren
}
@@ -108,14 +124,16 @@ func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes
// The caller passes a target operator which may be any value of type either
// [LinComb] or [Product]. Any other type yields a panic error.
func expandTerms(op Operator, magnitudes []int, children []*Expression) (
expandedMagnitudes []int,
expandedExpression []*Expression,
[]int,
[]*Expression,
) {
var (
opIsProd bool
opIsLinC bool
numChildren = len(children)
opIsProd bool
opIsLinC bool
numChildren = len(children)
totalReturnSize = 0
needExpand = false
)
switch op.(type) {
@@ -133,9 +151,34 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
panic("incompatible number of children and magnitudes")
}
// The capacity allocation is purely heuristic
expandedExpression = make([]*Expression, 0, 2*len(magnitudes))
expandedMagnitudes = make([]int, 0, 2*len(magnitudes))
// This loops performs a first scan of the children to compute the total
// number of elements to allocate.
for i, child := range children {
switch child.Operator.(type) {
case Product, *Product:
if opIsProd {
needExpand = true
totalReturnSize += len(children[i].Children)
continue
}
case LinComb, *LinComb:
if opIsLinC {
needExpand = true
totalReturnSize += len(children[i].Children)
continue
}
}
totalReturnSize++
}
if !needExpand {
return magnitudes, children
}
expandedMagnitudes := make([]int, 0, totalReturnSize)
expandedExpression := make([]*Expression, 0, totalReturnSize)
for i := 0; i < numChildren; i++ {
@@ -175,7 +218,6 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
for k := range child.Children {
expandedExpression = append(expandedExpression, child.Children[k])
expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k])
}
continue
}

View File

@@ -1,7 +1,6 @@
package simplify
import (
"errors"
"fmt"
"math"
"sort"
@@ -46,9 +45,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
// factoring possibilities. There is also a bound on the loop to
// prevent infinite loops.
//
// The choice of 1000 is purely heuristic and is not meant to be
// The choice of 100 is purely heuristic and is not meant to be
// actually met.
for k := 0; k < 1000; k++ {
for k := 0; k < 100; k++ {
_, ok := new.Operator.(sym.LinComb)
if !ok {
return new
@@ -103,13 +102,18 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
// children that are already in the children set.
func rankChildren(
parents []*sym.Expression,
childrenSet map[field.Element]*sym.Expression,
childrenSet map[uint64]*sym.Expression,
) []*sym.Expression {
// List all the grand-children of the expression whose parents are
// products and counts the number of occurences by summing the exponents.
relevantGdChildrenCnt := map[field.Element]int{}
uniqueChildrenList := make([]*sym.Expression, 0)
// As an optimization the map is addressed using the first uint64 repr
// of the element. We consider this is good enough to avoid collisions.
// The risk if it happens is that it gets caught by the validation checks
// at the end of the factorization routine. The preallocation value is
// purely heuristic to avoid successive allocations.
var relevantGdChildrenCnt map[uint64]int
var uniqueChildrenList []*sym.Expression
for _, p := range parents {
@@ -127,16 +131,21 @@ func rankChildren(
// If it's in the group, it does not count. We can't add it a second
// time.
if _, ok := childrenSet[c.ESHash]; ok {
if _, ok := childrenSet[c.ESHash[0]]; ok {
continue
}
if _, ok := relevantGdChildrenCnt[c.ESHash]; !ok {
relevantGdChildrenCnt[c.ESHash] = 0
if relevantGdChildrenCnt == nil {
relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2)
uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2)
}
if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok {
relevantGdChildrenCnt[c.ESHash[0]] = 0
uniqueChildrenList = append(uniqueChildrenList, c)
}
relevantGdChildrenCnt[c.ESHash]++
relevantGdChildrenCnt[c.ESHash[0]]++
}
}
@@ -144,7 +153,7 @@ func rankChildren(
x := uniqueChildrenList[i].ESHash
y := uniqueChildrenList[j].ESHash
// We want to a decreasing order
return relevantGdChildrenCnt[x] > relevantGdChildrenCnt[y]
return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]]
})
return uniqueChildrenList
@@ -155,76 +164,54 @@ func rankChildren(
// than one parent. The finding is based on a greedy algorithm. We iteratively
// add nodes in the group so that the number of common parents decreases as
// slowly as possible.
func findGdChildrenGroup(expr *sym.Expression) map[field.Element]*sym.Expression {
func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression {
curParents := expr.Children
childrenSet := map[field.Element]*sym.Expression{}
childrenSet := map[uint64]*sym.Expression{}
for {
ranked := rankChildren(curParents, childrenSet)
ranked := rankChildren(curParents, childrenSet)
// Can happen when we have a lincomb of lincomb. Ideally they should be
// merged during canonization.
if len(ranked) == 0 {
return childrenSet
}
// Can happen when we have a lincomb of lincomb. Ideally they should be
// merged during canonization.
if len(ranked) == 0 {
return childrenSet
}
best := ranked[0]
newChildrenSet := copyMap(childrenSet)
newChildrenSet[best.ESHash] = best
newParents := getCommonProdParentOfCs(newChildrenSet, curParents)
for i := range ranked {
best := ranked[i]
childrenSet[best.ESHash[0]] = best
curParents = filterParentsWithChildren(curParents, best.ESHash)
// Can't grow the set anymore
if len(newParents) <= 1 {
if len(curParents) <= 1 {
delete(childrenSet, best.ESHash[0])
return childrenSet
}
childrenSet = newChildrenSet
curParents = newParents
logrus.Tracef(
"find groups, so far we have %v parents and %v siblings",
len(curParents), len(childrenSet))
// Sanity-check
if err := parentsMustHaveAllChildren(curParents, childrenSet); err != nil {
panic(err)
}
}
return childrenSet
}
// getCommonProdParentOfCs returns the parents that have all cs as children and
// that are themselves children of gdp (grandparent). The parents must be of
// type product however.
func getCommonProdParentOfCs(
cs map[field.Element]*sym.Expression,
// filterParentsWithChildren returns a filtered list of parents who have at
// least one child with the given ESHash. The function allocates a new list
// of parents and returns it without mutating he original list.
func filterParentsWithChildren(
parents []*sym.Expression,
childEsh field.Element,
) []*sym.Expression {
res := []*sym.Expression{}
res := make([]*sym.Expression, 0, len(parents))
for _, p := range parents {
prod, ok := p.Operator.(sym.Product)
if !ok {
continue
}
// Account for the fact that p may contain duplicates. So we cannot
// just use a counter here.
founds := map[field.Element]struct{}{}
for i, c := range p.Children {
if prod.Exponents[i] == 0 {
continue
for _, c := range p.Children {
if c.ESHash == childEsh {
res = append(res, p)
break
}
if _, inside := cs[c.ESHash]; inside {
// logrus.Tracef("%v contains %v", p.ESHash.String(), c.ESHash.String())
founds[c.ESHash] = struct{}{}
}
}
if len(founds) == len(cs) {
res = append(res, p)
}
}
@@ -235,22 +222,26 @@ func getCommonProdParentOfCs(
// determine the best common factor.
func factorLinCompFromGroup(
lincom *sym.Expression,
group map[field.Element]*sym.Expression,
group map[uint64]*sym.Expression,
) *sym.Expression {
var (
// numTerms indicates the number of children in the linear-combination
numTerms = len(lincom.Children)
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
// Build the common term by taking the max of the exponents
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
// Separate the non-factored terms
nonFactoredTerms = []*sym.Expression{}
nonFactoredCoeffs = []int{}
nonFactoredTerms = make([]*sym.Expression, 0, numTerms)
nonFactoredCoeffs = make([]int, 0, numTerms)
// The factored terms of the linear combination divided by the common
// group factor
factoredTerms = []*sym.Expression{}
factoredCoeffs = []int{}
factoredTerms = make([]*sym.Expression, 0, numTerms)
factoredCoeffs = make([]int, 0, numTerms)
)
numFactors := 0
@@ -295,7 +286,7 @@ func factorLinCompFromGroup(
//
// Fortunately, this is guaranteed if the expression was constructed via
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) (
factored *sym.Expression,
success bool,
) {
@@ -310,7 +301,7 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
numMatches := 0
for i, c := range e.Children {
eig, found := exponentsOfGroup[c.ESHash]
eig, found := exponentsOfGroup[c.ESHash[0]]
if !found {
continue
}
@@ -335,14 +326,14 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
// have the whole group as children.
func optimRegroupExponents(
parents []*sym.Expression,
group map[field.Element]*sym.Expression,
group map[uint64]*sym.Expression,
) (
exponentMap map[field.Element]int,
exponentMap map[uint64]int,
groupedTerm *sym.Expression,
) {
exponentMap = map[field.Element]int{}
canonTermList := make([]*sym.Expression, 0) // built in deterministic order
exponentMap = make(map[uint64]int, 16)
canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order
for _, p := range parents {
@@ -354,10 +345,10 @@ func optimRegroupExponents(
// Used to sanity-check that all the nodes of the group have been
// reached through this parent.
matched := map[field.Element]int{}
matched := make(map[uint64]int, len(p.Children))
for i, c := range p.Children {
if _, ingroup := group[c.ESHash]; !ingroup {
if _, ingroup := group[c.ESHash[0]]; !ingroup {
continue
}
@@ -365,16 +356,16 @@ func optimRegroupExponents(
panic("The expression is not canonic")
}
_, initialized := exponentMap[c.ESHash]
_, initialized := exponentMap[c.ESHash[0]]
if !initialized {
// Max int is used as a placeholder. It will be replaced anytime
// we wall utils.Min(exponentMap[h], n) where n is actually an
// exponent.
exponentMap[c.ESHash] = math.MaxInt
exponentMap[c.ESHash[0]] = math.MaxInt
canonTermList = append(canonTermList, c)
}
matched[c.ESHash] = exponents[i]
matched[c.ESHash[0]] = exponents[i]
}
if len(matched) != len(group) {
@@ -391,48 +382,8 @@ func optimRegroupExponents(
canonExponents := []int{}
for _, e := range canonTermList {
canonExponents = append(canonExponents, exponentMap[e.ESHash])
canonExponents = append(canonExponents, exponentMap[e.ESHash[0]])
}
return exponentMap, sym.NewProduct(canonTermList, canonExponents)
}
// parentsMustHaveAllChildren returns an error if at least one of the parents
// is missing one children from the set. This function is used internally to
// enforce invariants throughout the simplification routines.
func parentsMustHaveAllChildren[T any](
parents []*sym.Expression,
childrenSet map[field.Element]T,
) (resErr error) {
for parentID, p := range parents {
// Account for the fact that the node may contain duplicates of the node
// we are looking for.
founds := map[field.Element]struct{}{}
for _, c := range p.Children {
if _, ok := childrenSet[c.ESHash]; ok {
founds[c.ESHash] = struct{}{}
}
}
if len(founds) != len(childrenSet) {
resErr = errors.Join(
resErr,
fmt.Errorf(
"parent num %v is incomplete : found = %d/%d",
parentID, len(founds), len(childrenSet),
),
)
}
}
return resErr
}
func copyMap[K comparable, V any](m map[K]V) map[K]V {
res := make(map[K]V, len(m))
for k, v := range m {
res[k] = v
}
return res
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"testing"
"github.com/consensys/linea-monorepo/prover/maths/field"
sym "github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -60,12 +59,6 @@ func TestIsFactored(t *testing.T) {
IsFactored: true,
Factor: a,
},
{
Expr: sym.Mul(a, a, b),
By: a,
IsFactored: true,
Factor: sym.Mul(a, b),
},
}
for i, tc := range testcases {
@@ -74,13 +67,13 @@ func TestIsFactored(t *testing.T) {
// Build the group exponent map. If by is a product, we use directly
// the exponents it contains. Otherwise, we say this is a single
// term product with an exponent of 1.
groupedExp := map[field.Element]int{}
groupedExp := map[uint64]int{}
if byProd, ok := tc.By.Operator.(sym.Product); ok {
for i, ex := range byProd.Exponents {
groupedExp[tc.By.Children[i].ESHash] = ex
groupedExp[tc.By.Children[i].ESHash[0]] = ex
}
} else {
groupedExp[tc.By.ESHash] = 1
groupedExp[tc.By.ESHash[0]] = 1
}
factored, isFactored := isFactored(tc.Expr, groupedExp)
@@ -207,15 +200,27 @@ func TestFactorLinCompFromGroup(t *testing.T) {
for i, testCase := range testCases {
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
group := map[field.Element]*sym.Expression{}
group := map[uint64]*sym.Expression{}
for _, e := range testCase.Group {
group[e.ESHash] = e
group[e.ESHash[0]] = e
}
factored := factorLinCompFromGroup(testCase.LinComb, group)
require.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
if t.Failed() {
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
t.Fatal()
}
require.NoError(t, factored.Validate())
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
if t.Failed() {
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
}
})
}

View File

@@ -21,24 +21,17 @@ func removePolyEval(e *sym.Expression) *sym.Expression {
return oldExpr // Handle edge case where there are no coefficients
}
acc := cs[0]
// Precompute powers of x
powersOfX := make([]*sym.Expression, len(cs))
powersOfX[0] = x
for i := 1; i < len(cs); i++ {
monomialTerms := make([]any, len(cs))
for i := 0; i < len(cs); i++ {
// We don't use the default constructor because it will collapse the
// intermediate terms into a single term. The intermediates are useful because
// they tell the evaluator to reuse the intermediate terms instead of
// computing x^i for every term.
powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1})
monomialTerms[i] = any(sym.NewProduct([]*sym.Expression{cs[i], x}, []int{1, i}))
}
for i := 1; i < len(cs); i++ {
// Here we want to use the default constructor to ensure that we
// will have a merged sum at the end.
acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i]))
}
acc := sym.Add(monomialTerms...)
if oldExpr.ESHash != acc.ESHash {
panic("ESH was altered")

22
prover/utils/sort.go Normal file
View File

@@ -0,0 +1,22 @@
package utils
// GenSorter is a generic interface implementing [sort.Interface]
// but let the user provide the methods directly as closures.
// Without needing to implement a new custom type.
type GenSorter struct {
LenFn func() int
SwapFn func(int, int)
LessFn func(int, int) bool
}
func (s GenSorter) Len() int {
return s.LenFn()
}
func (s GenSorter) Swap(i, j int) {
s.SwapFn(i, j)
}
func (s GenSorter) Less(i, j int) bool {
return s.LessFn(i, j)
}

View File

@@ -105,7 +105,7 @@ func FullZkEvm(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
onceFullZkEvm.Do(func() {
// Initialize the Full zkEVM arithmetization
fullZkEvm = fullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
fullZkEvm = FullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
})
return fullZkEvm
@@ -115,13 +115,16 @@ func FullZkEVMCheckOnly(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
onceFullZkEvmCheckOnly.Do(func() {
// Initialize the Full zkEVM arithmetization
fullZkEvmCheckOnly = fullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
fullZkEvmCheckOnly = FullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
})
return fullZkEvmCheckOnly
}
func fullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
// FullZKEVMWithSuite returns a compiled zkEVM with the given compilation suite.
// It can be used to benchmark the compilation time of the zkEVM and helps with
// performance optimization.
func FullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
// @Alex: only set mandatory parameters here. aka, the one that are not
// actually feature-gated.