sudok & fp-add mostly done

This commit is contained in:
Erhan Tezcan
2023-04-02 17:47:15 +03:00
parent e5b68f122c
commit 40e5c034ed
13 changed files with 445 additions and 173 deletions

View File

@@ -2,6 +2,13 @@
> An opinionated Circom circuit development environment.
You can develop & test Circom circuits with ease using this repository. We have several example circuits to help guide you:
- **Multiplier**: Proves that you know the factors of a number.
- **Floating Point Addition**: A floating-point addition circuit, as written in [Berkeley ZKP MOOC 2023- Lab 1](https://github.com/rdi-berkeley/zkp-mooc-lab).
- **Fibonacci**: Calculate N'th Fibonacci number, has both recursive & iterative implementations.
- **Sudoku**: Prove that you know the solution to a sudoku puzzle where the board size is a perfect square.
## Usage
Clone the repository or create a new one with this as the template! You need [Circom](https://docs.circom.io/getting-started/installation/) to compile circuits. Other than that, just `yarn` or `npm install` to get started. It will also install [Circomlib](https://github.com/iden3/circomlib/tree/master/circuits) which has many utility circuits.

View File

@@ -1,9 +1,9 @@
/**
* @type {import("./circuit.config").Config}
* @type {import("./types/circuit").Config}
*/
const config = {
// multiplication of 3 numbers
multiplier3: {
multiplier_3: {
file: 'multiplier',
template: 'Multiplier',
publicInputs: [],
@@ -16,6 +16,13 @@ const config = {
publicInputs: ['puzzle'],
templateParams: [Math.sqrt(9)],
},
// A 4x4 sudoku board
sudoku_4x4: {
file: 'sudoku',
template: 'Sudoku',
publicInputs: ['puzzle'],
templateParams: [Math.sqrt(4)],
},
// 64-bit floating point, 11-bit exponent and 52-bit mantissa
fp64: {
file: 'float_add',

32
circuit.config.d.ts vendored
View File

@@ -1,32 +0,0 @@
/**
* Configuration file for your circuits.
*/
export type Config = {
[circuitName: string]: {
/**
* File to read the template from
*/
file: string;
/**
* The template name to instantiate
*/
template: string;
/**
* An array of public input signal names
*/
publicInputs: string[];
/**
* An array of template parameters
*/
templateParams: (number | bigint)[];
/**
* Directory to output under `circuits`, defaults to `main`
* @depracated work in progress, use `main` for now (leave empty)
*/
dir?: string;
};
};

View File

@@ -101,12 +101,12 @@ template RightShift(b, shift) {
// do the shifting
signal y_bits[b-shift];
for (var i = 0; i < b-shift; i++) {
y_bits[i] <== x_bits.bits[shift+i];
y_bits[i] <== x_bits.out[shift+i];
}
// convert shifted bits to number
component y_num = Bits2Num(b-shift);
y_num.bits <== y_bits;
y_num.in <== y_bits;
y <== y_num.out;
}
@@ -159,15 +159,15 @@ template RoundAndCheck(k, p, P) {
template Num2BitsWithSkipChecks(b) {
signal input in;
signal input skip_checks;
signal output bits[b];
signal output out[b];
for (var i = 0; i < b; i++) {
bits[i] <-- (in >> i) & 1;
bits[i] * (1 - bits[i]) === 0;
out[i] <-- (in >> i) & 1;
out[i] * (1 - out[i]) === 0;
}
var sum_of_bits = 0;
for (var i = 0; i < b; i++) {
sum_of_bits += (2 ** i) * bits[i];
sum_of_bits += (2 ** i) * out[i];
}
// is always true if skip_checks is 1
@@ -182,9 +182,8 @@ template LessThanWithSkipChecks(n) {
component n2b = Num2BitsWithSkipChecks(n+1);
n2b.in <== in[0] + (1<<n) - in[1];
n2b.skip_checks <== skip_checks;
out <== 1-n2b.bits[n];
n2b.skip_checks <== skip_checks;
out <== 1-n2b.out[n];
}
/*
@@ -220,7 +219,7 @@ template LeftShift(shift_bound) {
component muxes[n];
for (var i = 0; i < n; i++) {
muxes[i] = IfElse();
muxes[i].cond <== shift_bits.bits[i];
muxes[i].cond <== shift_bits.out[i];
muxes[i].ifTrue <== pow2_shift * (2 ** (2 ** i));
muxes[i].ifFalse <== pow2_shift;
pow2_shift = muxes[i].out;

View File

@@ -1,6 +0,0 @@
// auto-generated by instantiate.js
pragma circom 2.0.0;
include "../multiplier.circom";
component main = Multiplier(3);

View File

@@ -0,0 +1,6 @@
// auto-generated by instantiate.js
pragma circom 2.0.0;
include "../sudoku.circom";
component main {public[puzzle]} = Sudoku(2);

View File

@@ -3,7 +3,19 @@ pragma circom 2.0.0;
include "circomlib/circuits/bitify.circom";
include "functions/bits.circom";
// Ensures that number is representable by b-bits
// Assert that two elements are not equal
template NonEqual() {
signal input in[2];
signal output out;
signal inv;
// we check if (in[0] - in[1] != 0)
// because 1/0 results in 0, so the constraint won't hold
inv <-- 1 / (in[1] - in[0]);
inv * (in[1] - in[0]) === 1;
}
// Assert that number is representable by b-bits
template CheckBitLength(b) {
assert(b < 254);
signal input in;
@@ -21,17 +33,16 @@ template CheckBitLength(b) {
in === sum_of_bits;
}
// Assert that two elements are not equal
template NonEqual() {
signal input in[2];
signal output out;
signal inv;
// we check if (in[0] - in[1] != 0)
// because 1/0 results in 0, so the constraint won't hold
inv <-- 1 / (in[1] - in[0]);
0 * inv === 0; // silence error
out <== inv * (in[1] - in[0]);
// Checks that `in` is in range [MIN, MAX]
template InRange(MIN, MAX) {
assert(MIN < MAX);
signal input in;
var b = numOfBits(MAX);
component lowerBound = CheckBitLength(b);
component upperBound = CheckBitLength(b);
lowerBound.in <== in - MIN; // e.g. 1 - 1 = 0 (for 0 <= in)
upperBound.in <== in + (2 ** b) - MAX - 1; // e.g. 9 + (15 - 9) = 15 (for in <= 15)
}
// Assert that all given values are unique
@@ -42,22 +53,10 @@ template Distinct(n) {
for(var j = 0; j < i; j++){
nonEqual[i][j] = NonEqual();
nonEqual[i][j].in <== [in[i], in[j]];
nonEqual[i][j].out === 1;
}
}
}
// Checks that `in` is in range [MIN, MAX]
template InRange(MIN, MAX) {
signal input in;
var b = numOfBits(MAX);
component lowerBound = CheckBitLength(b);
component upperBound = CheckBitLength(b);
lowerBound.in <== in - MIN; // e.g. 1 - 1 = 0 (for 0 <= in)
upperBound.in <== in + (2 ** b) - MAX - 1; // e.g. 9 + 6 = 15 (for in <= 15)
}
template Sudoku(n_sqrt) {
var n = n_sqrt * n_sqrt;
signal input solution[n][n]; // solution is a 2D array of numbers
@@ -96,27 +95,30 @@ template Sudoku(n_sqrt) {
if (col_i == 0) {
distinctCols[row_i] = Distinct(n);
}
distinctCols[row_i].in[col_i] <== solution[col_i][row_i];
distinctCols[row_i].in[col_i] <== solution[row_i][col_i];
}
}
// ensure that all values in squares are distinct
component distinctSquares[n];
var s_i = 0;
for (var sr_i = 0; sr_i < n_sqrt; sr_i++) {
for (var sc_i = 0; sc_i < n_sqrt; sc_i++) {
// square index
var idx = sr_i * n_sqrt + sc_i;
distinctSquares[idx] = Distinct(n);
distinctSquares[s_i] = Distinct(n);
// (r, c) now marks the start of this square
var r = sr_i * n_sqrt;
var c = sc_i * n_sqrt;
var i = 0;
for (var row_i = r; row_i < r + n_sqrt; row_i++) {
for (var col_i = c; col_i < c + n_sqrt; col_i++) {
distinctSquares[idx].in[i] <== solution[row_i][col_i];
distinctSquares[s_i].in[i] <== solution[row_i][col_i];
i++;
}
}
s_i++;
}
}

View File

@@ -1,22 +1,7 @@
import {createWasmTester} from '../utils/wasmTester';
import type {CircuitSignals} from '../types/circuit';
// simple fibonacci with 2 variables
function fibonacci(init: [number, number], n: number): number {
if (n < 0) {
throw new Error('N must be positive');
}
let [a, b] = init;
for (let i = 2; i <= n; i++) {
b = a + b;
a = b - a;
}
return n == 0 ? a : b;
}
const CIRCUIT_NAME = 'fibonacci_11';
describe(CIRCUIT_NAME, () => {
describe('fibonacci_11', () => {
const INPUT: CircuitSignals = {
in: [1, 1],
};
@@ -24,7 +9,7 @@ describe(CIRCUIT_NAME, () => {
let circuit: Awaited<ReturnType<typeof createWasmTester>>;
before(async () => {
circuit = await createWasmTester(CIRCUIT_NAME);
circuit = await createWasmTester('fibonacci_11');
});
it('should compute correctly', async () => {
@@ -41,3 +26,17 @@ describe(CIRCUIT_NAME, () => {
await circuit.assertOut(witness, output);
});
});
// simple fibonacci with 2 variables
function fibonacci(init: [number, number], n: number): number {
if (n < 0) {
throw new Error('N must be positive');
}
let [a, b] = init;
for (let i = 2; i <= n; i++) {
b = a + b;
a = b - a;
}
return n == 0 ? a : b;
}

View File

@@ -1,27 +1,245 @@
import {createWasmTester} from '../utils/wasmTester';
import {createWasmTester, printConstraintCount} from '../utils/wasmTester';
import {ProofTester} from '../utils/proofTester';
import type {CircuitSignals, FullProof} from '../types/circuit';
import {assert, expect} from 'chai';
// TODO: write tests
const CIRCUIT_NAME = 'cbl_32';
describe('utils', () => {
describe('fp32', () => {
let circuit: Awaited<ReturnType<typeof createWasmTester>>;
before(async () => {
circuit = await createWasmTester(CIRCUIT_NAME, 'test');
circuit = await createWasmTester('fp32');
await circuit.loadConstraints();
await printConstraintCount(circuit, 401);
});
it('should compute correctly', async () => {
it('case I test', async () => {
const witness = await circuit.calculateWitness(
{
in: 3,
e: ['43', '5'],
m: ['11672136', '10566265'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {
out: 1,
});
await circuit.assertOut(witness, {e_out: '43', m_out: '11672136'});
});
it('case II test 1', async () => {
const witness = await circuit.calculateWitness(
{
e: ['104', '106'],
m: ['12444445', '14159003'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {e_out: '107', m_out: '8635057'});
});
it('case II test 2', async () => {
const witness = await circuit.calculateWitness(
{
e: ['176', '152'],
m: ['16777215', '16777215'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {e_out: '177', m_out: '8388608'});
});
it('case II test 3', async () => {
const witness = await circuit.calculateWitness(
{
e: ['142', '142'],
m: ['13291872', '13291872'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {e_out: '143', m_out: '13291872'});
});
it('one input zero test', async () => {
const witness = await circuit.calculateWitness(
{
e: ['0', '43'],
m: ['0', '10566265'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {e_out: '43', m_out: '10566265'});
});
it('both inputs zero test', async () => {
const witness = await circuit.calculateWitness(
{
e: ['0', '0'],
m: ['0', '0'],
},
true
);
await circuit.checkConstraints(witness);
await circuit.assertOut(witness, {e_out: '0', m_out: '0'});
});
it('should fail - exponent zero but mantissa non-zero', async () => {
await circuit
.calculateWitness(
{
e: ['0', '0'],
m: ['0', '10566265'],
},
true
)
.then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should fail - mantissa >= 2^{p+1}', async () => {
await circuit
.calculateWitness(
{
e: ['0', '43'],
m: ['0', '16777216'],
},
true
)
.then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should fail - mantissa < 2^{p}', async () => {
await circuit
.calculateWitness(
{
e: ['0', '43'],
m: ['0', '6777216'],
},
true
)
.then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should fail - exponent >= 2^k', async () => {
await circuit
.calculateWitness(
{
e: ['0', '256'],
m: ['0', '10566265'],
},
true
)
.then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
});
// describe('FP64Add', () => {
// var circ_file = path.join(__dirname, 'circuits', 'fp64_add.circom');
// var circ, num_constraints;
// before(async () => {
// circ = await wasm_tester(circ_file);
// await circuit.loadConstraints();
// num_constraints = circuit.constraints.length;
// console.log('Float64 Add #Constraints:', num_constraints, 'Expected:', 819);
// });
// it('case I test', async () => {
// const input = {
// e: ['1122', '1024'],
// m: ['7807742059002284', '7045130465601185'],
// };
// const witness = await circuit.calculateWitness(input, 1);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '1122', m_out: '7807742059002284'});
// });
// it('case II test 1', async () => {
// const input = {
// e: ['1056', '1053'],
// m: ['8879495032259305', '5030141535601637'],
// };
// const witness = await circuit.calculateWitness(input);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '1057', m_out: '4754131362104755'});
// });
// it('case II test 2', async () => {
// const input = {
// e: ['1035', '982'],
// m: ['4804509148660890', '8505192799372177'],
// };
// const witness = await circuit.calculateWitness(input);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '1035', m_out: '4804509148660891'});
// });
// it('case II test 3', async () => {
// const input = {
// e: ['982', '982'],
// m: ['8505192799372177', '8505192799372177'],
// };
// const witness = await circuit.calculateWitness(input);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '983', m_out: '8505192799372177'});
// });
// it('one input zero test', async () => {
// const input = {
// e: ['0', '982'],
// m: ['0', '8505192799372177'],
// };
// const witness = await circuit.calculateWitness(input);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '982', m_out: '8505192799372177'});
// });
// it('both inputs zero test', async () => {
// const input = {
// e: ['0', '0'],
// m: ['0', '0'],
// };
// const witness = await circuit.calculateWitness(input);
// await circuit.checkConstraints(witness);
// await circuit.assertOut(witness, {e_out: '0', m_out: '0'});
// });
// it('should fail - exponent zero but mantissa non-zero', async () => {
// const input = {
// e: ['0', '0'],
// m: ['0', '8505192799372177'],
// };
// try {
// const witness = await circuit.calculateWitness(input);
// } catch (e) {
// return 0;
// }
// assert.fail('should have thrown an error');
// });
// it('should fail - mantissa < 2^{p}', async () => {
// const input = {
// e: ['0', '43'],
// m: ['0', '16777216'],
// };
// try {
// const witness = await circuit.calculateWitness(input);
// } catch (e) {
// return 0;
// }
// assert.fail('should have thrown an error');
// });
// });

View File

@@ -4,12 +4,11 @@ import {assert, expect} from 'chai';
// read inputs from file
import input80 from '../inputs/multiplier3/80.json';
const CIRCUIT_NAME = 'multiplier3';
describe(CIRCUIT_NAME + ' (proofs)', () => {
describe('multiplier3 (proofs)', () => {
const INPUT: CircuitSignals = input80;
let fullProof: FullProof;
const circuit = new ProofTester(CIRCUIT_NAME);
const circuit = new ProofTester('multiplier3');
before(async () => {
fullProof = await circuit.prove(INPUT);

View File

@@ -1,9 +1,8 @@
import {createWasmTester} from '../utils/wasmTester';
import {assert, expect} from 'chai';
const CIRCUIT_NAME = 'sudoku_9x9';
describe(CIRCUIT_NAME, () => {
const INPUT = {
const INPUTS = {
sudoku_9x9: {
solution: [
[1, 9, 4, 8, 6, 5, 2, 3, 7],
[7, 3, 5, 4, 1, 2, 9, 6, 8],
@@ -26,63 +25,90 @@ describe(CIRCUIT_NAME, () => {
[0, 4, 0, 1, 0, 9, 0, 8, 0],
[5, 0, 7, 0, 8, 0, 0, 9, 4],
],
};
},
sudoku_4x4: {
solution: [
[4, 1, 3, 2],
[3, 2, 4, 1],
[2, 4, 1, 3],
[1, 3, 2, 4],
],
puzzle: [
[0, 1, 0, 2],
[3, 2, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
],
},
};
let circuit: Awaited<ReturnType<typeof createWasmTester>>;
['sudoku_9x9', 'sudoku_4x4'].map(circuitName =>
describe(circuitName, () => {
// @ts-ignore
const INPUT = INPUTS[circuitName];
before(async () => {
circuit = await createWasmTester(CIRCUIT_NAME);
});
let circuit: Awaited<ReturnType<typeof createWasmTester>>;
it('should compute correctly', async () => {
// compute witness
const witness = await circuit.calculateWitness(INPUT, true);
before(async () => {
circuit = await createWasmTester(circuitName);
});
// witness should have valid constraints
await circuit.checkConstraints(witness);
});
it('should compute correctly', async () => {
// compute witness
const witness = await circuit.calculateWitness(INPUT, true);
it('should NOT accept non-distinct rows', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
// witness should have valid constraints
await circuit.checkConstraints(witness);
});
badInput.solution[0][0] = badInput.solution[0][1];
console.log(badInput.solution[0], badInput.solution[1]);
await circuit.calculateWitness(INPUT, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should NOT accept non-distinct rows', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
it('should NOT accept non-distinct columns', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
badInput.solution[0][0] = badInput.solution[0][1];
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
badInput.solution[0][0] = badInput.solution[1][0];
console.log(badInput.solution[0], badInput.solution[1]);
await circuit.calculateWitness(INPUT, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should NOT accept non-distinct columns', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
it('should NOT accept non-distinct square', async () => {
const badInput: typeof INPUT = JSON.parse(JSON.stringify(INPUT));
badInput.solution[0][0] = badInput.solution[1][0];
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
badInput.solution[0][0] = badInput.solution[1][1];
console.log(badInput.solution[0], badInput.solution[1]);
await circuit.calculateWitness(INPUT, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should NOT accept non-distinct square', async () => {
const badInput: typeof INPUT = JSON.parse(JSON.stringify(INPUT));
it('should NOT accept empty value in solution', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
badInput.solution[0][0] = badInput.solution[1][1];
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
badInput.solution[0][0] = 0;
console.log(badInput.solution[0], badInput.solution[1]);
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
});
it('should NOT accept empty value in solution', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
badInput.solution[0][0] = 0;
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
it('should NOT accept out-of-range values', async () => {
const badInput = JSON.parse(JSON.stringify(INPUT));
badInput.solution[0][0] = 99999;
await circuit.calculateWitness(badInput, true).then(
() => assert.fail(),
err => expect(err.message.slice(0, 21)).to.eq('Error: Assert Failed.')
);
});
})
);

View File

@@ -17,3 +17,37 @@ export type FullProof = {
proof: object;
publicSignals: string[];
};
/**
* Configuration file for your circuits.
* @see `circuit.config.cjs` in the project root.
*/
export type Config = {
[circuitName: string]: {
/**
* File to read the template from
*/
file: string;
/**
* The template name to instantiate
*/
template: string;
/**
* An array of public input signal names
*/
publicInputs: string[];
/**
* An array of template parameters
*/
templateParams: (number | bigint)[];
/**
* Directory to output under `circuits`, defaults to `main`
* @depracated work in progress, use `main` for now (leave empty)
*/
dir?: string;
};
};

View File

@@ -28,7 +28,7 @@ type WasmTester = {
/**
* Compute witness given the input signals.
* @param input all signals, private and public.
* @param sanityCheck ?
* @param sanityCheck check if input signals are sanitized
*/
calculateWitness: (input: CircuitSignals, sanityCheck: boolean) => Promise<WitnessType>;
@@ -71,19 +71,32 @@ type WasmTester = {
* @param showNumConstraints print number of constraints, defualts to `false`
* @returns a `wasm_tester` object
*/
export async function createWasmTester(
circuitName: string,
dir: string = 'main',
showNumConstraints: boolean = false
): Promise<WasmTester> {
const circuit = await wasm_tester(`./circuits/${dir}/${circuitName}.circom`, {
export async function createWasmTester(circuitName: string, dir: string = 'main'): Promise<WasmTester> {
return wasm_tester(`./circuits/${dir}/${circuitName}.circom`, {
include: 'node_modules', // will link circomlib circuits
});
if (showNumConstraints) {
await circuit.loadConstraints();
console.log(' number of constraints:', circuit.constraints!.length);
}
return circuit;
}
/**
* Prints the number of constraints of the circuit.
* If expected count is provided, will also include that in the log.
* @param circuit WasmTester circuit
* @param expected expected number of constraints
*/
export async function printConstraintCount(circuit: WasmTester, expected?: number) {
await circuit.loadConstraints();
const numConstraints = circuit.constraints!.length;
let expectionMessage = '';
if (expected !== undefined) {
let alertType = '';
if (numConstraints < expected) {
alertType = '🔴';
} else if (numConstraints > expected) {
alertType = '🟡';
} else {
alertType = '🟢';
}
expectionMessage = ` (${alertType} expected ${expected})`;
}
console.log(`#constraints: ${numConstraints}` + expectionMessage);
}