mirror of
https://github.com/chancehudson/circom-stark.git
synced 2026-01-10 06:17:54 -05:00
refactor: witness calculator improvement
This commit is contained in:
18
package-lock.json
generated
18
package-lock.json
generated
@@ -35,7 +35,8 @@
|
||||
},
|
||||
"node_modules/@noble/hashes": {
|
||||
"version": "1.3.2",
|
||||
"license": "MIT",
|
||||
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz",
|
||||
"integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ==",
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
},
|
||||
@@ -1219,7 +1220,8 @@
|
||||
},
|
||||
"node_modules/randomf": {
|
||||
"version": "0.0.3",
|
||||
"license": "MIT"
|
||||
"resolved": "https://registry.npmjs.org/randomf/-/randomf-0.0.3.tgz",
|
||||
"integrity": "sha512-hcZAFZNPZ3YeNpYLkC3MYjVedW4Hb4I0OhQdEHRd/UdDM8roGRH1vCtt8/UTaIHy4VKPLh6kfzS9ljzlMD2Y2Q=="
|
||||
},
|
||||
"node_modules/readdirp": {
|
||||
"version": "3.6.0",
|
||||
@@ -1385,7 +1387,7 @@
|
||||
},
|
||||
"node_modules/starkstark": {
|
||||
"version": "0.0.0",
|
||||
"resolved": "git+ssh://git@github.com/vimwitch/starkstark.git#a761cb5c957f4810444fbcec61b3c2135b73f85b",
|
||||
"resolved": "git+ssh://git@github.com/vimwitch/starkstark.git#683bf06ef15882a67f86a26a33efedc5055df885",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@noble/hashes": "^1.3.1",
|
||||
@@ -1691,7 +1693,9 @@
|
||||
}
|
||||
},
|
||||
"@noble/hashes": {
|
||||
"version": "1.3.2"
|
||||
"version": "1.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz",
|
||||
"integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ=="
|
||||
},
|
||||
"@nodelib/fs.scandir": {
|
||||
"version": "2.1.5",
|
||||
@@ -2353,7 +2357,9 @@
|
||||
}
|
||||
},
|
||||
"randomf": {
|
||||
"version": "0.0.3"
|
||||
"version": "0.0.3",
|
||||
"resolved": "https://registry.npmjs.org/randomf/-/randomf-0.0.3.tgz",
|
||||
"integrity": "sha512-hcZAFZNPZ3YeNpYLkC3MYjVedW4Hb4I0OhQdEHRd/UdDM8roGRH1vCtt8/UTaIHy4VKPLh6kfzS9ljzlMD2Y2Q=="
|
||||
},
|
||||
"readdirp": {
|
||||
"version": "3.6.0",
|
||||
@@ -2439,7 +2445,7 @@
|
||||
}
|
||||
},
|
||||
"starkstark": {
|
||||
"version": "git+ssh://git@github.com/vimwitch/starkstark.git#a761cb5c957f4810444fbcec61b3c2135b73f85b",
|
||||
"version": "git+ssh://git@github.com/vimwitch/starkstark.git#683bf06ef15882a67f86a26a33efedc5055df885",
|
||||
"from": "starkstark@vimwitch/starkstark#main",
|
||||
"requires": {
|
||||
"@noble/hashes": "^1.3.1",
|
||||
|
||||
132
src/r1cs.mjs
132
src/r1cs.mjs
@@ -3,7 +3,12 @@ import { ScalarField } from 'starkstark/src/ScalarField.mjs'
|
||||
import { MultiPolynomial } from 'starkstark/src/MultiPolynomial.mjs'
|
||||
import { Polynomial } from 'starkstark/src/Polynomial.mjs'
|
||||
|
||||
function buildWitness(data, input = [], baseField) {
|
||||
// take only the input variables and solve a system of
|
||||
// equations (the constraints) to return a complete witness
|
||||
//
|
||||
// TODO: exploit structure of the constraint to solve more efficiently
|
||||
// when possible
|
||||
async function buildWitness(data, input = [], baseField) {
|
||||
const {
|
||||
nOutputs,
|
||||
nPubInputs,
|
||||
@@ -16,6 +21,7 @@ function buildWitness(data, input = [], baseField) {
|
||||
vars[1+nOutputs+x] = input[x]
|
||||
}
|
||||
const constraints = []
|
||||
// turn the raw constraints into multivariate polynomials
|
||||
for (const constraint of data.constraints) {
|
||||
const [a, b, c] = constraint
|
||||
const polyA = new MultiPolynomial(baseField)
|
||||
@@ -31,44 +37,108 @@ function buildWitness(data, input = [], baseField) {
|
||||
polyC.term({ coef: val, exps: { [Number(key)]: 1n } })
|
||||
}
|
||||
const finalPoly = polyA.copy().mul(polyB).sub(polyC)
|
||||
finalPoly.originalConstraint = constraint
|
||||
constraints.push(finalPoly)
|
||||
}
|
||||
// build a map of what constraints need what variables
|
||||
for (const [i, constraint] of Object.entries(constraints)) {
|
||||
const unknownVars = []
|
||||
const unknownVarsMap = {}
|
||||
constraint.vars = []
|
||||
const [a, b, c] = constraint.originalConstraint
|
||||
for (const obj of [a, b, c].flat()) {
|
||||
for (const key of Object.keys(obj)) {
|
||||
constraint.vars[+key] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// iterate over the set of constraints
|
||||
// look for constraints that have only 1 unknown
|
||||
// and solve for that unknown
|
||||
// then iterate again
|
||||
// repeat until all variables are known
|
||||
while (vars.indexOf(null) !== -1) {
|
||||
let solvedCount = 0
|
||||
for (const constraint of constraints) {
|
||||
let unknownVars = []
|
||||
for (const [key, coef] of constraint.expMap.entries()) {
|
||||
const v = MultiPolynomial.expStringToVector(key)
|
||||
for (const [varIndex, degree] of Object.entries(v)) {
|
||||
if (degree > 0n && vars[+varIndex] === null && unknownVars.indexOf(+varIndex) === -1) {
|
||||
unknownVars.push(+varIndex)
|
||||
const solved = []
|
||||
// determine what variables are unknown. If there is not
|
||||
// exactly 1 unknown then skip
|
||||
for (const [key, constraint] of Object.entries(constraints)) {
|
||||
const unknownVarsMap = {}
|
||||
for (const [varIndex, ] of Object.entries(constraint.vars)) {
|
||||
if (vars[varIndex] === null) {
|
||||
unknownVarsMap[varIndex] = true
|
||||
if (Object.keys(unknownVarsMap).length > 1) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
const unknownVars = Object.keys(unknownVarsMap).map(v => +v)
|
||||
if (unknownVars.length >= 2 || unknownVars.length === 0) continue
|
||||
// otherwise solve
|
||||
const cc = constraint.copy()
|
||||
for (const [i, v] of Object.entries(vars)) {
|
||||
if (v === null) continue
|
||||
cc.evaluateSingle(BigInt(v), Number(i))
|
||||
cc.evaluatePartial(vars)
|
||||
// we should end up with either
|
||||
// 0 = x + c
|
||||
// or
|
||||
// 0 = x^2 + x
|
||||
// or a constraint with c = 0
|
||||
// in which case we solve a and b
|
||||
let out
|
||||
if (cc.expMap.size === 3) {
|
||||
const [a, b, c] = constraint.originalConstraint
|
||||
// check that c = 0
|
||||
if (Object.keys(c).length !== 0) {
|
||||
throw new Error('expected c = 0 in quadratic polynomial constraint')
|
||||
}
|
||||
// solve for the unknown in a and b
|
||||
const polyA = new Polynomial(baseField)
|
||||
const polyB = new Polynomial(baseField)
|
||||
for (const [key, val] of Object.entries(a)) {
|
||||
if (+unknownVars[0] === +key) {
|
||||
polyA.term({ coef: val, exp: 1n })
|
||||
} else {
|
||||
polyA.term({ coef: baseField.mul(val, vars[+key]), exp: 0n })
|
||||
}
|
||||
}
|
||||
for (const [key, val] of Object.entries(b)) {
|
||||
if (+unknownVars[0] === +key) {
|
||||
polyB.term({ coef: val, exp: 1n })
|
||||
} else {
|
||||
polyB.term({ coef: baseField.mul(val, vars[+key]), exp: 0n })
|
||||
}
|
||||
}
|
||||
out = polyA.solve() ?? polyB.solve()
|
||||
} else {
|
||||
const _expKey = Array(unknownVars[0]).fill(0n)
|
||||
_expKey.push(1n)
|
||||
const expKey = MultiPolynomial.expVectorToString(_expKey)
|
||||
if (cc.expMap.size !== 2) {
|
||||
continue
|
||||
throw new Error('expected exactly 2 remaining terms')
|
||||
}
|
||||
if (!cc.expMap.has(expKey)) {
|
||||
throw new Error('cannot find remaining variable')
|
||||
}
|
||||
if (cc.expMap.has(expKey.replace('1', '2'))) {
|
||||
// we're in the case of 0 = x^2 + x
|
||||
// reduce by dividing an x out
|
||||
cc.expMap.set('0', cc.expMap.get(expKey))
|
||||
cc.expMap.set(expKey, cc.expMap.get(expKey.replace('1', '2')))
|
||||
cc.expMap.delete(expKey.replace('1', '2'))
|
||||
}
|
||||
if (!cc.expMap.has('0')) {
|
||||
throw new Error('exactly one term should be a constant')
|
||||
}
|
||||
out = cc.field.div(cc.field.neg(cc.expMap.get('0')), cc.expMap.get(expKey))
|
||||
}
|
||||
if (cc.expMap.size !== 2) {
|
||||
throw new Error('expected exactly 2 remaining terms')
|
||||
}
|
||||
if (!cc.expMap.has('0')) {
|
||||
throw new Error('exactly one term should be a constant')
|
||||
}
|
||||
const _expKey = Array(unknownVars[0]).fill(0n)
|
||||
_expKey.push(1n)
|
||||
const expKey = MultiPolynomial.expVectorToString(_expKey)
|
||||
if (!cc.expMap.has(expKey)) {
|
||||
throw new Error('cannot find remaining variable')
|
||||
}
|
||||
const out = cc.field.div(cc.field.neg(cc.expMap.get('0')), cc.expMap.get(expKey))
|
||||
vars[unknownVars[0]] = out
|
||||
solvedCount++
|
||||
solved.push(constraint)
|
||||
}
|
||||
if (solvedCount === 0)
|
||||
for (const c of solved) {
|
||||
const i = constraints.indexOf(c)
|
||||
constraints.splice(i, 1)
|
||||
}
|
||||
await new Promise(r => setTimeout(r, 10))
|
||||
if (solved.length === 0)
|
||||
throw new Error('Unable to solve for remaining variables')
|
||||
}
|
||||
return vars
|
||||
@@ -85,7 +155,7 @@ export async function compileR1cs(file, input = [], memoryOverride) {
|
||||
nVars
|
||||
} = data
|
||||
const baseField = new ScalarField(prime)
|
||||
const memory = memoryOverride ? memoryOverride : buildWitness(data, input, baseField)
|
||||
const memory = memoryOverride ? memoryOverride : (await buildWitness(data, input, baseField))
|
||||
const negOne = baseField.neg(1n)
|
||||
// order of variables
|
||||
// ONE, outputs, pub inputs, prv inputs
|
||||
@@ -116,10 +186,14 @@ export async function compileR1cs(file, input = [], memoryOverride) {
|
||||
asm.push(`add ${scratch0} ${scratch0} ${scratch3}`)
|
||||
asm.push(`eq ${scratch0} ${scratch1}`)
|
||||
}
|
||||
console.log(`Proving ${asm.length} steps with ${memory.length} memory slots`)
|
||||
await new Promise(r => setTimeout(r, 100))
|
||||
return asm.join('\n')
|
||||
|
||||
}
|
||||
|
||||
// mulsum - multiply 2 numbers and add them to a third register
|
||||
// abc - constrain 3 registers to a*b - c = 0
|
||||
|
||||
function sum(scratch0, scratch, map, negOne) {
|
||||
if (Object.keys(map).length === 0) {
|
||||
return [`set ${scratch} 0`]
|
||||
|
||||
Reference in New Issue
Block a user