mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-08 03:43:56 -05:00
Prover(perf): faster global constraints compilation (#704)
* bench(global): adds a benchmark for the global constraint compiler * perf(merging): accumulates the factors before creating the expression * perf(product): computes the ESH without using a smart-vector * perf(factor): preallocations in the factorization algorithm * perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs * perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations. * perf(factor): use an integer map instead of a field.Element map when possible * fixup(expands): fix the skip condition for term expansion * perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb * feat(repr): adds a json repr function to help debugging * test(constructor): cleans the test of the constructors * perf(factor): address maps using the first limb of a field.Element instead of the full field.Element * fixup(commit): adds missing file in previous commit * perf(factor): reduce the number of calls to rankChildren * perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims * perf(factors): use a counter in getCommonProdParentOfCs * perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function * clean(factor): remove unneeded function and imports * feat(utils): adds a generic sort interface implementation * perf(rankChildren): lazy allocation of the map to save on allocations * perf(factorize): reduces the loop-bound for factorizeExpression * (chore): fix a missing argument and format gofmt * feat: readd test --------- Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com> Co-authored-by: gusiri <dreamerty@postech.ac.kr>
This commit is contained in:
75
prover/protocol/compiler/globalcs/global_perf_test.go
Normal file
75
prover/protocol/compiler/globalcs/global_perf_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package globalcs_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/config"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/cleanup"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
|
||||
"github.com/consensys/linea-monorepo/prover/zkevm"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// BenchmarkGlobalConstraint benchmarks the global constraints compiler against the
|
||||
// actual zk-evm constraint system.
|
||||
func BenchmarkGlobalConstraintWoArtefacts(b *testing.B) {
|
||||
|
||||
// partialSuite corresponds to the actual compilation suite of
|
||||
// the full zk-evm down to the point where the global constraints
|
||||
// are compiled.
|
||||
partialSuite := []func(*wizard.CompiledIOP){
|
||||
mimc.CompileMiMC,
|
||||
specialqueries.RangeProof,
|
||||
specialqueries.CompileFixedPermutations,
|
||||
permutation.CompileGrandProduct,
|
||||
lookup.CompileLogDerivative,
|
||||
innerproduct.Compile,
|
||||
sticker.Sticker(1<<10, 1<<19),
|
||||
splitter.SplitColumns(1 << 19),
|
||||
cleanup.CleanUp,
|
||||
localcs.Compile,
|
||||
globalcs.Compile,
|
||||
}
|
||||
|
||||
// In order to load the config we need to position ourselves in the root
|
||||
// folder.
|
||||
_ = os.Chdir("../../..")
|
||||
defer os.Chdir("protocol/compiler/globalcs")
|
||||
|
||||
// config corresponds to the config we use on sepolia
|
||||
cfg, cfgErr := config.NewConfigFromFile("./config/config-sepolia-full.toml")
|
||||
if cfgErr != nil {
|
||||
b.Fatalf("could not find the config: %v", cfgErr)
|
||||
}
|
||||
|
||||
// Shut the logger to not overwhelm the benchmark output
|
||||
logrus.SetLevel(logrus.PanicLevel)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for c_ := 0; c_ < b.N; c_++ {
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
// Removes the artefacts to
|
||||
if err := os.RemoveAll("/tmp/prover-artefacts"); err != nil {
|
||||
b.Fatalf("could not remove the artefacts: %v", err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
_ = zkevm.FullZKEVMWithSuite(&cfg.TracesLimits, partialSuite, cfg)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -146,19 +146,28 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
|
||||
domainSize = cs.DomainSize
|
||||
x = variables.NewXVar()
|
||||
omega = fft.GetOmega(domainSize)
|
||||
// factors is a list of expression to multiply to obtain the return expression. It
|
||||
// is initialized with "only" the initial expression and we iteratively add the
|
||||
// terms (X-i) to it. At the end, we call [sym.Mul] a single time. This structure
|
||||
// is important because it [sym.Mul] operates a sequence of optimization routines
|
||||
// that are everytime we call it. In an earlier version, we were calling [sym.Mul]
|
||||
// for every factor and this were making the function have a quadratic/cubic runtime.
|
||||
factors = make([]any, 0, utils.Abs(cancelRange.Max)+utils.Abs(cancelRange.Min)+1)
|
||||
)
|
||||
|
||||
// cancelExprAtPoint cancels the expression at a particular position
|
||||
cancelExprAtPoint := func(expr *symbolic.Expression, i int) *symbolic.Expression {
|
||||
factors = append(factors, res)
|
||||
|
||||
// appendFactor appends an expressions representing $X-\rho^i$ to [factors]
|
||||
appendFactor := func(i int) {
|
||||
var root field.Element
|
||||
root.Exp(omega, big.NewInt(int64(i)))
|
||||
return symbolic.Mul(expr, symbolic.Sub(x, root))
|
||||
factors = append(factors, symbolic.Sub(x, root))
|
||||
}
|
||||
|
||||
if cancelRange.Min < 0 {
|
||||
// Cancels the expression on the range [0, -cancelRange.Min)
|
||||
for i := 0; i < -cancelRange.Min; i++ {
|
||||
res = cancelExprAtPoint(res, i)
|
||||
appendFactor(i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,11 +175,18 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
|
||||
// Cancels the expression on the range (N-cancelRange.Max-1, N-1]
|
||||
for i := 0; i < cancelRange.Max; i++ {
|
||||
point := domainSize - i - 1 // point at which we want to cancel the constraint
|
||||
res = cancelExprAtPoint(res, point)
|
||||
appendFactor(point)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
// When factors is of length 1, it means the expression does not need to be
|
||||
// bound-cancelled and we can directly return the original expression
|
||||
// without calling [sym.Mul].
|
||||
if len(factors) == 1 {
|
||||
return res
|
||||
}
|
||||
|
||||
return symbolic.Mul(factors...)
|
||||
}
|
||||
|
||||
// getExprRatio computes the ratio of the expression and ceil to the next power
|
||||
|
||||
@@ -27,12 +27,12 @@ func Add(inputs ...any) *Expression {
|
||||
|
||||
exprInputs := intoExprSlice(inputs...)
|
||||
|
||||
res := exprInputs[0]
|
||||
for i := 1; i < len(exprInputs); i++ {
|
||||
res = res.Add(exprInputs[i])
|
||||
magnitudes := make([]int, len(exprInputs))
|
||||
for i := range exprInputs {
|
||||
magnitudes[i] = 1
|
||||
}
|
||||
|
||||
return res
|
||||
return NewLinComb(exprInputs, magnitudes)
|
||||
}
|
||||
|
||||
// Mul constructs a symbolic expression representing the product of its inputs.
|
||||
@@ -57,12 +57,12 @@ func Mul(inputs ...any) *Expression {
|
||||
|
||||
exprInputs := intoExprSlice(inputs...)
|
||||
|
||||
res := exprInputs[0]
|
||||
for i := 1; i < len(exprInputs); i++ {
|
||||
res = res.Mul(exprInputs[i])
|
||||
magnitudes := make([]int, len(exprInputs))
|
||||
for i := range exprInputs {
|
||||
magnitudes[i] = 1
|
||||
}
|
||||
|
||||
return res
|
||||
return NewProduct(exprInputs, magnitudes)
|
||||
}
|
||||
|
||||
// Sub returns a symbolic expression representing the subtraction of `a` by all
|
||||
@@ -82,16 +82,18 @@ func Sub(a any, bs ...any) *Expression {
|
||||
}
|
||||
|
||||
var (
|
||||
aExpr = intoExpr(a)
|
||||
bExpr = intoExprSlice(bs...)
|
||||
res = aExpr
|
||||
aExpr = intoExpr(a)
|
||||
bExpr = intoExprSlice(bs...)
|
||||
exprInputs = append([]*Expression{aExpr}, bExpr...)
|
||||
magnitudes = make([]int, len(exprInputs))
|
||||
)
|
||||
|
||||
for i := range bExpr {
|
||||
res = res.Sub(bExpr[i])
|
||||
for i := range exprInputs {
|
||||
magnitudes[i] = -1
|
||||
}
|
||||
magnitudes[0] = 1
|
||||
|
||||
return res
|
||||
return NewLinComb(exprInputs, magnitudes)
|
||||
}
|
||||
|
||||
// Neg returns an expression representing the negation of an expression or of
|
||||
@@ -119,7 +121,7 @@ func Square(x any) *Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
return intoExpr(x).Square()
|
||||
return NewProduct([]*Expression{intoExpr(x)}, []int{2})
|
||||
}
|
||||
|
||||
// Pow returns an expression representing the raising to the power "n" of an
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package symbolic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
@@ -284,5 +285,14 @@ func (e *Expression) SameWithNewChildren(newChildren []*Expression) *Expression
|
||||
default:
|
||||
panic("unexpected type: " + reflect.TypeOf(op).String())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// MarshalJSONString returns a JSON string returns a JSON string representation
|
||||
// of the expression.
|
||||
func (e *Expression) MarshalJSONString() string {
|
||||
js, jsErr := json.MarshalIndent(e, "", " ")
|
||||
if jsErr != nil {
|
||||
utils.Panic("failed to marshal expression: %v", jsErr)
|
||||
}
|
||||
return string(js)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||
"github.com/consensys/linea-monorepo/prover/utils/collection"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -96,7 +97,7 @@ func TestLCConstruction(t *testing.T) {
|
||||
x := NewDummyVar("x")
|
||||
y := NewDummyVar("y")
|
||||
|
||||
{
|
||||
t.Run("simple-addition", func(t *testing.T) {
|
||||
/*
|
||||
Test t a simple case of addition
|
||||
*/
|
||||
@@ -108,23 +109,23 @@ func TestLCConstruction(t *testing.T) {
|
||||
require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs))
|
||||
require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1)
|
||||
require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("x-y-x", func(t *testing.T) {
|
||||
/*
|
||||
Adding y then substracting x should give back (y)
|
||||
*/
|
||||
expr1 := x.Add(y).Sub(x)
|
||||
require.Equal(t, expr1, y)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("(-x)+x+y", func(t *testing.T) {
|
||||
/*
|
||||
Same thing when using Neg
|
||||
*/
|
||||
expr := x.Neg().Add(x).Add(y)
|
||||
require.Equal(t, expr, y)
|
||||
}
|
||||
assert.Equal(t, expr, y)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -133,7 +134,7 @@ func TestProductConstruction(t *testing.T) {
|
||||
x := NewDummyVar("x")
|
||||
y := NewDummyVar("y")
|
||||
|
||||
{
|
||||
t.Run("x * y", func(t *testing.T) {
|
||||
/*
|
||||
Test t a simple case of addition
|
||||
*/
|
||||
@@ -145,9 +146,9 @@ func TestProductConstruction(t *testing.T) {
|
||||
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
||||
require.Equal(t, expr1.Operator.(Product).Exponents[0], 1)
|
||||
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("x * y * x", func(t *testing.T) {
|
||||
/*
|
||||
Adding y then substracting x should give back (y)
|
||||
*/
|
||||
@@ -158,9 +159,9 @@ func TestProductConstruction(t *testing.T) {
|
||||
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
||||
require.Equal(t, expr1.Operator.(Product).Exponents[0], 2)
|
||||
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
||||
}
|
||||
})
|
||||
|
||||
{
|
||||
t.Run("x^2", func(t *testing.T) {
|
||||
/*
|
||||
When we square
|
||||
*/
|
||||
@@ -169,6 +170,6 @@ func TestProductConstruction(t *testing.T) {
|
||||
require.Equal(t, expr.Children[0], x)
|
||||
require.Equal(t, 1, len(expr.Operator.(Product).Exponents))
|
||||
require.Equal(t, expr.Operator.(Product).Exponents[0], 2)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -95,18 +95,30 @@ func NewProduct(items []*Expression, exponents []int) *Expression {
|
||||
e := &Expression{
|
||||
Operator: Product{Exponents: exponents},
|
||||
Children: items,
|
||||
ESHash: field.One(),
|
||||
}
|
||||
|
||||
// Now we need to assign the ESH
|
||||
eshashes := make([]sv.SmartVector, len(e.Children))
|
||||
for i := range e.Children {
|
||||
eshashes[i] = sv.NewConstant(e.Children[i].ESHash, 1)
|
||||
}
|
||||
|
||||
if len(items) > 0 {
|
||||
// The cast back to sv.Constant is no important functionally but is an easy
|
||||
// sanity check.
|
||||
e.ESHash = e.Operator.Evaluate(eshashes).(*sv.Constant).Get(0)
|
||||
var tmp field.Element
|
||||
switch {
|
||||
case exponents[i] == 1:
|
||||
e.ESHash.Mul(&e.ESHash, &e.Children[i].ESHash)
|
||||
case exponents[i] == 2:
|
||||
tmp.Square(&e.Children[i].ESHash)
|
||||
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||
case exponents[i] == 3:
|
||||
tmp.Square(&e.Children[i].ESHash)
|
||||
tmp.Mul(&tmp, &e.Children[i].ESHash)
|
||||
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||
case exponents[i] == 4:
|
||||
tmp.Square(&e.Children[i].ESHash)
|
||||
tmp.Square(&tmp)
|
||||
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||
default:
|
||||
exponent := big.NewInt(int64(exponents[i]))
|
||||
tmp.Exp(e.Children[i].ESHash, exponent)
|
||||
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return e
|
||||
|
||||
@@ -78,22 +78,38 @@ func regroupTerms(magnitudes []int, children []*Expression) (
|
||||
// the linear combination. This function is used both for simplifying [LinComb]
|
||||
// expressions and for simplifying [Product]. "magnitude" denotes either the
|
||||
// coefficient for LinComb or exponents for Product.
|
||||
//
|
||||
// The function takes ownership of the provided slices.
|
||||
func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) {
|
||||
|
||||
if len(magnitudes) != len(children) {
|
||||
panic("magnitudes and children don't have the same length")
|
||||
}
|
||||
|
||||
cleanChildren = make([]*Expression, 0, len(children))
|
||||
cleanMagnitudes = make([]int, 0, len(children))
|
||||
|
||||
// cleanChildren and cleanMagnitudes are initialized lazily to
|
||||
// avoid unnecessarily allocating memory. The underlying assumption
|
||||
// is that the application will 99% of time never pass zero as a
|
||||
// magnitude.
|
||||
for i, c := range magnitudes {
|
||||
if c != 0 {
|
||||
|
||||
if c == 0 && cleanChildren == nil {
|
||||
cleanChildren = make([]*Expression, i, len(children))
|
||||
cleanMagnitudes = make([]int, i, len(children))
|
||||
copy(cleanChildren, children[:i])
|
||||
copy(cleanMagnitudes, magnitudes[:i])
|
||||
}
|
||||
|
||||
if c != 0 && cleanChildren != nil {
|
||||
cleanMagnitudes = append(cleanMagnitudes, magnitudes[i])
|
||||
cleanChildren = append(cleanChildren, children[i])
|
||||
}
|
||||
}
|
||||
|
||||
if cleanChildren == nil {
|
||||
cleanChildren = children
|
||||
cleanMagnitudes = magnitudes
|
||||
}
|
||||
|
||||
return cleanMagnitudes, cleanChildren
|
||||
}
|
||||
|
||||
@@ -108,14 +124,16 @@ func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes
|
||||
// The caller passes a target operator which may be any value of type either
|
||||
// [LinComb] or [Product]. Any other type yields a panic error.
|
||||
func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
||||
expandedMagnitudes []int,
|
||||
expandedExpression []*Expression,
|
||||
[]int,
|
||||
[]*Expression,
|
||||
) {
|
||||
|
||||
var (
|
||||
opIsProd bool
|
||||
opIsLinC bool
|
||||
numChildren = len(children)
|
||||
opIsProd bool
|
||||
opIsLinC bool
|
||||
numChildren = len(children)
|
||||
totalReturnSize = 0
|
||||
needExpand = false
|
||||
)
|
||||
|
||||
switch op.(type) {
|
||||
@@ -133,9 +151,34 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
||||
panic("incompatible number of children and magnitudes")
|
||||
}
|
||||
|
||||
// The capacity allocation is purely heuristic
|
||||
expandedExpression = make([]*Expression, 0, 2*len(magnitudes))
|
||||
expandedMagnitudes = make([]int, 0, 2*len(magnitudes))
|
||||
// This loops performs a first scan of the children to compute the total
|
||||
// number of elements to allocate.
|
||||
for i, child := range children {
|
||||
|
||||
switch child.Operator.(type) {
|
||||
case Product, *Product:
|
||||
if opIsProd {
|
||||
needExpand = true
|
||||
totalReturnSize += len(children[i].Children)
|
||||
continue
|
||||
}
|
||||
case LinComb, *LinComb:
|
||||
if opIsLinC {
|
||||
needExpand = true
|
||||
totalReturnSize += len(children[i].Children)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
totalReturnSize++
|
||||
}
|
||||
|
||||
if !needExpand {
|
||||
return magnitudes, children
|
||||
}
|
||||
|
||||
expandedMagnitudes := make([]int, 0, totalReturnSize)
|
||||
expandedExpression := make([]*Expression, 0, totalReturnSize)
|
||||
|
||||
for i := 0; i < numChildren; i++ {
|
||||
|
||||
@@ -175,7 +218,6 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
||||
for k := range child.Children {
|
||||
expandedExpression = append(expandedExpression, child.Children[k])
|
||||
expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k])
|
||||
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package simplify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
@@ -46,9 +45,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
|
||||
// factoring possibilities. There is also a bound on the loop to
|
||||
// prevent infinite loops.
|
||||
//
|
||||
// The choice of 1000 is purely heuristic and is not meant to be
|
||||
// The choice of 100 is purely heuristic and is not meant to be
|
||||
// actually met.
|
||||
for k := 0; k < 1000; k++ {
|
||||
for k := 0; k < 100; k++ {
|
||||
_, ok := new.Operator.(sym.LinComb)
|
||||
if !ok {
|
||||
return new
|
||||
@@ -103,13 +102,18 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
|
||||
// children that are already in the children set.
|
||||
func rankChildren(
|
||||
parents []*sym.Expression,
|
||||
childrenSet map[field.Element]*sym.Expression,
|
||||
childrenSet map[uint64]*sym.Expression,
|
||||
) []*sym.Expression {
|
||||
|
||||
// List all the grand-children of the expression whose parents are
|
||||
// products and counts the number of occurences by summing the exponents.
|
||||
relevantGdChildrenCnt := map[field.Element]int{}
|
||||
uniqueChildrenList := make([]*sym.Expression, 0)
|
||||
// As an optimization the map is addressed using the first uint64 repr
|
||||
// of the element. We consider this is good enough to avoid collisions.
|
||||
// The risk if it happens is that it gets caught by the validation checks
|
||||
// at the end of the factorization routine. The preallocation value is
|
||||
// purely heuristic to avoid successive allocations.
|
||||
var relevantGdChildrenCnt map[uint64]int
|
||||
var uniqueChildrenList []*sym.Expression
|
||||
|
||||
for _, p := range parents {
|
||||
|
||||
@@ -127,16 +131,21 @@ func rankChildren(
|
||||
|
||||
// If it's in the group, it does not count. We can't add it a second
|
||||
// time.
|
||||
if _, ok := childrenSet[c.ESHash]; ok {
|
||||
if _, ok := childrenSet[c.ESHash[0]]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := relevantGdChildrenCnt[c.ESHash]; !ok {
|
||||
relevantGdChildrenCnt[c.ESHash] = 0
|
||||
if relevantGdChildrenCnt == nil {
|
||||
relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2)
|
||||
uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2)
|
||||
}
|
||||
|
||||
if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok {
|
||||
relevantGdChildrenCnt[c.ESHash[0]] = 0
|
||||
uniqueChildrenList = append(uniqueChildrenList, c)
|
||||
}
|
||||
|
||||
relevantGdChildrenCnt[c.ESHash]++
|
||||
relevantGdChildrenCnt[c.ESHash[0]]++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +153,7 @@ func rankChildren(
|
||||
x := uniqueChildrenList[i].ESHash
|
||||
y := uniqueChildrenList[j].ESHash
|
||||
// We want to a decreasing order
|
||||
return relevantGdChildrenCnt[x] > relevantGdChildrenCnt[y]
|
||||
return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]]
|
||||
})
|
||||
|
||||
return uniqueChildrenList
|
||||
@@ -155,76 +164,54 @@ func rankChildren(
|
||||
// than one parent. The finding is based on a greedy algorithm. We iteratively
|
||||
// add nodes in the group so that the number of common parents decreases as
|
||||
// slowly as possible.
|
||||
func findGdChildrenGroup(expr *sym.Expression) map[field.Element]*sym.Expression {
|
||||
func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression {
|
||||
|
||||
curParents := expr.Children
|
||||
childrenSet := map[field.Element]*sym.Expression{}
|
||||
childrenSet := map[uint64]*sym.Expression{}
|
||||
|
||||
for {
|
||||
ranked := rankChildren(curParents, childrenSet)
|
||||
ranked := rankChildren(curParents, childrenSet)
|
||||
|
||||
// Can happen when we have a lincomb of lincomb. Ideally they should be
|
||||
// merged during canonization.
|
||||
if len(ranked) == 0 {
|
||||
return childrenSet
|
||||
}
|
||||
// Can happen when we have a lincomb of lincomb. Ideally they should be
|
||||
// merged during canonization.
|
||||
if len(ranked) == 0 {
|
||||
return childrenSet
|
||||
}
|
||||
|
||||
best := ranked[0]
|
||||
newChildrenSet := copyMap(childrenSet)
|
||||
newChildrenSet[best.ESHash] = best
|
||||
newParents := getCommonProdParentOfCs(newChildrenSet, curParents)
|
||||
for i := range ranked {
|
||||
|
||||
best := ranked[i]
|
||||
childrenSet[best.ESHash[0]] = best
|
||||
curParents = filterParentsWithChildren(curParents, best.ESHash)
|
||||
|
||||
// Can't grow the set anymore
|
||||
if len(newParents) <= 1 {
|
||||
if len(curParents) <= 1 {
|
||||
delete(childrenSet, best.ESHash[0])
|
||||
return childrenSet
|
||||
}
|
||||
|
||||
childrenSet = newChildrenSet
|
||||
curParents = newParents
|
||||
|
||||
logrus.Tracef(
|
||||
"find groups, so far we have %v parents and %v siblings",
|
||||
len(curParents), len(childrenSet))
|
||||
|
||||
// Sanity-check
|
||||
if err := parentsMustHaveAllChildren(curParents, childrenSet); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return childrenSet
|
||||
}
|
||||
|
||||
// getCommonProdParentOfCs returns the parents that have all cs as children and
|
||||
// that are themselves children of gdp (grandparent). The parents must be of
|
||||
// type product however.
|
||||
func getCommonProdParentOfCs(
|
||||
cs map[field.Element]*sym.Expression,
|
||||
// filterParentsWithChildren returns a filtered list of parents who have at
|
||||
// least one child with the given ESHash. The function allocates a new list
|
||||
// of parents and returns it without mutating he original list.
|
||||
func filterParentsWithChildren(
|
||||
parents []*sym.Expression,
|
||||
childEsh field.Element,
|
||||
) []*sym.Expression {
|
||||
|
||||
res := []*sym.Expression{}
|
||||
|
||||
res := make([]*sym.Expression, 0, len(parents))
|
||||
for _, p := range parents {
|
||||
prod, ok := p.Operator.(sym.Product)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Account for the fact that p may contain duplicates. So we cannot
|
||||
// just use a counter here.
|
||||
founds := map[field.Element]struct{}{}
|
||||
for i, c := range p.Children {
|
||||
if prod.Exponents[i] == 0 {
|
||||
continue
|
||||
for _, c := range p.Children {
|
||||
if c.ESHash == childEsh {
|
||||
res = append(res, p)
|
||||
break
|
||||
}
|
||||
|
||||
if _, inside := cs[c.ESHash]; inside {
|
||||
// logrus.Tracef("%v contains %v", p.ESHash.String(), c.ESHash.String())
|
||||
founds[c.ESHash] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(founds) == len(cs) {
|
||||
res = append(res, p)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,22 +222,26 @@ func getCommonProdParentOfCs(
|
||||
// determine the best common factor.
|
||||
func factorLinCompFromGroup(
|
||||
lincom *sym.Expression,
|
||||
group map[field.Element]*sym.Expression,
|
||||
group map[uint64]*sym.Expression,
|
||||
) *sym.Expression {
|
||||
|
||||
var (
|
||||
|
||||
// numTerms indicates the number of children in the linear-combination
|
||||
numTerms = len(lincom.Children)
|
||||
|
||||
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
|
||||
// Build the common term by taking the max of the exponents
|
||||
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
|
||||
|
||||
// Separate the non-factored terms
|
||||
nonFactoredTerms = []*sym.Expression{}
|
||||
nonFactoredCoeffs = []int{}
|
||||
nonFactoredTerms = make([]*sym.Expression, 0, numTerms)
|
||||
nonFactoredCoeffs = make([]int, 0, numTerms)
|
||||
|
||||
// The factored terms of the linear combination divided by the common
|
||||
// group factor
|
||||
factoredTerms = []*sym.Expression{}
|
||||
factoredCoeffs = []int{}
|
||||
factoredTerms = make([]*sym.Expression, 0, numTerms)
|
||||
factoredCoeffs = make([]int, 0, numTerms)
|
||||
)
|
||||
|
||||
numFactors := 0
|
||||
@@ -295,7 +286,7 @@ func factorLinCompFromGroup(
|
||||
//
|
||||
// Fortunately, this is guaranteed if the expression was constructed via
|
||||
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
|
||||
func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
||||
func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) (
|
||||
factored *sym.Expression,
|
||||
success bool,
|
||||
) {
|
||||
@@ -310,7 +301,7 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
||||
|
||||
numMatches := 0
|
||||
for i, c := range e.Children {
|
||||
eig, found := exponentsOfGroup[c.ESHash]
|
||||
eig, found := exponentsOfGroup[c.ESHash[0]]
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
@@ -335,14 +326,14 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
||||
// have the whole group as children.
|
||||
func optimRegroupExponents(
|
||||
parents []*sym.Expression,
|
||||
group map[field.Element]*sym.Expression,
|
||||
group map[uint64]*sym.Expression,
|
||||
) (
|
||||
exponentMap map[field.Element]int,
|
||||
exponentMap map[uint64]int,
|
||||
groupedTerm *sym.Expression,
|
||||
) {
|
||||
|
||||
exponentMap = map[field.Element]int{}
|
||||
canonTermList := make([]*sym.Expression, 0) // built in deterministic order
|
||||
exponentMap = make(map[uint64]int, 16)
|
||||
canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order
|
||||
|
||||
for _, p := range parents {
|
||||
|
||||
@@ -354,10 +345,10 @@ func optimRegroupExponents(
|
||||
|
||||
// Used to sanity-check that all the nodes of the group have been
|
||||
// reached through this parent.
|
||||
matched := map[field.Element]int{}
|
||||
matched := make(map[uint64]int, len(p.Children))
|
||||
|
||||
for i, c := range p.Children {
|
||||
if _, ingroup := group[c.ESHash]; !ingroup {
|
||||
if _, ingroup := group[c.ESHash[0]]; !ingroup {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -365,16 +356,16 @@ func optimRegroupExponents(
|
||||
panic("The expression is not canonic")
|
||||
}
|
||||
|
||||
_, initialized := exponentMap[c.ESHash]
|
||||
_, initialized := exponentMap[c.ESHash[0]]
|
||||
if !initialized {
|
||||
// Max int is used as a placeholder. It will be replaced anytime
|
||||
// we wall utils.Min(exponentMap[h], n) where n is actually an
|
||||
// exponent.
|
||||
exponentMap[c.ESHash] = math.MaxInt
|
||||
exponentMap[c.ESHash[0]] = math.MaxInt
|
||||
canonTermList = append(canonTermList, c)
|
||||
}
|
||||
|
||||
matched[c.ESHash] = exponents[i]
|
||||
matched[c.ESHash[0]] = exponents[i]
|
||||
}
|
||||
|
||||
if len(matched) != len(group) {
|
||||
@@ -391,48 +382,8 @@ func optimRegroupExponents(
|
||||
|
||||
canonExponents := []int{}
|
||||
for _, e := range canonTermList {
|
||||
canonExponents = append(canonExponents, exponentMap[e.ESHash])
|
||||
canonExponents = append(canonExponents, exponentMap[e.ESHash[0]])
|
||||
}
|
||||
|
||||
return exponentMap, sym.NewProduct(canonTermList, canonExponents)
|
||||
}
|
||||
|
||||
// parentsMustHaveAllChildren returns an error if at least one of the parents
|
||||
// is missing one children from the set. This function is used internally to
|
||||
// enforce invariants throughout the simplification routines.
|
||||
func parentsMustHaveAllChildren[T any](
|
||||
parents []*sym.Expression,
|
||||
childrenSet map[field.Element]T,
|
||||
) (resErr error) {
|
||||
|
||||
for parentID, p := range parents {
|
||||
// Account for the fact that the node may contain duplicates of the node
|
||||
// we are looking for.
|
||||
founds := map[field.Element]struct{}{}
|
||||
for _, c := range p.Children {
|
||||
if _, ok := childrenSet[c.ESHash]; ok {
|
||||
founds[c.ESHash] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(founds) != len(childrenSet) {
|
||||
resErr = errors.Join(
|
||||
resErr,
|
||||
fmt.Errorf(
|
||||
"parent num %v is incomplete : found = %d/%d",
|
||||
parentID, len(founds), len(childrenSet),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resErr
|
||||
}
|
||||
|
||||
func copyMap[K comparable, V any](m map[K]V) map[K]V {
|
||||
res := make(map[K]V, len(m))
|
||||
for k, v := range m {
|
||||
res[k] = v
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||
sym "github.com/consensys/linea-monorepo/prover/symbolic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -60,12 +59,6 @@ func TestIsFactored(t *testing.T) {
|
||||
IsFactored: true,
|
||||
Factor: a,
|
||||
},
|
||||
{
|
||||
Expr: sym.Mul(a, a, b),
|
||||
By: a,
|
||||
IsFactored: true,
|
||||
Factor: sym.Mul(a, b),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testcases {
|
||||
@@ -74,13 +67,13 @@ func TestIsFactored(t *testing.T) {
|
||||
// Build the group exponent map. If by is a product, we use directly
|
||||
// the exponents it contains. Otherwise, we say this is a single
|
||||
// term product with an exponent of 1.
|
||||
groupedExp := map[field.Element]int{}
|
||||
groupedExp := map[uint64]int{}
|
||||
if byProd, ok := tc.By.Operator.(sym.Product); ok {
|
||||
for i, ex := range byProd.Exponents {
|
||||
groupedExp[tc.By.Children[i].ESHash] = ex
|
||||
groupedExp[tc.By.Children[i].ESHash[0]] = ex
|
||||
}
|
||||
} else {
|
||||
groupedExp[tc.By.ESHash] = 1
|
||||
groupedExp[tc.By.ESHash[0]] = 1
|
||||
}
|
||||
|
||||
factored, isFactored := isFactored(tc.Expr, groupedExp)
|
||||
@@ -207,15 +200,27 @@ func TestFactorLinCompFromGroup(t *testing.T) {
|
||||
for i, testCase := range testCases {
|
||||
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
|
||||
|
||||
group := map[field.Element]*sym.Expression{}
|
||||
group := map[uint64]*sym.Expression{}
|
||||
for _, e := range testCase.Group {
|
||||
group[e.ESHash] = e
|
||||
group[e.ESHash[0]] = e
|
||||
}
|
||||
|
||||
factored := factorLinCompFromGroup(testCase.LinComb, group)
|
||||
require.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
|
||||
assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
|
||||
|
||||
if t.Failed() {
|
||||
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
||||
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
||||
t.Fatal()
|
||||
}
|
||||
|
||||
require.NoError(t, factored.Validate())
|
||||
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
|
||||
|
||||
if t.Failed() {
|
||||
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
||||
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -21,24 +21,17 @@ func removePolyEval(e *sym.Expression) *sym.Expression {
|
||||
return oldExpr // Handle edge case where there are no coefficients
|
||||
}
|
||||
|
||||
acc := cs[0]
|
||||
|
||||
// Precompute powers of x
|
||||
powersOfX := make([]*sym.Expression, len(cs))
|
||||
powersOfX[0] = x
|
||||
for i := 1; i < len(cs); i++ {
|
||||
monomialTerms := make([]any, len(cs))
|
||||
for i := 0; i < len(cs); i++ {
|
||||
// We don't use the default constructor because it will collapse the
|
||||
// intermediate terms into a single term. The intermediates are useful because
|
||||
// they tell the evaluator to reuse the intermediate terms instead of
|
||||
// computing x^i for every term.
|
||||
powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1})
|
||||
monomialTerms[i] = any(sym.NewProduct([]*sym.Expression{cs[i], x}, []int{1, i}))
|
||||
}
|
||||
|
||||
for i := 1; i < len(cs); i++ {
|
||||
// Here we want to use the default constructor to ensure that we
|
||||
// will have a merged sum at the end.
|
||||
acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i]))
|
||||
}
|
||||
acc := sym.Add(monomialTerms...)
|
||||
|
||||
if oldExpr.ESHash != acc.ESHash {
|
||||
panic("ESH was altered")
|
||||
|
||||
22
prover/utils/sort.go
Normal file
22
prover/utils/sort.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
// GenSorter is a generic interface implementing [sort.Interface]
|
||||
// but let the user provide the methods directly as closures.
|
||||
// Without needing to implement a new custom type.
|
||||
type GenSorter struct {
|
||||
LenFn func() int
|
||||
SwapFn func(int, int)
|
||||
LessFn func(int, int) bool
|
||||
}
|
||||
|
||||
func (s GenSorter) Len() int {
|
||||
return s.LenFn()
|
||||
}
|
||||
|
||||
func (s GenSorter) Swap(i, j int) {
|
||||
s.SwapFn(i, j)
|
||||
}
|
||||
|
||||
func (s GenSorter) Less(i, j int) bool {
|
||||
return s.LessFn(i, j)
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func FullZkEvm(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
|
||||
|
||||
onceFullZkEvm.Do(func() {
|
||||
// Initialize the Full zkEVM arithmetization
|
||||
fullZkEvm = fullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
|
||||
fullZkEvm = FullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
|
||||
})
|
||||
|
||||
return fullZkEvm
|
||||
@@ -115,13 +115,16 @@ func FullZkEVMCheckOnly(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
|
||||
|
||||
onceFullZkEvmCheckOnly.Do(func() {
|
||||
// Initialize the Full zkEVM arithmetization
|
||||
fullZkEvmCheckOnly = fullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
|
||||
fullZkEvmCheckOnly = FullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
|
||||
})
|
||||
|
||||
return fullZkEvmCheckOnly
|
||||
}
|
||||
|
||||
func fullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
|
||||
// FullZKEVMWithSuite returns a compiled zkEVM with the given compilation suite.
|
||||
// It can be used to benchmark the compilation time of the zkEVM and helps with
|
||||
// performance optimization.
|
||||
func FullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
|
||||
|
||||
// @Alex: only set mandatory parameters here. aka, the one that are not
|
||||
// actually feature-gated.
|
||||
|
||||
Reference in New Issue
Block a user