mirror of
https://github.com/vacp2p/linea-monorepo.git
synced 2026-01-09 04:08:01 -05:00
Perf(prover): uses padded-circular windows for the MiMC assignment (#770)
* use padded circular window to reduce runtime memory * (feat): use padded circular window for mimc col assignment
This commit is contained in:
committed by
GitHub
parent
f0649e7fb8
commit
9988ebbee0
@@ -19,6 +19,10 @@ type Witness struct {
|
||||
}
|
||||
|
||||
func Prove(cfg *config.Config, req *Request, large bool) (*Response, error) {
|
||||
|
||||
// Set MonitorParams before any proving happens
|
||||
profiling.SetMonitorParams(cfg)
|
||||
|
||||
traces := &cfg.TracesLimits
|
||||
if large {
|
||||
traces = &cfg.TracesLimitsLarge
|
||||
|
||||
@@ -4,89 +4,170 @@ import (
|
||||
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
|
||||
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
|
||||
"github.com/consensys/linea-monorepo/prover/maths/field"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
|
||||
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
|
||||
"github.com/consensys/linea-monorepo/prover/utils/parallel"
|
||||
)
|
||||
|
||||
// assign assigns the columns to the prover runtime
|
||||
// assign: Assigns the columns to the prover runtime using PaddedCircularWindow
|
||||
func (ctx *mimcCtx) assign(run *wizard.ProverRuntime) {
|
||||
|
||||
var (
|
||||
oldState = ctx.oldStates.GetColAssignment(run).IntoRegVecSaveAlloc()
|
||||
blocks = ctx.blocks.GetColAssignment(run).IntoRegVecSaveAlloc()
|
||||
|
||||
// Initialize slices to hold intermediate results and intermediatePow4
|
||||
// The first entry is left empty for consistency with ctx.intermediateResult
|
||||
// We don't need to assign it because it is assigned already.
|
||||
intermediateRes = make([][]field.Element, len(ctx.intermediateResult))
|
||||
intermediatePow4 = make([][]field.Element, len(ctx.intermediateResult))
|
||||
oldStateSV = ctx.oldStates.GetColAssignment(run)
|
||||
blocksSV = ctx.blocks.GetColAssignment(run)
|
||||
totalRows = oldStateSV.Len()
|
||||
numRounds = len(ctx.intermediateResult)
|
||||
)
|
||||
|
||||
// Initialize intermediateRes and intermediatePow4 with correct lengths
|
||||
for i := range intermediateRes {
|
||||
// For each intermediate result, create a slice of field.Elements with length numRows
|
||||
intermediateRes[i] = make([]field.Element, len(oldState))
|
||||
intermediatePow4[i] = make([]field.Element, len(oldState))
|
||||
offset, windowLen := identifyActiveWindow(oldStateSV, blocksSV, totalRows)
|
||||
oldStateWindow, blocksWindow := extractWindowSlices(oldStateSV, blocksSV, offset, windowLen)
|
||||
intermediateResWindow, intermediatePow4Window := computeIntermediateValues(numRounds, oldStateWindow, blocksWindow, windowLen)
|
||||
|
||||
var resPad, pow4Pad []field.Element
|
||||
|
||||
// Precompute padding only when `PaddedCircularWindow` is tobe used
|
||||
// i.e. Whenever there is sparsity in the (oldState, blocks) pair
|
||||
if windowLen != totalRows {
|
||||
resPad, pow4Pad = precomputePaddingValues(numRounds)
|
||||
}
|
||||
|
||||
// Set the initial intermediate res as the block itself
|
||||
intermediateRes[0] = blocks
|
||||
|
||||
// Compute intermediate values for each round
|
||||
for i := range ctx.intermediateResult {
|
||||
computeIntermediateValues(i, oldState, intermediateRes, intermediatePow4)
|
||||
}
|
||||
|
||||
// Assign columns
|
||||
for i := range ctx.intermediateResult {
|
||||
// Assign computed values to the runtime
|
||||
if i > 0 {
|
||||
// Skip the first intermediate result
|
||||
// Recall that the first intermediate res is the block itself
|
||||
run.AssignColumn(
|
||||
ctx.intermediateResult[i].GetColID(),
|
||||
smartvectors.NewRegular(intermediateRes[i]),
|
||||
)
|
||||
}
|
||||
|
||||
// Assign intermediatePow4 to the runtime
|
||||
run.AssignColumn(
|
||||
ctx.intermediatePow4[i].GetColID(),
|
||||
smartvectors.NewRegular(intermediatePow4[i]),
|
||||
)
|
||||
}
|
||||
|
||||
assignOptimizedVectors(run, ctx, intermediateResWindow, intermediatePow4Window, resPad, pow4Pad, offset, totalRows)
|
||||
}
|
||||
|
||||
// computeIntermediateValues computes intermediate values for the given round
|
||||
func computeIntermediateValues(round int, oldState []field.Element, intermediateRes, intermediatePow4 [][]field.Element) {
|
||||
parallel.Execute(len(oldState), func(start, stop int) {
|
||||
for k := start; k < stop; k++ {
|
||||
if round == 0 {
|
||||
// For the first round, compute initial intermediatePow4
|
||||
tmp := intermediateRes[0][k]
|
||||
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldState[k])
|
||||
intermediatePow4[0][k].Square(&tmp).Square(&intermediatePow4[0][k])
|
||||
// identifyActiveWindow finds the smallest active window scanning through the oldState and blocks
|
||||
func identifyActiveWindow(oldStateSV, blocksSV smartvectors.SmartVector, totalRows int) (offset int, windowLen int) {
|
||||
// Convert to regular vectors to scan all elements
|
||||
var (
|
||||
oldState = smartvectors.IntoRegVec(oldStateSV)
|
||||
blocks = smartvectors.IntoRegVec(blocksSV)
|
||||
)
|
||||
|
||||
// Initialize firstNonZero and lastNonZero indices to default values
|
||||
firstNonZero, lastNonZero := totalRows, -1
|
||||
for i := 0; i < totalRows; i++ {
|
||||
if !oldState[i].IsZero() || !blocks[i].IsZero() {
|
||||
firstNonZero = min(firstNonZero, i)
|
||||
lastNonZero = max(lastNonZero, i)
|
||||
}
|
||||
}
|
||||
|
||||
if firstNonZero <= lastNonZero {
|
||||
offset = firstNonZero
|
||||
windowLen = lastNonZero - firstNonZero + 1
|
||||
return offset, windowLen
|
||||
}
|
||||
// Default window => Full window
|
||||
return 0, totalRows
|
||||
}
|
||||
|
||||
// computeIntermediateValues computes intermediate values for the window
|
||||
func computeIntermediateValues(numRounds int, oldStateWindow, blocksWindow []field.Element, windowLen int) ([][]field.Element, [][]field.Element) {
|
||||
intermediateResWindow := make([][]field.Element, numRounds)
|
||||
intermediatePow4Window := make([][]field.Element, numRounds)
|
||||
for i := range intermediateResWindow {
|
||||
intermediateResWindow[i] = make([]field.Element, windowLen)
|
||||
intermediatePow4Window[i] = make([]field.Element, windowLen)
|
||||
}
|
||||
|
||||
// Initalize intermediateResWindow to the blocksWindow
|
||||
copy(intermediateResWindow[0], blocksWindow)
|
||||
|
||||
// r => round
|
||||
for r := 0; r < numRounds; r++ {
|
||||
parallel.Execute(windowLen, func(start, stop int) {
|
||||
for k := start; k < stop; k++ {
|
||||
if r == 0 {
|
||||
tmp := intermediateResWindow[0][k]
|
||||
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldStateWindow[k])
|
||||
intermediatePow4Window[0][k].Square(&tmp).Square(&intermediatePow4Window[0][k])
|
||||
} else {
|
||||
// For subsequent rounds, compute intermediate values based on previous results
|
||||
ark := mimc.Constants[r-1]
|
||||
nextArk := mimc.Constants[r]
|
||||
|
||||
tmp := intermediatePow4Window[r-1][k]
|
||||
tmp.Square(&tmp).Square(&tmp)
|
||||
|
||||
// Compute intermediate result using previous result and oldState
|
||||
intermediateResWindow[r][k] = intermediateResWindow[r-1][k]
|
||||
intermediateResWindow[r][k].Add(&intermediateResWindow[r][k], &ark).Add(&intermediateResWindow[r][k], &oldStateWindow[k])
|
||||
intermediateResWindow[r][k].Mul(&intermediateResWindow[r][k], &tmp)
|
||||
|
||||
// Compute intermediatePow4
|
||||
tmp = intermediateResWindow[r][k]
|
||||
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldStateWindow[k])
|
||||
intermediatePow4Window[r][k].Square(&tmp).Square(&intermediatePow4Window[r][k])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
return intermediateResWindow, intermediatePow4Window
|
||||
}
|
||||
|
||||
// assignOptimizedVectors assigns optimized vectors to the prover runtime
|
||||
func assignOptimizedVectors(run *wizard.ProverRuntime, ctx *mimcCtx, intermediateResWindow, intermediatePow4Window [][]field.Element, resPad, pow4Pad []field.Element, offset, totalRows int) {
|
||||
for round := range ctx.intermediateResult {
|
||||
windowLen := len(intermediateResWindow[round])
|
||||
|
||||
// Full-length window: use Regular vector
|
||||
isRegSmartVec := windowLen == totalRows
|
||||
|
||||
// Helper function to assign a column with the appropriate smart vector
|
||||
assignColumn := func(colID ifaces.ColID, window []field.Element, padVal field.Element) {
|
||||
if isRegSmartVec {
|
||||
fullVec := make([]field.Element, totalRows)
|
||||
copy(fullVec[offset:offset+windowLen], window)
|
||||
run.AssignColumn(colID, smartvectors.NewRegular(fullVec))
|
||||
} else {
|
||||
// For subsequent rounds, compute intermediate values based on previous results
|
||||
ark := mimc.Constants[round-1]
|
||||
nextArk := mimc.Constants[round]
|
||||
|
||||
tmp := intermediatePow4[round-1][k]
|
||||
tmp.Square(&tmp).Square(&tmp)
|
||||
|
||||
// Compute intermediate result using previous result and oldState
|
||||
intermediateRes[round][k] = intermediateRes[round-1][k]
|
||||
intermediateRes[round][k].Add(&intermediateRes[round][k], &ark).Add(&intermediateRes[round][k], &oldState[k])
|
||||
intermediateRes[round][k].Mul(&intermediateRes[round][k], &tmp)
|
||||
|
||||
// Compute intermediatePow4
|
||||
tmp = intermediateRes[round][k]
|
||||
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldState[k])
|
||||
tmp.Square(&tmp).Square(&tmp)
|
||||
intermediatePow4[round][k] = tmp
|
||||
// Partial window: use PaddedCircularWindow with lazily evaluated padding
|
||||
run.AssignColumn(colID, smartvectors.NewPaddedCircularWindow(window, padVal, offset, totalRows))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Determine padding values
|
||||
var resPadVal, pow4PadVal field.Element
|
||||
if resPad != nil && len(resPad) > round {
|
||||
resPadVal = resPad[round]
|
||||
}
|
||||
if pow4Pad != nil && len(pow4Pad) > round {
|
||||
pow4PadVal = pow4Pad[round]
|
||||
}
|
||||
|
||||
// Assign intermediateResult (skip round=0 as it is initialized to the blocks)
|
||||
if round > 0 {
|
||||
assignColumn(ctx.intermediateResult[round].GetColID(), intermediateResWindow[round], resPadVal)
|
||||
}
|
||||
|
||||
// Assign intermediatePow4
|
||||
assignColumn(ctx.intermediatePow4[round].GetColID(), intermediatePow4Window[round], pow4PadVal)
|
||||
}
|
||||
}
|
||||
|
||||
// precomputePaddingValues precomputes padding values for constant regions
|
||||
func precomputePaddingValues(numRounds int) ([]field.Element, []field.Element) {
|
||||
resPad := make([]field.Element, numRounds)
|
||||
pow4Pad := make([]field.Element, numRounds)
|
||||
resPad[0].SetZero()
|
||||
|
||||
var tmp field.Element
|
||||
tmp.Add(&resPad[0], &mimc.Constants[0])
|
||||
pow4Pad[0].Square(&tmp).Square(&pow4Pad[0])
|
||||
|
||||
for r := 1; r < numRounds; r++ {
|
||||
tmp.Square(&pow4Pad[r-1]).Square(&tmp)
|
||||
resPad[r].Add(&resPad[r-1], &mimc.Constants[r-1])
|
||||
resPad[r].Mul(&resPad[r], &tmp)
|
||||
tmp.Add(&resPad[r], &mimc.Constants[r])
|
||||
pow4Pad[r].Square(&tmp).Square(&pow4Pad[r])
|
||||
}
|
||||
|
||||
return resPad, pow4Pad
|
||||
}
|
||||
|
||||
// extractWindowSlices extracts window slices from the smart vectors
|
||||
func extractWindowSlices(oldStateSV, blocksSV smartvectors.SmartVector, l, h int) ([]field.Element, []field.Element) {
|
||||
var (
|
||||
oldStateWindow = smartvectors.IntoRegVec(oldStateSV)[l : l+h]
|
||||
blocksWindow = smartvectors.IntoRegVec(blocksSV)[l : l+h]
|
||||
)
|
||||
return oldStateWindow, blocksWindow
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user