mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 04:08:01 -05:00
fix: minor mem leak in quotient compute in prover (#968)
* fix race condition * fix: add a mutex to avoid mem leak in quotient computation --------- Co-authored-by: AlexandreBelling <alexandrebelling8@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package mempool
|
||||
|
||||
import (
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||
"github.com/consensys/linea-monorepo/prover/utils"
|
||||
)
|
||||
|
||||
// SliceArena is a simple not-threadsafe arena implementation that uses a
|
||||
@@ -46,6 +47,8 @@ func (m *SliceArena) Size() int {
|
||||
|
||||
func (m *SliceArena) TearDown() {
|
||||
for i := range m.frees {
|
||||
m.parent.Free(m.frees[i])
|
||||
if err := m.parent.Free(m.frees[i]); err != nil {
|
||||
utils.Panic("failed to free slice in arena: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// CosetRatio > CosetID:
|
||||
// - Specifies on which coset to perform the operation
|
||||
// - 0, 0 to assert that the transformation should not be done over a coset
|
||||
func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool) SmartVector {
|
||||
func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool, opts ...fft.Option) SmartVector {
|
||||
|
||||
// Sanity-check on the size of the vector v
|
||||
assertPowerOfTwoLen(v.Len())
|
||||
@@ -66,10 +66,9 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
|
||||
v.WriteInSlice(res.Regular)
|
||||
|
||||
domain := fft.NewDomain(v.Len())
|
||||
opt := fft.EmptyOption()
|
||||
|
||||
if cosetID != 0 || cosetRatio != 0 {
|
||||
opt = fft.OnCoset()
|
||||
opts = append(opts, fft.OnCoset())
|
||||
domain = domain.WithCustomCoset(cosetRatio, cosetID)
|
||||
}
|
||||
|
||||
@@ -78,10 +77,10 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
|
||||
if bitReverse {
|
||||
fft.BitReverse(res.Regular)
|
||||
}
|
||||
domain.FFT(res.Regular, fft.DIT, opt)
|
||||
domain.FFT(res.Regular, fft.DIT, opts...)
|
||||
} else {
|
||||
// Likewise, the optionally rearrange the input in correct order
|
||||
domain.FFT(res.Regular, fft.DIF, opt)
|
||||
domain.FFT(res.Regular, fft.DIF, opts...)
|
||||
if bitReverse {
|
||||
fft.BitReverse(res.Regular)
|
||||
}
|
||||
@@ -102,7 +101,7 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
|
||||
// CosetRatio > CosetID:
|
||||
// - Specifies on which coset to perform the operation
|
||||
// - 0, 0 to assert that the transformation should not be done over a coset
|
||||
func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool) SmartVector {
|
||||
func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool, opts ...fft.Option) SmartVector {
|
||||
|
||||
// Sanity-check on the size of the vector v
|
||||
assertPowerOfTwoLen(v.Len())
|
||||
@@ -147,19 +146,18 @@ func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, coset
|
||||
res = &Pooled{Regular: make([]field.Element, v.Len())}
|
||||
}
|
||||
|
||||
opt := fft.EmptyOption()
|
||||
v.WriteInSlice(res.Regular)
|
||||
|
||||
domain := fft.NewDomain(v.Len())
|
||||
if cosetID != 0 || cosetRatio != 0 {
|
||||
// Optionally equip the domain with a coset
|
||||
opt = fft.OnCoset()
|
||||
opts = append(opts, fft.OnCoset())
|
||||
domain = domain.WithCustomCoset(cosetRatio, cosetID)
|
||||
}
|
||||
|
||||
if decimation == fft.DIF {
|
||||
// Optionally, bitReverse the output
|
||||
domain.FFTInverse(res.Regular, fft.DIF, opt)
|
||||
domain.FFTInverse(res.Regular, fft.DIF, opts...)
|
||||
if bitReverse {
|
||||
fft.BitReverse(res.Regular)
|
||||
}
|
||||
@@ -168,7 +166,7 @@ func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, coset
|
||||
if bitReverse {
|
||||
fft.BitReverse(res.Regular)
|
||||
}
|
||||
domain.FFTInverse(res.Regular, fft.DIT, opt)
|
||||
domain.FFTInverse(res.Regular, fft.DIT, opts...)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package smartvectors
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/poly"
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/vector"
|
||||
@@ -111,7 +112,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
|
||||
polys = make([][]field.Element, len(vs))
|
||||
results = make([]field.Element, len(vs))
|
||||
computed = make([]bool, len(vs))
|
||||
totalConstant = 0
|
||||
totalConstant = uint64(0)
|
||||
)
|
||||
|
||||
// smartvector to []fr.element
|
||||
@@ -122,7 +123,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
|
||||
// constant vectors
|
||||
results[i] = con.val
|
||||
computed[i] = true
|
||||
totalConstant++
|
||||
atomic.AddUint64(&totalConstant, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -131,7 +132,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
|
||||
}
|
||||
})
|
||||
|
||||
if totalConstant == len(vs) {
|
||||
if int(totalConstant) == len(vs) {
|
||||
return results
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package smartvectors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field/fext"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/mempool"
|
||||
@@ -192,7 +193,9 @@ func AllocFromPool(pool mempool.MemPool) *Pooled {
|
||||
|
||||
func (p *Pooled) Free(pool mempool.MemPool) {
|
||||
if p.poolPtr != nil {
|
||||
pool.Free(p.poolPtr)
|
||||
if err := pool.Free(p.poolPtr); err != nil {
|
||||
utils.Panic("failed to free slice in pool: %v", err)
|
||||
}
|
||||
}
|
||||
p.poolPtr = nil
|
||||
p.Regular = nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package smartvectors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field/fext"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/vector"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package fft
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/linea-monorepo/prover/utils/parallel"
|
||||
"math/bits"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||
)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package globalcs
|
||||
|
||||
import (
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/coin"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/variables"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/coin"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/variables"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/mempool"
|
||||
@@ -196,12 +197,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
|
||||
logrus.Infof("run the prover for the global constraint (quotient computation)")
|
||||
|
||||
// threadSafeVector is a wrapper around the smart vector that makes it lockable
|
||||
type threadSafeVector struct {
|
||||
*sync.Mutex
|
||||
sv.SmartVector
|
||||
}
|
||||
|
||||
var (
|
||||
// Tracks the time spent on garbage collection
|
||||
totalTimeGc = int64(0)
|
||||
|
||||
// Initial step is to compute the FFTs for all committed vectors
|
||||
coeffs = sync.Map{} // (ifaces.ColID <=> sv.SmartVector)
|
||||
coeffs = sync.Map{} // (ifaces.ColID <=> threadSafeVector{sv.SmartVector})
|
||||
stopTimer = profiling.LogTimer("Computing the coeffs %v pols of size %v", len(ctx.AllInvolvedColumns), ctx.DomainSize)
|
||||
pool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize).Prewarm(runtime.GOMAXPROCS(0) * ctx.MaxNbExprNode)
|
||||
largePool = mempool.CreateFromSyncPool(ctx.DomainSize).Prewarm(len(ctx.AllInvolvedColumns))
|
||||
@@ -235,9 +242,9 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
witness = pol.GetColAssignment(run)
|
||||
}
|
||||
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil, fft.WithNbTasks(2))
|
||||
|
||||
coeffs.Store(name, witness)
|
||||
coeffs.Store(name, threadSafeVector{&sync.Mutex{}, witness})
|
||||
})
|
||||
|
||||
wg.Done()
|
||||
@@ -258,11 +265,11 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
// normal case for interleaved or repeated columns
|
||||
witness := pol.GetColAssignment(run)
|
||||
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
|
||||
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil, fft.WithNbTasks(2))
|
||||
|
||||
name := pol.GetColID()
|
||||
|
||||
coeffs.Store(name, witness)
|
||||
coeffs.Store(name, threadSafeVector{&sync.Mutex{}, witness})
|
||||
})
|
||||
wg.Done()
|
||||
}()
|
||||
@@ -359,18 +366,31 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
root := roots[k]
|
||||
name := root.GetColID()
|
||||
|
||||
_, found := computedReeval.Load(name)
|
||||
|
||||
if found {
|
||||
// it was already computed in a previous iteration of `j`
|
||||
// check if it was already computed
|
||||
if _, ok := computedReeval.Load(name); ok {
|
||||
return
|
||||
}
|
||||
|
||||
// else it's the first value of j that sees it. so we compute the
|
||||
// coset reevaluation.
|
||||
|
||||
v, _ := coeffs.Load(name)
|
||||
reevaledRoot := sv.FFT(v.(sv.SmartVector), fft.DIT, false, ratio, share, localPool)
|
||||
v, ok := coeffs.Load(name)
|
||||
if !ok {
|
||||
utils.Panic("handle %v not found in the coeffs (a)\n", name)
|
||||
}
|
||||
|
||||
// lock the vector to ensure we don't do twice the same compute
|
||||
// and that we don't (over)write an entry in the map with a (leaking) pool vector
|
||||
v.(threadSafeVector).Lock()
|
||||
defer v.(threadSafeVector).Unlock()
|
||||
|
||||
// check again if it was already computed
|
||||
// (can happen if 2 go routines hit the lock at the same time)
|
||||
if _, ok := computedReeval.Load(name); ok {
|
||||
return
|
||||
}
|
||||
|
||||
reevaledRoot := sv.FFT(v.(threadSafeVector).SmartVector, fft.DIT, false, ratio, share, localPool, fft.WithNbTasks(2))
|
||||
computedReeval.Store(name, reevaledRoot)
|
||||
})
|
||||
|
||||
@@ -382,6 +402,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
// short-path, the column is a purely Shifted(Natural) or a Natural
|
||||
// (this excludes repeats and/or interleaved columns)
|
||||
rootCols := column.RootParents(pol)
|
||||
|
||||
if len(rootCols) == 1 && rootCols[0].Size() == pol.Size() {
|
||||
|
||||
root := rootCols[0]
|
||||
@@ -410,8 +431,9 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
}
|
||||
|
||||
name := pol.GetColID()
|
||||
_, ok := computedReeval.Load(name)
|
||||
if ok {
|
||||
|
||||
if _, ok := computedReeval.Load(name); ok {
|
||||
// already computed
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,7 +442,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
utils.Panic("handle %v not found in the coeffs\n", name)
|
||||
}
|
||||
|
||||
res := sv.FFT(v.(sv.SmartVector), fft.DIT, false, ratio, share, localPool)
|
||||
// lock the vector to ensure we don't do twice the same compute
|
||||
// and that we don't (over)write an entry in the map with a (leaking) pool vector
|
||||
v.(threadSafeVector).Lock()
|
||||
defer v.(threadSafeVector).Unlock()
|
||||
|
||||
// check again if it was already computed
|
||||
// (can happen if 2 go routines hit the lock at the same time)
|
||||
if _, ok := computedReeval.Load(name); ok {
|
||||
return
|
||||
}
|
||||
|
||||
res := sv.FFT(v.(threadSafeVector).SmartVector, fft.DIT, false, ratio, share, localPool, fft.WithNbTasks(2))
|
||||
computedReeval.Store(name, res)
|
||||
|
||||
})
|
||||
@@ -445,7 +478,10 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
|
||||
case ifaces.Column:
|
||||
//name := metadata.GetColID()
|
||||
//evalInputs[k] = computedReeval[name]
|
||||
value, _ := computedReeval.Load(metadata.GetColID())
|
||||
value, ok := computedReeval.Load(metadata.GetColID())
|
||||
if !ok {
|
||||
utils.Panic("did not find the reevaluation of %v", metadata.GetColID())
|
||||
}
|
||||
evalInputs[k] = value.(sv.SmartVector)
|
||||
case coin.Info:
|
||||
evalInputs[k] = sv.NewConstant(run.GetRandomCoinField(metadata.Name), ctx.DomainSize)
|
||||
|
||||
Reference in New Issue
Block a user