Fix verification: Add public input validation

This commit is contained in:
Daniel Tehrani
2023-02-02 15:06:39 +09:00
parent fcb816816e
commit ed0993a0a5
10 changed files with 157 additions and 134 deletions

View File

@@ -1,6 +1,7 @@
{
"editor.formatOnSave": true,
"cSpell.words": [
"merkle",
"NIZK"
]
}

View File

@@ -78,7 +78,7 @@ const benchAddrMembership = async () => {
await verifier.initWasm();
// Verify proof
await verifier.verify(proof, publicInput);
await verifier.verify(proof, publicInput.serialize());
};
export default benchAddrMembership;

View File

@@ -75,7 +75,7 @@ const benchPubKeyMembership = async () => {
await verifier.initWasm();
// Verify proof
await verifier.verify(proof, publicInput);
await verifier.verify(proof, publicInput.serialize());
};
export default benchPubKeyMembership;

View File

@@ -80,7 +80,7 @@ export default function Home() {
await verifier.initWasm();
console.time("Verification time");
const result = await verifier.verify(proof, publicInput);
const result = await verifier.verify(proof, publicInput.serialize());
console.timeEnd("Verification time");
if (result) {
@@ -152,7 +152,7 @@ export default function Home() {
await verifier.initWasm();
console.time("Verification time");
const result = await verifier.verify(proof, publicInput);
const result = await verifier.verify(proof, publicInput.serialize());
console.timeEnd("Verification time");
if (result) {

View File

@@ -1,14 +1,10 @@
import { Profiler } from "../helpers/profiler";
import { IProver, MerkleProof, NIZK, ProverConfig } from "../types";
import { loadCircuit, fromSig, snarkJsWitnessGen } from "../helpers/utils";
import {
bigIntToBytes,
loadCircuit,
fromSig,
snarkJsWitnessGen
} from "../helpers/utils";
import {
EffEcdsaPubInput,
EffEcdsaCircuitPubInput
PublicInput,
computeEffEcdsaPubInput,
CircuitPubInput
} from "../helpers/efficient_ecdsa";
import wasm, { init } from "../wasm";
import {
@@ -70,32 +66,20 @@ export class MembershipProver extends Profiler implements IProver {
): Promise<NIZK> {
const { r, s, v } = fromSig(sig);
const circuitPubInput = EffEcdsaCircuitPubInput.computeFromSig(
r,
v,
msgHash
const effEcdsaPubInput = computeEffEcdsaPubInput(r, v, msgHash);
const circuitPubInput = new CircuitPubInput(
merkleProof.root,
effEcdsaPubInput.Tx,
effEcdsaPubInput.Ty,
effEcdsaPubInput.Ux,
effEcdsaPubInput.Uy
);
const effEcdsaPubInput = new EffEcdsaPubInput(
r,
v,
msgHash,
circuitPubInput
);
const merkleRootSer: Uint8Array = bigIntToBytes(merkleProof.root, 32);
const circuitPubInputSer = circuitPubInput.serialize();
// Concatenate circuitPubInputSer and merkleRootSer to construct the full public input
const pubInput = new Uint8Array(
merkleRootSer.length + circuitPubInputSer.length
);
pubInput.set(merkleRootSer);
pubInput.set(circuitPubInputSer, merkleRootSer.length);
const publicInput = new PublicInput(r, v, msgHash, circuitPubInput);
const witnessGenInput = {
s,
...merkleProof,
...effEcdsaPubInput.circuitPubInput
...effEcdsaPubInput
};
this.time("Generate witness");
@@ -109,13 +93,17 @@ export class MembershipProver extends Profiler implements IProver {
const circuitBin = await loadCircuit(this.circuit);
this.timeEnd("Load circuit");
// Get the public input in bytes
const circuitPublicInput: Uint8Array =
publicInput.circuitPubInput.serialize();
this.time("Prove");
let proof = wasm.prove(circuitBin, witness.data, pubInput);
let proof = wasm.prove(circuitBin, witness.data, circuitPublicInput);
this.timeEnd("Prove");
return {
proof,
publicInput: pubInput
publicInput
};
}
}

View File

@@ -6,6 +6,10 @@ import { Profiler } from "../helpers/profiler";
import { loadCircuit } from "../helpers/utils";
import { IVerifier, VerifyConfig } from "../types";
import wasm, { init } from "../wasm";
import {
PublicInput,
verifyEffEcdsaPubInput
} from "../helpers/efficient_ecdsa";
/**
* ECDSA Membership Verifier
@@ -35,19 +39,32 @@ export class MembershipVerifier extends Profiler implements IVerifier {
await init();
}
async verify(proof: Uint8Array, publicInput: Uint8Array): Promise<boolean> {
async verify(
proof: Uint8Array,
publicInputSer: Uint8Array
): Promise<boolean> {
this.time("Load circuit");
const circuitBin = await loadCircuit(this.circuit);
this.timeEnd("Load circuit");
this.time("Verify public input");
const publicInput = PublicInput.deserialize(publicInputSer);
const isPubInputValid = verifyEffEcdsaPubInput(publicInput);
this.timeEnd("Verify public input");
this.time("Verify proof");
let result;
let isProofValid;
try {
result = await wasm.verify(circuitBin, proof, publicInput);
isProofValid = await wasm.verify(
circuitBin,
proof,
publicInput.circuitPubInput.serialize()
);
} catch (_e) {
result = false;
isProofValid = false;
}
this.timeEnd("Verify proof");
return result;
return isProofValid && isPubInputValid;
}
}

View File

@@ -2,6 +2,7 @@ var EC = require("elliptic").ec;
const BN = require("bn.js");
import { bytesToBigInt, bigIntToBytes } from "./utils";
import { EffECDSAPubInput } from "../types";
const ec = new EC("secp256k1");
@@ -11,78 +12,68 @@ const SECP256K1_N = new BN(
);
/**
* Public inputs that are passed into the efficient ECDSA circuit
* This doesn't include the other public values, which are the group element R and the msgHash.
* Public inputs that are passed into the membership circuit
* This doesn't include the public values that aren't passed into the circuit,
* which are the group element R and the msgHash.
*/
export class EffEcdsaCircuitPubInput {
export class CircuitPubInput {
merkleRoot: bigint;
Tx: bigint;
Ty: bigint;
Ux: bigint;
Uy: bigint;
constructor(Tx: bigint, Ty: bigint, Ux: bigint, Uy: bigint) {
constructor(
merkleRoot: bigint,
Tx: bigint,
Ty: bigint,
Ux: bigint,
Uy: bigint
) {
this.merkleRoot = merkleRoot;
this.Tx = Tx;
this.Ty = Ty;
this.Ux = Ux;
this.Uy = Uy;
}
static computeFromSig(
r: bigint,
v: bigint,
msgHash: Buffer
): EffEcdsaCircuitPubInput {
const isYOdd = (v - BigInt(27)) % BigInt(2);
const rPoint = ec.keyFromPublic(
ec.curve.pointFromX(new BN(r), isYOdd).encode("hex"),
"hex"
);
// Get the group element: -(m * r^1 * G)
const rInv = new BN(r).invm(SECP256K1_N);
// w = -(r^-1 * msg)
const w = rInv.mul(new BN(msgHash)).neg().umod(SECP256K1_N);
// U = -(w * G) = -(r^-1 * msg * G)
const U = ec.curve.g.mul(w);
// T = r^-1 * R
const T = rPoint.getPublic().mul(rInv);
return new EffEcdsaCircuitPubInput(
BigInt(T.getX().toString()),
BigInt(T.getY().toString()),
BigInt(U.getX().toString()),
BigInt(U.getY().toString())
);
}
serialize(): Uint8Array {
let serialized = new Uint8Array(32 * 4);
let serialized = new Uint8Array(32 * 5);
serialized.set(bigIntToBytes(this.Tx, 32), 0);
serialized.set(bigIntToBytes(this.Ty, 32), 32);
serialized.set(bigIntToBytes(this.Ux, 32), 64);
serialized.set(bigIntToBytes(this.Uy, 32), 96);
serialized.set(bigIntToBytes(this.merkleRoot, 32), 0);
serialized.set(bigIntToBytes(this.Tx, 32), 32);
serialized.set(bigIntToBytes(this.Ty, 32), 64);
serialized.set(bigIntToBytes(this.Ux, 32), 96);
serialized.set(bigIntToBytes(this.Uy, 32), 128);
return serialized;
}
static deserialize(serialized: Uint8Array): CircuitPubInput {
const merkleRoot = bytesToBigInt(serialized.slice(0, 32));
const Tx = bytesToBigInt(serialized.slice(32, 64));
const Ty = bytesToBigInt(serialized.slice(64, 96));
const Ux = bytesToBigInt(serialized.slice(96, 128));
const Uy = bytesToBigInt(serialized.slice(128, 160));
return new CircuitPubInput(merkleRoot, Tx, Ty, Ux, Uy);
}
}
/**
* Public values of efficient ECDSA
* Public values of the membership circuit
*/
export class EffEcdsaPubInput {
export class PublicInput {
r: bigint;
rV: bigint;
msgHash: Buffer;
circuitPubInput: EffEcdsaCircuitPubInput;
circuitPubInput: CircuitPubInput;
constructor(
r: bigint,
v: bigint,
msgHash: Buffer,
circuitPubInput: EffEcdsaCircuitPubInput
circuitPubInput: CircuitPubInput
) {
this.r = r;
this.rV = v;
@@ -90,52 +81,71 @@ export class EffEcdsaPubInput {
this.circuitPubInput = circuitPubInput;
}
/**
* Serialize the public input into a Uint8Array
* @returns the serialized public input
*/
serialize(): Uint8Array {
let serialized = new Uint8Array(32 * 6 + 1);
const circuitPubInput: Uint8Array = this.circuitPubInput.serialize();
let serialized = new Uint8Array(
32 + 1 + this.msgHash.byteLength + circuitPubInput.byteLength
);
serialized.set(bigIntToBytes(this.r, 32), 0);
serialized.set(bigIntToBytes(this.rV, 1), 32);
serialized.set(this.msgHash, 33);
serialized.set(bigIntToBytes(this.circuitPubInput.Tx, 32), 65);
serialized.set(bigIntToBytes(this.circuitPubInput.Ty, 32), 97);
serialized.set(bigIntToBytes(this.circuitPubInput.Ux, 32), 129);
serialized.set(bigIntToBytes(this.circuitPubInput.Uy, 32), 161);
serialized.set(circuitPubInput, 33);
serialized.set(this.msgHash, 33 + circuitPubInput.byteLength);
return serialized;
}
/**
* Instantiate EffEcdsaPubInput from a serialized Uint8Array
* @param serialized Uint8Array serialized by the serialize() function
* @returns EffEcdsaPubInput
*/
static deserialize(serialized: Uint8Array): EffEcdsaPubInput {
static deserialize(serialized: Uint8Array): PublicInput {
const r = bytesToBigInt(serialized.slice(0, 32));
const rV = bytesToBigInt(serialized.slice(32, 33));
const msg = serialized.slice(33, 65);
const Tx = bytesToBigInt(serialized.slice(65, 97));
const Ty = bytesToBigInt(serialized.slice(97, 129));
const Ux = bytesToBigInt(serialized.slice(129, 161));
const Uy = bytesToBigInt(serialized.slice(161, 193));
return new EffEcdsaPubInput(
r,
rV,
Buffer.from(msg),
new EffEcdsaCircuitPubInput(Tx, Ty, Ux, Uy)
const circuitPubInput: CircuitPubInput = CircuitPubInput.deserialize(
serialized.slice(32 + 1, 32 + 1 + 32 * 5)
);
const msgHash = serialized.slice(32 + 1 + 32 * 5);
return new PublicInput(r, rV, Buffer.from(msgHash), circuitPubInput);
}
}
/**
* Compute the group elements T and U for efficient ecdsa
* http://localhost:1313/posts/efficient-ecdsa-1/
*/
export const computeEffEcdsaPubInput = (
r: bigint,
v: bigint,
msgHash: Buffer
): EffECDSAPubInput => {
const isYOdd = (v - BigInt(27)) % BigInt(2);
const rPoint = ec.keyFromPublic(
ec.curve.pointFromX(new BN(r), isYOdd).encode("hex"),
"hex"
);
// Get the group element: -(m * r^1 * G)
const rInv = new BN(r).invm(SECP256K1_N);
// w = -(r^-1 * msg)
const w = rInv.mul(new BN(msgHash)).neg().umod(SECP256K1_N);
// U = -(w * G) = -(r^-1 * msg * G)
const U = ec.curve.g.mul(w);
// T = r^-1 * R
const T = rPoint.getPublic().mul(rInv);
return {
Tx: BigInt(T.getX().toString()),
Ty: BigInt(T.getY().toString()),
Ux: BigInt(U.getX().toString()),
Uy: BigInt(U.getY().toString())
};
};
/**
* Verify the public values of the efficient ECDSA circuit
*/
export const verifyEffEcdsaPubInput = (pubInput: EffEcdsaPubInput): boolean => {
const expectedCircuitInput = EffEcdsaCircuitPubInput.computeFromSig(
export const verifyEffEcdsaPubInput = (pubInput: PublicInput): boolean => {
const expectedCircuitInput = computeEffEcdsaPubInput(
pubInput.r,
pubInput.rV,
pubInput.msgHash

View File

@@ -1,3 +1,5 @@
import { PublicInput } from "./helpers/efficient_ecdsa";
// The same structure as MerkleProof in @zk-kit/incremental-merkle-tree.
// Not directly using MerkleProof defined in @zk-kit/incremental-merkle-tree so
// library users can choose whatever merkle tree management method they want.
@@ -6,10 +8,16 @@ export interface MerkleProof {
siblings: bigint[];
pathIndices: number[];
}
export interface EffECDSAPubInput {
Tx: bigint;
Ty: bigint;
Ux: bigint;
Uy: bigint;
}
export interface NIZK {
proof: Uint8Array;
publicInput: Uint8Array;
publicInput: PublicInput;
}
export interface ProverConfig {

View File

@@ -1,6 +1,6 @@
import {
EffEcdsaCircuitPubInput,
EffEcdsaPubInput,
CircuitPubInput,
PublicInput,
verifyEffEcdsaPubInput
} from "../src/helpers/efficient_ecdsa";
import { hashPersonalMessage } from "@ethereumjs/util";
@@ -27,6 +27,7 @@ describe("efficient_ecdsa", () => {
*/
it("should verify valid public input", () => {
const merkleRoot = BigInt("0xbeef");
const msg = Buffer.from("harry potter");
const msgHash = hashPersonalMessage(msg);
@@ -47,13 +48,8 @@ describe("efficient_ecdsa", () => {
);
const v = BigInt(28);
const circuitPubInput = new EffEcdsaCircuitPubInput(Tx, Ty, Ux, Uy);
const effEcdsaPubInput = new EffEcdsaPubInput(
rX,
v,
msgHash,
circuitPubInput
);
const circuitPubInput = new CircuitPubInput(merkleRoot, Tx, Ty, Ux, Uy);
const effEcdsaPubInput = new PublicInput(rX, v, msgHash, circuitPubInput);
const isValid = verifyEffEcdsaPubInput(effEcdsaPubInput);
expect(isValid).toBe(true);

View File

@@ -87,23 +87,23 @@ describe("membership prove and verify", () => {
nizk = await pubKeyMembershipProver.prove(sig, msgHash, merkleProof);
const { proof, publicInput } = nizk;
expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe(
true
);
expect(
await pubKeyMembershipVerifier.verify(proof, publicInput.serialize())
).toBe(true);
});
it("should assert invalid proof", async () => {
const { publicInput } = nizk;
let proof = nizk.proof;
proof[0] = proof[0] += 1;
expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe(
false
);
expect(
await pubKeyMembershipVerifier.verify(proof, publicInput.serialize())
).toBe(false);
});
it("should assert invalid public input", async () => {
const { proof } = nizk;
let publicInput = nizk.publicInput;
let publicInput = nizk.publicInput.serialize();
publicInput[0] = publicInput[0] += 1;
expect(await pubKeyMembershipVerifier.verify(proof, publicInput)).toBe(
false
@@ -158,7 +158,10 @@ describe("membership prove and verify", () => {
await addressMembershipVerifier.initWasm();
expect(
await addressMembershipVerifier.verify(nizk.proof, nizk.publicInput)
await addressMembershipVerifier.verify(
nizk.proof,
nizk.publicInput.serialize()
)
).toBe(true);
});
@@ -166,14 +169,14 @@ describe("membership prove and verify", () => {
const { publicInput } = nizk;
let proof = nizk.proof;
proof[0] = proof[0] += 1;
expect(await addressMembershipVerifier.verify(proof, publicInput)).toBe(
false
);
expect(
await addressMembershipVerifier.verify(proof, publicInput.serialize())
).toBe(false);
});
it("should assert invalid public input", async () => {
const { proof } = nizk;
let publicInput = nizk.publicInput;
let publicInput = nizk.publicInput.serialize();
publicInput[0] = publicInput[0] += 1;
expect(await addressMembershipVerifier.verify(proof, publicInput)).toBe(
false