mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 23:47:55 -05:00
Prover: optimize memory allocations in symbolic expression (#3910)
* feat(pool): implements a multi-level caching system for the pool * fix(sym): make sure that all the nodes are freed * feat(pool): implements a DebugPool to debug leakage to gc * feat(parallel): implements a channel based parallelization template * test(pool): ensure that symbolic evaluation does not leak to gc * fixup(par): adds a wg to the workload worker * create an prover-checker that can be used for checking * feat(check-only): adds a check-only mode in the config * test(smartvectors): update the tests * chores(makefile): rm corset flags from the unneeding commands * remove the change on config-benchmark * perf(pool): Use the pool for regular massive FFT workload
This commit is contained in:
@@ -134,14 +134,14 @@ lib/compressor-and-shnarf-calculator-local: lib/compressor lib/shnarf_calculator
|
||||
##
|
||||
## Run all the unit-tests
|
||||
##
|
||||
test: zkevm/arithmetization/zkevm.bin
|
||||
$(CORSET_FLAGS) go test ./...
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
##
|
||||
## Run the CI linting
|
||||
##
|
||||
ci-lint: zkevm/arithmetization/zkevm.bin
|
||||
$(CORSET_FLAGS) golangci-lint run --timeout 5m
|
||||
golangci-lint run --timeout 5m
|
||||
|
||||
##
|
||||
## Echo, the CGO flags. Usefull for testing manually
|
||||
|
||||
146
prover/maths/common/mempool/debug_pool.go
Normal file
146
prover/maths/common/mempool/debug_pool.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils"
|
||||
)
|
||||
|
||||
type DebuggeableCall struct {
|
||||
Parent MemPool
|
||||
Logs map[uintptr]*[]Record
|
||||
}
|
||||
|
||||
func NewDebugPool(p MemPool) *DebuggeableCall {
|
||||
return &DebuggeableCall{
|
||||
Parent: p,
|
||||
Logs: make(map[uintptr]*[]Record),
|
||||
}
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
Where string
|
||||
What recordType
|
||||
}
|
||||
|
||||
func (m *DebuggeableCall) Prewarm(nbPrewarm int) MemPool {
|
||||
m.Parent.Prewarm(nbPrewarm)
|
||||
return m
|
||||
}
|
||||
|
||||
type recordType string
|
||||
|
||||
const (
|
||||
AllocRecord recordType = "alloc"
|
||||
FreeRecord recordType = "free"
|
||||
)
|
||||
|
||||
func (m *DebuggeableCall) Alloc() *[]field.Element {
|
||||
|
||||
var (
|
||||
v = m.Parent.Alloc()
|
||||
uptr = uintptr(unsafe.Pointer(v))
|
||||
logs *[]Record
|
||||
_, file, line, _ = runtime.Caller(2)
|
||||
)
|
||||
|
||||
logs, found := m.Logs[uptr]
|
||||
|
||||
if !found {
|
||||
logs = &[]Record{}
|
||||
m.Logs[uptr] = logs
|
||||
}
|
||||
|
||||
*logs = append(*logs, Record{
|
||||
Where: file + ":" + strconv.Itoa(line),
|
||||
What: AllocRecord,
|
||||
})
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (m *DebuggeableCall) Free(v *[]field.Element) error {
|
||||
|
||||
var (
|
||||
uptr = uintptr(unsafe.Pointer(v))
|
||||
logs *[]Record
|
||||
_, file, line, _ = runtime.Caller(2)
|
||||
)
|
||||
|
||||
logs, found := m.Logs[uptr]
|
||||
|
||||
if !found {
|
||||
logs = &[]Record{}
|
||||
m.Logs[uptr] = logs
|
||||
}
|
||||
|
||||
*logs = append(*logs, Record{
|
||||
Where: file + ":" + strconv.Itoa(line),
|
||||
What: FreeRecord,
|
||||
})
|
||||
|
||||
return m.Parent.Free(v)
|
||||
}
|
||||
|
||||
func (m *DebuggeableCall) Size() int {
|
||||
return m.Parent.Size()
|
||||
}
|
||||
|
||||
func (m *DebuggeableCall) TearDown() {
|
||||
if p, ok := m.Parent.(*SliceArena); ok {
|
||||
p.TearDown()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DebuggeableCall) Errors() error {
|
||||
|
||||
var err error
|
||||
|
||||
for _, logs_ := range m.Logs {
|
||||
|
||||
if logs_ == nil || len(*logs_) == 0 {
|
||||
utils.Panic("got a nil entry")
|
||||
}
|
||||
|
||||
logs := *logs_
|
||||
|
||||
for i := range logs {
|
||||
if i == 0 && logs[i].What == FreeRecord {
|
||||
err = errors.Join(err, fmt.Errorf("freed a vector that was not from the pool: where=%v", logs[i].Where))
|
||||
}
|
||||
|
||||
if i == len(logs)-1 && logs[i].What == AllocRecord {
|
||||
err = errors.Join(err, fmt.Errorf("leaked a vector out of the pool: where=%v", logs[i].Where))
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if logs[i-1].What == AllocRecord && logs[i].What == AllocRecord {
|
||||
wheres := []string{logs[i-1].Where, logs[i].Where}
|
||||
for k := i + 1; k < len(logs) && logs[k].What == AllocRecord; k++ {
|
||||
wheres = append(wheres, logs[k].Where)
|
||||
}
|
||||
|
||||
err = errors.Join(err, fmt.Errorf("vector was allocated multiple times concurrently where=%v", wheres))
|
||||
}
|
||||
|
||||
if logs[i-1].What == FreeRecord && logs[i].What == FreeRecord {
|
||||
wheres := []string{logs[i-1].Where, logs[i].Where}
|
||||
for k := i + 1; k < len(logs) && logs[k].What == FreeRecord; k++ {
|
||||
wheres = append(wheres, logs[k].Where)
|
||||
}
|
||||
|
||||
err = errors.Join(err, fmt.Errorf("vector was freed multiple times concurrently where=%v", wheres))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
52
prover/maths/common/mempool/debug_pool_test.go
Normal file
52
prover/maths/common/mempool/debug_pool_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDebugPool(t *testing.T) {
|
||||
|
||||
t.Run("leak-detection", func(t *testing.T) {
|
||||
|
||||
pool := NewDebugPool(CreateFromSyncPool(32))
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
func() {
|
||||
_ = pool.Alloc()
|
||||
}()
|
||||
}
|
||||
|
||||
err := pool.Errors().Error()
|
||||
assert.True(t, strings.HasPrefix(err, "leaked a vector out of the pool"))
|
||||
})
|
||||
|
||||
t.Run("double-free", func(t *testing.T) {
|
||||
|
||||
pool := NewDebugPool(CreateFromSyncPool(32))
|
||||
|
||||
v := pool.Alloc()
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
pool.Free(v)
|
||||
}
|
||||
|
||||
err := pool.Errors().Error()
|
||||
assert.Truef(t, strings.HasPrefix(err, "vector was freed multiple times concurrently"), err)
|
||||
})
|
||||
|
||||
t.Run("foreign-free", func(t *testing.T) {
|
||||
|
||||
pool := NewDebugPool(CreateFromSyncPool(32))
|
||||
|
||||
v := make([]field.Element, 32)
|
||||
pool.Free(&v)
|
||||
|
||||
err := pool.Errors().Error()
|
||||
assert.Truef(t, strings.HasPrefix(err, "freed a vector that was not from the pool"), err)
|
||||
})
|
||||
|
||||
}
|
||||
73
prover/maths/common/mempool/from_sync_pool.go
Normal file
73
prover/maths/common/mempool/from_sync_pool.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils/parallel"
|
||||
)
|
||||
|
||||
// FromSyncPool pools the allocation for slices of [field.Element] of size `Size`.
|
||||
// It should be used with great caution and every slice allocated via this pool
|
||||
// must be manually freed and only once.
|
||||
//
|
||||
// FromSyncPool is used to reduce the number of allocation which can be significant
|
||||
// when doing operations over field elements.
|
||||
type FromSyncPool struct {
|
||||
size int
|
||||
P sync.Pool
|
||||
}
|
||||
|
||||
// CreateFromSyncPool initializes the Pool with the given number of elements in it.
|
||||
func CreateFromSyncPool(size int) *FromSyncPool {
|
||||
// Initializes the pool
|
||||
return &FromSyncPool{
|
||||
size: size,
|
||||
P: sync.Pool{
|
||||
New: func() any {
|
||||
res := make([]field.Element, size)
|
||||
return &res
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Prewarm the Pool by preallocating `nbPrewarm` in it.
|
||||
func (p *FromSyncPool) Prewarm(nbPrewarm int) MemPool {
|
||||
prewarmed := make([]field.Element, p.size*nbPrewarm)
|
||||
parallel.Execute(nbPrewarm, func(start, stop int) {
|
||||
for i := start; i < stop; i++ {
|
||||
vec := prewarmed[i*p.size : (i+1)*p.size]
|
||||
p.P.Put(&vec)
|
||||
}
|
||||
})
|
||||
return p
|
||||
}
|
||||
|
||||
// Alloc returns a vector allocated from the pool. Vector allocated via the
|
||||
// pool should ideally be returned to the pool. If not, they are still going to
|
||||
// be picked up by the GC.
|
||||
func (p *FromSyncPool) Alloc() *[]field.Element {
|
||||
res := p.P.Get().(*[]field.Element)
|
||||
return res
|
||||
}
|
||||
|
||||
// Free returns an object to the pool. It must never be called twice over
|
||||
// the same object or undefined behaviours are going to arise. It is fine to
|
||||
// pass objects allocated to outside of the pool as long as they have the right
|
||||
// dimension.
|
||||
func (p *FromSyncPool) Free(vec *[]field.Element) error {
|
||||
// Check the vector has the right size
|
||||
if len(*vec) != p.size {
|
||||
utils.Panic("expected size %v, expected %v", len(*vec), p.Size())
|
||||
}
|
||||
|
||||
p.P.Put(vec)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *FromSyncPool) Size() int {
|
||||
return p.size
|
||||
}
|
||||
60
prover/maths/common/mempool/mempool.go
Normal file
60
prover/maths/common/mempool/mempool.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils"
|
||||
)
|
||||
|
||||
type MemPool interface {
|
||||
Prewarm(nbPrewarm int) MemPool
|
||||
Alloc() *[]field.Element
|
||||
Free(vec *[]field.Element) error
|
||||
Size() int
|
||||
}
|
||||
|
||||
// ExtractCheckOptionalStrict returns
|
||||
// - p[0], true if the expectedSize matches the one of the provided pool
|
||||
// - nil, false if no `p` is provided
|
||||
// - panic if the assigned size of the pool does not match
|
||||
// - panic if the caller provides `nil` as argument for `p`
|
||||
//
|
||||
// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an
|
||||
// optional variadic parameter and at the same time validating that the pool
|
||||
// object has the right size.
|
||||
func ExtractCheckOptionalStrict(expectedSize int, p ...MemPool) (pool MemPool, ok bool) {
|
||||
// Checks if there is a pool
|
||||
hasPool := len(p) > 0 && p[0] != nil
|
||||
if hasPool {
|
||||
pool = p[0]
|
||||
}
|
||||
|
||||
// Sanity-check that the size of the pool is actually what we expected
|
||||
if hasPool && pool.Size() != expectedSize {
|
||||
utils.Panic("pooled vector size are %v, but required %v", pool.Size(), expectedSize)
|
||||
}
|
||||
|
||||
return pool, hasPool
|
||||
}
|
||||
|
||||
// ExtractCheckOptionalSoft returns
|
||||
// - p[0], true if the expectedSize matches the one of the provided pool
|
||||
// - nil, false if no `p` is provided
|
||||
// - nil, false if the length of the vector does not match the one of the pool
|
||||
// - panic if the caller provides `nil` as argument for `p`
|
||||
//
|
||||
// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an
|
||||
// optional variadic parameter.
|
||||
func ExtractCheckOptionalSoft(expectedSize int, p ...MemPool) (pool MemPool, ok bool) {
|
||||
// Checks if there is a pool
|
||||
hasPool := len(p) > 0
|
||||
if hasPool {
|
||||
pool = p[0]
|
||||
}
|
||||
|
||||
// Sanity-check that the size of the pool is actually what we expected
|
||||
if hasPool && pool.Size() != expectedSize {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return pool, hasPool
|
||||
}
|
||||
51
prover/maths/common/mempool/slice_arena.go
Normal file
51
prover/maths/common/mempool/slice_arena.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
)
|
||||
|
||||
// SliceArena is a simple not-threadsafe arena implementation that uses a
|
||||
// mempool to carry its allocation. It will only put back free memory in the
|
||||
// the parent pool when TearDown is called.
|
||||
type SliceArena struct {
|
||||
frees []*[]field.Element
|
||||
parent MemPool
|
||||
}
|
||||
|
||||
func WrapsWithMemCache(pool MemPool) *SliceArena {
|
||||
return &SliceArena{
|
||||
frees: make([]*[]field.Element, 0, 1<<7),
|
||||
parent: pool,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SliceArena) Prewarm(nbPrewarm int) MemPool {
|
||||
m.parent.Prewarm(nbPrewarm)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *SliceArena) Alloc() *[]field.Element {
|
||||
|
||||
if len(m.frees) == 0 {
|
||||
return m.parent.Alloc()
|
||||
}
|
||||
|
||||
last := m.frees[len(m.frees)-1]
|
||||
m.frees = m.frees[:len(m.frees)-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func (m *SliceArena) Free(v *[]field.Element) error {
|
||||
m.frees = append(m.frees, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SliceArena) Size() int {
|
||||
return m.parent.Size()
|
||||
}
|
||||
|
||||
func (m *SliceArena) TearDown() {
|
||||
for i := range m.frees {
|
||||
m.parent.Free(m.frees[i])
|
||||
}
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils/parallel"
|
||||
)
|
||||
|
||||
// Pool pools the allocation for slices of [field.Element] of size `Size`. It
|
||||
// should be used with great caution and every slice allocated via this pool
|
||||
// must be manually freed and only once.
|
||||
//
|
||||
// Pool is used to reduce the number of allocation which can be significant
|
||||
// when doing operations over field elements.
|
||||
type Pool struct {
|
||||
Size int
|
||||
P sync.Pool
|
||||
}
|
||||
|
||||
// Create initializes the Pool with the given number of elements in it.
|
||||
func Create(size int) *Pool {
|
||||
// Initializes the pool
|
||||
return &Pool{
|
||||
Size: size,
|
||||
P: sync.Pool{
|
||||
New: func() any {
|
||||
res := make([]field.Element, size)
|
||||
return &res
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Prewarm the Pool by preallocating `nbPrewarm` in it.
|
||||
func (p *Pool) Prewarm(nbPrewarm int) *Pool {
|
||||
prewarmed := make([]field.Element, p.Size*nbPrewarm)
|
||||
parallel.Execute(nbPrewarm, func(start, stop int) {
|
||||
for i := start; i < stop; i++ {
|
||||
vec := prewarmed[i*p.Size : (i+1)*p.Size]
|
||||
p.P.Put(&vec)
|
||||
}
|
||||
})
|
||||
return p
|
||||
}
|
||||
|
||||
// Alloc returns a vector allocated from the pool. Vector allocated via the
|
||||
// pool should ideally be returned to the pool. If not, they are still going to
|
||||
// be picked up by the GC.
|
||||
func (p *Pool) Alloc() *[]field.Element {
|
||||
res := p.P.Get().(*[]field.Element)
|
||||
return res
|
||||
}
|
||||
|
||||
// Free returns an object to the pool. It must never be called twice over
|
||||
// the same object or undefined behaviours are going to arise. It is fine to
|
||||
// pass objects allocated to outside of the pool as long as they have the right
|
||||
// dimension.
|
||||
func (p *Pool) Free(vec *[]field.Element) error {
|
||||
// Check the vector has the right size
|
||||
if len(*vec) != p.Size {
|
||||
utils.Panic("expected size %v, expected %v", len(*vec), p.Size)
|
||||
}
|
||||
|
||||
p.P.Put(vec)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractCheckOptionalStrict returns
|
||||
// - p[0], true if the expectedSize matches the one of the provided pool
|
||||
// - nil, false if no `p` is provided
|
||||
// - panic if the assigned size of the pool does not match
|
||||
// - panic if the caller provides `nil` as argument for `p`
|
||||
//
|
||||
// This is used to unwrap a [Pool] that is commonly passed to functions as an
|
||||
// optional variadic parameter and at the same time validating that the pool
|
||||
// object has the right size.
|
||||
func ExtractCheckOptionalStrict(expectedSize int, p ...*Pool) (pool *Pool, ok bool) {
|
||||
// Checks if there is a pool
|
||||
hasPool := len(p) > 0 && p[0] != nil
|
||||
if hasPool {
|
||||
pool = p[0]
|
||||
}
|
||||
|
||||
// Sanity-check that the size of the pool is actually what we expected
|
||||
if hasPool && pool.Size != expectedSize {
|
||||
utils.Panic("pooled vector size are %v, but required %v", pool.Size, expectedSize)
|
||||
}
|
||||
|
||||
return pool, hasPool
|
||||
}
|
||||
|
||||
// ExtractCheckOptionalSoft returns
|
||||
// - p[0], true if the expectedSize matches the one of the provided pool
|
||||
// - nil, false if no `p` is provided
|
||||
// - nil, false if the length of the vector does not match the one of the pool
|
||||
// - panic if the caller provides `nil` as argument for `p`
|
||||
//
|
||||
// This is used to unwrap a [Pool] that is commonly passed to functions as an
|
||||
// optional variadic parameter.
|
||||
func ExtractCheckOptionalSoft(expectedSize int, p ...*Pool) (pool *Pool, ok bool) {
|
||||
// Checks if there is a pool
|
||||
hasPool := len(p) > 0
|
||||
if hasPool {
|
||||
pool = p[0]
|
||||
}
|
||||
|
||||
// Sanity-check that the size of the pool is actually what we expected
|
||||
if hasPool && pool.Size != expectedSize {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return pool, hasPool
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func InnerProduct(a, b SmartVector) field.Element {
|
||||
// result = vecs[0] + vecs[1] * x + vecs[2] * x^2 + vecs[3] * x^3 + ...
|
||||
//
|
||||
// where `x` is a scalar and `vecs[i]` are [SmartVector]
|
||||
func PolyEval(vecs []SmartVector, x field.Element, p ...*mempool.Pool) (result SmartVector) {
|
||||
func PolyEval(vecs []SmartVector, x field.Element, p ...mempool.MemPool) (result SmartVector) {
|
||||
|
||||
if len(vecs) == 0 {
|
||||
panic("no input vectors")
|
||||
@@ -80,10 +80,11 @@ func PolyEval(vecs []SmartVector, x field.Element, p ...*mempool.Pool) (result S
|
||||
resReg = make([]field.Element, length)
|
||||
tmpVec = make([]field.Element, length)
|
||||
} else {
|
||||
a, b := pool.Alloc(), pool.Alloc()
|
||||
resReg, tmpVec = *a, *b
|
||||
a := AllocFromPool(pool)
|
||||
b := AllocFromPool(pool)
|
||||
resReg, tmpVec = a.Regular, b.Regular
|
||||
vector.Fill(resReg, field.Zero())
|
||||
defer pool.Free(b)
|
||||
defer b.Free(pool)
|
||||
}
|
||||
|
||||
var tmpF, resCon field.Element
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// - The function panics if svecs is empty
|
||||
// - The function panics if the length of coeffs does not match the length of
|
||||
// svecs
|
||||
func LinComb(coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
|
||||
func LinComb(coeffs []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
|
||||
// Sanity check : all svec should have the same length
|
||||
length := svecs[0].Len()
|
||||
for i := 0; i < len(svecs); i++ {
|
||||
@@ -27,7 +27,7 @@ func LinComb(coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector
|
||||
// - The function panics if svecs is empty
|
||||
// - The function panics if the length of exponents does not match the length of
|
||||
// svecs
|
||||
func Product(exponents []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
|
||||
func Product(exponents []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
|
||||
return processOperator(productOp{}, exponents, svecs, p...)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func Product(exponents []int, svecs []SmartVector, p ...*mempool.Pool) SmartVect
|
||||
// - The function panics if svecs is empty
|
||||
// - The function panics if the length of coeffs does not match the length of
|
||||
// svecs
|
||||
func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempool.Pool) SmartVector {
|
||||
func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...mempool.MemPool) SmartVector {
|
||||
|
||||
// There should be as many coeffs than there are vectors
|
||||
if len(coeffs) != len(svecs) {
|
||||
@@ -114,13 +114,13 @@ func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempo
|
||||
case matchedRegular+matchedConst == totalToMatch:
|
||||
// In this case, there are no windowed in the list. This means we only
|
||||
// need to merge the const one into the regular one before returning
|
||||
op.constTermIntoVec(*regularRes, &constRes.val)
|
||||
op.constTermIntoVec(regularRes.Regular, &constRes.val)
|
||||
return regularRes
|
||||
default:
|
||||
|
||||
// If windowRes is a regular (can happen if all windows arguments cover the full circle)
|
||||
if w, ok := windowRes.(*Regular); ok {
|
||||
op.vecTermIntoVec(*regularRes, *w)
|
||||
op.vecTermIntoVec(regularRes.Regular, *w)
|
||||
return regularRes
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...*mempo
|
||||
// In this case, the constant is already accumulated into the windowed.
|
||||
// Thus, we just have to merge the windowed one into the regular one.
|
||||
interval := windowRes.interval()
|
||||
regvec := *regularRes
|
||||
regvec := regularRes.Regular
|
||||
length := len(regvec)
|
||||
|
||||
// The windows rolls over
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestFuzzProductWithPool(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
t.Logf("TEST CASE %v\n", tcase.String())
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestFuzzProductWithPoolCompare(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
t.Logf("TEST CASE %v\n", tcase.String())
|
||||
|
||||
@@ -142,7 +142,7 @@ func TestFuzzLinCombWithPool(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
t.Logf("TEST CASE %v\n", tcase.String())
|
||||
|
||||
@@ -159,7 +159,6 @@ func TestFuzzLinCombWithPool(t *testing.T) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFuzzLinCombWithPoolCompare(t *testing.T) {
|
||||
@@ -169,7 +168,7 @@ func TestFuzzLinCombWithPoolCompare(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
t.Logf("TEST CASE %v\n", tcase.String())
|
||||
|
||||
@@ -237,8 +236,8 @@ func TestOpBasicEdgeCases(t *testing.T) {
|
||||
for i, testCase := range testCases {
|
||||
t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) {
|
||||
t.Logf("test-case details: %v", testCase.explainer)
|
||||
res := testCase.fn(testCase.inputs...)
|
||||
require.Equal(t, testCase.expectedRes, res, "expectedRes=%v\nres=%v", testCase.expectedRes.Pretty(), res.Pretty())
|
||||
res := testCase.fn(testCase.inputs...).(*Pooled)
|
||||
require.Equal(t, testCase.expectedRes, &res.Regular, "expectedRes=%v\nres=%v", testCase.expectedRes.Pretty(), res.Pretty())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -290,7 +289,7 @@ func TestFuzzPolyEvalWithPool(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
// PolyEval() with pool
|
||||
polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool)
|
||||
@@ -315,7 +314,7 @@ func TestFuzzPolyEvalWithPoolCompare(t *testing.T) {
|
||||
|
||||
success := t.Run(tcase.name, func(t *testing.T) {
|
||||
|
||||
pool := mempool.Create(tcase.svecs[0].Len())
|
||||
pool := mempool.CreateFromSyncPool(tcase.svecs[0].Len())
|
||||
|
||||
// PolyEval() with pool
|
||||
polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool)
|
||||
|
||||
@@ -77,13 +77,13 @@ func (r *Regular) Pretty() string {
|
||||
return fmt.Sprintf("Regular[%v]", vector.Prettify(*r))
|
||||
}
|
||||
|
||||
func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*mempool.Pool) (result *Regular, numMatches int) {
|
||||
func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...mempool.MemPool) (result *Pooled, numMatches int) {
|
||||
|
||||
length := svecs[0].Len()
|
||||
|
||||
pool, hasPool := mempool.ExtractCheckOptionalStrict(length, p...)
|
||||
|
||||
var resvec []field.Element
|
||||
var resvec *Pooled
|
||||
|
||||
isFirst := true
|
||||
numMatches = 0
|
||||
@@ -97,6 +97,10 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
|
||||
svec = rotatedAsRegular(rot)
|
||||
}
|
||||
|
||||
if pooled, ok := svec.(*Pooled); ok {
|
||||
svec = &pooled.Regular
|
||||
}
|
||||
|
||||
if reg, ok := svec.(*Regular); ok {
|
||||
numMatches++
|
||||
// For the first one, we can save by just copying the result
|
||||
@@ -104,16 +108,17 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
|
||||
// zero.
|
||||
if isFirst {
|
||||
if hasPool {
|
||||
resvec = *pool.Alloc()
|
||||
resvec = AllocFromPool(pool)
|
||||
} else {
|
||||
resvec = make([]field.Element, length)
|
||||
resvec = &Pooled{Regular: make([]field.Element, length)}
|
||||
}
|
||||
|
||||
isFirst = false
|
||||
op.vecIntoTerm(resvec, *reg, coeffs[i])
|
||||
op.vecIntoTerm(resvec.Regular, *reg, coeffs[i])
|
||||
continue
|
||||
}
|
||||
op.vecIntoVec(resvec, *reg, coeffs[i])
|
||||
|
||||
op.vecIntoVec(resvec.Regular, *reg, coeffs[i])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,8 +126,7 @@ func processRegularOnly(op operator, svecs []SmartVector, coeffs []int, p ...*me
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
res := Regular(resvec)
|
||||
return &res, numMatches
|
||||
return resvec, numMatches
|
||||
}
|
||||
|
||||
func (r *Regular) DeepCopy() SmartVector {
|
||||
@@ -134,3 +138,24 @@ func (r *Regular) DeepCopy() SmartVector {
|
||||
func (r *Regular) IntoRegVecSaveAlloc() []field.Element {
|
||||
return (*r)[:]
|
||||
}
|
||||
|
||||
type Pooled struct {
|
||||
Regular
|
||||
poolPtr *[]field.Element
|
||||
}
|
||||
|
||||
func AllocFromPool(pool mempool.MemPool) *Pooled {
|
||||
poolPtr := pool.Alloc()
|
||||
return &Pooled{
|
||||
Regular: *poolPtr,
|
||||
poolPtr: poolPtr,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pooled) Free(pool mempool.MemPool) {
|
||||
if p.poolPtr != nil {
|
||||
pool.Free(p.poolPtr)
|
||||
}
|
||||
p.poolPtr = nil
|
||||
p.Regular = nil
|
||||
}
|
||||
|
||||
@@ -204,7 +204,8 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
stopTimer = profiling.LogTimer("Computing the coeffs %v pols of size %v", len(ctx.AllInvolvedColumns), ctx.DomainSize)
|
||||
lock = sync.Mutex{}
|
||||
lockRun = sync.Mutex{}
|
||||
pool = mempool.Create(symbolic.MaxChunkSize).Prewarm(runtime.NumCPU() * ctx.MaxNbExprNode)
|
||||
pool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize).Prewarm(runtime.NumCPU() * ctx.MaxNbExprNode)
|
||||
largePool = mempool.CreateFromSyncPool(ctx.DomainSize).Prewarm(len(ctx.AllInvolvedColumns))
|
||||
)
|
||||
|
||||
if ctx.DomainSize >= GC_DOMAIN_SIZE {
|
||||
@@ -236,7 +237,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
if !isNatural {
|
||||
witness = pol.GetColAssignment(run)
|
||||
}
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0)
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
|
||||
|
||||
lock.Lock()
|
||||
coeffs[name] = witness
|
||||
@@ -261,7 +262,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
|
||||
// normal case for interleaved or repeated columns
|
||||
witness := pol.GetColAssignment(run)
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0)
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
|
||||
name := pol.GetColID()
|
||||
lock.Lock()
|
||||
coeffs[name] = witness
|
||||
@@ -356,8 +357,12 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
|
||||
stopTimer := profiling.LogTimer("ReEvaluate %v pols of size %v on coset %v/%v", len(handles), ctx.DomainSize, share, ratio)
|
||||
|
||||
parallel.ExecuteChunky(len(roots), func(start, stop int) {
|
||||
for k := start; k < stop; k++ {
|
||||
parallel.ExecuteFromChan(len(roots), func(wg *sync.WaitGroup, indexChan chan int) {
|
||||
|
||||
localPool := mempool.WrapsWithMemCache(largePool)
|
||||
defer localPool.TearDown()
|
||||
|
||||
for k := range indexChan {
|
||||
root := roots[k]
|
||||
name := root.GetColID()
|
||||
|
||||
@@ -365,18 +370,25 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
|
||||
if found {
|
||||
// it was already computed in a previous iteration of `j`
|
||||
wg.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
// else it's the first value of j that sees it. so we compute the
|
||||
// coset reevaluation.
|
||||
reevaledRoot := sv.FFT(coeffs[name], fft.DIT, false, ratio, share)
|
||||
reevaledRoot := sv.FFT(coeffs[name], fft.DIT, false, ratio, share, localPool)
|
||||
computedReeval.Store(name, reevaledRoot)
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
})
|
||||
|
||||
parallel.ExecuteChunky(len(handles), func(start, stop int) {
|
||||
for k := start; k < stop; k++ {
|
||||
parallel.ExecuteFromChan(len(handles), func(wg *sync.WaitGroup, indexChan chan int) {
|
||||
|
||||
localPool := mempool.WrapsWithMemCache(largePool)
|
||||
defer localPool.TearDown()
|
||||
|
||||
for k := range indexChan {
|
||||
|
||||
pol := handles[k]
|
||||
// short-path, the column is a purely Shifted(Natural) or a Natural
|
||||
@@ -397,6 +409,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
// Now, we can reuse a soft-rotation of the smart-vector to save memory
|
||||
if !pol.IsComposite() {
|
||||
// in this case, the right vector was the root so we are done
|
||||
wg.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -404,6 +417,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
polName := pol.GetColID()
|
||||
res := sv.SoftRotate(reevaledRoot.(sv.SmartVector), shifted.Offset)
|
||||
computedReeval.Store(polName, res)
|
||||
wg.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -412,14 +426,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
name := pol.GetColID()
|
||||
_, ok := computedReeval.Load(name)
|
||||
if ok {
|
||||
wg.Done()
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := coeffs[name]; !ok {
|
||||
utils.Panic("handle %v not found in the coeffs\n", name)
|
||||
}
|
||||
res := sv.FFT(coeffs[name], fft.DIT, false, ratio, share)
|
||||
|
||||
res := sv.FFT(coeffs[name], fft.DIT, false, ratio, share, localPool)
|
||||
computedReeval.Store(name, res)
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
})
|
||||
|
||||
@@ -477,6 +495,11 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
|
||||
// Forcefuly clean the memory for the computed reevals
|
||||
computedReeval.Range(func(k, v interface{}) bool {
|
||||
|
||||
if pooled, ok := v.(*sv.Pooled); ok {
|
||||
pooled.Free(largePool)
|
||||
}
|
||||
|
||||
computedReeval.Delete(k)
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestMPTS(t *testing.T) {
|
||||
|
||||
hLProver := func(assi *wizard.ProverRuntime) {
|
||||
p1 := smartvectors.Rand(PolSize)
|
||||
p2 := smartvectors.Rand(PolSize)
|
||||
p2 := smartvectors.NewConstant(field.NewElement(42), PolSize)
|
||||
|
||||
assi.AssignColumn(P1, p1)
|
||||
assi.AssignColumn(P2, p2)
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/common/mempool"
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/common/poly"
|
||||
sv "github.com/consensys/zkevm-monorepo/prover/maths/common/smartvectors"
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/fft"
|
||||
@@ -109,6 +111,8 @@ type mptsCtx struct {
|
||||
// maxSize int
|
||||
quotientSize int
|
||||
targetSize int
|
||||
// maxSize is the size of the largest column being processed by the compiler
|
||||
maxSize int
|
||||
// Number of rounds in the original protocol
|
||||
numRound int
|
||||
// Various indentifiers created in the protocol
|
||||
@@ -129,7 +133,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
|
||||
xPoly := make(map[ifaces.ColID][]int)
|
||||
hs := []ifaces.QueryID{}
|
||||
polys := []ifaces.Column{}
|
||||
maxDeg := 0
|
||||
maxSize := 0
|
||||
|
||||
/*
|
||||
Adding coins in the protocol can add extra rounds,
|
||||
@@ -182,7 +186,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
|
||||
if _, ok := xPoly[poly.GetColID()]; !ok {
|
||||
polys = append(polys, poly)
|
||||
xPoly[poly.GetColID()] = []int{}
|
||||
maxDeg = utils.Max(maxDeg, poly.Size())
|
||||
maxSize = utils.Max(maxSize, poly.Size())
|
||||
}
|
||||
xPoly[poly.GetColID()] = append(xPoly[poly.GetColID()], queryCount-1)
|
||||
}
|
||||
@@ -207,6 +211,7 @@ func createMptsCtx(comp *wizard.CompiledIOP, targetSize int) mptsCtx {
|
||||
quotientSize: quotientSize,
|
||||
targetSize: targetSize,
|
||||
numRound: numRound,
|
||||
maxSize: maxSize,
|
||||
QuotientName: deriveName[ifaces.ColID](comp, MPTS, "", MTSP_QUOTIENT_SUFFIX),
|
||||
LinCombCoeff: deriveName[coin.Name](comp, MPTS, "", MTSP_LIN_COMB),
|
||||
EvaluationPoint: deriveName[coin.Name](comp, MPTS, "", MTSP_RAND_EVAL),
|
||||
@@ -226,31 +231,48 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
|
||||
r := run.GetRandomCoinField(ctx.LinCombCoeff)
|
||||
|
||||
// Preallocate the value of the quotient
|
||||
quotient := sv.AllocateRegular(ctx.quotientSize)
|
||||
quoLock := sync.Mutex{}
|
||||
ys, hs := ctx.getYsHs(run.GetUnivariateParams, run.GetUnivariateEval)
|
||||
logrus.Tracef("computed Ys and Hs")
|
||||
var (
|
||||
quotient = sv.AllocateRegular(ctx.quotientSize)
|
||||
quoLock = sync.Mutex{}
|
||||
ys, hs = ctx.getYsHs(run.GetUnivariateParams, run.GetUnivariateEval)
|
||||
|
||||
// precompute the lagrange polynomials
|
||||
lagranges := getLagrangesPolys(hs)
|
||||
// precompute the lagrange polynomials
|
||||
lagranges = getLagrangesPolys(hs)
|
||||
pool = mempool.CreateFromSyncPool(ctx.targetSize).Prewarm(runtime.NumCPU())
|
||||
mainWg = &sync.WaitGroup{}
|
||||
)
|
||||
|
||||
parallel.Execute(len(ctx.polys), func(start, stop int) {
|
||||
mainWg.Add(runtime.NumCPU())
|
||||
|
||||
maxSize := 0
|
||||
for i := start; i < stop; i++ {
|
||||
maxSize = utils.Max(maxSize, ctx.polys[i].Size())
|
||||
}
|
||||
parallel.ExecuteFromChan(len(ctx.polys), func(wg *sync.WaitGroup, indexChan chan int) {
|
||||
|
||||
// Preallocate the value of the quotient
|
||||
subQuotient := sv.AllocateRegular(maxSize)
|
||||
var (
|
||||
pool = mempool.WrapsWithMemCache(pool)
|
||||
|
||||
for i := start; i < stop; i++ {
|
||||
polHandle := ctx.polys[i]
|
||||
/*
|
||||
Get the coefficients of the witness polynomial
|
||||
*/
|
||||
polWitness := polHandle.GetColAssignment(run)
|
||||
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0)
|
||||
// Preallocate the value of the quotient
|
||||
subQuotientReg = sv.AllocateRegular(ctx.maxSize)
|
||||
subQuotientCnst = field.Zero()
|
||||
)
|
||||
|
||||
defer pool.TearDown()
|
||||
|
||||
for i := range indexChan {
|
||||
|
||||
var (
|
||||
polHandle = ctx.polys[i]
|
||||
polWitness = polHandle.GetColAssignment(run)
|
||||
ri field.Element
|
||||
)
|
||||
|
||||
ri.Exp(r, big.NewInt(int64(i)))
|
||||
|
||||
if cnst, isCnst := polWitness.(*sv.Constant); isCnst {
|
||||
polWitness = sv.NewRegular([]field.Element{cnst.Val()})
|
||||
} else if pool.Size() == polWitness.Len() {
|
||||
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0, pool)
|
||||
} else {
|
||||
polWitness = sv.FFTInverse(polWitness, fft.DIF, true, 0, 0, nil)
|
||||
}
|
||||
|
||||
/*
|
||||
Substract by the lagrange interpolator and get the quotient
|
||||
@@ -298,32 +320,37 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Then accumulate the subQuotient. The subQuotient is supposedly
|
||||
larger than the pol,
|
||||
*/
|
||||
var ri field.Element
|
||||
ri.Exp(r, big.NewInt(int64(i)))
|
||||
|
||||
// Should be very uncommon
|
||||
if subQuotient.Len() < polWitness.Len() {
|
||||
if subQuotientReg.Len() < polWitness.Len() {
|
||||
logrus.Warnf("Warning reallocation of the subquotient for MPTS. If there are too many it's an issue")
|
||||
// It's the only known use-case for concatenating smart-vectors
|
||||
newSubquotient := make([]field.Element, polWitness.Len())
|
||||
subQuotient.WriteInSlice(newSubquotient[:subQuotient.Len()])
|
||||
subQuotient = sv.NewRegular(newSubquotient)
|
||||
subQuotientReg.WriteInSlice(newSubquotient[:subQuotientReg.Len()])
|
||||
subQuotientReg = sv.NewRegular(newSubquotient)
|
||||
}
|
||||
|
||||
tmp := sv.ScalarMul(polWitness, ri)
|
||||
subQuotient = sv.PolyAdd(tmp, subQuotient)
|
||||
subQuotientReg = sv.PolyAdd(tmp, subQuotientReg)
|
||||
|
||||
if pooled, ok := polWitness.(*sv.Pooled); ok {
|
||||
pooled.Free(pool)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
subQuotientReg = sv.PolyAdd(subQuotientReg, sv.NewConstant(subQuotientCnst, 1))
|
||||
|
||||
// This locking mechanism is completely subOptimal, but this should be good enough
|
||||
quoLock.Lock()
|
||||
quotient = sv.PolyAdd(quotient, subQuotient)
|
||||
quotient = sv.PolyAdd(quotient, subQuotientReg)
|
||||
quoLock.Unlock()
|
||||
|
||||
mainWg.Done()
|
||||
})
|
||||
|
||||
mainWg.Wait()
|
||||
|
||||
if quotient.Len() < ctx.targetSize {
|
||||
quo := sv.IntoRegVec(quotient)
|
||||
quotient = sv.RightZeroPadded(quo, ctx.targetSize)
|
||||
@@ -332,7 +359,7 @@ func (ctx mptsCtx) accumulateQuotients(run *wizard.ProverRuntime) {
|
||||
for i := range ctx.Quotients {
|
||||
// each subquotient is a slice of the original larger quotient
|
||||
subQuotient := quotient.SubVector(i*ctx.targetSize, (i+1)*ctx.targetSize)
|
||||
subQuotient = sv.FFT(subQuotient, fft.DIF, true, 0, 0)
|
||||
subQuotient = sv.FFT(subQuotient, fft.DIF, true, 0, 0, nil)
|
||||
run.AssignColumn(ctx.Quotients[i].GetColID(), subQuotient)
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func CheckReedSolomon(comp *wizard.CompiledIOP, rate int, h ifaces.Column) {
|
||||
//
|
||||
comp.SubProvers.AppendToInner(round, func(assi *wizard.ProverRuntime) {
|
||||
witness := h.GetColAssignment(assi)
|
||||
coeffs := smartvectors.FFTInverse(witness, fft.DIF, true, 0, 0).SubVector(0, codeDim)
|
||||
coeffs := smartvectors.FFTInverse(witness, fft.DIF, true, 0, 0, nil).SubVector(0, codeDim)
|
||||
assi.AssignColumn(ifaces.ColIDf("%v_%v", REED_SOLOMON_COEFF, h.GetColID()), coeffs)
|
||||
})
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func TestReedSolomon(t *testing.T) {
|
||||
|
||||
wp := smartvectors.ForTest(1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
wp = smartvectors.FFT(wp, fft.DIF, true, 0, 0)
|
||||
wp = smartvectors.FFT(wp, fft.DIF, true, 0, 0, nil)
|
||||
|
||||
definer := func(b *wizard.Builder) {
|
||||
p := b.RegisterCommit("P", wp.Len())
|
||||
|
||||
@@ -156,8 +156,8 @@ func TestPeriodicSampleCoset(t *testing.T) {
|
||||
for cosetID := 0; cosetID < ratio; cosetID++ {
|
||||
|
||||
testEval := sampling.EvalCoset(domain, cosetID, ratio, true)
|
||||
testEval = smartvectors.FFTInverse(testEval, fft.DIF, true, ratio, cosetID)
|
||||
testEval = smartvectors.FFT(testEval, fft.DIT, true, 0, 0)
|
||||
testEval = smartvectors.FFTInverse(testEval, fft.DIF, true, ratio, cosetID, nil)
|
||||
testEval = smartvectors.FFT(testEval, fft.DIT, true, 0, 0, nil)
|
||||
|
||||
require.Equal(t, vanillaEval.Pretty(), testEval.Pretty(),
|
||||
"domain %v, period %v, offset %v, ratio %v, cosetID %v",
|
||||
|
||||
@@ -22,7 +22,7 @@ func (Constant) Degree([]int) int {
|
||||
}
|
||||
|
||||
// Evaluates implements the [Operator] interface
|
||||
func (c Constant) Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector {
|
||||
func (c Constant) Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector {
|
||||
panic("we never call it for a constant")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package symbolic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/common/mempool"
|
||||
@@ -35,18 +36,18 @@ func (b *ExpressionBoard) ListVariableMetadata() []Metadata {
|
||||
}
|
||||
|
||||
// Evaluate the board for a batch of inputs in parallel
|
||||
func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
|
||||
func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
|
||||
|
||||
/*
|
||||
Find the size of the vector
|
||||
*/
|
||||
totalSize := 0
|
||||
for i, inp := range inputs {
|
||||
|
||||
if totalSize > 0 && totalSize != inp.Len() {
|
||||
// Expects that all vector inputs have the same size
|
||||
utils.Panic("Mismatch in the size: len(v) %v, totalsize %v, pos %v", inp.Len(), totalSize, i)
|
||||
}
|
||||
|
||||
if totalSize == 0 {
|
||||
totalSize = inp.Len()
|
||||
}
|
||||
@@ -58,6 +59,10 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
|
||||
utils.Panic("Either there is no input or the inputs all have size 0")
|
||||
}
|
||||
|
||||
if len(p) > 0 && p[0].Size() != MaxChunkSize {
|
||||
utils.Panic("the pool should be a pool of vectors of size=%v but it is %v", MaxChunkSize, p[0].Size())
|
||||
}
|
||||
|
||||
/*
|
||||
The is no vector input iff totalSize is 0
|
||||
Thus the condition below catch the two cases where:
|
||||
@@ -65,9 +70,18 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
|
||||
- The vectors are smaller than the min chunk size
|
||||
*/
|
||||
|
||||
if totalSize <= MaxChunkSize {
|
||||
// never pass the pool here
|
||||
return b.evaluateSingleThread(inputs)
|
||||
if totalSize < MaxChunkSize {
|
||||
// never pass the pool here as the pool assumes that all vectors have a
|
||||
// size of MaxChunkSize. Thus, it would not work here.
|
||||
return b.evaluateSingleThread(inputs).DeepCopy()
|
||||
}
|
||||
|
||||
// This is the code-path that is used for benchmarking when the size of the
|
||||
// vectors is exactly MaxChunkSize. In production, it will rather use the
|
||||
// multi-threaded option. The above condition cannot use the pool because we
|
||||
// assume here that the pool has a vector size of exactly MaxChunkSize.
|
||||
if totalSize == MaxChunkSize {
|
||||
return b.evaluateSingleThread(inputs, p...).DeepCopy()
|
||||
}
|
||||
|
||||
if totalSize%MaxChunkSize != 0 {
|
||||
@@ -77,12 +91,23 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
|
||||
numChunks := totalSize / MaxChunkSize
|
||||
res := make([]field.Element, totalSize)
|
||||
|
||||
parallel.ExecuteChunky(numChunks, func(start, stop int) {
|
||||
for chunkID := start; chunkID < stop; chunkID++ {
|
||||
parallel.ExecuteFromChan(numChunks, func(wg *sync.WaitGroup, idChan chan int) {
|
||||
|
||||
chunkStart := chunkID * MaxChunkSize
|
||||
chunkStop := (chunkID + 1) * MaxChunkSize
|
||||
chunkInputs := make([]sv.SmartVector, len(inputs))
|
||||
var pool []mempool.MemPool
|
||||
if len(p) > 0 {
|
||||
if _, ok := p[0].(*mempool.DebuggeableCall); !ok {
|
||||
pool = append(pool, mempool.WrapsWithMemCache(p[0]))
|
||||
}
|
||||
}
|
||||
|
||||
chunkInputs := make([]sv.SmartVector, len(inputs))
|
||||
|
||||
for chunkID := range idChan {
|
||||
|
||||
var (
|
||||
chunkStart = chunkID * MaxChunkSize
|
||||
chunkStop = (chunkID + 1) * MaxChunkSize
|
||||
)
|
||||
|
||||
for i, inp := range inputs {
|
||||
chunkInputs[i] = inp.SubVector(chunkStart, chunkStop)
|
||||
@@ -94,11 +119,19 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
|
||||
|
||||
// We don't parallelize evaluations where the inputs are all scalars
|
||||
// Therefore the cast is safe.
|
||||
chunkRes := b.evaluateSingleThread(chunkInputs, p...)
|
||||
chunkRes := b.evaluateSingleThread(chunkInputs, pool...)
|
||||
|
||||
// No race condition here as each call write to different places
|
||||
// of vec.
|
||||
chunkRes.WriteInSlice(res[chunkStart:chunkStop])
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
if len(p) > 0 {
|
||||
if sa, ok := pool[0].(*mempool.SliceArena); ok {
|
||||
sa.TearDown()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -107,7 +140,7 @@ func (b *ExpressionBoard) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool)
|
||||
|
||||
// evaluateSingleThread evaluates a boarded expression. The inputs can be either
|
||||
// vector or scalars. The vector's input length should be smaller than a chunk.
|
||||
func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
|
||||
func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
|
||||
|
||||
var (
|
||||
length = inputs[0].Len()
|
||||
@@ -138,10 +171,9 @@ func (b *ExpressionBoard) evaluateSingleThread(inputs []sv.SmartVector, p ...*me
|
||||
// Deep-copy the last node and put resBuf back in the pool. It's cleanier
|
||||
// that way.
|
||||
if hasPool {
|
||||
if reg, ok := resBuf.(*sv.Regular); ok {
|
||||
if reg, ok := resBuf.(*sv.Pooled); ok {
|
||||
resGC := reg.DeepCopy()
|
||||
v := []field.Element(*reg)
|
||||
pool.Free(&v)
|
||||
reg.Free(pool)
|
||||
resBuf = resGC
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/protocol/serialization"
|
||||
"github.com/consensys/zkevm-monorepo/prover/symbolic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkEvaluationSingleThreaded(b *testing.B) {
|
||||
@@ -38,18 +39,18 @@ func BenchmarkEvaluationSingleThreaded(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("ratio-%v", ratio), func(b *testing.B) {
|
||||
|
||||
var (
|
||||
testDir = "testdata/evaluation-benchmark"
|
||||
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
|
||||
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
|
||||
constanthoodFPath = path.Join(testDir, constanthoodFName)
|
||||
exprFPath = path.Join(testDir, exprFName)
|
||||
constantHoodFile = files.MustRead(constanthoodFPath)
|
||||
exprFile = files.MustReadCompressed(exprFPath)
|
||||
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
|
||||
expr = serialization.UnmarshalExprCBOR(exprFile)
|
||||
inputs = make([]smartvectors.SmartVector, len(constantHood))
|
||||
board = expr.Board()
|
||||
pool = mempool.Create(symbolic.MaxChunkSize)
|
||||
testDir = "testdata/evaluation-benchmark"
|
||||
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
|
||||
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
|
||||
constanthoodFPath = path.Join(testDir, constanthoodFName)
|
||||
exprFPath = path.Join(testDir, exprFName)
|
||||
constantHoodFile = files.MustRead(constanthoodFPath)
|
||||
exprFile = files.MustReadCompressed(exprFPath)
|
||||
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
|
||||
expr = serialization.UnmarshalExprCBOR(exprFile)
|
||||
inputs = make([]smartvectors.SmartVector, len(constantHood))
|
||||
board = expr.Board()
|
||||
pool mempool.MemPool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize)
|
||||
)
|
||||
|
||||
for i := range inputs {
|
||||
@@ -71,6 +72,78 @@ func BenchmarkEvaluationSingleThreaded(b *testing.B) {
|
||||
_ = board.Evaluate(inputs, pool)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluationSingleThreaded(t *testing.T) {
|
||||
|
||||
makeRegular := func() smartvectors.SmartVector {
|
||||
return smartvectors.Rand(symbolic.MaxChunkSize)
|
||||
}
|
||||
|
||||
makeConst := func() smartvectors.SmartVector {
|
||||
var x field.Element
|
||||
x.SetRandom()
|
||||
return smartvectors.NewConstant(x, symbolic.MaxChunkSize)
|
||||
}
|
||||
|
||||
makeFullZero := func() smartvectors.SmartVector {
|
||||
return smartvectors.NewConstant(field.Zero(), symbolic.MaxChunkSize)
|
||||
}
|
||||
|
||||
makeFullOnes := func() smartvectors.SmartVector {
|
||||
return smartvectors.NewConstant(field.One(), symbolic.MaxChunkSize)
|
||||
}
|
||||
|
||||
for ratio := 1; ratio <= 32; ratio *= 2 {
|
||||
|
||||
t.Run(fmt.Sprintf("ratio-%v", ratio), func(b *testing.T) {
|
||||
|
||||
var (
|
||||
testDir = "testdata/evaluation-benchmark"
|
||||
constanthoodFName = fmt.Sprintf("global-variable-constanthood-%v.csv", ratio)
|
||||
exprFName = fmt.Sprintf("global-cs-ratio-%v.cbor.gz", ratio)
|
||||
constanthoodFPath = path.Join(testDir, constanthoodFName)
|
||||
exprFPath = path.Join(testDir, exprFName)
|
||||
constantHoodFile = files.MustRead(constanthoodFPath)
|
||||
exprFile = files.MustReadCompressed(exprFPath)
|
||||
constantHood = symbolic.ReadConstanthoodFromCsv(constantHoodFile)
|
||||
expr = serialization.UnmarshalExprCBOR(exprFile)
|
||||
inputs = make([]smartvectors.SmartVector, len(constantHood))
|
||||
board = expr.Board()
|
||||
pool_ mempool.MemPool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize)
|
||||
)
|
||||
|
||||
pool_ = mempool.WrapsWithMemCache(pool_)
|
||||
pool := mempool.NewDebugPool(pool_)
|
||||
|
||||
_, mustBeTrue := mempool.ExtractCheckOptionalSoft(symbolic.MaxChunkSize, pool)
|
||||
_, mustBeTrue2 := mempool.ExtractCheckOptionalStrict(symbolic.MaxChunkSize, pool)
|
||||
|
||||
assert.True(t, mustBeTrue)
|
||||
assert.True(t, mustBeTrue2)
|
||||
|
||||
for i := range inputs {
|
||||
switch {
|
||||
case !constantHood[i][0]:
|
||||
inputs[i] = makeRegular()
|
||||
case constantHood[i][1]:
|
||||
inputs[i] = makeFullZero()
|
||||
case constantHood[i][2]:
|
||||
inputs[i] = makeFullOnes()
|
||||
default:
|
||||
inputs[i] = makeConst()
|
||||
}
|
||||
}
|
||||
|
||||
_ = board.Evaluate(inputs, pool)
|
||||
|
||||
if len(pool.Logs) == 0 {
|
||||
t.Fatalf("the pool was not used")
|
||||
}
|
||||
|
||||
assert.NoError(t, pool.Errors())
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
sv "github.com/consensys/zkevm-monorepo/prover/maths/common/smartvectors"
|
||||
"github.com/consensys/zkevm-monorepo/prover/maths/field"
|
||||
"github.com/consensys/zkevm-monorepo/prover/utils"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type boardAssignment [][]nodeAssignment
|
||||
@@ -66,7 +67,7 @@ func (b *ExpressionBoard) prepareNodeAssignments(inputs []sv.SmartVector) boardA
|
||||
|
||||
if success {
|
||||
for i := range inputs {
|
||||
nodeAssignments.incParentKnownCountOf(&inputs[i], nil, true)
|
||||
nodeAssignments.incParentKnownCountOf(inputs[i], nil, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +78,7 @@ func (b *ExpressionBoard) prepareNodeAssignments(inputs []sv.SmartVector) boardA
|
||||
return nodeAssignments
|
||||
}
|
||||
|
||||
func (b boardAssignment) eval(na *nodeAssignment, pool *mempool.Pool) {
|
||||
func (b boardAssignment) eval(na *nodeAssignment, pool mempool.MemPool) {
|
||||
|
||||
if (na.allParentsKnown() && na.hasParents()) || na.hasAValue() {
|
||||
return
|
||||
@@ -99,11 +100,11 @@ func (b boardAssignment) eval(na *nodeAssignment, pool *mempool.Pool) {
|
||||
na.Value = na.Node.Operator.Evaluate(smv, pool)
|
||||
|
||||
for i := range val {
|
||||
b.incParentKnownCountOf(&val[i], pool, false)
|
||||
b.incParentKnownCountOf(val[i], pool, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (na *nodeAssignment) tryGuessEval(val []nodeAssignment) bool {
|
||||
func (na *nodeAssignment) tryGuessEval(val []*nodeAssignment) bool {
|
||||
|
||||
if na.hasAValue() {
|
||||
return true
|
||||
@@ -146,6 +147,7 @@ func (na *nodeAssignment) tryGuessEval(val []nodeAssignment) bool {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
default:
|
||||
panic("unexpected type")
|
||||
}
|
||||
@@ -176,13 +178,13 @@ func (na *nodeAssignment) constValue() (*sv.Constant, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (b boardAssignment) inputOf(na *nodeAssignment) []nodeAssignment {
|
||||
func (b boardAssignment) inputOf(na *nodeAssignment) []*nodeAssignment {
|
||||
|
||||
if na.Node == nil {
|
||||
panic("na has a nil node")
|
||||
}
|
||||
|
||||
nodeInputs := make([]nodeAssignment, len(na.Node.Children))
|
||||
nodeInputs := make([]*nodeAssignment, len(na.Node.Children))
|
||||
|
||||
for i, childID := range na.Node.Children {
|
||||
var (
|
||||
@@ -190,25 +192,31 @@ func (b boardAssignment) inputOf(na *nodeAssignment) []nodeAssignment {
|
||||
pil = childID.posInLevel()
|
||||
)
|
||||
|
||||
nodeInputs[i] = b[lvl][pil]
|
||||
nodeInputs[i] = &b[lvl][pil]
|
||||
}
|
||||
return nodeInputs
|
||||
}
|
||||
|
||||
func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool *mempool.Pool, recursive bool) (wasDeleted bool) {
|
||||
func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool mempool.MemPool, recursive bool) (wasDeleted bool) {
|
||||
|
||||
na.NumKnownParents++
|
||||
|
||||
// Sanity-checking that this function is not called too many time
|
||||
if na.NumKnownParents > len(na.Node.Parents) {
|
||||
panic("invalid count: overflowing the total number of parent")
|
||||
utils.Panic("invalid count: overflowing the total number of parent")
|
||||
}
|
||||
|
||||
if na.allParentsKnown() {
|
||||
|
||||
if recursive {
|
||||
// The recursive call to incParentKnownCount is needed only if the node
|
||||
// that we "completed" by marking all its parent as known was completed
|
||||
// **only** for that reason. It could also have been completed because
|
||||
// all its children are constants. When that is the case, all the children
|
||||
// will have been incremented already.
|
||||
if recursive && na.Value == nil {
|
||||
children := b.inputOf(na)
|
||||
for i := range children {
|
||||
b.incParentKnownCountOf(&children[i], pool, recursive)
|
||||
b.incParentKnownCountOf(children[i], pool, recursive)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +226,7 @@ func (b boardAssignment) incParentKnownCountOf(na *nodeAssignment, pool *mempool
|
||||
return false
|
||||
}
|
||||
|
||||
func (na *nodeAssignment) tryFree(pool *mempool.Pool) bool {
|
||||
func (na *nodeAssignment) tryFree(pool mempool.MemPool) bool {
|
||||
if pool == nil {
|
||||
return false
|
||||
}
|
||||
@@ -236,12 +244,39 @@ func (na *nodeAssignment) tryFree(pool *mempool.Pool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if reg, ok := na.Value.(*sv.Regular); ok {
|
||||
if reg, ok := na.Value.(*sv.Pooled); ok {
|
||||
na.Value = nil
|
||||
v := []field.Element(*reg)
|
||||
pool.Free(&v)
|
||||
reg.Free(pool)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (b boardAssignment) inspectCleaning() {
|
||||
|
||||
for lvl := 1; lvl < len(b); lvl++ {
|
||||
for pil := range b[lvl] {
|
||||
na := b[lvl][pil]
|
||||
if na.NumKnownParents != len(na.Node.Parents) {
|
||||
logrus.Errorf(
|
||||
"the parent count was not updated till the end lvl=%v pil=%v parentCnt=%v nbParent=%v valueT=%T parents=%v",
|
||||
lvl, pil, na.NumKnownParents, len(na.Node.Parents), na.Value, na.Node.Parents,
|
||||
)
|
||||
}
|
||||
|
||||
if na.Value == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
p, ok := na.Value.(*sv.Pooled)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if p.Regular != nil {
|
||||
logrus.Errorf("the result of node [%v %v] was not cleaned", lvl, pil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ type Expression struct {
|
||||
type Operator interface {
|
||||
// Evaluate returns an evaluation of the operator from a list of assignments:
|
||||
// one for each operand (children) of the expression.
|
||||
Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector
|
||||
Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector
|
||||
// Validate performs a sanity-check of the expression the Operator belongs
|
||||
// to.
|
||||
Validate(e *Expression) error
|
||||
|
||||
@@ -96,7 +96,7 @@ func (LinComb) Degree(inputDegrees []int) int {
|
||||
}
|
||||
|
||||
// Evaluate implements the [Operator] interface.
|
||||
func (lc LinComb) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
|
||||
func (lc LinComb) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
|
||||
return sv.LinComb(lc.Coeffs, inputs, p...)
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ func (PolyEval) Degree(inputDegrees []int) int {
|
||||
/*
|
||||
Evaluates a polynomial evaluation
|
||||
*/
|
||||
func (PolyEval) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
|
||||
func (PolyEval) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
|
||||
// We assume that the first element is always a scalar
|
||||
// Get the constant value. We use Get(0) to get the value, but any integer would
|
||||
// also work provided it is also in range. 0 ensures that.
|
||||
|
||||
@@ -124,7 +124,7 @@ func (prod Product) Degree(inputDegrees []int) int {
|
||||
}
|
||||
|
||||
// Evaluate implements the [Operator] interface.
|
||||
func (prod Product) Evaluate(inputs []sv.SmartVector, p ...*mempool.Pool) sv.SmartVector {
|
||||
func (prod Product) Evaluate(inputs []sv.SmartVector, p ...mempool.MemPool) sv.SmartVector {
|
||||
return sv.Product(prod.Exponents, inputs, p...)
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func (Variable) Degree([]int) int {
|
||||
}
|
||||
|
||||
// Evaluate implements the [Operator] interface. Yet, this panics if this is called.
|
||||
func (v Variable) Evaluate([]sv.SmartVector, ...*mempool.Pool) sv.SmartVector {
|
||||
func (v Variable) Evaluate([]sv.SmartVector, ...mempool.MemPool) sv.SmartVector {
|
||||
panic("we never call it for variables")
|
||||
}
|
||||
|
||||
|
||||
41
prover/utils/parallel/job_stealing.go
Normal file
41
prover/utils/parallel/job_stealing.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package parallel
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ExecuteJobStealing parallelizes a workload specified by a function consuming
|
||||
// a channel distributing the workload. It is appropriate when each iteration
|
||||
// takes an order of magnitude more time than the other functions.
|
||||
//
|
||||
// This is as [ExecuteChunky] but gives more freedom to the caller to initialize
|
||||
// its threads.
|
||||
func ExecuteFromChan(nbIterations int, work func(wg *sync.WaitGroup, indexChan chan int), numcpus ...int) {
|
||||
|
||||
numcpu := runtime.GOMAXPROCS(0)
|
||||
if len(numcpus) > 0 && numcpus[0] > 0 {
|
||||
numcpu = numcpus[0]
|
||||
}
|
||||
|
||||
// The jobs are sent one by one to the workers
|
||||
jobChan := make(chan int, nbIterations)
|
||||
for i := 0; i < nbIterations; i++ {
|
||||
jobChan <- i
|
||||
}
|
||||
|
||||
// The wait group ensures that all the children goroutine have terminated
|
||||
// before we close the
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(nbIterations)
|
||||
|
||||
// Each goroutine consumes the jobChan to
|
||||
for p := 0; p < numcpu; p++ {
|
||||
go func() {
|
||||
work(wg, jobChan)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(jobChan)
|
||||
}
|
||||
Reference in New Issue
Block a user