From ed0993a0a5bb2ebc8a4df5fd53cab69f58930c40 Mon Sep 17 00:00:00 2001 From: Daniel Tehrani Date: Thu, 2 Feb 2023 15:06:39 +0900 Subject: [PATCH] Fix verification: Add public input validation --- .vscode/settings.json | 1 + .../node/src/node.bench_addr_membership.ts | 2 +- .../node/src/node.bench_pubkey_membership.ts | 2 +- packages/benchmark/web/pages/index.tsx | 4 +- packages/lib/src/core/membership_prover.ts | 50 +++--- packages/lib/src/core/membership_verifier.ts | 27 ++- packages/lib/src/helpers/efficient_ecdsa.ts | 154 ++++++++++-------- packages/lib/src/types.ts | 10 +- packages/lib/tests/efficient_ecdsa.test.ts | 14 +- packages/lib/tests/membership_nizk.test.ts | 27 +-- 10 files changed, 157 insertions(+), 134 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3317b9e..2bff312 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "editor.formatOnSave": true, "cSpell.words": [ + "merkle", "NIZK" ] } \ No newline at end of file diff --git a/packages/benchmark/node/src/node.bench_addr_membership.ts b/packages/benchmark/node/src/node.bench_addr_membership.ts index 39d0963..6605282 100644 --- a/packages/benchmark/node/src/node.bench_addr_membership.ts +++ b/packages/benchmark/node/src/node.bench_addr_membership.ts @@ -78,7 +78,7 @@ const benchAddrMembership = async () => { await verifier.initWasm(); // Verify proof - await verifier.verify(proof, publicInput); + await verifier.verify(proof, publicInput.serialize()); }; export default benchAddrMembership; diff --git a/packages/benchmark/node/src/node.bench_pubkey_membership.ts b/packages/benchmark/node/src/node.bench_pubkey_membership.ts index a19609a..84f6a06 100644 --- a/packages/benchmark/node/src/node.bench_pubkey_membership.ts +++ b/packages/benchmark/node/src/node.bench_pubkey_membership.ts @@ -75,7 +75,7 @@ const benchPubKeyMembership = async () => { await verifier.initWasm(); // Verify proof - await verifier.verify(proof, publicInput); + await verifier.verify(proof, publicInput.serialize()); }; export default benchPubKeyMembership; diff --git a/packages/benchmark/web/pages/index.tsx b/packages/benchmark/web/pages/index.tsx index e6d2dd5..4cd46fb 100644 --- a/packages/benchmark/web/pages/index.tsx +++ b/packages/benchmark/web/pages/index.tsx @@ -80,7 +80,7 @@ export default function Home() { await verifier.initWasm(); console.time("Verification time"); - const result = await verifier.verify(proof, publicInput); + const result = await verifier.verify(proof, publicInput.serialize()); console.timeEnd("Verification time"); if (result) { @@ -152,7 +152,7 @@ export default function Home() { await verifier.initWasm(); console.time("Verification time"); - const result = await verifier.verify(proof, publicInput); + const result = await verifier.verify(proof, publicInput.serialize()); console.timeEnd("Verification time"); if (result) { diff --git a/packages/lib/src/core/membership_prover.ts b/packages/lib/src/core/membership_prover.ts index a3bab50..1a14348 100644 --- a/packages/lib/src/core/membership_prover.ts +++ b/packages/lib/src/core/membership_prover.ts @@ -1,14 +1,10 @@ import { Profiler } from "../helpers/profiler"; import { IProver, MerkleProof, NIZK, ProverConfig } from "../types"; +import { loadCircuit, fromSig, snarkJsWitnessGen } from "../helpers/utils"; import { - bigIntToBytes, - loadCircuit, - fromSig, - snarkJsWitnessGen -} from "../helpers/utils"; -import { - EffEcdsaPubInput, - EffEcdsaCircuitPubInput + PublicInput, + computeEffEcdsaPubInput, + CircuitPubInput } from "../helpers/efficient_ecdsa"; import wasm, { init } from "../wasm"; import { @@ -70,32 +66,20 @@ export class MembershipProver extends Profiler implements IProver { ): Promise { const { r, s, v } = fromSig(sig); - const circuitPubInput = EffEcdsaCircuitPubInput.computeFromSig( - r, - v, - msgHash + const effEcdsaPubInput = computeEffEcdsaPubInput(r, v, msgHash); + const circuitPubInput = new CircuitPubInput( + merkleProof.root, + effEcdsaPubInput.Tx, + effEcdsaPubInput.Ty, + effEcdsaPubInput.Ux, + effEcdsaPubInput.Uy ); - const effEcdsaPubInput = new EffEcdsaPubInput( - r, - v, - msgHash, - circuitPubInput - ); - - const merkleRootSer: Uint8Array = bigIntToBytes(merkleProof.root, 32); - const circuitPubInputSer = circuitPubInput.serialize(); - - // Concatenate circuitPubInputSer and merkleRootSer to construct the full public input - const pubInput = new Uint8Array( - merkleRootSer.length + circuitPubInputSer.length - ); - pubInput.set(merkleRootSer); - pubInput.set(circuitPubInputSer, merkleRootSer.length); + const publicInput = new PublicInput(r, v, msgHash, circuitPubInput); const witnessGenInput = { s, ...merkleProof, - ...effEcdsaPubInput.circuitPubInput + ...effEcdsaPubInput }; this.time("Generate witness"); @@ -109,13 +93,17 @@ export class MembershipProver extends Profiler implements IProver { const circuitBin = await loadCircuit(this.circuit); this.timeEnd("Load circuit"); + // Get the public input in bytes + const circuitPublicInput: Uint8Array = + publicInput.circuitPubInput.serialize(); + this.time("Prove"); - let proof = wasm.prove(circuitBin, witness.data, pubInput); + let proof = wasm.prove(circuitBin, witness.data, circuitPublicInput); this.timeEnd("Prove"); return { proof, - publicInput: pubInput + publicInput }; } } diff --git a/packages/lib/src/core/membership_verifier.ts b/packages/lib/src/core/membership_verifier.ts index 95378b0..abb3e24 100644 --- a/packages/lib/src/core/membership_verifier.ts +++ b/packages/lib/src/core/membership_verifier.ts @@ -6,6 +6,10 @@ import { Profiler } from "../helpers/profiler"; import { loadCircuit } from "../helpers/utils"; import { IVerifier, VerifyConfig } from "../types"; import wasm, { init } from "../wasm"; +import { + PublicInput, + verifyEffEcdsaPubInput +} from "../helpers/efficient_ecdsa"; /** * ECDSA Membership Verifier @@ -35,19 +39,32 @@ export class MembershipVerifier extends Profiler implements IVerifier { await init(); } - async verify(proof: Uint8Array, publicInput: Uint8Array): Promise { + async verify( + proof: Uint8Array, + publicInputSer: Uint8Array + ): Promise { this.time("Load circuit"); const circuitBin = await loadCircuit(this.circuit); this.timeEnd("Load circuit"); + this.time("Verify public input"); + const publicInput = PublicInput.deserialize(publicInputSer); + const isPubInputValid = verifyEffEcdsaPubInput(publicInput); + this.timeEnd("Verify public input"); + this.time("Verify proof"); - let result; + let isProofValid; try { - result = await wasm.verify(circuitBin, proof, publicInput); + isProofValid = await wasm.verify( + circuitBin, + proof, + publicInput.circuitPubInput.serialize() + ); } catch (_e) { - result = false; + isProofValid = false; } + this.timeEnd("Verify proof"); - return result; + return isProofValid && isPubInputValid; } } diff --git a/packages/lib/src/helpers/efficient_ecdsa.ts b/packages/lib/src/helpers/efficient_ecdsa.ts index a9b4eee..6484e5e 100644 --- a/packages/lib/src/helpers/efficient_ecdsa.ts +++ b/packages/lib/src/helpers/efficient_ecdsa.ts @@ -2,6 +2,7 @@ var EC = require("elliptic").ec; const BN = require("bn.js"); import { bytesToBigInt, bigIntToBytes } from "./utils"; +import { EffECDSAPubInput } from "../types"; const ec = new EC("secp256k1"); @@ -11,78 +12,68 @@ const SECP256K1_N = new BN( ); /** - * Public inputs that are passed into the efficient ECDSA circuit - * This doesn't include the other public values, which are the group element R and the msgHash. + * Public inputs that are passed into the membership circuit + * This doesn't include the public values that aren't passed into the circuit, + * which are the group element R and the msgHash. */ -export class EffEcdsaCircuitPubInput { +export class CircuitPubInput { + merkleRoot: bigint; Tx: bigint; Ty: bigint; Ux: bigint; Uy: bigint; - constructor(Tx: bigint, Ty: bigint, Ux: bigint, Uy: bigint) { + constructor( + merkleRoot: bigint, + Tx: bigint, + Ty: bigint, + Ux: bigint, + Uy: bigint + ) { + this.merkleRoot = merkleRoot; this.Tx = Tx; this.Ty = Ty; this.Ux = Ux; this.Uy = Uy; } - static computeFromSig( - r: bigint, - v: bigint, - msgHash: Buffer - ): EffEcdsaCircuitPubInput { - const isYOdd = (v - BigInt(27)) % BigInt(2); - const rPoint = ec.keyFromPublic( - ec.curve.pointFromX(new BN(r), isYOdd).encode("hex"), - "hex" - ); - - // Get the group element: -(m * r^−1 * G) - const rInv = new BN(r).invm(SECP256K1_N); - - // w = -(r^-1 * msg) - const w = rInv.mul(new BN(msgHash)).neg().umod(SECP256K1_N); - // U = -(w * G) = -(r^-1 * msg * G) - const U = ec.curve.g.mul(w); - - // T = r^-1 * R - const T = rPoint.getPublic().mul(rInv); - - return new EffEcdsaCircuitPubInput( - BigInt(T.getX().toString()), - BigInt(T.getY().toString()), - BigInt(U.getX().toString()), - BigInt(U.getY().toString()) - ); - } - serialize(): Uint8Array { - let serialized = new Uint8Array(32 * 4); + let serialized = new Uint8Array(32 * 5); - serialized.set(bigIntToBytes(this.Tx, 32), 0); - serialized.set(bigIntToBytes(this.Ty, 32), 32); - serialized.set(bigIntToBytes(this.Ux, 32), 64); - serialized.set(bigIntToBytes(this.Uy, 32), 96); + serialized.set(bigIntToBytes(this.merkleRoot, 32), 0); + serialized.set(bigIntToBytes(this.Tx, 32), 32); + serialized.set(bigIntToBytes(this.Ty, 32), 64); + serialized.set(bigIntToBytes(this.Ux, 32), 96); + serialized.set(bigIntToBytes(this.Uy, 32), 128); return serialized; } + + static deserialize(serialized: Uint8Array): CircuitPubInput { + const merkleRoot = bytesToBigInt(serialized.slice(0, 32)); + const Tx = bytesToBigInt(serialized.slice(32, 64)); + const Ty = bytesToBigInt(serialized.slice(64, 96)); + const Ux = bytesToBigInt(serialized.slice(96, 128)); + const Uy = bytesToBigInt(serialized.slice(128, 160)); + + return new CircuitPubInput(merkleRoot, Tx, Ty, Ux, Uy); + } } /** - * Public values of efficient ECDSA + * Public values of the membership circuit */ -export class EffEcdsaPubInput { +export class PublicInput { r: bigint; rV: bigint; msgHash: Buffer; - circuitPubInput: EffEcdsaCircuitPubInput; + circuitPubInput: CircuitPubInput; constructor( r: bigint, v: bigint, msgHash: Buffer, - circuitPubInput: EffEcdsaCircuitPubInput + circuitPubInput: CircuitPubInput ) { this.r = r; this.rV = v; @@ -90,52 +81,71 @@ export class EffEcdsaPubInput { this.circuitPubInput = circuitPubInput; } - /** - * Serialize the public input into a Uint8Array - * @returns the serialized public input - */ serialize(): Uint8Array { - let serialized = new Uint8Array(32 * 6 + 1); + const circuitPubInput: Uint8Array = this.circuitPubInput.serialize(); + let serialized = new Uint8Array( + 32 + 1 + this.msgHash.byteLength + circuitPubInput.byteLength + ); serialized.set(bigIntToBytes(this.r, 32), 0); serialized.set(bigIntToBytes(this.rV, 1), 32); - serialized.set(this.msgHash, 33); - serialized.set(bigIntToBytes(this.circuitPubInput.Tx, 32), 65); - serialized.set(bigIntToBytes(this.circuitPubInput.Ty, 32), 97); - serialized.set(bigIntToBytes(this.circuitPubInput.Ux, 32), 129); - serialized.set(bigIntToBytes(this.circuitPubInput.Uy, 32), 161); + serialized.set(circuitPubInput, 33); + serialized.set(this.msgHash, 33 + circuitPubInput.byteLength); return serialized; } - /** - * Instantiate EffEcdsaPubInput from a serialized Uint8Array - * @param serialized Uint8Array serialized by the serialize() function - * @returns EffEcdsaPubInput - */ - static deserialize(serialized: Uint8Array): EffEcdsaPubInput { + static deserialize(serialized: Uint8Array): PublicInput { const r = bytesToBigInt(serialized.slice(0, 32)); const rV = bytesToBigInt(serialized.slice(32, 33)); - const msg = serialized.slice(33, 65); - const Tx = bytesToBigInt(serialized.slice(65, 97)); - const Ty = bytesToBigInt(serialized.slice(97, 129)); - const Ux = bytesToBigInt(serialized.slice(129, 161)); - const Uy = bytesToBigInt(serialized.slice(161, 193)); - - return new EffEcdsaPubInput( - r, - rV, - Buffer.from(msg), - new EffEcdsaCircuitPubInput(Tx, Ty, Ux, Uy) + const circuitPubInput: CircuitPubInput = CircuitPubInput.deserialize( + serialized.slice(32 + 1, 32 + 1 + 32 * 5) ); + const msgHash = serialized.slice(32 + 1 + 32 * 5); + + return new PublicInput(r, rV, Buffer.from(msgHash), circuitPubInput); } } +/** + * Compute the group elements T and U for efficient ecdsa + * http://localhost:1313/posts/efficient-ecdsa-1/ + */ +export const computeEffEcdsaPubInput = ( + r: bigint, + v: bigint, + msgHash: Buffer +): EffECDSAPubInput => { + const isYOdd = (v - BigInt(27)) % BigInt(2); + const rPoint = ec.keyFromPublic( + ec.curve.pointFromX(new BN(r), isYOdd).encode("hex"), + "hex" + ); + + // Get the group element: -(m * r^−1 * G) + const rInv = new BN(r).invm(SECP256K1_N); + + // w = -(r^-1 * msg) + const w = rInv.mul(new BN(msgHash)).neg().umod(SECP256K1_N); + // U = -(w * G) = -(r^-1 * msg * G) + const U = ec.curve.g.mul(w); + + // T = r^-1 * R + const T = rPoint.getPublic().mul(rInv); + + return { + Tx: BigInt(T.getX().toString()), + Ty: BigInt(T.getY().toString()), + Ux: BigInt(U.getX().toString()), + Uy: BigInt(U.getY().toString()) + }; +}; + /** * Verify the public values of the efficient ECDSA circuit */ -export const verifyEffEcdsaPubInput = (pubInput: EffEcdsaPubInput): boolean => { - const expectedCircuitInput = EffEcdsaCircuitPubInput.computeFromSig( +export const verifyEffEcdsaPubInput = (pubInput: PublicInput): boolean => { + const expectedCircuitInput = computeEffEcdsaPubInput( pubInput.r, pubInput.rV, pubInput.msgHash diff --git a/packages/lib/src/types.ts b/packages/lib/src/types.ts index 349f575..9bcfa87 100644 --- a/packages/lib/src/types.ts +++ b/packages/lib/src/types.ts @@ -1,3 +1,5 @@ +import { PublicInput } from "./helpers/efficient_ecdsa"; + // The same structure as MerkleProof in @zk-kit/incremental-merkle-tree. // Not directly using MerkleProof defined in @zk-kit/incremental-merkle-tree so // library users can choose whatever merkle tree management method they want. @@ -6,10 +8,16 @@ export interface MerkleProof { siblings: bigint[]; pathIndices: number[]; } +export interface EffECDSAPubInput { + Tx: bigint; + Ty: bigint; + Ux: bigint; + Uy: bigint; +} export interface NIZK { proof: Uint8Array; - publicInput: Uint8Array; + publicInput: PublicInput; } export interface ProverConfig { diff --git a/packages/lib/tests/efficient_ecdsa.test.ts b/packages/lib/tests/efficient_ecdsa.test.ts index e244c7f..8053a4b 100644 --- a/packages/lib/tests/efficient_ecdsa.test.ts +++ b/packages/lib/tests/efficient_ecdsa.test.ts @@ -1,6 +1,6 @@ import { - EffEcdsaCircuitPubInput, - EffEcdsaPubInput, + CircuitPubInput, + PublicInput, verifyEffEcdsaPubInput } from "../src/helpers/efficient_ecdsa"; import { hashPersonalMessage } from "@ethereumjs/util"; @@ -27,6 +27,7 @@ describe("efficient_ecdsa", () => { */ it("should verify valid public input", () => { + const merkleRoot = BigInt("0xbeef"); const msg = Buffer.from("harry potter"); const msgHash = hashPersonalMessage(msg); @@ -47,13 +48,8 @@ describe("efficient_ecdsa", () => { ); const v = BigInt(28); - const circuitPubInput = new EffEcdsaCircuitPubInput(Tx, Ty, Ux, Uy); - const effEcdsaPubInput = new EffEcdsaPubInput( - rX, - v, - msgHash, - circuitPubInput - ); + const circuitPubInput = new CircuitPubInput(merkleRoot, Tx, Ty, Ux, Uy); + const effEcdsaPubInput = new PublicInput(rX, v, msgHash, circuitPubInput); const isValid = verifyEffEcdsaPubInput(effEcdsaPubInput); expect(isValid).toBe(true); diff --git a/packages/lib/tests/membership_nizk.test.ts b/packages/lib/tests/membership_nizk.test.ts index adc9c38..60cafa8 100644 --- a/packages/lib/tests/membership_nizk.test.ts +++ b/packages/lib/tests/membership_nizk.test.ts @@ -87,23 +87,23 @@ describe("membership prove and verify", () => { nizk = await pubKeyMembershipProver.prove(sig, msgHash, merkleProof); const { proof, publicInput } = nizk; - expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe( - true - ); + expect( + await pubKeyMembershipVerifier.verify(proof, publicInput.serialize()) + ).toBe(true); }); it("should assert invalid proof", async () => { const { publicInput } = nizk; let proof = nizk.proof; proof[0] = proof[0] += 1; - expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe( - false - ); + expect( + await pubKeyMembershipVerifier.verify(proof, publicInput.serialize()) + ).toBe(false); }); it("should assert invalid public input", async () => { const { proof } = nizk; - let publicInput = nizk.publicInput; + let publicInput = nizk.publicInput.serialize(); publicInput[0] = publicInput[0] += 1; expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe( false @@ -158,7 +158,10 @@ describe("membership prove and verify", () => { await addressMembershipVerifier.initWasm(); expect( - await addressMembershipVerifier.verify(nizk.proof, nizk.publicInput) + await addressMembershipVerifier.verify( + nizk.proof, + nizk.publicInput.serialize() + ) ).toBe(true); }); @@ -166,14 +169,14 @@ describe("membership prove and verify", () => { const { publicInput } = nizk; let proof = nizk.proof; proof[0] = proof[0] += 1; - expect(await addressMembershipVerifier.verify(proof, publicInput)).toBe( - false - ); + expect( + await addressMembershipVerifier.verify(proof, publicInput.serialize()) + ).toBe(false); }); it("should assert invalid public input", async () => { const { proof } = nizk; - let publicInput = nizk.publicInput; + let publicInput = nizk.publicInput.serialize(); publicInput[0] = publicInput[0] += 1; expect(await addressMembershipVerifier.verify(proof, publicInput)).toBe( false