mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 04:08:01 -05:00
Prover(perf): faster global constraints compilation (#704)
* bench(global): adds a benchmark for the global constraint compiler * perf(merging): accumulates the factors before creating the expression * perf(product): computes the ESH without using a smart-vector * perf(factor): preallocations in the factorization algorithm * perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs * perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations. * perf(factor): use an integer map instead of a field.Element map when possible * fixup(expands): fix the skip condition for term expansion * perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb * feat(repr): adds a json repr function to help debugging * test(constructor): cleans the test of the constructors * perf(factor): address maps using the first limb of a field.Element instead of the full field.Element * fixup(commit): adds missing file in previous commit * perf(factor): reduce the number of calls to rankChildren * perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims * perf(factors): use a counter in getCommonProdParentOfCs * perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function * clean(factor): remove unneeded function and imports * feat(utils): adds a generic sort interface implementation * perf(rankChildren): lazy allocation of the map to save on allocations * perf(factorize): reduces the loop-bound for factorizeExpression * (chore): fix a missing argument and format gofmt * feat: readd test --------- Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com> Co-authored-by: gusiri <dreamerty@postech.ac.kr>
This commit is contained in:
75
prover/protocol/compiler/globalcs/global_perf_test.go
Normal file
75
prover/protocol/compiler/globalcs/global_perf_test.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package globalcs_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/consensys/linea-monorepo/prover/config"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/cleanup"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter/sticker"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
|
||||||
|
"github.com/consensys/linea-monorepo/prover/zkevm"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkGlobalConstraint benchmarks the global constraints compiler against the
|
||||||
|
// actual zk-evm constraint system.
|
||||||
|
func BenchmarkGlobalConstraintWoArtefacts(b *testing.B) {
|
||||||
|
|
||||||
|
// partialSuite corresponds to the actual compilation suite of
|
||||||
|
// the full zk-evm down to the point where the global constraints
|
||||||
|
// are compiled.
|
||||||
|
partialSuite := []func(*wizard.CompiledIOP){
|
||||||
|
mimc.CompileMiMC,
|
||||||
|
specialqueries.RangeProof,
|
||||||
|
specialqueries.CompileFixedPermutations,
|
||||||
|
permutation.CompileGrandProduct,
|
||||||
|
lookup.CompileLogDerivative,
|
||||||
|
innerproduct.Compile,
|
||||||
|
sticker.Sticker(1<<10, 1<<19),
|
||||||
|
splitter.SplitColumns(1 << 19),
|
||||||
|
cleanup.CleanUp,
|
||||||
|
localcs.Compile,
|
||||||
|
globalcs.Compile,
|
||||||
|
}
|
||||||
|
|
||||||
|
// In order to load the config we need to position ourselves in the root
|
||||||
|
// folder.
|
||||||
|
_ = os.Chdir("../../..")
|
||||||
|
defer os.Chdir("protocol/compiler/globalcs")
|
||||||
|
|
||||||
|
// config corresponds to the config we use on sepolia
|
||||||
|
cfg, cfgErr := config.NewConfigFromFile("./config/config-sepolia-full.toml")
|
||||||
|
if cfgErr != nil {
|
||||||
|
b.Fatalf("could not find the config: %v", cfgErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shut the logger to not overwhelm the benchmark output
|
||||||
|
logrus.SetLevel(logrus.PanicLevel)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for c_ := 0; c_ < b.N; c_++ {
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
// Removes the artefacts to
|
||||||
|
if err := os.RemoveAll("/tmp/prover-artefacts"); err != nil {
|
||||||
|
b.Fatalf("could not remove the artefacts: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
_ = zkevm.FullZKEVMWithSuite(&cfg.TracesLimits, partialSuite, cfg)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -146,19 +146,28 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
|
|||||||
domainSize = cs.DomainSize
|
domainSize = cs.DomainSize
|
||||||
x = variables.NewXVar()
|
x = variables.NewXVar()
|
||||||
omega = fft.GetOmega(domainSize)
|
omega = fft.GetOmega(domainSize)
|
||||||
|
// factors is a list of expression to multiply to obtain the return expression. It
|
||||||
|
// is initialized with "only" the initial expression and we iteratively add the
|
||||||
|
// terms (X-i) to it. At the end, we call [sym.Mul] a single time. This structure
|
||||||
|
// is important because it [sym.Mul] operates a sequence of optimization routines
|
||||||
|
// that are everytime we call it. In an earlier version, we were calling [sym.Mul]
|
||||||
|
// for every factor and this were making the function have a quadratic/cubic runtime.
|
||||||
|
factors = make([]any, 0, utils.Abs(cancelRange.Max)+utils.Abs(cancelRange.Min)+1)
|
||||||
)
|
)
|
||||||
|
|
||||||
// cancelExprAtPoint cancels the expression at a particular position
|
factors = append(factors, res)
|
||||||
cancelExprAtPoint := func(expr *symbolic.Expression, i int) *symbolic.Expression {
|
|
||||||
|
// appendFactor appends an expressions representing $X-\rho^i$ to [factors]
|
||||||
|
appendFactor := func(i int) {
|
||||||
var root field.Element
|
var root field.Element
|
||||||
root.Exp(omega, big.NewInt(int64(i)))
|
root.Exp(omega, big.NewInt(int64(i)))
|
||||||
return symbolic.Mul(expr, symbolic.Sub(x, root))
|
factors = append(factors, symbolic.Sub(x, root))
|
||||||
}
|
}
|
||||||
|
|
||||||
if cancelRange.Min < 0 {
|
if cancelRange.Min < 0 {
|
||||||
// Cancels the expression on the range [0, -cancelRange.Min)
|
// Cancels the expression on the range [0, -cancelRange.Min)
|
||||||
for i := 0; i < -cancelRange.Min; i++ {
|
for i := 0; i < -cancelRange.Min; i++ {
|
||||||
res = cancelExprAtPoint(res, i)
|
appendFactor(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,11 +175,18 @@ func getBoundCancelledExpression(cs query.GlobalConstraint) *symbolic.Expression
|
|||||||
// Cancels the expression on the range (N-cancelRange.Max-1, N-1]
|
// Cancels the expression on the range (N-cancelRange.Max-1, N-1]
|
||||||
for i := 0; i < cancelRange.Max; i++ {
|
for i := 0; i < cancelRange.Max; i++ {
|
||||||
point := domainSize - i - 1 // point at which we want to cancel the constraint
|
point := domainSize - i - 1 // point at which we want to cancel the constraint
|
||||||
res = cancelExprAtPoint(res, point)
|
appendFactor(point)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
// When factors is of length 1, it means the expression does not need to be
|
||||||
|
// bound-cancelled and we can directly return the original expression
|
||||||
|
// without calling [sym.Mul].
|
||||||
|
if len(factors) == 1 {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
return symbolic.Mul(factors...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getExprRatio computes the ratio of the expression and ceil to the next power
|
// getExprRatio computes the ratio of the expression and ceil to the next power
|
||||||
|
|||||||
@@ -27,12 +27,12 @@ func Add(inputs ...any) *Expression {
|
|||||||
|
|
||||||
exprInputs := intoExprSlice(inputs...)
|
exprInputs := intoExprSlice(inputs...)
|
||||||
|
|
||||||
res := exprInputs[0]
|
magnitudes := make([]int, len(exprInputs))
|
||||||
for i := 1; i < len(exprInputs); i++ {
|
for i := range exprInputs {
|
||||||
res = res.Add(exprInputs[i])
|
magnitudes[i] = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return NewLinComb(exprInputs, magnitudes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mul constructs a symbolic expression representing the product of its inputs.
|
// Mul constructs a symbolic expression representing the product of its inputs.
|
||||||
@@ -57,12 +57,12 @@ func Mul(inputs ...any) *Expression {
|
|||||||
|
|
||||||
exprInputs := intoExprSlice(inputs...)
|
exprInputs := intoExprSlice(inputs...)
|
||||||
|
|
||||||
res := exprInputs[0]
|
magnitudes := make([]int, len(exprInputs))
|
||||||
for i := 1; i < len(exprInputs); i++ {
|
for i := range exprInputs {
|
||||||
res = res.Mul(exprInputs[i])
|
magnitudes[i] = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return NewProduct(exprInputs, magnitudes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sub returns a symbolic expression representing the subtraction of `a` by all
|
// Sub returns a symbolic expression representing the subtraction of `a` by all
|
||||||
@@ -82,16 +82,18 @@ func Sub(a any, bs ...any) *Expression {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
aExpr = intoExpr(a)
|
aExpr = intoExpr(a)
|
||||||
bExpr = intoExprSlice(bs...)
|
bExpr = intoExprSlice(bs...)
|
||||||
res = aExpr
|
exprInputs = append([]*Expression{aExpr}, bExpr...)
|
||||||
|
magnitudes = make([]int, len(exprInputs))
|
||||||
)
|
)
|
||||||
|
|
||||||
for i := range bExpr {
|
for i := range exprInputs {
|
||||||
res = res.Sub(bExpr[i])
|
magnitudes[i] = -1
|
||||||
}
|
}
|
||||||
|
magnitudes[0] = 1
|
||||||
|
|
||||||
return res
|
return NewLinComb(exprInputs, magnitudes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Neg returns an expression representing the negation of an expression or of
|
// Neg returns an expression representing the negation of an expression or of
|
||||||
@@ -119,7 +121,7 @@ func Square(x any) *Expression {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return intoExpr(x).Square()
|
return NewProduct([]*Expression{intoExpr(x)}, []int{2})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pow returns an expression representing the raising to the power "n" of an
|
// Pow returns an expression representing the raising to the power "n" of an
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package symbolic
|
package symbolic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -284,5 +285,14 @@ func (e *Expression) SameWithNewChildren(newChildren []*Expression) *Expression
|
|||||||
default:
|
default:
|
||||||
panic("unexpected type: " + reflect.TypeOf(op).String())
|
panic("unexpected type: " + reflect.TypeOf(op).String())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSONString returns a JSON string returns a JSON string representation
|
||||||
|
// of the expression.
|
||||||
|
func (e *Expression) MarshalJSONString() string {
|
||||||
|
js, jsErr := json.MarshalIndent(e, "", " ")
|
||||||
|
if jsErr != nil {
|
||||||
|
utils.Panic("failed to marshal expression: %v", jsErr)
|
||||||
|
}
|
||||||
|
return string(js)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
|
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
|
||||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||||
"github.com/consensys/linea-monorepo/prover/utils/collection"
|
"github.com/consensys/linea-monorepo/prover/utils/collection"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,7 +97,7 @@ func TestLCConstruction(t *testing.T) {
|
|||||||
x := NewDummyVar("x")
|
x := NewDummyVar("x")
|
||||||
y := NewDummyVar("y")
|
y := NewDummyVar("y")
|
||||||
|
|
||||||
{
|
t.Run("simple-addition", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
Test t a simple case of addition
|
Test t a simple case of addition
|
||||||
*/
|
*/
|
||||||
@@ -108,23 +109,23 @@ func TestLCConstruction(t *testing.T) {
|
|||||||
require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs))
|
require.Equal(t, 2, len(expr1.Operator.(LinComb).Coeffs))
|
||||||
require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1)
|
require.Equal(t, expr1.Operator.(LinComb).Coeffs[0], 1)
|
||||||
require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1)
|
require.Equal(t, expr1.Operator.(LinComb).Coeffs[1], 1)
|
||||||
}
|
})
|
||||||
|
|
||||||
{
|
t.Run("x-y-x", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
Adding y then substracting x should give back (y)
|
Adding y then substracting x should give back (y)
|
||||||
*/
|
*/
|
||||||
expr1 := x.Add(y).Sub(x)
|
expr1 := x.Add(y).Sub(x)
|
||||||
require.Equal(t, expr1, y)
|
require.Equal(t, expr1, y)
|
||||||
}
|
})
|
||||||
|
|
||||||
{
|
t.Run("(-x)+x+y", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
Same thing when using Neg
|
Same thing when using Neg
|
||||||
*/
|
*/
|
||||||
expr := x.Neg().Add(x).Add(y)
|
expr := x.Neg().Add(x).Add(y)
|
||||||
require.Equal(t, expr, y)
|
assert.Equal(t, expr, y)
|
||||||
}
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +134,7 @@ func TestProductConstruction(t *testing.T) {
|
|||||||
x := NewDummyVar("x")
|
x := NewDummyVar("x")
|
||||||
y := NewDummyVar("y")
|
y := NewDummyVar("y")
|
||||||
|
|
||||||
{
|
t.Run("x * y", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
Test t a simple case of addition
|
Test t a simple case of addition
|
||||||
*/
|
*/
|
||||||
@@ -145,9 +146,9 @@ func TestProductConstruction(t *testing.T) {
|
|||||||
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
||||||
require.Equal(t, expr1.Operator.(Product).Exponents[0], 1)
|
require.Equal(t, expr1.Operator.(Product).Exponents[0], 1)
|
||||||
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
||||||
}
|
})
|
||||||
|
|
||||||
{
|
t.Run("x * y * x", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
Adding y then substracting x should give back (y)
|
Adding y then substracting x should give back (y)
|
||||||
*/
|
*/
|
||||||
@@ -158,9 +159,9 @@ func TestProductConstruction(t *testing.T) {
|
|||||||
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
require.Equal(t, 2, len(expr1.Operator.(Product).Exponents))
|
||||||
require.Equal(t, expr1.Operator.(Product).Exponents[0], 2)
|
require.Equal(t, expr1.Operator.(Product).Exponents[0], 2)
|
||||||
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
require.Equal(t, expr1.Operator.(Product).Exponents[1], 1)
|
||||||
}
|
})
|
||||||
|
|
||||||
{
|
t.Run("x^2", func(t *testing.T) {
|
||||||
/*
|
/*
|
||||||
When we square
|
When we square
|
||||||
*/
|
*/
|
||||||
@@ -169,6 +170,6 @@ func TestProductConstruction(t *testing.T) {
|
|||||||
require.Equal(t, expr.Children[0], x)
|
require.Equal(t, expr.Children[0], x)
|
||||||
require.Equal(t, 1, len(expr.Operator.(Product).Exponents))
|
require.Equal(t, 1, len(expr.Operator.(Product).Exponents))
|
||||||
require.Equal(t, expr.Operator.(Product).Exponents[0], 2)
|
require.Equal(t, expr.Operator.(Product).Exponents[0], 2)
|
||||||
}
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,18 +95,30 @@ func NewProduct(items []*Expression, exponents []int) *Expression {
|
|||||||
e := &Expression{
|
e := &Expression{
|
||||||
Operator: Product{Exponents: exponents},
|
Operator: Product{Exponents: exponents},
|
||||||
Children: items,
|
Children: items,
|
||||||
|
ESHash: field.One(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we need to assign the ESH
|
|
||||||
eshashes := make([]sv.SmartVector, len(e.Children))
|
|
||||||
for i := range e.Children {
|
for i := range e.Children {
|
||||||
eshashes[i] = sv.NewConstant(e.Children[i].ESHash, 1)
|
var tmp field.Element
|
||||||
}
|
switch {
|
||||||
|
case exponents[i] == 1:
|
||||||
if len(items) > 0 {
|
e.ESHash.Mul(&e.ESHash, &e.Children[i].ESHash)
|
||||||
// The cast back to sv.Constant is no important functionally but is an easy
|
case exponents[i] == 2:
|
||||||
// sanity check.
|
tmp.Square(&e.Children[i].ESHash)
|
||||||
e.ESHash = e.Operator.Evaluate(eshashes).(*sv.Constant).Get(0)
|
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||||
|
case exponents[i] == 3:
|
||||||
|
tmp.Square(&e.Children[i].ESHash)
|
||||||
|
tmp.Mul(&tmp, &e.Children[i].ESHash)
|
||||||
|
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||||
|
case exponents[i] == 4:
|
||||||
|
tmp.Square(&e.Children[i].ESHash)
|
||||||
|
tmp.Square(&tmp)
|
||||||
|
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||||
|
default:
|
||||||
|
exponent := big.NewInt(int64(exponents[i]))
|
||||||
|
tmp.Exp(e.Children[i].ESHash, exponent)
|
||||||
|
e.ESHash.Mul(&e.ESHash, &tmp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return e
|
return e
|
||||||
|
|||||||
@@ -78,22 +78,38 @@ func regroupTerms(magnitudes []int, children []*Expression) (
|
|||||||
// the linear combination. This function is used both for simplifying [LinComb]
|
// the linear combination. This function is used both for simplifying [LinComb]
|
||||||
// expressions and for simplifying [Product]. "magnitude" denotes either the
|
// expressions and for simplifying [Product]. "magnitude" denotes either the
|
||||||
// coefficient for LinComb or exponents for Product.
|
// coefficient for LinComb or exponents for Product.
|
||||||
|
//
|
||||||
|
// The function takes ownership of the provided slices.
|
||||||
func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) {
|
func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes []int, cleanChildren []*Expression) {
|
||||||
|
|
||||||
if len(magnitudes) != len(children) {
|
if len(magnitudes) != len(children) {
|
||||||
panic("magnitudes and children don't have the same length")
|
panic("magnitudes and children don't have the same length")
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanChildren = make([]*Expression, 0, len(children))
|
// cleanChildren and cleanMagnitudes are initialized lazily to
|
||||||
cleanMagnitudes = make([]int, 0, len(children))
|
// avoid unnecessarily allocating memory. The underlying assumption
|
||||||
|
// is that the application will 99% of time never pass zero as a
|
||||||
|
// magnitude.
|
||||||
for i, c := range magnitudes {
|
for i, c := range magnitudes {
|
||||||
if c != 0 {
|
|
||||||
|
if c == 0 && cleanChildren == nil {
|
||||||
|
cleanChildren = make([]*Expression, i, len(children))
|
||||||
|
cleanMagnitudes = make([]int, i, len(children))
|
||||||
|
copy(cleanChildren, children[:i])
|
||||||
|
copy(cleanMagnitudes, magnitudes[:i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if c != 0 && cleanChildren != nil {
|
||||||
cleanMagnitudes = append(cleanMagnitudes, magnitudes[i])
|
cleanMagnitudes = append(cleanMagnitudes, magnitudes[i])
|
||||||
cleanChildren = append(cleanChildren, children[i])
|
cleanChildren = append(cleanChildren, children[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cleanChildren == nil {
|
||||||
|
cleanChildren = children
|
||||||
|
cleanMagnitudes = magnitudes
|
||||||
|
}
|
||||||
|
|
||||||
return cleanMagnitudes, cleanChildren
|
return cleanMagnitudes, cleanChildren
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,14 +124,16 @@ func removeZeroCoeffs(magnitudes []int, children []*Expression) (cleanMagnitudes
|
|||||||
// The caller passes a target operator which may be any value of type either
|
// The caller passes a target operator which may be any value of type either
|
||||||
// [LinComb] or [Product]. Any other type yields a panic error.
|
// [LinComb] or [Product]. Any other type yields a panic error.
|
||||||
func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
||||||
expandedMagnitudes []int,
|
[]int,
|
||||||
expandedExpression []*Expression,
|
[]*Expression,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
opIsProd bool
|
opIsProd bool
|
||||||
opIsLinC bool
|
opIsLinC bool
|
||||||
numChildren = len(children)
|
numChildren = len(children)
|
||||||
|
totalReturnSize = 0
|
||||||
|
needExpand = false
|
||||||
)
|
)
|
||||||
|
|
||||||
switch op.(type) {
|
switch op.(type) {
|
||||||
@@ -133,9 +151,34 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
|||||||
panic("incompatible number of children and magnitudes")
|
panic("incompatible number of children and magnitudes")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The capacity allocation is purely heuristic
|
// This loops performs a first scan of the children to compute the total
|
||||||
expandedExpression = make([]*Expression, 0, 2*len(magnitudes))
|
// number of elements to allocate.
|
||||||
expandedMagnitudes = make([]int, 0, 2*len(magnitudes))
|
for i, child := range children {
|
||||||
|
|
||||||
|
switch child.Operator.(type) {
|
||||||
|
case Product, *Product:
|
||||||
|
if opIsProd {
|
||||||
|
needExpand = true
|
||||||
|
totalReturnSize += len(children[i].Children)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case LinComb, *LinComb:
|
||||||
|
if opIsLinC {
|
||||||
|
needExpand = true
|
||||||
|
totalReturnSize += len(children[i].Children)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
totalReturnSize++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !needExpand {
|
||||||
|
return magnitudes, children
|
||||||
|
}
|
||||||
|
|
||||||
|
expandedMagnitudes := make([]int, 0, totalReturnSize)
|
||||||
|
expandedExpression := make([]*Expression, 0, totalReturnSize)
|
||||||
|
|
||||||
for i := 0; i < numChildren; i++ {
|
for i := 0; i < numChildren; i++ {
|
||||||
|
|
||||||
@@ -175,7 +218,6 @@ func expandTerms(op Operator, magnitudes []int, children []*Expression) (
|
|||||||
for k := range child.Children {
|
for k := range child.Children {
|
||||||
expandedExpression = append(expandedExpression, child.Children[k])
|
expandedExpression = append(expandedExpression, child.Children[k])
|
||||||
expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k])
|
expandedMagnitudes = append(expandedMagnitudes, magnitude*cLinC.Coeffs[k])
|
||||||
|
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package simplify
|
package simplify
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -46,9 +45,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
|
|||||||
// factoring possibilities. There is also a bound on the loop to
|
// factoring possibilities. There is also a bound on the loop to
|
||||||
// prevent infinite loops.
|
// prevent infinite loops.
|
||||||
//
|
//
|
||||||
// The choice of 1000 is purely heuristic and is not meant to be
|
// The choice of 100 is purely heuristic and is not meant to be
|
||||||
// actually met.
|
// actually met.
|
||||||
for k := 0; k < 1000; k++ {
|
for k := 0; k < 100; k++ {
|
||||||
_, ok := new.Operator.(sym.LinComb)
|
_, ok := new.Operator.(sym.LinComb)
|
||||||
if !ok {
|
if !ok {
|
||||||
return new
|
return new
|
||||||
@@ -103,13 +102,18 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
|
|||||||
// children that are already in the children set.
|
// children that are already in the children set.
|
||||||
func rankChildren(
|
func rankChildren(
|
||||||
parents []*sym.Expression,
|
parents []*sym.Expression,
|
||||||
childrenSet map[field.Element]*sym.Expression,
|
childrenSet map[uint64]*sym.Expression,
|
||||||
) []*sym.Expression {
|
) []*sym.Expression {
|
||||||
|
|
||||||
// List all the grand-children of the expression whose parents are
|
// List all the grand-children of the expression whose parents are
|
||||||
// products and counts the number of occurences by summing the exponents.
|
// products and counts the number of occurences by summing the exponents.
|
||||||
relevantGdChildrenCnt := map[field.Element]int{}
|
// As an optimization the map is addressed using the first uint64 repr
|
||||||
uniqueChildrenList := make([]*sym.Expression, 0)
|
// of the element. We consider this is good enough to avoid collisions.
|
||||||
|
// The risk if it happens is that it gets caught by the validation checks
|
||||||
|
// at the end of the factorization routine. The preallocation value is
|
||||||
|
// purely heuristic to avoid successive allocations.
|
||||||
|
var relevantGdChildrenCnt map[uint64]int
|
||||||
|
var uniqueChildrenList []*sym.Expression
|
||||||
|
|
||||||
for _, p := range parents {
|
for _, p := range parents {
|
||||||
|
|
||||||
@@ -127,16 +131,21 @@ func rankChildren(
|
|||||||
|
|
||||||
// If it's in the group, it does not count. We can't add it a second
|
// If it's in the group, it does not count. We can't add it a second
|
||||||
// time.
|
// time.
|
||||||
if _, ok := childrenSet[c.ESHash]; ok {
|
if _, ok := childrenSet[c.ESHash[0]]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := relevantGdChildrenCnt[c.ESHash]; !ok {
|
if relevantGdChildrenCnt == nil {
|
||||||
relevantGdChildrenCnt[c.ESHash] = 0
|
relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2)
|
||||||
|
uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok {
|
||||||
|
relevantGdChildrenCnt[c.ESHash[0]] = 0
|
||||||
uniqueChildrenList = append(uniqueChildrenList, c)
|
uniqueChildrenList = append(uniqueChildrenList, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
relevantGdChildrenCnt[c.ESHash]++
|
relevantGdChildrenCnt[c.ESHash[0]]++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +153,7 @@ func rankChildren(
|
|||||||
x := uniqueChildrenList[i].ESHash
|
x := uniqueChildrenList[i].ESHash
|
||||||
y := uniqueChildrenList[j].ESHash
|
y := uniqueChildrenList[j].ESHash
|
||||||
// We want to a decreasing order
|
// We want to a decreasing order
|
||||||
return relevantGdChildrenCnt[x] > relevantGdChildrenCnt[y]
|
return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]]
|
||||||
})
|
})
|
||||||
|
|
||||||
return uniqueChildrenList
|
return uniqueChildrenList
|
||||||
@@ -155,76 +164,54 @@ func rankChildren(
|
|||||||
// than one parent. The finding is based on a greedy algorithm. We iteratively
|
// than one parent. The finding is based on a greedy algorithm. We iteratively
|
||||||
// add nodes in the group so that the number of common parents decreases as
|
// add nodes in the group so that the number of common parents decreases as
|
||||||
// slowly as possible.
|
// slowly as possible.
|
||||||
func findGdChildrenGroup(expr *sym.Expression) map[field.Element]*sym.Expression {
|
func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression {
|
||||||
|
|
||||||
curParents := expr.Children
|
curParents := expr.Children
|
||||||
childrenSet := map[field.Element]*sym.Expression{}
|
childrenSet := map[uint64]*sym.Expression{}
|
||||||
|
|
||||||
for {
|
ranked := rankChildren(curParents, childrenSet)
|
||||||
ranked := rankChildren(curParents, childrenSet)
|
|
||||||
|
|
||||||
// Can happen when we have a lincomb of lincomb. Ideally they should be
|
// Can happen when we have a lincomb of lincomb. Ideally they should be
|
||||||
// merged during canonization.
|
// merged during canonization.
|
||||||
if len(ranked) == 0 {
|
if len(ranked) == 0 {
|
||||||
return childrenSet
|
return childrenSet
|
||||||
}
|
}
|
||||||
|
|
||||||
best := ranked[0]
|
for i := range ranked {
|
||||||
newChildrenSet := copyMap(childrenSet)
|
|
||||||
newChildrenSet[best.ESHash] = best
|
best := ranked[i]
|
||||||
newParents := getCommonProdParentOfCs(newChildrenSet, curParents)
|
childrenSet[best.ESHash[0]] = best
|
||||||
|
curParents = filterParentsWithChildren(curParents, best.ESHash)
|
||||||
|
|
||||||
// Can't grow the set anymore
|
// Can't grow the set anymore
|
||||||
if len(newParents) <= 1 {
|
if len(curParents) <= 1 {
|
||||||
|
delete(childrenSet, best.ESHash[0])
|
||||||
return childrenSet
|
return childrenSet
|
||||||
}
|
}
|
||||||
|
|
||||||
childrenSet = newChildrenSet
|
|
||||||
curParents = newParents
|
|
||||||
|
|
||||||
logrus.Tracef(
|
logrus.Tracef(
|
||||||
"find groups, so far we have %v parents and %v siblings",
|
"find groups, so far we have %v parents and %v siblings",
|
||||||
len(curParents), len(childrenSet))
|
len(curParents), len(childrenSet))
|
||||||
|
|
||||||
// Sanity-check
|
|
||||||
if err := parentsMustHaveAllChildren(curParents, childrenSet); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return childrenSet
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCommonProdParentOfCs returns the parents that have all cs as children and
|
// filterParentsWithChildren returns a filtered list of parents who have at
|
||||||
// that are themselves children of gdp (grandparent). The parents must be of
|
// least one child with the given ESHash. The function allocates a new list
|
||||||
// type product however.
|
// of parents and returns it without mutating he original list.
|
||||||
func getCommonProdParentOfCs(
|
func filterParentsWithChildren(
|
||||||
cs map[field.Element]*sym.Expression,
|
|
||||||
parents []*sym.Expression,
|
parents []*sym.Expression,
|
||||||
|
childEsh field.Element,
|
||||||
) []*sym.Expression {
|
) []*sym.Expression {
|
||||||
|
|
||||||
res := []*sym.Expression{}
|
res := make([]*sym.Expression, 0, len(parents))
|
||||||
|
|
||||||
for _, p := range parents {
|
for _, p := range parents {
|
||||||
prod, ok := p.Operator.(sym.Product)
|
for _, c := range p.Children {
|
||||||
if !ok {
|
if c.ESHash == childEsh {
|
||||||
continue
|
res = append(res, p)
|
||||||
}
|
break
|
||||||
|
|
||||||
// Account for the fact that p may contain duplicates. So we cannot
|
|
||||||
// just use a counter here.
|
|
||||||
founds := map[field.Element]struct{}{}
|
|
||||||
for i, c := range p.Children {
|
|
||||||
if prod.Exponents[i] == 0 {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, inside := cs[c.ESHash]; inside {
|
|
||||||
// logrus.Tracef("%v contains %v", p.ESHash.String(), c.ESHash.String())
|
|
||||||
founds[c.ESHash] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(founds) == len(cs) {
|
|
||||||
res = append(res, p)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,22 +222,26 @@ func getCommonProdParentOfCs(
|
|||||||
// determine the best common factor.
|
// determine the best common factor.
|
||||||
func factorLinCompFromGroup(
|
func factorLinCompFromGroup(
|
||||||
lincom *sym.Expression,
|
lincom *sym.Expression,
|
||||||
group map[field.Element]*sym.Expression,
|
group map[uint64]*sym.Expression,
|
||||||
) *sym.Expression {
|
) *sym.Expression {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
||||||
|
// numTerms indicates the number of children in the linear-combination
|
||||||
|
numTerms = len(lincom.Children)
|
||||||
|
|
||||||
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
|
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
|
||||||
// Build the common term by taking the max of the exponents
|
// Build the common term by taking the max of the exponents
|
||||||
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
|
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
|
||||||
|
|
||||||
// Separate the non-factored terms
|
// Separate the non-factored terms
|
||||||
nonFactoredTerms = []*sym.Expression{}
|
nonFactoredTerms = make([]*sym.Expression, 0, numTerms)
|
||||||
nonFactoredCoeffs = []int{}
|
nonFactoredCoeffs = make([]int, 0, numTerms)
|
||||||
|
|
||||||
// The factored terms of the linear combination divided by the common
|
// The factored terms of the linear combination divided by the common
|
||||||
// group factor
|
// group factor
|
||||||
factoredTerms = []*sym.Expression{}
|
factoredTerms = make([]*sym.Expression, 0, numTerms)
|
||||||
factoredCoeffs = []int{}
|
factoredCoeffs = make([]int, 0, numTerms)
|
||||||
)
|
)
|
||||||
|
|
||||||
numFactors := 0
|
numFactors := 0
|
||||||
@@ -295,7 +286,7 @@ func factorLinCompFromGroup(
|
|||||||
//
|
//
|
||||||
// Fortunately, this is guaranteed if the expression was constructed via
|
// Fortunately, this is guaranteed if the expression was constructed via
|
||||||
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
|
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
|
||||||
func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) (
|
||||||
factored *sym.Expression,
|
factored *sym.Expression,
|
||||||
success bool,
|
success bool,
|
||||||
) {
|
) {
|
||||||
@@ -310,7 +301,7 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
|||||||
|
|
||||||
numMatches := 0
|
numMatches := 0
|
||||||
for i, c := range e.Children {
|
for i, c := range e.Children {
|
||||||
eig, found := exponentsOfGroup[c.ESHash]
|
eig, found := exponentsOfGroup[c.ESHash[0]]
|
||||||
if !found {
|
if !found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -335,14 +326,14 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
|
|||||||
// have the whole group as children.
|
// have the whole group as children.
|
||||||
func optimRegroupExponents(
|
func optimRegroupExponents(
|
||||||
parents []*sym.Expression,
|
parents []*sym.Expression,
|
||||||
group map[field.Element]*sym.Expression,
|
group map[uint64]*sym.Expression,
|
||||||
) (
|
) (
|
||||||
exponentMap map[field.Element]int,
|
exponentMap map[uint64]int,
|
||||||
groupedTerm *sym.Expression,
|
groupedTerm *sym.Expression,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
exponentMap = map[field.Element]int{}
|
exponentMap = make(map[uint64]int, 16)
|
||||||
canonTermList := make([]*sym.Expression, 0) // built in deterministic order
|
canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order
|
||||||
|
|
||||||
for _, p := range parents {
|
for _, p := range parents {
|
||||||
|
|
||||||
@@ -354,10 +345,10 @@ func optimRegroupExponents(
|
|||||||
|
|
||||||
// Used to sanity-check that all the nodes of the group have been
|
// Used to sanity-check that all the nodes of the group have been
|
||||||
// reached through this parent.
|
// reached through this parent.
|
||||||
matched := map[field.Element]int{}
|
matched := make(map[uint64]int, len(p.Children))
|
||||||
|
|
||||||
for i, c := range p.Children {
|
for i, c := range p.Children {
|
||||||
if _, ingroup := group[c.ESHash]; !ingroup {
|
if _, ingroup := group[c.ESHash[0]]; !ingroup {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,16 +356,16 @@ func optimRegroupExponents(
|
|||||||
panic("The expression is not canonic")
|
panic("The expression is not canonic")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, initialized := exponentMap[c.ESHash]
|
_, initialized := exponentMap[c.ESHash[0]]
|
||||||
if !initialized {
|
if !initialized {
|
||||||
// Max int is used as a placeholder. It will be replaced anytime
|
// Max int is used as a placeholder. It will be replaced anytime
|
||||||
// we wall utils.Min(exponentMap[h], n) where n is actually an
|
// we wall utils.Min(exponentMap[h], n) where n is actually an
|
||||||
// exponent.
|
// exponent.
|
||||||
exponentMap[c.ESHash] = math.MaxInt
|
exponentMap[c.ESHash[0]] = math.MaxInt
|
||||||
canonTermList = append(canonTermList, c)
|
canonTermList = append(canonTermList, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
matched[c.ESHash] = exponents[i]
|
matched[c.ESHash[0]] = exponents[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matched) != len(group) {
|
if len(matched) != len(group) {
|
||||||
@@ -391,48 +382,8 @@ func optimRegroupExponents(
|
|||||||
|
|
||||||
canonExponents := []int{}
|
canonExponents := []int{}
|
||||||
for _, e := range canonTermList {
|
for _, e := range canonTermList {
|
||||||
canonExponents = append(canonExponents, exponentMap[e.ESHash])
|
canonExponents = append(canonExponents, exponentMap[e.ESHash[0]])
|
||||||
}
|
}
|
||||||
|
|
||||||
return exponentMap, sym.NewProduct(canonTermList, canonExponents)
|
return exponentMap, sym.NewProduct(canonTermList, canonExponents)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parentsMustHaveAllChildren returns an error if at least one of the parents
|
|
||||||
// is missing one children from the set. This function is used internally to
|
|
||||||
// enforce invariants throughout the simplification routines.
|
|
||||||
func parentsMustHaveAllChildren[T any](
|
|
||||||
parents []*sym.Expression,
|
|
||||||
childrenSet map[field.Element]T,
|
|
||||||
) (resErr error) {
|
|
||||||
|
|
||||||
for parentID, p := range parents {
|
|
||||||
// Account for the fact that the node may contain duplicates of the node
|
|
||||||
// we are looking for.
|
|
||||||
founds := map[field.Element]struct{}{}
|
|
||||||
for _, c := range p.Children {
|
|
||||||
if _, ok := childrenSet[c.ESHash]; ok {
|
|
||||||
founds[c.ESHash] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(founds) != len(childrenSet) {
|
|
||||||
resErr = errors.Join(
|
|
||||||
resErr,
|
|
||||||
fmt.Errorf(
|
|
||||||
"parent num %v is incomplete : found = %d/%d",
|
|
||||||
parentID, len(founds), len(childrenSet),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyMap[K comparable, V any](m map[K]V) map[K]V {
|
|
||||||
res := make(map[K]V, len(m))
|
|
||||||
for k, v := range m {
|
|
||||||
res[k] = v
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
|
||||||
sym "github.com/consensys/linea-monorepo/prover/symbolic"
|
sym "github.com/consensys/linea-monorepo/prover/symbolic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -60,12 +59,6 @@ func TestIsFactored(t *testing.T) {
|
|||||||
IsFactored: true,
|
IsFactored: true,
|
||||||
Factor: a,
|
Factor: a,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Expr: sym.Mul(a, a, b),
|
|
||||||
By: a,
|
|
||||||
IsFactored: true,
|
|
||||||
Factor: sym.Mul(a, b),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testcases {
|
for i, tc := range testcases {
|
||||||
@@ -74,13 +67,13 @@ func TestIsFactored(t *testing.T) {
|
|||||||
// Build the group exponent map. If by is a product, we use directly
|
// Build the group exponent map. If by is a product, we use directly
|
||||||
// the exponents it contains. Otherwise, we say this is a single
|
// the exponents it contains. Otherwise, we say this is a single
|
||||||
// term product with an exponent of 1.
|
// term product with an exponent of 1.
|
||||||
groupedExp := map[field.Element]int{}
|
groupedExp := map[uint64]int{}
|
||||||
if byProd, ok := tc.By.Operator.(sym.Product); ok {
|
if byProd, ok := tc.By.Operator.(sym.Product); ok {
|
||||||
for i, ex := range byProd.Exponents {
|
for i, ex := range byProd.Exponents {
|
||||||
groupedExp[tc.By.Children[i].ESHash] = ex
|
groupedExp[tc.By.Children[i].ESHash[0]] = ex
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
groupedExp[tc.By.ESHash] = 1
|
groupedExp[tc.By.ESHash[0]] = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
factored, isFactored := isFactored(tc.Expr, groupedExp)
|
factored, isFactored := isFactored(tc.Expr, groupedExp)
|
||||||
@@ -207,15 +200,27 @@ func TestFactorLinCompFromGroup(t *testing.T) {
|
|||||||
for i, testCase := range testCases {
|
for i, testCase := range testCases {
|
||||||
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
|
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
|
||||||
|
|
||||||
group := map[field.Element]*sym.Expression{}
|
group := map[uint64]*sym.Expression{}
|
||||||
for _, e := range testCase.Group {
|
for _, e := range testCase.Group {
|
||||||
group[e.ESHash] = e
|
group[e.ESHash[0]] = e
|
||||||
}
|
}
|
||||||
|
|
||||||
factored := factorLinCompFromGroup(testCase.LinComb, group)
|
factored := factorLinCompFromGroup(testCase.LinComb, group)
|
||||||
require.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
|
assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
|
||||||
|
|
||||||
|
if t.Failed() {
|
||||||
|
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
||||||
|
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
|
||||||
require.NoError(t, factored.Validate())
|
require.NoError(t, factored.Validate())
|
||||||
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
|
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
|
||||||
|
|
||||||
|
if t.Failed() {
|
||||||
|
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
||||||
|
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,24 +21,17 @@ func removePolyEval(e *sym.Expression) *sym.Expression {
|
|||||||
return oldExpr // Handle edge case where there are no coefficients
|
return oldExpr // Handle edge case where there are no coefficients
|
||||||
}
|
}
|
||||||
|
|
||||||
acc := cs[0]
|
|
||||||
|
|
||||||
// Precompute powers of x
|
// Precompute powers of x
|
||||||
powersOfX := make([]*sym.Expression, len(cs))
|
monomialTerms := make([]any, len(cs))
|
||||||
powersOfX[0] = x
|
for i := 0; i < len(cs); i++ {
|
||||||
for i := 1; i < len(cs); i++ {
|
|
||||||
// We don't use the default constructor because it will collapse the
|
// We don't use the default constructor because it will collapse the
|
||||||
// intermediate terms into a single term. The intermediates are useful because
|
// intermediate terms into a single term. The intermediates are useful because
|
||||||
// they tell the evaluator to reuse the intermediate terms instead of
|
// they tell the evaluator to reuse the intermediate terms instead of
|
||||||
// computing x^i for every term.
|
// computing x^i for every term.
|
||||||
powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1})
|
monomialTerms[i] = any(sym.NewProduct([]*sym.Expression{cs[i], x}, []int{1, i}))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 1; i < len(cs); i++ {
|
acc := sym.Add(monomialTerms...)
|
||||||
// Here we want to use the default constructor to ensure that we
|
|
||||||
// will have a merged sum at the end.
|
|
||||||
acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i]))
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldExpr.ESHash != acc.ESHash {
|
if oldExpr.ESHash != acc.ESHash {
|
||||||
panic("ESH was altered")
|
panic("ESH was altered")
|
||||||
|
|||||||
22
prover/utils/sort.go
Normal file
22
prover/utils/sort.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
// GenSorter is a generic interface implementing [sort.Interface]
|
||||||
|
// but let the user provide the methods directly as closures.
|
||||||
|
// Without needing to implement a new custom type.
|
||||||
|
type GenSorter struct {
|
||||||
|
LenFn func() int
|
||||||
|
SwapFn func(int, int)
|
||||||
|
LessFn func(int, int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s GenSorter) Len() int {
|
||||||
|
return s.LenFn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s GenSorter) Swap(i, j int) {
|
||||||
|
s.SwapFn(i, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s GenSorter) Less(i, j int) bool {
|
||||||
|
return s.LessFn(i, j)
|
||||||
|
}
|
||||||
@@ -105,7 +105,7 @@ func FullZkEvm(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
|
|||||||
|
|
||||||
onceFullZkEvm.Do(func() {
|
onceFullZkEvm.Do(func() {
|
||||||
// Initialize the Full zkEVM arithmetization
|
// Initialize the Full zkEVM arithmetization
|
||||||
fullZkEvm = fullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
|
fullZkEvm = FullZKEVMWithSuite(tl, fullCompilationSuite, cfg)
|
||||||
})
|
})
|
||||||
|
|
||||||
return fullZkEvm
|
return fullZkEvm
|
||||||
@@ -115,13 +115,16 @@ func FullZkEVMCheckOnly(tl *config.TracesLimits, cfg *config.Config) *ZkEvm {
|
|||||||
|
|
||||||
onceFullZkEvmCheckOnly.Do(func() {
|
onceFullZkEvmCheckOnly.Do(func() {
|
||||||
// Initialize the Full zkEVM arithmetization
|
// Initialize the Full zkEVM arithmetization
|
||||||
fullZkEvmCheckOnly = fullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
|
fullZkEvmCheckOnly = FullZKEVMWithSuite(tl, dummyCompilationSuite, cfg)
|
||||||
})
|
})
|
||||||
|
|
||||||
return fullZkEvmCheckOnly
|
return fullZkEvmCheckOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
func fullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
|
// FullZKEVMWithSuite returns a compiled zkEVM with the given compilation suite.
|
||||||
|
// It can be used to benchmark the compilation time of the zkEVM and helps with
|
||||||
|
// performance optimization.
|
||||||
|
func FullZKEVMWithSuite(tl *config.TracesLimits, suite compilationSuite, cfg *config.Config) *ZkEvm {
|
||||||
|
|
||||||
// @Alex: only set mandatory parameters here. aka, the one that are not
|
// @Alex: only set mandatory parameters here. aka, the one that are not
|
||||||
// actually feature-gated.
|
// actually feature-gated.
|
||||||
|
|||||||
Reference in New Issue
Block a user