Prover: optimize memory allocations in symbolic expression (#3910)

* feat(pool): implements a multi-level caching system for the pool
* fix(sym): make sure that all the nodes are freed
* feat(pool): implements a DebugPool to debug leakage to gc
* feat(parallel): implements a channel based parallelization template
* test(pool): ensure that symbolic evaluation does not leak to gc
* fixup(par): adds a wg to the workload worker
* create an prover-checker that can be used for checking
* feat(check-only): adds a check-only mode in the config
* test(smartvectors): update the tests
* chores(makefile): rm corset flags from the unneeding commands
* remove the change on config-benchmark
* perf(pool): Use the pool for regular massive FFT workload
This commit is contained in:
AlexandreBelling
2024-09-10 20:58:55 +02:00
committed by GitHub
parent f92f743849
commit d079e98549
27 changed files with 766 additions and 244 deletions

View File

@@ -134,14 +134,14 @@ lib/compressor-and-shnarf-calculator-local: lib/compressor lib/shnarf_calculator
##
## Run all the unit-tests
##
test: zkevm/arithmetization/zkevm.bin
$(CORSET_FLAGS) go test ./...
test:
go test ./...
##
## Run the CI linting
##
ci-lint: zkevm/arithmetization/zkevm.bin
$(CORSET_FLAGS) golangci-lint run --timeout 5m
golangci-lint run --timeout 5m
##
## Echo, the CGO flags. Usefull for testing manually

View File

@@ -0,0 +1,146 @@
package mempool
import (
"errors"
"fmt"
"runtime"
"strconv"
"unsafe"
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/utils"
)
type DebuggeableCall struct {
Parent MemPool
Logs map[uintptr]*[]Record
}
func NewDebugPool(p MemPool) *DebuggeableCall {
return &DebuggeableCall{
Parent: p,
Logs: make(map[uintptr]*[]Record),
}
}
type Record struct {
Where string
What recordType
}
func (m *DebuggeableCall) Prewarm(nbPrewarm int) MemPool {
m.Parent.Prewarm(nbPrewarm)
return m
}
type recordType string
const (
AllocRecord recordType = "alloc"
FreeRecord recordType = "free"
)
func (m *DebuggeableCall) Alloc() *[]field.Element {
var (
v = m.Parent.Alloc()
uptr = uintptr(unsafe.Pointer(v))
logs *[]Record
_, file, line, _ = runtime.Caller(2)
)
logs, found := m.Logs[uptr]
if !found {
logs = &[]Record{}
m.Logs[uptr] = logs
}
*logs = append(*logs, Record{
Where: file + ":" + strconv.Itoa(line),
What: AllocRecord,
})
return v
}
func (m *DebuggeableCall) Free(v *[]field.Element) error {
var (
uptr = uintptr(unsafe.Pointer(v))
logs *[]Record
_, file, line, _ = runtime.Caller(2)
)
logs, found := m.Logs[uptr]
if !found {
logs = &[]Record{}
m.Logs[uptr] = logs
}
*logs = append(*logs, Record{
Where: file + ":" + strconv.Itoa(line),
What: FreeRecord,
})
return m.Parent.Free(v)
}
func (m *DebuggeableCall) Size() int {
return m.Parent.Size()
}
func (m *DebuggeableCall) TearDown() {
if p, ok := m.Parent.(*SliceArena); ok {
p.TearDown()
}
}
func (m *DebuggeableCall) Errors() error {
var err error
for _, logs_ := range m.Logs {
if logs_ == nil || len(*logs_) == 0 {
utils.Panic("got a nil entry")
}
logs := *logs_
for i := range logs {
if i == 0 && logs[i].What == FreeRecord {
err = errors.Join(err, fmt.Errorf("freed a vector that was not from the pool: where=%v", logs[i].Where))
}
if i == len(logs)-1 && logs[i].What == AllocRecord {
err = errors.Join(err, fmt.Errorf("leaked a vector out of the pool: where=%v", logs[i].Where))
}
if i == 0 {
continue
}
if logs[i-1].What == AllocRecord && logs[i].What == AllocRecord {
wheres := []string{logs[i-1].Where, logs[i].Where}
for k := i + 1; k < len(logs) && logs[k].What == AllocRecord; k++ {
wheres = append(wheres, logs[k].Where)
}
err = errors.Join(err, fmt.Errorf("vector was allocated multiple times concurrently where=%v", wheres))
}
if logs[i-1].What == FreeRecord && logs[i].What == FreeRecord {
wheres := []string{logs[i-1].Where, logs[i].Where}
for k := i + 1; k < len(logs) && logs[k].What == FreeRecord; k++ {
wheres = append(wheres, logs[k].Where)
}
err = errors.Join(err, fmt.Errorf("vector was freed multiple times concurrently where=%v", wheres))
}
}
}
return err
}

View File

@@ -0,0 +1,52 @@
package mempool
import (
"strings"
"testing"
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/stretchr/testify/assert"
)
func TestDebugPool(t *testing.T) {
t.Run("leak-detection", func(t *testing.T) {
pool := NewDebugPool(CreateFromSyncPool(32))
for i := 0; i < 16; i++ {
func() {
_ = pool.Alloc()
}()
}
err := pool.Errors().Error()
assert.True(t, strings.HasPrefix(err, "leaked a vector out of the pool"))
})
t.Run("double-free", func(t *testing.T) {
pool := NewDebugPool(CreateFromSyncPool(32))
v := pool.Alloc()
for i := 0; i < 16; i++ {
pool.Free(v)
}
err := pool.Errors().Error()
assert.Truef(t, strings.HasPrefix(err, "vector was freed multiple times concurrently"), err)
})
t.Run("foreign-free", func(t *testing.T) {
pool := NewDebugPool(CreateFromSyncPool(32))
v := make([]field.Element, 32)
pool.Free(&v)
err := pool.Errors().Error()
assert.Truef(t, strings.HasPrefix(err, "freed a vector that was not from the pool"), err)
})
}

View File

@@ -0,0 +1,73 @@
package mempool
import (
"sync"
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/consensys/zkevm-monorepo/prover/utils/parallel"
)
// FromSyncPool pools the allocation for slices of [field.Element] of size `Size`.
// It should be used with great caution and every slice allocated via this pool
// must be manually freed and only once.
//
// FromSyncPool is used to reduce the number of allocation which can be significant
// when doing operations over field elements.
type FromSyncPool struct {
size int
P sync.Pool
}
// CreateFromSyncPool initializes the Pool with the given number of elements in it.
func CreateFromSyncPool(size int) *FromSyncPool {
// Initializes the pool
return &FromSyncPool{
size: size,
P: sync.Pool{
New: func() any {
res := make([]field.Element, size)
return &res
},
},
}
}
// Prewarm the Pool by preallocating `nbPrewarm` in it.
func (p *FromSyncPool) Prewarm(nbPrewarm int) MemPool {
prewarmed := make([]field.Element, p.size*nbPrewarm)
parallel.Execute(nbPrewarm, func(start, stop int) {
for i := start; i < stop; i++ {
vec := prewarmed[i*p.size : (i+1)*p.size]
p.P.Put(&vec)
}
})
return p
}
// Alloc returns a vector allocated from the pool. Vector allocated via the
// pool should ideally be returned to the pool. If not, they are still going to
// be picked up by the GC.
func (p *FromSyncPool) Alloc() *[]field.Element {
res := p.P.Get().(*[]field.Element)
return res
}
// Free returns an object to the pool. It must never be called twice over
// the same object or undefined behaviours are going to arise. It is fine to
// pass objects allocated to outside of the pool as long as they have the right
// dimension.
func (p *FromSyncPool) Free(vec *[]field.Element) error {
// Check the vector has the right size
if len(*vec) != p.size {
utils.Panic("expected size %v, expected %v", len(*vec), p.Size())
}
p.P.Put(vec)
return nil
}
func (p *FromSyncPool) Size() int {
return p.size
}

View File

@@ -0,0 +1,60 @@
package mempool
import (
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/utils"
)
type MemPool interface {
Prewarm(nbPrewarm int) MemPool
Alloc() *[]field.Element
Free(vec *[]field.Element) error
Size() int
}
// ExtractCheckOptionalStrict returns
// - p[0], true if the expectedSize matches the one of the provided pool
// - nil, false if no `p` is provided
// - panic if the assigned size of the pool does not match
// - panic if the caller provides `nil` as argument for `p`
//
// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an
// optional variadic parameter and at the same time validating that the pool
// object has the right size.
func ExtractCheckOptionalStrict(expectedSize int, p ...MemPool) (pool MemPool, ok bool) {
// Checks if there is a pool
hasPool := len(p) > 0 && p[0] != nil
if hasPool {
pool = p[0]
}
// Sanity-check that the size of the pool is actually what we expected
if hasPool && pool.Size() != expectedSize {
utils.Panic("pooled vector size are %v, but required %v", pool.Size(), expectedSize)
}
return pool, hasPool
}
// ExtractCheckOptionalSoft returns
// - p[0], true if the expectedSize matches the one of the provided pool
// - nil, false if no `p` is provided
// - nil, false if the length of the vector does not match the one of the pool
// - panic if the caller provides `nil` as argument for `p`
//
// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an
// optional variadic parameter.
func ExtractCheckOptionalSoft(expectedSize int, p ...MemPool) (pool MemPool, ok bool) {
// Checks if there is a pool
hasPool := len(p) > 0
if hasPool {
pool = p[0]
}
// Sanity-check that the size of the pool is actually what we expected
if hasPool && pool.Size() != expectedSize {
return nil, false
}
return pool, hasPool
}

View File

@@ -0,0 +1,51 @@
package mempool
import (
"github.com/consensys/zkevm-monorepo/prover/maths/field"
)
// SliceArena is a simple not-threadsafe arena implementation that uses a
// mempool to carry its allocation. It will only put back free memory in the
// the parent pool when TearDown is called.
type SliceArena struct {
frees []*[]field.Element
parent MemPool
}
func WrapsWithMemCache(pool MemPool) *SliceArena {
return &SliceArena{
frees: make([]*[]field.Element, 0, 1<<7),
parent: pool,
}
}
func (m *SliceArena) Prewarm(nbPrewarm int) MemPool {
m.parent.Prewarm(nbPrewarm)
return m
}
func (m *SliceArena) Alloc() *[]field.Element {
if len(m.frees) == 0 {
return m.parent.Alloc()
}
last := m.frees[len(m.frees)-1]
m.frees = m.frees[:len(m.frees)-1]
return last
}
func (m *SliceArena) Free(v *[]field.Element) error {
m.frees = append(m.frees, v)
return nil
}
func (m *SliceArena) Size() int {
return m.parent.Size()
}
func (m *SliceArena) TearDown() {
for i := range m.frees {
m.parent.Free(m.frees[i])
}
}

View File

@@ -1,116 +0,0 @@
package mempool
import (
"sync"
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/consensys/zkevm-monorepo/prover/utils/parallel"
)
// Pool pools the allocation for slices of [field.Element] of size `Size`. It
// should be used with great caution and every slice allocated via this pool
// must be manually freed and only once.
//
// Pool is used to reduce the number of allocation which can be significant
// when doing operations over field elements.
type Pool struct {
Size int
P sync.Pool
}
// Create initializes the Pool with the given number of elements in it.
func Create(size int) *Pool {
// Initializes the pool
return &Pool{
Size: size,
P: sync.Pool{
New: func() any {
res := make([]field.Element, size)
return &res
},
},
}
}
// Prewarm the Pool by preallocating `nbPrewarm` in it.
func (p *Pool) Prewarm(nbPrewarm int) *Pool {
prewarmed := make([]field.Element, p.Size*nbPrewarm)
parallel.Execute(nbPrewarm, func(start, stop int) {
for i := start; i < stop; i++ {
vec := prewarmed[i*p.Size : (i+1)*p.Size]
p.P.Put(&vec)
}
})
return p
}
// Alloc returns a vector allocated from the pool. Vector allocated via the
// pool should ideally be returned to the pool. If not, they are still going to
// be picked up by the GC.
func (p *Pool) Alloc() *[]field.Element {
res := p.P.Get().(*[]field.Element)
return res
}
// Free returns an object to the pool. It must never be called twice over
// the same object or undefined behaviours are going to arise. It is fine to
// pass objects allocated to outside of the pool as long as they have the right
// dimension.
func (p *Pool) Free(vec *[]field.Element) error {
// Check the vector has the right size
if len(*vec) != p.Size {
utils.Panic("expected size %v, expected %v", len(*vec), p.Size)
}
p.P.Put(vec)
return nil
}
// ExtractCheckOptionalStrict returns
// - p[0], true if the expectedSize matches the one of the provided pool
// - nil, false if no `p` is provided
// - panic if the assigned size of the pool does not match
// - panic if the caller provides `nil` as argument for `p`
//
// This is used to unwrap a [Pool] that is commonly passed to functions as an
// optional variadic parameter and at the same time validating that the pool
// object has the right size.
func ExtractCheckOptionalStrict(expectedSize int, p ...*Pool) (pool *Pool, ok bool) {
// Checks if there is a pool
hasPool := len(p) > 0 && p[0] != nil
if hasPool {
pool = p[0]
}
// Sanity-check that the size of the pool is actually what we expected
if hasPool && pool.Size != expectedSize {
utils.Panic("pooled vector size are %v, but required %v", pool.Size, expectedSize)
}
return pool, hasPool
}
// ExtractCheckOptionalSoft returns
// - p[0], true if the expectedSize matches the one of the provided pool
// - nil, false if no `p` is provided
// - nil, false if the length of the vector does not match the one of the pool
// - panic if the caller provides `nil` as argument for `p`
//
// This is used to unwrap a [Pool] that is commonly passed to functions as an
// optional variadic parameter.
func ExtractCheckOptionalSoft(expectedSize int, p ...*Pool) (pool *Pool, ok bool) {
// Checks if there is a pool
hasPool := len(p) > 0
if hasPool {
pool = p[0]
}
// Sanity-check that the size of the pool is actually what we expected
if hasPool && pool.Size != expectedSize {
return nil, false
}
return pool, hasPool
}

View File

@@ -65,7 +65,7 @@ func InnerProduct(a, b SmartVector) field.Element {
// result = vecs[0] + vecs[1] * x + vecs[2] * x^2 + vecs[3] * x^3 + ...
//
// where `x` is a scalar and `vecs[i]` are [SmartVector]
func PolyEval(vecs []SmartVector, x field.Element, p ...*mempool.Pool) (result SmartVector) {
func PolyEval(vecs []SmartVector, x field.Element, p ...mempool.MemPool) (result SmartVector) {
if len(vecs) == 0 {
panic("no input vectors")
@@ -80,10 +80,11 @@ func PolyEval(vecs []SmartVector, x field.Element, p ...*mempool.Pool) (result S
resReg = make([]field.Element, length)
tmpVec = make([]field.Element, length)
} else {
a, b := pool.Alloc(), pool.Alloc()
resReg, tmpVec = *a, *b
a := AllocFromPool(pool)
b := AllocFromPool(pool)
resReg, tmpVec = a.Regular, b.Regular
vector.Fill(resReg, field.Zero())
defer pool.Free(b)
defer b.Free(pool)
}
var tmpF, resCon field.Element

View File

@@ -11,7 +11,7 @@ import (
// - The function panics if svecs is empty
// - The function panics if the length of coeffs does not match the length of
// svecs
func LinComb(coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
func LinComb(coeffs []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
// Sanity check : all svec should have the same length
length := svecs[0].Len()
for i := 0; i < len(svecs); i++ {
@@ -27,7 +27,7 @@ func LinComb(coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector
// - The function panics if svecs is empty
// - The function panics if the length of exponents does not match the length of
// svecs
func Product(exponents []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
func Product(exponents []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
return processOperator(productOp{}, exponents, svecs, p...)
}
@@ -36,7 +36,7 @@ func Product(exponents []int, svecs []SmartVector, p ...*mempool.Pool) SmartVect
// - The function panics if svecs is empty
// - The function panics if the length of coeffs does not match the length of
// svecs
func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
// There should be as many coeffs than there are vectors
if len(coeffs) != len(svecs) {
@@ -114,13 +114,13 @@ func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempo
case matchedRegular+matchedConst == totalToMatch:
// In this case, there are no windowed in the list. This means we only
// need to merge the const one into the regular one before returning
op.constTermIntoVec(*regularRes, &constRes.val)
op.constTermIntoVec(regularRes.Regular, &constRes.val)
return regularRes
default:
// If windowRes is a regular (can happen if all windows arguments cover the full circle)
if w, ok := windowRes.(*Regular); ok {
op.vecTermIntoVec(*regularRes, *w)
op.vecTermIntoVec(regularRes.Regular, *w)
return regularRes
}
@@ -130,7 +130,7 @@ func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempo
// In this case, the constant is already accumulated into the windowed.
// Thus, we just have to merge the windowed one into the regular one.
interval := windowRes.interval()
regvec := *regularRes
regvec := regularRes.Regular
length := len(regvec)
// The windows rolls over

View File

@@ -85,7 +85,7 @@ func TestFuzzProductWithPool(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
t.Logf("TEST CASE %v\n", tcase.String())
@@ -112,7 +112,7 @@ func TestFuzzProductWithPoolCompare(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
t.Logf("TEST CASE %v\n", tcase.String())
@@ -142,7 +142,7 @@ func TestFuzzLinCombWithPool(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
t.Logf("TEST CASE %v\n", tcase.String())
@@ -159,7 +159,6 @@ func TestFuzzLinCombWithPool(t *testing.T) {
t.FailNow()
}
}
}
func TestFuzzLinCombWithPoolCompare(t *testing.T) {
@@ -169,7 +168,7 @@ func TestFuzzLinCombWithPoolCompare(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
t.Logf("TEST CASE %v\n", tcase.String())
@@ -237,8 +236,8 @@ func TestOpBasicEdgeCases(t *testing.T) {
for i, testCase := range testCases {
t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) {
t.Logf("test-case details: %v", testCase.explainer)
res := testCase.fn(testCase.inputs...)
require.Equal(t, testCase.expectedRes, res, "expectedRes=%v\nres=%v", testCase.expectedRes.Pretty(), res.Pretty())
res := testCase.fn(testCase.inputs...).(*Pooled)
require.Equal(t, testCase.expectedRes, &res.Regular, "expectedRes=%v\nres=%v", testCase.expectedRes.Pretty(), res.Pretty())
})
}
}
@@ -290,7 +289,7 @@ func TestFuzzPolyEvalWithPool(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
// PolyEval() with pool
polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool)
@@ -315,7 +314,7 @@ func TestFuzzPolyEvalWithPoolCompare(t *testing.T) {
success := t.Run(tcase.name, func(t *testing.T) {
pool := mempool.Create(tcase.svecs[0].Len())
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
// PolyEval() with pool
polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool)

View File

@@ -77,13 +77,13 @@ func (r *Regular) Pretty() string {
return fmt.Sprintf("Regular[%v]", vector.Prettify(*r))
}
func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*mempool.Pool) (result *Regular, numMatches int) {
func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...mempool.MemPool) (result *Pooled, numMatches int) {
length := svecs[0].Len()
pool, hasPool := mempool.ExtractCheckOptionalStrict(length, p...)
var resvec []field.Element
var resvec *Pooled
isFirst := true
numMatches = 0
@@ -97,6 +97,10 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
svec = rotatedAsRegular(rot)
}
if pooled, ok := svec.(*Pooled); ok {
svec = &pooled.Regular
}
if reg, ok := svec.(*Regular); ok {
numMatches++
// For the first one, we can save by just copying the result
@@ -104,16 +108,17 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
// zero.
if isFirst {
if hasPool {
resvec = *pool.Alloc()
resvec = AllocFromPool(pool)
} else {
resvec = make([]field.Element, length)
resvec = &Pooled{Regular: make([]field.Element, length)}
}
isFirst = false
op.vecIntoTerm(resvec, *reg, coeffs[i])
op.vecIntoTerm(resvec.Regular, *reg, coeffs[i])
continue
}
op.vecIntoVec(resvec, *reg, coeffs[i])
op.vecIntoVec(resvec.Regular, *reg, coeffs[i])
}
}
@@ -121,8 +126,7 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
return nil, 0
}
res := Regular(resvec)
return &res, numMatches
return resvec, numMatches
}
func (r *Regular) DeepCopy() SmartVector {
@@ -134,3 +138,24 @@ func (r *Regular) DeepCopy() SmartVector {
func (r *Regular) IntoRegVecSaveAlloc() []field.Element {
return (*r)[:]
}
type Pooled struct {
Regular
poolPtr *[]field.Element
}
func AllocFromPool(pool mempool.MemPool) *Pooled {
poolPtr := pool.Alloc()
return &Pooled{
Regular: *poolPtr,
poolPtr: poolPtr,
}
}
func (p *Pooled) Free(pool mempool.MemPool) {
if p.poolPtr != nil {
pool.Free(p.poolPtr)
}
p.poolPtr = nil
p.Regular = nil
}

View File

@@ -204,7 +204,8 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
stopTimer = profiling.LogTimer("Computing the coeffs %v pols of size %v", len(ctx.AllInvolvedColumns), ctx.DomainSize)
lock = sync.Mutex{}
lockRun = sync.Mutex{}
pool = mempool.Create(symbolic.MaxChunkSize).Prewarm(runtime.NumCPU() * ctx.MaxNbExprNode)
pool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize).Prewarm(runtime.NumCPU() * ctx.MaxNbExprNode)
largePool = mempool.CreateFromSyncPool(ctx.DomainSize).Prewarm(len(ctx.AllInvolvedColumns))
)
if ctx.DomainSize >= GC_DOMAIN_SIZE {
@@ -236,7 +237,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
if !isNatural {
witness = pol.GetColAssignment(run)
}
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
lock.Lock()
coeffs[name] = witness
@@ -261,7 +262,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
// normal case for interleaved or repeated columns
witness := pol.GetColAssignment(run)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
name := pol.GetColID()
lock.Lock()
coeffs[name] = witness
@@ -356,8 +357,12 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
stopTimer := profiling.LogTimer("ReEvaluate %v pols of size %v on coset %v/%v", len(handles), ctx.DomainSize, share, ratio)
parallel.ExecuteChunky(len(roots), func(start, stop int) {
for k := start; k < stop; k++ {
parallel.ExecuteFromChan(len(roots), func(wg *sync.WaitGroup, indexChan chan int) {
localPool := mempool.WrapsWithMemCache(largePool)
defer localPool.TearDown()
for k := range indexChan {
root := roots[k]
name := root.GetColID()
@@ -365,18 +370,25 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
if found {
// it was already computed in a previous iteration of `j`
wg.Done()
continue
}
// else it's the first value of j that sees it. so we compute the
// coset reevaluation.
reevaledRoot := sv.FFT(coeffs[name], fft.DIT, false, ratio, share)
reevaledRoot := sv.FFT(coeffs[name], fft.DIT, false, ratio, share, localPool)
computedReeval.Store(name, reevaledRoot)
wg.Done()
}
})
parallel.ExecuteChunky(len(handles), func(start, stop int) {
for k := start; k < stop; k++ {
parallel.ExecuteFromChan(len(handles), func(wg *sync.WaitGroup, indexChan chan int) {
localPool := mempool.WrapsWithMemCache(largePool)
defer localPool.TearDown()
for k := range indexChan {
pol := handles[k]
// short-path, the column is a purely Shifted(Natural) or a Natural
@@ -397,6 +409,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
// Now, we can reuse a soft-rotation of the smart-vector to save memory
if !pol.IsComposite() {
// in this case, the right vector was the root so we are done
wg.Done()
continue
}
@@ -404,6 +417,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
polName := pol.GetColID()
res := sv.SoftRotate(reevaledRoot.(sv.SmartVector), shifted.Offset)
computedReeval.Store(polName, res)
wg.Done()
continue
}
@@ -412,14 +426,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
name := pol.GetColID()
_, ok := computedReeval.Load(name)
if ok {
wg.Done()
continue
}
if _, ok := coeffs[name]; !ok {
utils.Panic("handle %v not found in the coeffs\n", name)
}
res := sv.FFT(coeffs[name], fft.DIT, false, ratio, share)
res := sv.FFT(coeffs[name], fft.DIT, false, ratio, share, localPool)
computedReeval.Store(name, res)
wg.Done()
}
})
@@ -477,6 +495,11 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
// Forcefuly clean the memory for the computed reevals
computedReeval.Range(func(k, v interface{}) bool {
if pooled, ok := v.(*sv.Pooled); ok {
pooled.Free(largePool)
}
computedReeval.Delete(k)
return true
})

View File

@@ -36,7 +36,7 @@ func TestMPTS(t *testing.T) {
hLProver := func(assi *wizard.ProverRuntime) {
p1 := smartvectors.Rand(PolSize)
p2 := smartvectors.Rand(PolSize)
p2 := smartvectors.NewConstant(field.NewElement(42), PolSize)
assi.AssignColumn(P1, p1)
assi.AssignColumn(P2, p2)

View File

@@ -4,9 +4,11 @@ import (
"fmt"
"math/big"
"reflect"
"runtime"
"sync"
"github.com/consensys/gnark/frontend"
"github.com/consensys/zkevm-monorepo/prover/maths/common/mempool"
"github.com/consensys/zkevm-monorepo/prover/maths/common/poly"
sv "github.com/consensys/zkevm-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/zkevm-monorepo/prover/maths/fft"
@@ -109,6 +111,8 @@ type mptsCtx struct {
// maxSize int
quotientSize int
targetSize int
// maxSize is the size of the largest column being processed by the compiler
maxSize int
// Number of rounds in the original protocol
numRound int
// Various indentifiers created in the protocol
@@ -129,7 +133,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
xPoly := make(map[ifaces.ColID][]int)
hs := []ifaces.QueryID{}
polys := []ifaces.Column{}
maxDeg := 0
maxSize := 0
/*
Adding coins in the protocol can add extra rounds,
@@ -182,7 +186,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
if _, ok := xPoly[poly.GetColID()]; !ok {
polys = append(polys, poly)
xPoly[poly.GetColID()] = []int{}
maxDeg = utils.Max(maxDeg, poly.Size())
maxSize = utils.Max(maxSize, poly.Size())
}
xPoly[poly.GetColID()] = append(xPoly[poly.GetColID()], queryCount-1)
}
@@ -207,6 +211,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
quotientSize: quotientSize,
targetSize: targetSize,
numRound: numRound,
maxSize: maxSize,
QuotientName: deriveName[ifaces.ColID](comp, MPTS, "", MTSP_QUOTIENT_SUFFIX),
LinCombCoeff: deriveName[coin.Name](comp, MPTS, "", MTSP_LIN_COMB),
EvaluationPoint: deriveName[coin.Name](comp, MPTS, "", MTSP_RAND_EVAL),
@@ -226,31 +231,48 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
r := run.GetRandomCoinField(ctx.LinCombCoeff)
// Preallocate the value of the quotient
quotient := sv.AllocateRegular(ctx.quotientSize)
quoLock := sync.Mutex{}
ys, hs := ctx.getYsHs(run.GetUnivariateParams, run.GetUnivariateEval)
logrus.Tracef("computed Ys and Hs")
var (
quotient = sv.AllocateRegular(ctx.quotientSize)
quoLock = sync.Mutex{}
ys, hs = ctx.getYsHs(run.GetUnivariateParams, run.GetUnivariateEval)
// precompute the lagrange polynomials
lagranges := getLagrangesPolys(hs)
// precompute the lagrange polynomials
lagranges = getLagrangesPolys(hs)
pool = mempool.CreateFromSyncPool(ctx.targetSize).Prewarm(runtime.NumCPU())
mainWg = &sync.WaitGroup{}
)
parallel.Execute(len(ctx.polys), func(start, stop int) {
mainWg.Add(runtime.NumCPU())
maxSize := 0
for i := start; i < stop; i++ {
maxSize = utils.Max(maxSize, ctx.polys[i].Size())
}
parallel.ExecuteFromChan(len(ctx.polys), func(wg *sync.WaitGroup, indexChan chan int) {
// Preallocate the value of the quotient
subQuotient := sv.AllocateRegular(maxSize)
var (
pool = mempool.WrapsWithMemCache(pool)
for i := start; i < stop; i++ {
polHandle := ctx.polys[i]
/*
Get the coefficients of the witness polynomial
*/
polWitness := polHandle.GetColAssignment(run)
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0)
// Preallocate the value of the quotient
subQuotientReg = sv.AllocateRegular(ctx.maxSize)
subQuotientCnst = field.Zero()
)
defer pool.TearDown()
for i := range indexChan {
var (
polHandle = ctx.polys[i]
polWitness = polHandle.GetColAssignment(run)
ri field.Element
)
ri.Exp(r, big.NewInt(int64(i)))
if cnst, isCnst := polWitness.(*sv.Constant); isCnst {
polWitness = sv.NewRegular([]field.Element{cnst.Val()})
} else if pool.Size() == polWitness.Len() {
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0, pool)
} else {
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0, nil)
}
/*
Substract by the lagrange interpolator and get the quotient
@@ -298,32 +320,37 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
}
}
/*
Then accumulate the subQuotient. The subQuotient is supposedly
larger than the pol,
*/
var ri field.Element
ri.Exp(r, big.NewInt(int64(i)))
// Should be very uncommon
if subQuotient.Len() < polWitness.Len() {
if subQuotientReg.Len() < polWitness.Len() {
logrus.Warnf("Warning reallocation of the subquotient for MPTS. If there are too many it's an issue")
// It's the only known use-case for concatenating smart-vectors
newSubquotient := make([]field.Element, polWitness.Len())
subQuotient.WriteInSlice(newSubquotient[:subQuotient.Len()])
subQuotient = sv.NewRegular(newSubquotient)
subQuotientReg.WriteInSlice(newSubquotient[:subQuotientReg.Len()])
subQuotientReg = sv.NewRegular(newSubquotient)
}
tmp := sv.ScalarMul(polWitness, ri)
subQuotient = sv.PolyAdd(tmp, subQuotient)
subQuotientReg = sv.PolyAdd(tmp, subQuotientReg)
if pooled, ok := polWitness.(*sv.Pooled); ok {
pooled.Free(pool)
}
wg.Done()
}
subQuotientReg = sv.PolyAdd(subQuotientReg, sv.NewConstant(subQuotientCnst, 1))
// This locking mechanism is completely subOptimal, but this should be good enough
quoLock.Lock()
quotient = sv.PolyAdd(quotient, subQuotient)
quotient = sv.PolyAdd(quotient, subQuotientReg)
quoLock.Unlock()
mainWg.Done()
})
mainWg.Wait()
if quotient.Len() < ctx.targetSize {
quo := sv.IntoRegVec(quotient)
quotient = sv.RightZeroPadded(quo, ctx.targetSize)
@@ -332,7 +359,7 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
for i := range ctx.Quotients {
// each subquotient is a slice of the original larger quotient
subQuotient := quotient.SubVector(i*ctx.targetSize, (i+1)*ctx.targetSize)
subQuotient = sv.FFT(subQuotient, fft.DIF, true, 0, 0)
subQuotient = sv.FFT(subQuotient, fft.DIF, true, 0, 0, nil)
run.AssignColumn(ctx.Quotients[i].GetColID(), subQuotient)
}

View File

@@ -42,7 +42,7 @@ func CheckReedSolomon(comp *wizard.CompiledIOP, rate int, h ifaces.Column) {
//
comp.SubProvers.AppendToInner(round, func(assi *wizard.ProverRuntime) {
witness := h.GetColAssignment(assi)
coeffs := smartvectors.FFTInverse(witness, fft.DIF, true, 0, 0).SubVector(0, codeDim)
coeffs := smartvectors.FFTInverse(witness, fft.DIF, true, 0, 0, nil).SubVector(0, codeDim)
assi.AssignColumn(ifaces.ColIDf("%v_%v", REED_SOLOMON_COEFF, h.GetColID()), coeffs)
})

View File

@@ -15,7 +15,7 @@ import (
func TestReedSolomon(t *testing.T) {
wp := smartvectors.ForTest(1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0)
wp = smartvectors.FFT(wp, fft.DIF, true, 0, 0)
wp = smartvectors.FFT(wp, fft.DIF, true, 0, 0, nil)
definer := func(b *wizard.Builder) {
p := b.RegisterCommit("P", wp.Len())

View File

@@ -156,8 +156,8 @@ func TestPeriodicSampleCoset(t *testing.T) {
for cosetID := 0; cosetID < ratio; cosetID++ {
testEval := sampling.EvalCoset(domain, cosetID, ratio, true)
testEval = smartvectors.FFTInverse(testEval, fft.DIF, true, ratio, cosetID)
testEval = smartvectors.FFT(testEval, fft.DIT, true, 0, 0)
testEval = smartvectors.FFTInverse(testEval, fft.DIF, true, ratio, cosetID, nil)
testEval = smartvectors.FFT(testEval, fft.DIT, true, 0, 0, nil)
require.Equal(t, vanillaEval.Pretty(), testEval.Pretty(),
"domain %v, period %v, offset %v, ratio %v, cosetID %v",

View File

@@ -22,7 +22,7 @@ func (Constant) Degree([]int) int {
}
// Evaluates implements the [Operator] interface
func (c Constant) Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector {
func (c Constant) Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector {
panic("we never call it for a constant")
}

View File

@@ -2,6 +2,7 @@ package symbolic
import (
"fmt"
"sync"
"github.com/consensys/gnark/frontend"
"github.com/consensys/zkevm-monorepo/prover/maths/common/mempool"
@@ -35,18 +36,18 @@ func (b *ExpressionBoard) ListVariableMetadata() []Metadata {
}
// Evaluate the board for a batch of inputs in parallel
func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
/*
Find the size of the vector
*/
totalSize := 0
for i, inp := range inputs {
if totalSize > 0 && totalSize != inp.Len() {
// Expects that all vector inputs have the same size
utils.Panic("Mismatch in the size: len(v) %v, totalsize %v, pos %v", inp.Len(), totalSize, i)
}
if totalSize == 0 {
totalSize = inp.Len()
}
@@ -58,6 +59,10 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
utils.Panic("Either there is no input or the inputs all have size 0")
}
if len(p) > 0 && p[0].Size() != MaxChunkSize {
utils.Panic("the pool should be a pool of vectors of size=%v but it is %v", MaxChunkSize, p[0].Size())
}
/*
The is no vector input iff totalSize is 0
Thus the condition below catch the two cases where:
@@ -65,9 +70,18 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
- The vectors are smaller than the min chunk size
*/
if totalSize <= MaxChunkSize {
// never pass the pool here
return b.evaluateSingleThread(inputs)
if totalSize < MaxChunkSize {
// never pass the pool here as the pool assumes that all vectors have a
// size of MaxChunkSize. Thus, it would not work here.
return b.evaluateSingleThread(inputs).DeepCopy()
}
// This is the code-path that is used for benchmarking when the size of the
// vectors is exactly MaxChunkSize. In production, it will rather use the
// multi-threaded option. The above condition cannot use the pool because we
// assume here that the pool has a vector size of exactly MaxChunkSize.
if totalSize == MaxChunkSize {
return b.evaluateSingleThread(inputs, p...).DeepCopy()
}
if totalSize%MaxChunkSize != 0 {
@@ -77,12 +91,23 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
numChunks := totalSize / MaxChunkSize
res := make([]field.Element, totalSize)
parallel.ExecuteChunky(numChunks, func(start, stop int) {
for chunkID := start; chunkID < stop; chunkID++ {
parallel.ExecuteFromChan(numChunks, func(wg *sync.WaitGroup, idChan chan int) {
chunkStart := chunkID * MaxChunkSize
chunkStop := (chunkID + 1) * MaxChunkSize
chunkInputs := make([]sv.SmartVector, len(inputs))
var pool []mempool.MemPool
if len(p) > 0 {
if _, ok := p[0].(*mempool.DebuggeableCall); !ok {
pool = append(pool, mempool.WrapsWithMemCache(p[0]))
}
}
chunkInputs := make([]sv.SmartVector, len(inputs))
for chunkID := range idChan {
var (
chunkStart = chunkID * MaxChunkSize
chunkStop = (chunkID + 1) * MaxChunkSize
)
for i, inp := range inputs {
chunkInputs[i] = inp.SubVector(chunkStart, chunkStop)
@@ -94,11 +119,19 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
// We don't parallelize evaluations where the inputs are all scalars
// Therefore the cast is safe.
chunkRes := b.evaluateSingleThread(chunkInputs, p...)
chunkRes := b.evaluateSingleThread(chunkInputs, pool...)
// No race condition here as each call write to different places
// of vec.
chunkRes.WriteInSlice(res[chunkStart:chunkStop])
wg.Done()
}
if len(p) > 0 {
if sa, ok := pool[0].(*mempool.SliceArena); ok {
sa.TearDown()
}
}
})
@@ -107,7 +140,7 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
// evaluateSingleThread evaluates a boarded expression. The inputs can be either
// vector or scalars. The vector's input length should be smaller than a chunk.
func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
var (
length = inputs[0].Len()
@@ -138,10 +171,9 @@ func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...*me
// Deep-copy the last node and put resBuf back in the pool. It's cleanier
// that way.
if hasPool {
if reg, ok := resBuf.(*sv.Regular); ok {
if reg, ok := resBuf.(*sv.Pooled); ok {
resGC := reg.DeepCopy()
v := []field.Element(*reg)
pool.Free(&v)
reg.Free(pool)
resBuf = resGC
}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/protocol/serialization"
"github.com/consensys/zkevm-monorepo/prover/symbolic"
"github.com/stretchr/testify/assert"
)
func BenchmarkEvaluationSingleThreaded(b *testing.B) {
@@ -38,18 +39,18 @@ func BenchmarkEvaluationSingleThreaded(b *testing.B) {
b.Run(fmt.Sprintf("ratio-%v", ratio), func(b *testing.B) {
var (
testDir = "testdata/evaluation-benchmark"
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
constanthoodFPath = path.Join(testDir, constanthoodFName)
exprFPath = path.Join(testDir, exprFName)
constantHoodFile = files.MustRead(constanthoodFPath)
exprFile = files.MustReadCompressed(exprFPath)
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
expr = serialization.UnmarshalExprCBOR(exprFile)
inputs = make([]smartvectors.SmartVector, len(constantHood))
board = expr.Board()
pool = mempool.Create(symbolic.MaxChunkSize)
testDir = "testdata/evaluation-benchmark"
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
constanthoodFPath = path.Join(testDir, constanthoodFName)
exprFPath = path.Join(testDir, exprFName)
constantHoodFile = files.MustRead(constanthoodFPath)
exprFile = files.MustReadCompressed(exprFPath)
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
expr = serialization.UnmarshalExprCBOR(exprFile)
inputs = make([]smartvectors.SmartVector, len(constantHood))
board = expr.Board()
pool mempool.MemPool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize)
)
for i := range inputs {
@@ -71,6 +72,78 @@ func BenchmarkEvaluationSingleThreaded(b *testing.B) {
_ = board.Evaluate(inputs, pool)
}
})
}
}
func TestEvaluationSingleThreaded(t *testing.T) {
makeRegular := func() smartvectors.SmartVector {
return smartvectors.Rand(symbolic.MaxChunkSize)
}
makeConst := func() smartvectors.SmartVector {
var x field.Element
x.SetRandom()
return smartvectors.NewConstant(x, symbolic.MaxChunkSize)
}
makeFullZero := func() smartvectors.SmartVector {
return smartvectors.NewConstant(field.Zero(), symbolic.MaxChunkSize)
}
makeFullOnes := func() smartvectors.SmartVector {
return smartvectors.NewConstant(field.One(), symbolic.MaxChunkSize)
}
for ratio := 1; ratio <= 32; ratio *= 2 {
t.Run(fmt.Sprintf("ratio-%v", ratio), func(b *testing.T) {
var (
testDir = "testdata/evaluation-benchmark"
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
constanthoodFPath = path.Join(testDir, constanthoodFName)
exprFPath = path.Join(testDir, exprFName)
constantHoodFile = files.MustRead(constanthoodFPath)
exprFile = files.MustReadCompressed(exprFPath)
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
expr = serialization.UnmarshalExprCBOR(exprFile)
inputs = make([]smartvectors.SmartVector, len(constantHood))
board = expr.Board()
pool_ mempool.MemPool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize)
)
pool_ = mempool.WrapsWithMemCache(pool_)
pool := mempool.NewDebugPool(pool_)
_, mustBeTrue := mempool.ExtractCheckOptionalSoft(symbolic.MaxChunkSize, pool)
_, mustBeTrue2 := mempool.ExtractCheckOptionalStrict(symbolic.MaxChunkSize, pool)
assert.True(t, mustBeTrue)
assert.True(t, mustBeTrue2)
for i := range inputs {
switch {
case !constantHood[i][0]:
inputs[i] = makeRegular()
case constantHood[i][1]:
inputs[i] = makeFullZero()
case constantHood[i][2]:
inputs[i] = makeFullOnes()
default:
inputs[i] = makeConst()
}
}
_ = board.Evaluate(inputs, pool)
if len(pool.Logs) == 0 {
t.Fatalf("the pool was not used")
}
assert.NoError(t, pool.Errors())
})
}
}

View File

@@ -5,6 +5,7 @@ import (
sv "github.com/consensys/zkevm-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/utils"
"github.com/sirupsen/logrus"
)
type boardAssignment [][]nodeAssignment
@@ -66,7 +67,7 @@ func (b *ExpressionBoard) prepareNodeAssignments(inputs []sv.SmartVector) boardA
if success {
for i := range inputs {
nodeAssignments.incParentKnownCountOf(&inputs[i], nil, true)
nodeAssignments.incParentKnownCountOf(inputs[i], nil, true)
}
}
@@ -77,7 +78,7 @@ func (b *ExpressionBoard) prepareNodeAssignments(inputs []sv.SmartVector) boardA
return nodeAssignments
}
func (b boardAssignment) eval(na *nodeAssignment, pool *mempool.Pool) {
func (b boardAssignment) eval(na *nodeAssignment, pool mempool.MemPool) {
if (na.allParentsKnown() && na.hasParents()) || na.hasAValue() {
return
@@ -99,11 +100,11 @@ func (b boardAssignment) eval(na *nodeAssignment, pool *mempool.Pool) {
na.Value = na.Node.Operator.Evaluate(smv, pool)
for i := range val {
b.incParentKnownCountOf(&val[i], pool, false)
b.incParentKnownCountOf(val[i], pool, false)
}
}
func (na *nodeAssignment) tryGuessEval(val []nodeAssignment) bool {
func (na *nodeAssignment) tryGuessEval(val []*nodeAssignment) bool {
if na.hasAValue() {
return true
@@ -146,6 +147,7 @@ func (na *nodeAssignment) tryGuessEval(val []nodeAssignment) bool {
return true
}
return false
default:
panic("unexpected type")
}
@@ -176,13 +178,13 @@ func (na *nodeAssignment) constValue() (*sv.Constant, bool) {
return nil, false
}
func (b boardAssignment) inputOf(na *nodeAssignment) []nodeAssignment {
func (b boardAssignment) inputOf(na *nodeAssignment) []*nodeAssignment {
if na.Node == nil {
panic("na has a nil node")
}
nodeInputs := make([]nodeAssignment, len(na.Node.Children))
nodeInputs := make([]*nodeAssignment, len(na.Node.Children))
for i, childID := range na.Node.Children {
var (
@@ -190,25 +192,31 @@ func (b boardAssignment) inputOf(na *nodeAssignment) []nodeAssignment {
pil = childID.posInLevel()
)
nodeInputs[i] = b[lvl][pil]
nodeInputs[i] = &b[lvl][pil]
}
return nodeInputs
}
func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool *mempool.Pool, recursive bool) (wasDeleted bool) {
func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool mempool.MemPool, recursive bool) (wasDeleted bool) {
na.NumKnownParents++
// Sanity-checking that this function is not called too many time
if na.NumKnownParents > len(na.Node.Parents) {
panic("invalid count: overflowing the total number of parent")
utils.Panic("invalid count: overflowing the total number of parent")
}
if na.allParentsKnown() {
if recursive {
// The recursive call to incParentKnownCount is needed only if the node
// that we "completed" by marking all its parent as known was completed
// **only** for that reason. It could also have been completed because
// all its children are constants. When that is the case, all the children
// will have been incremented already.
if recursive && na.Value == nil {
children := b.inputOf(na)
for i := range children {
b.incParentKnownCountOf(&children[i], pool, recursive)
b.incParentKnownCountOf(children[i], pool, recursive)
}
}
@@ -218,7 +226,7 @@ func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool *mempool
return false
}
func (na *nodeAssignment) tryFree(pool *mempool.Pool) bool {
func (na *nodeAssignment) tryFree(pool mempool.MemPool) bool {
if pool == nil {
return false
}
@@ -236,12 +244,39 @@ func (na *nodeAssignment) tryFree(pool *mempool.Pool) bool {
return false
}
if reg, ok := na.Value.(*sv.Regular); ok {
if reg, ok := na.Value.(*sv.Pooled); ok {
na.Value = nil
v := []field.Element(*reg)
pool.Free(&v)
reg.Free(pool)
return true
}
return false
}
func (b boardAssignment) inspectCleaning() {
for lvl := 1; lvl < len(b); lvl++ {
for pil := range b[lvl] {
na := b[lvl][pil]
if na.NumKnownParents != len(na.Node.Parents) {
logrus.Errorf(
"the parent count was not updated till the end lvl=%v pil=%v parentCnt=%v nbParent=%v valueT=%T parents=%v",
lvl, pil, na.NumKnownParents, len(na.Node.Parents), na.Value, na.Node.Parents,
)
}
if na.Value == nil {
continue
}
p, ok := na.Value.(*sv.Pooled)
if !ok {
continue
}
if p.Regular != nil {
logrus.Errorf("the result of node [%v %v] was not cleaned", lvl, pil)
}
}
}
}

View File

@@ -57,7 +57,7 @@ type Expression struct {
type Operator interface {
// Evaluate returns an evaluation of the operator from a list of assignments:
// one for each operand (children) of the expression.
Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector
Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector
// Validate performs a sanity-check of the expression the Operator belongs
// to.
Validate(e *Expression) error

View File

@@ -96,7 +96,7 @@ func (LinComb) Degree(inputDegrees []int) int {
}
// Evaluate implements the [Operator] interface.
func (lc LinComb) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
func (lc LinComb) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
return sv.LinComb(lc.Coeffs, inputs, p...)
}

View File

@@ -62,7 +62,7 @@ func (PolyEval) Degree(inputDegrees []int) int {
/*
Evaluates a polynomial evaluation
*/
func (PolyEval) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
func (PolyEval) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
// We assume that the first element is always a scalar
// Get the constant value. We use Get(0) to get the value, but any integer would
// also work provided it is also in range. 0 ensures that.

View File

@@ -124,7 +124,7 @@ func (prod Product) Degree(inputDegrees []int) int {
}
// Evaluate implements the [Operator] interface.
func (prod Product) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
func (prod Product) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
return sv.Product(prod.Exponents, inputs, p...)
}

View File

@@ -34,7 +34,7 @@ func (Variable) Degree([]int) int {
}
// Evaluate implements the [Operator] interface. Yet, this panics if this is called.
func (v Variable) Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector {
func (v Variable) Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector {
panic("we never call it for variables")
}

View File

@@ -0,0 +1,41 @@
package parallel
import (
"runtime"
"sync"
)
// ExecuteJobStealing parallelizes a workload specified by a function consuming
// a channel distributing the workload. It is appropriate when each iteration
// takes an order of magnitude more time than the other functions.
//
// This is as [ExecuteChunky] but gives more freedom to the caller to initialize
// its threads.
func ExecuteFromChan(nbIterations int, work func(wg *sync.WaitGroup, indexChan chan int), numcpus ...int) {
numcpu := runtime.GOMAXPROCS(0)
if len(numcpus) > 0 && numcpus[0] > 0 {
numcpu = numcpus[0]
}
// The jobs are sent one by one to the workers
jobChan := make(chan int, nbIterations)
for i := 0; i < nbIterations; i++ {
jobChan <- i
}
// The wait group ensures that all the children goroutine have terminated
// before we close the
wg := &sync.WaitGroup{}
wg.Add(nbIterations)
// Each goroutine consumes the jobChan to
for p := 0; p < numcpu; p++ {
go func() {
work(wg, jobChan)
}()
}
wg.Wait()
close(jobChan)
}