refactor: witness calculator improvement

This commit is contained in:
Chance
2023-10-10 20:05:52 -05:00
parent 90b496a2d7
commit 96b7cb66fb
2 changed files with 115 additions and 35 deletions

18
package-lock.json generated
View File

@@ -35,7 +35,8 @@
},
"node_modules/@noble/hashes": {
"version": "1.3.2",
"license": "MIT",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz",
"integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ==",
"engines": {
"node": ">= 16"
},
@@ -1219,7 +1220,8 @@
},
"node_modules/randomf": {
"version": "0.0.3",
"license": "MIT"
"resolved": "https://registry.npmjs.org/randomf/-/randomf-0.0.3.tgz",
"integrity": "sha512-hcZAFZNPZ3YeNpYLkC3MYjVedW4Hb4I0OhQdEHRd/UdDM8roGRH1vCtt8/UTaIHy4VKPLh6kfzS9ljzlMD2Y2Q=="
},
"node_modules/readdirp": {
"version": "3.6.0",
@@ -1385,7 +1387,7 @@
},
"node_modules/starkstark": {
"version": "0.0.0",
"resolved": "git+ssh://git@github.com/vimwitch/starkstark.git#a761cb5c957f4810444fbcec61b3c2135b73f85b",
"resolved": "git+ssh://git@github.com/vimwitch/starkstark.git#683bf06ef15882a67f86a26a33efedc5055df885",
"license": "ISC",
"dependencies": {
"@noble/hashes": "^1.3.1",
@@ -1691,7 +1693,9 @@
}
},
"@noble/hashes": {
"version": "1.3.2"
"version": "1.3.2",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz",
"integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ=="
},
"@nodelib/fs.scandir": {
"version": "2.1.5",
@@ -2353,7 +2357,9 @@
}
},
"randomf": {
"version": "0.0.3"
"version": "0.0.3",
"resolved": "https://registry.npmjs.org/randomf/-/randomf-0.0.3.tgz",
"integrity": "sha512-hcZAFZNPZ3YeNpYLkC3MYjVedW4Hb4I0OhQdEHRd/UdDM8roGRH1vCtt8/UTaIHy4VKPLh6kfzS9ljzlMD2Y2Q=="
},
"readdirp": {
"version": "3.6.0",
@@ -2439,7 +2445,7 @@
}
},
"starkstark": {
"version": "git+ssh://git@github.com/vimwitch/starkstark.git#a761cb5c957f4810444fbcec61b3c2135b73f85b",
"version": "git+ssh://git@github.com/vimwitch/starkstark.git#683bf06ef15882a67f86a26a33efedc5055df885",
"from": "starkstark@vimwitch/starkstark#main",
"requires": {
"@noble/hashes": "^1.3.1",

View File

@@ -3,7 +3,12 @@ import { ScalarField } from 'starkstark/src/ScalarField.mjs'
import { MultiPolynomial } from 'starkstark/src/MultiPolynomial.mjs'
import { Polynomial } from 'starkstark/src/Polynomial.mjs'
function buildWitness(data, input = [], baseField) {
// take only the input variables and solve a system of
// equations (the constraints) to return a complete witness
//
// TODO: exploit structure of the constraint to solve more efficiently
// when possible
async function buildWitness(data, input = [], baseField) {
const {
nOutputs,
nPubInputs,
@@ -16,6 +21,7 @@ function buildWitness(data, input = [], baseField) {
vars[1+nOutputs+x] = input[x]
}
const constraints = []
// turn the raw constraints into multivariate polynomials
for (const constraint of data.constraints) {
const [a, b, c] = constraint
const polyA = new MultiPolynomial(baseField)
@@ -31,44 +37,108 @@ function buildWitness(data, input = [], baseField) {
polyC.term({ coef: val, exps: { [Number(key)]: 1n } })
}
const finalPoly = polyA.copy().mul(polyB).sub(polyC)
finalPoly.originalConstraint = constraint
constraints.push(finalPoly)
}
// build a map of what constraints need what variables
for (const [i, constraint] of Object.entries(constraints)) {
const unknownVars = []
const unknownVarsMap = {}
constraint.vars = []
const [a, b, c] = constraint.originalConstraint
for (const obj of [a, b, c].flat()) {
for (const key of Object.keys(obj)) {
constraint.vars[+key] = true
}
}
}
// iterate over the set of constraints
// look for constraints that have only 1 unknown
// and solve for that unknown
// then iterate again
// repeat until all variables are known
while (vars.indexOf(null) !== -1) {
let solvedCount = 0
for (const constraint of constraints) {
let unknownVars = []
for (const [key, coef] of constraint.expMap.entries()) {
const v = MultiPolynomial.expStringToVector(key)
for (const [varIndex, degree] of Object.entries(v)) {
if (degree > 0n && vars[+varIndex] === null && unknownVars.indexOf(+varIndex) === -1) {
unknownVars.push(+varIndex)
const solved = []
// determine what variables are unknown. If there is not
// exactly 1 unknown then skip
for (const [key, constraint] of Object.entries(constraints)) {
const unknownVarsMap = {}
for (const [varIndex, ] of Object.entries(constraint.vars)) {
if (vars[varIndex] === null) {
unknownVarsMap[varIndex] = true
if (Object.keys(unknownVarsMap).length > 1) {
break
}
}
}
const unknownVars = Object.keys(unknownVarsMap).map(v => +v)
if (unknownVars.length >= 2 || unknownVars.length === 0) continue
// otherwise solve
const cc = constraint.copy()
for (const [i, v] of Object.entries(vars)) {
if (v === null) continue
cc.evaluateSingle(BigInt(v), Number(i))
cc.evaluatePartial(vars)
// we should end up with either
// 0 = x + c
// or
// 0 = x^2 + x
// or a constraint with c = 0
// in which case we solve a and b
let out
if (cc.expMap.size === 3) {
const [a, b, c] = constraint.originalConstraint
// check that c = 0
if (Object.keys(c).length !== 0) {
throw new Error('expected c = 0 in quadratic polynomial constraint')
}
// solve for the unknown in a and b
const polyA = new Polynomial(baseField)
const polyB = new Polynomial(baseField)
for (const [key, val] of Object.entries(a)) {
if (+unknownVars[0] === +key) {
polyA.term({ coef: val, exp: 1n })
} else {
polyA.term({ coef: baseField.mul(val, vars[+key]), exp: 0n })
}
}
for (const [key, val] of Object.entries(b)) {
if (+unknownVars[0] === +key) {
polyB.term({ coef: val, exp: 1n })
} else {
polyB.term({ coef: baseField.mul(val, vars[+key]), exp: 0n })
}
}
out = polyA.solve() ?? polyB.solve()
} else {
const _expKey = Array(unknownVars[0]).fill(0n)
_expKey.push(1n)
const expKey = MultiPolynomial.expVectorToString(_expKey)
if (cc.expMap.size !== 2) {
continue
throw new Error('expected exactly 2 remaining terms')
}
if (!cc.expMap.has(expKey)) {
throw new Error('cannot find remaining variable')
}
if (cc.expMap.has(expKey.replace('1', '2'))) {
// we're in the case of 0 = x^2 + x
// reduce by dividing an x out
cc.expMap.set('0', cc.expMap.get(expKey))
cc.expMap.set(expKey, cc.expMap.get(expKey.replace('1', '2')))
cc.expMap.delete(expKey.replace('1', '2'))
}
if (!cc.expMap.has('0')) {
throw new Error('exactly one term should be a constant')
}
out = cc.field.div(cc.field.neg(cc.expMap.get('0')), cc.expMap.get(expKey))
}
if (cc.expMap.size !== 2) {
throw new Error('expected exactly 2 remaining terms')
}
if (!cc.expMap.has('0')) {
throw new Error('exactly one term should be a constant')
}
const _expKey = Array(unknownVars[0]).fill(0n)
_expKey.push(1n)
const expKey = MultiPolynomial.expVectorToString(_expKey)
if (!cc.expMap.has(expKey)) {
throw new Error('cannot find remaining variable')
}
const out = cc.field.div(cc.field.neg(cc.expMap.get('0')), cc.expMap.get(expKey))
vars[unknownVars[0]] = out
solvedCount++
solved.push(constraint)
}
if (solvedCount === 0)
for (const c of solved) {
const i = constraints.indexOf(c)
constraints.splice(i, 1)
}
await new Promise(r => setTimeout(r, 10))
if (solved.length === 0)
throw new Error('Unable to solve for remaining variables')
}
return vars
@@ -85,7 +155,7 @@ export async function compileR1cs(file, input = [], memoryOverride) {
nVars
} = data
const baseField = new ScalarField(prime)
const memory = memoryOverride ? memoryOverride : buildWitness(data, input, baseField)
const memory = memoryOverride ? memoryOverride : (await buildWitness(data, input, baseField))
const negOne = baseField.neg(1n)
// order of variables
// ONE, outputs, pub inputs, prv inputs
@@ -116,10 +186,14 @@ export async function compileR1cs(file, input = [], memoryOverride) {
asm.push(`add ${scratch0} ${scratch0} ${scratch3}`)
asm.push(`eq ${scratch0} ${scratch1}`)
}
console.log(`Proving ${asm.length} steps with ${memory.length} memory slots`)
await new Promise(r => setTimeout(r, 100))
return asm.join('\n')
}
// mulsum - multiply 2 numbers and add them to a third register
// abc - constrain 3 registers to a*b - c = 0
function sum(scratch0, scratch, map, negOne) {
if (Object.keys(map).length === 0) {
return [`set ${scratch} 0`]