mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 04:08:01 -05:00
* bench(global): adds a benchmark for the global constraint compiler * perf(merging): accumulates the factors before creating the expression * perf(product): computes the ESH without using a smart-vector * perf(factor): preallocations in the factorization algorithm * perf(removeZeroes): implements a lazy allocation mechanism in removeZeroCoeffs * perfs(alloc): counts the ret elements before returning in expandTerms to minimze the number of allocations. * perf(factor): use an integer map instead of a field.Element map when possible * fixup(expands): fix the skip condition for term expansion * perf(constructor): improves the immutable constructors to reduce the number of calls to NewProduct and NewLinComb * feat(repr): adds a json repr function to help debugging * test(constructor): cleans the test of the constructors * perf(factor): address maps using the first limb of a field.Element instead of the full field.Element * fixup(commit): adds missing file in previous commit * perf(factor): reduce the number of calls to rankChildren * perf(rmpolyeval): creates the equivalent expression more directly to save on unnecessary optims * perf(factors): use a counter in getCommonProdParentOfCs * perf(factor): remove map copy from findGdChildrenGroup and replace getCommonProdParent by a simpler function * clean(factor): remove unneeded function and imports * feat(utils): adds a generic sort interface implementation * perf(rankChildren): lazy allocation of the map to save on allocations * perf(factorize): reduces the loop-bound for factorizeExpression * (chore): fix a missing argument and format gofmt * feat: readd test --------- Signed-off-by: AlexandreBelling <alexandrebelling8@gmail.com> Co-authored-by: gusiri <dreamerty@postech.ac.kr>
390 lines
11 KiB
Go
390 lines
11 KiB
Go
package simplify
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/consensys/linea-monorepo/prover/maths/field"
|
|
sym "github.com/consensys/linea-monorepo/prover/symbolic"
|
|
"github.com/consensys/linea-monorepo/prover/utils"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// factorizeExpression attempt to simplify the expression by identifying common
|
|
// factors within sums and factor them into a single term.
|
|
func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
|
|
res := expr
|
|
initEsh := expr.ESHash
|
|
alreadyWalked := sync.Map{}
|
|
factorMemo := sync.Map{}
|
|
|
|
logrus.Infof("factoring expression : init stats %v", evaluateCostStat(expr))
|
|
|
|
for i := 0; i < iteration; i++ {
|
|
|
|
scoreInit := evaluateCostStat(res)
|
|
|
|
res = res.ReconstructBottomUp(func(lincomb *sym.Expression, newChildren []*sym.Expression) *sym.Expression {
|
|
// Time save, we reuse the results we got for that particular node.
|
|
if ret, ok := alreadyWalked.Load(lincomb.ESHash); ok {
|
|
return ret.(*sym.Expression)
|
|
}
|
|
|
|
// Incorporate the new children inside of the expression to account
|
|
// for them.
|
|
new := lincomb.SameWithNewChildren(newChildren)
|
|
// To ensure that it is not accessed anymore. Note that this does
|
|
// not mutate the input argument but makes it inaccessible to the
|
|
// rest of the function for safety.
|
|
lincomb = nil
|
|
prevSize := len(new.Children)
|
|
|
|
// The function returns only once it has figured out all the
|
|
// factoring possibilities. There is also a bound on the loop to
|
|
// prevent infinite loops.
|
|
//
|
|
// The choice of 100 is purely heuristic and is not meant to be
|
|
// actually met.
|
|
for k := 0; k < 100; k++ {
|
|
_, ok := new.Operator.(sym.LinComb)
|
|
if !ok {
|
|
return new
|
|
}
|
|
|
|
group := findGdChildrenGroup(new)
|
|
|
|
if len(group) < 1 {
|
|
return new
|
|
}
|
|
|
|
// Memoize the factorLinCompFromGroup result
|
|
cacheKey := fmt.Sprintf("%v-%v", new.ESHash, group)
|
|
|
|
if cachedResult, ok := factorMemo.Load(cacheKey); ok {
|
|
new = cachedResult.(*sym.Expression)
|
|
|
|
} else {
|
|
new = factorLinCompFromGroup(new, group)
|
|
factorMemo.Store(cacheKey, new)
|
|
}
|
|
|
|
if len(new.Children) >= prevSize {
|
|
return new
|
|
}
|
|
|
|
prevSize = len(new.Children)
|
|
}
|
|
|
|
alreadyWalked.Store(new.ESHash, new)
|
|
return new
|
|
})
|
|
|
|
if res.ESHash != initEsh {
|
|
panic("altered esh")
|
|
}
|
|
|
|
newScore := evaluateCostStat(res)
|
|
logrus.Infof("finished iteration : new stats %v", newScore)
|
|
|
|
if newScore.NumMul >= scoreInit.NumMul {
|
|
break
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// rankChildren ranks the children nodes of a list of parents based on which
|
|
// node has the highest number of parents in the list.
|
|
//
|
|
// The childrenSet is used as an exclusion set, the function shall not return
|
|
// children that are already in the children set.
|
|
func rankChildren(
|
|
parents []*sym.Expression,
|
|
childrenSet map[uint64]*sym.Expression,
|
|
) []*sym.Expression {
|
|
|
|
// List all the grand-children of the expression whose parents are
|
|
// products and counts the number of occurences by summing the exponents.
|
|
// As an optimization the map is addressed using the first uint64 repr
|
|
// of the element. We consider this is good enough to avoid collisions.
|
|
// The risk if it happens is that it gets caught by the validation checks
|
|
// at the end of the factorization routine. The preallocation value is
|
|
// purely heuristic to avoid successive allocations.
|
|
var relevantGdChildrenCnt map[uint64]int
|
|
var uniqueChildrenList []*sym.Expression
|
|
|
|
for _, p := range parents {
|
|
|
|
prod, ok := p.Operator.(sym.Product)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
for i, c := range p.Children {
|
|
// If the exponent is zero, then the term does not actually
|
|
// contribute in the expression.
|
|
if prod.Exponents[i] == 0 {
|
|
continue
|
|
}
|
|
|
|
// If it's in the group, it does not count. We can't add it a second
|
|
// time.
|
|
if _, ok := childrenSet[c.ESHash[0]]; ok {
|
|
continue
|
|
}
|
|
|
|
if relevantGdChildrenCnt == nil {
|
|
relevantGdChildrenCnt = make(map[uint64]int, len(parents)+2)
|
|
uniqueChildrenList = make([]*sym.Expression, 0, len(parents)+2)
|
|
}
|
|
|
|
if _, ok := relevantGdChildrenCnt[c.ESHash[0]]; !ok {
|
|
relevantGdChildrenCnt[c.ESHash[0]] = 0
|
|
uniqueChildrenList = append(uniqueChildrenList, c)
|
|
}
|
|
|
|
relevantGdChildrenCnt[c.ESHash[0]]++
|
|
}
|
|
}
|
|
|
|
sort.SliceStable(uniqueChildrenList, func(i, j int) bool {
|
|
x := uniqueChildrenList[i].ESHash
|
|
y := uniqueChildrenList[j].ESHash
|
|
// We want to a decreasing order
|
|
return relevantGdChildrenCnt[x[0]] > relevantGdChildrenCnt[y[0]]
|
|
})
|
|
|
|
return uniqueChildrenList
|
|
}
|
|
|
|
// findGdChildrenGroup finds a large set of grandchildren including c that are
|
|
// grandchildren of expr such that they are as big as possible and share more
|
|
// than one parent. The finding is based on a greedy algorithm. We iteratively
|
|
// add nodes in the group so that the number of common parents decreases as
|
|
// slowly as possible.
|
|
func findGdChildrenGroup(expr *sym.Expression) map[uint64]*sym.Expression {
|
|
|
|
curParents := expr.Children
|
|
childrenSet := map[uint64]*sym.Expression{}
|
|
|
|
ranked := rankChildren(curParents, childrenSet)
|
|
|
|
// Can happen when we have a lincomb of lincomb. Ideally they should be
|
|
// merged during canonization.
|
|
if len(ranked) == 0 {
|
|
return childrenSet
|
|
}
|
|
|
|
for i := range ranked {
|
|
|
|
best := ranked[i]
|
|
childrenSet[best.ESHash[0]] = best
|
|
curParents = filterParentsWithChildren(curParents, best.ESHash)
|
|
|
|
// Can't grow the set anymore
|
|
if len(curParents) <= 1 {
|
|
delete(childrenSet, best.ESHash[0])
|
|
return childrenSet
|
|
}
|
|
|
|
logrus.Tracef(
|
|
"find groups, so far we have %v parents and %v siblings",
|
|
len(curParents), len(childrenSet))
|
|
}
|
|
|
|
return childrenSet
|
|
}
|
|
|
|
// filterParentsWithChildren returns a filtered list of parents who have at
|
|
// least one child with the given ESHash. The function allocates a new list
|
|
// of parents and returns it without mutating he original list.
|
|
func filterParentsWithChildren(
|
|
parents []*sym.Expression,
|
|
childEsh field.Element,
|
|
) []*sym.Expression {
|
|
|
|
res := make([]*sym.Expression, 0, len(parents))
|
|
for _, p := range parents {
|
|
for _, c := range p.Children {
|
|
if c.ESHash == childEsh {
|
|
res = append(res, p)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
// factorLinCompFromGroup rebuilds lincomb by factoring it using `group` to
|
|
// determine the best common factor.
|
|
func factorLinCompFromGroup(
|
|
lincom *sym.Expression,
|
|
group map[uint64]*sym.Expression,
|
|
) *sym.Expression {
|
|
|
|
var (
|
|
|
|
// numTerms indicates the number of children in the linear-combination
|
|
numTerms = len(lincom.Children)
|
|
|
|
lcCoeffs = lincom.Operator.(sym.LinComb).Coeffs
|
|
// Build the common term by taking the max of the exponents
|
|
exponentsOfGroup, groupExpr = optimRegroupExponents(lincom.Children, group)
|
|
|
|
// Separate the non-factored terms
|
|
nonFactoredTerms = make([]*sym.Expression, 0, numTerms)
|
|
nonFactoredCoeffs = make([]int, 0, numTerms)
|
|
|
|
// The factored terms of the linear combination divided by the common
|
|
// group factor
|
|
factoredTerms = make([]*sym.Expression, 0, numTerms)
|
|
factoredCoeffs = make([]int, 0, numTerms)
|
|
)
|
|
|
|
numFactors := 0
|
|
for i, p := range lincom.Children {
|
|
factored, ok := isFactored(p, exponentsOfGroup)
|
|
if ok {
|
|
numFactors++
|
|
factoredTerms = append(factoredTerms, factored)
|
|
factoredCoeffs = append(factoredCoeffs, lcCoeffs[i])
|
|
} else {
|
|
nonFactoredTerms = append(nonFactoredTerms, p)
|
|
nonFactoredCoeffs = append(nonFactoredCoeffs, lcCoeffs[i])
|
|
}
|
|
}
|
|
|
|
logrus.Tracef("found %v factors for the group of size %v", numFactors, len(group))
|
|
|
|
// Could not factor anything
|
|
if numFactors == 0 {
|
|
return lincom
|
|
}
|
|
|
|
factoredExpr := sym.NewLinComb(factoredTerms, factoredCoeffs)
|
|
res := sym.Mul(factoredExpr, groupExpr)
|
|
|
|
// This is a conditional because it might be that the linear combination is
|
|
// fully factorized by the found factor.
|
|
if len(nonFactoredTerms) > 0 {
|
|
nonFactoredExpr := sym.NewLinComb(nonFactoredTerms, nonFactoredCoeffs)
|
|
res = sym.Add(res, nonFactoredExpr)
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
// Returns true if the product is factored by the given group. The current
|
|
// expression must be canonical.
|
|
//
|
|
// Assumption that the expression is canonical and that the exponent is
|
|
// not contained more than once. If the expression contains duplicates
|
|
// this will not be found.
|
|
//
|
|
// Fortunately, this is guaranteed if the expression was constructed via
|
|
// [sym.NewLinComb] or [sym.NewProduct] which is almost mandatory.
|
|
func isFactored(e *sym.Expression, exponentsOfGroup map[uint64]int) (
|
|
factored *sym.Expression,
|
|
success bool,
|
|
) {
|
|
|
|
op, isProduct := e.Operator.(sym.Product)
|
|
if !isProduct {
|
|
return nil, false
|
|
}
|
|
|
|
exponents := op.Exponents
|
|
factoredExponents := append([]int{}, exponents...)
|
|
|
|
numMatches := 0
|
|
for i, c := range e.Children {
|
|
eig, found := exponentsOfGroup[c.ESHash[0]]
|
|
if !found {
|
|
continue
|
|
}
|
|
|
|
if eig > exponents[i] {
|
|
return nil, false
|
|
}
|
|
|
|
numMatches++
|
|
factoredExponents[i] -= eig
|
|
}
|
|
|
|
if numMatches != len(exponentsOfGroup) {
|
|
return nil, false
|
|
}
|
|
|
|
return sym.NewProduct(e.Children, factoredExponents), true
|
|
}
|
|
|
|
// optimRegroupExponents returns an expression maximizing the exponents of an
|
|
// other expression. Panics if one of the parent is not a product or does not
|
|
// have the whole group as children.
|
|
func optimRegroupExponents(
|
|
parents []*sym.Expression,
|
|
group map[uint64]*sym.Expression,
|
|
) (
|
|
exponentMap map[uint64]int,
|
|
groupedTerm *sym.Expression,
|
|
) {
|
|
|
|
exponentMap = make(map[uint64]int, 16)
|
|
canonTermList := make([]*sym.Expression, 0, 16) // built in deterministic order
|
|
|
|
for _, p := range parents {
|
|
|
|
op, isProd := p.Operator.(sym.Product)
|
|
if !isProd {
|
|
continue
|
|
}
|
|
exponents := op.Exponents
|
|
|
|
// Used to sanity-check that all the nodes of the group have been
|
|
// reached through this parent.
|
|
matched := make(map[uint64]int, len(p.Children))
|
|
|
|
for i, c := range p.Children {
|
|
if _, ingroup := group[c.ESHash[0]]; !ingroup {
|
|
continue
|
|
}
|
|
|
|
if exponents[i] == 0 {
|
|
panic("The expression is not canonic")
|
|
}
|
|
|
|
_, initialized := exponentMap[c.ESHash[0]]
|
|
if !initialized {
|
|
// Max int is used as a placeholder. It will be replaced anytime
|
|
// we wall utils.Min(exponentMap[h], n) where n is actually an
|
|
// exponent.
|
|
exponentMap[c.ESHash[0]] = math.MaxInt
|
|
canonTermList = append(canonTermList, c)
|
|
}
|
|
|
|
matched[c.ESHash[0]] = exponents[i]
|
|
}
|
|
|
|
if len(matched) != len(group) {
|
|
continue
|
|
}
|
|
|
|
for esh, ex := range matched {
|
|
// Recall that the values of the exponent maps are initialized to
|
|
// MaxInt. So this will always pass ex the first time this loc is
|
|
// reached for esh.
|
|
exponentMap[esh] = utils.Min(ex, exponentMap[esh])
|
|
}
|
|
}
|
|
|
|
canonExponents := []int{}
|
|
for _, e := range canonTermList {
|
|
canonExponents = append(canonExponents, exponentMap[e.ESHash[0]])
|
|
}
|
|
|
|
return exponentMap, sym.NewProduct(canonTermList, canonExponents)
|
|
}
|