sudoku rfks

This commit is contained in:
Erhan Tezcan
2023-04-01 12:50:27 +03:00
parent ddb2e296fa
commit 5114061283
9 changed files with 140 additions and 218 deletions

View File

@@ -1,5 +1,5 @@
# compiler args
CLIENV_COMPILER_ARGS="-l ./node_modules"
CLIENV_COMPILER_ARGS="-l ./node_modules --r1cs --wasm --sym --inspect"
# colors for swag
CLIENV_COLOR_TITLE='\033[0;34m' # blue

View File

@@ -1,150 +1,34 @@
pragma circom 2.0.0;
// code from ZKP MOOC 2023 lab
// this code is taken from ZKP MOOC 2023 lab, and then modified
/////////////////////////////////////////////////////////////////////////////////////
/////////////////////// Templates from the circomlib ////////////////////////////////
////////////////// Copy-pasted here for easy reference //////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////
include "circomlib/circuits/comparators.sol";
include "circomlib/circuits/switcher.sol";
include "circomlib/circuits/gates.sol";
include "circomlib/circuits/bitify.sol";
include "./math/bits.sol";
/*
* Outputs `a` AND `b`
* Basically `out = cond ? ifTrue : ifFalse`
*/
template AND() {
signal input a;
signal input b;
signal output out;
out <== a*b;
}
/*
* Outputs `a` OR `b`
*/
template OR() {
signal input a;
signal input b;
signal output out;
out <== a + b - a*b;
}
/*
* `out` = `cond` ? `L` : `R`
*/
template IfThenElse() {
template IfElse() {
signal input cond;
signal input L;
signal input R;
signal input ifTrue;
signal input ifFalse;
signal output out;
out <== cond * (L - R) + R;
// cond * T - cond * F + F
// 0 * T - 0 * F + F = 0 - 0 + F = F
// 1 * T - 1 * F + F = T - F + F = T
out <== cond * (ifTrue - ifFalse) + ifFalse;
}
/*
* (`outL`, `outR`) = `sel` ? (`R`, `L`) : (`L`, `R`)
*/
template Switcher() {
signal input sel;
signal input L;
signal input R;
signal output outL;
signal output outR;
signal aux;
aux <== (R-L)*sel;
outL <== aux + L;
outR <== -aux + R;
}
/*
* Decomposes `in` into `b` bits, given by `bits`.
* Least significant bit in `bits[0]`.
* Enforces that `in` is at most `b` bits long.
*/
template Num2Bits(b) {
signal input in;
signal output bits[b];
for (var i = 0; i < b; i++) {
bits[i] <-- (in >> i) & 1;
bits[i] * (1 - bits[i]) === 0;
}
var sum_of_bits = 0;
for (var i = 0; i < b; i++) {
sum_of_bits += (2 ** i) * bits[i];
}
sum_of_bits === in;
}
/*
* Reconstructs `out` from `b` bits, given by `bits`.
* Least significant bit in `bits[0]`.
*/
template Bits2Num(b) {
signal input bits[b];
signal output out;
var lc = 0;
for (var i = 0; i < b; i++) {
lc += (bits[i] * (1 << i));
}
out <== lc;
}
/*
* Checks if `in` is zero and returns the output in `out`.
*/
template IsZero() {
signal input in;
signal output out;
signal inv;
inv <-- in!=0 ? 1/in : 0;
out <== -in*inv +1;
in*out === 0;
}
/*
* Checks if `in[0]` == `in[1]` and returns the output in `out`.
*/
template IsEqual() {
signal input in[2];
signal output out;
component isz = IsZero();
in[1] - in[0] ==> isz.in;
isz.out ==> out;
}
/*
* Checks if `in[0]` < `in[1]` and returns the output in `out`.
*/
template LessThan(n) {
assert(n <= 252);
signal input in[2];
signal output out;
component n2b = Num2Bits(n+1);
n2b.in <== in[0] + (1<<n) - in[1];
out <== 1-n2b.bits[n];
}
/////////////////////////////////////////////////////////////////////////////////////
///////////////////////// Templates for this lab ////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////
/*
* Outputs `out` = 1 if `in` is at most `b` bits long, and 0 otherwise.
* ✅ this is in line with what I have done!
*/
template CheckBitLength(b) {
assert(b < 254); // ⚠️ added this
assert(b < 254);
signal input in;
signal output out;
@@ -192,11 +76,11 @@ template CheckWellFormedness(k, p) {
check_m_bits.in <== m - (1 << p);
// choose the right checks based on `is_e_zero`
component if_else = IfThenElse();
component if_else = IfElse();
if_else.cond <== is_e_zero.out;
if_else.L <== is_m_zero.out;
if_else.ifTrue <== is_m_zero.out;
//// check_m_bits.out * check_e_bits.out is equivalent to check_m_bits.out AND check_e_bits.out
if_else.R <== check_m_bits.out * check_e_bits.out;
if_else.ifFalse <== check_m_bits.out * check_e_bits.out;
// assert that those checks passed
if_else.out === 1;
@@ -206,7 +90,7 @@ template CheckWellFormedness(k, p) {
* Right-shifts `x` by `shift` bits to output `y`, where `shift` is a public circuit parameter.
*/
template RightShift(b, shift) {
assert(shift < b); // ⚠️ number is b bits long, and shift must be less than that
assert(shift < b);
signal input x;
signal output y;
@@ -261,13 +145,13 @@ template RoundAndCheck(k, p, P) {
// select right output based on no_overflow
component if_else[2];
for (var i = 0; i < 2; i++) {
if_else[i] = IfThenElse();
if_else[i] = IfElse();
if_else[i].cond <== no_overflow;
}
if_else[0].L <== e_out_1;
if_else[0].R <== e_out_2;
if_else[1].L <== m_out_1;
if_else[1].R <== m_out_2;
if_else[0].ifTrue <== e_out_1;
if_else[0].ifFalse <== e_out_2;
if_else[1].ifTrue <== m_out_1;
if_else[1].ifFalse <== m_out_2;
e_out <== if_else[0].out;
m_out <== if_else[1].out;
}
@@ -289,6 +173,7 @@ template Num2BitsWithSkipChecks(b) {
// is always true if skip_checks is 1
(sum_of_bits - in) * (1 - skip_checks) === 0;
}
template LessThanWithSkipChecks(n) {
assert(n <= 252);
signal input in[2];
@@ -301,15 +186,6 @@ template LessThanWithSkipChecks(n) {
out <== 1-n2b.bits[n];
}
// a rough log2 function
function log2(a) {
var n = 1, r = 1;
while (n < a) {
r++;
n *= 2;
}
return r;
}
/*
* Left-shifts `x` by `shift` bits to output `y`.
@@ -324,7 +200,7 @@ template LeftShift(shift_bound) {
signal output y;
// find number of bits in shift_bound
var n = log2(shift_bound);
var n = numOfBits(shift_bound);
// convert "shift" to bits
component shift_bits = Num2BitsWithSkipChecks(n);
@@ -343,18 +219,18 @@ template LeftShift(shift_bound) {
var pow2_shift = 1;
component muxes[n];
for (var i = 0; i < n; i++) {
muxes[i] = IfThenElse();
muxes[i] = IfElse();
muxes[i].cond <== shift_bits.bits[i];
muxes[i].L <== pow2_shift * (2 ** (2 ** i));
muxes[i].R <== pow2_shift;
muxes[i].ifTrue <== pow2_shift * (2 ** (2 ** i));
muxes[i].ifFalse <== pow2_shift;
pow2_shift = muxes[i].out;
}
// if skip checks, set pow2_shift to 0
component if_else = IfThenElse();
component if_else = IfElse();
if_else.cond <== skip_checks;
if_else.L <== 0;
if_else.R <== pow2_shift;
if_else.ifTrue <== 0;
if_else.ifFalse <== pow2_shift;
pow2_shift = if_else.out; // not <== because it's a variable
// do the shift
@@ -522,10 +398,10 @@ template FloatAdd(k, p) {
// return
component if_else[2];
for (var i = 0; i < 2; i++) {
if_else[i] = IfThenElse();
if_else[i] = IfElse();
if_else[i].cond <== is_case_1;
if_else[i].L <== case_1_output[i];
if_else[i].R <== case_2_output[i];
if_else[i].ifTrue <== case_1_output[i];
if_else[i].ifFalse <== case_2_output[i];
}
e_out <== if_else[0].out;
m_out <== if_else[1].out;

View File

@@ -0,0 +1,10 @@
pragma circom 2.0.0;
include "./math.circom";
/*
* Get number of bits
*/
function numOfBits(n) {
return log2(n) + 1;
}

View File

@@ -0,0 +1,13 @@
pragma circom 2.0.0;
/*
* Computes Math.floor(log2(n))
*/
function log2(n) {
var tmp = 1, ans = 1;
while (tmp < n) {
ans++;
tmp *= 2;
}
return ans;
}

View File

@@ -1,14 +1,15 @@
pragma circom 2.0.0;
// include "circomlib/circuits/" TODO TODO
include "circomlib/circuits/bitify.circom";
include "./functions/bits.circom";
// Assert that two elements are not equal.
// Done via the check if in0 - in1 is non-zero.
// Assert that two elements are not equal
template NonEqual() {
signal input in[2];
signal inv;
// 1/0 results in 0, so the constraint wont hold
// we check if (in[0] - in[1] != 0)
// because 1/0 results in 0, so the constraint won't hold
inv <-- 1 / (in[1] - in[0]);
inv * (in[1] - in[0]) === 1;
}
@@ -16,7 +17,7 @@ template NonEqual() {
// Assert that all given values are unique
template Distinct(n) {
signal input in[n];
component nonEqual[n][n];
component nonEqual[n][n]; // TODO; has extra comps here
for(var i = 0; i < n; i++){
for(var j = 0; j < i; j++){
nonEqual[i][j] = NonEqual();
@@ -25,56 +26,78 @@ template Distinct(n) {
}
}
// Enforce that 0 <= in <= 15
template Bits4() {
// Checks that `in` is in range [MIN, MAX]
template InRange(MIN, MAX) {
signal input in;
signal bits[4];
var bitsum = 0;
for (var i = 0; i < 4; i++) {
bits[i] <-- (in >> i) & 1;
bits[i] * (bits[i] - 1) === 0;
bitsum = bitsum + 2 ** i * bits[i];
}
bitsum === in;
}
// Check if a given signal is in range [1, 9]
// This is true if 1-1 is
template OneToNine() {
signal input in;
component lowerBound = Bits4();
component upperBound = Bits4();
lowerBound.in <== in - 1; // 1 - 1 = 0 (for 0 <= in)
upperBound.in <== in + 6; // 9 + 6 = 15 (for in <= 15)
var b = numOfBits(MAX);
component lowerBound = Num2Bits(b);
component upperBound = Num2Bits(b);
lowerBound.in <== in - MIN; // e.g. 1 - 1 = 0 (for 0 <= in)
upperBound.in <== in + (2 ** b) - MAX - 1; // e.g. 9 + 6 = 15 (for in <= 15)
}
template Sudoku(n_sqrt) {
var n = n_sqrt * n_sqrt;
signal input solution[n][n]; // solution is a 2D array: indices are (row_i, col_i)
signal input solution[n][n]; // solution is a 2D array of numbers
signal input puzzle[n][n]; // puzzle is the same, but a zero indicates a blank
// make sure the solution agrees with the puzzle
// meaning that non-empty cells of the puzzle should equal to the corresponding solution value
// other values of the puzzle must be 0 (empty cell)
// ensure that solution & puzzle agrees
for (var row_i = 0; row_i < n; row_i++) {
for (var col_i = 0; col_i < n; col_i++) {
// puzzle is either empty (0), or the same as solution
puzzle[row_i][col_i] * (puzzle[row_i][col_i] - solution[row_i][col_i]) === 0;
}
}
// ensure all values in the solution are distinct and in range
// ensure all values in the solution are in range
component inRange[n][n];
component distinct[n];
for (var row_i = 0; row_i < n; row_i++) {
for (var col_i = 0; col_i < n; col_i++) {
inRange[row_i][col_i] = InRange(1, n);
inRange[row_i][col_i].in <== solution[row_i][col_i];
}
}
// ensure all values in the solution are distinct
component distinctRows[n];
for (var row_i = 0; row_i < n; row_i++) {
for (var col_i = 0; col_i < n; col_i++) {
if (row_i == 0) {
distinct[col_i] = Distinct(n);
distinctRows[col_i] = Distinct(n);
}
inRange[row_i][col_i] = OneToNine();
inRange[row_i][col_i].in <== solution[row_i][col_i];
distinct[col_i].in[row_i] <== solution[row_i][col_i];
distinctRows[col_i].in[row_i] <== solution[row_i][col_i];
}
}
component distinctCols[n];
for (var col_i = 0; col_i < n; col_i++) {
for (var row_i = 0; row_i < n; row_i++) {
if (col_i == 0) {
distinctCols[row_i] = Distinct(n);
}
distinctCols[row_i].in[col_i] <== solution[col_i][row_i];
}
}
// ensure that all values in squares are distinct
component distinctSquares[n];
for (var sr_i = 0; sr_i < n_sqrt; sr_i++) {
for (var sc_i = 0; sc_i < n_sqrt; sc_i++) {
// square index
var idx = sr_i * n_sqrt + sc_i;
distinctSquares[idx] = Distinct(n);
// (r, c) now marks the start of this square
var r = sr_i * n_sqrt;
var c = sc_i * n_sqrt;
var i = 0;
for (var row_i = r; row_i < r + n_sqrt; row_i++) {
for (var col_i = c; col_i < c + n_sqrt; col_i++) {
distinctSquares[idx].in[i] <== solution[row_i][col_i];
i++;
}
}
}
}
}

View File

@@ -17,6 +17,7 @@
"author": "erhant",
"devDependencies": {
"@types/chai": "^4.3.4",
"@types/chai-as-promised": "^7.1.5",
"@types/mocha": "^10.0.1",
"@types/node": "^18.11.18",
"chai": "^4.3.7",

View File

@@ -12,7 +12,8 @@ compile() {
mkdir -p $CIRCOM_OUT
# compile with circom
circom $CIRCOM_IN -o $CIRCOM_OUT --r1cs --wasm --sym
echo "circom $CIRCOM_IN -o $CIRCOM_OUT $CLIENV_COMPILER_ARGS"
circom $CIRCOM_IN -o $CIRCOM_OUT $CLIENV_COMPILER_ARGS
echo -e "${CLIENV_COLOR_LOG}Built artifacts under $CIRCOM_OUT${CLIENV_COLOR_RESET}"
}

View File

@@ -17,7 +17,6 @@ describe(CIRCUIT_NAME, () => {
});
it('should compute correctly', async () => {
// compute witness
const witness = await circuit.calculateWitness(INPUT, true);
// witness should have valid constraints
@@ -31,25 +30,17 @@ describe(CIRCUIT_NAME, () => {
});
it('should NOT compute with wrong number of inputs', async () => {
try {
await circuit.calculateWitness(
{
in: INPUT.in.slice(1), // fewer inputs
},
true
);
assert.fail('expected to fail on fewer inputs');
} catch (err) {}
const fewInputs = INPUT.in.slice(1);
await circuit.calculateWitness({in: fewInputs}, true).then(
() => assert.fail('expected to fail on fewer inputs'),
err => expect(err.message).to.eq('Not enough values for input signal in\n')
);
try {
await circuit.calculateWitness(
{
in: [2n, ...INPUT.in], // more inputs
},
true
);
assert.fail('expected to fail on too many inputs');
} catch (err) {}
const manyInputs = [2n, ...INPUT.in];
await circuit.calculateWitness({in: manyInputs}, true).then(
() => assert.fail('expected to fail on too many inputs'),
err => expect(err.message).to.eq('Too many values for input signal in\n')
);
});
});

View File

@@ -480,7 +480,14 @@
resolved "https://registry.yarnpkg.com/@tsconfig/node16/-/node16-1.0.3.tgz#472eaab5f15c1ffdd7f8628bd4c4f753995ec79e"
integrity sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==
"@types/chai@^4.3.4":
"@types/chai-as-promised@^7.1.5":
version "7.1.5"
resolved "https://registry.yarnpkg.com/@types/chai-as-promised/-/chai-as-promised-7.1.5.tgz#6e016811f6c7a64f2eed823191c3a6955094e255"
integrity sha512-jStwss93SITGBwt/niYrkf2C+/1KTeZCZl1LaeezTlqppAKeoQC7jxyqYuP72sxBGKCIbw7oHgbYssIRzT5FCQ==
dependencies:
"@types/chai" "*"
"@types/chai@*", "@types/chai@^4.3.4":
version "4.3.4"
resolved "https://registry.yarnpkg.com/@types/chai/-/chai-4.3.4.tgz#e913e8175db8307d78b4e8fa690408ba6b65dee4"
integrity sha512-KnRanxnpfpjUTqTCXslZSEdLfXExwgNxYPdiO2WGUj8+HDjFi8R3k5RVKPeSCzLjCcshCAtVO2QBbVuAV4kTnw==