From 73346939318b4307c2e585c5334661fe99b179d4 Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Fri, 21 Mar 2025 12:55:54 +0100 Subject: [PATCH] Prover(perf): faster global constraints compilation (#704) * bench(global): adds a benchmark for the global constraint compiler * perf(merging): accumulates the factors before creating the expression * perf(product): computes the ESH without using a smart-vector * perf(factor): preallocations in the factorization algorithm * perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs * perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations. * perf(factor): use an integer map instead of a field.Element map when possible * fixup(expands): fix the skip condition for term expansion * perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb * feat(repr): adds a json repr function to help debugging * test(constructor): cleans the test of the constructors * perf(factor): address maps using the first limb of a field.Element instead of the full field.Element * fixup(commit): adds missing file in previous commit * perf(factor): reduce the number of calls to rankChildren * perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims * perf(factors): use a counter in getCommonProdParentOfCs * perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function * clean(factor): remove unneeded function and imports * feat(utils): adds a generic sort interface implementation * perf(rankChildren): lazy allocation of the map to save on allocations * perf(factorize): reduces the loop-bound for factorizeExpression * (chore): fix a missing argument and format gofmt * feat: readd test --------- Signed-off-by: AlexandreBelling Co-authored-by: gusiri --- .../compiler/globalcs/global_perf_test.go | 75 +++++++ prover/protocol/compiler/globalcs/merging.go | 28 ++- prover/symbolic/constructor_new.go | 32 +-- prover/symbolic/expression.go | 12 +- prover/symbolic/expression_test.go | 27 +-- prover/symbolic/product.go | 30 ++- prover/symbolic/simplification.go | 68 +++++-- prover/symbolic/simplify/factor.go | 185 +++++++----------- prover/symbolic/simplify/factor_test.go | 31 +-- prover/symbolic/simplify/rmpolyeval.go | 15 +- prover/utils/sort.go | 22 +++ prover/zkevm/full.go | 9 +- 12 files changed, 333 insertions(+), 201 deletions(-) create mode 100644 prover/protocol/compiler/globalcs/global_perf_test.go create mode 100644 prover/utils/sort.go diff --git a/prover/protocol/compiler/globalcs/global_perf_test.go b/prover/protocol/compiler/globalcs/global_perf_test.go new file mode 100644 index 00000000..9df8fe5e --- /dev/null +++ b/prover/protocol/compiler/globalcs/global_perf_test.go @@ -0,0 +1,75 @@ +package globalcs_test + +import ( + "os" + "testing" + + "github.com/consensys/linea-monorepo/prover/config" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/cleanup" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/zkevm" + "github.com/sirupsen/logrus" +) + +// BenchmarkGlobalConstraint benchmarks the global constraints compiler against the +// actual zk-evm constraint system. +func BenchmarkGlobalConstraintWoArtefacts(b *testing.B) { + + // partialSuite corresponds to the actual compilation suite of + // the full zk-evm down to the point where the global constraints + // are compiled. + partialSuite := []func(*wizard.CompiledIOP){ + mimc.CompileMiMC, + specialqueries.RangeProof, + specialqueries.CompileFixedPermutations, + permutation.CompileGrandProduct, + lookup.CompileLogDerivative, + innerproduct.Compile, + sticker.Sticker(1<<10, 1<<19), + splitter.SplitColumns(1 << 19), + cleanup.CleanUp, + localcs.Compile, + globalcs.Compile, + } + + // In order to load the config we need to position ourselves in the root + // folder. + _ = os.Chdir("../../..") + defer os.Chdir("protocol/compiler/globalcs") + + // config corresponds to the config we use on sepolia + cfg, cfgErr := config.NewConfigFromFile("./config/config-sepolia-full.toml") + if cfgErr != nil { + b.Fatalf("could not find the config: %v", cfgErr) + } + + // Shut the logger to not overwhelm the benchmark output + logrus.SetLevel(logrus.PanicLevel) + + b.ResetTimer() + + for c_ := 0; c_ < b.N; c_++ { + + b.StopTimer() + + // Removes the artefacts to + if err := os.RemoveAll("/tmp/prover-artefacts"); err != nil { + b.Fatalf("could not remove the artefacts: %v", err) + } + + b.StartTimer() + + _ = zkevm.FullZKEVMWithSuite(&cfg.TracesLimits, partialSuite, cfg) + + } + +} diff --git a/prover/protocol/compiler/globalcs/merging.go b/prover/protocol/compiler/globalcs/merging.go index f7b04c45..8e2ee7ad 100644 --- a/prover/protocol/compiler/globalcs/merging.go +++ b/prover/protocol/compiler/globalcs/merging.go @@ -146,19 +146,28 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression domainSize = cs.DomainSize x = variables.NewXVar() omega = fft.GetOmega(domainSize) + // factors is a list of expression to multiply to obtain the return expression. It + // is initialized with "only" the initial expression and we iteratively add the + // terms (X-i) to it. At the end, we call [sym.Mul] a single time. This structure + // is important because it [sym.Mul] operates a sequence of optimization routines + // that are everytime we call it. In an earlier version, we were calling [sym.Mul] + // for every factor and this were making the function have a quadratic/cubic runtime. + factors = make([]any, 0, utils.Abs(cancelRange.Max)+utils.Abs(cancelRange.Min)+1) ) - // cancelExprAtPoint cancels the expression at a particular position - cancelExprAtPoint := func(expr *symbolic.Expression, i int) *symbolic.Expression { + factors = append(factors, res) + + // appendFactor appends an expressions representing $X-\rho^i$ to [factors] + appendFactor := func(i int) { var root field.Element root.Exp(omega, big.NewInt(int64(i))) - return symbolic.Mul(expr, symbolic.Sub(x, root)) + factors = append(factors, symbolic.Sub(x, root)) } if cancelRange.Min < 0 { // Cancels the expression on the range [0, -cancelRange.Min) for i := 0; i < -cancelRange.Min; i++ { - res = cancelExprAtPoint(res, i) + appendFactor(i) } } @@ -166,11 +175,18 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression // Cancels the expression on the range (N-cancelRange.Max-1, N-1] for i := 0; i < cancelRange.Max; i++ { point := domainSize - i - 1 // point at which we want to cancel the constraint - res = cancelExprAtPoint(res, point) + appendFactor(point) } } - return res + // When factors is of length 1, it means the expression does not need to be + // bound-cancelled and we can directly return the original expression + // without calling [sym.Mul]. + if len(factors) == 1 { + return res + } + + return symbolic.Mul(factors...) } // getExprRatio computes the ratio of the expression and ceil to the next power diff --git a/prover/symbolic/constructor_new.go b/prover/symbolic/constructor_new.go index 41c0212f..eade5e73 100644 --- a/prover/symbolic/constructor_new.go +++ b/prover/symbolic/constructor_new.go @@ -27,12 +27,12 @@ func Add(inputs ...any) *Expression { exprInputs := intoExprSlice(inputs...) - res := exprInputs[0] - for i := 1; i < len(exprInputs); i++ { - res = res.Add(exprInputs[i]) + magnitudes := make([]int, len(exprInputs)) + for i := range exprInputs { + magnitudes[i] = 1 } - return res + return NewLinComb(exprInputs, magnitudes) } // Mul constructs a symbolic expression representing the product of its inputs. @@ -57,12 +57,12 @@ func Mul(inputs ...any) *Expression { exprInputs := intoExprSlice(inputs...) - res := exprInputs[0] - for i := 1; i < len(exprInputs); i++ { - res = res.Mul(exprInputs[i]) + magnitudes := make([]int, len(exprInputs)) + for i := range exprInputs { + magnitudes[i] = 1 } - return res + return NewProduct(exprInputs, magnitudes) } // Sub returns a symbolic expression representing the subtraction of `a` by all @@ -82,16 +82,18 @@ func Sub(a any, bs ...any) *Expression { } var ( - aExpr = intoExpr(a) - bExpr = intoExprSlice(bs...) - res = aExpr + aExpr = intoExpr(a) + bExpr = intoExprSlice(bs...) + exprInputs = append([]*Expression{aExpr}, bExpr...) + magnitudes = make([]int, len(exprInputs)) ) - for i := range bExpr { - res = res.Sub(bExpr[i]) + for i := range exprInputs { + magnitudes[i] = -1 } + magnitudes[0] = 1 - return res + return NewLinComb(exprInputs, magnitudes) } // Neg returns an expression representing the negation of an expression or of @@ -119,7 +121,7 @@ func Square(x any) *Expression { return nil } - return intoExpr(x).Square() + return NewProduct([]*Expression{intoExpr(x)}, []int{2}) } // Pow returns an expression representing the raising to the power "n" of an diff --git a/prover/symbolic/expression.go b/prover/symbolic/expression.go index 9b293f59..ab74a333 100644 --- a/prover/symbolic/expression.go +++ b/prover/symbolic/expression.go @@ -1,6 +1,7 @@ package symbolic import ( + "encoding/json" "fmt" "reflect" "sync" @@ -284,5 +285,14 @@ func (e *Expression) SameWithNewChildren(newChildren []*Expression) *Expression default: panic("unexpected type: " + reflect.TypeOf(op).String()) } - +} + +// MarshalJSONString returns a JSON string returns a JSON string representation +// of the expression. +func (e *Expression) MarshalJSONString() string { + js, jsErr := json.MarshalIndent(e, "", " ") + if jsErr != nil { + utils.Panic("failed to marshal expression: %v", jsErr) + } + return string(js) } diff --git a/prover/symbolic/expression_test.go b/prover/symbolic/expression_test.go index 0f833f4d..d06b5184 100644 --- a/prover/symbolic/expression_test.go +++ b/prover/symbolic/expression_test.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/utils/collection" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -96,7 +97,7 @@ func TestLCConstruction(t *testing.T) { x := NewDummyVar("x") y := NewDummyVar("y") - { + t.Run("simple-addition", func(t *testing.T) { /* Test t a simple case of addition */ @@ -108,23 +109,23 @@ func TestLCConstruction(t *testing.T) { require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs)) require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1) require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1) - } + }) - { + t.Run("x-y-x", func(t *testing.T) { /* Adding y then substracting x should give back (y) */ expr1 := x.Add(y).Sub(x) require.Equal(t, expr1, y) - } + }) - { + t.Run("(-x)+x+y", func(t *testing.T) { /* Same thing when using Neg */ expr := x.Neg().Add(x).Add(y) - require.Equal(t, expr, y) - } + assert.Equal(t, expr, y) + }) } @@ -133,7 +134,7 @@ func TestProductConstruction(t *testing.T) { x := NewDummyVar("x") y := NewDummyVar("y") - { + t.Run("x * y", func(t *testing.T) { /* Test t a simple case of addition */ @@ -145,9 +146,9 @@ func TestProductConstruction(t *testing.T) { require.Equal(t, 2, len(expr1.Operator.(Product).Exponents)) require.Equal(t, expr1.Operator.(Product).Exponents[0], 1) require.Equal(t, expr1.Operator.(Product).Exponents[1], 1) - } + }) - { + t.Run("x * y * x", func(t *testing.T) { /* Adding y then substracting x should give back (y) */ @@ -158,9 +159,9 @@ func TestProductConstruction(t *testing.T) { require.Equal(t, 2, len(expr1.Operator.(Product).Exponents)) require.Equal(t, expr1.Operator.(Product).Exponents[0], 2) require.Equal(t, expr1.Operator.(Product).Exponents[1], 1) - } + }) - { + t.Run("x^2", func(t *testing.T) { /* When we square */ @@ -169,6 +170,6 @@ func TestProductConstruction(t *testing.T) { require.Equal(t, expr.Children[0], x) require.Equal(t, 1, len(expr.Operator.(Product).Exponents)) require.Equal(t, expr.Operator.(Product).Exponents[0], 2) - } + }) } diff --git a/prover/symbolic/product.go b/prover/symbolic/product.go index 0e646dd1..b276d68d 100644 --- a/prover/symbolic/product.go +++ b/prover/symbolic/product.go @@ -95,18 +95,30 @@ func NewProduct(items []*Expression, exponents []int) *Expression { e := &Expression{ Operator: Product{Exponents: exponents}, Children: items, + ESHash: field.One(), } - // Now we need to assign the ESH - eshashes := make([]sv.SmartVector, len(e.Children)) for i := range e.Children { - eshashes[i] = sv.NewConstant(e.Children[i].ESHash, 1) - } - - if len(items) > 0 { - // The cast back to sv.Constant is no important functionally but is an easy - // sanity check. - e.ESHash = e.Operator.Evaluate(eshashes).(*sv.Constant).Get(0) + var tmp field.Element + switch { + case exponents[i] == 1: + e.ESHash.Mul(&e.ESHash, &e.Children[i].ESHash) + case exponents[i] == 2: + tmp.Square(&e.Children[i].ESHash) + e.ESHash.Mul(&e.ESHash, &tmp) + case exponents[i] == 3: + tmp.Square(&e.Children[i].ESHash) + tmp.Mul(&tmp, &e.Children[i].ESHash) + e.ESHash.Mul(&e.ESHash, &tmp) + case exponents[i] == 4: + tmp.Square(&e.Children[i].ESHash) + tmp.Square(&tmp) + e.ESHash.Mul(&e.ESHash, &tmp) + default: + exponent := big.NewInt(int64(exponents[i])) + tmp.Exp(e.Children[i].ESHash, exponent) + e.ESHash.Mul(&e.ESHash, &tmp) + } } return e diff --git a/prover/symbolic/simplification.go b/prover/symbolic/simplification.go index 7a51a697..a132840a 100644 --- a/prover/symbolic/simplification.go +++ b/prover/symbolic/simplification.go @@ -78,22 +78,38 @@ func regroupTerms(magnitudes []int, children []*Expression) ( // the linear combination. This function is used both for simplifying [LinComb] // expressions and for simplifying [Product]. "magnitude" denotes either the // coefficient for LinComb or exponents for Product. +// +// The function takes ownership of the provided slices. func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) { if len(magnitudes) != len(children) { panic("magnitudes and children don't have the same length") } - cleanChildren = make([]*Expression, 0, len(children)) - cleanMagnitudes = make([]int, 0, len(children)) - + // cleanChildren and cleanMagnitudes are initialized lazily to + // avoid unnecessarily allocating memory. The underlying assumption + // is that the application will 99% of time never pass zero as a + // magnitude. for i, c := range magnitudes { - if c != 0 { + + if c == 0 && cleanChildren == nil { + cleanChildren = make([]*Expression, i, len(children)) + cleanMagnitudes = make([]int, i, len(children)) + copy(cleanChildren, children[:i]) + copy(cleanMagnitudes, magnitudes[:i]) + } + + if c != 0 && cleanChildren != nil { cleanMagnitudes = append(cleanMagnitudes, magnitudes[i]) cleanChildren = append(cleanChildren, children[i]) } } + if cleanChildren == nil { + cleanChildren = children + cleanMagnitudes = magnitudes + } + return cleanMagnitudes, cleanChildren } @@ -108,14 +124,16 @@ func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes // The caller passes a target operator which may be any value of type either // [LinComb] or [Product]. Any other type yields a panic error. func expandTerms(op Operator, magnitudes []int, children []*Expression) ( - expandedMagnitudes []int, - expandedExpression []*Expression, + []int, + []*Expression, ) { var ( - opIsProd bool - opIsLinC bool - numChildren = len(children) + opIsProd bool + opIsLinC bool + numChildren = len(children) + totalReturnSize = 0 + needExpand = false ) switch op.(type) { @@ -133,9 +151,34 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) ( panic("incompatible number of children and magnitudes") } - // The capacity allocation is purely heuristic - expandedExpression = make([]*Expression, 0, 2*len(magnitudes)) - expandedMagnitudes = make([]int, 0, 2*len(magnitudes)) + // This loops performs a first scan of the children to compute the total + // number of elements to allocate. + for i, child := range children { + + switch child.Operator.(type) { + case Product, *Product: + if opIsProd { + needExpand = true + totalReturnSize += len(children[i].Children) + continue + } + case LinComb, *LinComb: + if opIsLinC { + needExpand = true + totalReturnSize += len(children[i].Children) + continue + } + } + + totalReturnSize++ + } + + if !needExpand { + return magnitudes, children + } + + expandedMagnitudes := make([]int, 0, totalReturnSize) + expandedExpression := make([]*Expression, 0, totalReturnSize) for i := 0; i < numChildren; i++ { @@ -175,7 +218,6 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) ( for k := range child.Children { expandedExpression = append(expandedExpression, child.Children[k]) expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k]) - } continue } diff --git a/prover/symbolic/simplify/factor.go b/prover/symbolic/simplify/factor.go index b6135309..7b37214b 100644 --- a/prover/symbolic/simplify/factor.go +++ b/prover/symbolic/simplify/factor.go @@ -1,7 +1,6 @@ package simplify import ( - "errors" "fmt" "math" "sort" @@ -46,9 +45,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression { // factoring possibilities. There is also a bound on the loop to // prevent infinite loops. // - // The choice of 1000 is purely heuristic and is not meant to be + // The choice of 100 is purely heuristic and is not meant to be // actually met. - for k := 0; k < 1000; k++ { + for k := 0; k < 100; k++ { _, ok := new.Operator.(sym.LinComb) if !ok { return new @@ -103,13 +102,18 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression { // children that are already in the children set. func rankChildren( parents []*sym.Expression, - childrenSet map[field.Element]*sym.Expression, + childrenSet map[uint64]*sym.Expression, ) []*sym.Expression { // List all the grand-children of the expression whose parents are // products and counts the number of occurences by summing the exponents. - relevantGdChildrenCnt := map[field.Element]int{} - uniqueChildrenList := make([]*sym.Expression, 0) + // As an optimization the map is addressed using the first uint64 repr + // of the element. We consider this is good enough to avoid collisions. + // The risk if it happens is that it gets caught by the validation checks + // at the end of the factorization routine. The preallocation value is + // purely heuristic to avoid successive allocations. + var relevantGdChildrenCnt map[uint64]int + var uniqueChildrenList []*sym.Expression for _, p := range parents { @@ -127,16 +131,21 @@ func rankChildren( // If it's in the group, it does not count. We can't add it a second // time. - if _, ok := childrenSet[c.ESHash]; ok { + if _, ok := childrenSet[c.ESHash[0]]; ok { continue } - if _, ok := relevantGdChildrenCnt[c.ESHash]; !ok { - relevantGdChildrenCnt[c.ESHash] = 0 + if relevantGdChildrenCnt == nil { + relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2) + uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2) + } + + if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok { + relevantGdChildrenCnt[c.ESHash[0]] = 0 uniqueChildrenList = append(uniqueChildrenList, c) } - relevantGdChildrenCnt[c.ESHash]++ + relevantGdChildrenCnt[c.ESHash[0]]++ } } @@ -144,7 +153,7 @@ func rankChildren( x := uniqueChildrenList[i].ESHash y := uniqueChildrenList[j].ESHash // We want to a decreasing order - return relevantGdChildrenCnt[x] > relevantGdChildrenCnt[y] + return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]] }) return uniqueChildrenList @@ -155,76 +164,54 @@ func rankChildren( // than one parent. The finding is based on a greedy algorithm. We iteratively // add nodes in the group so that the number of common parents decreases as // slowly as possible. -func findGdChildrenGroup(expr *sym.Expression) map[field.Element]*sym.Expression { +func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression { curParents := expr.Children - childrenSet := map[field.Element]*sym.Expression{} + childrenSet := map[uint64]*sym.Expression{} - for { - ranked := rankChildren(curParents, childrenSet) + ranked := rankChildren(curParents, childrenSet) - // Can happen when we have a lincomb of lincomb. Ideally they should be - // merged during canonization. - if len(ranked) == 0 { - return childrenSet - } + // Can happen when we have a lincomb of lincomb. Ideally they should be + // merged during canonization. + if len(ranked) == 0 { + return childrenSet + } - best := ranked[0] - newChildrenSet := copyMap(childrenSet) - newChildrenSet[best.ESHash] = best - newParents := getCommonProdParentOfCs(newChildrenSet, curParents) + for i := range ranked { + + best := ranked[i] + childrenSet[best.ESHash[0]] = best + curParents = filterParentsWithChildren(curParents, best.ESHash) // Can't grow the set anymore - if len(newParents) <= 1 { + if len(curParents) <= 1 { + delete(childrenSet, best.ESHash[0]) return childrenSet } - childrenSet = newChildrenSet - curParents = newParents - logrus.Tracef( "find groups, so far we have %v parents and %v siblings", len(curParents), len(childrenSet)) - - // Sanity-check - if err := parentsMustHaveAllChildren(curParents, childrenSet); err != nil { - panic(err) - } } + + return childrenSet } -// getCommonProdParentOfCs returns the parents that have all cs as children and -// that are themselves children of gdp (grandparent). The parents must be of -// type product however. -func getCommonProdParentOfCs( - cs map[field.Element]*sym.Expression, +// filterParentsWithChildren returns a filtered list of parents who have at +// least one child with the given ESHash. The function allocates a new list +// of parents and returns it without mutating he original list. +func filterParentsWithChildren( parents []*sym.Expression, + childEsh field.Element, ) []*sym.Expression { - res := []*sym.Expression{} - + res := make([]*sym.Expression, 0, len(parents)) for _, p := range parents { - prod, ok := p.Operator.(sym.Product) - if !ok { - continue - } - - // Account for the fact that p may contain duplicates. So we cannot - // just use a counter here. - founds := map[field.Element]struct{}{} - for i, c := range p.Children { - if prod.Exponents[i] == 0 { - continue + for _, c := range p.Children { + if c.ESHash == childEsh { + res = append(res, p) + break } - - if _, inside := cs[c.ESHash]; inside { - // logrus.Tracef("%v contains %v", p.ESHash.String(), c.ESHash.String()) - founds[c.ESHash] = struct{}{} - } - } - - if len(founds) == len(cs) { - res = append(res, p) } } @@ -235,22 +222,26 @@ func getCommonProdParentOfCs( // determine the best common factor. func factorLinCompFromGroup( lincom *sym.Expression, - group map[field.Element]*sym.Expression, + group map[uint64]*sym.Expression, ) *sym.Expression { var ( + + // numTerms indicates the number of children in the linear-combination + numTerms = len(lincom.Children) + lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs // Build the common term by taking the max of the exponents exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group) // Separate the non-factored terms - nonFactoredTerms = []*sym.Expression{} - nonFactoredCoeffs = []int{} + nonFactoredTerms = make([]*sym.Expression, 0, numTerms) + nonFactoredCoeffs = make([]int, 0, numTerms) // The factored terms of the linear combination divided by the common // group factor - factoredTerms = []*sym.Expression{} - factoredCoeffs = []int{} + factoredTerms = make([]*sym.Expression, 0, numTerms) + factoredCoeffs = make([]int, 0, numTerms) ) numFactors := 0 @@ -295,7 +286,7 @@ func factorLinCompFromGroup( // // Fortunately, this is guaranteed if the expression was constructed via // [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory. -func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) ( +func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) ( factored *sym.Expression, success bool, ) { @@ -310,7 +301,7 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) ( numMatches := 0 for i, c := range e.Children { - eig, found := exponentsOfGroup[c.ESHash] + eig, found := exponentsOfGroup[c.ESHash[0]] if !found { continue } @@ -335,14 +326,14 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) ( // have the whole group as children. func optimRegroupExponents( parents []*sym.Expression, - group map[field.Element]*sym.Expression, + group map[uint64]*sym.Expression, ) ( - exponentMap map[field.Element]int, + exponentMap map[uint64]int, groupedTerm *sym.Expression, ) { - exponentMap = map[field.Element]int{} - canonTermList := make([]*sym.Expression, 0) // built in deterministic order + exponentMap = make(map[uint64]int, 16) + canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order for _, p := range parents { @@ -354,10 +345,10 @@ func optimRegroupExponents( // Used to sanity-check that all the nodes of the group have been // reached through this parent. - matched := map[field.Element]int{} + matched := make(map[uint64]int, len(p.Children)) for i, c := range p.Children { - if _, ingroup := group[c.ESHash]; !ingroup { + if _, ingroup := group[c.ESHash[0]]; !ingroup { continue } @@ -365,16 +356,16 @@ func optimRegroupExponents( panic("The expression is not canonic") } - _, initialized := exponentMap[c.ESHash] + _, initialized := exponentMap[c.ESHash[0]] if !initialized { // Max int is used as a placeholder. It will be replaced anytime // we wall utils.Min(exponentMap[h], n) where n is actually an // exponent. - exponentMap[c.ESHash] = math.MaxInt + exponentMap[c.ESHash[0]] = math.MaxInt canonTermList = append(canonTermList, c) } - matched[c.ESHash] = exponents[i] + matched[c.ESHash[0]] = exponents[i] } if len(matched) != len(group) { @@ -391,48 +382,8 @@ func optimRegroupExponents( canonExponents := []int{} for _, e := range canonTermList { - canonExponents = append(canonExponents, exponentMap[e.ESHash]) + canonExponents = append(canonExponents, exponentMap[e.ESHash[0]]) } return exponentMap, sym.NewProduct(canonTermList, canonExponents) } - -// parentsMustHaveAllChildren returns an error if at least one of the parents -// is missing one children from the set. This function is used internally to -// enforce invariants throughout the simplification routines. -func parentsMustHaveAllChildren[T any]( - parents []*sym.Expression, - childrenSet map[field.Element]T, -) (resErr error) { - - for parentID, p := range parents { - // Account for the fact that the node may contain duplicates of the node - // we are looking for. - founds := map[field.Element]struct{}{} - for _, c := range p.Children { - if _, ok := childrenSet[c.ESHash]; ok { - founds[c.ESHash] = struct{}{} - } - } - - if len(founds) != len(childrenSet) { - resErr = errors.Join( - resErr, - fmt.Errorf( - "parent num %v is incomplete : found = %d/%d", - parentID, len(founds), len(childrenSet), - ), - ) - } - } - - return resErr -} - -func copyMap[K comparable, V any](m map[K]V) map[K]V { - res := make(map[K]V, len(m)) - for k, v := range m { - res[k] = v - } - return res -} diff --git a/prover/symbolic/simplify/factor_test.go b/prover/symbolic/simplify/factor_test.go index 60bb8ba7..87bb0fbe 100644 --- a/prover/symbolic/simplify/factor_test.go +++ b/prover/symbolic/simplify/factor_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/consensys/linea-monorepo/prover/maths/field" sym "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -60,12 +59,6 @@ func TestIsFactored(t *testing.T) { IsFactored: true, Factor: a, }, - { - Expr: sym.Mul(a, a, b), - By: a, - IsFactored: true, - Factor: sym.Mul(a, b), - }, } for i, tc := range testcases { @@ -74,13 +67,13 @@ func TestIsFactored(t *testing.T) { // Build the group exponent map. If by is a product, we use directly // the exponents it contains. Otherwise, we say this is a single // term product with an exponent of 1. - groupedExp := map[field.Element]int{} + groupedExp := map[uint64]int{} if byProd, ok := tc.By.Operator.(sym.Product); ok { for i, ex := range byProd.Exponents { - groupedExp[tc.By.Children[i].ESHash] = ex + groupedExp[tc.By.Children[i].ESHash[0]] = ex } } else { - groupedExp[tc.By.ESHash] = 1 + groupedExp[tc.By.ESHash[0]] = 1 } factored, isFactored := isFactored(tc.Expr, groupedExp) @@ -207,15 +200,27 @@ func TestFactorLinCompFromGroup(t *testing.T) { for i, testCase := range testCases { t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) { - group := map[field.Element]*sym.Expression{} + group := map[uint64]*sym.Expression{} for _, e := range testCase.Group { - group[e.ESHash] = e + group[e.ESHash[0]] = e } factored := factorLinCompFromGroup(testCase.LinComb, group) - require.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String()) + assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String()) + + if t.Failed() { + fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString()) + fmt.Printf("factored=%v\n", factored.MarshalJSONString()) + t.Fatal() + } + require.NoError(t, factored.Validate()) assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored)) + + if t.Failed() { + fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString()) + fmt.Printf("factored=%v\n", factored.MarshalJSONString()) + } }) } diff --git a/prover/symbolic/simplify/rmpolyeval.go b/prover/symbolic/simplify/rmpolyeval.go index a3d47149..f9155aa2 100644 --- a/prover/symbolic/simplify/rmpolyeval.go +++ b/prover/symbolic/simplify/rmpolyeval.go @@ -21,24 +21,17 @@ func removePolyEval(e *sym.Expression) *sym.Expression { return oldExpr // Handle edge case where there are no coefficients } - acc := cs[0] - // Precompute powers of x - powersOfX := make([]*sym.Expression, len(cs)) - powersOfX[0] = x - for i := 1; i < len(cs); i++ { + monomialTerms := make([]any, len(cs)) + for i := 0; i < len(cs); i++ { // We don't use the default constructor because it will collapse the // intermediate terms into a single term. The intermediates are useful because // they tell the evaluator to reuse the intermediate terms instead of // computing x^i for every term. - powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1}) + monomialTerms[i] = any(sym.NewProduct([]*sym.Expression{cs[i], x}, []int{1, i})) } - for i := 1; i < len(cs); i++ { - // Here we want to use the default constructor to ensure that we - // will have a merged sum at the end. - acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i])) - } + acc := sym.Add(monomialTerms...) if oldExpr.ESHash != acc.ESHash { panic("ESH was altered") diff --git a/prover/utils/sort.go b/prover/utils/sort.go new file mode 100644 index 00000000..6294df9e --- /dev/null +++ b/prover/utils/sort.go @@ -0,0 +1,22 @@ +package utils + +// GenSorter is a generic interface implementing [sort.Interface] +// but let the user provide the methods directly as closures. +// Without needing to implement a new custom type. +type GenSorter struct { + LenFn func() int + SwapFn func(int, int) + LessFn func(int, int) bool +} + +func (s GenSorter) Len() int { + return s.LenFn() +} + +func (s GenSorter) Swap(i, j int) { + s.SwapFn(i, j) +} + +func (s GenSorter) Less(i, j int) bool { + return s.LessFn(i, j) +} diff --git a/prover/zkevm/full.go b/prover/zkevm/full.go index 438b6731..ff95530e 100644 --- a/prover/zkevm/full.go +++ b/prover/zkevm/full.go @@ -105,7 +105,7 @@ func FullZkEvm(tl *config.TracesLimits, cfg *config.Config) *ZkEvm { onceFullZkEvm.Do(func() { // Initialize the Full zkEVM arithmetization - fullZkEvm = fullZKEVMWithSuite(tl, fullCompilationSuite, cfg) + fullZkEvm = FullZKEVMWithSuite(tl, fullCompilationSuite, cfg) }) return fullZkEvm @@ -115,13 +115,16 @@ func FullZkEVMCheckOnly(tl *config.TracesLimits, cfg *config.Config) *ZkEvm { onceFullZkEvmCheckOnly.Do(func() { // Initialize the Full zkEVM arithmetization - fullZkEvmCheckOnly = fullZKEVMWithSuite(tl, dummyCompilationSuite, cfg) + fullZkEvmCheckOnly = FullZKEVMWithSuite(tl, dummyCompilationSuite, cfg) }) return fullZkEvmCheckOnly } -func fullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm { +// FullZKEVMWithSuite returns a compiled zkEVM with the given compilation suite. +// It can be used to benchmark the compilation time of the zkEVM and helps with +// performance optimization. +func FullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm { // @Alex: only set mandatory parameters here. aka, the one that are not // actually feature-gated.