diff --git a/tests/float_add.test.ts b/tests/float_add.test.ts index 714b2d2..65f11cf 100644 --- a/tests/float_add.test.ts +++ b/tests/float_add.test.ts @@ -3,6 +3,8 @@ import {createWasmTester} from '../utils/wasmTester'; // tests adapted from https://github.com/rdi-berkeley/zkp-mooc-lab describe('float_add 32-bit', () => { + const k = 8; + const p = 23; let circuit: Awaited>; before(async () => { @@ -10,7 +12,7 @@ describe('float_add 32-bit', () => { file: 'float_add', template: 'FloatAdd', publicInputs: [], - templateParams: [8, 23], + templateParams: [k, p], }); circuit = await createWasmTester('fp32', 'test'); await circuit.printConstraintCount(401); @@ -106,6 +108,8 @@ describe('float_add 32-bit', () => { }); describe('float_add 64-bit', () => { + const k = 11; + const p = 52; let circuit: Awaited>; before(async () => { @@ -113,7 +117,7 @@ describe('float_add 64-bit', () => { file: 'float_add', template: 'FloatAdd', publicInputs: [], - templateParams: [11, 52], + templateParams: [k, p], }); circuit = await createWasmTester('fp64', 'test'); await circuit.printConstraintCount(819); @@ -201,7 +205,7 @@ describe('float_add utilities', () => { let circuit: Awaited>; before(async () => { - instantiate(circuitName, 'float_add/test', { + instantiate(circuitName, 'test/float_add', { file: 'float_add', template: 'CheckBitLength', publicInputs: [], @@ -230,5 +234,224 @@ describe('float_add utilities', () => { }); }); - describe('right shift', () => {}); + describe('left shift', () => { + const shift_bound = 25; + const circuitName = 'shl_' + shift_bound; + let circuit: Awaited>; + + before(async () => { + instantiate(circuitName, 'test/float_add', { + file: 'float_add', + template: 'LeftShift', + publicInputs: [], + templateParams: [shift_bound], + }); + circuit = await createWasmTester(circuitName, 'test/float_add'); + await circuit.printConstraintCount(shift_bound + 2); + }); + + it("should pass test 1 - don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + x: '65', + shift: '24', + skip_checks: '0', + }, + {y: '1090519040'} + ); + }); + + it("should pass test 2 - don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + x: '65', + shift: '0', + skip_checks: '0', + }, + {y: '65'} + ); + }); + + it("should fail - don't skip checks", async () => { + await circuit.expectFailedAssert({ + x: '65', + shift: '25', + skip_checks: '0', + }); + }); + + it('should pass when `skip_checks` = 1 and `shift` is >= shift_bound', async () => { + await circuit.expectCorrectAssert({ + x: '65', + shift: '25', + skip_checks: '1', + }); + }); + }); + + describe('right shift', () => { + const b = 49; + const shift = 24; + const circuitName = 'shr_' + b; + let circuit: Awaited>; + + before(async () => { + instantiate(circuitName, 'test/float_add', { + file: 'float_add', + template: 'RightShift', + publicInputs: [], + templateParams: [b, shift], + }); + circuit = await createWasmTester(circuitName, 'test/float_add'); + await circuit.printConstraintCount(b); + }); + + it('should pass - small bitwidth', async () => { + instantiate(circuitName, 'test/float_add', { + file: 'float_add', + template: 'RightShift', + publicInputs: [], + templateParams: [b, shift], + }); + circuit = await createWasmTester(circuitName, 'test/float_add'); + await circuit.printConstraintCount(b); + + await circuit.expectCorrectAssert( + { + x: '82263136010365', + }, + {y: '4903265'} + ); + }); + + it('should fail - large bitwidth', async () => { + await circuit.expectFailedAssert({ + x: '15087340228765024367', + }); + }); + }); + + describe('normalize', () => { + // var circ_file = path.join(__dirname, 'circuits', 'normalize.circom'); + // var circ_file_msnzb = path.join(__dirname, 'circuits', 'msnzb.circom'); + // var circ, num_constraints; + const k = 8; + const p = 23; + const P = 47; + let circuit: Awaited>; + + // before(async () => { + // circ = await wasm_tester(circ_file); + // await circ.loadConstraints(); + // num_constraints = circ.constraints.length; + + // console.log('Normalize #Constraints:', num_constraints, 'Expected:', 3 * (P + 1)); + + // circ_msnzb = await wasm_tester(circ_file_msnzb); + // await circ_msnzb.loadConstraints(); + // num_constraints_msnzb = circ_msnzb.constraints.length; + // if (num_constraints < num_constraints_msnzb + 1) { + // console.log( + // 'WARNING: the #constraints is less than (#constraints for MSNZB + 1). It is likely that you are not constraining the witnesses appropriately.' + // ); + // } + // }); + + it("should pass - don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + e: '100', + m: '20565784002591', + skip_checks: '0', + }, + {e_out: '121', m_out: '164526272020728'} + ); + }); + + it("should pass - already normalized and don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + e: '100', + m: '164526272020728', + skip_checks: '0', + }, + {e_out: '124', m_out: '164526272020728'} + ); + }); + + it("should fail when `m` = 0 - don't skip checks", async () => { + await circuit.expectFailedAssert({ + e: '100', + m: '0', + skip_checks: '0', + }); + }); + + it('should pass when `skip_checks` = 1 and `m` is 0', async () => { + await circuit.expectCorrectAssert({ + e: '100', + m: '0', + skip_checks: '1', + }); + }); + }); + + describe('msnzb', () => { + const b = 48; + let circuit: Awaited>; + + // before(async () => { + // circ = await wasm_tester(circ_file); + // await circ.loadConstraints(); + // num_constraints = circ.constraints.length; + // var b = 48; + // var expected_constraints = 3 * b - 1; + // console.log('MSNZB #Constraints:', num_constraints, 'Expected:', expected_constraints); + // if (num_constraints < expected_constraints) { + // console.log( + // 'WARNING: number of constraints is less than 3b-1. It is likely that you are not constraining the witnesses appropriately.' + // ); + // } + // }); + + it("should pass test 1 - don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + in: '1', + skip_checks: '0', + }, + { + // prettier-ignore + one_hot: ["1", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0"], + } + ); + }); + + it("should pass test 2 - don't skip checks", async () => { + await circuit.expectCorrectAssert( + { + in: '281474976710655', + skip_checks: '0', + }, + { + // prettier-ignore + one_hot: ["0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "1"], + } + ); + }); + + it("should fail when `in` = 0 - don't skip checks", async () => { + await circuit.expectFailedAssert({ + in: '0', + skip_checks: '0', + }); + }); + + it('should pass when `skip_checks` = 1 and `in` is 0', async () => { + await circuit.expectCorrectAssert({ + in: '0', + skip_checks: '1', + }); + }); + }); }); diff --git a/utils/instantiate.ts b/utils/instantiate.ts index f41d9a9..581328b 100644 --- a/utils/instantiate.ts +++ b/utils/instantiate.ts @@ -16,9 +16,13 @@ export function instantiate(name: string, directory: string, circuitConfig?: Cir // generate the main component code const ejsPath = './circuits/ejs/template.circom'; + // add "../" to the filename in include, one for each "/" in directory name + // if none, the prefix becomes empty string + const filePrefix = '../'.repeat((directory.match(/\//g) || []).length); + let circuit = ejs.render(readFileSync(ejsPath).toString(), { ...circuitConfig, - file: circuitConfig.file, // TODO: add ../'s based on dir + file: filePrefix + circuitConfig.file, // TODO: add ../'s based on dir }); // output to file @@ -30,7 +34,7 @@ export function instantiate(name: string, directory: string, circuitConfig?: Cir } const targetPath = `${targetDir}/${name}.circom`; writeFileSync(targetPath, circuit); - console.log(`Main component created at: ${targetPath}\n`); + // console.log(`Main component created at: ${targetPath}\n`); } export function clearTestInstance(name: string, directory: string) { diff --git a/utils/wasmTester.ts b/utils/wasmTester.ts index 0ea03ef..ce74925 100644 --- a/utils/wasmTester.ts +++ b/utils/wasmTester.ts @@ -104,14 +104,19 @@ class WasmTester { await this.loadConstraints(); } const numConstraints = this.constraints!.length; + + // if expecting a specific number, check if you match that let expectionMessage = ''; if (expected !== undefined) { let alertType = ''; if (numConstraints < expected) { + // need more alertType = '🔴'; } else if (numConstraints > expected) { + // too many alertType = '🟡'; } else { + // on point alertType = '🟢'; } expectionMessage = ` (${alertType} expected ${expected})`; @@ -145,11 +150,10 @@ class WasmTester { } /** - * Compiles and reutrns a circuit via `circom_tester`'s `wasm_tester`. + * Compiles and reutrns a circuit tester class instance. * @param circuit name of circuit * @param dir directory to read the circuit from, defaults to `main` - * @param showNumConstraints print number of constraints, defualts to `false` - * @returns a `wasm_tester` object + * @returns a `WasmTester` instance */ export async function createWasmTester(circuitName: string, dir: string = 'main'): Promise { const circomWasmTester: CircomWasmTester = await wasm_tester(`./circuits/${dir}/${circuitName}.circom`, {