Perf(prover): uses padded-circular windows for the MiMC assignment (#770)

* use padded circular window to reduce runtime memory

* (feat): use padded circular window for mimc col assignment
This commit is contained in:
Lakshminarayanan Nandakumar
2025-03-13 13:39:48 +01:00
committed by GitHub
parent f0649e7fb8
commit 9988ebbee0
2 changed files with 154 additions and 69 deletions

View File

@@ -19,6 +19,10 @@ type Witness struct {
}
func Prove(cfg *config.Config, req *Request, large bool) (*Response, error) {
// Set MonitorParams before any proving happens
profiling.SetMonitorParams(cfg)
traces := &cfg.TracesLimits
if large {
traces = &cfg.TracesLimitsLarge

View File

@@ -4,89 +4,170 @@ import (
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/utils/parallel"
)
// assign assigns the columns to the prover runtime
// assign: Assigns the columns to the prover runtime using PaddedCircularWindow
func (ctx *mimcCtx) assign(run *wizard.ProverRuntime) {
var (
oldState = ctx.oldStates.GetColAssignment(run).IntoRegVecSaveAlloc()
blocks = ctx.blocks.GetColAssignment(run).IntoRegVecSaveAlloc()
// Initialize slices to hold intermediate results and intermediatePow4
// The first entry is left empty for consistency with ctx.intermediateResult
// We don't need to assign it because it is assigned already.
intermediateRes = make([][]field.Element, len(ctx.intermediateResult))
intermediatePow4 = make([][]field.Element, len(ctx.intermediateResult))
oldStateSV = ctx.oldStates.GetColAssignment(run)
blocksSV = ctx.blocks.GetColAssignment(run)
totalRows = oldStateSV.Len()
numRounds = len(ctx.intermediateResult)
)
// Initialize intermediateRes and intermediatePow4 with correct lengths
for i := range intermediateRes {
// For each intermediate result, create a slice of field.Elements with length numRows
intermediateRes[i] = make([]field.Element, len(oldState))
intermediatePow4[i] = make([]field.Element, len(oldState))
offset, windowLen := identifyActiveWindow(oldStateSV, blocksSV, totalRows)
oldStateWindow, blocksWindow := extractWindowSlices(oldStateSV, blocksSV, offset, windowLen)
intermediateResWindow, intermediatePow4Window := computeIntermediateValues(numRounds, oldStateWindow, blocksWindow, windowLen)
var resPad, pow4Pad []field.Element
// Precompute padding only when `PaddedCircularWindow` is tobe used
// i.e. Whenever there is sparsity in the (oldState, blocks) pair
if windowLen != totalRows {
resPad, pow4Pad = precomputePaddingValues(numRounds)
}
// Set the initial intermediate res as the block itself
intermediateRes[0] = blocks
// Compute intermediate values for each round
for i := range ctx.intermediateResult {
computeIntermediateValues(i, oldState, intermediateRes, intermediatePow4)
}
// Assign columns
for i := range ctx.intermediateResult {
// Assign computed values to the runtime
if i > 0 {
// Skip the first intermediate result
// Recall that the first intermediate res is the block itself
run.AssignColumn(
ctx.intermediateResult[i].GetColID(),
smartvectors.NewRegular(intermediateRes[i]),
)
}
// Assign intermediatePow4 to the runtime
run.AssignColumn(
ctx.intermediatePow4[i].GetColID(),
smartvectors.NewRegular(intermediatePow4[i]),
)
}
assignOptimizedVectors(run, ctx, intermediateResWindow, intermediatePow4Window, resPad, pow4Pad, offset, totalRows)
}
// computeIntermediateValues computes intermediate values for the given round
func computeIntermediateValues(round int, oldState []field.Element, intermediateRes, intermediatePow4 [][]field.Element) {
parallel.Execute(len(oldState), func(start, stop int) {
for k := start; k < stop; k++ {
if round == 0 {
// For the first round, compute initial intermediatePow4
tmp := intermediateRes[0][k]
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldState[k])
intermediatePow4[0][k].Square(&tmp).Square(&intermediatePow4[0][k])
// identifyActiveWindow finds the smallest active window scanning through the oldState and blocks
func identifyActiveWindow(oldStateSV, blocksSV smartvectors.SmartVector, totalRows int) (offset int, windowLen int) {
// Convert to regular vectors to scan all elements
var (
oldState = smartvectors.IntoRegVec(oldStateSV)
blocks = smartvectors.IntoRegVec(blocksSV)
)
// Initialize firstNonZero and lastNonZero indices to default values
firstNonZero, lastNonZero := totalRows, -1
for i := 0; i < totalRows; i++ {
if !oldState[i].IsZero() || !blocks[i].IsZero() {
firstNonZero = min(firstNonZero, i)
lastNonZero = max(lastNonZero, i)
}
}
if firstNonZero <= lastNonZero {
offset = firstNonZero
windowLen = lastNonZero - firstNonZero + 1
return offset, windowLen
}
// Default window => Full window
return 0, totalRows
}
// computeIntermediateValues computes intermediate values for the window
func computeIntermediateValues(numRounds int, oldStateWindow, blocksWindow []field.Element, windowLen int) ([][]field.Element, [][]field.Element) {
intermediateResWindow := make([][]field.Element, numRounds)
intermediatePow4Window := make([][]field.Element, numRounds)
for i := range intermediateResWindow {
intermediateResWindow[i] = make([]field.Element, windowLen)
intermediatePow4Window[i] = make([]field.Element, windowLen)
}
// Initalize intermediateResWindow to the blocksWindow
copy(intermediateResWindow[0], blocksWindow)
// r => round
for r := 0; r < numRounds; r++ {
parallel.Execute(windowLen, func(start, stop int) {
for k := start; k < stop; k++ {
if r == 0 {
tmp := intermediateResWindow[0][k]
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldStateWindow[k])
intermediatePow4Window[0][k].Square(&tmp).Square(&intermediatePow4Window[0][k])
} else {
// For subsequent rounds, compute intermediate values based on previous results
ark := mimc.Constants[r-1]
nextArk := mimc.Constants[r]
tmp := intermediatePow4Window[r-1][k]
tmp.Square(&tmp).Square(&tmp)
// Compute intermediate result using previous result and oldState
intermediateResWindow[r][k] = intermediateResWindow[r-1][k]
intermediateResWindow[r][k].Add(&intermediateResWindow[r][k], &ark).Add(&intermediateResWindow[r][k], &oldStateWindow[k])
intermediateResWindow[r][k].Mul(&intermediateResWindow[r][k], &tmp)
// Compute intermediatePow4
tmp = intermediateResWindow[r][k]
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldStateWindow[k])
intermediatePow4Window[r][k].Square(&tmp).Square(&intermediatePow4Window[r][k])
}
}
})
}
return intermediateResWindow, intermediatePow4Window
}
// assignOptimizedVectors assigns optimized vectors to the prover runtime
func assignOptimizedVectors(run *wizard.ProverRuntime, ctx *mimcCtx, intermediateResWindow, intermediatePow4Window [][]field.Element, resPad, pow4Pad []field.Element, offset, totalRows int) {
for round := range ctx.intermediateResult {
windowLen := len(intermediateResWindow[round])
// Full-length window: use Regular vector
isRegSmartVec := windowLen == totalRows
// Helper function to assign a column with the appropriate smart vector
assignColumn := func(colID ifaces.ColID, window []field.Element, padVal field.Element) {
if isRegSmartVec {
fullVec := make([]field.Element, totalRows)
copy(fullVec[offset:offset+windowLen], window)
run.AssignColumn(colID, smartvectors.NewRegular(fullVec))
} else {
// For subsequent rounds, compute intermediate values based on previous results
ark := mimc.Constants[round-1]
nextArk := mimc.Constants[round]
tmp := intermediatePow4[round-1][k]
tmp.Square(&tmp).Square(&tmp)
// Compute intermediate result using previous result and oldState
intermediateRes[round][k] = intermediateRes[round-1][k]
intermediateRes[round][k].Add(&intermediateRes[round][k], &ark).Add(&intermediateRes[round][k], &oldState[k])
intermediateRes[round][k].Mul(&intermediateRes[round][k], &tmp)
// Compute intermediatePow4
tmp = intermediateRes[round][k]
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldState[k])
tmp.Square(&tmp).Square(&tmp)
intermediatePow4[round][k] = tmp
// Partial window: use PaddedCircularWindow with lazily evaluated padding
run.AssignColumn(colID, smartvectors.NewPaddedCircularWindow(window, padVal, offset, totalRows))
}
}
})
// Determine padding values
var resPadVal, pow4PadVal field.Element
if resPad != nil && len(resPad) > round {
resPadVal = resPad[round]
}
if pow4Pad != nil && len(pow4Pad) > round {
pow4PadVal = pow4Pad[round]
}
// Assign intermediateResult (skip round=0 as it is initialized to the blocks)
if round > 0 {
assignColumn(ctx.intermediateResult[round].GetColID(), intermediateResWindow[round], resPadVal)
}
// Assign intermediatePow4
assignColumn(ctx.intermediatePow4[round].GetColID(), intermediatePow4Window[round], pow4PadVal)
}
}
// precomputePaddingValues precomputes padding values for constant regions
func precomputePaddingValues(numRounds int) ([]field.Element, []field.Element) {
resPad := make([]field.Element, numRounds)
pow4Pad := make([]field.Element, numRounds)
resPad[0].SetZero()
var tmp field.Element
tmp.Add(&resPad[0], &mimc.Constants[0])
pow4Pad[0].Square(&tmp).Square(&pow4Pad[0])
for r := 1; r < numRounds; r++ {
tmp.Square(&pow4Pad[r-1]).Square(&tmp)
resPad[r].Add(&resPad[r-1], &mimc.Constants[r-1])
resPad[r].Mul(&resPad[r], &tmp)
tmp.Add(&resPad[r], &mimc.Constants[r])
pow4Pad[r].Square(&tmp).Square(&pow4Pad[r])
}
return resPad, pow4Pad
}
// extractWindowSlices extracts window slices from the smart vectors
func extractWindowSlices(oldStateSV, blocksSV smartvectors.SmartVector, l, h int) ([]field.Element, []field.Element) {
var (
oldStateWindow = smartvectors.IntoRegVec(oldStateSV)[l : l+h]
blocksWindow = smartvectors.IntoRegVec(blocksSV)[l : l+h]
)
return oldStateWindow, blocksWindow
}