mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 04:08:01 -05:00
* bench(global): adds a benchmark for the global constraint compiler * perf(merging): accumulates the factors before creating the expression * perf(product): computes the ESH without using a smart-vector * perf(factor): preallocations in the factorization algorithm * perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs * perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations. * perf(factor): use an integer map instead of a field.Element map when possible * fixup(expands): fix the skip condition for term expansion * perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb * feat(repr): adds a json repr function to help debugging * test(constructor): cleans the test of the constructors * perf(factor): address maps using the first limb of a field.Element instead of the full field.Element * fixup(commit): adds missing file in previous commit * perf(factor): reduce the number of calls to rankChildren * perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims * perf(factors): use a counter in getCommonProdParentOfCs * perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function * clean(factor): remove unneeded function and imports * feat(utils): adds a generic sort interface implementation * perf(rankChildren): lazy allocation of the map to save on allocations * perf(factorize): reduces the loop-bound for factorizeExpression * (chore): fix a missing argument and format gofmt * feat: readd test --------- Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com> Co-authored-by: gusiri <dreamerty@postech.ac.kr>
228 lines
4.5 KiB
Go
228 lines
4.5 KiB
Go
package simplify
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
sym "github.com/consensys/linea-monorepo/prover/symbolic"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestIsFactored(t *testing.T) {
|
|
|
|
testcases := []struct {
|
|
Expr *sym.Expression
|
|
By *sym.Expression
|
|
IsFactored bool
|
|
Factor *sym.Expression
|
|
}{
|
|
{
|
|
Expr: sym.Mul(a, b, c),
|
|
By: sym.Mul(a, b),
|
|
IsFactored: true,
|
|
Factor: c,
|
|
},
|
|
{
|
|
Expr: sym.Mul(a, b, c),
|
|
By: sym.Mul(a, c),
|
|
IsFactored: true,
|
|
Factor: b,
|
|
},
|
|
{
|
|
Expr: sym.Mul(a, b, c),
|
|
By: sym.Mul(a, d),
|
|
IsFactored: false,
|
|
Factor: nil,
|
|
},
|
|
{
|
|
Expr: sym.Add(a, b, c),
|
|
By: sym.Add(a, d),
|
|
IsFactored: false,
|
|
Factor: nil,
|
|
},
|
|
{
|
|
Expr: sym.Add(a, b, c),
|
|
By: sym.Add(a, b, b),
|
|
IsFactored: false,
|
|
Factor: nil,
|
|
},
|
|
{
|
|
Expr: sym.Mul(a, b),
|
|
By: sym.Mul(a, b),
|
|
IsFactored: true,
|
|
Factor: sym.NewConstant(1),
|
|
},
|
|
{
|
|
Expr: sym.Mul(a, a),
|
|
By: a,
|
|
IsFactored: true,
|
|
Factor: a,
|
|
},
|
|
}
|
|
|
|
for i, tc := range testcases {
|
|
|
|
t.Run(fmt.Sprintf("testcase-%v", i), func(t *testing.T) {
|
|
// Build the group exponent map. If by is a product, we use directly
|
|
// the exponents it contains. Otherwise, we say this is a single
|
|
// term product with an exponent of 1.
|
|
groupedExp := map[uint64]int{}
|
|
if byProd, ok := tc.By.Operator.(sym.Product); ok {
|
|
for i, ex := range byProd.Exponents {
|
|
groupedExp[tc.By.Children[i].ESHash[0]] = ex
|
|
}
|
|
} else {
|
|
groupedExp[tc.By.ESHash[0]] = 1
|
|
}
|
|
|
|
factored, isFactored := isFactored(tc.Expr, groupedExp)
|
|
assert.Equalf(t, tc.IsFactored, isFactored, "missed factor identification")
|
|
|
|
if isFactored && tc.IsFactored {
|
|
assert.Equalf(t, tc.Factor.ESHash.String(), factored.ESHash.String(), "wrong factor")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFactorization(t *testing.T) {
|
|
|
|
var (
|
|
a = sym.NewDummyVar("a")
|
|
b = sym.NewDummyVar("b")
|
|
c = sym.NewDummyVar("c")
|
|
d = sym.NewDummyVar("d")
|
|
)
|
|
|
|
testCases := []struct {
|
|
Origin *sym.Expression
|
|
Factored *sym.Expression
|
|
}{
|
|
{
|
|
Origin: sym.Add(
|
|
sym.Mul(a, b),
|
|
sym.Mul(a, c),
|
|
sym.Mul(a, d),
|
|
),
|
|
Factored: sym.Mul(
|
|
a,
|
|
sym.Add(b, c, d),
|
|
),
|
|
},
|
|
{
|
|
Origin: sym.Add(
|
|
sym.Mul(a, b, b),
|
|
sym.Mul(a, b, c),
|
|
sym.Mul(a, b, d),
|
|
),
|
|
Factored: sym.Mul(
|
|
a,
|
|
b,
|
|
sym.Add(b, c, d),
|
|
),
|
|
},
|
|
{
|
|
Origin: sym.Add(
|
|
sym.Mul(a, b),
|
|
sym.Mul(a, b, c),
|
|
sym.Mul(a, b, d),
|
|
),
|
|
Factored: sym.Mul(
|
|
a,
|
|
b,
|
|
sym.Add(1, c, d),
|
|
),
|
|
},
|
|
{
|
|
Origin: sym.Add(
|
|
sym.Mul(a, b),
|
|
sym.Mul(a, c),
|
|
sym.Mul(d),
|
|
),
|
|
Factored: sym.Add(
|
|
sym.Mul(a, sym.Add(b, c)),
|
|
d,
|
|
),
|
|
},
|
|
}
|
|
|
|
for i, testCase := range testCases {
|
|
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
|
|
factored := factorizeExpression(testCase.Origin, 10)
|
|
require.Equal(t, testCase.Origin.ESHash.String(), factored.ESHash.String())
|
|
require.NoError(t, factored.Validate())
|
|
assert.Equal(t, evaluateCostStat(testCase.Factored), evaluateCostStat(factored))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFactorLinCompFromGroup(t *testing.T) {
|
|
|
|
testCases := []struct {
|
|
LinComb *sym.Expression
|
|
Group []*sym.Expression
|
|
Res *sym.Expression
|
|
}{
|
|
{
|
|
LinComb: sym.Add(
|
|
sym.Mul(a, b),
|
|
sym.Mul(a, b, c),
|
|
sym.Mul(a, b, d),
|
|
),
|
|
Group: []*sym.Expression{a, b},
|
|
Res: sym.Mul(
|
|
a,
|
|
b,
|
|
sym.Add(1, c, d),
|
|
),
|
|
},
|
|
{
|
|
LinComb: sym.Add(
|
|
sym.Mul(a, c),
|
|
sym.Mul(a, a, b),
|
|
1,
|
|
),
|
|
Group: []*sym.Expression{a},
|
|
Res: sym.Add(
|
|
sym.Mul(
|
|
a,
|
|
sym.Add(
|
|
sym.Mul(a, b),
|
|
c,
|
|
),
|
|
),
|
|
1,
|
|
),
|
|
},
|
|
}
|
|
|
|
for i, testCase := range testCases {
|
|
t.Run(fmt.Sprintf("test-case-%v", i), func(t *testing.T) {
|
|
|
|
group := map[uint64]*sym.Expression{}
|
|
for _, e := range testCase.Group {
|
|
group[e.ESHash[0]] = e
|
|
}
|
|
|
|
factored := factorLinCompFromGroup(testCase.LinComb, group)
|
|
assert.Equal(t, testCase.LinComb.ESHash.String(), factored.ESHash.String())
|
|
|
|
if t.Failed() {
|
|
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
|
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
|
t.Fatal()
|
|
}
|
|
|
|
require.NoError(t, factored.Validate())
|
|
assert.Equal(t, evaluateCostStat(testCase.Res), evaluateCostStat(factored))
|
|
|
|
if t.Failed() {
|
|
fmt.Printf("res=%v\n", testCase.Res.MarshalJSONString())
|
|
fmt.Printf("factored=%v\n", factored.MarshalJSONString())
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|