diff --git a/packages/protocols/src/rln.ts b/packages/protocols/src/rln.ts index 544ee77..1639660 100644 --- a/packages/protocols/src/rln.ts +++ b/packages/protocols/src/rln.ts @@ -1,10 +1,40 @@ import { MerkleProof } from "@zk-kit/incremental-merkle-tree" import { poseidon } from "circomlibjs" -import { StrBigInt } from "./types" +import { groth16 } from "snarkjs" +import { FullProof, RLNPublicSignals, StrBigInt } from "./types" import { Fq, genSignalHash } from "./utils" import ZkProtocol from "./zk-protocol" export default class RLN extends ZkProtocol { + /** + * The number of public signals that should be returned by snarkjs when generating a proof. + */ + private static PUBLIC_SIGNALS_COUNT: number = 6 + + /** + * Generates a SnarkJS full proof with Groth16. + * @param witness The parameters for creating the proof. + * @param wasmFilePath The WASM file path. + * @param finalZkeyPath The ZKey file path. + * @returns The full SnarkJS proof. + */ + public static async genProof(witness: any, wasmFilePath: string, finalZkeyPath: string): Promise { + const { proof, publicSignalsArray } = await groth16.fullProve(witness, wasmFilePath, finalZkeyPath, null) + + if (publicSignalsArray.length !== RLN.PUBLIC_SIGNALS_COUNT) throw new Error("Error while generating proof") + + const publicSignals: RLNPublicSignals = { + yShare: publicSignalsArray[0], + merkleRoot: publicSignalsArray[1], + internalNullifier: publicSignalsArray[2], + signalHash: publicSignalsArray[3], + epoch: publicSignalsArray[4], + rlnIdentifier: publicSignalsArray[5] + } + + return { proof, publicSignals } + } + /** * Creates witness for rln proof * @param identitySecret identity secret diff --git a/packages/protocols/src/semaphore.ts b/packages/protocols/src/semaphore.ts index 03e6bd7..5e9738d 100644 --- a/packages/protocols/src/semaphore.ts +++ b/packages/protocols/src/semaphore.ts @@ -1,10 +1,38 @@ import { MerkleProof } from "@zk-kit/incremental-merkle-tree" import { poseidon } from "circomlibjs" -import { SemaphoreWitness, StrBigInt } from "./types" +import { groth16 } from "snarkjs" +import { FullProof, StrBigInt, SemaphoreWitness, SemaphorePublicSignals } from "./types" import { genSignalHash } from "./utils" import ZkProtocol from "./zk-protocol" export default class Semaphore extends ZkProtocol { + /** + * The number of public signals that should be returned by snarkjs when generating a proof. + */ + private static PUBLIC_SIGNALS_COUNT: number = 6 + + /** + * Generates a SnarkJS full proof with Groth16. + * @param witness The parameters for creating the proof. + * @param wasmFilePath The WASM file path. + * @param finalZkeyPath The ZKey file path. + * @returns The full SnarkJS proof. + */ + public static async genProof(witness: any, wasmFilePath: string, finalZkeyPath: string): Promise { + const { proof, publicSignalsArray } = await groth16.fullProve(witness, wasmFilePath, finalZkeyPath, null) + + if (publicSignalsArray.length !== Semaphore.PUBLIC_SIGNALS_COUNT) throw new Error("Error while generating proof") + + const publicSignals: SemaphorePublicSignals = { + merkleRoot: publicSignalsArray[0], + nullifierHash: publicSignalsArray[1], + signalHash: publicSignalsArray[2], + externalNullifier: publicSignalsArray[3] + } + + return { proof, publicSignals } + } + /** * Creates a Semaphore witness for the Semaphore ZK proof. * @param identityTrapdoor The identity trapdoor. diff --git a/packages/protocols/src/types/index.ts b/packages/protocols/src/types/index.ts index 2d75845..df8d759 100644 --- a/packages/protocols/src/types/index.ts +++ b/packages/protocols/src/types/index.ts @@ -10,7 +10,23 @@ export type Proof = { export type FullProof = { proof: Proof - publicSignals: StrBigInt[] + publicSignals: RLNPublicSignals | SemaphorePublicSignals +} + +export type RLNPublicSignals = { + yShare: StrBigInt + merkleRoot: StrBigInt + internalNullifier: StrBigInt + signalHash: StrBigInt + epoch: StrBigInt + rlnIdentifier: StrBigInt +} + +export type SemaphorePublicSignals = { + merkleRoot: StrBigInt + nullifierHash: StrBigInt + signalHash: StrBigInt + externalNullifier: StrBigInt } export type SolidityProof = StrBigInt[] diff --git a/packages/protocols/src/zk-protocol.ts b/packages/protocols/src/zk-protocol.ts index ec8d36f..89db91b 100644 --- a/packages/protocols/src/zk-protocol.ts +++ b/packages/protocols/src/zk-protocol.ts @@ -1,20 +1,8 @@ /* istanbul ignore file */ import { groth16 } from "snarkjs" -import { FullProof, SolidityProof } from "./types" +import { FullProof, SolidityProof, StrBigInt } from "./types" export default class ZkProtocol { - /** - * Generates a SnarkJS full proof with Groth16. - * @param witness The parameters for creating the proof. - * @param wasmFilePath The WASM file path. - * @param finalZkeyPath The ZKey file path. - * @returns The full SnarkJS proof. - */ - public static async genProof(witness: any, wasmFilePath: string, finalZkeyPath: string): Promise { - const { proof, publicSignals } = await groth16.fullProve(witness, wasmFilePath, finalZkeyPath, null) - return { proof, publicSignals } - } - /** * Verifies a zero-knowledge SnarkJS proof. * @param verificationKey The zero-knowledge verification key. @@ -24,7 +12,9 @@ export default class ZkProtocol { public static verifyProof(verificationKey: string, fullProof: FullProof): Promise { const { proof, publicSignals } = fullProof - return groth16.verify(verificationKey, publicSignals, proof) + const publicSignalsArray: StrBigInt[] = Object.values(publicSignals) + + return groth16.verify(verificationKey, publicSignalsArray, proof) } /** diff --git a/packages/protocols/tests/rln.test.ts b/packages/protocols/tests/rln.test.ts index 637df14..de5d6d6 100644 --- a/packages/protocols/tests/rln.test.ts +++ b/packages/protocols/tests/rln.test.ts @@ -4,6 +4,7 @@ import * as fs from "fs" import * as path from "path" import { RLN } from "../src" import { generateMerkleProof, genExternalNullifier, genSignalHash } from "../src/utils" +import { RLNPublicSignals } from "../src/types" describe("RLN", () => { const zkeyFiles = "./packages/protocols/zkeyFiles" @@ -65,7 +66,15 @@ describe("RLN", () => { const witness = RLN.genWitness(secretHash, merkleProof, epoch, signal, rlnIdentifier) const [y, nullifier] = RLN.calculateOutput(secretHash, BigInt(epoch), rlnIdentifier, signalHash) - const publicSignals = [y, merkleProof.root, nullifier, signalHash, epoch, rlnIdentifier] + + const publicSignals: RLNPublicSignals = { + yShare: y, + merkleRoot: merkleProof.root, + internalNullifier: nullifier, + signalHash, + epoch, + rlnIdentifier + } const vkeyPath = path.join(zkeyFiles, "rln", "verification_key.json") const vKey = JSON.parse(fs.readFileSync(vkeyPath, "utf-8")) @@ -77,6 +86,7 @@ describe("RLN", () => { const response = await RLN.verifyProof(vKey, { proof: fullProof.proof, publicSignals }) expect(response).toBe(true) + expect(fullProof.publicSignals).toEqual(publicSignals) }, 30000) it("Should retrieve user secret after spaming", () => { diff --git a/packages/protocols/tests/semaphore.test.ts b/packages/protocols/tests/semaphore.test.ts index 29f27ea..9718254 100644 --- a/packages/protocols/tests/semaphore.test.ts +++ b/packages/protocols/tests/semaphore.test.ts @@ -3,6 +3,7 @@ import { getCurveFromName } from "ffjavascript" import fs from "fs" import path from "path" import { Semaphore } from "../src" +import { SemaphorePublicSignals } from "../src/types" import { generateMerkleProof, genExternalNullifier, genSignalHash } from "../src/utils" describe("Semaphore", () => { @@ -64,11 +65,18 @@ describe("Semaphore", () => { const vkeyPath = path.join("./packages/protocols/zkeyFiles", "semaphore", "verification_key.json") const vKey = JSON.parse(fs.readFileSync(vkeyPath, "utf-8")) const nullifierHash = Semaphore.genNullifierHash(externalNullifier, identity.getNullifier()) - const publicSignals = [merkleProof.root.toString(), nullifierHash, genSignalHash(signal), externalNullifier] + + const publicSignals: SemaphorePublicSignals = { + merkleRoot: merkleProof.root.toString(), + nullifierHash, + signalHash: genSignalHash(signal), + externalNullifier + } const response = await Semaphore.verifyProof(vKey, { proof: fullProof.proof, publicSignals }) expect(response).toBe(true) + expect(fullProof.publicSignals).toEqual(publicSignals) }, 30000) }) })