fix: minor mem leak in quotient compute in prover (#968)

* fix race condition

* fix: add a mutex to avoid mem leak in quotient computation

---------

Co-authored-by: AlexandreBelling <alexandrebelling8@gmail.com>
This commit is contained in:
Gautam Botrel
2025-05-16 05:53:06 -05:00
committed by GitHub
parent be03e86021
commit 8a0bcc8fb0
7 changed files with 76 additions and 33 deletions

View File

@@ -2,6 +2,7 @@ package mempool
import (
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
)
// SliceArena is a simple not-threadsafe arena implementation that uses a
@@ -46,6 +47,8 @@ func (m *SliceArena) Size() int {
func (m *SliceArena) TearDown() {
for i := range m.frees {
m.parent.Free(m.frees[i])
if err := m.parent.Free(m.frees[i]); err != nil {
utils.Panic("failed to free slice in arena: %v", err)
}
}
}

View File

@@ -19,7 +19,7 @@ import (
// CosetRatio > CosetID:
// - Specifies on which coset to perform the operation
// - 0, 0 to assert that the transformation should not be done over a coset
func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool) SmartVector {
func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool, opts ...fft.Option) SmartVector {
// Sanity-check on the size of the vector v
assertPowerOfTwoLen(v.Len())
@@ -66,10 +66,9 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
v.WriteInSlice(res.Regular)
domain := fft.NewDomain(v.Len())
opt := fft.EmptyOption()
if cosetID != 0 || cosetRatio != 0 {
opt = fft.OnCoset()
opts = append(opts, fft.OnCoset())
domain = domain.WithCustomCoset(cosetRatio, cosetID)
}
@@ -78,10 +77,10 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
if bitReverse {
fft.BitReverse(res.Regular)
}
domain.FFT(res.Regular, fft.DIT, opt)
domain.FFT(res.Regular, fft.DIT, opts...)
} else {
// Likewise, the optionally rearrange the input in correct order
domain.FFT(res.Regular, fft.DIF, opt)
domain.FFT(res.Regular, fft.DIF, opts...)
if bitReverse {
fft.BitReverse(res.Regular)
}
@@ -102,7 +101,7 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i
// CosetRatio > CosetID:
// - Specifies on which coset to perform the operation
// - 0, 0 to assert that the transformation should not be done over a coset
func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool) SmartVector {
func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio int, cosetID int, pool mempool.MemPool, opts ...fft.Option) SmartVector {
// Sanity-check on the size of the vector v
assertPowerOfTwoLen(v.Len())
@@ -147,19 +146,18 @@ func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, coset
res = &Pooled{Regular: make([]field.Element, v.Len())}
}
opt := fft.EmptyOption()
v.WriteInSlice(res.Regular)
domain := fft.NewDomain(v.Len())
if cosetID != 0 || cosetRatio != 0 {
// Optionally equip the domain with a coset
opt = fft.OnCoset()
opts = append(opts, fft.OnCoset())
domain = domain.WithCustomCoset(cosetRatio, cosetID)
}
if decimation == fft.DIF {
// Optionally, bitReverse the output
domain.FFTInverse(res.Regular, fft.DIF, opt)
domain.FFTInverse(res.Regular, fft.DIF, opts...)
if bitReverse {
fft.BitReverse(res.Regular)
}
@@ -168,7 +166,7 @@ func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, coset
if bitReverse {
fft.BitReverse(res.Regular)
}
domain.FFTInverse(res.Regular, fft.DIT, opt)
domain.FFTInverse(res.Regular, fft.DIT, opts...)
}
return res
}

View File

@@ -2,6 +2,7 @@ package smartvectors
import (
"math/big"
"sync/atomic"
"github.com/consensys/linea-monorepo/prover/maths/common/poly"
"github.com/consensys/linea-monorepo/prover/maths/common/vector"
@@ -111,7 +112,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
polys = make([][]field.Element, len(vs))
results = make([]field.Element, len(vs))
computed = make([]bool, len(vs))
totalConstant = 0
totalConstant = uint64(0)
)
// smartvector to []fr.element
@@ -122,7 +123,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
// constant vectors
results[i] = con.val
computed[i] = true
totalConstant++
atomic.AddUint64(&totalConstant, 1)
continue
}
@@ -131,7 +132,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel
}
})
if totalConstant == len(vs) {
if int(totalConstant) == len(vs) {
return results
}

View File

@@ -2,6 +2,7 @@ package smartvectors
import (
"fmt"
"github.com/consensys/linea-monorepo/prover/maths/field/fext"
"github.com/consensys/linea-monorepo/prover/maths/common/mempool"
@@ -192,7 +193,9 @@ func AllocFromPool(pool mempool.MemPool) *Pooled {
func (p *Pooled) Free(pool mempool.MemPool) {
if p.poolPtr != nil {
pool.Free(p.poolPtr)
if err := pool.Free(p.poolPtr); err != nil {
utils.Panic("failed to free slice in pool: %v", err)
}
}
p.poolPtr = nil
p.Regular = nil

View File

@@ -2,6 +2,7 @@ package smartvectors
import (
"fmt"
"github.com/consensys/linea-monorepo/prover/maths/field/fext"
"github.com/consensys/linea-monorepo/prover/maths/common/vector"

View File

@@ -1,9 +1,10 @@
package fft
import (
"math/bits"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/linea-monorepo/prover/utils/parallel"
"math/bits"
"github.com/consensys/linea-monorepo/prover/maths/field"
)

View File

@@ -1,14 +1,15 @@
package globalcs
import (
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/variables"
"math/big"
"reflect"
"runtime"
"sync"
"time"
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/variables"
"github.com/sirupsen/logrus"
"github.com/consensys/linea-monorepo/prover/maths/common/mempool"
@@ -196,12 +197,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
logrus.Infof("run the prover for the global constraint (quotient computation)")
// threadSafeVector is a wrapper around the smart vector that makes it lockable
type threadSafeVector struct {
*sync.Mutex
sv.SmartVector
}
var (
// Tracks the time spent on garbage collection
totalTimeGc = int64(0)
// Initial step is to compute the FFTs for all committed vectors
coeffs = sync.Map{} // (ifaces.ColID <=> sv.SmartVector)
coeffs = sync.Map{} // (ifaces.ColID <=> threadSafeVector{sv.SmartVector})
stopTimer = profiling.LogTimer("Computing the coeffs %v pols of size %v", len(ctx.AllInvolvedColumns), ctx.DomainSize)
pool = mempool.CreateFromSyncPool(symbolic.MaxChunkSize).Prewarm(runtime.GOMAXPROCS(0) * ctx.MaxNbExprNode)
largePool = mempool.CreateFromSyncPool(ctx.DomainSize).Prewarm(len(ctx.AllInvolvedColumns))
@@ -235,9 +242,9 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
witness = pol.GetColAssignment(run)
}
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil, fft.WithNbTasks(2))
coeffs.Store(name, witness)
coeffs.Store(name, threadSafeVector{&sync.Mutex{}, witness})
})
wg.Done()
@@ -258,11 +265,11 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
// normal case for interleaved or repeated columns
witness := pol.GetColAssignment(run)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil)
witness = sv.FFTInverse(witness, fft.DIF, false, 0, 0, nil, fft.WithNbTasks(2))
name := pol.GetColID()
coeffs.Store(name, witness)
coeffs.Store(name, threadSafeVector{&sync.Mutex{}, witness})
})
wg.Done()
}()
@@ -359,18 +366,31 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
root := roots[k]
name := root.GetColID()
_, found := computedReeval.Load(name)
if found {
// it was already computed in a previous iteration of `j`
// check if it was already computed
if _, ok := computedReeval.Load(name); ok {
return
}
// else it's the first value of j that sees it. so we compute the
// coset reevaluation.
v, _ := coeffs.Load(name)
reevaledRoot := sv.FFT(v.(sv.SmartVector), fft.DIT, false, ratio, share, localPool)
v, ok := coeffs.Load(name)
if !ok {
utils.Panic("handle %v not found in the coeffs (a)\n", name)
}
// lock the vector to ensure we don't do twice the same compute
// and that we don't (over)write an entry in the map with a (leaking) pool vector
v.(threadSafeVector).Lock()
defer v.(threadSafeVector).Unlock()
// check again if it was already computed
// (can happen if 2 go routines hit the lock at the same time)
if _, ok := computedReeval.Load(name); ok {
return
}
reevaledRoot := sv.FFT(v.(threadSafeVector).SmartVector, fft.DIT, false, ratio, share, localPool, fft.WithNbTasks(2))
computedReeval.Store(name, reevaledRoot)
})
@@ -382,6 +402,7 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
// short-path, the column is a purely Shifted(Natural) or a Natural
// (this excludes repeats and/or interleaved columns)
rootCols := column.RootParents(pol)
if len(rootCols) == 1 && rootCols[0].Size() == pol.Size() {
root := rootCols[0]
@@ -410,8 +431,9 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
}
name := pol.GetColID()
_, ok := computedReeval.Load(name)
if ok {
if _, ok := computedReeval.Load(name); ok {
// already computed
return
}
@@ -420,7 +442,18 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
utils.Panic("handle %v not found in the coeffs\n", name)
}
res := sv.FFT(v.(sv.SmartVector), fft.DIT, false, ratio, share, localPool)
// lock the vector to ensure we don't do twice the same compute
// and that we don't (over)write an entry in the map with a (leaking) pool vector
v.(threadSafeVector).Lock()
defer v.(threadSafeVector).Unlock()
// check again if it was already computed
// (can happen if 2 go routines hit the lock at the same time)
if _, ok := computedReeval.Load(name); ok {
return
}
res := sv.FFT(v.(threadSafeVector).SmartVector, fft.DIT, false, ratio, share, localPool, fft.WithNbTasks(2))
computedReeval.Store(name, res)
})
@@ -445,7 +478,10 @@ func (ctx *quotientCtx) Run(run *wizard.ProverRuntime) {
case ifaces.Column:
//name := metadata.GetColID()
//evalInputs[k] = computedReeval[name]
value, _ := computedReeval.Load(metadata.GetColID())
value, ok := computedReeval.Load(metadata.GetColID())
if !ok {
utils.Panic("did not find the reevaluation of %v", metadata.GetColID())
}
evalInputs[k] = value.(sv.SmartVector)
case coin.Info:
evalInputs[k] = sv.NewConstant(run.GetRandomCoinField(metadata.Name), ctx.DomainSize)