Prover: Zeroize empty Plonk witnesses for Plonk in Wizard (#3766)

* add supports for cancellable plonk
* optimize the range-checker assignment
* prover: add verifier checks for the activators
* fixup: set the check to the right column
* remove trailing printf's
This commit is contained in:
AlexandreBelling
2024-08-07 11:47:07 +02:00
committed by GitHub
parent f1fc8bcb7d
commit a0c3d3d52c
12 changed files with 203 additions and 32 deletions

View File

@@ -35,7 +35,7 @@ func NewFromPublicColumn(col ifaces.Column, pos int) ifaces.Accessor {
}
if !nat.Status().IsPublic() {
panic("expected an coin.IntegerVec")
panic("expected a public column")
}
if nat.Size() <= pos {
utils.Panic("the column has size %v, but requested position %v", nat.Size(), pos)

View File

@@ -1,6 +1,7 @@
package plonk
import (
"fmt"
"sync"
"github.com/consensys/gnark-crypto/ecc"
@@ -11,6 +12,7 @@ import (
"github.com/consensys/zkevm-monorepo/prover/protocol/column"
"github.com/consensys/zkevm-monorepo/prover/protocol/dedicated/projection"
"github.com/consensys/zkevm-monorepo/prover/protocol/ifaces"
"github.com/consensys/zkevm-monorepo/prover/protocol/query"
"github.com/consensys/zkevm-monorepo/prover/protocol/wizard"
"github.com/consensys/zkevm-monorepo/prover/symbolic"
"github.com/consensys/zkevm-monorepo/prover/utils"
@@ -61,12 +63,18 @@ type CircuitAlignmentInput struct {
// input. If it is nil, then we use zero value.
InputFiller func(circuitInstance, inputIndex int) field.Element
witnesses []witness.Witness
witnessesOnce sync.Once
witnesses []witness.Witness
witnessesOnce sync.Once
numEffWitnesses int
// nbPublicInputs is the number of public inputs for the circuit. It is
// computed from the circuit and then stored here for later use.
nbPublicInputs int
// circMaskOpenings are local opening queries over the ToCircuitMask that
// we use to checks that the "activators" of the Plonk in Wizard are
// correctly set w.r.t. circMaskOpening
circMaskOpenings []query.LocalOpening
}
func (ci *CircuitAlignmentInput) NbInstances() int {
@@ -130,6 +138,9 @@ func (ci *CircuitAlignmentInput) prepareWitnesses(run *wizard.ProverRuntime) {
close(witnessFillers[(filled-1)/ci.nbPublicInputs])
}
}
ci.numEffWitnesses = (filled-1)/ci.nbPublicInputs + 1
for filled < ci.nbPublicInputs*ci.NbCircuitInstances {
select {
case <-ctx.Done():
@@ -158,6 +169,11 @@ func (ci *CircuitAlignmentInput) Assign(run *wizard.ProverRuntime, i int) (priva
return ci.witnesses[i], ci.witnesses[i], nil
}
func (ci *CircuitAlignmentInput) NumEffWitnesses(run *wizard.ProverRuntime) int {
ci.prepareWitnesses(run)
return ci.numEffWitnesses
}
// Alignment is the prepared structure where the Data field is aligned to gnark
// circuit PI column. It considers the cases where we call multiple instances of
// the circuit so that the inputs for every circuit is padded to power of two
@@ -235,6 +251,7 @@ func DefineAlignment(comp *wizard.CompiledIOP, toAlign *CircuitAlignmentInput) *
res.csIsActive(comp)
res.csProjection(comp)
res.csProjectionSelector(comp)
res.checkActivators(comp)
return res
}
@@ -259,6 +276,7 @@ func (a *Alignment) csProjectionSelector(comp *wizard.CompiledIOP) {
func (a *Alignment) Assign(run *wizard.ProverRuntime) {
a.plonkInWizardCtx.GetPlonkProverAction().Run(run, a.CircuitAlignmentInput)
a.assignMasks(run)
a.assignCircMaskOpenings(run)
}
func (a *Alignment) assignMasks(run *wizard.ProverRuntime) {
@@ -300,6 +318,14 @@ func (a *Alignment) assignMasks(run *wizard.ProverRuntime) {
run.AssignColumn(a.ActualCircuitInputMask.GetColID(), smartvectors.NewRegular(actualCircMaskAssignment))
}
// assignCircMaskOpenings assigns the openings queries over [actualCircMaskAssignment]
func (a *Alignment) assignCircMaskOpenings(run *wizard.ProverRuntime) {
for i := range a.circMaskOpenings {
v := a.circMaskOpenings[i].Pol.GetColAssignmentAt(run, 0)
run.AssignLocalPoint(a.circMaskOpenings[i].ID, v)
}
}
// getCircuitMaskValue returns the
func getCircuitMaskValue(nbPublicInputPerCircuit, nbCircuitInstance int) smartvectors.SmartVector {
@@ -316,3 +342,55 @@ func getCircuitMaskValue(nbPublicInputPerCircuit, nbCircuitInstance int) smartve
return smartvectors.NewRegular(maskValue)
}
// check the activators are well-set w.r.t to the circuit mask column
func (ci *Alignment) checkActivators(comp *wizard.CompiledIOP) {
var (
openings = make([]query.LocalOpening, ci.NbCircuitInstances)
mask = ci.ActualCircuitInputMask
offset = utils.NextPowerOfTwo(ci.nbPublicInputs)
activators = ci.plonkInWizardCtx.Columns.Activators
round = activators[0].Round()
)
for i := range openings {
openings[i] = comp.InsertLocalOpening(
round,
ifaces.QueryIDf("%v_ACTIVATOR_LOCAL_OP_%v", ci.Name, i),
column.Shift(mask, i*offset),
)
}
ci.circMaskOpenings = openings
comp.RegisterVerifierAction(ci.Round, checkActivatorAndMask(*ci))
}
type checkActivatorAndMask Alignment
func (c checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error {
for i := range c.circMaskOpenings {
var (
valOpened = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID).Y
valActiv = c.plonkInWizardCtx.Columns.Activators[i].GetColAssignment(run).Get(0)
)
if valOpened != valActiv {
return fmt.Errorf("activator does not match the circMask %v", i)
}
}
return nil
}
func (c checkActivatorAndMask) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) {
for i := range c.circMaskOpenings {
var (
valOpened = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID).Y
valActiv = c.plonkInWizardCtx.Columns.Activators[i].GetColAssignmentGnarkAt(run, 0)
)
api.AssertIsEqual(valOpened, valActiv)
}
}

View File

@@ -45,6 +45,7 @@ func TestAlignment(t *testing.T) {
ct.Assign(run, "DATA", "DATA_MASK")
alignment.Assign(run)
})
ct.CheckAssignmentColumn(runLeaked, "IS_ACTIVE", alignment.IsActive)
ct.CheckAssignmentColumn(runLeaked, "CIRCUIT_INPUT", alignment.CircuitInput)
ct.CheckAssignmentColumn(runLeaked, "FULL_CIRCUIT_INPUT_MASK", alignment.FullCircuitInputMask)

View File

@@ -1,6 +1,7 @@
package plonk
import (
"fmt"
"sync"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop"
@@ -9,6 +10,7 @@ import (
"github.com/consensys/zkevm-monorepo/prover/maths/field"
"github.com/consensys/zkevm-monorepo/prover/protocol/accessors"
"github.com/consensys/zkevm-monorepo/prover/protocol/coin"
"github.com/consensys/zkevm-monorepo/prover/protocol/column"
"github.com/consensys/zkevm-monorepo/prover/protocol/column/verifiercol"
"github.com/consensys/zkevm-monorepo/prover/protocol/dedicated/expr_handle"
"github.com/consensys/zkevm-monorepo/prover/protocol/ifaces"
@@ -63,6 +65,8 @@ func PlonkCheck(
comp.RegisterProverAction(round+1, lroCommitProverAction{compilationCtx: ctx, proverStateLock: &sync.Mutex{}})
}
comp.RegisterVerifierAction(round, checkingActivators(ctx.Columns.Activators))
return ctx
}
@@ -86,6 +90,7 @@ func (ctx *compilationCtx) commitGateColumns() {
ctx.Columns.L = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.R = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.O = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.Activators = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.PI = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.TinyPI = make([]ifaces.Column, ctx.maxNbInstances)
ctx.Columns.Cp = make([]ifaces.Column, ctx.maxNbInstances)
@@ -103,6 +108,7 @@ func (ctx *compilationCtx) commitGateColumns() {
ctx.Columns.PI[i] = verifiercol.NewConstantCol(field.Zero(), ctx.DomainSize())
}
ctx.Columns.Cp[i] = ctx.comp.InsertCommit(ctx.round, ctx.colIDf("Cp_%v", i), ctx.DomainSize())
ctx.Columns.Activators[i] = ctx.comp.InsertProof(ctx.round, ctx.colIDf("ACTIVATOR_%v", i), 1)
}
// Second rounds, after sampling HCP
@@ -126,6 +132,7 @@ func (ctx *compilationCtx) commitGateColumns() {
ctx.Columns.L[i] = ctx.comp.InsertCommit(ctx.round, ctx.colIDf("L_%v", i), ctx.DomainSize())
ctx.Columns.R[i] = ctx.comp.InsertCommit(ctx.round, ctx.colIDf("R_%v", i), ctx.DomainSize())
ctx.Columns.O[i] = ctx.comp.InsertCommit(ctx.round, ctx.colIDf("O_%v", i), ctx.DomainSize())
ctx.Columns.Activators[i] = ctx.comp.InsertColumn(ctx.round, ctx.colIDf("ACTIVATOR_%v", i), 1, column.Proof)
}
}
}
@@ -187,13 +194,19 @@ func (ctx *compilationCtx) addGateConstraint() {
for i := 0; i < ctx.maxNbInstances; i++ {
// Declare the expression
exp := sym.Add(
sym.Mul(ctx.Columns.L[i], ctx.Columns.Ql),
sym.Mul(ctx.Columns.R[i], ctx.Columns.Qr),
sym.Mul(ctx.Columns.O[i], ctx.Columns.Qo),
sym.Mul(ctx.Columns.L[i], ctx.Columns.R[i], ctx.Columns.Qm),
ctx.Columns.PI[i],
ctx.Columns.Qk,
exp := sym.Mul(
// The conversion into an activator is required for the system
// to understand that the expression is multiplied by a scalar
// and not by a wrongfully constructed column
accessors.NewFromPublicColumn(ctx.Columns.Activators[i], 0),
sym.Add(
sym.Mul(ctx.Columns.L[i], ctx.Columns.Ql),
sym.Mul(ctx.Columns.R[i], ctx.Columns.Qr),
sym.Mul(ctx.Columns.O[i], ctx.Columns.Qo),
sym.Mul(ctx.Columns.L[i], ctx.Columns.R[i], ctx.Columns.Qm),
ctx.Columns.PI[i],
ctx.Columns.Qk,
),
)
roundLRO := ctx.round
@@ -277,3 +290,41 @@ func (ctx *compilationCtx) addCopyConstraint() {
[]ifaces.Column{l, r, o},
)
}
// checkingActivators implements the [wizard.VerifierAction] interface and
// checks that the [Activators] columns are correctly assigned
type checkingActivators []ifaces.Column
var _ wizard.VerifierAction = checkingActivators{}
func (ca checkingActivators) Run(run *wizard.VerifierRuntime) error {
for i := range ca {
curr := ca[i].GetColAssignmentAt(run, 0)
if !curr.IsOne() && !curr.IsZero() {
return fmt.Errorf("error the activators must be 0 or 1")
}
if i+1 < len(ca) {
next := ca[i+1].GetColAssignmentAt(run, 0)
if curr.IsZero() && !next.IsZero() {
return fmt.Errorf("the activators must never go from 0 to 1")
}
}
}
return nil
}
func (ca checkingActivators) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) {
for i := range ca {
curr := ca[i].GetColAssignmentGnarkAt(run, 0)
api.AssertIsBoolean(curr)
if i+1 < len(ca) {
next := ca[i+1].GetColAssignmentGnarkAt(run, 0)
api.AssertIsEqual(next, api.Mul(curr, next))
}
}
}

View File

@@ -56,6 +56,12 @@ type compilationCtx struct {
Ql, Qr, Qm, Qo, Qk, Qcp ifaces.Column
// Witness columns
L, R, O, PI, TinyPI, Cp []ifaces.Column
// Activators are tiny verifier-visible columns that are used to
// deactivate the constraints happening for constraints that are not
// happening in the system. The verifier is required to check that the
// columns are assigned to binary values and that they are structured
// as a sequence of 1s followed by a sequence of 0s.
Activators []ifaces.Column
// Columns representing the permutation
S [3]ifaces.ColAssignment
// Commitment randomness

View File

@@ -49,11 +49,23 @@ type (
// Run implements the [wizard.ProverAction] interface.
func (pa initialBBSProverAction) Run(run *wizard.ProverRuntime, wa WitnessAssigner) {
ctx := compilationCtx(pa.compilationCtx)
var (
ctx = compilationCtx(pa.compilationCtx)
numEffInstances = wa.NumEffWitnesses(run)
)
// Store the information
parallel.Execute(pa.maxNbInstances, func(start, stop int) {
for i := start; i < stop; i++ {
if i >= numEffInstances {
run.AssignColumn(ctx.Columns.TinyPI[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.TinyPI[i].Size()))
run.AssignColumn(ctx.Columns.Cp[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.Cp[i].Size()))
run.AssignColumn(ctx.Columns.Activators[i].GetColID(), smartvectors.NewConstant(field.Zero(), 1))
continue
}
// Initialize the channels
solSync := solverSync{
comChan: make(chan []field.Element, 1),
@@ -96,6 +108,7 @@ func (pa initialBBSProverAction) Run(run *wizard.ProverRuntime, wa WitnessAssign
// And assign it in the runtime
run.AssignColumn(ctx.Columns.Cp[i].GetColID(), smartvectors.NewRegular(com))
run.AssignColumn(ctx.Columns.Activators[i].GetColID(), smartvectors.NewConstant(field.One(), 1))
}
})
}
@@ -107,14 +120,23 @@ func (pa lroCommitProverAction) Run(run *wizard.ProverRuntime) {
parallel.Execute(ctx.maxNbInstances, func(start, stop int) {
for i := start; i < stop; i++ {
// Retrive the solsync
// Retrieve the solsync. Not finding it means the instance is not
// used.
pa.proverStateLock.Lock()
solsync := run.State.MustGet(ctx.Sprintf("SOLSYNC_%v", i)).(solverSync)
solsync_, foundSolSync := run.State.TryGet(ctx.Sprintf("SOLSYNC_%v", i))
run.State.TryDel(ctx.Sprintf("SOLSYNC_%v", i))
pa.proverStateLock.Unlock()
if !foundSolSync {
zeroCol := smartvectors.NewConstant(field.Zero(), ctx.Columns.L[i].Size())
run.AssignColumn(ctx.Columns.L[i].GetColID(), zeroCol)
run.AssignColumn(ctx.Columns.R[i].GetColID(), zeroCol)
run.AssignColumn(ctx.Columns.O[i].GetColID(), zeroCol)
}
// Inject the coin which will be assigned to the randomness
solsync := solsync_.(solverSync)
solsync.randChan <- run.GetRandomCoinField(ctx.Columns.Hcp.Name)
close(solsync.randChan)

View File

@@ -37,18 +37,30 @@ var (
func (pa noCommitProverAction) Run(run *wizard.ProverRuntime, wa WitnessAssigner) {
var (
ctx = compilationCtx(pa)
maxNbInstance = pa.maxNbInstances
ctx = compilationCtx(pa)
maxNbInstance = pa.maxNbInstances
numEffInstances = wa.NumEffWitnesses(run)
)
parallel.Execute(maxNbInstance, func(start, stop int) {
for i := start; i < stop; i++ {
if i >= numEffInstances {
run.AssignColumn(ctx.Columns.TinyPI[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.TinyPI[i].Size()))
run.AssignColumn(ctx.Columns.L[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.L[0].Size()))
run.AssignColumn(ctx.Columns.R[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.R[0].Size()))
run.AssignColumn(ctx.Columns.O[i].GetColID(), smartvectors.NewConstant(field.Zero(), ctx.Columns.O[0].Size()))
run.AssignColumn(ctx.Columns.Activators[i].GetColID(), smartvectors.NewConstant(field.Zero(), 1))
continue
}
// create the witness assignment
witness, pubWitness, err := wa.Assign(run, i)
if err != nil {
utils.Panic("Could not create the witness: %v", err)
}
if ctx.TinyPISize() > 0 {
// Converts it as a smart-vector
pubWitSV := smartvectors.RightZeroPadded(
[]field.Element(pubWitness.Vector().(fr.Vector)),
@@ -70,6 +82,7 @@ func (pa noCommitProverAction) Run(run *wizard.ProverRuntime, wa WitnessAssigner
run.AssignColumn(ctx.Columns.L[i].GetColID(), smartvectors.NewRegular(solution.L))
run.AssignColumn(ctx.Columns.R[i].GetColID(), smartvectors.NewRegular(solution.R))
run.AssignColumn(ctx.Columns.O[i].GetColID(), smartvectors.NewRegular(solution.O))
run.AssignColumn(ctx.Columns.Activators[i].GetColID(), smartvectors.NewConstant(field.One(), 1))
}
if ctx.RangeCheck.Enabled && !ctx.RangeCheck.wasCancelled {

View File

@@ -222,16 +222,25 @@ func (ctx *compilationCtx) assignRangeChecked(run *wizard.ProverRuntime) {
)
parallel.Execute(len(ctx.Columns.RangeChecked), func(start, stop int) {
for i := range ctx.Columns.RangeChecked {
for i := start; i < stop; i++ {
var (
l = ctx.Columns.L[i].GetColAssignment(run)
r = ctx.Columns.R[i].GetColAssignment(run)
o = ctx.Columns.O[i].GetColAssignment(run)
rcSize = ctx.Columns.RangeChecked[i].Size()
rc = make([]field.Element, 0, rcSize)
activated = ctx.Columns.Activators[i].GetColAssignment(run).Get(0)
l = ctx.Columns.L[i].GetColAssignment(run)
r = ctx.Columns.R[i].GetColAssignment(run)
o = ctx.Columns.O[i].GetColAssignment(run)
rcSize = ctx.Columns.RangeChecked[i].Size()
rc = make([]field.Element, 0, rcSize)
)
if activated.IsZero() {
run.AssignColumn(
ctx.Columns.RangeChecked[i].GetColID(),
smartvectors.NewConstant(field.Zero(), rcSize),
)
continue
}
for i := range rcLValue {
if rcLValue[i].IsOne() {
rc = append(rc, l.Get(i))

View File

@@ -12,8 +12,7 @@ import (
// WitnessAssigner allows obtaining witness assignment for a circuit.
type WitnessAssigner interface {
// NBInstance returns the number of concretely provided instances.
NbInstances() int
NumEffWitnesses(run *wizard.ProverRuntime) int
Assign(run *wizard.ProverRuntime, i int) (private, public witness.Witness, err error)
}
@@ -22,7 +21,7 @@ type witnessFuncAssigner struct {
assigners []func() frontend.Circuit
}
func (w *witnessFuncAssigner) NbInstances() int {
func (w *witnessFuncAssigner) NumEffWitnesses(run *wizard.ProverRuntime) int {
return len(w.assigners)
}

View File

@@ -82,9 +82,6 @@ func SerializeValue(v reflect.Value, mode mode) (json.RawMessage, error) {
}
concrete := v.Elem()
if fmt.Sprintf("%++v", concrete) == "<invalid reflect.Value>" {
fmt.Printf("Parent(v) = %++v\n", v)
}
rawValue, err := SerializeValue(concrete, mode)
if err != nil {

View File

@@ -330,8 +330,6 @@ func isFactored(e *sym.Expression, exponentsOfGroup map[field.Element]int) (
return nil, false
}
fmt.Printf("returning the %++v with exponents %++v\n", e.Children, factoredExponents)
return sym.NewProduct(e.Children, factoredExponents), true
}

View File

@@ -1,7 +1,6 @@
package statesummary
import (
"fmt"
"io"
"sync"
@@ -143,8 +142,6 @@ func (ss *stateSummaryAssignmentBuilder) pushBlockTraces(batchNumber int, traces
// a block.
func (ss *stateSummaryAssignmentBuilder) pushAccountSegment(batchNumber int, segment accountSegmentWitness) {
fmt.Printf("ss accumulator statement = %v\n", ss.accumulatorStatement)
for segID, seg := range segment {
var (