From 0ebeea50fdd8275a110a0f6272c8e0cd40fb8357 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vinh=20Tr=E1=BB=8Bnh?= <108657096+vinhtc27@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:27:07 +0700 Subject: [PATCH] feat(rln): extend error handling for rln module (#358) Changes: - Unified error types (`PoseidonError`, `HashError`, etc.) across hashing, keygen, witness calculation, and serialization for consistent and descriptive error handling. - Refactored tests and examples to use `unwrap()` where safe, and limited `expect()` in library code to non-panicking cases with clear messaging. - Improved witness and proof generation by removing panicking code paths and enforcing proper error propagation. - Cleaned up outdated imports, removed unused operations in `graph.rs`, and updated public API documentation. - Updated C, Nim, and WASM FFI bindings with more robust serialization and clearer error log messages. - Added keywords to package.json and update dependencies in Makefile.toml and Nightly CI. --- .github/workflows/nightly-release.yml | 3 + rln-cli/src/examples/relay.rs | 15 +- rln-cli/src/examples/stateless.rs | 19 +- rln-wasm/Makefile.toml | 11 +- rln-wasm/examples/index.js | 289 +++++++-- rln-wasm/src/lib.rs | 4 +- rln-wasm/src/wasm_rln.rs | 32 +- rln-wasm/src/wasm_utils.rs | 149 ++++- rln-wasm/tests/browser.rs | 74 +-- rln-wasm/tests/node.rs | 65 +-- rln-wasm/tests/utils.rs | 22 +- rln/Cargo.toml | 6 +- rln/benches/pmtree_benchmark.rs | 2 +- rln/benches/poseidon_tree_benchmark.rs | 2 +- rln/ffi_c_examples/main.c | 138 ++++- rln/ffi_nim_examples/main.nim | 128 +++- rln/src/circuit/error.rs | 18 + rln/src/circuit/iden3calc.rs | 41 +- rln/src/circuit/iden3calc/graph.rs | 584 ++++--------------- rln/src/circuit/iden3calc/proto.rs | 28 +- rln/src/circuit/iden3calc/storage.rs | 92 +-- rln/src/circuit/mod.rs | 10 +- rln/src/error.rs | 33 +- rln/src/ffi/ffi_rln.rs | 38 +- rln/src/ffi/ffi_utils.rs | 137 +++-- rln/src/hashers.rs | 29 +- rln/src/pm_tree_adapter.rs | 69 ++- rln/src/poseidon_tree.rs | 4 +- rln/src/prelude.rs | 3 - rln/src/protocol/keygen.rs | 33 +- rln/src/protocol/proof.rs | 26 +- rln/src/protocol/witness.rs | 18 +- rln/src/public.rs | 80 +-- rln/src/utils.rs | 26 +- rln/tests/ffi.rs | 90 +-- rln/tests/ffi_utils.rs | 83 ++- rln/tests/poseidon_tree.rs | 10 +- rln/tests/protocol.rs | 32 +- rln/tests/public.rs | 116 ++-- utils/benches/merkle_tree_benchmark.rs | 12 +- utils/benches/poseidon_benchmark.rs | 2 +- utils/src/error.rs | 11 + utils/src/lib.rs | 10 +- utils/src/merkle_tree/error.rs | 13 +- utils/src/merkle_tree/full_merkle_tree.rs | 47 +- utils/src/merkle_tree/merkle_tree.rs | 22 +- utils/src/merkle_tree/mod.rs | 9 +- utils/src/merkle_tree/optimal_merkle_tree.rs | 44 +- utils/src/pm_tree/mod.rs | 6 +- utils/src/poseidon/error.rs | 8 + utils/src/poseidon/mod.rs | 7 +- utils/src/poseidon/poseidon_constants.rs | 18 +- utils/src/poseidon/poseidon_hash.rs | 19 +- utils/tests/merkle_tree.rs | 196 +++---- utils/tests/poseidon_constants.rs | 2 +- utils/tests/poseidon_hash_test.rs | 2 +- 56 files changed, 1667 insertions(+), 1320 deletions(-) create mode 100644 utils/src/error.rs create mode 100644 utils/src/poseidon/error.rs diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 17e6d7b..02aea94 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -129,6 +129,9 @@ jobs: sed -i.bak 's/rln-wasm/zerokit-rln-wasm/g' pkg/package.json && rm pkg/package.json.bak fi + jq '. + {keywords: ["zerokit", "rln", "wasm"]}' pkg/package.json > pkg/package.json.tmp && \ + mv pkg/package.json.tmp pkg/package.json + mkdir release cp -r pkg/* release/ tar -czvf rln-wasm-${{ matrix.feature }}.tar.gz release/ diff --git a/rln-cli/src/examples/relay.rs b/rln-cli/src/examples/relay.rs index 380c99b..716eb1c 100644 --- a/rln-cli/src/examples/relay.rs +++ b/rln-cli/src/examples/relay.rs @@ -10,7 +10,7 @@ use rln::prelude::{ hash_to_field_le, keygen, poseidon_hash, recover_id_secret, Fr, IdSecret, PmtreeConfigBuilder, RLNProofValues, RLNWitnessInput, RLN, }; -use zerokit_utils::Mode; +use zerokit_utils::pm_tree::Mode; const MESSAGE_LIMIT: u32 = 1; @@ -49,7 +49,7 @@ struct Identity { impl Identity { fn new() -> Self { - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); Identity { identity_secret, id_commitment, @@ -117,7 +117,8 @@ impl RLNSystem { let index = self.rln.leaves_set(); let identity = Identity::new(); - let rate_commitment = poseidon_hash(&[identity.id_commitment, Fr::from(MESSAGE_LIMIT)]); + let rate_commitment = + poseidon_hash(&[identity.id_commitment, Fr::from(MESSAGE_LIMIT)]).unwrap(); match self.rln.set_next_leaf(rate_commitment) { Ok(_) => { println!("Registered User Index: {index}"); @@ -146,7 +147,7 @@ impl RLNSystem { }; let (path_elements, identity_path_index) = self.rln.get_merkle_proof(user_index)?; - let x = hash_to_field_le(signal.as_bytes()); + let x = hash_to_field_le(signal.as_bytes())?; let witness = RLNWitnessInput::new( identity.identity_secret.clone(), @@ -230,9 +231,9 @@ fn main() -> Result<()> { println!("Initializing RLN instance..."); print!("\x1B[2J\x1B[1;1H"); let mut rln_system = RLNSystem::new()?; - let rln_epoch = hash_to_field_le(b"epoch"); - let rln_identifier = hash_to_field_le(b"rln-identifier"); - let external_nullifier = poseidon_hash(&[rln_epoch, rln_identifier]); + let rln_epoch = hash_to_field_le(b"epoch")?; + let rln_identifier = hash_to_field_le(b"rln-identifier")?; + let external_nullifier = poseidon_hash(&[rln_epoch, rln_identifier]).unwrap(); println!("RLN Relay Example:"); println!("Message Limit: {MESSAGE_LIMIT}"); println!("----------------------------------"); diff --git a/rln-cli/src/examples/stateless.rs b/rln-cli/src/examples/stateless.rs index ba38c73..0279bcb 100644 --- a/rln-cli/src/examples/stateless.rs +++ b/rln-cli/src/examples/stateless.rs @@ -7,10 +7,10 @@ use std::{ use clap::{Parser, Subcommand}; use rln::prelude::{ - hash_to_field_le, keygen, poseidon_hash, recover_id_secret, Fr, IdSecret, PoseidonHash, - RLNProofValues, RLNWitnessInput, DEFAULT_TREE_DEPTH, RLN, + hash_to_field_le, keygen, poseidon_hash, recover_id_secret, Fr, IdSecret, OptimalMerkleTree, + PoseidonHash, RLNProofValues, RLNWitnessInput, ZerokitMerkleProof, ZerokitMerkleTree, + DEFAULT_TREE_DEPTH, RLN, }; -use zerokit_utils::{OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree}; const MESSAGE_LIMIT: u32 = 1; @@ -48,7 +48,7 @@ struct Identity { impl Identity { fn new() -> Self { - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); Identity { identity_secret, id_commitment, @@ -101,7 +101,8 @@ impl RLNSystem { let index = self.tree.leaves_set(); let identity = Identity::new(); - let rate_commitment = poseidon_hash(&[identity.id_commitment, Fr::from(MESSAGE_LIMIT)]); + let rate_commitment = + poseidon_hash(&[identity.id_commitment, Fr::from(MESSAGE_LIMIT)]).unwrap(); self.tree.update_next(rate_commitment)?; println!("Registered User Index: {index}"); @@ -125,7 +126,7 @@ impl RLNSystem { }; let merkle_proof = self.tree.proof(user_index)?; - let x = hash_to_field_le(signal.as_bytes()); + let x = hash_to_field_le(signal.as_bytes())?; let witness = RLNWitnessInput::new( identity.identity_secret.clone(), @@ -219,9 +220,9 @@ fn main() -> Result<()> { println!("Initializing RLN instance..."); print!("\x1B[2J\x1B[1;1H"); let mut rln_system = RLNSystem::new()?; - let rln_epoch = hash_to_field_le(b"epoch"); - let rln_identifier = hash_to_field_le(b"rln-identifier"); - let external_nullifier = poseidon_hash(&[rln_epoch, rln_identifier]); + let rln_epoch = hash_to_field_le(b"epoch")?; + let rln_identifier = hash_to_field_le(b"rln-identifier")?; + let external_nullifier = poseidon_hash(&[rln_epoch, rln_identifier]).unwrap(); println!("RLN Stateless Relay Example:"); println!("Message Limit: {MESSAGE_LIMIT}"); println!("----------------------------------"); diff --git a/rln-wasm/Makefile.toml b/rln-wasm/Makefile.toml index addfe7b..139d868 100644 --- a/rln-wasm/Makefile.toml +++ b/rln-wasm/Makefile.toml @@ -1,14 +1,14 @@ [tasks.build] clear = true -dependencies = ["pack_build", "pack_rename"] +dependencies = ["pack_build", "pack_rename", "pack_add_keywords"] [tasks.build_parallel] clear = true -dependencies = ["pack_build_parallel", "pack_rename"] +dependencies = ["pack_build_parallel", "pack_rename", "pack_add_keywords"] [tasks.build_utils] clear = true -dependencies = ["pack_build_utils", "pack_rename_utils"] +dependencies = ["pack_build_utils", "pack_rename_utils", "pack_add_keywords"] [tasks.pack_build] command = "wasm-pack" @@ -54,6 +54,11 @@ args = [ [tasks.pack_rename_utils] script = "sed -i.bak 's/rln-wasm/zerokit-rln-wasm-utils/g' pkg/package.json && rm pkg/package.json.bak" +[tasks.pack_add_keywords] +script = """ + jq '. + {keywords: ["zerokit", "rln", "wasm"]}' pkg/package.json > pkg/package.json.tmp && \ + mv pkg/package.json.tmp pkg/package.json +""" [tasks.test] command = "wasm-pack" diff --git a/rln-wasm/examples/index.js b/rln-wasm/examples/index.js index 6407389..ee2d16b 100644 --- a/rln-wasm/examples/index.js +++ b/rln-wasm/examples/index.js @@ -13,7 +13,7 @@ function debugUint8Array(uint8Array) { async function calculateWitness(circomPath, inputs, witnessCalculatorFile) { const wasmFile = readFileSync(circomPath); - const wasmFileBuffer = wasmFile.slice( + const wasmFileBuffer = wasmFile.buffer.slice( wasmFile.byteOffset, wasmFile.byteOffset + wasmFile.byteLength ); @@ -49,11 +49,23 @@ async function main() { console.log("Creating RLN instance"); const zkeyData = readFileSync(zkeyPath); - const rlnInstance = new rlnWasm.WasmRLN(new Uint8Array(zkeyData)); + let rlnInstance; + try { + rlnInstance = new rlnWasm.WasmRLN(new Uint8Array(zkeyData)); + } catch (error) { + console.error("Initial RLN instance creation error:", error); + return; + } console.log("RLN instance created successfully"); console.log("\nGenerating identity keys"); - const identity = rlnWasm.Identity.generate(); + let identity; + try { + identity = rlnWasm.Identity.generate(); + } catch (error) { + console.error("Key generation error:", error); + return; + } const identitySecret = identity.getSecretHash(); const idCommitment = identity.getCommitment(); console.log("Identity generated"); @@ -65,10 +77,16 @@ async function main() { console.log(" - user_message_limit = " + userMessageLimit.debug()); console.log("\nComputing rate commitment"); - const rateCommitment = rlnWasm.Hasher.poseidonHashPair( - idCommitment, - userMessageLimit - ); + let rateCommitment; + try { + rateCommitment = rlnWasm.Hasher.poseidonHashPair( + idCommitment, + userMessageLimit + ); + } catch (error) { + console.error("Rate commitment hash error:", error); + return; + } console.log(" - rate_commitment = " + rateCommitment.debug()); console.log("\nWasmFr serialization: WasmFr <-> bytes"); @@ -79,22 +97,59 @@ async function main() { "]" ); - const deserRateCommitment = rlnWasm.WasmFr.fromBytesLE(serRateCommitment); + let deserRateCommitment; + try { + deserRateCommitment = rlnWasm.WasmFr.fromBytesLE(serRateCommitment); + } catch (error) { + console.error("Rate commitment deserialization error:", error); + return; + } console.log( " - deserialized rate_commitment = " + deserRateCommitment.debug() ); + console.log("\nIdentity serialization: Identity <-> bytes"); + const serIdentity = identity.toBytesLE(); + console.log( + " - serialized identity = [" + debugUint8Array(serIdentity) + "]" + ); + + let deserIdentity; + try { + deserIdentity = rlnWasm.Identity.fromBytesLE(serIdentity); + } catch (error) { + console.error("Identity deserialization error:", error); + return; + } + const deserIdentitySecret = deserIdentity.getSecretHash(); + const deserIdCommitment = deserIdentity.getCommitment(); + console.log( + " - deserialized identity = [" + + deserIdentitySecret.debug() + + ", " + + deserIdCommitment.debug() + + "]" + ); + console.log("\nBuilding Merkle path for stateless mode"); const treeDepth = 20; const defaultLeaf = rlnWasm.WasmFr.zero(); const defaultHashes = []; - defaultHashes[0] = rlnWasm.Hasher.poseidonHashPair(defaultLeaf, defaultLeaf); - for (let i = 1; i < treeDepth - 1; i++) { - defaultHashes[i] = rlnWasm.Hasher.poseidonHashPair( - defaultHashes[i - 1], - defaultHashes[i - 1] + try { + defaultHashes[0] = rlnWasm.Hasher.poseidonHashPair( + defaultLeaf, + defaultLeaf ); + for (let i = 1; i < treeDepth - 1; i++) { + defaultHashes[i] = rlnWasm.Hasher.poseidonHashPair( + defaultHashes[i - 1], + defaultHashes[i - 1] + ); + } + } catch (error) { + console.error("Poseidon hash error:", error); + return; } const pathElements = new rlnWasm.VecWasmFr(); @@ -110,7 +165,13 @@ async function main() { " - serialized path_elements = [" + debugUint8Array(serPathElements) + "]" ); - const deserPathElements = rlnWasm.VecWasmFr.fromBytesLE(serPathElements); + let deserPathElements; + try { + deserPathElements = rlnWasm.VecWasmFr.fromBytesLE(serPathElements); + } catch (error) { + console.error("Path elements deserialization error:", error); + return; + } console.log(" - deserialized path_elements = ", deserPathElements.debug()); console.log("\nUint8Array serialization: Uint8Array <-> bytes"); @@ -119,21 +180,30 @@ async function main() { " - serialized path_index = [" + debugUint8Array(serPathIndex) + "]" ); - const deserPathIndex = rlnWasm.Uint8ArrayUtils.fromBytesLE(serPathIndex); + let deserPathIndex; + try { + deserPathIndex = rlnWasm.Uint8ArrayUtils.fromBytesLE(serPathIndex); + } catch (error) { + console.error("Path index deserialization error:", error); + return; + } console.log(" - deserialized path_index =", deserPathIndex); console.log("\nComputing Merkle root for stateless mode"); console.log(" - computing root for index 0 with rate_commitment"); - let computedRoot = rlnWasm.Hasher.poseidonHashPair( - rateCommitment, - defaultLeaf - ); - for (let i = 1; i < treeDepth; i++) { - computedRoot = rlnWasm.Hasher.poseidonHashPair( - computedRoot, - defaultHashes[i - 1] - ); + let computedRoot; + try { + computedRoot = rlnWasm.Hasher.poseidonHashPair(rateCommitment, defaultLeaf); + for (let i = 1; i < treeDepth; i++) { + computedRoot = rlnWasm.Hasher.poseidonHashPair( + computedRoot, + defaultHashes[i - 1] + ); + } + } catch (error) { + console.error("Poseidon hash error:", error); + return; } console.log(" - computed_root = " + computedRoot.debug()); @@ -142,28 +212,47 @@ async function main() { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]); - const x = rlnWasm.Hasher.hashToFieldLE(signal); + let x; + try { + x = rlnWasm.Hasher.hashToFieldLE(signal); + } catch (error) { + console.error("Hash signal error:", error); + return; + } console.log(" - x = " + x.debug()); console.log("\nHashing epoch"); const epochStr = "test-epoch"; - const epoch = rlnWasm.Hasher.hashToFieldLE( - new TextEncoder().encode(epochStr) - ); + let epoch; + try { + epoch = rlnWasm.Hasher.hashToFieldLE(new TextEncoder().encode(epochStr)); + } catch (error) { + console.error("Hash epoch error:", error); + return; + } console.log(" - epoch = " + epoch.debug()); console.log("\nHashing RLN identifier"); const rlnIdStr = "test-rln-identifier"; - const rlnIdentifier = rlnWasm.Hasher.hashToFieldLE( - new TextEncoder().encode(rlnIdStr) - ); + let rlnIdentifier; + try { + rlnIdentifier = rlnWasm.Hasher.hashToFieldLE( + new TextEncoder().encode(rlnIdStr) + ); + } catch (error) { + console.error("Hash RLN identifier error:", error); + return; + } console.log(" - rln_identifier = " + rlnIdentifier.debug()); console.log("\nComputing Poseidon hash for external nullifier"); - const externalNullifier = rlnWasm.Hasher.poseidonHashPair( - epoch, - rlnIdentifier - ); + let externalNullifier; + try { + externalNullifier = rlnWasm.Hasher.poseidonHashPair(epoch, rlnIdentifier); + } catch (error) { + console.error("External nullifier hash error:", error); + return; + } console.log(" - external_nullifier = " + externalNullifier.debug()); console.log("\nCreating message_id"); @@ -182,8 +271,37 @@ async function main() { ); console.log("RLN Witness created successfully"); + console.log( + "\nWasmRLNWitnessInput serialization: WasmRLNWitnessInput <-> bytes" + ); + let serWitness; + try { + serWitness = witness.toBytesLE(); + } catch (error) { + console.error("Witness serialization error:", error); + return; + } + console.log( + " - serialized witness = [" + debugUint8Array(serWitness) + " ]" + ); + + let deserWitness; + try { + deserWitness = rlnWasm.WasmRLNWitnessInput.fromBytesLE(serWitness); + } catch (error) { + console.error("Witness deserialization error:", error); + return; + } + console.log(" - witness deserialized successfully"); + console.log("\nCalculating witness"); - const witnessJson = witness.toBigIntJson(); + let witnessJson; + try { + witnessJson = witness.toBigIntJson(); + } catch (error) { + console.error("Witness to BigInt JSON error:", error); + return; + } const calculatedWitness = await calculateWitness( circomPath, witnessJson, @@ -192,10 +310,16 @@ async function main() { console.log("Witness calculated successfully"); console.log("\nGenerating RLN Proof"); - const rln_proof = rlnInstance.generateRLNProofWithWitness( - calculatedWitness, - witness - ); + let rln_proof; + try { + rln_proof = rlnInstance.generateRLNProofWithWitness( + calculatedWitness, + witness + ); + } catch (error) { + console.error("Proof generation error:", error); + return; + } console.log("Proof generated successfully"); console.log("\nGetting proof values"); @@ -209,10 +333,22 @@ async function main() { ); console.log("\nRLNProof serialization: RLNProof <-> bytes"); - const serProof = rln_proof.toBytesLE(); + let serProof; + try { + serProof = rln_proof.toBytesLE(); + } catch (error) { + console.error("Proof serialization error:", error); + return; + } console.log(" - serialized proof = [" + debugUint8Array(serProof) + " ]"); - const deserProof = rlnWasm.WasmRLNProof.fromBytesLE(serProof); + let deserProof; + try { + deserProof = rlnWasm.WasmRLNProof.fromBytesLE(serProof); + } catch (error) { + console.error("Proof deserialization error:", error); + return; + } console.log(" - proof deserialized successfully"); console.log("\nRLNProofValues serialization: RLNProofValues <-> bytes"); @@ -221,8 +357,13 @@ async function main() { " - serialized proof_values = [" + debugUint8Array(serProofValues) + " ]" ); - const deserProofValues2 = - rlnWasm.WasmRLNProofValues.fromBytesLE(serProofValues); + let deserProofValues2; + try { + deserProofValues2 = rlnWasm.WasmRLNProofValues.fromBytesLE(serProofValues); + } catch (error) { + console.error("Proof values deserialization error:", error); + return; + } console.log(" - proof_values deserialized successfully"); console.log( " - deserialized external_nullifier = " + @@ -232,7 +373,13 @@ async function main() { console.log("\nVerifying Proof"); const roots = new rlnWasm.VecWasmFr(); roots.push(computedRoot); - const isValid = rlnInstance.verifyWithRoots(rln_proof, roots, x); + let isValid; + try { + isValid = rlnInstance.verifyWithRoots(rln_proof, roots, x); + } catch (error) { + console.error("Proof verification error:", error); + return; + } if (isValid) { console.log("Proof verified successfully"); } else { @@ -249,7 +396,13 @@ async function main() { 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]); - const x2 = rlnWasm.Hasher.hashToFieldLE(signal2); + let x2; + try { + x2 = rlnWasm.Hasher.hashToFieldLE(signal2); + } catch (error) { + console.error("Hash second signal error:", error); + return; + } console.log(" - x2 = " + x2.debug()); console.log("\nCreating second message with the same id"); @@ -269,7 +422,13 @@ async function main() { console.log("Second RLN Witness created successfully"); console.log("\nCalculating second witness"); - const witnessJson2 = witness2.toBigIntJson(); + let witnessJson2; + try { + witnessJson2 = witness2.toBigIntJson(); + } catch (error) { + console.error("Second witness to BigInt JSON error:", error); + return; + } const calculatedWitness2 = await calculateWitness( circomPath, witnessJson2, @@ -278,24 +437,42 @@ async function main() { console.log("Second witness calculated successfully"); console.log("\nGenerating second RLN Proof"); - const rln_proof2 = rlnInstance.generateRLNProofWithWitness( - calculatedWitness2, - witness2 - ); + let rln_proof2; + try { + rln_proof2 = rlnInstance.generateRLNProofWithWitness( + calculatedWitness2, + witness2 + ); + } catch (error) { + console.error("Second proof generation error:", error); + return; + } console.log("Second proof generated successfully"); console.log("\nVerifying second proof"); - const isValid2 = rlnInstance.verifyWithRoots(rln_proof2, roots, x2); + let isValid2; + try { + isValid2 = rlnInstance.verifyWithRoots(rln_proof2, roots, x2); + } catch (error) { + console.error("Proof verification error:", error); + return; + } if (isValid2) { console.log("Second proof verified successfully"); console.log("\nRecovering identity secret"); const proofValues1 = rln_proof.getValues(); const proofValues2 = rln_proof2.getValues(); - const recoveredSecret = rlnWasm.WasmRLNProofValues.recoverIdSecret( - proofValues1, - proofValues2 - ); + let recoveredSecret; + try { + recoveredSecret = rlnWasm.WasmRLNProofValues.recoverIdSecret( + proofValues1, + proofValues2 + ); + } catch (error) { + console.error("Identity recovery error:", error); + return; + } console.log(" - recovered_secret = " + recoveredSecret.debug()); console.log(" - original_secret = " + identitySecret.debug()); console.log("Slashing successful: Identity is recovered!"); diff --git a/rln-wasm/src/lib.rs b/rln-wasm/src/lib.rs index 3e8a111..3377c20 100644 --- a/rln-wasm/src/lib.rs +++ b/rln-wasm/src/lib.rs @@ -1,7 +1,7 @@ #![cfg(target_arch = "wasm32")] -mod wasm_rln; -mod wasm_utils; +pub mod wasm_rln; +pub mod wasm_utils; #[cfg(all(feature = "parallel", not(feature = "utils")))] pub use wasm_bindgen_rayon::init_thread_pool; diff --git a/rln-wasm/src/wasm_rln.rs b/rln-wasm/src/wasm_rln.rs index 17566dd..ff71475 100644 --- a/rln-wasm/src/wasm_rln.rs +++ b/rln-wasm/src/wasm_rln.rs @@ -29,10 +29,18 @@ impl WasmRLN { let calculated_witness_bigint: Vec = calculated_witness .iter() .map(|js_bigint| { - let str_val = js_bigint.to_string(10).unwrap().as_string().unwrap(); - str_val.parse::().unwrap() + js_bigint + .to_string(10) + .ok() + .and_then(|js_str| js_str.as_string()) + .ok_or_else(|| "Failed to convert JsBigInt to string".to_string()) + .and_then(|str_val| { + str_val + .parse::() + .map_err(|err| format!("Failed to parse BigInt: {}", err)) + }) }) - .collect(); + .collect::, _>>()?; let (proof, proof_values) = self .0 @@ -76,26 +84,28 @@ impl WasmRLNProof { } #[wasm_bindgen(js_name = toBytesLE)] - pub fn to_bytes_le(&self) -> Uint8Array { - Uint8Array::from(&rln_proof_to_bytes_le(&self.0)[..]) + pub fn to_bytes_le(&self) -> Result { + let bytes = rln_proof_to_bytes_le(&self.0).map_err(|err| err.to_string())?; + Ok(Uint8Array::from(&bytes[..])) } #[wasm_bindgen(js_name = toBytesBE)] - pub fn to_bytes_be(&self) -> Uint8Array { - Uint8Array::from(&rln_proof_to_bytes_be(&self.0)[..]) + pub fn to_bytes_be(&self) -> Result { + let bytes = rln_proof_to_bytes_be(&self.0).map_err(|err| err.to_string())?; + Ok(Uint8Array::from(&bytes[..])) } #[wasm_bindgen(js_name = fromBytesLE)] pub fn from_bytes_le(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); - let (proof, _) = bytes_le_to_rln_proof(&bytes_vec).map_err(|e| e.to_string())?; + let (proof, _) = bytes_le_to_rln_proof(&bytes_vec).map_err(|err| err.to_string())?; Ok(WasmRLNProof(proof)) } #[wasm_bindgen(js_name = fromBytesBE)] pub fn from_bytes_be(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); - let (proof, _) = bytes_be_to_rln_proof(&bytes_vec).map_err(|e| e.to_string())?; + let (proof, _) = bytes_be_to_rln_proof(&bytes_vec).map_err(|err| err.to_string())?; Ok(WasmRLNProof(proof)) } } @@ -144,7 +154,7 @@ impl WasmRLNProofValues { pub fn from_bytes_le(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); let (proof_values, _) = - bytes_le_to_rln_proof_values(&bytes_vec).map_err(|e| e.to_string())?; + bytes_le_to_rln_proof_values(&bytes_vec).map_err(|err| err.to_string())?; Ok(WasmRLNProofValues(proof_values)) } @@ -152,7 +162,7 @@ impl WasmRLNProofValues { pub fn from_bytes_be(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); let (proof_values, _) = - bytes_be_to_rln_proof_values(&bytes_vec).map_err(|e| e.to_string())?; + bytes_be_to_rln_proof_values(&bytes_vec).map_err(|err| err.to_string())?; Ok(WasmRLNProofValues(proof_values)) } diff --git a/rln-wasm/src/wasm_utils.rs b/rln-wasm/src/wasm_utils.rs index 3876e31..f0e8103 100644 --- a/rln-wasm/src/wasm_utils.rs +++ b/rln-wasm/src/wasm_utils.rs @@ -45,14 +45,14 @@ impl WasmFr { #[wasm_bindgen(js_name = fromBytesLE)] pub fn from_bytes_le(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); - let (fr, _) = bytes_le_to_fr(&bytes_vec).map_err(|e| e.to_string())?; + let (fr, _) = bytes_le_to_fr(&bytes_vec).map_err(|err| err.to_string())?; Ok(Self(fr)) } #[wasm_bindgen(js_name = fromBytesBE)] pub fn from_bytes_be(bytes: &Uint8Array) -> Result { let bytes_vec = bytes.to_vec(); - let (fr, _) = bytes_be_to_fr(&bytes_vec).map_err(|e| e.to_string())?; + let (fr, _) = bytes_be_to_fr(&bytes_vec).map_err(|err| err.to_string())?; Ok(Self(fr)) } @@ -194,18 +194,24 @@ pub struct Hasher; #[wasm_bindgen] impl Hasher { #[wasm_bindgen(js_name = hashToFieldLE)] - pub fn hash_to_field_le(input: &Uint8Array) -> WasmFr { - WasmFr(hash_to_field_le(&input.to_vec())) + pub fn hash_to_field_le(input: &Uint8Array) -> Result { + hash_to_field_le(&input.to_vec()) + .map(WasmFr) + .map_err(|err| err.to_string()) } #[wasm_bindgen(js_name = hashToFieldBE)] - pub fn hash_to_field_be(input: &Uint8Array) -> WasmFr { - WasmFr(hash_to_field_be(&input.to_vec())) + pub fn hash_to_field_be(input: &Uint8Array) -> Result { + hash_to_field_be(&input.to_vec()) + .map(WasmFr) + .map_err(|err| err.to_string()) } #[wasm_bindgen(js_name = poseidonHashPair)] - pub fn poseidon_hash_pair(a: &WasmFr, b: &WasmFr) -> WasmFr { - WasmFr(poseidon_hash(&[a.0, b.0])) + pub fn poseidon_hash_pair(a: &WasmFr, b: &WasmFr) -> Result { + poseidon_hash(&[a.0, b.0]) + .map(WasmFr) + .map_err(|err| err.to_string()) } } @@ -218,22 +224,23 @@ pub struct Identity { #[wasm_bindgen] impl Identity { #[wasm_bindgen(js_name = generate)] - pub fn generate() -> Identity { - let (identity_secret, id_commitment) = keygen(); - Identity { + pub fn generate() -> Result { + let (identity_secret, id_commitment) = keygen().map_err(|err| err.to_string())?; + Ok(Identity { identity_secret: *identity_secret, id_commitment, - } + }) } #[wasm_bindgen(js_name = generateSeeded)] - pub fn generate_seeded(seed: &Uint8Array) -> Identity { + pub fn generate_seeded(seed: &Uint8Array) -> Result { let seed_vec = seed.to_vec(); - let (identity_secret, id_commitment) = seeded_keygen(&seed_vec); - Identity { + let (identity_secret, id_commitment) = + seeded_keygen(&seed_vec).map_err(|err| err.to_string())?; + Ok(Identity { identity_secret, id_commitment, - } + }) } #[wasm_bindgen(js_name = getSecretHash)] @@ -250,6 +257,46 @@ impl Identity { pub fn to_array(&self) -> VecWasmFr { VecWasmFr(vec![self.identity_secret, self.id_commitment]) } + + #[wasm_bindgen(js_name = toBytesLE)] + pub fn to_bytes_le(&self) -> Uint8Array { + let vec_fr = vec![self.identity_secret, self.id_commitment]; + let bytes = vec_fr_to_bytes_le(&vec_fr); + Uint8Array::from(&bytes[..]) + } + + #[wasm_bindgen(js_name = toBytesBE)] + pub fn to_bytes_be(&self) -> Uint8Array { + let vec_fr = vec![self.identity_secret, self.id_commitment]; + let bytes = vec_fr_to_bytes_be(&vec_fr); + Uint8Array::from(&bytes[..]) + } + + #[wasm_bindgen(js_name = fromBytesLE)] + pub fn from_bytes_le(bytes: &Uint8Array) -> Result { + let bytes_vec = bytes.to_vec(); + let (vec_fr, _) = bytes_le_to_vec_fr(&bytes_vec).map_err(|err| err.to_string())?; + if vec_fr.len() != 2 { + return Err(format!("Expected 2 elements, got {}", vec_fr.len())); + } + Ok(Identity { + identity_secret: vec_fr[0], + id_commitment: vec_fr[1], + }) + } + + #[wasm_bindgen(js_name = fromBytesBE)] + pub fn from_bytes_be(bytes: &Uint8Array) -> Result { + let bytes_vec = bytes.to_vec(); + let (vec_fr, _) = bytes_be_to_vec_fr(&bytes_vec).map_err(|err| err.to_string())?; + if vec_fr.len() != 2 { + return Err(format!("Expected 2 elements, got {}", vec_fr.len())); + } + Ok(Identity { + identity_secret: vec_fr[0], + id_commitment: vec_fr[1], + }) + } } #[wasm_bindgen] @@ -263,28 +310,28 @@ pub struct ExtendedIdentity { #[wasm_bindgen] impl ExtendedIdentity { #[wasm_bindgen(js_name = generate)] - pub fn generate() -> ExtendedIdentity { + pub fn generate() -> Result { let (identity_trapdoor, identity_nullifier, identity_secret, id_commitment) = - extended_keygen(); - ExtendedIdentity { + extended_keygen().map_err(|err| err.to_string())?; + Ok(ExtendedIdentity { identity_trapdoor, identity_nullifier, identity_secret, id_commitment, - } + }) } #[wasm_bindgen(js_name = generateSeeded)] - pub fn generate_seeded(seed: &Uint8Array) -> ExtendedIdentity { + pub fn generate_seeded(seed: &Uint8Array) -> Result { let seed_vec = seed.to_vec(); let (identity_trapdoor, identity_nullifier, identity_secret, id_commitment) = - extended_seeded_keygen(&seed_vec); - ExtendedIdentity { + extended_seeded_keygen(&seed_vec).map_err(|err| err.to_string())?; + Ok(ExtendedIdentity { identity_trapdoor, identity_nullifier, identity_secret, id_commitment, - } + }) } #[wasm_bindgen(js_name = getTrapdoor)] @@ -316,4 +363,58 @@ impl ExtendedIdentity { self.id_commitment, ]) } + + #[wasm_bindgen(js_name = toBytesLE)] + pub fn to_bytes_le(&self) -> Uint8Array { + let vec_fr = vec![ + self.identity_trapdoor, + self.identity_nullifier, + self.identity_secret, + self.id_commitment, + ]; + let bytes = vec_fr_to_bytes_le(&vec_fr); + Uint8Array::from(&bytes[..]) + } + + #[wasm_bindgen(js_name = toBytesBE)] + pub fn to_bytes_be(&self) -> Uint8Array { + let vec_fr = vec![ + self.identity_trapdoor, + self.identity_nullifier, + self.identity_secret, + self.id_commitment, + ]; + let bytes = vec_fr_to_bytes_be(&vec_fr); + Uint8Array::from(&bytes[..]) + } + + #[wasm_bindgen(js_name = fromBytesLE)] + pub fn from_bytes_le(bytes: &Uint8Array) -> Result { + let bytes_vec = bytes.to_vec(); + let (vec_fr, _) = bytes_le_to_vec_fr(&bytes_vec).map_err(|err| err.to_string())?; + if vec_fr.len() != 4 { + return Err(format!("Expected 4 elements, got {}", vec_fr.len())); + } + Ok(ExtendedIdentity { + identity_trapdoor: vec_fr[0], + identity_nullifier: vec_fr[1], + identity_secret: vec_fr[2], + id_commitment: vec_fr[3], + }) + } + + #[wasm_bindgen(js_name = fromBytesBE)] + pub fn from_bytes_be(bytes: &Uint8Array) -> Result { + let bytes_vec = bytes.to_vec(); + let (vec_fr, _) = bytes_be_to_vec_fr(&bytes_vec).map_err(|err| err.to_string())?; + if vec_fr.len() != 4 { + return Err(format!("Expected 4 elements, got {}", vec_fr.len())); + } + Ok(ExtendedIdentity { + identity_trapdoor: vec_fr[0], + identity_nullifier: vec_fr[1], + identity_secret: vec_fr[2], + id_commitment: vec_fr[3], + }) + } } diff --git a/rln-wasm/tests/browser.rs b/rln-wasm/tests/browser.rs index 7db845d..9069fd4 100644 --- a/rln-wasm/tests/browser.rs +++ b/rln-wasm/tests/browser.rs @@ -10,7 +10,7 @@ mod test { }; use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; use wasm_bindgen_test::{console_log, wasm_bindgen_test, wasm_bindgen_test_configure}; - use zerokit_utils::{ + use zerokit_utils::merkle_tree::{ OptimalMerkleProof, OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree, }; #[cfg(feature = "parallel")] @@ -80,72 +80,64 @@ mod test { pub async fn rln_wasm_benchmark() { // Check if thread pool is supported #[cfg(feature = "parallel")] - if !isThreadpoolSupported().expect("Failed to check thread pool support") { + if !isThreadpoolSupported().unwrap() { panic!("Thread pool is NOT supported"); } else { // Initialize thread pool - let cpu_count = window() - .expect("Failed to get window") - .navigator() - .hardware_concurrency() as usize; - JsFuture::from(init_thread_pool(cpu_count)) - .await - .expect("Failed to initialize thread pool"); + let cpu_count = window().unwrap().navigator().hardware_concurrency() as usize; + JsFuture::from(init_thread_pool(cpu_count)).await.unwrap(); } // Initialize witness calculator - initWitnessCalculator(WITNESS_CALCULATOR_JS) - .expect("Failed to initialize witness calculator"); + initWitnessCalculator(WITNESS_CALCULATOR_JS).unwrap(); let mut results = String::from("\nbenchmarks:\n"); let iterations = 10; - let zkey = readFile(ARKZKEY_BYTES).expect("Failed to read zkey file"); + let zkey = readFile(ARKZKEY_BYTES).unwrap(); // Benchmark RLN instance creation let start_rln_new = Date::now(); for _ in 0..iterations { - let _ = WasmRLN::new(&zkey).expect("Failed to create RLN instance"); + let _ = WasmRLN::new(&zkey).unwrap(); } let rln_new_result = Date::now() - start_rln_new; // Create RLN instance for other benchmarks - let rln_instance = WasmRLN::new(&zkey).expect("Failed to create RLN instance"); + let rln_instance = WasmRLN::new(&zkey).unwrap(); let mut tree: OptimalMerkleTree = - OptimalMerkleTree::default(DEFAULT_TREE_DEPTH).expect("Failed to create tree"); + OptimalMerkleTree::default(DEFAULT_TREE_DEPTH).unwrap(); // Benchmark generate identity let start_identity_gen = Date::now(); for _ in 0..iterations { - let _ = Identity::generate(); + let _ = Identity::generate().unwrap(); } let identity_gen_result = Date::now() - start_identity_gen; // Generate identity for other benchmarks - let identity_pair = Identity::generate(); + let identity_pair = Identity::generate().unwrap(); let identity_secret = identity_pair.get_secret_hash(); let id_commitment = identity_pair.get_commitment(); - let epoch = Hasher::hash_to_field_le(&Uint8Array::from(b"test-epoch" as &[u8])); + let epoch = Hasher::hash_to_field_le(&Uint8Array::from(b"test-epoch" as &[u8])).unwrap(); let rln_identifier = - Hasher::hash_to_field_le(&Uint8Array::from(b"test-rln-identifier" as &[u8])); - let external_nullifier = Hasher::poseidon_hash_pair(&epoch, &rln_identifier); + Hasher::hash_to_field_le(&Uint8Array::from(b"test-rln-identifier" as &[u8])).unwrap(); + let external_nullifier = Hasher::poseidon_hash_pair(&epoch, &rln_identifier).unwrap(); let identity_index = tree.leaves_set(); let user_message_limit = WasmFr::from_uint(100); - let rate_commitment = Hasher::poseidon_hash_pair(&id_commitment, &user_message_limit); - tree.update_next(*rate_commitment) - .expect("Failed to update tree"); + let rate_commitment = + Hasher::poseidon_hash_pair(&id_commitment, &user_message_limit).unwrap(); + tree.update_next(*rate_commitment).unwrap(); let message_id = WasmFr::from_uint(0); let signal: [u8; 32] = [0; 32]; - let x = Hasher::hash_to_field_le(&Uint8Array::from(&signal[..])); + let x = Hasher::hash_to_field_le(&Uint8Array::from(&signal[..])).unwrap(); - let merkle_proof: OptimalMerkleProof = tree - .proof(identity_index) - .expect("Failed to generate merkle proof"); + let merkle_proof: OptimalMerkleProof = tree.proof(identity_index).unwrap(); let mut path_elements = VecWasmFr::new(); for path_element in merkle_proof.get_path_elements() { @@ -162,32 +154,30 @@ mod test { &x, &external_nullifier, ) - .expect("Failed to create WasmRLNWitnessInput"); + .unwrap(); - let bigint_json = witness - .to_bigint_json() - .expect("Failed to convert witness to BigInt JSON"); + let bigint_json = witness.to_bigint_json().unwrap(); // Benchmark witness calculation let start_calculate_witness = Date::now(); for _ in 0..iterations { let _ = calculateWitness(CIRCOM_BYTES, bigint_json.clone()) .await - .expect("Failed to calculate witness"); + .unwrap(); } let calculate_witness_result = Date::now() - start_calculate_witness; // Calculate witness for other benchmarks let calculated_witness_str = calculateWitness(CIRCOM_BYTES, bigint_json.clone()) .await - .expect("Failed to calculate witness") + .unwrap() .as_string() - .expect("Failed to convert calculated witness to string"); + .unwrap(); let calculated_witness_vec_str: Vec = - serde_json::from_str(&calculated_witness_str).expect("Failed to parse JSON"); + serde_json::from_str(&calculated_witness_str).unwrap(); let calculated_witness: Vec = calculated_witness_vec_str .iter() - .map(|x| JsBigInt::new(&x.into()).expect("Failed to create JsBigInt")) + .map(|x| JsBigInt::new(&x.into()).unwrap()) .collect(); // Benchmark proof generation with witness @@ -195,7 +185,7 @@ mod test { for _ in 0..iterations { let _ = rln_instance .generate_rln_proof_with_witness(calculated_witness.clone(), &witness) - .expect("Failed to generate proof"); + .unwrap(); } let generate_rln_proof_with_witness_result = Date::now() - start_generate_rln_proof_with_witness; @@ -203,7 +193,7 @@ mod test { // Generate proof with witness for other benchmarks let proof: WasmRLNProof = rln_instance .generate_rln_proof_with_witness(calculated_witness, &witness) - .expect("Failed to generate proof"); + .unwrap(); let root = WasmFr::from(tree.root()); let mut roots = VecWasmFr::new(); @@ -212,16 +202,12 @@ mod test { // Benchmark proof verification with the root let start_verify_with_roots = Date::now(); for _ in 0..iterations { - let _ = rln_instance - .verify_with_roots(&proof, &roots, &x) - .expect("Failed to verify proof"); + let _ = rln_instance.verify_with_roots(&proof, &roots, &x).unwrap(); } let verify_with_roots_result = Date::now() - start_verify_with_roots; // Verify proof with the root for other benchmarks - let is_proof_valid = rln_instance - .verify_with_roots(&proof, &roots, &x) - .expect("Failed to verify proof"); + let is_proof_valid = rln_instance.verify_with_roots(&proof, &roots, &x).unwrap(); assert!(is_proof_valid, "verification failed"); // Format and display the benchmark results diff --git a/rln-wasm/tests/node.rs b/rln-wasm/tests/node.rs index bc2ef1f..44b4f48 100644 --- a/rln-wasm/tests/node.rs +++ b/rln-wasm/tests/node.rs @@ -10,7 +10,7 @@ mod test { }; use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; use wasm_bindgen_test::{console_log, wasm_bindgen_test}; - use zerokit_utils::{ + use zerokit_utils::merkle_tree::{ OptimalMerkleProof, OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree, }; @@ -40,7 +40,7 @@ mod test { calculateWitness: async function (circom_path, inputs) { const wasmFile = fs.readFileSync(circom_path); - const wasmFileBuffer = wasmFile.slice( + const wasmFileBuffer = wasmFile.buffer.slice( wasmFile.byteOffset, wasmFile.byteOffset + wasmFile.byteLength ); @@ -75,58 +75,55 @@ mod test { #[wasm_bindgen_test] pub async fn rln_wasm_benchmark() { // Initialize witness calculator - initWitnessCalculator(WITNESS_CALCULATOR_JS) - .expect("Failed to initialize witness calculator"); + initWitnessCalculator(WITNESS_CALCULATOR_JS).unwrap(); let mut results = String::from("\nbenchmarks:\n"); let iterations = 10; - let zkey = readFile(ARKZKEY_PATH).expect("Failed to read zkey file"); + let zkey = readFile(ARKZKEY_PATH).unwrap(); // Benchmark RLN instance creation let start_rln_new = Date::now(); for _ in 0..iterations { - let _ = WasmRLN::new(&zkey).expect("Failed to create RLN instance"); + let _ = WasmRLN::new(&zkey).unwrap(); } let rln_new_result = Date::now() - start_rln_new; // Create RLN instance for other benchmarks - let rln_instance = WasmRLN::new(&zkey).expect("Failed to create RLN instance"); + let rln_instance = WasmRLN::new(&zkey).unwrap(); let mut tree: OptimalMerkleTree = - OptimalMerkleTree::default(DEFAULT_TREE_DEPTH).expect("Failed to create tree"); + OptimalMerkleTree::default(DEFAULT_TREE_DEPTH).unwrap(); // Benchmark generate identity let start_identity_gen = Date::now(); for _ in 0..iterations { - let _ = Identity::generate(); + let _ = Identity::generate().unwrap(); } let identity_gen_result = Date::now() - start_identity_gen; // Generate identity for other benchmarks - let identity_pair = Identity::generate(); + let identity_pair = Identity::generate().unwrap(); let identity_secret = identity_pair.get_secret_hash(); let id_commitment = identity_pair.get_commitment(); - let epoch = Hasher::hash_to_field_le(&Uint8Array::from(b"test-epoch" as &[u8])); + let epoch = Hasher::hash_to_field_le(&Uint8Array::from(b"test-epoch" as &[u8])).unwrap(); let rln_identifier = - Hasher::hash_to_field_le(&Uint8Array::from(b"test-rln-identifier" as &[u8])); - let external_nullifier = Hasher::poseidon_hash_pair(&epoch, &rln_identifier); + Hasher::hash_to_field_le(&Uint8Array::from(b"test-rln-identifier" as &[u8])).unwrap(); + let external_nullifier = Hasher::poseidon_hash_pair(&epoch, &rln_identifier).unwrap(); let identity_index = tree.leaves_set(); let user_message_limit = WasmFr::from_uint(100); - let rate_commitment = Hasher::poseidon_hash_pair(&id_commitment, &user_message_limit); - tree.update_next(*rate_commitment) - .expect("Failed to update tree"); + let rate_commitment = + Hasher::poseidon_hash_pair(&id_commitment, &user_message_limit).unwrap(); + tree.update_next(*rate_commitment).unwrap(); let message_id = WasmFr::from_uint(0); let signal: [u8; 32] = [0; 32]; - let x = Hasher::hash_to_field_le(&Uint8Array::from(&signal[..])); + let x = Hasher::hash_to_field_le(&Uint8Array::from(&signal[..])).unwrap(); - let merkle_proof: OptimalMerkleProof = tree - .proof(identity_index) - .expect("Failed to generate merkle proof"); + let merkle_proof: OptimalMerkleProof = tree.proof(identity_index).unwrap(); let mut path_elements = VecWasmFr::new(); for path_element in merkle_proof.get_path_elements() { @@ -143,32 +140,30 @@ mod test { &x, &external_nullifier, ) - .expect("Failed to create WasmRLNWitnessInput"); + .unwrap(); - let bigint_json = witness - .to_bigint_json() - .expect("Failed to convert witness to BigInt JSON"); + let bigint_json = witness.to_bigint_json().unwrap(); // Benchmark witness calculation let start_calculate_witness = Date::now(); for _ in 0..iterations { let _ = calculateWitness(CIRCOM_PATH, bigint_json.clone()) .await - .expect("Failed to calculate witness"); + .unwrap(); } let calculate_witness_result = Date::now() - start_calculate_witness; // Calculate witness for other benchmarks let calculated_witness_str = calculateWitness(CIRCOM_PATH, bigint_json.clone()) .await - .expect("Failed to calculate witness") + .unwrap() .as_string() - .expect("Failed to convert calculated witness to string"); + .unwrap(); let calculated_witness_vec_str: Vec = - serde_json::from_str(&calculated_witness_str).expect("Failed to parse JSON"); + serde_json::from_str(&calculated_witness_str).unwrap(); let calculated_witness: Vec = calculated_witness_vec_str .iter() - .map(|x| JsBigInt::new(&x.into()).expect("Failed to create JsBigInt")) + .map(|x| JsBigInt::new(&x.into()).unwrap()) .collect(); // Benchmark proof generation with witness @@ -176,7 +171,7 @@ mod test { for _ in 0..iterations { let _ = rln_instance .generate_rln_proof_with_witness(calculated_witness.clone(), &witness) - .expect("Failed to generate proof"); + .unwrap(); } let generate_rln_proof_with_witness_result = Date::now() - start_generate_rln_proof_with_witness; @@ -184,7 +179,7 @@ mod test { // Generate proof with witness for other benchmarks let proof: WasmRLNProof = rln_instance .generate_rln_proof_with_witness(calculated_witness, &witness) - .expect("Failed to generate proof"); + .unwrap(); let root = WasmFr::from(tree.root()); let mut roots = VecWasmFr::new(); @@ -193,16 +188,12 @@ mod test { // Benchmark proof verification with the root let start_verify_with_roots = Date::now(); for _ in 0..iterations { - let _ = rln_instance - .verify_with_roots(&proof, &roots, &x) - .expect("Failed to verify proof"); + let _ = rln_instance.verify_with_roots(&proof, &roots, &x).unwrap(); } let verify_with_roots_result = Date::now() - start_verify_with_roots; // Verify proof with the root for other benchmarks - let is_proof_valid = rln_instance - .verify_with_roots(&proof, &roots, &x) - .expect("Failed to verify proof"); + let is_proof_valid = rln_instance.verify_with_roots(&proof, &roots, &x).unwrap(); assert!(is_proof_valid, "verification failed"); // Format and display the benchmark results diff --git a/rln-wasm/tests/utils.rs b/rln-wasm/tests/utils.rs index 5827446..1d00f9e 100644 --- a/rln-wasm/tests/utils.rs +++ b/rln-wasm/tests/utils.rs @@ -13,7 +13,7 @@ mod test { #[wasm_bindgen_test] fn test_keygen_wasm() { - let identity = Identity::generate(); + let identity = Identity::generate().unwrap(); let identity_secret = *identity.get_secret_hash(); let id_commitment = *identity.get_commitment(); @@ -28,7 +28,7 @@ mod test { #[wasm_bindgen_test] fn test_extended_keygen_wasm() { - let identity = ExtendedIdentity::generate(); + let identity = ExtendedIdentity::generate().unwrap(); let identity_trapdoor = *identity.get_trapdoor(); let identity_nullifier = *identity.get_nullifier(); @@ -53,7 +53,7 @@ mod test { let seed_bytes: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let seed = Uint8Array::from(&seed_bytes[..]); - let identity = Identity::generate_seeded(&seed); + let identity = Identity::generate_seeded(&seed).unwrap(); let identity_secret = *identity.get_secret_hash(); let id_commitment = *identity.get_commitment(); @@ -77,7 +77,7 @@ mod test { let seed_bytes: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let seed = Uint8Array::from(&seed_bytes[..]); - let identity = ExtendedIdentity::generate_seeded(&seed); + let identity = ExtendedIdentity::generate_seeded(&seed).unwrap(); let identity_trapdoor = *identity.get_trapdoor(); let identity_nullifier = *identity.get_nullifier(); @@ -128,7 +128,7 @@ mod test { let wasmfr_debug_str = wasmfr_int.debug(); assert_eq!(wasmfr_debug_str.to_string(), "42"); - let identity = Identity::generate(); + let identity = Identity::generate().unwrap(); let mut id_secret_fr = *identity.get_secret_hash(); let id_secret_hash = IdSecret::from(&mut id_secret_fr); let id_commitment = *identity.get_commitment(); @@ -184,12 +184,12 @@ mod test { let signal_gen: [u8; 32] = rng.gen(); let signal = Uint8Array::from(&signal_gen[..]); - let wasmfr_le_1 = Hasher::hash_to_field_le(&signal); - let fr_le_2 = rln::hashers::hash_to_field_le(&signal_gen); + let wasmfr_le_1 = Hasher::hash_to_field_le(&signal).unwrap(); + let fr_le_2 = hash_to_field_le(&signal_gen).unwrap(); assert_eq!(*wasmfr_le_1, fr_le_2); - let wasmfr_be_1 = Hasher::hash_to_field_be(&signal); - let fr_be_2 = rln::hashers::hash_to_field_be(&signal_gen); + let wasmfr_be_1 = Hasher::hash_to_field_be(&signal).unwrap(); + let fr_be_2 = hash_to_field_be(&signal_gen).unwrap(); assert_eq!(*wasmfr_be_1, fr_be_2); assert_eq!(*wasmfr_le_1, *wasmfr_be_1); @@ -212,10 +212,10 @@ mod test { let input_1 = Fr::from(42u8); let input_2 = Fr::from(99u8); - let expected_hash = poseidon_hash(&[input_1, input_2]); + let expected_hash = poseidon_hash(&[input_1, input_2]).unwrap(); let wasmfr_1 = WasmFr::from_uint(42); let wasmfr_2 = WasmFr::from_uint(99); - let received_hash = Hasher::poseidon_hash_pair(&wasmfr_1, &wasmfr_2); + let received_hash = Hasher::poseidon_hash_pair(&wasmfr_1, &wasmfr_2).unwrap(); assert_eq!(*received_hash, expected_hash); } diff --git a/rln/Cargo.toml b/rln/Cargo.toml index 616ada1..e5e3bd2 100644 --- a/rln/Cargo.toml +++ b/rln/Cargo.toml @@ -46,7 +46,7 @@ ruint = { version = "1.17.0", default-features = false, features = [ tiny-keccak = { version = "2.0.2", features = ["keccak"] } zeroize = "1.8.2" tempfile = "3.23.0" -utils = { package = "zerokit_utils", version = "0.7.0", path = "../utils", default-features = false } +zerokit_utils = { version = "0.7.0", path = "../utils", default-features = false } # FFI safer-ffi.version = "0.1" @@ -67,17 +67,17 @@ default = ["parallel", "pmtree-ft"] stateless = [] parallel = [ "rayon", - "utils/parallel", "ark-ff/parallel", "ark-ec/parallel", "ark-std/parallel", "ark-poly/parallel", "ark-groth16/parallel", "ark-serialize/parallel", + "zerokit_utils/parallel", ] fullmerkletree = [] # Pre-allocated tree, fastest access optimalmerkletree = [] # Sparse storage, memory efficient -pmtree-ft = ["utils/pmtree-ft"] # Persistent storage, disk-based +pmtree-ft = ["zerokit_utils/pmtree-ft"] # Persistent storage, disk-based headers = ["safer-ffi/headers"] # Generate C header file with safer-ffi [[bench]] diff --git a/rln/benches/pmtree_benchmark.rs b/rln/benches/pmtree_benchmark.rs index 049193b..74814a5 100644 --- a/rln/benches/pmtree_benchmark.rs +++ b/rln/benches/pmtree_benchmark.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use rln::prelude::*; -use utils::ZerokitMerkleTree; +use zerokit_utils::merkle_tree::ZerokitMerkleTree; pub fn pmtree_benchmark(c: &mut Criterion) { let mut tree = PmTree::default(2).unwrap(); diff --git a/rln/benches/poseidon_tree_benchmark.rs b/rln/benches/poseidon_tree_benchmark.rs index bf52f67..7294897 100644 --- a/rln/benches/poseidon_tree_benchmark.rs +++ b/rln/benches/poseidon_tree_benchmark.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use rln::prelude::*; -use utils::{FullMerkleTree, OptimalMerkleTree, ZerokitMerkleTree}; +use zerokit_utils::merkle_tree::{FullMerkleTree, OptimalMerkleTree, ZerokitMerkleTree}; pub fn get_leaves(n: u32) -> Vec { (0..n).map(Fr::from).collect() diff --git a/rln/ffi_c_examples/main.c b/rln/ffi_c_examples/main.c index eb8d811..7c5b4db 100644 --- a/rln/ffi_c_examples/main.c +++ b/rln/ffi_c_examples/main.c @@ -26,7 +26,14 @@ int main(int argc, char const *const argv[]) printf("RLN instance created successfully\n"); printf("\nGenerating identity keys\n"); - Vec_CFr_t keys = ffi_key_gen(); + CResult_Vec_CFr_Vec_uint8_t keys_result = ffi_key_gen(); + if (keys_result.err.ptr) + { + fprintf(stderr, "Key generation error: %s\n", keys_result.err.ptr); + ffi_c_string_free(keys_result.err); + return EXIT_FAILURE; + } + Vec_CFr_t keys = keys_result.ok; const CFr_t *identity_secret = ffi_vec_cfr_get(&keys, 0); const CFr_t *id_commitment = ffi_vec_cfr_get(&keys, 1); printf("Identity generated\n"); @@ -47,7 +54,14 @@ int main(int argc, char const *const argv[]) ffi_c_string_free(debug); printf("\nComputing rate commitment\n"); - CFr_t *rate_commitment = ffi_poseidon_hash_pair(id_commitment, user_message_limit); + CResult_CFr_ptr_Vec_uint8_t rate_commitment_result = ffi_poseidon_hash_pair(id_commitment, user_message_limit); + if (!rate_commitment_result.ok) + { + fprintf(stderr, "Rate commitment hash error: %s\n", rate_commitment_result.err.ptr); + ffi_c_string_free(rate_commitment_result.err); + return EXIT_FAILURE; + } + CFr_t *rate_commitment = rate_commitment_result.ok; debug = ffi_cfr_debug(rate_commitment); printf(" - rate_commitment = %s\n", debug.ptr); @@ -92,7 +106,7 @@ int main(int argc, char const *const argv[]) } debug = ffi_vec_cfr_debug(&deser_keys_result.ok); - printf(" - deserialized identity_secret = %s\n", debug.ptr); + printf(" - deserialized keys = %s\n", debug.ptr); ffi_c_string_free(debug); Vec_CFr_t deser_keys = deser_keys_result.ok; @@ -108,10 +122,24 @@ int main(int argc, char const *const argv[]) CFr_t *default_leaf = ffi_cfr_zero(); CFr_t *default_hashes[TREE_DEPTH - 1]; - default_hashes[0] = ffi_poseidon_hash_pair(default_leaf, default_leaf); + CResult_CFr_ptr_Vec_uint8_t hash_result = ffi_poseidon_hash_pair(default_leaf, default_leaf); + if (!hash_result.ok) + { + fprintf(stderr, "Poseidon hash error: %s\n", hash_result.err.ptr); + ffi_c_string_free(hash_result.err); + return EXIT_FAILURE; + } + default_hashes[0] = hash_result.ok; for (size_t i = 1; i < TREE_DEPTH - 1; i++) { - default_hashes[i] = ffi_poseidon_hash_pair(default_hashes[i - 1], default_hashes[i - 1]); + hash_result = ffi_poseidon_hash_pair(default_hashes[i - 1], default_hashes[i - 1]); + if (!hash_result.ok) + { + fprintf(stderr, "Poseidon hash error: %s\n", hash_result.err.ptr); + ffi_c_string_free(hash_result.err); + return EXIT_FAILURE; + } + default_hashes[i] = hash_result.ok; } Vec_CFr_t path_elements = ffi_vec_cfr_new(TREE_DEPTH); @@ -177,10 +205,24 @@ int main(int argc, char const *const argv[]) printf("\nComputing Merkle root for stateless mode\n"); printf(" - computing root for index 0 with rate_commitment\n"); - CFr_t *computed_root = ffi_poseidon_hash_pair(rate_commitment, default_leaf); + CResult_CFr_ptr_Vec_uint8_t root_result = ffi_poseidon_hash_pair(rate_commitment, default_leaf); + if (!root_result.ok) + { + fprintf(stderr, "Poseidon hash error: %s\n", root_result.err.ptr); + ffi_c_string_free(root_result.err); + return EXIT_FAILURE; + } + CFr_t *computed_root = root_result.ok; for (size_t i = 1; i < TREE_DEPTH; i++) { - CFr_t *next_root = ffi_poseidon_hash_pair(computed_root, default_hashes[i - 1]); + root_result = ffi_poseidon_hash_pair(computed_root, default_hashes[i - 1]); + if (!root_result.ok) + { + fprintf(stderr, "Poseidon hash error: %s\n", root_result.err.ptr); + ffi_c_string_free(root_result.err); + return EXIT_FAILURE; + } + CFr_t *next_root = root_result.ok; ffi_cfr_free(computed_root); computed_root = next_root; } @@ -216,7 +258,14 @@ int main(int argc, char const *const argv[]) printf("\nHashing signal\n"); uint8_t signal[32] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; Vec_uint8_t signal_vec = {signal, 32, 32}; - CFr_t *x = ffi_hash_to_field_le(&signal_vec); + CResult_CFr_ptr_Vec_uint8_t x_result = ffi_hash_to_field_le(&signal_vec); + if (!x_result.ok) + { + fprintf(stderr, "Hash signal error: %s\n", x_result.err.ptr); + ffi_c_string_free(x_result.err); + return EXIT_FAILURE; + } + CFr_t *x = x_result.ok; debug = ffi_cfr_debug(x); printf(" - x = %s\n", debug.ptr); @@ -225,7 +274,14 @@ int main(int argc, char const *const argv[]) printf("\nHashing epoch\n"); const char *epoch_str = "test-epoch"; Vec_uint8_t epoch_vec = {(uint8_t *)epoch_str, strlen(epoch_str), strlen(epoch_str)}; - CFr_t *epoch = ffi_hash_to_field_le(&epoch_vec); + CResult_CFr_ptr_Vec_uint8_t epoch_result = ffi_hash_to_field_le(&epoch_vec); + if (!epoch_result.ok) + { + fprintf(stderr, "Hash epoch error: %s\n", epoch_result.err.ptr); + ffi_c_string_free(epoch_result.err); + return EXIT_FAILURE; + } + CFr_t *epoch = epoch_result.ok; debug = ffi_cfr_debug(epoch); printf(" - epoch = %s\n", debug.ptr); @@ -234,14 +290,28 @@ int main(int argc, char const *const argv[]) printf("\nHashing RLN identifier\n"); const char *rln_id_str = "test-rln-identifier"; Vec_uint8_t rln_id_vec = {(uint8_t *)rln_id_str, strlen(rln_id_str), strlen(rln_id_str)}; - CFr_t *rln_identifier = ffi_hash_to_field_le(&rln_id_vec); + CResult_CFr_ptr_Vec_uint8_t rln_identifier_result = ffi_hash_to_field_le(&rln_id_vec); + if (!rln_identifier_result.ok) + { + fprintf(stderr, "Hash RLN identifier error: %s\n", rln_identifier_result.err.ptr); + ffi_c_string_free(rln_identifier_result.err); + return EXIT_FAILURE; + } + CFr_t *rln_identifier = rln_identifier_result.ok; debug = ffi_cfr_debug(rln_identifier); printf(" - rln_identifier = %s\n", debug.ptr); ffi_c_string_free(debug); printf("\nComputing Poseidon hash for external nullifier\n"); - CFr_t *external_nullifier = ffi_poseidon_hash_pair(epoch, rln_identifier); + CResult_CFr_ptr_Vec_uint8_t external_nullifier_result = ffi_poseidon_hash_pair(epoch, rln_identifier); + if (!external_nullifier_result.ok) + { + fprintf(stderr, "External nullifier hash error: %s\n", external_nullifier_result.err.ptr); + ffi_c_string_free(external_nullifier_result.err); + return EXIT_FAILURE; + } + CFr_t *external_nullifier = external_nullifier_result.ok; debug = ffi_cfr_debug(external_nullifier); printf(" - external_nullifier = %s\n", debug.ptr); @@ -293,6 +363,34 @@ int main(int argc, char const *const argv[]) printf("RLN Witness created successfully\n"); #endif + printf("\nRLNWitnessInput serialization: RLNWitnessInput <-> bytes\n"); + CResult_Vec_uint8_Vec_uint8_t ser_witness_result = ffi_rln_witness_to_bytes_le(&witness); + if (ser_witness_result.err.ptr) + { + fprintf(stderr, "Witness serialization error: %s\n", ser_witness_result.err.ptr); + ffi_c_string_free(ser_witness_result.err); + return EXIT_FAILURE; + } + Vec_uint8_t ser_witness = ser_witness_result.ok; + + debug = ffi_vec_u8_debug(&ser_witness); + printf(" - serialized witness = %s\n", debug.ptr); + ffi_c_string_free(debug); + + CResult_FFI_RLNWitnessInput_ptr_Vec_uint8_t deser_witness_result = ffi_bytes_le_to_rln_witness(&ser_witness); + if (!deser_witness_result.ok) + { + fprintf(stderr, "Witness deserialization error: %s\n", deser_witness_result.err.ptr); + ffi_c_string_free(deser_witness_result.err); + return EXIT_FAILURE; + } + + FFI_RLNWitnessInput_t *deser_witness = deser_witness_result.ok; + printf(" - witness deserialized successfully\n"); + + ffi_rln_witness_input_free(deser_witness); + ffi_vec_u8_free(ser_witness); + printf("\nGenerating RLN Proof\n"); CResult_FFI_RLNProof_ptr_Vec_uint8_t proof_gen_result = ffi_generate_rln_proof( &rln, @@ -342,7 +440,14 @@ int main(int argc, char const *const argv[]) ffi_cfr_free(ext_nullifier); printf("\nRLNProof serialization: RLNProof <-> bytes\n"); - Vec_uint8_t ser_proof = ffi_rln_proof_to_bytes_le(&rln_proof); + CResult_Vec_uint8_Vec_uint8_t ser_proof_result = ffi_rln_proof_to_bytes_le(&rln_proof); + if (ser_proof_result.err.ptr) + { + fprintf(stderr, "Proof serialization error: %s\n", ser_proof_result.err.ptr); + ffi_c_string_free(ser_proof_result.err); + return EXIT_FAILURE; + } + Vec_uint8_t ser_proof = ser_proof_result.ok; debug = ffi_vec_u8_debug(&ser_proof); printf(" - serialized proof = %s\n", debug.ptr); @@ -411,7 +516,14 @@ int main(int argc, char const *const argv[]) printf("\nHashing second signal\n"); uint8_t signal2[32] = {11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; Vec_uint8_t signal2_vec = {signal2, 32, 32}; - CFr_t *x2 = ffi_hash_to_field_le(&signal2_vec); + CResult_CFr_ptr_Vec_uint8_t x2_result = ffi_hash_to_field_le(&signal2_vec); + if (!x2_result.ok) + { + fprintf(stderr, "Hash second signal error: %s\n", x2_result.err.ptr); + ffi_c_string_free(x2_result.err); + return EXIT_FAILURE; + } + CFr_t *x2 = x2_result.ok; debug = ffi_cfr_debug(x2); printf(" - x2 = %s\n", debug.ptr); diff --git a/rln/ffi_nim_examples/main.nim b/rln/ffi_nim_examples/main.nim index 0dea44c..ebb4bbb 100644 --- a/rln/ffi_nim_examples/main.nim +++ b/rln/ffi_nim_examples/main.nim @@ -140,22 +140,22 @@ proc ffi_vec_u8_free*(v: Vec_uint8) {.importc: "ffi_vec_u8_free", cdecl, dynlib: RLN_LIB.} # Hashing functions -proc ffi_hash_to_field_le*(input: ptr Vec_uint8): ptr CFr {.importc: "ffi_hash_to_field_le", +proc ffi_hash_to_field_le*(input: ptr Vec_uint8): CResultCFrPtrVecU8 {.importc: "ffi_hash_to_field_le", cdecl, dynlib: RLN_LIB.} -proc ffi_hash_to_field_be*(input: ptr Vec_uint8): ptr CFr {.importc: "ffi_hash_to_field_be", +proc ffi_hash_to_field_be*(input: ptr Vec_uint8): CResultCFrPtrVecU8 {.importc: "ffi_hash_to_field_be", cdecl, dynlib: RLN_LIB.} proc ffi_poseidon_hash_pair*(a: ptr CFr, - b: ptr CFr): ptr CFr {.importc: "ffi_poseidon_hash_pair", cdecl, + b: ptr CFr): CResultCFrPtrVecU8 {.importc: "ffi_poseidon_hash_pair", cdecl, dynlib: RLN_LIB.} # Keygen function -proc ffi_key_gen*(): Vec_CFr {.importc: "ffi_key_gen", cdecl, +proc ffi_key_gen*(): CResultVecCFrVecU8 {.importc: "ffi_key_gen", cdecl, dynlib: RLN_LIB.} -proc ffi_seeded_key_gen*(seed: ptr Vec_uint8): Vec_CFr {.importc: "ffi_seeded_key_gen", +proc ffi_seeded_key_gen*(seed: ptr Vec_uint8): CResultVecCFrVecU8 {.importc: "ffi_seeded_key_gen", cdecl, dynlib: RLN_LIB.} -proc ffi_extended_key_gen*(): Vec_CFr {.importc: "ffi_extended_key_gen", +proc ffi_extended_key_gen*(): CResultVecCFrVecU8 {.importc: "ffi_extended_key_gen", cdecl, dynlib: RLN_LIB.} -proc ffi_seeded_extended_key_gen*(seed: ptr Vec_uint8): Vec_CFr {.importc: "ffi_seeded_extended_key_gen", +proc ffi_seeded_extended_key_gen*(seed: ptr Vec_uint8): CResultVecCFrVecU8 {.importc: "ffi_seeded_extended_key_gen", cdecl, dynlib: RLN_LIB.} # RLN instance functions @@ -186,9 +186,9 @@ proc ffi_rln_witness_input_new*( external_nullifier: ptr CFr ): CResultWitnessInputPtrVecU8 {.importc: "ffi_rln_witness_input_new", cdecl, dynlib: RLN_LIB.} -proc ffi_rln_witness_to_bytes_le*(witness: ptr ptr FFI_RLNWitnessInput): Vec_uint8 {.importc: "ffi_rln_witness_to_bytes_le", +proc ffi_rln_witness_to_bytes_le*(witness: ptr ptr FFI_RLNWitnessInput): CResultVecU8VecU8 {.importc: "ffi_rln_witness_to_bytes_le", cdecl, dynlib: RLN_LIB.} -proc ffi_rln_witness_to_bytes_be*(witness: ptr ptr FFI_RLNWitnessInput): Vec_uint8 {.importc: "ffi_rln_witness_to_bytes_be", +proc ffi_rln_witness_to_bytes_be*(witness: ptr ptr FFI_RLNWitnessInput): CResultVecU8VecU8 {.importc: "ffi_rln_witness_to_bytes_be", cdecl, dynlib: RLN_LIB.} proc ffi_bytes_le_to_rln_witness*(bytes: ptr Vec_uint8): CResultWitnessInputPtrVecU8 {.importc: "ffi_bytes_le_to_rln_witness", cdecl, dynlib: RLN_LIB.} @@ -285,9 +285,9 @@ proc ffi_recover_id_secret*(proof_values_1: ptr ptr FFI_RLNProofValues, cdecl, dynlib: RLN_LIB.} # RLNProof serialization -proc ffi_rln_proof_to_bytes_le*(proof: ptr ptr FFI_RLNProof): Vec_uint8 {.importc: "ffi_rln_proof_to_bytes_le", +proc ffi_rln_proof_to_bytes_le*(proof: ptr ptr FFI_RLNProof): CResultVecU8VecU8 {.importc: "ffi_rln_proof_to_bytes_le", cdecl, dynlib: RLN_LIB.} -proc ffi_rln_proof_to_bytes_be*(proof: ptr ptr FFI_RLNProof): Vec_uint8 {.importc: "ffi_rln_proof_to_bytes_be", +proc ffi_rln_proof_to_bytes_be*(proof: ptr ptr FFI_RLNProof): CResultVecU8VecU8 {.importc: "ffi_rln_proof_to_bytes_be", cdecl, dynlib: RLN_LIB.} proc ffi_bytes_le_to_rln_proof*(bytes: ptr Vec_uint8): CResultProofPtrVecU8 {.importc: "ffi_bytes_le_to_rln_proof", cdecl, dynlib: RLN_LIB.} @@ -351,7 +351,13 @@ when isMainModule: echo "RLN instance created successfully" echo "\nGenerating identity keys" - var keys = ffi_key_gen() + var keysResult = ffi_key_gen() + if keysResult.err.dataPtr != nil: + let errMsg = asString(keysResult.err) + ffi_c_string_free(keysResult.err) + echo "Key generation error: ", errMsg + quit 1 + var keys = keysResult.ok let identitySecret = ffi_vec_cfr_get(addr keys, CSize(0)) let idCommitment = ffi_vec_cfr_get(addr keys, CSize(1)) echo "Identity generated" @@ -375,7 +381,13 @@ when isMainModule: ffi_c_string_free(debug) echo "\nComputing rate commitment" - let rateCommitment = ffi_poseidon_hash_pair(idCommitment, userMessageLimit) + let rateCommitmentResult = ffi_poseidon_hash_pair(idCommitment, userMessageLimit) + if rateCommitmentResult.ok.isNil: + let errMsg = asString(rateCommitmentResult.err) + ffi_c_string_free(rateCommitmentResult.err) + echo "Rate commitment hash error: ", errMsg + quit 1 + let rateCommitment = rateCommitmentResult.ok block: let debug = ffi_cfr_debug(rateCommitment) @@ -424,7 +436,7 @@ when isMainModule: block: var okKeys = deserKeysResult.ok let debug = ffi_vec_cfr_debug(addr okKeys) - echo " - deserialized identity_secret = ", asString(debug) + echo " - deserialized keys = ", asString(debug) ffi_c_string_free(debug) ffi_vec_cfr_free(deserKeysResult.ok) @@ -438,10 +450,22 @@ when isMainModule: let defaultLeaf = ffi_cfr_zero() var defaultHashes: array[treeDepth-1, ptr CFr] - defaultHashes[0] = ffi_poseidon_hash_pair(defaultLeaf, defaultLeaf) + block: + let hashResult = ffi_poseidon_hash_pair(defaultLeaf, defaultLeaf) + if hashResult.ok.isNil: + let errMsg = asString(hashResult.err) + ffi_c_string_free(hashResult.err) + echo "Poseidon hash error: ", errMsg + quit 1 + defaultHashes[0] = hashResult.ok for i in 1..treeDepth-2: - defaultHashes[i] = ffi_poseidon_hash_pair(defaultHashes[i-1], - defaultHashes[i-1]) + let hashResult = ffi_poseidon_hash_pair(defaultHashes[i-1], defaultHashes[i-1]) + if hashResult.ok.isNil: + let errMsg = asString(hashResult.err) + ffi_c_string_free(hashResult.err) + echo "Poseidon hash error: ", errMsg + quit 1 + defaultHashes[i] = hashResult.ok var pathElements = ffi_vec_cfr_new(CSize(treeDepth)) ffi_vec_cfr_push(addr pathElements, defaultLeaf) @@ -501,9 +525,21 @@ when isMainModule: echo "\nComputing Merkle root for stateless mode" echo " - computing root for index 0 with rate_commitment" - var computedRoot = ffi_poseidon_hash_pair(rateCommitment, defaultLeaf) + let rootResult = ffi_poseidon_hash_pair(rateCommitment, defaultLeaf) + if rootResult.ok.isNil: + let errMsg = asString(rootResult.err) + ffi_c_string_free(rootResult.err) + echo "Poseidon hash error: ", errMsg + quit 1 + var computedRoot = rootResult.ok for i in 1..treeDepth-1: - let next = ffi_poseidon_hash_pair(computedRoot, defaultHashes[i-1]) + let nextResult = ffi_poseidon_hash_pair(computedRoot, defaultHashes[i-1]) + if nextResult.ok.isNil: + let errMsg = asString(nextResult.err) + ffi_c_string_free(nextResult.err) + echo "Poseidon hash error: ", errMsg + quit 1 + let next = nextResult.ok ffi_cfr_free(computedRoot) computedRoot = next @@ -537,7 +573,12 @@ when isMainModule: 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] var signalVec = Vec_uint8(dataPtr: cast[ptr uint8](addr signal[0]), len: CSize(signal.len), cap: CSize(signal.len)) - let x = ffi_hash_to_field_be(addr signalVec) + let xResult = ffi_hash_to_field_be(addr signalVec) + if xResult.ok.isNil: + stderr.writeLine "Hash signal error: ", asString(xResult.err) + ffi_c_string_free(xResult.err) + quit 1 + let x = xResult.ok block: let debug = ffi_cfr_debug(x) @@ -549,7 +590,12 @@ when isMainModule: var epochBytes = newSeq[uint8](epochStr.len) for i in 0.. bytes" - var serWitness = ffi_rln_witness_to_bytes_be(addr witness) + let serWitnessResult = ffi_rln_witness_to_bytes_be(addr witness) + if serWitnessResult.err.dataPtr != nil: + stderr.writeLine "Witness serialization error: ", asString( + serWitnessResult.err) + ffi_c_string_free(serWitnessResult.err) + quit 1 + var serWitness = serWitnessResult.ok block: let debug = ffi_vec_u8_debug(addr serWitness) @@ -676,7 +740,12 @@ when isMainModule: ffi_cfr_free(extNullifier) echo "\nRLNProof serialization: RLNProof <-> bytes" - var serProof = ffi_rln_proof_to_bytes_be(addr proof) + let serProofResult = ffi_rln_proof_to_bytes_be(addr proof) + if serProofResult.err.dataPtr != nil: + stderr.writeLine "Proof serialization error: ", asString(serProofResult.err) + ffi_c_string_free(serProofResult.err) + quit 1 + var serProof = serProofResult.ok block: let debug = ffi_vec_u8_debug(addr serProof) @@ -747,7 +816,12 @@ when isMainModule: 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] var signal2Vec = Vec_uint8(dataPtr: cast[ptr uint8](addr signal2[0]), len: CSize(signal2.len), cap: CSize(signal2.len)) - let x2 = ffi_hash_to_field_be(addr signal2Vec) + let x2Result = ffi_hash_to_field_be(addr signal2Vec) + if x2Result.ok.isNil: + stderr.writeLine "Hash second signal error: ", asString(x2Result.err) + ffi_c_string_free(x2Result.err) + quit 1 + let x2 = x2Result.ok block: let debug = ffi_cfr_debug(x2) @@ -806,7 +880,7 @@ when isMainModule: let verifyErr2 = ffi_verify_rln_proof(addr rln, addr proof2, x2) if not verifyErr2.ok: - stderr.writeLine "Second proof verification error: ", asString( + stderr.writeLine "Proof verification error: ", asString( verifyErr2.err) ffi_c_string_free(verifyErr2.err) quit 1 diff --git a/rln/src/circuit/error.rs b/rln/src/circuit/error.rs index 82b726e..465b680 100644 --- a/rln/src/circuit/error.rs +++ b/rln/src/circuit/error.rs @@ -1,3 +1,4 @@ +/// Errors that can occur during zkey reading operations #[derive(Debug, thiserror::Error)] pub enum ZKeyReadError { #[error("Empty zkey bytes provided")] @@ -5,3 +6,20 @@ pub enum ZKeyReadError { #[error("{0}")] SerializationError(#[from] ark_serialize::SerializationError), } + +/// Errors that can occur during witness calculation +#[derive(Debug, thiserror::Error)] +pub enum WitnessCalcError { + #[error("Failed to deserialize witness calculation graph: {0}")] + GraphDeserialization(#[from] std::io::Error), + #[error("Failed to evaluate witness calculation graph: {0}")] + GraphEvaluation(String), + #[error("Invalid input length for '{name}': expected {expected}, got {actual}")] + InvalidInputLength { + name: String, + expected: usize, + actual: usize, + }, + #[error("Missing required input: {0}")] + MissingInput(String), +} diff --git a/rln/src/circuit/iden3calc.rs b/rln/src/circuit/iden3calc.rs index 82782db..b4fdcb3 100644 --- a/rln/src/circuit/iden3calc.rs +++ b/rln/src/circuit/iden3calc.rs @@ -1,9 +1,9 @@ // This crate is based on the code by iden3. Its preimage can be found here: // https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/src/lib.rs -pub mod graph; -pub mod proto; -pub mod storage; +mod graph; +mod proto; +mod storage; use std::collections::HashMap; @@ -12,17 +12,16 @@ use ruint::aliases::U256; use storage::deserialize_witnesscalc_graph; use zeroize::zeroize_flat_type; -use crate::{ - circuit::{iden3calc::graph::fr_to_u256, Fr}, - utils::FrOrSecret, -}; +use self::graph::fr_to_u256; +use super::{error::WitnessCalcError, Fr}; +use crate::utils::FrOrSecret; pub(crate) type InputSignalsInfo = HashMap; pub(crate) fn calc_witness)>>( inputs: I, graph_data: &[u8], -) -> Vec { +) -> Result, WitnessCalcError> { let mut inputs: HashMap> = inputs .into_iter() .map(|(key, value)| { @@ -40,11 +39,11 @@ pub(crate) fn calc_witness)>>( .collect(); let (nodes, signals, input_mapping): (Vec, Vec, InputSignalsInfo) = - deserialize_witnesscalc_graph(std::io::Cursor::new(graph_data)).unwrap(); + deserialize_witnesscalc_graph(std::io::Cursor::new(graph_data))?; let mut inputs_buffer = get_inputs_buffer(get_inputs_size(&nodes)); - populate_inputs(&inputs, &input_mapping, &mut inputs_buffer); + populate_inputs(&inputs, &input_mapping, &mut inputs_buffer)?; if let Some(v) = inputs.get_mut("identitySecret") { // DO NOT USE: unsafe { zeroize_flat_type(v) } only clears the Vec pointer, not the data—can cause memory leaks @@ -54,13 +53,14 @@ pub(crate) fn calc_witness)>>( } } - let res = graph::evaluate(&nodes, inputs_buffer.as_slice(), &signals); + let res = graph::evaluate(&nodes, inputs_buffer.as_slice(), &signals) + .map_err(WitnessCalcError::GraphEvaluation)?; for val in inputs_buffer.iter_mut() { unsafe { zeroize_flat_type(val) }; } - res + Ok(res) } fn get_inputs_size(nodes: &[Node]) -> usize { @@ -83,17 +83,26 @@ fn populate_inputs( input_list: &HashMap>, inputs_info: &InputSignalsInfo, input_buffer: &mut [U256], -) { +) -> Result<(), WitnessCalcError> { for (key, value) in input_list { - let (offset, len) = inputs_info[key]; - if len != value.len() { - panic!("Invalid input length for {key}"); + let (offset, len) = inputs_info + .get(key) + .ok_or_else(|| WitnessCalcError::MissingInput(key.clone()))?; + + if *len != value.len() { + return Err(WitnessCalcError::InvalidInputLength { + name: key.clone(), + expected: *len, + actual: value.len(), + }); } for (i, v) in value.iter().enumerate() { input_buffer[offset + i] = *v; } } + + Ok(()) } /// Allocates inputs vec with position 0 set to 1 diff --git a/rln/src/circuit/iden3calc/graph.rs b/rln/src/circuit/iden3calc/graph.rs index a87d330..a1e3b10 100644 --- a/rln/src/circuit/iden3calc/graph.rs +++ b/rln/src/circuit/iden3calc/graph.rs @@ -1,22 +1,17 @@ // This crate is based on the code by iden3. Its preimage can be found here: // https://github.com/iden3/circom-witnesscalc/blob/5cb365b6e4d9052ecc69d4567fcf5bc061c20e94/src/graph.rs -use std::{ - cmp::Ordering, - collections::HashMap, - error::Error, - ops::{Deref, Shl, Shr}, -}; +use std::cmp::Ordering; use ark_ff::{BigInt, BigInteger, One, PrimeField, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; -use rand::Rng; use ruint::{aliases::U256, uint}; use serde::{Deserialize, Serialize}; -use crate::circuit::{iden3calc::proto, Fr}; +use super::proto; +use crate::circuit::Fr; -pub const M: U256 = +const M: U256 = uint!(21888242871839275222246405745257275088548364400416034343698204186575808495617_U256); fn ark_se(a: &A, s: S) -> Result @@ -39,17 +34,18 @@ where } #[inline(always)] -pub fn fr_to_u256(x: &Fr) -> U256 { +pub(crate) fn fr_to_u256(x: &Fr) -> U256 { U256::from_limbs(x.into_bigint().0) } #[inline(always)] -pub fn u256_to_fr(x: &U256) -> Fr { - Fr::from_bigint(BigInt::new(x.into_limbs())).expect("Failed to convert U256 to Fr") +pub(crate) fn u256_to_fr(x: &U256) -> Result { + Fr::from_bigint(BigInt::new(x.into_limbs())) + .ok_or_else(|| "Failed to convert U256 to Fr".to_string()) } #[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] -pub enum Operation { +pub(crate) enum Operation { Mul, Div, Add, @@ -73,113 +69,76 @@ pub enum Operation { } impl Operation { - // TODO: rewrite to &U256 type - pub fn eval(&self, a: U256, b: U256) -> U256 { + fn eval_fr(&self, a: Fr, b: Fr) -> Result { use Operation::*; match self { - Mul => a.mul_mod(b, M), - Div => { - if b == U256::ZERO { - // as we are simulating a circuit execution with signals - // values all equal to 0, just return 0 here in case of - // division by zero - U256::ZERO - } else { - a.mul_mod(b.inv_mod(M).unwrap(), M) - } - } - Add => a.add_mod(b, M), - Sub => a.add_mod(M - b, M), - Pow => a.pow_mod(b, M), - Mod => a.div_rem(b).1, - Eq => U256::from(a == b), - Neq => U256::from(a != b), - Lt => u_lt(&a, &b), - Gt => u_gt(&a, &b), - Leq => u_lte(&a, &b), - Geq => u_gte(&a, &b), - Land => U256::from(a != U256::ZERO && b != U256::ZERO), - Lor => U256::from(a != U256::ZERO || b != U256::ZERO), - Shl => compute_shl_uint(a, b), - Shr => compute_shr_uint(a, b), - // TODO test with conner case when it is possible to get the number - // bigger then modulus - Bor => a.bitor(b), - Band => a.bitand(b), - // TODO test with conner case when it is possible to get the number - // bigger then modulus - Bxor => a.bitxor(b), - Idiv => a / b, - } - } - - pub fn eval_fr(&self, a: Fr, b: Fr) -> Fr { - use Operation::*; - match self { - Mul => a * b, + Mul => Ok(a * b), // We always should return something on the circuit execution. // So in case of division by 0 we would return 0. And the proof // should be invalid in the end. Div => { if b.is_zero() { - Fr::zero() + Ok(Fr::zero()) } else { - a / b + Ok(a / b) } } - Add => a + b, - Sub => a - b, + Add => Ok(a + b), + Sub => Ok(a - b), + // Modular exponentiation to prevent overflow and keep result in field + Pow => { + let a_u256 = fr_to_u256(&a); + let b_u256 = fr_to_u256(&b); + let result = a_u256.pow_mod(b_u256, M); + u256_to_fr(&result) + } + // Integer division (not field division) Idiv => { if b.is_zero() { - Fr::zero() + Ok(Fr::zero()) } else { let a_u256 = fr_to_u256(&a); let b_u256 = fr_to_u256(&b); u256_to_fr(&(a_u256 / b_u256)) } } + // Integer modulo (not field arithmetic) Mod => { if b.is_zero() { - Fr::zero() + Ok(Fr::zero()) } else { let a_u256 = fr_to_u256(&a); let b_u256 = fr_to_u256(&b); u256_to_fr(&(a_u256 % b_u256)) } } - Eq => match a.cmp(&b) { + Eq => Ok(match a.cmp(&b) { Ordering::Equal => Fr::one(), _ => Fr::zero(), - }, - Neq => match a.cmp(&b) { + }), + Neq => Ok(match a.cmp(&b) { Ordering::Equal => Fr::zero(), _ => Fr::one(), - }, + }), Lt => u256_to_fr(&u_lt(&fr_to_u256(&a), &fr_to_u256(&b))), Gt => u256_to_fr(&u_gt(&fr_to_u256(&a), &fr_to_u256(&b))), Leq => u256_to_fr(&u_lte(&fr_to_u256(&a), &fr_to_u256(&b))), Geq => u256_to_fr(&u_gte(&fr_to_u256(&a), &fr_to_u256(&b))), - Land => { - if a.is_zero() || b.is_zero() { - Fr::zero() - } else { - Fr::one() - } - } - Lor => { - if a.is_zero() && b.is_zero() { - Fr::zero() - } else { - Fr::one() - } - } + Land => Ok(if a.is_zero() || b.is_zero() { + Fr::zero() + } else { + Fr::one() + }), + Lor => Ok(if a.is_zero() && b.is_zero() { + Fr::zero() + } else { + Fr::one() + }), Shl => shl(a, b), Shr => shr(a, b), Bor => bit_or(a, b), Band => bit_and(a, b), Bxor => bit_xor(a, b), - // TODO implement other operators - _ => unimplemented!("operator {:?} not implemented for Montgomery", self), } } } @@ -212,37 +171,27 @@ impl From<&Operation> for proto::DuoOp { } #[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] -pub enum UnoOperation { +pub(crate) enum UnoOperation { Neg, Id, // identity - just return self } impl UnoOperation { - pub fn eval(&self, a: U256) -> U256 { - match self { - UnoOperation::Neg => { - if a == U256::ZERO { - U256::ZERO - } else { - M - a - } - } - UnoOperation::Id => a, - } - } - - pub fn eval_fr(&self, a: Fr) -> Fr { + fn eval_fr(&self, a: Fr) -> Result { match self { UnoOperation::Neg => { if a.is_zero() { - Fr::zero() + Ok(Fr::zero()) } else { let mut x = Fr::MODULUS; x.sub_with_borrow(&a.into_bigint()); - Fr::from_bigint(x).unwrap() + Fr::from_bigint(x).ok_or_else(|| "Failed to compute negation".to_string()) } } - _ => unimplemented!("uno operator {:?} not implemented for Montgomery", self), + _ => Err(format!( + "uno operator {:?} not implemented for Montgomery", + self + )), } } } @@ -257,30 +206,18 @@ impl From<&UnoOperation> for proto::UnoOp { } #[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] -pub enum TresOperation { +pub(crate) enum TresOperation { TernCond, } impl TresOperation { - pub fn eval(&self, a: U256, b: U256, c: U256) -> U256 { - match self { - TresOperation::TernCond => { - if a == U256::ZERO { - c - } else { - b - } - } - } - } - - pub fn eval_fr(&self, a: Fr, b: Fr, c: Fr) -> Fr { + fn eval_fr(&self, a: Fr, b: Fr, c: Fr) -> Result { match self { TresOperation::TernCond => { if a.is_zero() { - c + Ok(c) } else { - b + Ok(b) } } } @@ -296,7 +233,7 @@ impl From<&TresOperation> for proto::TresOp { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub enum Node { +pub(crate) enum Node { Input(usize), Constant(U256), #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] @@ -306,133 +243,21 @@ pub enum Node { TresOp(TresOperation, usize, usize, usize), } -// TODO remove pub from Vec -#[derive(Default)] -pub struct Nodes(pub Vec); - -impl Nodes { - pub fn new() -> Self { - Nodes(Vec::new()) - } - - pub fn to_const(&self, idx: NodeIdx) -> Result { - let me = self.0.get(idx.0).ok_or(NodeConstErr::EmptyNode(idx))?; - match me { - Node::Constant(v) => Ok(*v), - Node::UnoOp(op, a) => Ok(op.eval(self.to_const(NodeIdx(*a))?)), - Node::Op(op, a, b) => { - Ok(op.eval(self.to_const(NodeIdx(*a))?, self.to_const(NodeIdx(*b))?)) - } - Node::TresOp(op, a, b, c) => Ok(op.eval( - self.to_const(NodeIdx(*a))?, - self.to_const(NodeIdx(*b))?, - self.to_const(NodeIdx(*c))?, - )), - Node::Input(_) => Err(NodeConstErr::InputSignal), - Node::MontConstant(_) => { - panic!("MontConstant should not be used here") - } - } - } - - pub fn push(&mut self, n: Node) -> NodeIdx { - self.0.push(n); - NodeIdx(self.0.len() - 1) - } - - pub fn get(&self, idx: NodeIdx) -> Option<&Node> { - self.0.get(idx.0) - } -} - -impl Deref for Nodes { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[derive(Debug, Copy, Clone)] -pub struct NodeIdx(pub usize); - -impl From for NodeIdx { - fn from(v: usize) -> Self { - NodeIdx(v) - } -} - -#[derive(Debug)] -pub enum NodeConstErr { - EmptyNode(NodeIdx), - InputSignal, -} - -impl std::fmt::Display for NodeConstErr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - NodeConstErr::EmptyNode(idx) => { - write!(f, "empty node at index {}", idx.0) - } - NodeConstErr::InputSignal => { - write!(f, "input signal is not a constant") - } - } - } -} - -impl Error for NodeConstErr {} - -fn compute_shl_uint(a: U256, b: U256) -> U256 { - debug_assert!(b.lt(&U256::from(256))); - let ls_limb = b.as_limbs()[0]; - a.shl(ls_limb as usize) -} - -fn compute_shr_uint(a: U256, b: U256) -> U256 { - debug_assert!(b.lt(&U256::from(256))); - let ls_limb = b.as_limbs()[0]; - a.shr(ls_limb as usize) -} - -/// All references must be backwards. -fn assert_valid(nodes: &[Node]) { - for (i, &node) in nodes.iter().enumerate() { - if let Node::Op(_, a, b) = node { - assert!(a < i); - assert!(b < i); - } else if let Node::UnoOp(_, a) = node { - assert!(a < i); - } else if let Node::TresOp(_, a, b, c) = node { - assert!(a < i); - assert!(b < i); - assert!(c < i); - } - } -} - -pub fn optimize(nodes: &mut Vec, outputs: &mut [usize]) { - tree_shake(nodes, outputs); - propagate(nodes); - value_numbering(nodes, outputs); - constants(nodes); - tree_shake(nodes, outputs); - montgomery_form(nodes); -} - -pub fn evaluate(nodes: &[Node], inputs: &[U256], outputs: &[usize]) -> Vec { - // assert_valid(nodes); - +pub(crate) fn evaluate( + nodes: &[Node], + inputs: &[U256], + outputs: &[usize], +) -> Result, String> { // Evaluate the graph. let mut values = Vec::with_capacity(nodes.len()); for &node in nodes.iter() { let value = match node { - Node::Constant(c) => u256_to_fr(&c), + Node::Constant(c) => u256_to_fr(&c)?, Node::MontConstant(c) => c, - Node::Input(i) => u256_to_fr(&inputs[i]), - Node::Op(op, a, b) => op.eval_fr(values[a], values[b]), - Node::UnoOp(op, a) => op.eval_fr(values[a]), - Node::TresOp(op, a, b, c) => op.eval_fr(values[a], values[b], values[c]), + Node::Input(i) => u256_to_fr(&inputs[i])?, + Node::Op(op, a, b) => op.eval_fr(values[a], values[b])?, + Node::UnoOp(op, a) => op.eval_fr(values[a])?, + Node::TresOp(op, a, b, c) => op.eval_fr(values[a], values[b], values[c])?, }; values.push(value); } @@ -443,246 +268,31 @@ pub fn evaluate(nodes: &[Node], inputs: &[U256], outputs: &[usize]) -> Vec { out[i] = values[outputs[i]]; } - out + Ok(out) } -/// Constant propagation -pub fn propagate(nodes: &mut [Node]) { - assert_valid(nodes); - for i in 0..nodes.len() { - if let Node::Op(op, a, b) = nodes[i] { - if let (Node::Constant(va), Node::Constant(vb)) = (nodes[a], nodes[b]) { - nodes[i] = Node::Constant(op.eval(va, vb)); - } else if a == b { - // Not constant but equal - use Operation::*; - if let Some(c) = match op { - Eq | Leq | Geq => Some(true), - Neq | Lt | Gt => Some(false), - _ => None, - } { - nodes[i] = Node::Constant(U256::from(c)); - } - } - } else if let Node::UnoOp(op, a) = nodes[i] { - if let Node::Constant(va) = nodes[a] { - nodes[i] = Node::Constant(op.eval(va)); - } - } else if let Node::TresOp(op, a, b, c) = nodes[i] { - if let (Node::Constant(va), Node::Constant(vb), Node::Constant(vc)) = - (nodes[a], nodes[b], nodes[c]) - { - nodes[i] = Node::Constant(op.eval(va, vb, vc)); - } - } - } -} - -/// Remove unused nodes -pub fn tree_shake(nodes: &mut Vec, outputs: &mut [usize]) { - assert_valid(nodes); - - // Mark all nodes that are used. - let mut used = vec![false; nodes.len()]; - for &i in outputs.iter() { - used[i] = true; - } - - // Work backwards from end as all references are backwards. - for i in (0..nodes.len()).rev() { - if used[i] { - if let Node::Op(_, a, b) = nodes[i] { - used[a] = true; - used[b] = true; - } - if let Node::UnoOp(_, a) = nodes[i] { - used[a] = true; - } - if let Node::TresOp(_, a, b, c) = nodes[i] { - used[a] = true; - used[b] = true; - used[c] = true; - } - } - } - - // Remove unused nodes - let n = nodes.len(); - let mut retain = used.iter(); - nodes.retain(|_| *retain.next().unwrap()); - - // Renumber references. - let mut renumber = vec![None; n]; - let mut index = 0; - for (i, &used) in used.iter().enumerate() { - if used { - renumber[i] = Some(index); - index += 1; - } - } - assert_eq!(index, nodes.len()); - for (&used, renumber) in used.iter().zip(renumber.iter()) { - assert_eq!(used, renumber.is_some()); - } - - // Renumber references. - for node in nodes.iter_mut() { - if let Node::Op(_, a, b) = node { - *a = renumber[*a].unwrap(); - *b = renumber[*b].unwrap(); - } - if let Node::UnoOp(_, a) = node { - *a = renumber[*a].unwrap(); - } - if let Node::TresOp(_, a, b, c) = node { - *a = renumber[*a].unwrap(); - *b = renumber[*b].unwrap(); - *c = renumber[*c].unwrap(); - } - } - for output in outputs.iter_mut() { - *output = renumber[*output].unwrap(); - } -} - -/// Randomly evaluate the graph -fn random_eval(nodes: &mut [Node]) -> Vec { - let mut rng = rand::thread_rng(); - let mut values = Vec::with_capacity(nodes.len()); - let mut inputs = HashMap::new(); - let mut prfs = HashMap::new(); - let mut prfs_uno = HashMap::new(); - let mut prfs_tres = HashMap::new(); - for node in nodes.iter() { - use Operation::*; - let value = match node { - // Constants evaluate to themselves - Node::Constant(c) => *c, - - Node::MontConstant(_) => unimplemented!("should not be used"), - - // Algebraic Ops are evaluated directly - // Since the field is large, by Swartz-Zippel if - // two values are the same then they are likely algebraically equal. - Node::Op(op @ (Add | Sub | Mul), a, b) => op.eval(values[*a], values[*b]), - - // Input and non-algebraic ops are random functions - // TODO: https://github.com/recmo/uint/issues/95 and use .gen_range(..M) - Node::Input(i) => *inputs.entry(*i).or_insert_with(|| rng.gen::() % M), - Node::Op(op, a, b) => *prfs - .entry((*op, values[*a], values[*b])) - .or_insert_with(|| rng.gen::() % M), - Node::UnoOp(op, a) => *prfs_uno - .entry((*op, values[*a])) - .or_insert_with(|| rng.gen::() % M), - Node::TresOp(op, a, b, c) => *prfs_tres - .entry((*op, values[*a], values[*b], values[*c])) - .or_insert_with(|| rng.gen::() % M), - }; - values.push(value); - } - values -} - -/// Value numbering -pub fn value_numbering(nodes: &mut [Node], outputs: &mut [usize]) { - assert_valid(nodes); - - // Evaluate the graph in random field elements. - let values = random_eval(nodes); - - // Find all nodes with the same value. - let mut value_map = HashMap::new(); - for (i, &value) in values.iter().enumerate() { - value_map.entry(value).or_insert_with(Vec::new).push(i); - } - - // For nodes that are the same, pick the first index. - let renumber: Vec<_> = values.into_iter().map(|v| value_map[&v][0]).collect(); - - // Renumber references. - for node in nodes.iter_mut() { - if let Node::Op(_, a, b) = node { - *a = renumber[*a]; - *b = renumber[*b]; - } - if let Node::UnoOp(_, a) = node { - *a = renumber[*a]; - } - if let Node::TresOp(_, a, b, c) = node { - *a = renumber[*a]; - *b = renumber[*b]; - *c = renumber[*c]; - } - } - for output in outputs.iter_mut() { - *output = renumber[*output]; - } -} - -/// Probabilistic constant determination -pub fn constants(nodes: &mut [Node]) { - assert_valid(nodes); - - // Evaluate the graph in random field elements. - let values_a = random_eval(nodes); - let values_b = random_eval(nodes); - - // Find all nodes with the same value. - for i in 0..nodes.len() { - if let Node::Constant(_) = nodes[i] { - continue; - } - if values_a[i] == values_b[i] { - nodes[i] = Node::Constant(values_a[i]); - } - } -} - -/// Convert to Montgomery form -pub fn montgomery_form(nodes: &mut [Node]) { - for node in nodes.iter_mut() { - use Node::*; - use Operation::*; - match node { - Constant(c) => *node = MontConstant(u256_to_fr(c)), - MontConstant(..) => (), - Input(..) => (), - Op( - Mul | Div | Add | Sub | Idiv | Mod | Eq | Neq | Lt | Gt | Leq | Geq | Land | Lor - | Shl | Shr | Bor | Band | Bxor, - .., - ) => (), - Op(op @ Pow, ..) => unimplemented!("Operators Montgomery form: {:?}", op), - UnoOp(UnoOperation::Neg, ..) => (), - UnoOp(op, ..) => unimplemented!("Uno Operators Montgomery form: {:?}", op), - TresOp(TresOperation::TernCond, ..) => (), - } - } -} - -fn shl(a: Fr, b: Fr) -> Fr { +fn shl(a: Fr, b: Fr) -> Result { if b.is_zero() { - return a; + return Ok(a); } if b.cmp(&Fr::from(Fr::MODULUS_BIT_SIZE)).is_ge() { - return Fr::zero(); + return Ok(Fr::zero()); } let n = b.into_bigint().0[0] as u32; let a = a.into_bigint(); - Fr::from_bigint(a << n).unwrap() + Fr::from_bigint(a << n).ok_or_else(|| "Failed to compute left shift".to_string()) } -fn shr(a: Fr, b: Fr) -> Fr { +fn shr(a: Fr, b: Fr) -> Result { if b.is_zero() { - return a; + return Ok(a); } match b.cmp(&Fr::from(254u64)) { - Ordering::Equal => return Fr::zero(), - Ordering::Greater => return Fr::zero(), + Ordering::Equal => return Ok(Fr::zero()), + Ordering::Greater => return Ok(Fr::zero()), _ => (), }; @@ -698,7 +308,7 @@ fn shr(a: Fr, b: Fr) -> Fr { } if n == 0 { - return Fr::from_bigint(result).unwrap(); + return Fr::from_bigint(result).ok_or_else(|| "Failed to compute right shift".to_string()); } let mask: u64 = (1 << n) - 1; @@ -709,10 +319,10 @@ fn shr(a: Fr, b: Fr) -> Fr { c[i] = (c[i] >> n) | (carrier << (64 - n)); carrier = new_carrier; } - Fr::from_bigint(result).unwrap() + Fr::from_bigint(result).ok_or_else(|| "Failed to compute right shift".to_string()) } -fn bit_and(a: Fr, b: Fr) -> Fr { +fn bit_and(a: Fr, b: Fr) -> Result { let a = a.into_bigint(); let b = b.into_bigint(); let c: [u64; 4] = [ @@ -726,10 +336,10 @@ fn bit_and(a: Fr, b: Fr) -> Fr { d.sub_with_borrow(&Fr::MODULUS); } - Fr::from_bigint(d).unwrap() + Fr::from_bigint(d).ok_or_else(|| "Failed to compute bitwise AND".to_string()) } -fn bit_or(a: Fr, b: Fr) -> Fr { +fn bit_or(a: Fr, b: Fr) -> Result { let a = a.into_bigint(); let b = b.into_bigint(); let c: [u64; 4] = [ @@ -743,10 +353,10 @@ fn bit_or(a: Fr, b: Fr) -> Fr { d.sub_with_borrow(&Fr::MODULUS); } - Fr::from_bigint(d).unwrap() + Fr::from_bigint(d).ok_or_else(|| "Failed to compute bitwise OR".to_string()) } -fn bit_xor(a: Fr, b: Fr) -> Fr { +fn bit_xor(a: Fr, b: Fr) -> Result { let a = a.into_bigint(); let b = b.into_bigint(); let c: [u64; 4] = [ @@ -760,7 +370,7 @@ fn bit_xor(a: Fr, b: Fr) -> Fr { d.sub_with_borrow(&Fr::MODULUS); } - Fr::from_bigint(d).unwrap() + Fr::from_bigint(d).ok_or_else(|| "Failed to compute bitwise XOR".to_string()) } // M / 2 @@ -827,14 +437,16 @@ mod test { fn test_ok() { let a = Fr::from(4u64); let b = Fr::from(2u64); - let c = shl(a, b); + let c = shl(a, b).unwrap(); assert_eq!(c.cmp(&Fr::from(16u64)), Ordering::Equal) } #[test] fn test_div() { assert_eq!( - Operation::Div.eval_fr(Fr::from(2u64), Fr::from(3u64)), + Operation::Div + .eval_fr(Fr::from(2u64), Fr::from(3u64)) + .unwrap(), Fr::from_str( "7296080957279758407415468581752425029516121466805344781232734728858602831873" ) @@ -842,12 +454,16 @@ mod test { ); assert_eq!( - Operation::Div.eval_fr(Fr::from(6u64), Fr::from(2u64)), + Operation::Div + .eval_fr(Fr::from(6u64), Fr::from(2u64)) + .unwrap(), Fr::from_str("3").unwrap() ); assert_eq!( - Operation::Div.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Operation::Div + .eval_fr(Fr::from(7u64), Fr::from(2u64)) + .unwrap(), Fr::from_str( "10944121435919637611123202872628637544274182200208017171849102093287904247812" ) @@ -858,17 +474,23 @@ mod test { #[test] fn test_idiv() { assert_eq!( - Operation::Idiv.eval_fr(Fr::from(2u64), Fr::from(3u64)), + Operation::Idiv + .eval_fr(Fr::from(2u64), Fr::from(3u64)) + .unwrap(), Fr::from_str("0").unwrap() ); assert_eq!( - Operation::Idiv.eval_fr(Fr::from(6u64), Fr::from(2u64)), + Operation::Idiv + .eval_fr(Fr::from(6u64), Fr::from(2u64)) + .unwrap(), Fr::from_str("3").unwrap() ); assert_eq!( - Operation::Idiv.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Operation::Idiv + .eval_fr(Fr::from(7u64), Fr::from(2u64)) + .unwrap(), Fr::from_str("3").unwrap() ); } @@ -876,12 +498,16 @@ mod test { #[test] fn test_fr_mod() { assert_eq!( - Operation::Mod.eval_fr(Fr::from(7u64), Fr::from(2u64)), + Operation::Mod + .eval_fr(Fr::from(7u64), Fr::from(2u64)) + .unwrap(), Fr::from_str("1").unwrap() ); assert_eq!( - Operation::Mod.eval_fr(Fr::from(7u64), Fr::from(9u64)), + Operation::Mod + .eval_fr(Fr::from(7u64), Fr::from(9u64)) + .unwrap(), Fr::from_str("7").unwrap() ); } diff --git a/rln/src/circuit/iden3calc/proto.rs b/rln/src/circuit/iden3calc/proto.rs index 9a71367..302cb9c 100644 --- a/rln/src/circuit/iden3calc/proto.rs +++ b/rln/src/circuit/iden3calc/proto.rs @@ -5,29 +5,29 @@ use std::collections::HashMap; #[derive(Clone, PartialEq, prost::Message)] -pub struct BigUInt { +pub(crate) struct BigUInt { #[prost(bytes = "vec", tag = "1")] pub value_le: Vec, } #[derive(Clone, Copy, PartialEq, prost::Message)] -pub struct InputNode { +pub(crate) struct InputNode { #[prost(uint32, tag = "1")] pub idx: u32, } #[derive(Clone, PartialEq, prost::Message)] -pub struct ConstantNode { +pub(crate) struct ConstantNode { #[prost(message, optional, tag = "1")] pub value: Option, } #[derive(Clone, Copy, PartialEq, prost::Message)] -pub struct UnoOpNode { +pub(crate) struct UnoOpNode { #[prost(enumeration = "UnoOp", tag = "1")] pub op: i32, #[prost(uint32, tag = "2")] pub a_idx: u32, } #[derive(Clone, Copy, PartialEq, prost::Message)] -pub struct DuoOpNode { +pub(crate) struct DuoOpNode { #[prost(enumeration = "DuoOp", tag = "1")] pub op: i32, #[prost(uint32, tag = "2")] @@ -36,7 +36,7 @@ pub struct DuoOpNode { pub b_idx: u32, } #[derive(Clone, Copy, PartialEq, prost::Message)] -pub struct TresOpNode { +pub(crate) struct TresOpNode { #[prost(enumeration = "TresOp", tag = "1")] pub op: i32, #[prost(uint32, tag = "2")] @@ -47,14 +47,14 @@ pub struct TresOpNode { pub c_idx: u32, } #[derive(Clone, PartialEq, prost::Message)] -pub struct Node { +pub(crate) struct Node { #[prost(oneof = "node::Node", tags = "1, 2, 3, 4, 5")] pub node: Option, } /// Nested message and enum types in `Node`. -pub mod node { +pub(crate) mod node { #[derive(Clone, PartialEq, prost::Oneof)] - pub enum Node { + pub(crate) enum Node { #[prost(message, tag = "1")] Input(super::InputNode), #[prost(message, tag = "2")] @@ -68,21 +68,21 @@ pub mod node { } } #[derive(Clone, Copy, PartialEq, prost::Message)] -pub struct SignalDescription { +pub(crate) struct SignalDescription { #[prost(uint32, tag = "1")] pub offset: u32, #[prost(uint32, tag = "2")] pub len: u32, } #[derive(Clone, PartialEq, prost::Message)] -pub struct GraphMetadata { +pub(crate) struct GraphMetadata { #[prost(uint32, repeated, tag = "1")] pub witness_signals: Vec, #[prost(map = "string, message", tag = "2")] pub inputs: HashMap, } #[derive(Clone, Copy, Debug, PartialEq, prost::Enumeration)] -pub enum DuoOp { +pub(crate) enum DuoOp { Mul = 0, Div = 1, Add = 2, @@ -106,12 +106,12 @@ pub enum DuoOp { } #[derive(Clone, Copy, Debug, PartialEq, prost::Enumeration)] -pub enum UnoOp { +pub(crate) enum UnoOp { Neg = 0, Id = 1, } #[derive(Clone, Copy, Debug, PartialEq, prost::Enumeration)] -pub enum TresOp { +pub(crate) enum TresOp { TernCond = 0, } diff --git a/rln/src/circuit/iden3calc/storage.rs b/rln/src/circuit/iden3calc/storage.rs index e965533..d864b50 100644 --- a/rln/src/circuit/iden3calc/storage.rs +++ b/rln/src/circuit/iden3calc/storage.rs @@ -7,54 +7,80 @@ use ark_ff::PrimeField; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use prost::Message; -use crate::circuit::{ - iden3calc::{ - graph, - graph::{Operation, TresOperation, UnoOperation}, - proto, InputSignalsInfo, - }, - Fr, +use super::{ + graph::{self, Operation, TresOperation, UnoOperation}, + proto, InputSignalsInfo, }; +use crate::circuit::Fr; -// format of the wtns.graph file: -// + magic line: wtns.graph.001 -// + 4 bytes unsigned LE 32-bit integer: number of nodes -// + series of protobuf serialized nodes. Each node prefixed by varint length -// + protobuf serialized GraphMetadata -// + 8 bytes unsigned LE 64-bit integer: offset of GraphMetadata message - +/// Format of the wtns.graph file: +/// + magic line: wtns.graph.001 +/// + 4 bytes unsigned LE 32-bit integer: number of nodes +/// + series of protobuf serialized nodes. Each node prefixed by varint length +/// + protobuf serialized GraphMetadata +/// + 8 bytes unsigned LE 64-bit integer: offset of GraphMetadata message const WITNESSCALC_GRAPH_MAGIC: &[u8] = b"wtns.graph.001"; const MAX_VARINT_LENGTH: usize = 10; -impl From for graph::Node { - fn from(value: proto::Node) -> Self { - match value.node.unwrap() { - proto::node::Node::Input(input_node) => graph::Node::Input(input_node.idx as usize), +impl TryFrom for graph::Node { + type Error = std::io::Error; + + fn try_from(value: proto::Node) -> Result { + let node = value.node.ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Proto::Node must have a node field", + ) + })?; + match node { + proto::node::Node::Input(input_node) => Ok(graph::Node::Input(input_node.idx as usize)), proto::node::Node::Constant(constant_node) => { - let i = constant_node.value.unwrap(); - graph::Node::MontConstant(Fr::from_le_bytes_mod_order(i.value_le.as_slice())) + let i = constant_node.value.ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Constant node must have a value", + ) + })?; + Ok(graph::Node::MontConstant(Fr::from_le_bytes_mod_order( + i.value_le.as_slice(), + ))) } proto::node::Node::UnoOp(uno_op_node) => { - let op = proto::UnoOp::try_from(uno_op_node.op).unwrap(); - graph::Node::UnoOp(op.into(), uno_op_node.a_idx as usize) + let op = proto::UnoOp::try_from(uno_op_node.op).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "UnoOp must be valid enum value", + ) + })?; + Ok(graph::Node::UnoOp(op.into(), uno_op_node.a_idx as usize)) } proto::node::Node::DuoOp(duo_op_node) => { - let op = proto::DuoOp::try_from(duo_op_node.op).unwrap(); - graph::Node::Op( + let op = proto::DuoOp::try_from(duo_op_node.op).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "DuoOp must be valid enum value", + ) + })?; + Ok(graph::Node::Op( op.into(), duo_op_node.a_idx as usize, duo_op_node.b_idx as usize, - ) + )) } proto::node::Node::TresOp(tres_op_node) => { - let op = proto::TresOp::try_from(tres_op_node.op).unwrap(); - graph::Node::TresOp( + let op = proto::TresOp::try_from(tres_op_node.op).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "TresOp must be valid enum value", + ) + })?; + Ok(graph::Node::TresOp( op.into(), tres_op_node.a_idx as usize, tres_op_node.b_idx as usize, tres_op_node.c_idx as usize, - ) + )) } } } @@ -140,14 +166,15 @@ impl From for graph::TresOperation { } } -pub fn serialize_witnesscalc_graph( +#[allow(dead_code)] +pub(crate) fn serialize_witnesscalc_graph( mut w: T, nodes: &Vec, witness_signals: &[usize], input_signals: &InputSignalsInfo, ) -> std::io::Result<()> { let mut ptr = 0usize; - w.write_all(WITNESSCALC_GRAPH_MAGIC).unwrap(); + w.write_all(WITNESSCALC_GRAPH_MAGIC)?; ptr += WITNESSCALC_GRAPH_MAGIC.len(); w.write_u64::(nodes.len() as u64)?; @@ -235,7 +262,7 @@ fn read_message( Ok(msg) } -pub fn deserialize_witnesscalc_graph( +pub(crate) fn deserialize_witnesscalc_graph( r: impl Read, ) -> std::io::Result<(Vec, Vec, InputSignalsInfo)> { let mut br = WriteBackReader::new(r); @@ -254,8 +281,7 @@ pub fn deserialize_witnesscalc_graph( let mut nodes = Vec::with_capacity(nodes_num as usize); for _ in 0..nodes_num { let n: proto::Node = read_message(&mut br)?; - let n2: graph::Node = n.into(); - nodes.push(n2); + nodes.push(n.try_into()?); } let md: proto::GraphMetadata = read_message(&mut br)?; diff --git a/rln/src/circuit/mod.rs b/rln/src/circuit/mod.rs index 51302b5..0b4cb31 100644 --- a/rln/src/circuit/mod.rs +++ b/rln/src/circuit/mod.rs @@ -1,8 +1,8 @@ // This crate provides interfaces for the zero-knowledge circuit and keys -pub mod error; -pub mod iden3calc; -pub mod qap; +pub(crate) mod error; +pub(crate) mod iden3calc; +pub(crate) mod qap; #[cfg(not(target_arch = "wasm32"))] use std::sync::LazyLock; @@ -18,7 +18,7 @@ use ark_groth16::{ use ark_relations::r1cs::ConstraintMatrices; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use crate::circuit::error::ZKeyReadError; +use self::error::ZKeyReadError; #[cfg(not(target_arch = "wasm32"))] const GRAPH_BYTES: &[u8] = include_bytes!("../../resources/tree_depth_20/graph.bin"); @@ -28,7 +28,7 @@ const ARKZKEY_BYTES: &[u8] = include_bytes!("../../resources/tree_depth_20/rln_f #[cfg(not(target_arch = "wasm32"))] static ARKZKEY: LazyLock = LazyLock::new(|| { - read_arkzkey_from_bytes_uncompressed(ARKZKEY_BYTES).expect("Failed to read arkzkey") + read_arkzkey_from_bytes_uncompressed(ARKZKEY_BYTES).expect("Default zkey must be valid") }); pub const DEFAULT_TREE_DEPTH: usize = 20; diff --git a/rln/src/error.rs b/rln/src/error.rs index 5966723..27595a7 100644 --- a/rln/src/error.rs +++ b/rln/src/error.rs @@ -3,30 +3,37 @@ use std::{array::TryFromSliceError, num::TryFromIntError}; use ark_relations::r1cs::SynthesisError; use num_bigint::{BigInt, ParseBigIntError}; use thiserror::Error; -use utils::error::{FromConfigError, ZerokitMerkleTreeError}; +use zerokit_utils::error::{FromConfigError, HashError, ZerokitMerkleTreeError}; -use crate::circuit::{error::ZKeyReadError, Fr}; +use crate::circuit::{ + error::{WitnessCalcError, ZKeyReadError}, + Fr, +}; +/// Errors that can occur during RLN utility operations (conversions, parsing, etc.) #[derive(Debug, thiserror::Error)] pub enum UtilsError { #[error("Expected radix 10 or 16")] WrongRadix, - #[error("{0}")] + #[error("Failed to parse big integer: {0}")] ParseBigInt(#[from] ParseBigIntError), - #[error("{0}")] + #[error("Failed to convert to usize: {0}")] ToUsize(#[from] TryFromIntError), - #[error("{0}")] + #[error("Failed to convert from slice: {0}")] FromSlice(#[from] TryFromSliceError), #[error("Input data too short: expected at least {expected} bytes, got {actual} bytes")] InsufficientData { expected: usize, actual: usize }, } +/// Errors that can occur during RLN protocol operations (proof generation, verification, etc.) #[derive(Debug, thiserror::Error)] pub enum ProtocolError { #[error("Error producing proof: {0}")] Synthesis(#[from] SynthesisError), - #[error("{0}")] + #[error("RLN utility error: {0}")] Utils(#[from] UtilsError), + #[error("Error calculating witness: {0}")] + WitnessCalc(#[from] WitnessCalcError), #[error("Expected to read {0} bytes but read only {1} bytes")] InvalidReadLen(usize, usize), #[error("Cannot convert bigint {0:?} to biguint")] @@ -39,8 +46,15 @@ pub enum ProtocolError { ExternalNullifierMismatch(Fr, Fr), #[error("Cannot recover secret: division by zero")] DivisionByZero, + #[error("Merkle tree operation error: {0}")] + MerkleTree(#[from] ZerokitMerkleTreeError), + #[error("Hash computation error: {0}")] + Hash(#[from] HashError), + #[error("Proof serialization error: {0}")] + SerializationError(#[from] ark_serialize::SerializationError), } +/// Errors that can occur during proof verification #[derive(Error, Debug)] pub enum VerifyError { #[error("Invalid proof provided")] @@ -51,16 +65,19 @@ pub enum VerifyError { InvalidSignal, } +/// Top-level RLN error type encompassing all RLN operations #[derive(Debug, thiserror::Error)] pub enum RLNError { - #[error("Config error: {0}")] + #[error("Configuration error: {0}")] Config(#[from] FromConfigError), #[error("Merkle tree error: {0}")] MerkleTree(#[from] ZerokitMerkleTreeError), + #[error("Hash error: {0}")] + Hash(#[from] HashError), #[error("ZKey error: {0}")] ZKey(#[from] ZKeyReadError), #[error("Protocol error: {0}")] Protocol(#[from] ProtocolError), - #[error("Verify error: {0}")] + #[error("Verification error: {0}")] Verify(#[from] VerifyError), } diff --git a/rln/src/ffi/ffi_rln.rs b/rln/src/ffi/ffi_rln.rs index bcedac1..66a7356 100644 --- a/rln/src/ffi/ffi_rln.rs +++ b/rln/src/ffi/ffi_rln.rs @@ -152,13 +152,35 @@ pub fn ffi_rln_proof_get_values( } #[ffi_export] -pub fn ffi_rln_proof_to_bytes_le(rln_proof: &repr_c::Box) -> repr_c::Vec { - rln_proof_to_bytes_le(&rln_proof.0).into() +pub fn ffi_rln_proof_to_bytes_le( + rln_proof: &repr_c::Box, +) -> CResult, repr_c::String> { + match rln_proof_to_bytes_le(&rln_proof.0) { + Ok(bytes) => CResult { + ok: Some(bytes.into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(err.to_string().into()), + }, + } } #[ffi_export] -pub fn ffi_rln_proof_to_bytes_be(rln_proof: &repr_c::Box) -> repr_c::Vec { - rln_proof_to_bytes_be(&rln_proof.0).into() +pub fn ffi_rln_proof_to_bytes_be( + rln_proof: &repr_c::Box, +) -> CResult, repr_c::String> { + match rln_proof_to_bytes_be(&rln_proof.0) { + Ok(bytes) => CResult { + ok: Some(bytes.into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(err.to_string().into()), + }, + } } #[ffi_export] @@ -376,9 +398,9 @@ pub fn ffi_bytes_le_to_rln_proof_values( ok: Some(Box_::new(FFI_RLNProofValues(pv))), err: None, }, - Err(e) => CResult { + Err(err) => CResult { ok: None, - err: Some(format!("{:?}", e).into()), + err: Some(format!("{:?}", err).into()), }, } } @@ -392,9 +414,9 @@ pub fn ffi_bytes_be_to_rln_proof_values( ok: Some(Box_::new(FFI_RLNProofValues(pv))), err: None, }, - Err(e) => CResult { + Err(err) => CResult { ok: None, - err: Some(format!("{:?}", e).into()), + err: Some(format!("{:?}", err).into()), }, } } diff --git a/rln/src/ffi/ffi_utils.rs b/rln/src/ffi/ffi_utils.rs index 9ff1991..5310192 100644 --- a/rln/src/ffi/ffi_utils.rs +++ b/rln/src/ffi/ffi_utils.rs @@ -93,9 +93,9 @@ pub fn ffi_bytes_le_to_cfr(bytes: &repr_c::Vec) -> CResult, ok: Some(CFr(cfr).into()), err: None, }, - Err(e) => CResult { + Err(err) => CResult { ok: None, - err: Some(format!("{:?}", e).into()), + err: Some(format!("{:?}", err).into()), }, } } @@ -107,9 +107,9 @@ pub fn ffi_bytes_be_to_cfr(bytes: &repr_c::Vec) -> CResult, ok: Some(CFr(cfr).into()), err: None, }, - Err(e) => CResult { + Err(err) => CResult { ok: None, - err: Some(format!("{:?}", e).into()), + err: Some(format!("{:?}", err).into()), }, } } @@ -286,58 +286,119 @@ pub fn ffi_vec_u8_free(v: repr_c::Vec) { // Utility APIs #[ffi_export] -pub fn ffi_hash_to_field_le(input: &repr_c::Vec) -> repr_c::Box { - let hash_result = hash_to_field_le(input); - CFr::from(hash_result).into() +pub fn ffi_hash_to_field_le(input: &repr_c::Vec) -> CResult, repr_c::String> { + match hash_to_field_le(input) { + Ok(hash_result) => CResult { + ok: Some(CFr::from(hash_result).into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_hash_to_field_be(input: &repr_c::Vec) -> repr_c::Box { - let hash_result = hash_to_field_be(input); - CFr::from(hash_result).into() +pub fn ffi_hash_to_field_be(input: &repr_c::Vec) -> CResult, repr_c::String> { + match hash_to_field_be(input) { + Ok(hash_result) => CResult { + ok: Some(CFr::from(hash_result).into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_poseidon_hash_pair(a: &CFr, b: &CFr) -> repr_c::Box { - let hash_result = poseidon_hash(&[a.0, b.0]); - CFr::from(hash_result).into() +pub fn ffi_poseidon_hash_pair(a: &CFr, b: &CFr) -> CResult, repr_c::String> { + match poseidon_hash(&[a.0, b.0]) { + Ok(hash_result) => CResult { + ok: Some(CFr::from(hash_result).into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_key_gen() -> repr_c::Vec { - let (identity_secret, id_commitment) = keygen(); - vec![CFr(*identity_secret), CFr(id_commitment)].into() +pub fn ffi_key_gen() -> CResult, repr_c::String> { + match keygen() { + Ok((identity_secret, id_commitment)) => CResult { + ok: Some(vec![CFr(*identity_secret), CFr(id_commitment)].into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_seeded_key_gen(seed: &repr_c::Vec) -> repr_c::Vec { - let (identity_secret, id_commitment) = seeded_keygen(seed); - vec![CFr(identity_secret), CFr(id_commitment)].into() +pub fn ffi_seeded_key_gen(seed: &repr_c::Vec) -> CResult, repr_c::String> { + match seeded_keygen(seed) { + Ok((identity_secret, id_commitment)) => CResult { + ok: Some(vec![CFr(identity_secret), CFr(id_commitment)].into()), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_extended_key_gen() -> repr_c::Vec { - let (identity_trapdoor, identity_nullifier, identity_secret, id_commitment) = extended_keygen(); - vec![ - CFr(identity_trapdoor), - CFr(identity_nullifier), - CFr(identity_secret), - CFr(id_commitment), - ] - .into() +pub fn ffi_extended_key_gen() -> CResult, repr_c::String> { + match extended_keygen() { + Ok((identity_trapdoor, identity_nullifier, identity_secret, id_commitment)) => CResult { + ok: Some( + vec![ + CFr(identity_trapdoor), + CFr(identity_nullifier), + CFr(identity_secret), + CFr(id_commitment), + ] + .into(), + ), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] -pub fn ffi_seeded_extended_key_gen(seed: &repr_c::Vec) -> repr_c::Vec { - let (identity_trapdoor, identity_nullifier, identity_secret, id_commitment) = - extended_seeded_keygen(seed); - vec![ - CFr(identity_trapdoor), - CFr(identity_nullifier), - CFr(identity_secret), - CFr(id_commitment), - ] - .into() +pub fn ffi_seeded_extended_key_gen( + seed: &repr_c::Vec, +) -> CResult, repr_c::String> { + match extended_seeded_keygen(seed) { + Ok((identity_trapdoor, identity_nullifier, identity_secret, id_commitment)) => CResult { + ok: Some( + vec![ + CFr(identity_trapdoor), + CFr(identity_nullifier), + CFr(identity_secret), + CFr(id_commitment), + ] + .into(), + ), + err: None, + }, + Err(err) => CResult { + ok: None, + err: Some(format!("{:?}", err).into()), + }, + } } #[ffi_export] diff --git a/rln/src/hashers.rs b/rln/src/hashers.rs index fa67ca7..c805a08 100644 --- a/rln/src/hashers.rs +++ b/rln/src/hashers.rs @@ -2,10 +2,11 @@ use once_cell::sync::Lazy; use tiny_keccak::{Hasher, Keccak}; -use utils::poseidon::Poseidon; +use zerokit_utils::{error::HashError, poseidon::Poseidon}; use crate::{ circuit::Fr, + error::UtilsError, utils::{bytes_be_to_fr, bytes_le_to_fr}, }; @@ -26,10 +27,9 @@ const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [ /// Poseidon Hash wrapper over above implementation. static POSEIDON: Lazy> = Lazy::new(|| Poseidon::::from(&ROUND_PARAMS)); -pub fn poseidon_hash(input: &[Fr]) -> Fr { - POSEIDON - .hash(input) - .expect("hash with fixed input size can't fail") +pub fn poseidon_hash(input: &[Fr]) -> Result { + let hash = POSEIDON.hash(input)?; + Ok(hash) } /// The zerokit RLN Merkle tree Hasher. @@ -37,20 +37,21 @@ pub fn poseidon_hash(input: &[Fr]) -> Fr { pub struct PoseidonHash; /// The default Hasher trait used by Merkle tree implementation in utils. -impl utils::merkle_tree::Hasher for PoseidonHash { +impl zerokit_utils::merkle_tree::Hasher for PoseidonHash { type Fr = Fr; + type Error = HashError; fn default_leaf() -> Self::Fr { Self::Fr::from(0) } - fn hash(inputs: &[Self::Fr]) -> Self::Fr { + fn hash(inputs: &[Self::Fr]) -> Result { poseidon_hash(inputs) } } /// Hashes arbitrary signal to the underlying prime field. -pub fn hash_to_field_le(signal: &[u8]) -> Fr { +pub fn hash_to_field_le(signal: &[u8]) -> Result { // We hash the input signal using Keccak256 let mut hash = [0; 32]; let mut hasher = Keccak::v256(); @@ -58,12 +59,13 @@ pub fn hash_to_field_le(signal: &[u8]) -> Fr { hasher.finalize(&mut hash); // We export the hash as a field element - let (el, _) = bytes_le_to_fr(hash.as_ref()).expect("Keccak256 hash is always 32 bytes"); - el + let (el, _) = bytes_le_to_fr(hash.as_ref())?; + + Ok(el) } /// Hashes arbitrary signal to the underlying prime field. -pub fn hash_to_field_be(signal: &[u8]) -> Fr { +pub fn hash_to_field_be(signal: &[u8]) -> Result { // We hash the input signal using Keccak256 let mut hash = [0; 32]; let mut hasher = Keccak::v256(); @@ -74,6 +76,7 @@ pub fn hash_to_field_be(signal: &[u8]) -> Fr { hash.reverse(); // We export the hash as a field element - let (el, _) = bytes_be_to_fr(hash.as_ref()).expect("Keccak256 hash is always 32 bytes"); - el + let (el, _) = bytes_be_to_fr(hash.as_ref())?; + + Ok(el) } diff --git a/rln/src/pm_tree_adapter.rs b/rln/src/pm_tree_adapter.rs index 2bdc2c6..525b53c 100644 --- a/rln/src/pm_tree_adapter.rs +++ b/rln/src/pm_tree_adapter.rs @@ -4,11 +4,14 @@ use std::{fmt::Debug, path::PathBuf, str::FromStr}; use serde_json::Value; use tempfile::Builder; -use utils::{ +use zerokit_utils::{ error::{FromConfigError, ZerokitMerkleTreeError}, - pmtree, - pmtree::{tree::Key, Database, Hasher, PmtreeErrorKind}, - Config, Mode, SledDB, ZerokitMerkleProof, ZerokitMerkleTree, + merkle_tree::{ZerokitMerkleProof, ZerokitMerkleTree}, + pm_tree::{ + pmtree, + pmtree::{tree::Key, Database, Hasher, PmtreeErrorKind}, + Config, Mode, SledDB, + }, }; use crate::{ @@ -43,7 +46,8 @@ impl Hasher for PoseidonHash { } fn deserialize(value: pmtree::Value) -> Self::Fr { - let (fr, _) = bytes_le_to_fr(&value).expect("pmtree value should be valid Fr bytes"); + // TODO: allow to handle error properly in pmtree Hasher trait + let (fr, _) = bytes_le_to_fr(&value).expect("Fr deserialization must be valid"); fr } @@ -52,17 +56,17 @@ impl Hasher for PoseidonHash { } fn hash(inputs: &[Self::Fr]) -> Self::Fr { - poseidon_hash(inputs) + // TODO: allow to handle error properly in pmtree Hasher trait + poseidon_hash(inputs).expect("Poseidon hash must be valid") } } -fn default_tmp_path() -> PathBuf { - Builder::new() +fn default_tmp_path() -> Result { + Ok(Builder::new() .prefix("pmtree-") - .tempfile() - .expect("Failed to create temp file") + .tempfile()? .into_temp_path() - .to_path_buf() + .to_path_buf()) } const DEFAULT_TEMPORARY: bool = true; @@ -130,7 +134,7 @@ impl PmtreeConfigBuilder { pub fn build(self) -> Result { let path = match (self.temporary, self.path) { - (true, None) => default_tmp_path(), + (true, None) => default_tmp_path()?, (false, None) => return Err(FromConfigError::MissingPath), (true, Some(path)) if path.exists() => return Err(FromConfigError::PathExists), (_, Some(path)) => path, @@ -180,9 +184,10 @@ impl FromStr for PmtreeConfig { } } + let default_tmp_path = default_tmp_path()?; let config = Config::new() .temporary(temporary.unwrap_or(DEFAULT_TEMPORARY)) - .path(path.unwrap_or(default_tmp_path())) + .path(path.unwrap_or(default_tmp_path)) .cache_capacity(cache_capacity.unwrap_or(DEFAULT_CACHE_CAPACITY)) .flush_every_ms(flush_every_ms) .mode(mode) @@ -195,9 +200,10 @@ impl Default for PmtreeConfig { fn default() -> Self { Self::builder() .build() - .expect("Default configuration should never fail") + .expect("Default PmtreeConfig must be valid") } } + impl Debug for PmtreeConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) @@ -393,11 +399,8 @@ impl ZerokitMerkleTree for PmTree { // if empty, try searching the db let data = self.tree.db.get(METADATA_KEY)?; - if data.is_none() { - // send empty Metadata - return Ok(Vec::new()); - } - Ok(data.unwrap()) + // Return empty metadata if not found, otherwise return the data + Ok(data.unwrap_or_default()) } fn close_db_connection(&mut self) -> Result<(), ZerokitMerkleTreeError> { @@ -413,8 +416,13 @@ type FrOfPmTreeHasher = FrOf; impl PmTree { fn remove_indices(&mut self, indices: &[usize]) -> Result<(), PmtreeErrorKind> { + if indices.is_empty() { + return Err(PmtreeErrorKind::TreeError( + pmtree::TreeErrorKind::InvalidKey, + )); + } let start = indices[0]; - let end = indices.last().unwrap() + 1; + let end = indices[indices.len() - 1] + 1; let new_leaves = (start..end).map(|_| PmTreeHasher::default_leaf()); @@ -432,7 +440,12 @@ impl PmTree { leaves: Vec, indices: &[usize], ) -> Result<(), PmtreeErrorKind> { - let min_index = *indices.first().unwrap(); + if indices.is_empty() { + return Err(PmtreeErrorKind::TreeError( + pmtree::TreeErrorKind::InvalidKey, + )); + } + let min_index = indices[0]; let max_index = start + leaves.len(); let mut set_values = vec![PmTreeHasher::default_leaf(); max_index - min_index]; @@ -480,8 +493,12 @@ impl ZerokitMerkleProof for PmTreeProof { fn get_path_index(&self) -> Vec { self.proof.get_path_index() } - fn compute_root_from(&self, leaf: &FrOf) -> FrOf { - self.proof.compute_root_from(leaf) + + fn compute_root_from( + &self, + leaf: &FrOf, + ) -> Result, ZerokitMerkleTreeError> { + Ok(self.proof.compute_root_from(leaf)) } } @@ -501,15 +518,15 @@ mod test { "use_compression": false }"#; - let _: PmtreeConfig = json.parse().expect("Failed to parse JSON config"); + let _: PmtreeConfig = json.parse().unwrap(); let _ = PmtreeConfig::builder() - .path(default_tmp_path()) + .path(default_tmp_path().unwrap()) .temporary(DEFAULT_TEMPORARY) .cache_capacity(DEFAULT_CACHE_CAPACITY) .mode(DEFAULT_MODE) .use_compression(DEFAULT_USE_COMPRESSION) .build() - .expect("Failed to build config"); + .unwrap(); } } diff --git a/rln/src/poseidon_tree.rs b/rln/src/poseidon_tree.rs index 860bee3..39032fa 100644 --- a/rln/src/poseidon_tree.rs +++ b/rln/src/poseidon_tree.rs @@ -10,13 +10,13 @@ use cfg_if::cfg_if; cfg_if! { if #[cfg(feature = "fullmerkletree")] { - use utils::{FullMerkleTree, FullMerkleProof}; + use zerokit_utils::{FullMerkleTree, FullMerkleProof}; use crate::hashers::PoseidonHash; pub type PoseidonTree = FullMerkleTree; pub type MerkleProof = FullMerkleProof; } else if #[cfg(feature = "optimalmerkletree")] { - use utils::{OptimalMerkleTree, OptimalMerkleProof}; + use zerokit_utils::{OptimalMerkleTree, OptimalMerkleProof}; use crate::hashers::PoseidonHash; pub type PoseidonTree = OptimalMerkleTree; diff --git a/rln/src/prelude.rs b/rln/src/prelude.rs index ff880c5..572adda 100644 --- a/rln/src/prelude.rs +++ b/rln/src/prelude.rs @@ -1,8 +1,5 @@ // This module re-exports the most commonly used types and functions from the RLN library -#[cfg(not(feature = "stateless"))] -pub use utils::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree}; - #[cfg(not(target_arch = "wasm32"))] pub use crate::circuit::{graph_from_folder, zkey_from_folder}; #[cfg(feature = "pmtree-ft")] diff --git a/rln/src/protocol/keygen.rs b/rln/src/protocol/keygen.rs index 4e48263..b5c7783 100644 --- a/rln/src/protocol/keygen.rs +++ b/rln/src/protocol/keygen.rs @@ -2,17 +2,18 @@ use ark_std::{rand::thread_rng, UniformRand}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use tiny_keccak::{Hasher as _, Keccak}; +use zerokit_utils::error::ZerokitMerkleTreeError; use crate::{circuit::Fr, hashers::poseidon_hash, utils::IdSecret}; /// Generates a random RLN identity using a cryptographically secure RNG. /// /// Returns `(identity_secret, id_commitment)` where the commitment is `PoseidonHash(identity_secret)`. -pub fn keygen() -> (IdSecret, Fr) { +pub fn keygen() -> Result<(IdSecret, Fr), ZerokitMerkleTreeError> { let mut rng = thread_rng(); let identity_secret = IdSecret::rand(&mut rng); - let id_commitment = poseidon_hash(&[*identity_secret.clone()]); - (identity_secret, id_commitment) + let id_commitment = poseidon_hash(&[*identity_secret.clone()])?; + Ok((identity_secret, id_commitment)) } /// Generates an extended RLN identity compatible with Semaphore. @@ -20,25 +21,25 @@ pub fn keygen() -> (IdSecret, Fr) { /// Returns `(identity_trapdoor, identity_nullifier, identity_secret, id_commitment)` where: /// - `identity_secret = PoseidonHash(identity_trapdoor, identity_nullifier)` /// - `id_commitment = PoseidonHash(identity_secret)` -pub fn extended_keygen() -> (Fr, Fr, Fr, Fr) { +pub fn extended_keygen() -> Result<(Fr, Fr, Fr, Fr), ZerokitMerkleTreeError> { let mut rng = thread_rng(); let identity_trapdoor = Fr::rand(&mut rng); let identity_nullifier = Fr::rand(&mut rng); - let identity_secret = poseidon_hash(&[identity_trapdoor, identity_nullifier]); - let id_commitment = poseidon_hash(&[identity_secret]); - ( + let identity_secret = poseidon_hash(&[identity_trapdoor, identity_nullifier])?; + let id_commitment = poseidon_hash(&[identity_secret])?; + Ok(( identity_trapdoor, identity_nullifier, identity_secret, id_commitment, - ) + )) } /// Generates a deterministic RLN identity from a seed. /// /// Uses ChaCha20 RNG seeded with Keccak-256 hash of the input. /// Returns `(identity_secret, id_commitment)`. Same input always produces the same identity. -pub fn seeded_keygen(signal: &[u8]) -> (Fr, Fr) { +pub fn seeded_keygen(signal: &[u8]) -> Result<(Fr, Fr), ZerokitMerkleTreeError> { // ChaCha20 requires a seed of exactly 32 bytes. // We first hash the input seed signal to a 32 bytes array and pass this as seed to ChaCha20 let mut seed = [0; 32]; @@ -48,8 +49,8 @@ pub fn seeded_keygen(signal: &[u8]) -> (Fr, Fr) { let mut rng = ChaCha20Rng::from_seed(seed); let identity_secret = Fr::rand(&mut rng); - let id_commitment = poseidon_hash(&[identity_secret]); - (identity_secret, id_commitment) + let id_commitment = poseidon_hash(&[identity_secret])?; + Ok((identity_secret, id_commitment)) } /// Generates a deterministic extended RLN identity from a seed, compatible with Semaphore. @@ -57,7 +58,7 @@ pub fn seeded_keygen(signal: &[u8]) -> (Fr, Fr) { /// Uses ChaCha20 RNG seeded with Keccak-256 hash of the input. /// Returns `(identity_trapdoor, identity_nullifier, identity_secret, id_commitment)`. /// Same input always produces the same identity. -pub fn extended_seeded_keygen(signal: &[u8]) -> (Fr, Fr, Fr, Fr) { +pub fn extended_seeded_keygen(signal: &[u8]) -> Result<(Fr, Fr, Fr, Fr), ZerokitMerkleTreeError> { // ChaCha20 requires a seed of exactly 32 bytes. // We first hash the input seed signal to a 32 bytes array and pass this as seed to ChaCha20 let mut seed = [0; 32]; @@ -68,12 +69,12 @@ pub fn extended_seeded_keygen(signal: &[u8]) -> (Fr, Fr, Fr, Fr) { let mut rng = ChaCha20Rng::from_seed(seed); let identity_trapdoor = Fr::rand(&mut rng); let identity_nullifier = Fr::rand(&mut rng); - let identity_secret = poseidon_hash(&[identity_trapdoor, identity_nullifier]); - let id_commitment = poseidon_hash(&[identity_secret]); - ( + let identity_secret = poseidon_hash(&[identity_trapdoor, identity_nullifier])?; + let id_commitment = poseidon_hash(&[identity_secret])?; + Ok(( identity_trapdoor, identity_nullifier, identity_secret, id_commitment, - ) + )) } diff --git a/rln/src/protocol/proof.rs b/rln/src/protocol/proof.rs index 17191ad..e484476 100644 --- a/rln/src/protocol/proof.rs +++ b/rln/src/protocol/proof.rs @@ -147,46 +147,40 @@ pub fn bytes_be_to_rln_proof_values( /// /// Note: The Groth16 proof is always serialized in LE format (arkworks behavior), /// while proof_values are serialized in LE format. -pub fn rln_proof_to_bytes_le(rln_proof: &RLNProof) -> Vec { +pub fn rln_proof_to_bytes_le(rln_proof: &RLNProof) -> Result, ProtocolError> { // Calculate capacity for Vec: // - 128 bytes for compressed Groth16 proof // - 5 field elements for proof values (root, external_nullifier, x, y, nullifier) let mut bytes = Vec::with_capacity(COMPRESS_PROOF_SIZE + FR_BYTE_SIZE * 5); // Serialize proof (always LE format from arkworks) - rln_proof - .proof - .serialize_compressed(&mut bytes) - .expect("serialization should not fail"); + rln_proof.proof.serialize_compressed(&mut bytes)?; // Serialize proof values in LE let proof_values_bytes = rln_proof_values_to_bytes_le(&rln_proof.proof_values); bytes.extend_from_slice(&proof_values_bytes); - bytes + Ok(bytes) } /// Serializes RLN proof to big-endian bytes. /// /// Note: The Groth16 proof is always serialized in LE format (arkworks behavior), /// while proof_values are serialized in BE format. This creates a mixed-endian format. -pub fn rln_proof_to_bytes_be(rln_proof: &RLNProof) -> Vec { +pub fn rln_proof_to_bytes_be(rln_proof: &RLNProof) -> Result, ProtocolError> { // Calculate capacity for Vec: // - 128 bytes for compressed Groth16 proof // - 5 field elements for proof values (root, external_nullifier, x, y, nullifier) let mut bytes = Vec::with_capacity(COMPRESS_PROOF_SIZE + FR_BYTE_SIZE * 5); // Serialize proof (always LE format from arkworks) - rln_proof - .proof - .serialize_compressed(&mut bytes) - .expect("serialization should not fail"); + rln_proof.proof.serialize_compressed(&mut bytes)?; // Serialize proof values in BE let proof_values_bytes = rln_proof_values_to_bytes_be(&rln_proof.proof_values); bytes.extend_from_slice(&proof_values_bytes); - bytes + Ok(bytes) } /// Deserializes RLN proof from little-endian bytes. @@ -198,8 +192,7 @@ pub fn bytes_le_to_rln_proof(bytes: &[u8]) -> Result<(RLNProof, usize), Protocol let mut read: usize = 0; // Deserialize proof (always LE from arkworks) - let proof = Proof::deserialize_compressed(&bytes[read..read + COMPRESS_PROOF_SIZE]) - .map_err(|_| ProtocolError::InvalidReadLen(bytes.len(), read + COMPRESS_PROOF_SIZE))?; + let proof = Proof::deserialize_compressed(&bytes[read..read + COMPRESS_PROOF_SIZE])?; read += COMPRESS_PROOF_SIZE; // Deserialize proof values @@ -226,8 +219,7 @@ pub fn bytes_be_to_rln_proof(bytes: &[u8]) -> Result<(RLNProof, usize), Protocol let mut read: usize = 0; // Deserialize proof (always LE from arkworks) - let proof = Proof::deserialize_compressed(&bytes[read..read + COMPRESS_PROOF_SIZE]) - .map_err(|_| ProtocolError::InvalidReadLen(bytes.len(), read + COMPRESS_PROOF_SIZE))?; + let proof = Proof::deserialize_compressed(&bytes[read..read + COMPRESS_PROOF_SIZE])?; read += COMPRESS_PROOF_SIZE; // Deserialize proof values @@ -306,7 +298,7 @@ pub fn generate_zk_proof( .into_iter() .map(|(name, values)| (name.to_string(), values)); - let full_assignment = calc_witness(inputs, graph_data); + let full_assignment = calc_witness(inputs, graph_data)?; // Random Values let mut rng = thread_rng(); diff --git a/rln/src/protocol/witness.rs b/rln/src/protocol/witness.rs index 32cbd0e..03f15c9 100644 --- a/rln/src/protocol/witness.rs +++ b/rln/src/protocol/witness.rs @@ -270,11 +270,11 @@ pub fn proof_values_from_witness( // y share let a_0 = &witness.identity_secret; let mut to_hash = [**a_0, witness.external_nullifier, witness.message_id]; - let a_1 = poseidon_hash(&to_hash); + let a_1 = poseidon_hash(&to_hash)?; let y = *(a_0.clone()) + witness.x * a_1; // Nullifier - let nullifier = poseidon_hash(&[a_1]); + let nullifier = poseidon_hash(&[a_1])?; to_hash[0].zeroize(); // Merkle tree root computations @@ -283,7 +283,7 @@ pub fn proof_values_from_witness( &witness.user_message_limit, &witness.path_elements, &witness.identity_path_index, - ); + )?; Ok(RLNProofValues { y, @@ -300,22 +300,22 @@ pub fn compute_tree_root( user_message_limit: &Fr, path_elements: &[Fr], identity_path_index: &[u8], -) -> Fr { +) -> Result { let mut to_hash = [*identity_secret.clone()]; - let id_commitment = poseidon_hash(&to_hash); + let id_commitment = poseidon_hash(&to_hash)?; to_hash[0].zeroize(); - let mut root = poseidon_hash(&[id_commitment, *user_message_limit]); + let mut root = poseidon_hash(&[id_commitment, *user_message_limit])?; for i in 0..identity_path_index.len() { if identity_path_index[i] == 0 { - root = poseidon_hash(&[root, path_elements[i]]); + root = poseidon_hash(&[root, path_elements[i]])?; } else { - root = poseidon_hash(&[path_elements[i], root]); + root = poseidon_hash(&[path_elements[i], root])?; } } - root + Ok(root) } /// Prepares inputs for witness calculation from RLN witness input. diff --git a/rln/src/public.rs b/rln/src/public.rs index d99f995..8e3b6cb 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -6,8 +6,10 @@ use num_bigint::BigInt; use { crate::poseidon_tree::PoseidonTree, std::str::FromStr, - utils::error::ZerokitMerkleTreeError, - utils::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree}, + zerokit_utils::{ + error::ZerokitMerkleTreeError, + merkle_tree::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree}, + }, }; #[cfg(not(target_arch = "wasm32"))] @@ -70,25 +72,25 @@ impl RLN { /// /// The `tree_config` parameter accepts: /// - JSON string: `"{\"path\": \"./database\"}"` - /// - Direct config (with pmtree feature): `PmtreeConfig::builder().path("./database").build()?` + /// - Direct config (with pmtree feature): `PmtreeConfigBuilder::new().path("./database").build()?` /// - Empty config for defaults: `""` /// /// Examples: /// ``` /// // Using default config - /// let rln = RLN::new(20, "").unwrap(); + /// let rln = RLN::new(20, "")?; /// /// // Using JSON string /// let config_json = r#"{"path": "./database", "cache_capacity": 1073741824}"#; - /// let rln = RLN::new(20, config_json).unwrap(); + /// let rln = RLN::new(20, config_json)?; /// /// // Using `"` for defaults - /// let rln = RLN::new(20, "").unwrap(); + /// let rln = RLN::new(20, "")?; /// ``` /// /// For advanced usage with builder pattern (pmtree feature): /// ``` - /// let config = PmtreeConfig::builder() + /// let config = PmtreeConfigBuilder::new() /// .path("./database") /// .cache_capacity(1073741824) /// .mode(Mode::HighThroughput) @@ -148,22 +150,22 @@ impl RLN { /// let mut resources: Vec> = Vec::new(); /// for filename in ["rln_final.arkzkey", "graph.bin"] { /// let fullpath = format!("{resources_folder}{filename}"); - /// let mut file = File::open(&fullpath).expect("no file found"); - /// let metadata = std::fs::metadata(&fullpath).expect("unable to read metadata"); + /// let mut file = File::open(&fullpath)?; + /// let metadata = std::fs::metadata(&fullpath)?; /// let mut buffer = vec![0; metadata.len() as usize]; - /// file.read_exact(&mut buffer).expect("buffer overflow"); + /// file.read_exact(&mut buffer)?; /// resources.push(buffer); /// } /// /// // Using default config - /// let rln = RLN::new_with_params(tree_depth, resources[0].clone(), resources[1].clone(), "").unwrap(); + /// let rln = RLN::new_with_params(tree_depth, resources[0].clone(), resources[1].clone(), "")?; /// /// // Using JSON config /// let config_json = r#"{"path": "./database"}"#; - /// let rln = RLN::new_with_params(tree_depth, resources[0].clone(), resources[1].clone(), config_json).unwrap(); + /// let rln = RLN::new_with_params(tree_depth, resources[0].clone(), resources[1].clone(), config_json)?; /// /// // Using builder pattern (with pmtree feature) - /// let config = PmtreeConfig::builder().path("./database").build()?; + /// let config = PmtreeConfigBuilder::new().path("./database").build()?; /// let rln = RLN::new_with_params(tree_depth, resources[0].clone(), resources[1].clone(), config)?; /// ``` #[cfg(all(not(target_arch = "wasm32"), not(feature = "stateless")))] @@ -204,10 +206,10 @@ impl RLN { /// let mut resources: Vec> = Vec::new(); /// for filename in ["rln_final.arkzkey", "graph.bin"] { /// let fullpath = format!("{resources_folder}{filename}"); - /// let mut file = File::open(&fullpath).expect("no file found"); - /// let metadata = std::fs::metadata(&fullpath).expect("unable to read metadata"); + /// let mut file = File::open(&fullpath)?; + /// let metadata = std::fs::metadata(&fullpath)?; /// let mut buffer = vec![0; metadata.len() as usize]; - /// file.read_exact(&mut buffer).expect("buffer overflow"); + /// file.read_exact(&mut buffer)?; /// resources.push(buffer); /// } /// @@ -232,10 +234,10 @@ impl RLN { /// ``` /// let zkey_path = "./resources/tree_depth_20/rln_final.arkzkey"; /// - /// let mut file = File::open(zkey_path).expect("Failed to open file"); - /// let metadata = std::fs::metadata(zkey_path).expect("Failed to read metadata"); + /// let mut file = File::open(zkey_path)?; + /// let metadata = std::fs::metadata(zkey_path)?; /// let mut zkey_data = vec![0; metadata.len() as usize]; - /// file.read_exact(&mut zkey_data).expect("Failed to read file"); + /// file.read_exact(&mut zkey_data)?; /// /// let mut rln = RLN::new_with_params(zkey_data)?; /// ``` @@ -273,7 +275,7 @@ impl RLN { /// let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); /// /// // Set the leaf directly - /// rln.set_leaf(leaf_index, rate_commitment).unwrap(); + /// rln.set_leaf(leaf_index, rate_commitment)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn set_leaf(&mut self, index: usize, leaf: Fr) -> Result<(), RLNError> { @@ -286,7 +288,7 @@ impl RLN { /// Example: /// ``` /// let leaf_index = 10; - /// let rate_commitment = rln.get_leaf(leaf_index).unwrap(); + /// let rate_commitment = rln.get_leaf(leaf_index)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn get_leaf(&self, index: usize) -> Result { @@ -315,7 +317,7 @@ impl RLN { /// } /// /// // We add leaves in a batch into the tree - /// rln.set_leaves_from(index, leaves).unwrap(); + /// rln.set_leaves_from(start_index, leaves)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn set_leaves_from(&mut self, index: usize, leaves: Vec) -> Result<(), RLNError> { @@ -327,10 +329,10 @@ impl RLN { /// Resets the tree state to default and sets multiple leaves starting from index 0. /// /// In contrast to [`set_leaves_from`](crate::public::RLN::set_leaves_from), this function resets to 0 the internal `next_index` value, before setting the input leaves values. + /// + /// This requires the tree to be initialized with the correct depth initially. #[cfg(not(feature = "stateless"))] pub fn init_tree_with_leaves(&mut self, leaves: Vec) -> Result<(), RLNError> { - // NOTE: this requires the tree to be initialized with the correct depth initially - // TODO: accept tree_depth as a parameter and initialize the tree with that depth self.set_tree(self.tree.depth())?; self.set_leaves_from(0, leaves) } @@ -365,7 +367,7 @@ impl RLN { /// } /// /// // We atomically add leaves and remove indices from the tree - /// rln.atomic_operation(index, leaves, indices).unwrap(); + /// rln.atomic_operation(start_index, leaves, indices)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn atomic_operation( @@ -396,7 +398,7 @@ impl RLN { /// let no_of_leaves = 256; /// /// // We reset the tree - /// rln.set_tree(tree_depth).unwrap(); + /// rln.set_tree(tree_depth)?; /// /// // Internal Merkle tree next_index value is now 0 /// @@ -410,7 +412,7 @@ impl RLN { /// } /// /// // We add leaves in a batch into the tree - /// rln.set_leaves_from(index, leaves).unwrap(); + /// rln.set_leaves_from(start_index, leaves)?; /// /// // We set 256 leaves starting from index 10: next_index value is now max(0, 256+10) = 266 /// @@ -418,7 +420,7 @@ impl RLN { /// // rate_commitment will be set at index 266 /// let (_, id_commitment) = keygen(); /// let rate_commitment = poseidon_hash(&[id_commitment, 1.into()]); - /// rln.set_next_leaf(rate_commitment).unwrap(); + /// rln.set_next_leaf(rate_commitment)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn set_next_leaf(&mut self, leaf: Fr) -> Result<(), RLNError> { @@ -434,7 +436,7 @@ impl RLN { /// ``` /// /// let index = 10; - /// rln.delete_leaf(index).unwrap(); + /// rln.delete_leaf(index)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn delete_leaf(&mut self, index: usize) -> Result<(), RLNError> { @@ -450,7 +452,7 @@ impl RLN { /// /// ``` /// let metadata = b"some metadata"; - /// rln.set_metadata(metadata).unwrap(); + /// rln.set_metadata(metadata)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), RLNError> { @@ -463,7 +465,7 @@ impl RLN { /// Example: /// /// ``` - /// let metadata = rln.get_metadata().unwrap(); + /// let metadata = rln.get_metadata()?; /// ``` #[cfg(not(feature = "stateless"))] pub fn get_metadata(&self) -> Result, RLNError> { @@ -488,7 +490,7 @@ impl RLN { /// ``` /// let level = 1; /// let index = 2; - /// let subroot = rln.get_subtree_root(level, index).unwrap(); + /// let subroot = rln.get_subtree_root(level, index)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn get_subtree_root(&self, level: usize, index: usize) -> Result { @@ -501,7 +503,7 @@ impl RLN { /// Example: /// ``` /// let index = 10; - /// let (path_elements, identity_path_index) = rln.get_merkle_proof(index).unwrap(); + /// let (path_elements, identity_path_index) = rln.get_merkle_proof(index)?; /// ``` #[cfg(not(feature = "stateless"))] pub fn get_merkle_proof(&self, index: usize) -> Result<(Vec, Vec), RLNError> { @@ -529,7 +531,7 @@ impl RLN { /// } /// /// // We add leaves in a batch into the tree - /// rln.set_leaves_from(index, leaves).unwrap(); + /// rln.set_leaves_from(start_index, leaves)?; /// /// // Get indices of first empty leaves upto start_index /// let idxs = rln.get_empty_leaves_indices(); @@ -560,7 +562,7 @@ impl RLN { /// let proof_values = proof_values_from_witness(&witness); /// /// // We compute a Groth16 proof - /// let zk_proof = rln.generate_zk_proof(&witness).unwrap(); + /// let zk_proof = rln.generate_zk_proof(&witness)?; /// ``` #[cfg(not(target_arch = "wasm32"))] pub fn generate_zk_proof(&self, witness: &RLNWitnessInput) -> Result { @@ -575,7 +577,7 @@ impl RLN { /// Example: /// ``` /// let witness = RLNWitnessInput::new(...); - /// let (proof, proof_values) = rln.generate_rln_proof(&witness).unwrap(); + /// let (proof, proof_values) = rln.generate_rln_proof(&witness)?; /// ``` #[cfg(not(target_arch = "wasm32"))] pub fn generate_rln_proof( @@ -595,7 +597,7 @@ impl RLN { /// ``` /// let witness = RLNWitnessInput::new(...); /// let calculated_witness: Vec = ...; // obtained from external witness calculator - /// let (proof, proof_values) = rln.generate_rln_proof_with_witness(calculated_witness, &witness).unwrap(); + /// let (proof, proof_values) = rln.generate_rln_proof_with_witness(calculated_witness, &witness)?; /// ``` pub fn generate_rln_proof_with_witness( &self, @@ -612,13 +614,13 @@ impl RLN { /// Example: /// ``` /// // We compute a Groth16 proof - /// let zk_proof = rln.generate_zk_proof(&witness).unwrap(); + /// let zk_proof = rln.generate_zk_proof(&witness)?; /// /// // We compute proof values directly from witness /// let proof_values = proof_values_from_witness(&witness); /// /// // We verify the proof - /// let verified = rln.verify_zk_proof(&zk_proof, &proof_values).unwrap(); + /// let verified = rln.verify_zk_proof(&zk_proof, &proof_values)?; /// /// assert!(verified); /// ``` diff --git a/rln/src/utils.rs b/rln/src/utils.rs index 1f3a458..b9985f9 100644 --- a/rln/src/utils.rs +++ b/rln/src/utils.rs @@ -283,12 +283,15 @@ pub fn bytes_le_to_vec_usize(input: &[u8]) -> Result, UtilsError> { actual: input.len(), }); } - let elements: Vec = input[8..] - .chunks(8) + input[8..] + .chunks_exact(8) .take(nof_elem) - .map(|ch| usize::from_le_bytes(ch[0..8].try_into().unwrap())) - .collect(); - Ok(elements) + .map(|ch| { + ch.try_into() + .map(usize::from_le_bytes) + .map_err(UtilsError::FromSlice) + }) + .collect() } } @@ -310,12 +313,15 @@ pub fn bytes_be_to_vec_usize(input: &[u8]) -> Result, UtilsError> { actual: input.len(), }); } - let elements: Vec = input[8..] - .chunks(8) + input[8..] + .chunks_exact(8) .take(nof_elem) - .map(|ch| usize::from_be_bytes(ch[0..8].try_into().unwrap())) - .collect(); - Ok(elements) + .map(|ch| { + ch.try_into() + .map(usize::from_be_bytes) + .map_err(UtilsError::FromSlice) + }) + .collect() } } diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index df2028f..88cce78 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -48,7 +48,17 @@ mod test { } fn identity_pair_gen() -> (IdSecret, Fr) { - let key_gen = ffi_key_gen(); + let key_gen = match ffi_key_gen() { + CResult { + ok: Some(keys), + err: None, + } => keys, + CResult { + ok: None, + err: Some(err), + } => panic!("key gen call failed: {}", err), + _ => unreachable!(), + }; let mut id_secret_fr = *key_gen[0]; let id_secret_hash = IdSecret::from(&mut id_secret_fr); let id_commitment = *key_gen[1]; @@ -328,13 +338,13 @@ mod test { let mut ffi_rln_instance = create_rln_instance(); // generate identity - let mut identity_secret_ = hash_to_field_le(b"test-merkle-proof"); + let mut identity_secret_ = hash_to_field_le(b"test-merkle-proof").unwrap(); let identity_secret = IdSecret::from(&mut identity_secret_); let mut to_hash = [*identity_secret.clone()]; - let id_commitment = poseidon_hash(&to_hash); + let id_commitment = poseidon_hash(&to_hash).unwrap(); to_hash[0].zeroize(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); // We prepare id_commitment and we set the leaf at provided index let result = ffi_set_leaf( @@ -400,7 +410,7 @@ mod test { "0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f", "0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca", ] - .map(|e| str_to_fr(e, 16).unwrap()) + .map(|str| str_to_fr(str, 16).unwrap()) .to_vec(); let expected_identity_path_index: Vec = @@ -415,7 +425,8 @@ mod test { &user_message_limit, &path_elements, &identity_path_index, - ); + ) + .unwrap(); assert_eq!(root, root_from_proof); } @@ -430,20 +441,16 @@ mod test { let root_rln_folder = get_tree_root(&ffi_rln_instance); let zkey_path = "./resources/tree_depth_20/rln_final.arkzkey"; - let mut zkey_file = File::open(zkey_path).expect("no file found"); - let metadata = std::fs::metadata(zkey_path).expect("unable to read metadata"); + let mut zkey_file = File::open(zkey_path).unwrap(); + let metadata = std::fs::metadata(zkey_path).unwrap(); let mut zkey_data = vec![0; metadata.len() as usize]; - zkey_file - .read_exact(&mut zkey_data) - .expect("buffer overflow"); + zkey_file.read_exact(&mut zkey_data).unwrap(); let graph_data = "./resources/tree_depth_20/graph.bin"; - let mut graph_file = File::open(graph_data).expect("no file found"); - let metadata = std::fs::metadata(graph_data).expect("unable to read metadata"); + let mut graph_file = File::open(graph_data).unwrap(); + let metadata = std::fs::metadata(graph_data).unwrap(); let mut graph_buffer = vec![0; metadata.len() as usize]; - graph_file - .read_exact(&mut graph_buffer) - .expect("buffer overflow"); + graph_file.read_exact(&mut graph_buffer).unwrap(); // Creating a RLN instance passing the raw data let tree_config = "".to_string(); @@ -479,7 +486,7 @@ mod test { // We generate a vector of random leaves let mut rng = thread_rng(); let leaves: Vec = (0..NO_OF_LEAVES) - .map(|_| poseidon_hash(&[Fr::rand(&mut rng), Fr::from(100)])) + .map(|_| poseidon_hash(&[Fr::rand(&mut rng), Fr::from(100)]).unwrap()) .collect(); // We create a RLN instance @@ -497,15 +504,15 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); // We set as leaf rate_commitment, its index would be equal to no_of_leaves let result = ffi_set_next_leaf(&mut ffi_rln_instance, &CFr::from(rate_commitment)); @@ -514,7 +521,7 @@ mod test { } // Hash the signal to get x - let x = hash_to_field_le(&signal); + let x = hash_to_field_le(&signal).unwrap(); let rln_proof = rln_proof_gen( &ffi_rln_instance, @@ -544,7 +551,7 @@ mod test { // We generate a new identity pair let (identity_secret, id_commitment) = identity_pair_gen(); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); let identity_index: usize = NO_OF_LEAVES; // We generate a random signal @@ -552,11 +559,11 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); @@ -567,7 +574,7 @@ mod test { } // Hash the signal to get x - let x = hash_to_field_le(&signal); + let x = hash_to_field_le(&signal).unwrap(); let rln_proof = rln_proof_gen( &ffi_rln_instance, @@ -639,7 +646,7 @@ mod test { let (identity_secret, id_commitment) = identity_pair_gen(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); // We set as leaf rate_commitment, its index would be equal to 0 since tree is empty let result = ffi_set_next_leaf(&mut ffi_rln_instance, &CFr::from(rate_commitment)); @@ -657,17 +664,17 @@ mod test { let signal2: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); // Hash the signals to get x - let x1 = hash_to_field_le(&signal1); - let x2 = hash_to_field_le(&signal2); + let x1 = hash_to_field_le(&signal1).unwrap(); + let x2 = hash_to_field_le(&signal2).unwrap(); // Generate proofs using witness-based API // We call generate_rln_proof for first proof values @@ -715,7 +722,7 @@ mod test { // We generate a new identity pair let (identity_secret_new, id_commitment_new) = identity_pair_gen(); - let rate_commitment_new = poseidon_hash(&[id_commitment_new, user_message_limit]); + let rate_commitment_new = poseidon_hash(&[id_commitment_new, user_message_limit]).unwrap(); // We set as leaf id_commitment, its index would be equal to 1 since at 0 there is id_commitment let result = ffi_set_next_leaf(&mut ffi_rln_instance, &CFr::from(rate_commitment_new)); @@ -727,7 +734,7 @@ mod test { // We generate a random signal let signal3: [u8; 32] = rng.gen(); - let x3 = hash_to_field_le(&signal3); + let x3 = hash_to_field_le(&signal3).unwrap(); let rln_proof3 = rln_proof_gen( &ffi_rln_instance, @@ -773,8 +780,17 @@ mod test { // We generate a new identity tuple from an input seed let seed_bytes: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let key_gen = ffi_seeded_extended_key_gen(&seed_bytes.into()); - assert_eq!(key_gen.len(), 4, "seeded extended key gen call failed"); + let key_gen = match ffi_seeded_extended_key_gen(&seed_bytes.into()) { + CResult { + ok: Some(keys), + err: None, + } => keys, + CResult { + ok: None, + err: Some(err), + } => panic!("seeded extended key gen call failed: {}", err), + _ => unreachable!(), + }; let id_commitment = *key_gen[3]; // We insert the id_commitment into the tree at a random index diff --git a/rln/tests/ffi_utils.rs b/rln/tests/ffi_utils.rs index fe98a5a..fd06c7b 100644 --- a/rln/tests/ffi_utils.rs +++ b/rln/tests/ffi_utils.rs @@ -8,8 +8,17 @@ mod test { fn test_seeded_keygen_ffi() { // We generate a new identity pair from an input seed let seed_bytes: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let res = ffi_seeded_key_gen(&seed_bytes.into()); - assert_eq!(res.len(), 2, "seeded key gen call failed"); + let res = match ffi_seeded_key_gen(&seed_bytes.into()) { + CResult { + ok: Some(vec_cfr), + err: None, + } => vec_cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_seeded_key_gen call failed: {}", err), + _ => unreachable!(), + }; let identity_secret = res.first().unwrap(); let id_commitment = res.get(1).unwrap(); @@ -34,8 +43,17 @@ mod test { fn test_seeded_extended_keygen_ffi() { // We generate a new identity tuple from an input seed let seed_bytes: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let key_gen = ffi_seeded_extended_key_gen(&seed_bytes.into()); - assert_eq!(key_gen.len(), 4, "seeded extended key gen call failed"); + let key_gen = match ffi_seeded_extended_key_gen(&seed_bytes.into()) { + CResult { + ok: Some(vec_cfr), + err: None, + } => vec_cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_seeded_extended_key_gen call failed: {}", err), + _ => unreachable!(), + }; let identity_trapdoor = *key_gen[0]; let identity_nullifier = *key_gen[1]; let identity_secret = *key_gen[2]; @@ -87,7 +105,17 @@ mod test { let cfr_debug_str = ffi_cfr_debug(Some(&cfr_int)); assert_eq!(cfr_debug_str.to_string(), "42"); - let key_gen = ffi_key_gen(); + let key_gen = match ffi_key_gen() { + CResult { + ok: Some(vec_cfr), + err: None, + } => vec_cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_key_gen call failed: {}", err), + _ => unreachable!(), + }; let mut id_secret_fr = *key_gen[0]; let id_secret_hash = IdSecret::from(&mut id_secret_fr); let id_commitment = *key_gen[1]; @@ -187,13 +215,33 @@ mod test { let signal_gen: [u8; 32] = rng.gen(); let signal: Vec = signal_gen.to_vec(); - let cfr_le_1 = ffi_hash_to_field_le(&signal.clone().into()); - let fr_le_2 = hash_to_field_le(&signal); + let cfr_le_1 = match ffi_hash_to_field_le(&signal.clone().into()) { + CResult { + ok: Some(cfr), + err: None, + } => cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_hash_to_field_le call failed: {}", err), + _ => unreachable!(), + }; + let fr_le_2 = hash_to_field_le(&signal).unwrap(); assert_eq!(*cfr_le_1, fr_le_2); - let cfr_be_1 = ffi_hash_to_field_be(&signal.clone().into()); - let fr_be_2 = hash_to_field_be(&signal); - assert_eq!(*cfr_be_1, fr_be_2); + let cfr_be_1 = match ffi_hash_to_field_be(&signal.clone().into()) { + CResult { + ok: Some(cfr), + err: None, + } => cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_hash_to_field_be call failed: {}", err), + _ => unreachable!(), + }; + let fr_be_2 = hash_to_field_be(&signal).unwrap(); + assert_eq!(*cfr_le_1, fr_be_2); assert_eq!(*cfr_le_1, *cfr_be_1); assert_eq!(fr_le_2, fr_be_2); @@ -222,8 +270,19 @@ mod test { let input_1 = Fr::from(42u8); let input_2 = Fr::from(99u8); - let expected_hash = poseidon_hash(&[input_1, input_2]); - let received_hash_cfr = ffi_poseidon_hash_pair(&CFr::from(input_1), &CFr::from(input_2)); + let expected_hash = poseidon_hash(&[input_1, input_2]).unwrap(); + let received_hash_cfr = + match ffi_poseidon_hash_pair(&CFr::from(input_1), &CFr::from(input_2)) { + CResult { + ok: Some(cfr), + err: None, + } => cfr, + CResult { + ok: None, + err: Some(err), + } => panic!("ffi_poseidon_hash_pair call failed: {}", err), + _ => unreachable!(), + }; assert_eq!(*received_hash_cfr, expected_hash); } } diff --git a/rln/tests/poseidon_tree.rs b/rln/tests/poseidon_tree.rs index e58b652..f1dfd89 100644 --- a/rln/tests/poseidon_tree.rs +++ b/rln/tests/poseidon_tree.rs @@ -5,7 +5,9 @@ #[cfg(test)] mod test { use rln::prelude::*; - use utils::{FullMerkleTree, OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree}; + use zerokit_utils::merkle_tree::{ + FullMerkleTree, OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree, + }; #[test] // The test checked correctness for `FullMerkleTree` and `OptimalMerkleTree` with Poseidon hash @@ -22,12 +24,12 @@ mod test { .take(sample_size.try_into().unwrap()) { tree_full.set(i, leave).unwrap(); - let proof = tree_full.proof(i).expect("index should be set"); + let proof = tree_full.proof(i).unwrap(); assert_eq!(proof.leaf_index(), i); tree_opt.set(i, leave).unwrap(); assert_eq!(tree_opt.root(), tree_full.root()); - let proof = tree_opt.proof(i).expect("index should be set"); + let proof = tree_opt.proof(i).unwrap(); assert_eq!(proof.leaf_index(), i); } @@ -68,7 +70,7 @@ mod test { let prev_r = tree.get_subtree_root(n, idx_r).unwrap(); let subroot = tree.get_subtree_root(n - 1, idx_sr).unwrap(); - assert_eq!(poseidon_hash(&[prev_l, prev_r]), subroot); + assert_eq!(poseidon_hash(&[prev_l, prev_r]).unwrap(), subroot); } } } diff --git a/rln/tests/protocol.rs b/rln/tests/protocol.rs index eefa391..6906194 100644 --- a/rln/tests/protocol.rs +++ b/rln/tests/protocol.rs @@ -4,7 +4,7 @@ mod test { use ark_ff::BigInt; use rln::prelude::*; - use utils::{ZerokitMerkleProof, ZerokitMerkleTree}; + use zerokit_utils::merkle_tree::{ZerokitMerkleProof, ZerokitMerkleTree}; type ConfigOf = ::Config; @@ -14,9 +14,9 @@ mod test { let leaf_index = 3; // generate identity - let identity_secret = hash_to_field_le(b"test-merkle-proof"); - let id_commitment = poseidon_hash(&[identity_secret]); - let rate_commitment = poseidon_hash(&[id_commitment, 100.into()]); + let identity_secret = hash_to_field_le(b"test-merkle-proof").unwrap(); + let id_commitment = poseidon_hash(&[identity_secret]).unwrap(); + let rate_commitment = poseidon_hash(&[id_commitment, 100.into()]).unwrap(); // generate merkle tree let default_leaf = Fr::from(0); @@ -42,7 +42,7 @@ mod test { .into() ); - let merkle_proof = tree.proof(leaf_index).expect("proof should exist"); + let merkle_proof = tree.proof(leaf_index).unwrap(); let path_elements = merkle_proof.get_path_elements(); let identity_path_index = merkle_proof.get_path_index(); @@ -69,7 +69,7 @@ mod test { "0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f", "0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca", ] - .map(|e| str_to_fr(e, 16).unwrap()) + .map(|str| str_to_fr(str, 16).unwrap()) .to_vec(); let expected_identity_path_index: Vec = @@ -85,9 +85,9 @@ mod test { fn get_test_witness() -> RLNWitnessInput { let leaf_index = 3; // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); //// generate merkle tree let default_leaf = Fr::from(0); @@ -99,15 +99,15 @@ mod test { .unwrap(); tree.set(leaf_index, rate_commitment).unwrap(); - let merkle_proof = tree.proof(leaf_index).expect("proof should exist"); + let merkle_proof = tree.proof(leaf_index).unwrap(); let signal = b"hey hey"; - let x = hash_to_field_le(signal); + let x = hash_to_field_le(signal).unwrap(); // We set the remaining values to random ones - let epoch = hash_to_field_le(b"test-epoch"); - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); let message_id = Fr::from(1); @@ -165,7 +165,7 @@ mod test { fn test_seeded_keygen() { // Generate identity pair using a seed phrase let seed_phrase: &str = "A seed phrase example"; - let (identity_secret, id_commitment) = seeded_keygen(seed_phrase.as_bytes()); + let (identity_secret, id_commitment) = seeded_keygen(seed_phrase.as_bytes()).unwrap(); // We check against expected values let expected_identity_secret_seed_phrase = str_to_fr( @@ -184,7 +184,7 @@ mod test { // Generate identity pair using an byte array let seed_bytes: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let (identity_secret, id_commitment) = seeded_keygen(seed_bytes); + let (identity_secret, id_commitment) = seeded_keygen(seed_bytes).unwrap(); // We check against expected values let expected_identity_secret_seed_bytes = str_to_fr( @@ -202,7 +202,7 @@ mod test { assert_eq!(id_commitment, expected_id_commitment_seed_bytes); // We check again if the identity pair generated with the same seed phrase corresponds to the previously generated one - let (identity_secret, id_commitment) = seeded_keygen(seed_phrase.as_bytes()); + let (identity_secret, id_commitment) = seeded_keygen(seed_phrase.as_bytes()).unwrap(); assert_eq!(identity_secret, expected_identity_secret_seed_phrase); assert_eq!(id_commitment, expected_id_commitment_seed_phrase); diff --git a/rln/tests/public.rs b/rln/tests/public.rs index f5802c7..df74a5e 100644 --- a/rln/tests/public.rs +++ b/rln/tests/public.rs @@ -46,21 +46,21 @@ mod test { let mut rng = thread_rng(); let identity_secret = IdSecret::rand(&mut rng); - let x = hash_to_field_le(&rng.gen::<[u8; 32]>()); - let epoch = hash_to_field_le(&rng.gen::<[u8; 32]>()); - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let x = hash_to_field_le(&rng.gen::<[u8; 32]>()).unwrap(); + let epoch = hash_to_field_le(&rng.gen::<[u8; 32]>()).unwrap(); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); let mut path_elements: Vec = Vec::new(); let mut identity_path_index: Vec = Vec::new(); for _ in 0..tree_depth { - path_elements.push(hash_to_field_le(&rng.gen::<[u8; 32]>())); + path_elements.push(hash_to_field_le(&rng.gen::<[u8; 32]>()).unwrap()); identity_path_index.push(rng.gen_range(0..2) as u8); } let user_message_limit = Fr::from(100); let message_id = Fr::from(1); - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); RLNWitnessInput::new( identity_secret, @@ -478,9 +478,7 @@ mod test { let root_empty = rln.get_root(); // We add leaves in a batch into the tree - #[allow(unused_must_use)] - rln.set_leaves_from(bad_index, leaves) - .expect_err("Should throw an error"); + assert!(rln.set_leaves_from(bad_index, leaves).is_err()); // We check if number of leaves set is consistent assert_eq!(rln.leaves_set(), 0); @@ -548,7 +546,7 @@ mod test { let mut rng = thread_rng(); for _ in 0..NO_OF_LEAVES { let id_commitment = Fr::rand(&mut rng); - let rate_commitment = poseidon_hash(&[id_commitment, Fr::from(100)]); + let rate_commitment = poseidon_hash(&[id_commitment, Fr::from(100)]).unwrap(); leaves.push(rate_commitment); } @@ -559,12 +557,12 @@ mod test { rln.init_tree_with_leaves(leaves.clone()).unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); // We set as leaf rate_commitment after storing its index let identity_index = rln.leaves_set(); let user_message_limit = Fr::from(65535); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); rln.set_next_leaf(rate_commitment).unwrap(); // We generate a random signal @@ -572,16 +570,16 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); // Hash the signal to get x - let x = hash_to_field_le(&signal); + let x = hash_to_field_le(&signal).unwrap(); // Get merkle proof for the identity let (path_elements, identity_path_index) = @@ -626,12 +624,12 @@ mod test { rln.init_tree_with_leaves(leaves.clone()).unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); // We set as leaf rate_commitment after storing its index let identity_index = rln.leaves_set(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); rln.set_next_leaf(rate_commitment).unwrap(); // We generate a random signal @@ -639,16 +637,16 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); // Hash the signal to get x - let x = hash_to_field_le(&signal); + let x = hash_to_field_le(&signal).unwrap(); // Get merkle proof for the identity let (path_elements, identity_path_index) = @@ -694,12 +692,12 @@ mod test { rln.init_tree_with_leaves(leaves.clone()).unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); // We set as leaf rate_commitment after storing its index let identity_index = rln.leaves_set(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); rln.set_next_leaf(rate_commitment).unwrap(); // We generate a random signal @@ -707,16 +705,16 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); // Hash the signal to get x - let x = hash_to_field_le(&signal); + let x = hash_to_field_le(&signal).unwrap(); // Get merkle proof for the identity let (path_elements, identity_path_index) = @@ -776,9 +774,9 @@ mod test { let mut rln = RLN::new(tree_depth, "").unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); // We set as leaf rate_commitment, its index would be equal to 0 since tree is empty let identity_index = rln.leaves_set(); @@ -792,17 +790,17 @@ mod test { let signal2: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); // We generate a external nullifier - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We choose a message_id satisfy 0 <= message_id < MESSAGE_LIMIT let message_id = Fr::from(1); // Hash the signals to get x values - let x1 = hash_to_field_le(&signal1); - let x2 = hash_to_field_le(&signal2); + let x1 = hash_to_field_le(&signal1).unwrap(); + let x2 = hash_to_field_le(&signal2).unwrap(); // Get merkle proof for the identity let (path_elements, identity_path_index) = @@ -847,8 +845,9 @@ mod test { // We now test that computing identity_secret is unsuccessful if shares computed from two different identity secret but within same epoch are passed // We generate a new identity pair - let (identity_secret_new, id_commitment_new) = keygen(); - let rate_commitment_new = poseidon_hash(&[id_commitment_new, user_message_limit]); + let (identity_secret_new, id_commitment_new) = keygen().unwrap(); + let rate_commitment_new = + poseidon_hash(&[id_commitment_new, user_message_limit]).unwrap(); // We add it to the tree let identity_index_new = rln.leaves_set(); @@ -856,7 +855,7 @@ mod test { // We generate a random signal let signal3: [u8; 32] = rng.gen(); - let x3 = hash_to_field_le(&signal3); + let x3 = hash_to_field_le(&signal3).unwrap(); // Get merkle proof for the new identity let (path_elements_new, identity_path_index_new) = @@ -930,7 +929,9 @@ mod test { protocol::*, public::RLN, }; - use utils::{OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree}; + use zerokit_utils::merkle_tree::{ + OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree, + }; use super::DEFAULT_TREE_DEPTH; @@ -950,12 +951,12 @@ mod test { .unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); // We set as leaf rate_commitment after storing its index let identity_index = tree.leaves_set(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); tree.update_next(rate_commitment).unwrap(); // We generate a random signal @@ -963,14 +964,14 @@ mod test { let signal: [u8; 32] = rng.gen(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // Hash the signal to get x - let x = hash_to_field_le(&signal); - let merkle_proof = tree.proof(identity_index).expect("proof should exist"); + let x = hash_to_field_le(&signal).unwrap(); + let merkle_proof = tree.proof(identity_index).unwrap(); let message_id = Fr::from(1); let rln_witness = RLNWitnessInput::new( @@ -1032,27 +1033,27 @@ mod test { .unwrap(); // Generate identity pair - let (identity_secret, id_commitment) = keygen(); + let (identity_secret, id_commitment) = keygen().unwrap(); let user_message_limit = Fr::from(100); - let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); tree.update_next(rate_commitment).unwrap(); // We generate a random epoch - let epoch = hash_to_field_le(b"test-epoch"); + let epoch = hash_to_field_le(b"test-epoch").unwrap(); // We generate a random rln_identifier - let rln_identifier = hash_to_field_le(b"test-rln-identifier"); - let external_nullifier = poseidon_hash(&[epoch, rln_identifier]); + let rln_identifier = hash_to_field_le(b"test-rln-identifier").unwrap(); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); // We generate a random signal let mut rng = thread_rng(); let signal1: [u8; 32] = rng.gen(); - let x1 = hash_to_field_le(&signal1); + let x1 = hash_to_field_le(&signal1).unwrap(); let signal2: [u8; 32] = rng.gen(); - let x2 = hash_to_field_le(&signal2); + let x2 = hash_to_field_le(&signal2).unwrap(); let identity_index = tree.leaves_set(); - let merkle_proof = tree.proof(identity_index).expect("proof should exist"); + let merkle_proof = tree.proof(identity_index).unwrap(); let message_id = Fr::from(1); let rln_witness1 = RLNWitnessInput::new( @@ -1091,15 +1092,16 @@ mod test { // We now test that computing identity_secret is unsuccessful if shares computed from two different identity secret but within same epoch are passed // We generate a new identity pair - let (identity_secret_new, id_commitment_new) = keygen(); - let rate_commitment_new = poseidon_hash(&[id_commitment_new, user_message_limit]); + let (identity_secret_new, id_commitment_new) = keygen().unwrap(); + let rate_commitment_new = + poseidon_hash(&[id_commitment_new, user_message_limit]).unwrap(); tree.update_next(rate_commitment_new).unwrap(); let signal3: [u8; 32] = rng.gen(); - let x3 = hash_to_field_le(&signal3); + let x3 = hash_to_field_le(&signal3).unwrap(); let identity_index_new = tree.leaves_set(); - let merkle_proof_new = tree.proof(identity_index_new).expect("proof should exist"); + let merkle_proof_new = tree.proof(identity_index_new).unwrap(); let rln_witness3 = RLNWitnessInput::new( identity_secret_new.clone(), diff --git a/utils/benches/merkle_tree_benchmark.rs b/utils/benches/merkle_tree_benchmark.rs index f50f128..16de04e 100644 --- a/utils/benches/merkle_tree_benchmark.rs +++ b/utils/benches/merkle_tree_benchmark.rs @@ -3,8 +3,11 @@ use std::{fmt::Display, str::FromStr, sync::LazyLock}; use criterion::{criterion_group, criterion_main, Criterion}; use tiny_keccak::{Hasher as _, Keccak}; use zerokit_utils::{ - FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, - ZerokitMerkleTree, + error::HashError, + merkle_tree::{ + FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, + ZerokitMerkleTree, + }, }; #[derive(Clone, Copy, Eq, PartialEq)] @@ -15,19 +18,20 @@ struct TestFr([u8; 32]); impl Hasher for Keccak256 { type Fr = TestFr; + type Error = HashError; fn default_leaf() -> Self::Fr { TestFr([0; 32]) } - fn hash(inputs: &[Self::Fr]) -> Self::Fr { + fn hash(inputs: &[Self::Fr]) -> Result { let mut output = [0; 32]; let mut hasher = Keccak::v256(); for element in inputs { hasher.update(element.0.as_slice()); } hasher.finalize(&mut output); - TestFr(output) + Ok(TestFr(output)) } } diff --git a/utils/benches/poseidon_benchmark.rs b/utils/benches/poseidon_benchmark.rs index 401e50e..e0ce233 100644 --- a/utils/benches/poseidon_benchmark.rs +++ b/utils/benches/poseidon_benchmark.rs @@ -2,7 +2,7 @@ use std::hint::black_box; use ark_bn254::Fr; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; -use zerokit_utils::Poseidon; +use zerokit_utils::poseidon::Poseidon; const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [ (2, 8, 56, 0), diff --git a/utils/src/error.rs b/utils/src/error.rs new file mode 100644 index 0000000..228b15f --- /dev/null +++ b/utils/src/error.rs @@ -0,0 +1,11 @@ +use super::poseidon::error::PoseidonError; +pub use crate::merkle_tree::{FromConfigError, ZerokitMerkleTreeError}; + +/// Errors that can occur during hashing operations. +#[derive(Debug, thiserror::Error)] +pub enum HashError { + #[error("Poseidon hash error: {0}")] + Poseidon(#[from] PoseidonError), + #[error("Generic hash error: {0}")] + Generic(String), +} diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 597f453..5585358 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,10 +1,4 @@ -pub mod poseidon; -pub use self::poseidon::*; - +pub mod error; pub mod merkle_tree; -pub use self::merkle_tree::*; - -#[cfg(feature = "pmtree-ft")] pub mod pm_tree; -#[cfg(feature = "pmtree-ft")] -pub use self::pm_tree::*; +pub mod poseidon; diff --git a/utils/src/merkle_tree/error.rs b/utils/src/merkle_tree/error.rs index c002c3a..658cdf7 100644 --- a/utils/src/merkle_tree/error.rs +++ b/utils/src/merkle_tree/error.rs @@ -1,8 +1,12 @@ -#[derive(thiserror::Error, Debug)] +use crate::error::HashError; + +/// Errors that can occur during Merkle tree operations +#[derive(Debug, thiserror::Error)] pub enum ZerokitMerkleTreeError { #[error("Invalid index")] InvalidIndex, - // InvalidProof, + #[error("Invalid indices")] + InvalidIndices, #[error("Leaf index out of bounds")] InvalidLeaf, #[error("Level exceeds tree depth")] @@ -20,8 +24,11 @@ pub enum ZerokitMerkleTreeError { #[cfg(feature = "pmtree-ft")] #[error("Pmtree error: {0}")] PmtreeErrorKind(#[from] pmtree::PmtreeErrorKind), + #[error("Hash error: {0}")] + HashError(#[from] HashError), } +/// Errors that can occur while creating Merkle tree from config #[derive(Debug, thiserror::Error)] pub enum FromConfigError { #[error("Error while reading pmtree config: {0}")] @@ -30,4 +37,6 @@ pub enum FromConfigError { MissingPath, #[error("Error while creating pmtree config: path already exists")] PathExists, + #[error("Error while creating pmtree default temp path: {0}")] + IoError(#[from] std::io::Error), } diff --git a/utils/src/merkle_tree/full_merkle_tree.rs b/utils/src/merkle_tree/full_merkle_tree.rs index e58f851..1ac2e84 100644 --- a/utils/src/merkle_tree/full_merkle_tree.rs +++ b/utils/src/merkle_tree/full_merkle_tree.rs @@ -7,9 +7,9 @@ use std::{ use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use crate::merkle_tree::{ +use super::{ error::{FromConfigError, ZerokitMerkleTreeError}, - FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES, + merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES}, }; // Full Merkle Tree Implementation @@ -40,7 +40,7 @@ where /// Element of a Merkle proof #[derive(Clone, Copy, PartialEq, Eq)] -pub enum FullMerkleBranch { +pub(crate) enum FullMerkleBranch { /// Left branch taken, value is the right sibling hash. Left(H::Fr), @@ -50,7 +50,7 @@ pub enum FullMerkleBranch { /// Merkle proof path, bottom to top. #[derive(Clone, PartialEq, Eq)] -pub struct FullMerkleProof(pub Vec>); +pub struct FullMerkleProof(Vec>); #[derive(Default)] pub struct FullMerkleConfig(()); @@ -87,7 +87,7 @@ where let mut cached_nodes: Vec = Vec::with_capacity(depth + 1); cached_nodes.push(default_leaf); for i in 0..depth { - cached_nodes.push(H::hash(&[cached_nodes[i]; 2])); + cached_nodes.push(H::hash(&[cached_nodes[i]; 2]).map_err(Into::into)?); } cached_nodes.reverse(); @@ -164,13 +164,14 @@ where let mut idx = self.capacity() + index - 1; let mut nd = self.depth; loop { - let parent = self.parent(idx).expect("parent should exist"); + let parent = self + .parent(idx) + .ok_or(ZerokitMerkleTreeError::InvalidIndex)?; nd -= 1; if nd == n { return Ok(self.nodes[parent]); } else { idx = parent; - continue; } } } @@ -225,7 +226,10 @@ where J: ExactSizeIterator, { let indices = indices.into_iter().collect::>(); - let min_index = *indices.first().expect("indices should not be empty"); + if indices.is_empty() { + return Err(ZerokitMerkleTreeError::InvalidIndices); + } + let min_index = indices[0]; let leaves_vec = leaves.into_iter().collect::>(); let max_index = start + leaves_vec.len(); @@ -291,7 +295,7 @@ where hash: &FrOf, merkle_proof: &FullMerkleProof, ) -> Result { - Ok(merkle_proof.compute_root_from(hash) == self.root()) + Ok(merkle_proof.compute_root_from(hash)? == self.root()) } fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), ZerokitMerkleTreeError> { @@ -351,17 +355,20 @@ where { // Use parallel processing when the number of pairs exceeds the threshold if end_parent - start_parent + 1 >= MIN_PARALLEL_NODES { - let updates: Vec<(usize, H::Fr)> = (start_parent..=end_parent) + #[allow(clippy::type_complexity)] + let updates: Result, ZerokitMerkleTreeError> = (start_parent + ..=end_parent) .into_par_iter() .map(|parent| { let left_child = self.first_child(parent); let right_child = left_child + 1; - let hash = H::hash(&[self.nodes[left_child], self.nodes[right_child]]); - (parent, hash) + let hash = H::hash(&[self.nodes[left_child], self.nodes[right_child]]) + .map_err(Into::into)?; + Ok((parent, hash)) }) .collect(); - for (parent, hash) in updates { + for (parent, hash) in updates? { self.nodes[parent] = hash; } } else { @@ -370,7 +377,8 @@ where let left_child = self.first_child(parent); let right_child = left_child + 1; self.nodes[parent] = - H::hash(&[self.nodes[left_child], self.nodes[right_child]]); + H::hash(&[self.nodes[left_child], self.nodes[right_child]]) + .map_err(Into::into)?; } } @@ -421,10 +429,13 @@ impl ZerokitMerkleProof for FullMerkleProof { } /// Computes the Merkle root corresponding by iteratively hashing a Merkle proof with a given input leaf - fn compute_root_from(&self, hash: &FrOf) -> FrOf { - self.0.iter().fold(*hash, |hash, branch| match branch { - FullMerkleBranch::Left(sibling) => H::hash(&[hash, *sibling]), - FullMerkleBranch::Right(sibling) => H::hash(&[*sibling, hash]), + fn compute_root_from( + &self, + hash: &FrOf, + ) -> Result, ZerokitMerkleTreeError> { + self.0.iter().try_fold(*hash, |hash, branch| match branch { + FullMerkleBranch::Left(sibling) => H::hash(&[hash, *sibling]).map_err(Into::into), + FullMerkleBranch::Right(sibling) => H::hash(&[*sibling, hash]).map_err(Into::into), }) } } diff --git a/utils/src/merkle_tree/merkle_tree.rs b/utils/src/merkle_tree/merkle_tree.rs index bd79bc7..a4bcc3f 100644 --- a/utils/src/merkle_tree/merkle_tree.rs +++ b/utils/src/merkle_tree/merkle_tree.rs @@ -7,33 +7,30 @@ // Merkle tree implementations are adapted from https://github.com/kilic/rln/blob/master/src/merkle.rs // and https://github.com/worldcoin/semaphore-rs/blob/d462a4372f1fd9c27610f2acfe4841fab1d396aa/src/merkle_tree.rs -//! -//! # TODO -//! -//! * Disk based storage backend (using mmaped files should be easy) -//! * Implement serialization for tree and Merkle proof - use std::{ fmt::{Debug, Display}, str::FromStr, }; -use crate::merkle_tree::error::ZerokitMerkleTreeError; +use super::error::ZerokitMerkleTreeError; /// Enables parallel hashing when there are at least 8 nodes (4 pairs to hash), justifying the overhead. pub const MIN_PARALLEL_NODES: usize = 8; -/// In the Hasher trait we define the node type, the default leaf -/// and the hash function used to initialize a Merkle Tree implementation +/// In the Hasher trait we define the node type, the default leaf, +/// and the hash function used to initialize a Merkle Tree implementation. pub trait Hasher { /// Type of the leaf and tree node type Fr: Clone + Copy + Eq + Default + Debug + Display + FromStr + Send + Sync; + /// Error type for hash operations - must be convertible to ZerokitMerkleTreeError + type Error: Into + std::error::Error + Send + Sync + 'static; + /// Returns the default tree leaf fn default_leaf() -> Self::Fr; /// Utility to compute the hash of an intermediate node - fn hash(input: &[Self::Fr]) -> Self::Fr; + fn hash(input: &[Self::Fr]) -> Result; } pub type FrOf = ::Fr; @@ -101,5 +98,8 @@ pub trait ZerokitMerkleProof { fn leaf_index(&self) -> usize; fn get_path_elements(&self) -> Vec>; fn get_path_index(&self) -> Vec; - fn compute_root_from(&self, leaf: &FrOf) -> FrOf; + fn compute_root_from( + &self, + leaf: &FrOf, + ) -> Result, ZerokitMerkleTreeError>; } diff --git a/utils/src/merkle_tree/mod.rs b/utils/src/merkle_tree/mod.rs index 3acec88..b0da34f 100644 --- a/utils/src/merkle_tree/mod.rs +++ b/utils/src/merkle_tree/mod.rs @@ -4,8 +4,7 @@ pub mod full_merkle_tree; pub mod merkle_tree; pub mod optimal_merkle_tree; -pub use self::{ - full_merkle_tree::{FullMerkleConfig, FullMerkleProof, FullMerkleTree}, - merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES}, - optimal_merkle_tree::{OptimalMerkleConfig, OptimalMerkleProof, OptimalMerkleTree}, -}; +pub use error::{FromConfigError, ZerokitMerkleTreeError}; +pub use full_merkle_tree::{FullMerkleConfig, FullMerkleProof, FullMerkleTree}; +pub use merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES}; +pub use optimal_merkle_tree::{OptimalMerkleConfig, OptimalMerkleProof, OptimalMerkleTree}; diff --git a/utils/src/merkle_tree/optimal_merkle_tree.rs b/utils/src/merkle_tree/optimal_merkle_tree.rs index 71387f9..ba05dd4 100644 --- a/utils/src/merkle_tree/optimal_merkle_tree.rs +++ b/utils/src/merkle_tree/optimal_merkle_tree.rs @@ -2,9 +2,9 @@ use std::{cmp::max, collections::HashMap, fmt::Debug, str::FromStr}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use crate::merkle_tree::{ +use super::{ error::{FromConfigError, ZerokitMerkleTreeError}, - FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES, + merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES}, }; // Optimal Merkle Tree Implementation @@ -79,7 +79,7 @@ where let mut cached_nodes: Vec = Vec::with_capacity(depth + 1); cached_nodes.push(default_leaf); for i in 0..depth { - cached_nodes.push(H::hash(&[cached_nodes[i]; 2])); + cached_nodes.push(H::hash(&[cached_nodes[i]; 2]).map_err(Into::into)?); } cached_nodes.reverse(); @@ -197,7 +197,10 @@ where J: ExactSizeIterator, { let indices = indices.into_iter().collect::>(); - let min_index = *indices.first().expect("indices should not be empty"); + if indices.is_empty() { + return Err(ZerokitMerkleTreeError::InvalidIndices); + } + let min_index = indices[0]; let leaves_vec = leaves.into_iter().collect::>(); let max_index = start + leaves_vec.len(); @@ -248,10 +251,7 @@ where let mut depth = self.depth; loop { i ^= 1; - witness.push(( - self.get_node(depth, i), - (1 - (i & 1)).try_into().expect("0 or 1 expected"), - )); + witness.push((self.get_node(depth, i), (1 - (i & 1)) as u8)); i >>= 1; depth -= 1; if depth == 0 { @@ -274,7 +274,7 @@ where if merkle_proof.length() != self.depth { return Err(ZerokitMerkleTreeError::InvalidMerkleProof); } - let expected_root = merkle_proof.compute_root_from(leaf); + let expected_root = merkle_proof.compute_root_from(leaf)?; Ok(expected_root.eq(&self.root())) } @@ -304,9 +304,9 @@ where /// Computes the hash of a node’s two children at the given depth. /// If the index is odd, it is rounded down to the nearest even index. - fn hash_couple(&self, depth: usize, index: usize) -> H::Fr { + fn hash_couple(&self, depth: usize, index: usize) -> Result { let b = index & !1; - H::hash(&[self.get_node(depth, b), self.get_node(depth, b + 1)]) + H::hash(&[self.get_node(depth, b), self.get_node(depth, b + 1)]).map_err(Into::into) } /// Updates parent hashes after modifying a range of leaf nodes. @@ -330,25 +330,29 @@ where // Use parallel processing when the number of pairs exceeds the threshold if current_index_max - current_index >= MIN_PARALLEL_NODES { - let updates: Vec<((usize, usize), H::Fr)> = (current_index..current_index_max) + #[allow(clippy::type_complexity)] + let updates: Result< + Vec<((usize, usize), H::Fr)>, + ZerokitMerkleTreeError, + > = (current_index..current_index_max) .step_by(2) .collect::>() .into_par_iter() .map(|index| { // Hash two child nodes at positions (current_depth, index) and (current_depth, index + 1) - let hash = self.hash_couple(current_depth, index); + let hash = self.hash_couple(current_depth, index)?; // Return the computed parent hash and its position at - ((parent_depth, index >> 1), hash) + Ok(((parent_depth, index >> 1), hash)) }) .collect(); - for (parent, hash) in updates { + for (parent, hash) in updates? { self.nodes.insert(parent, hash); } } else { // Otherwise, fallback to sequential update for small ranges for index in (current_index..current_index_max).step_by(2) { - let hash = self.hash_couple(current_depth, index); + let hash = self.hash_couple(current_depth, index)?; self.nodes.insert((parent_depth, index >> 1), hash); } } @@ -396,16 +400,16 @@ where } /// Computes the Merkle root corresponding by iteratively hashing a Merkle proof with a given input leaf - fn compute_root_from(&self, leaf: &H::Fr) -> H::Fr { + fn compute_root_from(&self, leaf: &H::Fr) -> Result { let mut acc: H::Fr = *leaf; for w in self.0.iter() { if w.1 == 0 { - acc = H::hash(&[acc, w.0]); + acc = H::hash(&[acc, w.0]).map_err(Into::into)?; } else { - acc = H::hash(&[w.0, acc]); + acc = H::hash(&[w.0, acc]).map_err(Into::into)?; } } - acc + Ok(acc) } } diff --git a/utils/src/pm_tree/mod.rs b/utils/src/pm_tree/mod.rs index 441d710..5c33061 100644 --- a/utils/src/pm_tree/mod.rs +++ b/utils/src/pm_tree/mod.rs @@ -1,5 +1,7 @@ +#![cfg(feature = "pmtree-ft")] + pub mod sled_adapter; + pub use pmtree; pub use sled::{Config, Mode}; - -pub use self::sled_adapter::SledDB; +pub use sled_adapter::SledDB; diff --git a/utils/src/poseidon/error.rs b/utils/src/poseidon/error.rs new file mode 100644 index 0000000..b802476 --- /dev/null +++ b/utils/src/poseidon/error.rs @@ -0,0 +1,8 @@ +/// Errors that can occur during Poseidon hash computations +#[derive(Debug, thiserror::Error)] +pub enum PoseidonError { + #[error("No parameters found for input length {0}")] + NoParametersForInputLength(usize), + #[error("Empty input provided")] + EmptyInput, +} diff --git a/utils/src/poseidon/mod.rs b/utils/src/poseidon/mod.rs index 7fe211d..f7384e4 100644 --- a/utils/src/poseidon/mod.rs +++ b/utils/src/poseidon/mod.rs @@ -1,4 +1,5 @@ -pub mod poseidon_hash; -pub use poseidon_hash::Poseidon; - +pub mod error; pub mod poseidon_constants; +pub mod poseidon_hash; + +pub use self::poseidon_hash::Poseidon; diff --git a/utils/src/poseidon/poseidon_constants.rs b/utils/src/poseidon/poseidon_constants.rs index b6cca33..499b921 100644 --- a/utils/src/poseidon/poseidon_constants.rs +++ b/utils/src/poseidon/poseidon_constants.rs @@ -12,14 +12,14 @@ use ark_ff::PrimeField; use num_bigint::BigUint; -pub struct PoseidonGrainLFSR { +struct PoseidonGrainLFSR { pub prime_num_bits: u64, pub state: [bool; 80], pub head: usize, } impl PoseidonGrainLFSR { - pub fn new( + fn new( is_field: u64, is_sbox_an_inverse: u64, prime_num_bits: u64, @@ -92,7 +92,7 @@ impl PoseidonGrainLFSR { res } - pub fn get_bits(&mut self, num_bits: usize) -> Vec { + fn get_bits(&mut self, num_bits: usize) -> Vec { let mut res = Vec::new(); for _ in 0..num_bits { @@ -114,10 +114,7 @@ impl PoseidonGrainLFSR { res } - pub fn get_field_elements_rejection_sampling( - &mut self, - num_elems: usize, - ) -> Vec { + fn get_field_elements_rejection_sampling(&mut self, num_elems: usize) -> Vec { assert_eq!(F::MODULUS_BIT_SIZE as u64, self.prime_num_bits); let modulus: BigUint = F::MODULUS.into(); @@ -151,7 +148,7 @@ impl PoseidonGrainLFSR { res } - pub fn get_field_elements_mod_p(&mut self, num_elems: usize) -> Vec { + fn get_field_elements_mod_p(&mut self, num_elems: usize) -> Vec { assert_eq!(F::MODULUS_BIT_SIZE as u64, self.prime_num_bits); let mut res = Vec::new(); @@ -253,7 +250,10 @@ pub fn find_poseidon_ark_and_mds( for i in 0..(rate) { for (j, ys_item) in ys.iter().enumerate().take(rate) { - mds[i][j] = (xs[i] + ys_item).inverse().unwrap(); + // Poseidon algorithm guarantees xs[i] + ys[j] != 0 + mds[i][j] = (xs[i] + ys_item) + .inverse() + .expect("MDS matrix inverse must be valid"); } } diff --git a/utils/src/poseidon/poseidon_hash.rs b/utils/src/poseidon/poseidon_hash.rs index c7d26fe..4c248ee 100644 --- a/utils/src/poseidon/poseidon_hash.rs +++ b/utils/src/poseidon/poseidon_hash.rs @@ -5,7 +5,7 @@ use ark_ff::PrimeField; -use crate::poseidon_constants::find_poseidon_ark_and_mds; +use super::{error::PoseidonError, poseidon_constants::find_poseidon_ark_and_mds}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct RoundParameters { @@ -20,6 +20,7 @@ pub struct RoundParameters { pub struct Poseidon { round_params: Vec>, } + impl Poseidon { // Loads round parameters and generates round constants // poseidon_params is a vector containing tuples (t, RF, RP, skip_matrices) @@ -93,18 +94,20 @@ impl Poseidon { } } - pub fn hash(&self, inp: &[F]) -> Result { + pub fn hash(&self, inp: &[F]) -> Result { // Note that the rate t becomes input length + 1; hence for length N we pick parameters with T = N + 1 let t = inp.len() + 1; - // We seek the index (Poseidon's round_params is an ordered vector) for the parameters corresponding to t - let param_index = self.round_params.iter().position(|el| el.t == t); - - if inp.is_empty() || param_index.is_none() { - return Err("No parameters found for inputs length".to_string()); + if inp.is_empty() { + return Err(PoseidonError::EmptyInput); } - let param_index = param_index.unwrap(); + // We seek the index (Poseidon's round_params is an ordered vector) for the parameters corresponding to t + let param_index = self + .round_params + .iter() + .position(|el| el.t == t) + .ok_or(PoseidonError::NoParametersForInputLength(inp.len()))?; let mut state = vec![F::ZERO; t]; let mut state_2 = state.clone(); diff --git a/utils/tests/merkle_tree.rs b/utils/tests/merkle_tree.rs index ec62d9c..4d3aa0b 100644 --- a/utils/tests/merkle_tree.rs +++ b/utils/tests/merkle_tree.rs @@ -6,8 +6,11 @@ mod test { use hex_literal::hex; use tiny_keccak::{Hasher as _, Keccak}; use zerokit_utils::{ - FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, - ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES, + error::HashError, + merkle_tree::{ + FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree, + ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES, + }, }; #[derive(Clone, Copy, Eq, PartialEq)] struct Keccak256; @@ -17,19 +20,20 @@ mod test { impl Hasher for Keccak256 { type Fr = TestFr; + type Error = HashError; fn default_leaf() -> Self::Fr { TestFr([0; 32]) } - fn hash(inputs: &[Self::Fr]) -> Self::Fr { + fn hash(inputs: &[Self::Fr]) -> Result { let mut output = [0; 32]; let mut hasher = Keccak::v256(); for element in inputs { hasher.update(element.0.as_slice()); } hasher.finalize(&mut output); - TestFr(output) + Ok(TestFr(output)) } } @@ -43,7 +47,7 @@ mod test { type Err = std::string::FromUtf8Error; fn from_str(s: &str) -> Result { - Ok(TestFr(s.as_bytes().try_into().expect("Invalid length"))) + Ok(TestFr(s.as_bytes().try_into().unwrap())) } } @@ -51,7 +55,7 @@ mod test { fn from(value: u32) -> Self { let mut bytes: Vec = vec![0; 28]; bytes.extend_from_slice(&value.to_be_bytes()); - TestFr(bytes.as_slice().try_into().expect("Invalid length")) + TestFr(bytes.as_slice().try_into().unwrap()) } } @@ -59,12 +63,12 @@ mod test { fn default_full_merkle_tree(depth: usize) -> FullMerkleTree { FullMerkleTree::::new(depth, TestFr([0; 32]), FullMerkleConfig::default()) - .expect("Failed to create FullMerkleTree") + .unwrap() } fn default_optimal_merkle_tree(depth: usize) -> OptimalMerkleTree { OptimalMerkleTree::::new(depth, TestFr([0; 32]), OptimalMerkleConfig::default()) - .expect("Failed to create OptimalMerkleTree") + .unwrap() } #[test] @@ -87,14 +91,14 @@ mod test { let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH); assert_eq!(tree_full.root(), default_tree_root); for i in 0..nof_leaves { - tree_full.set(i, leaves[i]).expect("Failed to set leaf"); + tree_full.set(i, leaves[i]).unwrap(); assert_eq!(tree_full.root(), roots[i]); } let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH); assert_eq!(tree_opt.root(), default_tree_root); for i in 0..nof_leaves { - tree_opt.set(i, leaves[i]).expect("Failed to set leaf"); + tree_opt.set(i, leaves[i]).unwrap(); assert_eq!(tree_opt.root(), roots[i]); } } @@ -106,17 +110,13 @@ mod test { let mut tree_full = default_full_merkle_tree(depth); let root_before = tree_full.root(); - tree_full - .set_range(0, leaves.iter().cloned()) - .expect("Failed to set leaves"); + tree_full.set_range(0, leaves.iter().cloned()).unwrap(); let root_after = tree_full.root(); assert_ne!(root_before, root_after); let mut tree_opt = default_optimal_merkle_tree(depth); let root_before = tree_opt.root(); - tree_opt - .set_range(0, leaves.iter().cloned()) - .expect("Failed to set leaves"); + tree_opt.set_range(0, leaves.iter().cloned()).unwrap(); let root_after = tree_opt.root(); assert_ne!(root_before, root_after); } @@ -128,10 +128,10 @@ mod test { for i in 0..4 { let leaf = TestFr::from(i as u32); - tree_full.update_next(leaf).expect("Failed to update leaf"); - tree_opt.update_next(leaf).expect("Failed to update leaf"); - assert_eq!(tree_full.get(i).expect("Failed to get leaf"), leaf); - assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), leaf); + tree_full.update_next(leaf).unwrap(); + tree_opt.update_next(leaf).unwrap(); + assert_eq!(tree_full.get(i).unwrap(), leaf); + assert_eq!(tree_opt.get(i).unwrap(), leaf); } assert_eq!(tree_full.leaves_set(), 4); @@ -145,38 +145,34 @@ mod test { let new_leaf = TestFr::from(99); let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH); - tree_full - .set(index, original_leaf) - .expect("Failed to set leaf"); + tree_full.set(index, original_leaf).unwrap(); let root_with_original = tree_full.root(); - tree_full.delete(index).expect("Failed to delete leaf"); + tree_full.delete(index).unwrap(); let root_after_delete = tree_full.root(); assert_ne!(root_with_original, root_after_delete); - tree_full.set(index, new_leaf).expect("Failed to set leaf"); + tree_full.set(index, new_leaf).unwrap(); let root_after_reset = tree_full.root(); assert_ne!(root_after_delete, root_after_reset); assert_ne!(root_with_original, root_after_reset); - assert_eq!(tree_full.get(index).expect("Failed to get leaf"), new_leaf); + assert_eq!(tree_full.get(index).unwrap(), new_leaf); let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH); - tree_opt - .set(index, original_leaf) - .expect("Failed to set leaf"); + tree_opt.set(index, original_leaf).unwrap(); let root_with_original = tree_opt.root(); - tree_opt.delete(index).expect("Failed to delete leaf"); + tree_opt.delete(index).unwrap(); let root_after_delete = tree_opt.root(); assert_ne!(root_with_original, root_after_delete); - tree_opt.set(index, new_leaf).expect("Failed to set leaf"); + tree_opt.set(index, new_leaf).unwrap(); let root_after_reset = tree_opt.root(); assert_ne!(root_after_delete, root_after_reset); assert_ne!(root_with_original, root_after_reset); - assert_eq!(tree_opt.get(index).expect("Failed to get leaf"), new_leaf); + assert_eq!(tree_opt.get(index).unwrap(), new_leaf); } #[test] @@ -207,24 +203,24 @@ mod test { // check situation when the number of items to insert is less than the number of items to delete tree_full .override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); // check if the indexes for write and delete are the same tree_full .override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), Vec::::new()); // check if indexes for deletion are before indexes for overwriting tree_full .override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1, 2, 3]); // check if the indices for write and delete do not overlap completely tree_full .override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1]); let mut tree_opt = default_optimal_merkle_tree(depth); @@ -246,24 +242,24 @@ mod test { // check situation when the number of items to insert is less than the number of items to delete tree_opt .override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); // check if the indexes for write and delete are the same tree_opt .override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), Vec::::new()); // check if indexes for deletion are before indexes for overwriting tree_opt .override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1, 2, 3]); // check if the indices for write and delete do not overlap completely tree_opt .override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter()) - .expect("Failed to override range"); + .unwrap(); assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1]); } @@ -279,19 +275,12 @@ mod test { for i in 0..nof_leaves { // check leaves assert_eq!( - tree_full.get(i).expect("Failed to get leaf"), - tree_full - .get_subtree_root(depth, i) - .expect("Failed to get subtree root") + tree_full.get(i).unwrap(), + tree_full.get_subtree_root(depth, i).unwrap() ); // check root - assert_eq!( - tree_full.root(), - tree_full - .get_subtree_root(0, i) - .expect("Failed to get subtree root") - ); + assert_eq!(tree_full.root(), tree_full.get_subtree_root(0, i).unwrap()); } // check intermediate nodes @@ -301,18 +290,12 @@ mod test { let idx_r = (i + 1) * (1 << (depth - n)); let idx_sr = idx_l; - let prev_l = tree_full - .get_subtree_root(n, idx_l) - .expect("Failed to get subtree root"); - let prev_r = tree_full - .get_subtree_root(n, idx_r) - .expect("Failed to get subtree root"); - let subroot = tree_full - .get_subtree_root(n - 1, idx_sr) - .expect("Failed to get subtree root"); + let prev_l = tree_full.get_subtree_root(n, idx_l).unwrap(); + let prev_r = tree_full.get_subtree_root(n, idx_r).unwrap(); + let subroot = tree_full.get_subtree_root(n - 1, idx_sr).unwrap(); // check intermediate nodes - assert_eq!(Keccak256::hash(&[prev_l, prev_r]), subroot); + assert_eq!(Keccak256::hash(&[prev_l, prev_r]).unwrap(), subroot); } } @@ -322,18 +305,11 @@ mod test { for i in 0..nof_leaves { // check leaves assert_eq!( - tree_opt.get(i).expect("Failed to get leaf"), - tree_opt - .get_subtree_root(depth, i) - .expect("Failed to get subtree root") + tree_opt.get(i).unwrap(), + tree_opt.get_subtree_root(depth, i).unwrap() ); // check root - assert_eq!( - tree_opt.root(), - tree_opt - .get_subtree_root(0, i) - .expect("Failed to get subtree root") - ); + assert_eq!(tree_opt.root(), tree_opt.get_subtree_root(0, i).unwrap()); } // check intermediate nodes @@ -343,18 +319,12 @@ mod test { let idx_r = (i + 1) * (1 << (depth - n)); let idx_sr = idx_l; - let prev_l = tree_opt - .get_subtree_root(n, idx_l) - .expect("Failed to get subtree root"); - let prev_r = tree_opt - .get_subtree_root(n, idx_r) - .expect("Failed to get subtree root"); - let subroot = tree_opt - .get_subtree_root(n - 1, idx_sr) - .expect("Failed to get subtree root"); + let prev_l = tree_opt.get_subtree_root(n, idx_l).unwrap(); + let prev_r = tree_opt.get_subtree_root(n, idx_r).unwrap(); + let subroot = tree_opt.get_subtree_root(n - 1, idx_sr).unwrap(); // check intermediate nodes - assert_eq!(Keccak256::hash(&[prev_l, prev_r]), subroot); + assert_eq!(Keccak256::hash(&[prev_l, prev_r]).unwrap(), subroot); } } } @@ -368,52 +338,54 @@ mod test { let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH); for i in 0..nof_leaves { // We set the leaves - tree_full.set(i, leaves[i]).expect("Failed to set leaf"); + tree_full.set(i, leaves[i]).unwrap(); // We compute a merkle proof - let proof = tree_full.proof(i).expect("Failed to compute proof"); + let proof = tree_full.proof(i).unwrap(); // We verify if the merkle proof corresponds to the right leaf index assert_eq!(proof.leaf_index(), i); // We verify the proof - assert!(tree_full - .verify(&leaves[i], &proof) - .expect("Failed to verify proof")); + assert!(tree_full.verify(&leaves[i], &proof).unwrap()); // We ensure that the Merkle proof and the leaf generate the same root as the tree - assert_eq!(proof.compute_root_from(&leaves[i]), tree_full.root()); + assert_eq!( + proof.compute_root_from(&leaves[i]).unwrap(), + tree_full.root() + ); // We check that the proof is not valid for another leaf assert!(!tree_full .verify(&leaves[(i + 1) % nof_leaves], &proof) - .expect("Failed to verify proof")); + .unwrap()); } // We test the OptimalMerkleTree implementation let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH); for i in 0..nof_leaves { // We set the leaves - tree_opt.set(i, leaves[i]).expect("Failed to set leaf"); + tree_opt.set(i, leaves[i]).unwrap(); // We compute a merkle proof - let proof = tree_opt.proof(i).expect("Failed to compute proof"); + let proof = tree_opt.proof(i).unwrap(); // We verify if the merkle proof corresponds to the right leaf index assert_eq!(proof.leaf_index(), i); // We verify the proof - assert!(tree_opt - .verify(&leaves[i], &proof) - .expect("Failed to verify proof")); + assert!(tree_opt.verify(&leaves[i], &proof).unwrap()); // We ensure that the Merkle proof and the leaf generate the same root as the tree - assert_eq!(proof.compute_root_from(&leaves[i]), tree_opt.root()); + assert_eq!( + proof.compute_root_from(&leaves[i]).unwrap(), + tree_opt.root() + ); // We check that the proof is not valid for another leaf assert!(!tree_opt .verify(&leaves[(i + 1) % nof_leaves], &proof) - .expect("Failed to verify proof")); + .unwrap()); } } @@ -424,16 +396,12 @@ mod test { let invalid_leaf = TestFr::from(12345); - let proof_full = tree_full.proof(0).expect("Failed to compute proof"); - let proof_opt = tree_opt.proof(0).expect("Failed to compute proof"); + let proof_full = tree_full.proof(0).unwrap(); + let proof_opt = tree_opt.proof(0).unwrap(); // Should fail because no leaf was set - assert!(!tree_full - .verify(&invalid_leaf, &proof_full) - .expect("Failed to verify proof")); - assert!(!tree_opt - .verify(&invalid_leaf, &proof_opt) - .expect("Failed to verify proof")); + assert!(!tree_full.verify(&invalid_leaf, &proof_full).unwrap()); + assert!(!tree_opt.verify(&invalid_leaf, &proof_opt).unwrap()); } #[test] @@ -450,9 +418,7 @@ mod test { let to_delete_indices: [usize; 2] = [0, 1]; let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH); - tree_full - .set_range(0, leaves.iter().cloned()) - .expect("Failed to set leaves"); + tree_full.set_range(0, leaves.iter().cloned()).unwrap(); tree_full .override_range( @@ -460,16 +426,14 @@ mod test { new_leaves.iter().cloned(), to_delete_indices.iter().cloned(), ) - .expect("Failed to override range"); + .unwrap(); for (i, &new_leaf) in new_leaves.iter().enumerate() { - assert_eq!(tree_full.get(i).expect("Failed to get leaf"), new_leaf); + assert_eq!(tree_full.get(i).unwrap(), new_leaf); } let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH); - tree_opt - .set_range(0, leaves.iter().cloned()) - .expect("Failed to set leaves"); + tree_opt.set_range(0, leaves.iter().cloned()).unwrap(); tree_opt .override_range( @@ -477,10 +441,10 @@ mod test { new_leaves.iter().cloned(), to_delete_indices.iter().cloned(), ) - .expect("Failed to override range"); + .unwrap(); for (i, &new_leaf) in new_leaves.iter().enumerate() { - assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), new_leaf); + assert_eq!(tree_opt.get(i).unwrap(), new_leaf); } } @@ -499,20 +463,20 @@ mod test { tree_full .override_range(0, leaves.iter().cloned(), indices.iter().cloned()) - .expect("Failed to override range"); + .unwrap(); for (i, &leaf) in leaves.iter().enumerate() { - assert_eq!(tree_full.get(i).expect("Failed to get leaf"), leaf); + assert_eq!(tree_full.get(i).unwrap(), leaf); } let mut tree_opt = default_optimal_merkle_tree(depth); tree_opt .override_range(0, leaves.iter().cloned(), indices.iter().cloned()) - .expect("Failed to override range"); + .unwrap(); for (i, &leaf) in leaves.iter().enumerate() { - assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), leaf); + assert_eq!(tree_opt.get(i).unwrap(), leaf); } } } diff --git a/utils/tests/poseidon_constants.rs b/utils/tests/poseidon_constants.rs index 02bb0ae..39cc25b 100644 --- a/utils/tests/poseidon_constants.rs +++ b/utils/tests/poseidon_constants.rs @@ -3,7 +3,7 @@ mod test { use ark_bn254::Fr; use num_bigint::BigUint; use num_traits::Num; - use zerokit_utils::poseidon_hash::Poseidon; + use zerokit_utils::poseidon::Poseidon; const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [ (2, 8, 56, 0), diff --git a/utils/tests/poseidon_hash_test.rs b/utils/tests/poseidon_hash_test.rs index 32d533a..20666cb 100644 --- a/utils/tests/poseidon_hash_test.rs +++ b/utils/tests/poseidon_hash_test.rs @@ -4,7 +4,7 @@ mod test { use ark_bn254::Fr; use ark_ff::{AdditiveGroup, Field}; - use zerokit_utils::poseidon_hash::Poseidon; + use zerokit_utils::poseidon::Poseidon; const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [ (2, 8, 56, 0),