Merge pull request #5 from erhant/erhant/outputs-and-rfk

Easy outputs via witness parsing
This commit is contained in:
erhant.eth
2023-05-20 00:35:21 +03:00
committed by GitHub
12 changed files with 192 additions and 68 deletions

View File

@@ -34,7 +34,8 @@
- [x] **Easily Configurable**: Just change the configured proof-system & elliptic curve at [`.cli.env`](./.cli.env) and you are good to go.
- [x] **Witness Testing**: You can test computations & assertions for every template in a circuit, with minimal code-repetition.
- [x] **Proof Testing**: With prover & verification keys and the WASM circuit, you can test proof generation & verification.
- [x] **Type-safe**: Witness & proof testers, as well as circuit signal inputs & outputs are type-safe.
- [x] **Simple Outputs**: Easily see the output signals of your circuit, without generating a proof.
- [x] **Type-safe**: Witness & proof testers, as well as circuit signal inputs & outputs are all type-safe via generics.
- [x] **Solidity Exports**: Export a verifier contract in Solidity, or export a calldata for your proofs & public signals.
## Usage
@@ -133,7 +134,7 @@ To run a circuit, you need to create a `main` component in Circom, where your ma
```ts
import {instantiate} from '../utils/instantiate';
import {createWasmTester} from '../utils/wasmTester';
import {WasmTester} from '../utils/wasmTester';
describe('multiplier', () => {
// templates parameters!
@@ -150,7 +151,7 @@ describe('multiplier', () => {
publicInputs: [],
templateParams: [N],
});
circuit = await createWasmTester(circuitName);
circuit = await WasmTester.new(circuitName);
// constraint count checks!
await circuit.checkConstraintCount(N - 1);
@@ -183,6 +184,22 @@ With the circuit object, we can do the following:
- `circuit.expectCorrectAssert(input)` to test whether the circuit assertions pass for some given input
- `circuit.expectFailedAssert(input)` to test whether the circuit assertions pass for some given input
#### Circuit outputs
What if we would just like to see what the output is, instead of comparing it to some witness? Well, that would be a trouble because we would have to parse the witness array (which is huge for some circuits) with respect to which signals the output signals correspond to. Thankfully, Circomkit has a function for that:
```ts
const outputSignals = ['foo', 'bar'];
const output = await circuit.compute(INPUT, outputSignals);
console.log(output);
/* {
foo: [[1n, 2n], [3n, 4n]]
bar: 48n
} */
```
Note that this operation requires parsing the symbols file (`.sym`) and reading the witness array, which may be costly for large circuits. Most of the time, you won't need this for testing; instead, you will likely use it to see what the circuit actually does for debugging.
#### Multiple templates
You will often have multiple templates in your circuit code, and you might want to test them in the same test file of your main circuit too. Well, you can!
@@ -204,7 +221,7 @@ describe('multiplier utilities', () => {
},
'test/multiplier'
);
circuit = await createWasmTester(circuitName, 'test/multiplier');
circuit = await WasmTester.new(circuitName, 'test/multiplier');
});
it('should pass for in range', async () => {
@@ -240,7 +257,7 @@ describe('multiplier proofs', () => {
await circuit.expectVerificationPass(fullProof.proof, fullProof.publicSignals);
});
it('should NOT verify a wrong multiplication', async () => {
it('should NOT verify', async () => {
// just give a prime number as the output, assuming none of the inputs are 1
await circuit.expectVerificationFail(fullProof.proof, ['13']);
});

View File

@@ -5,7 +5,7 @@
"author": "erhant",
"license": "MIT",
"engines": {
"node": ">=10.4.0"
"node": ">=12.0.0"
},
"scripts": {
"test:all": "npx mocha",

View File

@@ -2,4 +2,4 @@ $$
\left(g^\tau, g^{\tau^2}, g^{\tau^3}, g^{\tau^4}, \ldots, g^{\tau^d}\right)
$$
For circuit-specific setup phase, you need to have `ptau` files for the first phase. See [SnarkJS docs](https://github.com/iden3/snarkjs#7-prepare-phase-2) for more information.
For circuit setups, you need to have `ptau` files from the universal setup phase. See [SnarkJS docs](https://github.com/iden3/snarkjs#7-prepare-phase-2) for more information.

View File

@@ -1,5 +1,5 @@
import {instantiate} from '../utils/instantiate';
import {WasmTester, createWasmTester} from '../utils/wasmTester';
import instantiate from '../utils/instantiate';
import WasmTester from '../utils/wasmTester';
const CIRCUIT_FILE = 'fibonacci';
describe(CIRCUIT_FILE, () => {
@@ -14,7 +14,7 @@ describe(CIRCUIT_FILE, () => {
publicInputs: [],
templateParams: [N],
});
circuit = await createWasmTester(circuitName);
circuit = await WasmTester.new(circuitName);
await circuit.checkConstraintCount();
});
@@ -43,7 +43,7 @@ describe.skip(CIRCUIT_FILE + ' recursive', () => {
publicInputs: [],
templateParams: [N],
});
circuit = await createWasmTester(circuitName);
circuit = await WasmTester.new(circuitName);
await circuit.checkConstraintCount();
});

View File

@@ -1,5 +1,5 @@
import {instantiate} from '../utils/instantiate';
import {WasmTester, createWasmTester} from '../utils/wasmTester';
import instantiate from '../utils/instantiate';
import WasmTester from '../utils/wasmTester';
// tests adapted from https://github.com/rdi-berkeley/zkp-mooc-lab
@@ -25,7 +25,7 @@ describe('float_add 32-bit', () => {
publicInputs: [],
templateParams: [k, p],
});
circuit = await createWasmTester('fp32');
circuit = await WasmTester.new('fp32');
await circuit.checkConstraintCount(401);
});
@@ -37,6 +37,16 @@ describe('float_add 32-bit', () => {
},
{e_out: '43', m_out: '11672136'}
);
console.log(
await circuit.compute(
{
e: ['43', '5'],
m: ['11672136', '10566265'],
},
['e_out', 'm_out']
)
);
});
it('case II test 1', async () => {
@@ -130,7 +140,7 @@ describe('float_add 64-bit', () => {
publicInputs: [],
templateParams: [k, p],
});
circuit = await createWasmTester('fp64');
circuit = await WasmTester.new('fp64');
await circuit.checkConstraintCount(819);
});
@@ -226,11 +236,11 @@ describe('float_add utilities', () => {
},
'test/float_add'
);
circuit = await createWasmTester(circuitName, 'test/float_add');
circuit = await WasmTester.new(circuitName, 'test/float_add');
await circuit.checkConstraintCount(expectedConstraints.checkBitLength(b));
});
it('should give 1 for in <= b', async () => {
it('should give 1 for in b', async () => {
await circuit.expectCorrectAssert(
{
in: '4903265',
@@ -265,7 +275,7 @@ describe('float_add utilities', () => {
},
'test/float_add'
);
circuit = await createWasmTester(circuitName, 'test/float_add');
circuit = await WasmTester.new(circuitName, 'test/float_add');
await circuit.checkConstraintCount(expectedConstraints.leftShift(shift_bound));
});
@@ -325,7 +335,7 @@ describe('float_add utilities', () => {
},
'test/float_add'
);
circuit = await createWasmTester(circuitName, 'test/float_add');
circuit = await WasmTester.new(circuitName, 'test/float_add');
await circuit.checkConstraintCount(b);
});
@@ -363,7 +373,7 @@ describe('float_add utilities', () => {
},
'test/float_add'
);
circuit = await createWasmTester(circuitName, 'test/float_add');
circuit = await WasmTester.new(circuitName, 'test/float_add');
await circuit.checkConstraintCount(expectedConstraints.normalize(P));
});
@@ -422,7 +432,7 @@ describe('float_add utilities', () => {
},
'test/float_add'
);
circuit = await createWasmTester(circuitName, 'test/float_add');
circuit = await WasmTester.new(circuitName, 'test/float_add');
await circuit.checkConstraintCount(expectedConstraints.msnzb(b));
});

View File

@@ -1,7 +1,7 @@
import {WasmTester, createWasmTester} from '../utils/wasmTester';
import {ProofTester} from '../utils/proofTester';
import WasmTester from '../utils/wasmTester';
import ProofTester from '../utils/proofTester';
import type {FullProof} from '../types/circuit';
import {instantiate} from '../utils/instantiate';
import instantiate from '../utils/instantiate';
describe('multiplier', () => {
// templates parameters!
@@ -18,7 +18,7 @@ describe('multiplier', () => {
publicInputs: [],
templateParams: [N],
});
circuit = await createWasmTester(circuitName);
circuit = await WasmTester.new(circuitName);
// constraint count checks!
await circuit.checkConstraintCount(N - 1);
@@ -53,7 +53,7 @@ describe('multiplier utilities', () => {
},
'test/multiplier'
);
circuit = await createWasmTester(circuitName, 'test/multiplier');
circuit = await WasmTester.new(circuitName, 'test/multiplier');
});
it('should multiply correctly', async () => {
@@ -83,7 +83,7 @@ describe('multiplier proofs', () => {
await circuit.expectVerificationPass(fullProof.proof, fullProof.publicSignals);
});
it('should NOT verify a wrong multiplication', async () => {
it('should NOT verify', async () => {
// just give a prime number as the output, assuming none of the inputs are 1
await circuit.expectVerificationFail(fullProof.proof, ['13']);
});

View File

@@ -1,5 +1,5 @@
import {instantiate} from '../utils/instantiate';
import {WasmTester, createWasmTester} from '../utils/wasmTester';
import instantiate from '../utils/instantiate';
import WasmTester from '../utils/wasmTester';
type BoardSizes = 4 | 9;
@@ -47,7 +47,7 @@ const INPUTS: {[N in BoardSizes]: any} = {
([4, 9] as BoardSizes[]).map(N =>
describe(`sudoku (${N} by ${N})`, () => {
const INPUT = INPUTS[N];
let circuit: WasmTester<['solution', 'puzzle'], []>;
let circuit: WasmTester<['solution', 'puzzle']>;
before(async () => {
const circuitName = `sudoku_${N}x${N}`;
@@ -57,7 +57,7 @@ const INPUTS: {[N in BoardSizes]: any} = {
publicInputs: ['puzzle'],
templateParams: [Math.sqrt(N)],
});
circuit = await createWasmTester(circuitName);
circuit = await WasmTester.new(circuitName);
await circuit.checkConstraintCount();
});
@@ -115,7 +115,7 @@ describe('sudoku utilities', () => {
},
'test/sudoku'
);
circuit = await createWasmTester(circuitName, 'test/sudoku');
circuit = await WasmTester.new(circuitName, 'test/sudoku');
});
it('should pass for input < 2^b', async () => {
@@ -151,7 +151,7 @@ describe('sudoku utilities', () => {
},
'test/sudoku'
);
circuit = await createWasmTester(circuitName, 'test/sudoku');
circuit = await WasmTester.new(circuitName, 'test/sudoku');
});
it('should pass if all inputs are unique', async () => {
@@ -191,7 +191,7 @@ describe('sudoku utilities', () => {
},
'test/sudoku'
);
circuit = await createWasmTester(circuitName, 'test/sudoku');
circuit = await WasmTester.new(circuitName, 'test/sudoku');
});
it('should pass for in range', async () => {

View File

@@ -1,15 +1,23 @@
type IntegerValue = `${number}` | number | bigint;
type SignalValue = IntegerValue | SignalValue[];
/**
* An integer value is a numerical string, a number, or a bigint.
*/
type IntegerValueType = `${number}` | number | bigint;
/**
* A signal value is a number, or an array of numbers (recursively).
*/
export type SignalValueType = IntegerValueType | SignalValueType[];
/**
* An object with string keys and array of numerical values.
* Each key represents a signal name as it appears in the circuit.
*
* By default, signal names are not typed, but you can pass an array of signal names
* to make them type-safe, e.g. `CircuitSignals<['sig1', 'sig2']>`
*/
export type CircuitSignals<T extends readonly string[] = []> = T extends []
? {[signal: string]: SignalValue}
: {[signal in T[number]]: SignalValue};
? {[signal: string]: SignalValueType}
: {[signal in T[number]]: SignalValueType};
/**
* A witness is an array of bigints, corresponding to the values of each wire in
@@ -17,6 +25,18 @@ export type CircuitSignals<T extends readonly string[] = []> = T extends []
*/
export type WitnessType = bigint[];
/**
* Symbols are a mapping of each circuit `wire` to an object with three keys. Within them,
* the most important is `varIdx` which indicates the position of this signal in the witness array.
*/
export type SymbolsType = {
[symbol: string]: {
labelIdx: number;
varIdx: number;
componentIdx: number;
};
};
/**
* A FullProof, as returned from SnarkJS `fullProve` function.
*/
@@ -26,6 +46,6 @@ export type FullProof = {
};
/**
* Proof system to be used
* Proof system to be used by SnarkJS.
*/
export type ProofSystem = 'groth16' | 'plonk' | 'fflonk';

View File

@@ -1,4 +1,4 @@
import {WitnessType, CircuitSignals} from './circuit';
import {WitnessType, CircuitSignals, SymbolsType} from './circuit';
/**
* A simple type-wrapper for `circom_tester` WASM tester class.
@@ -13,6 +13,6 @@ export type CircomWasmTester = {
loadConstraints: () => Promise<void>;
constraints: any[] | undefined;
loadSymbols: () => Promise<void>;
symbols: object | undefined;
symbols: SymbolsType | undefined;
getDecoratedOutput: (witness: WitnessType) => Promise<string>;
};

View File

@@ -9,7 +9,7 @@ import {CircuitConfig} from '../types/config';
* @param circuitConfig circuit configurations, if `undefined` then `circuit.config.ts` will be used.
* @param directory name of the directory under circuits to be created. Can be given sub-folders like `test/myCircuit/foobar`. Defaults to `test`
*/
export function instantiate(name: string, circuitConfig?: CircuitConfig, directory = 'test') {
export default function instantiate(name: string, circuitConfig?: CircuitConfig, directory = 'test') {
// get config from circuit.config.ts if none are given
if (circuitConfig === undefined) {
if (!(name in config)) {

View File

@@ -9,7 +9,7 @@ const PROOF_SYSTEMS = ['groth16', 'plonk', 'fflonk'] as const;
* A more extensive Circuit class, able to generate proofs & verify them.
* Assumes that prover key and verifier key have been computed.
*/
export class ProofTester<IN extends string[] = []> {
export default class ProofTester<IN extends string[] = []> {
public readonly protocol: ProofSystem;
private readonly wasmPath: string;
private readonly proverKeyPath: string;

View File

@@ -1,12 +1,12 @@
const wasm_tester = require('circom_tester').wasm;
import {WitnessType, CircuitSignals} from '../types/circuit';
import {WitnessType, CircuitSignals, SymbolsType, SignalValueType} from '../types/circuit';
import {CircomWasmTester} from '../types/wasmTester';
import {assert, expect} from 'chai';
/**
A utility class to test your circuits. Use `expectFailedAssert` and `expectCorrectAssert` to test out evaluations
*/
export class WasmTester<IN extends readonly string[] = [], OUT extends readonly string[] = []> {
export default class WasmTester<IN extends readonly string[] = [], OUT extends readonly string[] = []> {
/**
* The underlying `circom_tester` object
*/
@@ -15,7 +15,7 @@ export class WasmTester<IN extends readonly string[] = [], OUT extends readonly
/**
* A dictionary of symbols
*/
symbols: object | undefined;
symbols: SymbolsType | undefined;
/**
* List of constraints, must call `loadConstraints` before accessing this key
@@ -63,10 +63,13 @@ export class WasmTester<IN extends readonly string[] = [], OUT extends readonly
/**
* Loads the symbols in a dictionary at `this.symbols`
* Symbols are stored under the .sym file
*
* Each line has 4 comma-separated values:
* 0: label index
* 1: variable index
* 2: component index
*
* 1. symbol name
* 2. label index
* 3. variable index
* 4. component index
*/
async loadSymbols(): Promise<void> {
await this.circomWasmTester.loadSymbols();
@@ -103,15 +106,15 @@ export class WasmTester<IN extends readonly string[] = [], OUT extends readonly
await this.loadConstraints();
}
const numConstraints = this.constraints!.length;
console.log(`#constraints: ${numConstraints}`);
console.log(`# constraints: ${numConstraints}`);
if (expected !== undefined) {
if (numConstraints < expected) {
console.log(`\x1b[0;31mx expectation ${expected}\x1b[0m`);
console.log(`\x1b[0;31mx expectation: ${expected}\x1b[0m`);
} else if (numConstraints > expected) {
console.log(`\x1b[0;33m! expectation ${expected}\x1b[0m`);
console.log(`\x1b[0;33m! expectation: ${expected}\x1b[0m`);
} else {
console.log(`\x1b[0;32m✔\x1b[2;37m expectation ${expected}\x1b[0m`);
console.log(`\x1b[0;32m✔\x1b[2;37m expectation: ${expected}\x1b[0m`);
}
}
}
@@ -139,20 +142,94 @@ export class WasmTester<IN extends readonly string[] = [], OUT extends readonly
await this.assertOut(witness, output);
}
}
}
/**
* Compiles and reutrns a circuit tester class instance.
* @param circuit name of circuit
* @param dir directory to read the circuit from, defaults to `test`
* @returns a `WasmTester` instance
*/
export async function createWasmTester<IN extends string[] = [], OUT extends string[] = []>(
circuitName: string,
dir = 'test'
): Promise<WasmTester<IN, OUT>> {
const circomWasmTester: CircomWasmTester = await wasm_tester(`./circuits/${dir}/${circuitName}.circom`, {
include: 'node_modules', // will link circomlib circuits
});
return new WasmTester<IN, OUT>(circomWasmTester);
/**
* Computes the output.
*
* This is an **expensive operation** in the following sense:
*
* 1. the witness is calculated via `calculateWitness`
* 2. symbols are loaded via `loadSymbols`, which is a bit expensive in it's own sense
* 3. for the requested output signals, the symbols are parsed and the required symbols are retrieved
* 4. for each signal & it's required symbols, corresponding witness values are retrieved from witness
* 5. the results are aggregated in a final object, of the same type of circuit output signals
*
* @param input input signals
* @param outputSignals an array of signal names
* @returns output signals
*/
async compute(input: CircuitSignals<IN>, outputSignals: OUT): Promise<Partial<CircuitSignals<typeof outputSignals>>> {
const witness = await this.calculateWitness(input, true);
// get symbols of main component
await this.loadSymbols();
const symbolNames = Object.keys(this.symbols!).filter(signal => !signal.includes('.', 5)); // non-main signals have an additional `.` in them after `main.symbol`
// for each out signal, process the respective symbol
const entries: [OUT[number], SignalValueType][] = [];
for (const outSignal of outputSignals) {
// get the symbol values from symbol names
const symbols = symbolNames.filter(s => s.startsWith(outSignal, 5));
/*
we can assume that a symbol with this name appears only once in `main`, and that the depth is same for
all occurences of this symbol, given the type system used in Circom. So, we can just count the number
of `[`s in any symbol of this signal to find the number of dimensions of this signal.
we particularly choose the last symbol in the array, as that holds the maximum index of each dimension of this array.
*/
const splits = symbols.at(-1)!.split('[');
// since we chose the last symbol, we have something like `main.signal[dim1][dim2]...[dimN]` which we can parse
const dims = splits.slice(1).map(dim => parseInt(dim.slice(0, -1)) + 1); // +1 is needed because the final value is 0-indexed
// since signal names are consequent, we only need to know the witness index of the first symbol
let idx = this.symbols![symbols[0]].varIdx;
if (dims.length === 0) {
// easy case, just return the witness of this symbol
entries.push([outSignal, witness[idx]]);
} else {
/*
at this point, we have an array of signals like `main.signal[0..dim1][0..dim2]..[0..dimN]` and we must construct
the necessary multi-dimensional array out of it.
*/
// eslint-disable-next-line no-inner-declarations
function processDepth(d: number): SignalValueType {
const acc: SignalValueType = [];
if (d === dims.length - 1) {
// final depth, count witnesses
for (let i = 0; i < dims[d]; i++) {
acc.push(witness[idx++]);
}
} else {
// not final depth, recurse to next
for (let i = 0; i < dims[d]; i++) {
acc.push(processDepth(d + 1));
}
}
return acc;
}
entries.push([outSignal, processDepth(0)]);
}
}
return Object.fromEntries(entries) as CircuitSignals<OUT>;
}
/**
* Compiles and reutrns a circuit tester class instance.
* @param circuit name of circuit
* @param dir directory to read the circuit from, defaults to `test`
* @returns a `WasmTester` instance
*/
static async new<IN extends string[] = [], OUT extends string[] = []>(
circuitName: string,
dir = 'test'
): Promise<WasmTester<IN, OUT>> {
const circomWasmTester: CircomWasmTester = await wasm_tester(`./circuits/${dir}/${circuitName}.circom`, {
include: 'node_modules', // will link circomlib circuits
});
return new WasmTester<IN, OUT>(circomWasmTester);
}
}