From 55b00fd6539d7cbd1950a9d18fe167774e753a2a Mon Sep 17 00:00:00 2001 From: tyshko-rostyslav <122977916+tyshko-rostyslav@users.noreply.github.com> Date: Mon, 27 Feb 2023 07:16:16 +0100 Subject: [PATCH] Code quality (#114) * to color_eyre::Result 1st part * tests and seconds batch * third batch * rln fixes + multiplier * rln-wasm, assert rln, multiplier * io to color_eyre * fmt + clippy * fix lint * temporary fix of `ark-circom` * fix ci after merge * fmt * fix rln tests * minor * fix tests * imports * requested change * report + commented line + requested change * requested changes * fix build * lint fixes * better comments --------- Co-authored-by: tyshkor --- multiplier/src/ffi.rs | 12 +- multiplier/src/main.rs | 15 +- multiplier/src/public.rs | 37 +++-- multiplier/tests/public.rs | 3 +- rln-wasm/Cargo.toml | 1 + rln-wasm/src/lib.rs | 50 ++++--- rln/src/circuit.rs | 171 +++++++++++---------- rln/src/ffi.rs | 20 ++- rln/src/protocol.rs | 162 ++++++++++---------- rln/src/public.rs | 214 +++++++++++---------------- rln/src/utils.rs | 74 +++++---- rln/tests/ffi.rs | 101 ++++++++----- rln/tests/protocol.rs | 93 ++++++++---- rln/tests/public.rs | 89 +++++++---- semaphore/src/protocol.rs | 4 +- utils/Cargo.toml | 1 + utils/src/merkle_tree/merkle_tree.rs | 94 ++++++------ 17 files changed, 606 insertions(+), 535 deletions(-) diff --git a/multiplier/src/ffi.rs b/multiplier/src/ffi.rs index 228cc0b..b865a90 100644 --- a/multiplier/src/ffi.rs +++ b/multiplier/src/ffi.rs @@ -31,12 +31,12 @@ impl<'a> From<&Buffer> for &'a [u8] { #[allow(clippy::not_unsafe_ptr_arg_deref)] #[no_mangle] pub extern "C" fn new_circuit(ctx: *mut *mut Multiplier) -> bool { - println!("multiplier ffi: new"); - let mul = Multiplier::new(); - - unsafe { *ctx = Box::into_raw(Box::new(mul)) }; - - true + if let Ok(mul) = Multiplier::new() { + unsafe { *ctx = Box::into_raw(Box::new(mul)) }; + true + } else { + false + } } #[allow(clippy::not_unsafe_ptr_arg_deref)] diff --git a/multiplier/src/main.rs b/multiplier/src/main.rs index c3f6176..30c1657 100644 --- a/multiplier/src/main.rs +++ b/multiplier/src/main.rs @@ -1,6 +1,6 @@ use ark_circom::{CircomBuilder, CircomConfig}; use ark_std::rand::thread_rng; -use color_eyre::Result; +use color_eyre::{Report, Result}; use ark_bn254::Bn254; use ark_groth16::{ @@ -25,17 +25,18 @@ fn groth16_proof_example() -> Result<()> { let circom = builder.build()?; - let inputs = circom.get_public_inputs().unwrap(); + let inputs = circom + .get_public_inputs() + .ok_or(Report::msg("no public inputs"))?; let proof = prove(circom, ¶ms, &mut rng)?; let pvk = prepare_verifying_key(¶ms.vk); - let verified = verify_proof(&pvk, &proof, &inputs)?; - - assert!(verified); - - Ok(()) + match verify_proof(&pvk, &proof, &inputs) { + Ok(_) => Ok(()), + Err(_) => Err(Report::msg("not verified")), + } } fn main() { diff --git a/multiplier/src/public.rs b/multiplier/src/public.rs index 950c4f6..4312e7a 100644 --- a/multiplier/src/public.rs +++ b/multiplier/src/public.rs @@ -7,9 +7,8 @@ use ark_groth16::{ Proof, ProvingKey, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -// , SerializationError}; - -use std::io::{self, Read, Write}; +use color_eyre::{Report, Result}; +use std::io::{Read, Write}; pub struct Multiplier { circom: CircomCircuit, @@ -18,12 +17,11 @@ pub struct Multiplier { impl Multiplier { // TODO Break this apart here - pub fn new() -> Multiplier { + pub fn new() -> Result { let cfg = CircomConfig::::new( "./resources/circom2_multiplier2.wasm", "./resources/circom2_multiplier2.r1cs", - ) - .unwrap(); + )?; let mut builder = CircomBuilder::new(cfg); builder.push_input("a", 3); @@ -34,40 +32,41 @@ impl Multiplier { let mut rng = thread_rng(); - let params = generate_random_parameters::(circom, &mut rng).unwrap(); + let params = generate_random_parameters::(circom, &mut rng)?; - let circom = builder.build().unwrap(); + let circom = builder.build()?; - //let inputs = circom.get_public_inputs().unwrap(); - - Multiplier { circom, params } + Ok(Multiplier { circom, params }) } // TODO Input Read - pub fn prove(&self, result_data: W) -> io::Result<()> { + pub fn prove(&self, result_data: W) -> Result<()> { let mut rng = thread_rng(); // XXX: There's probably a better way to do this let circom = self.circom.clone(); let params = self.params.clone(); - let proof = prove(circom, ¶ms, &mut rng).unwrap(); + let proof = prove(circom, ¶ms, &mut rng)?; // XXX: Unclear if this is different from other serialization(s) - proof.serialize(result_data).unwrap(); + proof.serialize(result_data)?; Ok(()) } - pub fn verify(&self, input_data: R) -> io::Result { - let proof = Proof::deserialize(input_data).unwrap(); + pub fn verify(&self, input_data: R) -> Result { + let proof = Proof::deserialize(input_data)?; let pvk = prepare_verifying_key(&self.params.vk); // XXX Part of input data? - let inputs = self.circom.get_public_inputs().unwrap(); + let inputs = self + .circom + .get_public_inputs() + .ok_or(Report::msg("no public inputs"))?; - let verified = verify_proof(&pvk, &proof, &inputs).unwrap(); + let verified = verify_proof(&pvk, &proof, &inputs)?; Ok(verified) } @@ -75,6 +74,6 @@ impl Multiplier { impl Default for Multiplier { fn default() -> Self { - Self::new() + Self::new().unwrap() } } diff --git a/multiplier/tests/public.rs b/multiplier/tests/public.rs index 220ce0d..f955e71 100644 --- a/multiplier/tests/public.rs +++ b/multiplier/tests/public.rs @@ -4,8 +4,7 @@ mod tests { #[test] fn multiplier_proof() { - let mul = Multiplier::new(); - //let inputs = mul.circom.get_public_inputs().unwrap(); + let mul = Multiplier::new().unwrap(); let mut output_data: Vec = Vec::new(); let _ = mul.prove(&mut output_data); diff --git a/rln-wasm/Cargo.toml b/rln-wasm/Cargo.toml index 6b61812..107947c 100644 --- a/rln-wasm/Cargo.toml +++ b/rln-wasm/Cargo.toml @@ -20,6 +20,7 @@ wasm-bindgen = "0.2.63" serde-wasm-bindgen = "0.4" js-sys = "0.3.59" serde_json = "1.0.85" +anyhow = "1.0.69" # The `console_error_panic_hook` crate provides better debugging of panics by # logging them with `console.error`. This is great for development, but requires diff --git a/rln-wasm/src/lib.rs b/rln-wasm/src/lib.rs index 32660ce..27c8e25 100644 --- a/rln-wasm/src/lib.rs +++ b/rln-wasm/src/lib.rs @@ -22,21 +22,30 @@ pub struct RLNWrapper { #[allow(clippy::not_unsafe_ptr_arg_deref)] #[wasm_bindgen(js_name = newRLN)] -pub fn wasm_new(tree_height: usize, zkey: Uint8Array, vk: Uint8Array) -> *mut RLNWrapper { - let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec()); +pub fn wasm_new( + tree_height: usize, + zkey: Uint8Array, + vk: Uint8Array, +) -> Result<*mut RLNWrapper, String> { + let instance = RLN::new_with_params(tree_height, zkey.to_vec(), vk.to_vec()) + .map_err(|err| format!("{:#?}", err))?; let wrapper = RLNWrapper { instance }; - Box::into_raw(Box::new(wrapper)) + Ok(Box::into_raw(Box::new(wrapper))) } #[allow(clippy::not_unsafe_ptr_arg_deref)] #[wasm_bindgen(js_name = getSerializedRLNWitness)] -pub fn wasm_get_serialized_rln_witness(ctx: *mut RLNWrapper, input: Uint8Array) -> Uint8Array { +pub fn wasm_get_serialized_rln_witness( + ctx: *mut RLNWrapper, + input: Uint8Array, +) -> Result { let wrapper = unsafe { &mut *ctx }; let rln_witness = wrapper .instance - .get_serialized_rln_witness(&input.to_vec()[..]); + .get_serialized_rln_witness(&input.to_vec()[..]) + .map_err(|err| format!("{:#?}", err))?; - Uint8Array::from(&rln_witness[..]) + Ok(Uint8Array::from(&rln_witness[..])) } #[allow(clippy::not_unsafe_ptr_arg_deref)] @@ -86,16 +95,18 @@ pub fn wasm_init_tree_with_leaves(ctx: *mut RLNWrapper, input: Uint8Array) -> Re #[allow(clippy::not_unsafe_ptr_arg_deref)] #[wasm_bindgen(js_name = RLNWitnessToJson)] -pub fn rln_witness_to_json(ctx: *mut RLNWrapper, serialized_witness: Uint8Array) -> Object { +pub fn rln_witness_to_json( + ctx: *mut RLNWrapper, + serialized_witness: Uint8Array, +) -> Result { let wrapper = unsafe { &mut *ctx }; let inputs = wrapper .instance .get_rln_witness_json(&serialized_witness.to_vec()[..]) - .unwrap(); + .map_err(|err| err.to_string())?; - let js_value = serde_wasm_bindgen::to_value(&inputs).unwrap(); - let obj = Object::from_entries(&js_value); - obj.unwrap() + let js_value = serde_wasm_bindgen::to_value(&inputs).map_err(|err| err.to_string())?; + Object::from_entries(&js_value).map_err(|err| format!("{:#?}", err)) } #[allow(clippy::not_unsafe_ptr_arg_deref)] @@ -107,17 +118,18 @@ pub fn generate_rln_proof_with_witness( ) -> Result { let wrapper = unsafe { &mut *ctx }; - let witness_vec: Vec = calculated_witness - .iter() - .map(|v| { + let mut witness_vec: Vec = vec![]; + + for v in calculated_witness { + witness_vec.push( v.to_string(10) - .unwrap() + .map_err(|err| format!("{:#?}", err))? .as_string() - .unwrap() + .ok_or("not a string error")? .parse::() - .unwrap() - }) - .collect(); + .map_err(|err| format!("{:#?}", err))?, + ); + } let mut output_data: Vec = Vec::new(); diff --git a/rln/src/circuit.rs b/rln/src/circuit.rs index 1e082de..b11b8f3 100644 --- a/rln/src/circuit.rs +++ b/rln/src/circuit.rs @@ -8,10 +8,11 @@ use ark_circom::read_zkey; use ark_groth16::{ProvingKey, VerifyingKey}; use ark_relations::r1cs::ConstraintMatrices; use cfg_if::cfg_if; +use color_eyre::{Report, Result}; use num_bigint::BigUint; use serde_json::Value; use std::fs::File; -use std::io::{Cursor, Error, ErrorKind, Result}; +use std::io::Cursor; use std::path::Path; use std::str::FromStr; @@ -57,7 +58,7 @@ pub fn zkey_from_raw(zkey_data: &Vec) -> Result<(ProvingKey, Constrai let proving_key_and_matrices = read_zkey(&mut c)?; Ok(proving_key_and_matrices) } else { - Err(Error::new(ErrorKind::NotFound, "No proving key found!")) + Err(Report::msg("No proving key found!")) } } @@ -71,7 +72,7 @@ pub fn zkey_from_folder( let proving_key_and_matrices = read_zkey(&mut file)?; Ok(proving_key_and_matrices) } else { - Err(Error::new(ErrorKind::NotFound, "No proving key found!")) + Err(Report::msg("No proving key found!")) } } @@ -80,17 +81,14 @@ pub fn vk_from_raw(vk_data: &Vec, zkey_data: &Vec) -> Result; if !vk_data.is_empty() { - verifying_key = vk_from_vector(vk_data); + verifying_key = vk_from_vector(vk_data)?; Ok(verifying_key) } else if !zkey_data.is_empty() { let (proving_key, _matrices) = zkey_from_raw(zkey_data)?; verifying_key = proving_key.vk; Ok(verifying_key) } else { - Err(Error::new( - ErrorKind::NotFound, - "No proving/verification key found!", - )) + Err(Report::msg("No proving/verification key found!")) } } @@ -102,17 +100,13 @@ pub fn vk_from_folder(resources_folder: &str) -> Result> { let verifying_key: VerifyingKey; if Path::new(&vk_path).exists() { - verifying_key = vk_from_json(&vk_path); - Ok(verifying_key) + vk_from_json(&vk_path) } else if Path::new(&zkey_path).exists() { let (proving_key, _matrices) = zkey_from_folder(resources_folder)?; verifying_key = proving_key.vk; Ok(verifying_key) } else { - Err(Error::new( - ErrorKind::NotFound, - "No proving/verification key found!", - )) + Err(Report::msg("No proving/verification key found!")) } } @@ -121,129 +115,146 @@ static WITNESS_CALCULATOR: OnceCell> = OnceCell::new(); // Initializes the witness calculator using a bytes vector #[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_raw(wasm_buffer: Vec) -> &'static Mutex { - WITNESS_CALCULATOR.get_or_init(|| { +pub fn circom_from_raw(wasm_buffer: Vec) -> Result<&'static Mutex> { + WITNESS_CALCULATOR.get_or_try_init(|| { let store = Store::default(); - let module = Module::new(&store, wasm_buffer).unwrap(); - let result = - WitnessCalculator::from_module(module).expect("Failed to create witness calculator"); - Mutex::new(result) + let module = Module::new(&store, wasm_buffer)?; + let result = WitnessCalculator::from_module(module)?; + Ok::, Report>(Mutex::new(result)) }) } // Initializes the witness calculator #[cfg(not(target_arch = "wasm32"))] -pub fn circom_from_folder(resources_folder: &str) -> &'static Mutex { +pub fn circom_from_folder(resources_folder: &str) -> Result<&'static Mutex> { // We read the wasm file let wasm_path = format!("{resources_folder}{WASM_FILENAME}"); - let wasm_buffer = std::fs::read(wasm_path).unwrap(); + let wasm_buffer = std::fs::read(wasm_path)?; circom_from_raw(wasm_buffer) } // The following function implementations are taken/adapted from https://github.com/gakonst/ark-circom/blob/1732e15d6313fe176b0b1abb858ac9e095d0dbd7/src/zkey.rs // Utilities to convert a json verification key in a groth16::VerificationKey -fn fq_from_str(s: &str) -> Fq { - Fq::try_from(BigUint::from_str(s).unwrap()).unwrap() +fn fq_from_str(s: &str) -> Result { + Ok(Fq::try_from(BigUint::from_str(s)?)?) } // Extracts the element in G1 corresponding to its JSON serialization -fn json_to_g1(json: &Value, key: &str) -> G1Affine { +fn json_to_g1(json: &Value, key: &str) -> Result { let els: Vec = json .get(key) - .unwrap() + .ok_or(Report::msg("no json value"))? .as_array() - .unwrap() + .ok_or(Report::msg("value not an array"))? .iter() - .map(|i| i.as_str().unwrap().to_string()) - .collect(); - G1Affine::from(G1Projective::new( - fq_from_str(&els[0]), - fq_from_str(&els[1]), - fq_from_str(&els[2]), - )) + .map(|i| i.as_str().ok_or(Report::msg("element is not a string"))) + .map(|x| x.map(|v| v.to_owned())) + .collect::>>()?; + + Ok(G1Affine::from(G1Projective::new( + fq_from_str(&els[0])?, + fq_from_str(&els[1])?, + fq_from_str(&els[2])?, + ))) } // Extracts the vector of G1 elements corresponding to its JSON serialization -fn json_to_g1_vec(json: &Value, key: &str) -> Vec { +fn json_to_g1_vec(json: &Value, key: &str) -> Result> { let els: Vec> = json .get(key) - .unwrap() + .ok_or(Report::msg("no json value"))? .as_array() - .unwrap() + .ok_or(Report::msg("value not an array"))? .iter() .map(|i| { i.as_array() - .unwrap() - .iter() - .map(|x| x.as_str().unwrap().to_string()) - .collect::>() + .ok_or(Report::msg("element is not an array")) + .and_then(|array| { + array + .iter() + .map(|x| x.as_str().ok_or(Report::msg("element is not a string"))) + .map(|x| x.map(|v| v.to_owned())) + .collect::>>() + }) }) - .collect(); + .collect::>>>()?; - els.iter() - .map(|coords| { - G1Affine::from(G1Projective::new( - fq_from_str(&coords[0]), - fq_from_str(&coords[1]), - fq_from_str(&coords[2]), - )) - }) - .collect() + let mut res = vec![]; + for coords in els { + res.push(G1Affine::from(G1Projective::new( + fq_from_str(&coords[0])?, + fq_from_str(&coords[1])?, + fq_from_str(&coords[2])?, + ))) + } + + Ok(res) } // Extracts the element in G2 corresponding to its JSON serialization -fn json_to_g2(json: &Value, key: &str) -> G2Affine { +fn json_to_g2(json: &Value, key: &str) -> Result { let els: Vec> = json .get(key) - .unwrap() + .ok_or(Report::msg("no json value"))? .as_array() - .unwrap() + .ok_or(Report::msg("value not an array"))? .iter() .map(|i| { i.as_array() - .unwrap() - .iter() - .map(|x| x.as_str().unwrap().to_string()) - .collect::>() + .ok_or(Report::msg("element is not an array")) + .and_then(|array| { + array + .iter() + .map(|x| x.as_str().ok_or(Report::msg("element is not a string"))) + .map(|x| x.map(|v| v.to_owned())) + .collect::>>() + }) }) - .collect(); + .collect::>>>()?; - let x = Fq2::new(fq_from_str(&els[0][0]), fq_from_str(&els[0][1])); - let y = Fq2::new(fq_from_str(&els[1][0]), fq_from_str(&els[1][1])); - let z = Fq2::new(fq_from_str(&els[2][0]), fq_from_str(&els[2][1])); - G2Affine::from(G2Projective::new(x, y, z)) + let x = Fq2::new(fq_from_str(&els[0][0])?, fq_from_str(&els[0][1])?); + let y = Fq2::new(fq_from_str(&els[1][0])?, fq_from_str(&els[1][1])?); + let z = Fq2::new(fq_from_str(&els[2][0])?, fq_from_str(&els[2][1])?); + Ok(G2Affine::from(G2Projective::new(x, y, z))) } // Converts JSON to a VerifyingKey -fn to_verifying_key(json: serde_json::Value) -> VerifyingKey { - VerifyingKey { - alpha_g1: json_to_g1(&json, "vk_alpha_1"), - beta_g2: json_to_g2(&json, "vk_beta_2"), - gamma_g2: json_to_g2(&json, "vk_gamma_2"), - delta_g2: json_to_g2(&json, "vk_delta_2"), - gamma_abc_g1: json_to_g1_vec(&json, "IC"), - } +fn to_verifying_key(json: serde_json::Value) -> Result> { + Ok(VerifyingKey { + alpha_g1: json_to_g1(&json, "vk_alpha_1")?, + beta_g2: json_to_g2(&json, "vk_beta_2")?, + gamma_g2: json_to_g2(&json, "vk_gamma_2")?, + delta_g2: json_to_g2(&json, "vk_delta_2")?, + gamma_abc_g1: json_to_g1_vec(&json, "IC")?, + }) } // Computes the verification key from its JSON serialization -fn vk_from_json(vk_path: &str) -> VerifyingKey { - let json = std::fs::read_to_string(vk_path).unwrap(); - let json: Value = serde_json::from_str(&json).unwrap(); +fn vk_from_json(vk_path: &str) -> Result> { + let json = std::fs::read_to_string(vk_path)?; + let json: Value = serde_json::from_str(&json)?; to_verifying_key(json) } // Computes the verification key from a bytes vector containing its JSON serialization -fn vk_from_vector(vk: &[u8]) -> VerifyingKey { - let json = String::from_utf8(vk.to_vec()).expect("Found invalid UTF-8"); - let json: Value = serde_json::from_str(&json).unwrap(); +fn vk_from_vector(vk: &[u8]) -> Result> { + let json = String::from_utf8(vk.to_vec())?; + let json: Value = serde_json::from_str(&json)?; to_verifying_key(json) } // Checks verification key to be correct with respect to proving key -pub fn check_vk_from_zkey(resources_folder: &str, verifying_key: VerifyingKey) { - let (proving_key, _matrices) = zkey_from_folder(resources_folder).unwrap(); - assert_eq!(proving_key.vk, verifying_key); +pub fn check_vk_from_zkey( + resources_folder: &str, + verifying_key: VerifyingKey, +) -> Result<()> { + let (proving_key, _matrices) = zkey_from_folder(resources_folder)?; + if proving_key.vk == verifying_key { + Ok(()) + } else { + Err(Report::msg("verifying_keys are not equal")) + } } diff --git a/rln/src/ffi.rs b/rln/src/ffi.rs index 5ac3768..58f2a91 100644 --- a/rln/src/ffi.rs +++ b/rln/src/ffi.rs @@ -171,9 +171,12 @@ impl<'a> From<&Buffer> for &'a [u8] { #[allow(clippy::not_unsafe_ptr_arg_deref)] #[no_mangle] pub extern "C" fn new(tree_height: usize, input_buffer: *const Buffer, ctx: *mut *mut RLN) -> bool { - let rln = RLN::new(tree_height, input_buffer.process()); - unsafe { *ctx = Box::into_raw(Box::new(rln)) }; - true + if let Ok(rln) = RLN::new(tree_height, input_buffer.process()) { + unsafe { *ctx = Box::into_raw(Box::new(rln)) }; + true + } else { + false + } } #[allow(clippy::not_unsafe_ptr_arg_deref)] @@ -185,14 +188,17 @@ pub extern "C" fn new_with_params( vk_buffer: *const Buffer, ctx: *mut *mut RLN, ) -> bool { - let rln = RLN::new_with_params( + if let Ok(rln) = RLN::new_with_params( tree_height, circom_buffer.process().to_vec(), zkey_buffer.process().to_vec(), vk_buffer.process().to_vec(), - ); - unsafe { *ctx = Box::into_raw(Box::new(rln)) }; - true + ) { + unsafe { *ctx = Box::into_raw(Box::new(rln)) }; + true + } else { + false + } } //////////////////////////////////////////////////////// diff --git a/rln/src/protocol.rs b/rln/src/protocol.rs index 4a56269..240dc9c 100644 --- a/rln/src/protocol.rs +++ b/rln/src/protocol.rs @@ -8,7 +8,7 @@ use ark_groth16::{ use ark_relations::r1cs::ConstraintMatrices; use ark_relations::r1cs::SynthesisError; use ark_std::{rand::thread_rng, UniformRand}; -use color_eyre::Result; +use color_eyre::{Report, Result}; use num_bigint::BigInt; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; @@ -91,29 +91,29 @@ pub fn deserialize_identity_tuple(serialized: Vec) -> (Fr, Fr, Fr, Fr) { ) } -pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Vec { +pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result> { let mut serialized: Vec = Vec::new(); serialized.append(&mut fr_to_bytes_le(&rln_witness.identity_secret)); - serialized.append(&mut vec_fr_to_bytes_le(&rln_witness.path_elements)); - serialized.append(&mut vec_u8_to_bytes_le(&rln_witness.identity_path_index)); + serialized.append(&mut vec_fr_to_bytes_le(&rln_witness.path_elements)?); + serialized.append(&mut vec_u8_to_bytes_le(&rln_witness.identity_path_index)?); serialized.append(&mut fr_to_bytes_le(&rln_witness.x)); serialized.append(&mut fr_to_bytes_le(&rln_witness.epoch)); serialized.append(&mut fr_to_bytes_le(&rln_witness.rln_identifier)); - serialized + Ok(serialized) } -pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) { +pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize)> { let mut all_read: usize = 0; let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]); all_read += read; - let (path_elements, read) = bytes_le_to_vec_fr(&serialized[all_read..]); + let (path_elements, read) = bytes_le_to_vec_fr(&serialized[all_read..])?; all_read += read; - let (identity_path_index, read) = bytes_le_to_vec_u8(&serialized[all_read..]); + let (identity_path_index, read) = bytes_le_to_vec_u8(&serialized[all_read..])?; all_read += read; let (x, read) = bytes_le_to_fr(&serialized[all_read..]); @@ -126,9 +126,11 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) { all_read += read; // TODO: check rln_identifier against public::RLN_IDENTIFIER - assert_eq!(serialized.len(), all_read); + if serialized.len() != all_read { + return Err(Report::msg("serialized length is not equal to all_read")); + } - ( + Ok(( RLNWitnessInput { identity_secret, path_elements, @@ -138,7 +140,7 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) { rln_identifier, }, all_read, - ) + )) } // This function deserializes input for kilic's rln generate_proof public API @@ -148,19 +150,19 @@ pub fn deserialize_witness(serialized: &[u8]) -> (RLNWitnessInput, usize) { pub fn proof_inputs_to_rln_witness( tree: &mut PoseidonTree, serialized: &[u8], -) -> (RLNWitnessInput, usize) { +) -> Result<(RLNWitnessInput, usize)> { let mut all_read: usize = 0; let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]); all_read += read; - let id_index = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()); + let id_index = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?); all_read += 8; let (epoch, read) = bytes_le_to_fr(&serialized[all_read..]); all_read += read; - let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()); + let signal_len = u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?); all_read += 8; let signal: Vec = serialized[all_read..all_read + (signal_len as usize)].to_vec(); @@ -173,7 +175,7 @@ pub fn proof_inputs_to_rln_witness( let rln_identifier = hash_to_field(RLN_IDENTIFIER); - ( + Ok(( RLNWitnessInput { identity_secret, path_elements, @@ -183,45 +185,48 @@ pub fn proof_inputs_to_rln_witness( rln_identifier, }, all_read, - ) + )) } -pub fn rln_witness_from_json(input_json_str: &str) -> RLNWitnessInput { +pub fn rln_witness_from_json(input_json_str: &str) -> Result { let input_json: serde_json::Value = serde_json::from_str(input_json_str).expect("JSON was not well-formatted"); - let identity_secret = str_to_fr(&input_json["identity_secret"].to_string(), 10); + let identity_secret = str_to_fr(&input_json["identity_secret"].to_string(), 10)?; let path_elements = input_json["path_elements"] .as_array() - .unwrap() + .ok_or(Report::msg("not an array"))? .iter() .map(|v| str_to_fr(&v.to_string(), 10)) - .collect(); + .collect::>()?; - let identity_path_index = input_json["identity_path_index"] + let identity_path_index_array = input_json["identity_path_index"] .as_array() - .unwrap() - .iter() - .map(|v| v.as_u64().unwrap() as u8) - .collect(); + .ok_or(Report::msg("not an arrray"))?; - let x = str_to_fr(&input_json["x"].to_string(), 10); + let mut identity_path_index: Vec = vec![]; - let epoch = str_to_fr(&input_json["epoch"].to_string(), 16); + for v in identity_path_index_array { + identity_path_index.push(v.as_u64().ok_or(Report::msg("not a u64 value"))? as u8); + } - let rln_identifier = str_to_fr(&input_json["rln_identifier"].to_string(), 10); + let x = str_to_fr(&input_json["x"].to_string(), 10)?; + + let epoch = str_to_fr(&input_json["epoch"].to_string(), 16)?; + + let rln_identifier = str_to_fr(&input_json["rln_identifier"].to_string(), 10)?; // TODO: check rln_identifier against public::RLN_IDENTIFIER - RLNWitnessInput { + Ok(RLNWitnessInput { identity_secret, path_elements, identity_path_index, x, epoch, rln_identifier, - } + }) } pub fn rln_witness_from_values( @@ -353,8 +358,8 @@ pub fn prepare_prove_input( id_index: usize, epoch: Fr, signal: &[u8], -) -> Vec { - let signal_len = u64::try_from(signal.len()).unwrap(); +) -> Result> { + let signal_len = u64::try_from(signal.len())?; let mut serialized: Vec = Vec::new(); @@ -364,12 +369,12 @@ pub fn prepare_prove_input( serialized.append(&mut signal_len.to_le_bytes().to_vec()); serialized.append(&mut signal.to_vec()); - serialized + Ok(serialized) } #[allow(clippy::redundant_clone)] -pub fn prepare_verify_input(proof_data: Vec, signal: &[u8]) -> Vec { - let signal_len = u64::try_from(signal.len()).unwrap(); +pub fn prepare_verify_input(proof_data: Vec, signal: &[u8]) -> Result> { + let signal_len = u64::try_from(signal.len())?; let mut serialized: Vec = Vec::new(); @@ -377,7 +382,7 @@ pub fn prepare_verify_input(proof_data: Vec, signal: &[u8]) -> Vec { serialized.append(&mut signal_len.to_le_bytes().to_vec()); serialized.append(&mut signal.to_vec()); - serialized + Ok(serialized) } /////////////////////////////////////////////////////// @@ -533,9 +538,9 @@ pub fn compute_id_secret( #[derive(Error, Debug)] pub enum ProofError { #[error("Error reading circuit key: {0}")] - CircuitKeyError(#[from] std::io::Error), + CircuitKeyError(#[from] Report), #[error("Error producing witness: {0}")] - WitnessError(color_eyre::Report), + WitnessError(Report), #[error("Error producing proof: {0}")] SynthesisError(#[from] SynthesisError), } @@ -546,20 +551,21 @@ fn calculate_witness_element(witness: Vec) -> // convert it to field elements use num_traits::Signed; - let witness = witness - .into_iter() - .map(|w| { - let w = if w.sign() == num_bigint::Sign::Minus { - // Need to negate the witness element if negative - modulus.into() - w.abs().to_biguint().unwrap() - } else { - w.to_biguint().unwrap() - }; - E::Fr::from(w) - }) - .collect::>(); + let mut witness_vec = vec![]; + for w in witness.into_iter() { + let w = if w.sign() == num_bigint::Sign::Minus { + // Need to negate the witness element if negative + modulus.into() + - w.abs() + .to_biguint() + .ok_or(Report::msg("not a biguint value"))? + } else { + w.to_biguint().ok_or(Report::msg("not a biguint value"))? + }; + witness_vec.push(E::Fr::from(w)) + } - Ok(witness) + Ok(witness_vec) } pub fn generate_proof_with_witness( @@ -570,9 +576,8 @@ pub fn generate_proof_with_witness( #[cfg(debug_assertions)] let now = Instant::now(); - let full_assignment = calculate_witness_element::(witness) - .map_err(ProofError::WitnessError) - .unwrap(); + let full_assignment = + calculate_witness_element::(witness).map_err(ProofError::WitnessError)?; #[cfg(debug_assertions)] println!("witness generation took: {:.2?}", now.elapsed()); @@ -594,8 +599,7 @@ pub fn generate_proof_with_witness( proving_key.1.num_instance_variables, proving_key.1.num_constraints, full_assignment.as_slice(), - ) - .unwrap(); + )?; #[cfg(debug_assertions)] println!("proof generation took: {:.2?}", now.elapsed()); @@ -603,14 +607,16 @@ pub fn generate_proof_with_witness( Ok(proof) } -pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str, Vec); 6] { +pub fn inputs_for_witness_calculation( + rln_witness: &RLNWitnessInput, +) -> Result<[(&str, Vec); 6]> { // We confert the path indexes to field elements // TODO: check if necessary let mut path_elements = Vec::new(); - rln_witness - .path_elements - .iter() - .for_each(|v| path_elements.push(to_bigint(v))); + + for v in rln_witness.path_elements.iter() { + path_elements.push(to_bigint(v)?); + } let mut identity_path_index = Vec::new(); rln_witness @@ -618,20 +624,20 @@ pub fn inputs_for_witness_calculation(rln_witness: &RLNWitnessInput) -> [(&str, .iter() .for_each(|v| identity_path_index.push(BigInt::from(*v))); - [ + Ok([ ( "identity_secret", - vec![to_bigint(&rln_witness.identity_secret)], + vec![to_bigint(&rln_witness.identity_secret)?], ), ("path_elements", path_elements), ("identity_path_index", identity_path_index), - ("x", vec![to_bigint(&rln_witness.x)]), - ("epoch", vec![to_bigint(&rln_witness.epoch)]), + ("x", vec![to_bigint(&rln_witness.x)?]), + ("epoch", vec![to_bigint(&rln_witness.epoch)?]), ( "rln_identifier", - vec![to_bigint(&rln_witness.rln_identifier)], + vec![to_bigint(&rln_witness.rln_identifier)?], ), - ] + ]) } /// Generates a RLN proof @@ -645,7 +651,7 @@ pub fn generate_proof( proving_key: &(ProvingKey, ConstraintMatrices), rln_witness: &RLNWitnessInput, ) -> Result, ProofError> { - let inputs = inputs_for_witness_calculation(rln_witness) + let inputs = inputs_for_witness_calculation(rln_witness)? .into_iter() .map(|(name, values)| (name.to_string(), values)); @@ -736,12 +742,12 @@ pub fn verify_proof( /// /// Returns a JSON object containing the inputs necessary to calculate /// the witness with CIRCOM on javascript -pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value { +pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> Result { let mut path_elements = Vec::new(); - rln_witness - .path_elements - .iter() - .for_each(|v| path_elements.push(to_bigint(v).to_str_radix(10))); + + for v in rln_witness.path_elements.iter() { + path_elements.push(to_bigint(v)?.to_str_radix(10)); + } let mut identity_path_index = Vec::new(); rln_witness @@ -750,13 +756,13 @@ pub fn get_json_inputs(rln_witness: &RLNWitnessInput) -> serde_json::Value { .for_each(|v| identity_path_index.push(BigInt::from(*v).to_str_radix(10))); let inputs = serde_json::json!({ - "identity_secret": to_bigint(&rln_witness.identity_secret).to_str_radix(10), + "identity_secret": to_bigint(&rln_witness.identity_secret)?.to_str_radix(10), "path_elements": path_elements, "identity_path_index": identity_path_index, - "x": to_bigint(&rln_witness.x).to_str_radix(10), - "epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)), - "rln_identifier": to_bigint(&rln_witness.rln_identifier).to_str_radix(10), + "x": to_bigint(&rln_witness.x)?.to_str_radix(10), + "epoch": format!("0x{:064x}", to_bigint(&rln_witness.epoch)?), + "rln_identifier": to_bigint(&rln_witness.rln_identifier)?.to_str_radix(10), }); - inputs + Ok(inputs) } diff --git a/rln/src/public.rs b/rln/src/public.rs index 92ba43e..8989433 100644 --- a/rln/src/public.rs +++ b/rln/src/public.rs @@ -10,9 +10,9 @@ use ark_groth16::{ProvingKey, VerifyingKey}; use ark_relations::r1cs::ConstraintMatrices; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write}; use cfg_if::cfg_if; +use color_eyre::Result; use num_bigint::BigInt; use std::io::Cursor; -use std::io::{self, Result}; cfg_if! { if #[cfg(not(target_arch = "wasm32"))] { @@ -36,8 +36,8 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809"; /// /// I/O is mostly done using writers and readers implementing `std::io::Write` and `std::io::Read`, respectively. pub struct RLN<'a> { - proving_key: Result<(ProvingKey, ConstraintMatrices)>, - verification_key: Result>, + proving_key: (ProvingKey, ConstraintMatrices), + verification_key: VerifyingKey, tree: PoseidonTree, // The witness calculator can't be loaded in zerokit. Since this struct @@ -67,29 +67,29 @@ impl RLN<'_> { /// let mut rln = RLN::new(tree_height, resources); /// ``` #[cfg(not(target_arch = "wasm32"))] - pub fn new(tree_height: usize, mut input_data: R) -> RLN<'static> { + pub fn new(tree_height: usize, mut input_data: R) -> Result> { // We read input let mut input: Vec = Vec::new(); - input_data.read_to_end(&mut input).unwrap(); + input_data.read_to_end(&mut input)?; - let resources_folder = String::from_utf8(input).expect("Found invalid UTF-8"); + let resources_folder = String::from_utf8(input)?; - let witness_calculator = circom_from_folder(&resources_folder); + let witness_calculator = circom_from_folder(&resources_folder)?; - let proving_key = zkey_from_folder(&resources_folder); - let verification_key = vk_from_folder(&resources_folder); + let proving_key = zkey_from_folder(&resources_folder)?; + let verification_key = vk_from_folder(&resources_folder)?; // We compute a default empty tree let tree = PoseidonTree::default(tree_height); - RLN { + Ok(RLN { witness_calculator, proving_key, verification_key, tree, #[cfg(target_arch = "wasm32")] _marker: PhantomData, - } + }) } /// Creates a new RLN object by passing circuit resources as byte vectors. @@ -130,17 +130,17 @@ impl RLN<'_> { #[cfg(not(target_arch = "wasm32"))] circom_vec: Vec, zkey_vec: Vec, vk_vec: Vec, - ) -> RLN<'static> { + ) -> Result> { #[cfg(not(target_arch = "wasm32"))] - let witness_calculator = circom_from_raw(circom_vec); + let witness_calculator = circom_from_raw(circom_vec)?; - let proving_key = zkey_from_raw(&zkey_vec); - let verification_key = vk_from_raw(&vk_vec, &zkey_vec); + let proving_key = zkey_from_raw(&zkey_vec)?; + let verification_key = vk_from_raw(&vk_vec, &zkey_vec)?; // We compute a default empty tree let tree = PoseidonTree::default(tree_height); - RLN { + Ok(RLN { #[cfg(not(target_arch = "wasm32"))] witness_calculator, proving_key, @@ -148,7 +148,7 @@ impl RLN<'_> { tree, #[cfg(target_arch = "wasm32")] _marker: PhantomData, - } + }) } //////////////////////////////////////////////////////// @@ -160,7 +160,7 @@ impl RLN<'_> { /// /// Input values are: /// - `tree_height`: the height of the Merkle tree. - pub fn set_tree(&mut self, tree_height: usize) -> io::Result<()> { + pub fn set_tree(&mut self, tree_height: usize) -> Result<()> { // We compute a default empty tree of desired height self.tree = PoseidonTree::default(tree_height); @@ -187,7 +187,7 @@ impl RLN<'_> { /// let mut buffer = Cursor::new(serialize_field_element(id_commitment)); /// rln.set_leaf(id_index, &mut buffer).unwrap(); /// ``` - pub fn set_leaf(&mut self, index: usize, mut input_data: R) -> io::Result<()> { + pub fn set_leaf(&mut self, index: usize, mut input_data: R) -> Result<()> { // We read input let mut leaf_byte: Vec = Vec::new(); input_data.read_to_end(&mut leaf_byte)?; @@ -229,12 +229,12 @@ impl RLN<'_> { /// let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); /// rln.set_leaves_from(index, &mut buffer).unwrap(); /// ``` - pub fn set_leaves_from(&mut self, index: usize, mut input_data: R) -> io::Result<()> { + pub fn set_leaves_from(&mut self, index: usize, mut input_data: R) -> Result<()> { // We read input let mut leaves_byte: Vec = Vec::new(); input_data.read_to_end(&mut leaves_byte)?; - let (leaves, _) = bytes_le_to_vec_fr(&leaves_byte); + let (leaves, _) = bytes_le_to_vec_fr(&leaves_byte)?; // We set the leaves self.tree.set_range(index, leaves) @@ -246,7 +246,7 @@ impl RLN<'_> { /// /// Input values are: /// - `input_data`: a reader for the serialization of multiple leaf values (serialization done with [`rln::utils::vec_fr_to_bytes_le`](crate::utils::vec_fr_to_bytes_le)) - pub fn init_tree_with_leaves(&mut self, input_data: R) -> io::Result<()> { + pub fn init_tree_with_leaves(&mut self, input_data: R) -> Result<()> { // reset the tree // NOTE: this requires the tree to be initialized with the correct height initially // TODO: accept tree_height as a parameter and initialize the tree with that height @@ -295,7 +295,7 @@ impl RLN<'_> { /// let mut buffer = Cursor::new(fr_to_bytes_le(&id_commitment)); /// rln.set_next_leaf(&mut buffer).unwrap(); /// ``` - pub fn set_next_leaf(&mut self, mut input_data: R) -> io::Result<()> { + pub fn set_next_leaf(&mut self, mut input_data: R) -> Result<()> { // We read input let mut leaf_byte: Vec = Vec::new(); input_data.read_to_end(&mut leaf_byte)?; @@ -320,7 +320,7 @@ impl RLN<'_> { /// let index = 10; /// rln.delete_leaf(index).unwrap(); /// ``` - pub fn delete_leaf(&mut self, index: usize) -> io::Result<()> { + pub fn delete_leaf(&mut self, index: usize) -> Result<()> { self.tree.delete(index)?; Ok(()) } @@ -338,7 +338,7 @@ impl RLN<'_> { /// rln.get_root(&mut buffer).unwrap(); /// let (root, _) = bytes_le_to_fr(&buffer.into_inner()); /// ``` - pub fn get_root(&self, mut output_data: W) -> io::Result<()> { + pub fn get_root(&self, mut output_data: W) -> Result<()> { let root = self.tree.root(); output_data.write_all(&fr_to_bytes_le(&root))?; @@ -366,13 +366,13 @@ impl RLN<'_> { /// let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner); /// let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec()); /// ``` - pub fn get_proof(&self, index: usize, mut output_data: W) -> io::Result<()> { + pub fn get_proof(&self, index: usize, mut output_data: W) -> Result<()> { let merkle_proof = self.tree.proof(index).expect("proof should exist"); let path_elements = merkle_proof.get_path_elements(); let identity_path_index = merkle_proof.get_path_index(); - output_data.write_all(&vec_fr_to_bytes_le(&path_elements))?; - output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index))?; + output_data.write_all(&vec_fr_to_bytes_le(&path_elements)?)?; + output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index)?)?; Ok(()) } @@ -406,11 +406,11 @@ impl RLN<'_> { &mut self, mut input_data: R, mut output_data: W, - ) -> io::Result<()> { + ) -> Result<()> { // We read input RLN witness and we deserialize it let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; - let (rln_witness, _) = deserialize_witness(&serialized); + let (rln_witness, _) = deserialize_witness(&serialized)?; /* if self.witness_calculator.is_none() { @@ -418,15 +418,10 @@ impl RLN<'_> { } */ - let proof = generate_proof( - self.witness_calculator, - self.proving_key.as_ref().unwrap(), - &rln_witness, - ) - .unwrap(); + let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof - proof.serialize(&mut output_data).unwrap(); + proof.serialize(&mut output_data)?; Ok(()) } @@ -466,22 +461,17 @@ impl RLN<'_> { /// /// assert!(verified); /// ``` - pub fn verify(&self, mut input_data: R) -> io::Result { + pub fn verify(&self, mut input_data: R) -> Result { // Input data is serialized for Curve as: // serialized_proof (compressed, 4*32 bytes) || serialized_proof_values (6*32 bytes), i.e. // [ proof<128> | root<32> | epoch<32> | share_x<32> | share_y<32> | nullifier<32> | rln_identifier<32> ] let mut input_byte: Vec = Vec::new(); input_data.read_to_end(&mut input_byte)?; - let proof = ArkProof::deserialize(&mut Cursor::new(&input_byte[..128])).unwrap(); + let proof = ArkProof::deserialize(&mut Cursor::new(&input_byte[..128]))?; let (proof_values, _) = deserialize_proof_values(&input_byte[128..]); - let verified = verify_proof( - self.verification_key.as_ref().unwrap(), - &proof, - &proof_values, - ) - .unwrap(); + let verified = verify_proof(&self.verification_key, &proof, &proof_values)?; Ok(verified) } @@ -537,23 +527,18 @@ impl RLN<'_> { &mut self, mut input_data: R, mut output_data: W, - ) -> io::Result<()> { + ) -> Result<()> { // We read input RLN witness and we deserialize it let mut witness_byte: Vec = Vec::new(); input_data.read_to_end(&mut witness_byte)?; - let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte); + let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?; let proof_values = proof_values_from_witness(&rln_witness); - let proof = generate_proof( - self.witness_calculator, - self.proving_key.as_ref().unwrap(), - &rln_witness, - ) - .unwrap(); + let proof = generate_proof(self.witness_calculator, &self.proving_key, &rln_witness)?; // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long - proof.serialize(&mut output_data).unwrap(); + proof.serialize(&mut output_data)?; output_data.write_all(&serialize_proof_values(&proof_values))?; Ok(()) @@ -570,17 +555,15 @@ impl RLN<'_> { calculated_witness: Vec, rln_witness_vec: Vec, mut output_data: W, - ) -> io::Result<()> { - let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..]); + ) -> Result<()> { + let (rln_witness, _) = deserialize_witness(&rln_witness_vec[..])?; let proof_values = proof_values_from_witness(&rln_witness); - let proof = - generate_proof_with_witness(calculated_witness, self.proving_key.as_ref().unwrap()) - .unwrap(); + let proof = generate_proof_with_witness(calculated_witness, &self.proving_key).unwrap(); // Note: we export a serialization of ark-groth16::Proof not semaphore::Proof // This proof is compressed, i.e. 128 bytes long - proof.serialize(&mut output_data).unwrap(); + proof.serialize(&mut output_data)?; output_data.write_all(&serialize_proof_values(&proof_values))?; Ok(()) } @@ -612,27 +595,22 @@ impl RLN<'_> { /// /// assert!(verified); /// ``` - pub fn verify_rln_proof(&self, mut input_data: R) -> io::Result { + pub fn verify_rln_proof(&self, mut input_data: R) -> Result { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; let mut all_read = 0; - let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec())).unwrap(); + let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec()))?; all_read += 128; let (proof_values, read) = deserialize_proof_values(&serialized[all_read..]); all_read += read; let signal_len = - u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize; + u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?) as usize; all_read += 8; let signal: Vec = serialized[all_read..all_read + signal_len].to_vec(); - let verified = verify_proof( - self.verification_key.as_ref().unwrap(), - &proof, - &proof_values, - ) - .unwrap(); + let verified = verify_proof(&self.verification_key, &proof, &proof_values)?; // Consistency checks to counter proof tampering let x = hash_to_field(&signal); @@ -693,31 +671,22 @@ impl RLN<'_> { /// /// assert!(verified); /// ``` - pub fn verify_with_roots( - &self, - mut input_data: R, - mut roots_data: R, - ) -> io::Result { + pub fn verify_with_roots(&self, mut input_data: R, mut roots_data: R) -> Result { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; let mut all_read = 0; - let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec())).unwrap(); + let proof = ArkProof::deserialize(&mut Cursor::new(&serialized[..128].to_vec()))?; all_read += 128; let (proof_values, read) = deserialize_proof_values(&serialized[all_read..]); all_read += read; let signal_len = - u64::from_le_bytes(serialized[all_read..all_read + 8].try_into().unwrap()) as usize; + u64::from_le_bytes(serialized[all_read..all_read + 8].try_into()?) as usize; all_read += 8; let signal: Vec = serialized[all_read..all_read + signal_len].to_vec(); - let verified = verify_proof( - self.verification_key.as_ref().unwrap(), - &proof, - &proof_values, - ) - .unwrap(); + let verified = verify_proof(&self.verification_key, &proof, &proof_values)?; // First consistency checks to counter proof tampering let x = hash_to_field(&signal); @@ -783,7 +752,7 @@ impl RLN<'_> { /// // We deserialize the keygen output /// let (identity_secret_hash, id_commitment) = deserialize_identity_pair(buffer.into_inner()); /// ``` - pub fn key_gen(&self, mut output_data: W) -> io::Result<()> { + pub fn key_gen(&self, mut output_data: W) -> Result<()> { let (identity_secret_hash, id_commitment) = keygen(); output_data.write_all(&fr_to_bytes_le(&identity_secret_hash))?; output_data.write_all(&fr_to_bytes_le(&id_commitment))?; @@ -813,7 +782,7 @@ impl RLN<'_> { /// // We deserialize the keygen output /// let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) = deserialize_identity_tuple(buffer.into_inner()); /// ``` - pub fn extended_key_gen(&self, mut output_data: W) -> io::Result<()> { + pub fn extended_key_gen(&self, mut output_data: W) -> Result<()> { let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) = extended_keygen(); output_data.write_all(&fr_to_bytes_le(&identity_trapdoor))?; @@ -852,7 +821,7 @@ impl RLN<'_> { &self, mut input_data: R, mut output_data: W, - ) -> io::Result<()> { + ) -> Result<()> { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; @@ -895,7 +864,7 @@ impl RLN<'_> { &self, mut input_data: R, mut output_data: W, - ) -> io::Result<()> { + ) -> Result<()> { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; @@ -946,7 +915,7 @@ impl RLN<'_> { mut input_proof_data_1: R, mut input_proof_data_2: R, mut output_data: W, - ) -> io::Result<()> { + ) -> Result<()> { // We deserialize the two proofs and we get the corresponding RLNProofValues objects let mut serialized: Vec = Vec::new(); input_proof_data_1.read_to_end(&mut serialized)?; @@ -990,11 +959,11 @@ impl RLN<'_> { /// - `input_data`: a reader for the serialization of `[ identity_secret<32> | id_index<8> | epoch<32> | signal_len<8> | signal ]` /// /// The function returns the corresponding [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object serialized using [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness)). - pub fn get_serialized_rln_witness(&mut self, mut input_data: R) -> Vec { + pub fn get_serialized_rln_witness(&mut self, mut input_data: R) -> Result> { // We read input RLN witness and we deserialize it let mut witness_byte: Vec = Vec::new(); - input_data.read_to_end(&mut witness_byte).unwrap(); - let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte); + input_data.read_to_end(&mut witness_byte)?; + let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?; serialize_witness(&rln_witness) } @@ -1005,12 +974,9 @@ impl RLN<'_> { /// - `serialized_witness`: the byte serialization of a [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object (serialization done with [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness)). /// /// The function returns the corresponding JSON encoding of the input [`RLNWitnessInput`](crate::protocol::RLNWitnessInput) object. - pub fn get_rln_witness_json( - &mut self, - serialized_witness: &[u8], - ) -> io::Result { - let (rln_witness, _) = deserialize_witness(serialized_witness); - Ok(get_json_inputs(&rln_witness)) + pub fn get_rln_witness_json(&mut self, serialized_witness: &[u8]) -> Result { + let (rln_witness, _) = deserialize_witness(serialized_witness)?; + get_json_inputs(&rln_witness) } } @@ -1019,7 +985,7 @@ impl Default for RLN<'_> { fn default() -> Self { let tree_height = TEST_TREE_HEIGHT; let buffer = Cursor::new(TEST_RESOURCES_FOLDER); - Self::new(tree_height, buffer) + Self::new(tree_height, buffer).unwrap() } } @@ -1045,7 +1011,7 @@ impl Default for RLN<'_> { /// // We deserialize the keygen output /// let field_element = deserialize_field_element(output_buffer.into_inner()); /// ``` -pub fn hash(mut input_data: R, mut output_data: W) -> io::Result<()> { +pub fn hash(mut input_data: R, mut output_data: W) -> Result<()> { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; @@ -1078,11 +1044,11 @@ pub fn hash(mut input_data: R, mut output_data: W) -> io::Res /// // We deserialize the hash output /// let hash_result = deserialize_field_element(output_buffer.into_inner()); /// ``` -pub fn poseidon_hash(mut input_data: R, mut output_data: W) -> io::Result<()> { +pub fn poseidon_hash(mut input_data: R, mut output_data: W) -> Result<()> { let mut serialized: Vec = Vec::new(); input_data.read_to_end(&mut serialized)?; - let (inputs, _) = bytes_le_to_vec_fr(&serialized); + let (inputs, _) = bytes_le_to_vec_fr(&serialized)?; let hash = utils_poseidon_hash(inputs.as_ref()); output_data.write_all(&fr_to_bytes_le(&hash))?; @@ -1110,7 +1076,7 @@ mod test { // We create a new tree let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // We first add leaves one by one specifying the index for (i, leaf) in leaves.iter().enumerate() { @@ -1149,7 +1115,7 @@ mod test { rln.set_tree(tree_height).unwrap(); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // We check if number of leaves set is consistent @@ -1205,10 +1171,10 @@ mod test { // We create a new tree let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // We check if number of leaves set is consistent @@ -1222,11 +1188,11 @@ mod test { // `init_tree_with_leaves` resets the tree to the height it was initialized with, using `set_tree` // We add leaves in a batch starting from index 0..set_index - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index])); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index]).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // We add the remaining n leaves in a batch starting from index m - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..])); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..]).unwrap()); rln.set_leaves_from(set_index, &mut buffer).unwrap(); // We check if number of leaves set is consistent @@ -1259,6 +1225,7 @@ mod test { assert_eq!(root_batch_with_init, root_single_additions); } + #[allow(unused_must_use)] #[test] // This test checks if `set_leaves_from` throws an error when the index is out of bounds fn test_set_leaves_bad_index() { @@ -1275,7 +1242,7 @@ mod test { // We create a new tree let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // Get root of empty tree let mut buffer = Cursor::new(Vec::::new()); @@ -1283,7 +1250,7 @@ mod test { let (root_empty, _) = bytes_le_to_fr(&buffer.into_inner()); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.set_leaves_from(bad_index, &mut buffer) .expect_err("Should throw an error"); @@ -1304,25 +1271,21 @@ mod test { let tree_height = TEST_TREE_HEIGHT; let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // Note: we only test Groth16 proof generation, so we ignore setting the tree in the RLN object let rln_witness = random_rln_witness(tree_height); let proof_values = proof_values_from_witness(&rln_witness); // We compute a Groth16 proof - let mut input_buffer = Cursor::new(serialize_witness(&rln_witness)); + let mut input_buffer = Cursor::new(serialize_witness(&rln_witness).unwrap()); let mut output_buffer = Cursor::new(Vec::::new()); rln.prove(&mut input_buffer, &mut output_buffer).unwrap(); let serialized_proof = output_buffer.into_inner(); // Before checking public verify API, we check that the (deserialized) proof generated by prove is actually valid let proof = ArkProof::deserialize(&mut Cursor::new(&serialized_proof)).unwrap(); - let verified = verify_proof( - &rln.verification_key.as_ref().unwrap(), - &proof, - &proof_values, - ); + let verified = verify_proof(&rln.verification_key, &proof, &proof_values); assert!(verified.unwrap()); // We prepare the input to prove API, consisting of serialized_proof (compressed, 4*32 bytes) || serialized_proof_values (6*32 bytes) @@ -1352,10 +1315,10 @@ mod test { // We create a new RLN instance let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // Generate identity pair @@ -1417,10 +1380,10 @@ mod test { // We create a new RLN instance let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // Generate identity pair @@ -1453,12 +1416,13 @@ mod test { // We read input RLN witness and we deserialize it let mut witness_byte: Vec = Vec::new(); input_buffer.read_to_end(&mut witness_byte).unwrap(); - let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte); + let (rln_witness, _) = proof_inputs_to_rln_witness(&mut rln.tree, &witness_byte).unwrap(); - let serialized_witness = serialize_witness(&rln_witness); + let serialized_witness = serialize_witness(&rln_witness).unwrap(); // Calculate witness outside zerokit (simulating what JS is doing) let inputs = inputs_for_witness_calculation(&rln_witness) + .unwrap() .into_iter() .map(|(name, values)| (name.to_string(), values)); let calculated_witness = rln @@ -1471,7 +1435,7 @@ mod test { let calculated_witness_vec: Vec = calculated_witness .into_iter() - .map(|v| to_bigint(&v)) + .map(|v| to_bigint(&v).unwrap()) .collect(); // Generating the proof @@ -1513,10 +1477,10 @@ mod test { // We create a new RLN instance let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // We add leaves in a batch into the tree - let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves)); + let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap()); rln.init_tree_with_leaves(&mut buffer).unwrap(); // Generate identity pair @@ -1600,7 +1564,7 @@ mod test { // We create a new RLN instance let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // Generate identity pair let (identity_secret_hash, id_commitment) = keygen(); diff --git a/rln/src/utils.rs b/rln/src/utils.rs index faae931..9eba031 100644 --- a/rln/src/utils.rs +++ b/rln/src/utils.rs @@ -2,13 +2,14 @@ use crate::circuit::Fr; use ark_ff::PrimeField; +use color_eyre::{Report, Result}; use num_bigint::{BigInt, BigUint}; use num_traits::Num; use std::iter::Extend; -pub fn to_bigint(el: &Fr) -> BigInt { - let res: BigUint = (*el).try_into().unwrap(); - res.try_into().unwrap() +pub fn to_bigint(el: &Fr) -> Result { + let res: BigUint = (*el).try_into()?; + Ok(res.into()) } pub fn fr_byte_size() -> usize { @@ -16,8 +17,10 @@ pub fn fr_byte_size() -> usize { (mbs + 64 - (mbs % 64)) / 8 } -pub fn str_to_fr(input: &str, radix: u32) -> Fr { - assert!((radix == 10) || (radix == 16)); +pub fn str_to_fr(input: &str, radix: u32) -> Result { + if !(radix == 10 || radix == 16) { + return Err(Report::msg("wrong radix")); + } // We remove any quote present and we trim let single_quote: char = '\"'; @@ -25,16 +28,10 @@ pub fn str_to_fr(input: &str, radix: u32) -> Fr { input_clean = input_clean.trim().to_string(); if radix == 10 { - BigUint::from_str_radix(&input_clean, radix) - .unwrap() - .try_into() - .unwrap() + Ok(BigUint::from_str_radix(&input_clean, radix)?.try_into()?) } else { input_clean = input_clean.replace("0x", ""); - BigUint::from_str_radix(&input_clean, radix) - .unwrap() - .try_into() - .unwrap() + Ok(BigUint::from_str_radix(&input_clean, radix)?.try_into()?) } } @@ -75,72 +72,73 @@ pub fn fr_to_bytes_be(input: &Fr) -> Vec { res } -pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec { +pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Result> { let mut bytes: Vec = Vec::new(); //We store the vector length - bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec()); + bytes.extend(u64::try_from(input.len())?.to_le_bytes().to_vec()); // We store each element input.iter().for_each(|el| bytes.extend(fr_to_bytes_le(el))); - bytes + Ok(bytes) } -pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Vec { +pub fn vec_fr_to_bytes_be(input: &[Fr]) -> Result> { let mut bytes: Vec = Vec::new(); //We store the vector length - bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec()); + bytes.extend(u64::try_from(input.len())?.to_be_bytes().to_vec()); // We store each element input.iter().for_each(|el| bytes.extend(fr_to_bytes_be(el))); - bytes + Ok(bytes) } -pub fn vec_u8_to_bytes_le(input: &[u8]) -> Vec { +pub fn vec_u8_to_bytes_le(input: &[u8]) -> Result> { let mut bytes: Vec = Vec::new(); //We store the vector length - bytes.extend(u64::try_from(input.len()).unwrap().to_le_bytes().to_vec()); + bytes.extend(u64::try_from(input.len())?.to_le_bytes().to_vec()); bytes.extend(input); - bytes + + Ok(bytes) } -pub fn vec_u8_to_bytes_be(input: Vec) -> Vec { - let mut bytes: Vec = Vec::new(); +pub fn vec_u8_to_bytes_be(input: Vec) -> Result> { //We store the vector length - bytes.extend(u64::try_from(input.len()).unwrap().to_be_bytes().to_vec()); + let mut bytes: Vec = u64::try_from(input.len())?.to_be_bytes().to_vec(); bytes.extend(input); - bytes + + Ok(bytes) } -pub fn bytes_le_to_vec_u8(input: &[u8]) -> (Vec, usize) { +pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec, usize)> { let mut read: usize = 0; - let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize; + let len = u64::from_le_bytes(input[0..8].try_into()?) as usize; read += 8; let res = input[8..8 + len].to_vec(); read += res.len(); - (res, read) + Ok((res, read)) } -pub fn bytes_be_to_vec_u8(input: &[u8]) -> (Vec, usize) { +pub fn bytes_be_to_vec_u8(input: &[u8]) -> Result<(Vec, usize)> { let mut read: usize = 0; - let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize; + let len = u64::from_be_bytes(input[0..8].try_into()?) as usize; read += 8; let res = input[8..8 + len].to_vec(); read += res.len(); - (res, read) + Ok((res, read)) } -pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec, usize) { +pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec, usize)> { let mut read: usize = 0; let mut res: Vec = Vec::new(); - let len = u64::from_le_bytes(input[0..8].try_into().unwrap()) as usize; + let len = u64::from_le_bytes(input[0..8].try_into()?) as usize; read += 8; let el_size = fr_byte_size(); @@ -150,14 +148,14 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> (Vec, usize) { read += el_size; } - (res, read) + Ok((res, read)) } -pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec, usize) { +pub fn bytes_be_to_vec_fr(input: &[u8]) -> Result<(Vec, usize)> { let mut read: usize = 0; let mut res: Vec = Vec::new(); - let len = u64::from_be_bytes(input[0..8].try_into().unwrap()) as usize; + let len = u64::from_be_bytes(input[0..8].try_into()?) as usize; read += 8; let el_size = fr_byte_size(); @@ -167,7 +165,7 @@ pub fn bytes_be_to_vec_fr(input: &[u8]) -> (Vec, usize) { read += el_size; } - (res, read) + Ok((res, read)) } /* Old conversion utilities between different libraries data types diff --git a/rln/tests/ffi.rs b/rln/tests/ffi.rs index cffe0f1..fc09edf 100644 --- a/rln/tests/ffi.rs +++ b/rln/tests/ffi.rs @@ -78,7 +78,7 @@ mod test { assert!(success, "set tree call failed"); // We add leaves in a batch into the tree - let leaves_ser = vec_fr_to_bytes_le(&leaves); + let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap(); let input_buffer = &Buffer::from(leaves_ser.as_ref()); let success = init_tree_with_leaves(rln_pointer, input_buffer); assert!(success, "init tree with leaves call failed"); @@ -153,7 +153,7 @@ mod test { let set_index = rng.gen_range(0..no_of_leaves) as usize; // We add leaves in a batch into the tree - let leaves_ser = vec_fr_to_bytes_le(&leaves); + let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap(); let input_buffer = &Buffer::from(leaves_ser.as_ref()); let success = init_tree_with_leaves(rln_pointer, input_buffer); assert!(success, "init tree with leaves call failed"); @@ -170,13 +170,13 @@ mod test { // `init_tree_with_leaves` resets the tree to the height it was initialized with, using `set_tree` // We add leaves in a batch starting from index 0..set_index - let leaves_m = vec_fr_to_bytes_le(&leaves[0..set_index]); + let leaves_m = vec_fr_to_bytes_le(&leaves[0..set_index]).unwrap(); let buffer = &Buffer::from(leaves_m.as_ref()); let success = init_tree_with_leaves(rln_pointer, buffer); assert!(success, "init tree with leaves call failed"); // We add the remaining n leaves in a batch starting from index set_index - let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]); + let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]).unwrap(); let buffer = &Buffer::from(leaves_n.as_ref()); let success = set_leaves_from(rln_pointer, set_index, buffer); assert!(success, "set leaves from call failed"); @@ -248,7 +248,7 @@ mod test { let (root_empty, _) = bytes_le_to_fr(&result_data); // We add leaves in a batch into the tree - let leaves = vec_fr_to_bytes_le(&leaves); + let leaves = vec_fr_to_bytes_le(&leaves).unwrap(); let buffer = &Buffer::from(leaves.as_ref()); let success = set_leaves_from(rln_pointer, bad_index, buffer); assert!(!success, "set leaves from call succeeded"); @@ -303,71 +303,86 @@ mod test { let output_buffer = unsafe { output_buffer.assume_init() }; let result_data = <&[u8]>::from(&output_buffer).to_vec(); - let (path_elements, read) = bytes_le_to_vec_fr(&result_data); - let (identity_path_index, _) = bytes_le_to_vec_u8(&result_data[read..].to_vec()); + let (path_elements, read) = bytes_le_to_vec_fr(&result_data).unwrap(); + let (identity_path_index, _) = bytes_le_to_vec_u8(&result_data[read..].to_vec()).unwrap(); // We check correct computation of the path and indexes let mut expected_path_elements = vec![ str_to_fr( "0x0000000000000000000000000000000000000000000000000000000000000000", 16, - ), + ) + .unwrap(), str_to_fr( "0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864", 16, - ), + ) + .unwrap(), str_to_fr( "0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1", 16, - ), + ) + .unwrap(), str_to_fr( "0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238", 16, - ), + ) + .unwrap(), str_to_fr( "0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a", 16, - ), + ) + .unwrap(), str_to_fr( "0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55", 16, - ), + ) + .unwrap(), str_to_fr( "0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78", 16, - ), + ) + .unwrap(), str_to_fr( "0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d", 16, - ), + ) + .unwrap(), str_to_fr( "0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61", 16, - ), + ) + .unwrap(), str_to_fr( "0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747", 16, - ), + ) + .unwrap(), str_to_fr( "0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2", 16, - ), + ) + .unwrap(), str_to_fr( "0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636", 16, - ), + ) + .unwrap(), str_to_fr( "0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a", 16, - ), + ) + .unwrap(), str_to_fr( "0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0", 16, - ), + ) + .unwrap(), str_to_fr( "0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c", 16, - ), + ) + .unwrap(), ]; let mut expected_identity_path_index: Vec = @@ -379,19 +394,23 @@ mod test { str_to_fr( "0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92", 16, - ), + ) + .unwrap(), str_to_fr( "0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323", 16, - ), + ) + .unwrap(), str_to_fr( "0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992", 16, - ), + ) + .unwrap(), str_to_fr( "0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f", 16, - ), + ) + .unwrap(), ]); expected_identity_path_index.append(&mut vec![0, 0, 0, 0]); } @@ -400,7 +419,8 @@ mod test { expected_path_elements.append(&mut vec![str_to_fr( "0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca", 16, - )]); + ) + .unwrap()]); expected_identity_path_index.append(&mut vec![0]); } @@ -439,7 +459,7 @@ mod test { let proof_values = proof_values_from_witness(&rln_witness); // We prepare id_commitment and we set the leaf at provided index - let rln_witness_ser = serialize_witness(&rln_witness); + let rln_witness_ser = serialize_witness(&rln_witness).unwrap(); let input_buffer = &Buffer::from(rln_witness_ser.as_ref()); let mut output_buffer = MaybeUninit::::uninit(); let now = Instant::now(); @@ -569,7 +589,7 @@ mod test { let rln_pointer = unsafe { &mut *rln_pointer.assume_init() }; // We add leaves in a batch into the tree - let leaves_ser = vec_fr_to_bytes_le(&leaves); + let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap(); let input_buffer = &Buffer::from(leaves_ser.as_ref()); let success = init_tree_with_leaves(rln_pointer, input_buffer); assert!(success, "init tree with leaves call failed"); @@ -654,7 +674,7 @@ mod test { let rln_pointer = unsafe { &mut *rln_pointer.assume_init() }; // We add leaves in a batch into the tree - let leaves_ser = vec_fr_to_bytes_le(&leaves); + let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap(); let input_buffer = &Buffer::from(leaves_ser.as_ref()); let success = init_tree_with_leaves(rln_pointer, input_buffer); assert!(success, "set leaves call failed"); @@ -957,16 +977,15 @@ mod test { assert_eq!( identity_secret_hash, - expected_identity_secret_hash_seed_bytes + expected_identity_secret_hash_seed_bytes.unwrap() ); - assert_eq!(id_commitment, expected_id_commitment_seed_bytes); + assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap()); } #[test] // Tests hash to field using FFI APIs fn test_seeded_extended_keygen_ffi() { let tree_height = TEST_TREE_HEIGHT; - // We create a RLN instance let mut rln_pointer = MaybeUninit::<*mut RLN>::uninit(); let input_buffer = &Buffer::from(TEST_RESOURCES_FOLDER.as_bytes()); @@ -1004,13 +1023,19 @@ mod test { 16, ); - assert_eq!(identity_trapdoor, expected_identity_trapdoor_seed_bytes); - assert_eq!(identity_nullifier, expected_identity_nullifier_seed_bytes); + assert_eq!( + identity_trapdoor, + expected_identity_trapdoor_seed_bytes.unwrap() + ); + assert_eq!( + identity_nullifier, + expected_identity_nullifier_seed_bytes.unwrap() + ); assert_eq!( identity_secret_hash, - expected_identity_secret_hash_seed_bytes + expected_identity_secret_hash_seed_bytes.unwrap() ); - assert_eq!(id_commitment, expected_id_commitment_seed_bytes); + assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap()); } #[test] @@ -1045,7 +1070,7 @@ mod test { for _ in 0..number_of_inputs { inputs.push(Fr::rand(&mut rng)); } - let inputs_ser = vec_fr_to_bytes_le(&inputs); + let inputs_ser = vec_fr_to_bytes_le(&inputs).unwrap(); let input_buffer = &Buffer::from(inputs_ser.as_ref()); let expected_hash = utils_poseidon_hash(inputs.as_ref()); diff --git a/rln/tests/protocol.rs b/rln/tests/protocol.rs index cef5118..7e69689 100644 --- a/rln/tests/protocol.rs +++ b/rln/tests/protocol.rs @@ -184,6 +184,7 @@ mod test { "0x1984f2e01184aef5cb974640898a5f5c25556554e2b06d99d4841badb8b198cd", 16 ) + .unwrap() ); } else if TEST_TREE_HEIGHT == 19 { assert_eq!( @@ -192,6 +193,7 @@ mod test { "0x219ceb53f2b1b7a6cf74e80d50d44d68ecb4a53c6cc65b25593c8d56343fb1fe", 16 ) + .unwrap() ); } else if TEST_TREE_HEIGHT == 20 { assert_eq!( @@ -200,6 +202,7 @@ mod test { "0x21947ffd0bce0c385f876e7c97d6a42eec5b1fe935aab2f01c1f8a8cbcc356d2", 16 ) + .unwrap() ); } @@ -213,63 +216,78 @@ mod test { str_to_fr( "0x0000000000000000000000000000000000000000000000000000000000000000", 16, - ), + ) + .unwrap(), str_to_fr( "0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864", 16, - ), + ) + .unwrap(), str_to_fr( "0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1", 16, - ), + ) + .unwrap(), str_to_fr( "0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238", 16, - ), + ) + .unwrap(), str_to_fr( "0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a", 16, - ), + ) + .unwrap(), str_to_fr( "0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55", 16, - ), + ) + .unwrap(), str_to_fr( "0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78", 16, - ), + ) + .unwrap(), str_to_fr( "0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d", 16, - ), + ) + .unwrap(), str_to_fr( "0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61", 16, - ), + ) + .unwrap(), str_to_fr( "0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747", 16, - ), + ) + .unwrap(), str_to_fr( "0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2", 16, - ), + ) + .unwrap(), str_to_fr( "0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636", 16, - ), + ) + .unwrap(), str_to_fr( "0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a", 16, - ), + ) + .unwrap(), str_to_fr( "0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0", 16, - ), + ) + .unwrap(), str_to_fr( "0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c", 16, - ), + ) + .unwrap(), ]; let mut expected_identity_path_index: Vec = @@ -281,19 +299,23 @@ mod test { str_to_fr( "0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92", 16, - ), + ) + .unwrap(), str_to_fr( "0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323", 16, - ), + ) + .unwrap(), str_to_fr( "0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992", 16, - ), + ) + .unwrap(), str_to_fr( "0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f", 16, - ), + ) + .unwrap(), ]); expected_identity_path_index.append(&mut vec![0, 0, 0, 0]); } @@ -302,7 +324,8 @@ mod test { expected_path_elements.append(&mut vec![str_to_fr( "0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca", 16, - )]); + ) + .unwrap()]); expected_identity_path_index.append(&mut vec![0]); } @@ -319,7 +342,7 @@ mod test { // We generate all relevant keys let proving_key = zkey_from_folder(TEST_RESOURCES_FOLDER).unwrap(); let verification_key = vk_from_folder(TEST_RESOURCES_FOLDER).unwrap(); - let builder = circom_from_folder(TEST_RESOURCES_FOLDER); + let builder = circom_from_folder(TEST_RESOURCES_FOLDER).unwrap(); // We compute witness from the json input example let mut witness_json: &str = ""; @@ -334,10 +357,12 @@ mod test { let rln_witness = rln_witness_from_json(witness_json); - // Let's generate a zkSNARK proof - let proof = generate_proof(builder, &proving_key, &rln_witness).unwrap(); + let rln_witness_unwrapped = rln_witness.unwrap(); - let proof_values = proof_values_from_witness(&rln_witness); + // Let's generate a zkSNARK proof + let proof = generate_proof(builder, &proving_key, &rln_witness_unwrapped).unwrap(); + + let proof_values = proof_values_from_witness(&rln_witness_unwrapped); // Let's verify the proof let verified = verify_proof(&verification_key, &proof, &proof_values); @@ -378,7 +403,7 @@ mod test { // We generate all relevant keys let proving_key = zkey_from_folder(TEST_RESOURCES_FOLDER).unwrap(); let verification_key = vk_from_folder(TEST_RESOURCES_FOLDER).unwrap(); - let builder = circom_from_folder(TEST_RESOURCES_FOLDER); + let builder = circom_from_folder(TEST_RESOURCES_FOLDER).unwrap(); // Let's generate a zkSNARK proof let proof = generate_proof(builder, &proving_key, &rln_witness).unwrap(); @@ -404,10 +429,10 @@ mod test { witness_json = WITNESS_JSON_20; } - let rln_witness = rln_witness_from_json(witness_json); + let rln_witness = rln_witness_from_json(witness_json).unwrap(); - let ser = serialize_witness(&rln_witness); - let (deser, _) = deserialize_witness(&ser); + let ser = serialize_witness(&rln_witness).unwrap(); + let (deser, _) = deserialize_witness(&ser).unwrap(); assert_eq!(rln_witness, deser); // We test Proof values serialization @@ -429,11 +454,13 @@ mod test { let expected_identity_secret_hash_seed_phrase = str_to_fr( "0x20df38f3f00496f19fe7c6535492543b21798ed7cb91aebe4af8012db884eda3", 16, - ); + ) + .unwrap(); let expected_id_commitment_seed_phrase = str_to_fr( "0x1223a78a5d66043a7f9863e14507dc80720a5602b2a894923e5b5147d5a9c325", 16, - ); + ) + .unwrap(); assert_eq!( identity_secret_hash, @@ -449,11 +476,13 @@ mod test { let expected_identity_secret_hash_seed_bytes = str_to_fr( "0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716", 16, - ); + ) + .unwrap(); let expected_id_commitment_seed_bytes = str_to_fr( "0xbf16d2b5c0d6f9d9d561e05bfca16a81b4b873bb063508fae360d8c74cef51f", 16, - ); + ) + .unwrap(); assert_eq!( identity_secret_hash, diff --git a/rln/tests/public.rs b/rln/tests/public.rs index 4826bf5..e2c8be8 100644 --- a/rln/tests/public.rs +++ b/rln/tests/public.rs @@ -16,7 +16,7 @@ mod test { let leaf_index = 3; let input_buffer = Cursor::new(TEST_RESOURCES_FOLDER); - let mut rln = RLN::new(tree_height, input_buffer); + let mut rln = RLN::new(tree_height, input_buffer).unwrap(); // generate identity let identity_secret_hash = hash_to_field(b"test-merkle-proof"); @@ -38,6 +38,7 @@ mod test { "0x1984f2e01184aef5cb974640898a5f5c25556554e2b06d99d4841badb8b198cd", 16 ) + .unwrap() ); } else if TEST_TREE_HEIGHT == 19 { assert_eq!( @@ -46,6 +47,7 @@ mod test { "0x219ceb53f2b1b7a6cf74e80d50d44d68ecb4a53c6cc65b25593c8d56343fb1fe", 16 ) + .unwrap() ); } else if TEST_TREE_HEIGHT == 20 { assert_eq!( @@ -54,6 +56,7 @@ mod test { "0x21947ffd0bce0c385f876e7c97d6a42eec5b1fe935aab2f01c1f8a8cbcc356d2", 16 ) + .unwrap() ); } @@ -62,71 +65,86 @@ mod test { rln.get_proof(leaf_index, &mut buffer).unwrap(); let buffer_inner = buffer.into_inner(); - let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner); - let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec()); + let (path_elements, read) = bytes_le_to_vec_fr(&buffer_inner).unwrap(); + let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec()).unwrap(); // We check correct computation of the path and indexes let mut expected_path_elements = vec![ str_to_fr( "0x0000000000000000000000000000000000000000000000000000000000000000", 16, - ), + ) + .unwrap(), str_to_fr( "0x2098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864", 16, - ), + ) + .unwrap(), str_to_fr( "0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1", 16, - ), + ) + .unwrap(), str_to_fr( "0x18f43331537ee2af2e3d758d50f72106467c6eea50371dd528d57eb2b856d238", 16, - ), + ) + .unwrap(), str_to_fr( "0x07f9d837cb17b0d36320ffe93ba52345f1b728571a568265caac97559dbc952a", 16, - ), + ) + .unwrap(), str_to_fr( "0x2b94cf5e8746b3f5c9631f4c5df32907a699c58c94b2ad4d7b5cec1639183f55", 16, - ), + ) + .unwrap(), str_to_fr( "0x2dee93c5a666459646ea7d22cca9e1bcfed71e6951b953611d11dda32ea09d78", 16, - ), + ) + .unwrap(), str_to_fr( "0x078295e5a22b84e982cf601eb639597b8b0515a88cb5ac7fa8a4aabe3c87349d", 16, - ), + ) + .unwrap(), str_to_fr( "0x2fa5e5f18f6027a6501bec864564472a616b2e274a41211a444cbe3a99f3cc61", 16, - ), + ) + .unwrap(), str_to_fr( "0x0e884376d0d8fd21ecb780389e941f66e45e7acce3e228ab3e2156a614fcd747", 16, - ), + ) + .unwrap(), str_to_fr( "0x1b7201da72494f1e28717ad1a52eb469f95892f957713533de6175e5da190af2", 16, - ), + ) + .unwrap(), str_to_fr( "0x1f8d8822725e36385200c0b201249819a6e6e1e4650808b5bebc6bface7d7636", 16, - ), + ) + .unwrap(), str_to_fr( "0x2c5d82f66c914bafb9701589ba8cfcfb6162b0a12acf88a8d0879a0471b5f85a", 16, - ), + ) + .unwrap(), str_to_fr( "0x14c54148a0940bb820957f5adf3fa1134ef5c4aaa113f4646458f270e0bfbfd0", 16, - ), + ) + .unwrap(), str_to_fr( "0x190d33b12f986f961e10c0ee44d8b9af11be25588cad89d416118e4bf4ebe80c", 16, - ), + ) + .unwrap(), ]; let mut expected_identity_path_index: Vec = @@ -138,19 +156,23 @@ mod test { str_to_fr( "0x22f98aa9ce704152ac17354914ad73ed1167ae6596af510aa5b3649325e06c92", 16, - ), + ) + .unwrap(), str_to_fr( "0x2a7c7c9b6ce5880b9f6f228d72bf6a575a526f29c66ecceef8b753d38bba7323", 16, - ), + ) + .unwrap(), str_to_fr( "0x2e8186e558698ec1c67af9c14d463ffc470043c9c2988b954d75dd643f36b992", 16, - ), + ) + .unwrap(), str_to_fr( "0x0f57c5571e9a4eab49e2c8cf050dae948aef6ead647392273546249d1c1ff10f", 16, - ), + ) + .unwrap(), ]); expected_identity_path_index.append(&mut vec![0, 0, 0, 0]); } @@ -159,7 +181,8 @@ mod test { expected_path_elements.append(&mut vec![str_to_fr( "0x1830ee67b5fb554ad5f63d4388800e1cfe78e310697d46e43c9ce36134f72cca", 16, - )]); + ) + .unwrap()]); expected_identity_path_index.append(&mut vec![0]); } @@ -193,11 +216,13 @@ mod test { let expected_identity_secret_hash_seed_bytes = str_to_fr( "0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716", 16, - ); + ) + .unwrap(); let expected_id_commitment_seed_bytes = str_to_fr( "0xbf16d2b5c0d6f9d9d561e05bfca16a81b4b873bb063508fae360d8c74cef51f", 16, - ); + ) + .unwrap(); assert_eq!( identity_secret_hash, @@ -226,19 +251,23 @@ mod test { let expected_identity_trapdoor_seed_bytes = str_to_fr( "0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716", 16, - ); + ) + .unwrap(); let expected_identity_nullifier_seed_bytes = str_to_fr( "0x1f18714c7bc83b5bca9e89d404cf6f2f585bc4c0f7ed8b53742b7e2b298f50b4", 16, - ); + ) + .unwrap(); let expected_identity_secret_hash_seed_bytes = str_to_fr( "0x2aca62aaa7abaf3686fff2caf00f55ab9462dc12db5b5d4bcf3994e671f8e521", 16, - ); + ) + .unwrap(); let expected_id_commitment_seed_bytes = str_to_fr( "0x68b66aa0a8320d2e56842581553285393188714c48f9b17acd198b4f1734c5c", 16, - ); + ) + .unwrap(); assert_eq!(identity_trapdoor, expected_identity_trapdoor_seed_bytes); assert_eq!(identity_nullifier, expected_identity_nullifier_seed_bytes); @@ -276,7 +305,7 @@ mod test { } let expected_hash = utils_poseidon_hash(&inputs); - let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs)); + let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs).unwrap()); let mut output_buffer = Cursor::new(Vec::::new()); public_poseidon_hash(&mut input_buffer, &mut output_buffer).unwrap(); diff --git a/semaphore/src/protocol.rs b/semaphore/src/protocol.rs index 724e96d..783e28b 100644 --- a/semaphore/src/protocol.rs +++ b/semaphore/src/protocol.rs @@ -12,7 +12,7 @@ use ark_groth16::{ }; use ark_relations::r1cs::SynthesisError; use ark_std::UniformRand; -use color_eyre::Result; +use color_eyre::{Report, Result}; use ethers_core::types::U256; use rand::{thread_rng, Rng}; use semaphore::{ @@ -89,7 +89,7 @@ pub enum ProofError { #[error("Error reading circuit key: {0}")] CircuitKeyError(#[from] std::io::Error), #[error("Error producing witness: {0}")] - WitnessError(color_eyre::Report), + WitnessError(Report), #[error("Error producing proof: {0}")] SynthesisError(#[from] SynthesisError), #[error("Error converting public input: {0}")] diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 9d64bfe..209b963 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] ark-ff = { version = "0.3.0", default-features = false, features = ["asm"] } num-bigint = { version = "0.4.3", default-features = false, features = ["rand"] } +color-eyre = "0.6.1" [dev-dependencies] ark-bn254 = { version = "0.3.0" } diff --git a/utils/src/merkle_tree/merkle_tree.rs b/utils/src/merkle_tree/merkle_tree.rs index 3e1c89f..daa0abb 100644 --- a/utils/src/merkle_tree/merkle_tree.rs +++ b/utils/src/merkle_tree/merkle_tree.rs @@ -16,13 +16,14 @@ #![allow(dead_code)] use std::collections::HashMap; -use std::io; use std::{ cmp::max, fmt::Debug, iter::{once, repeat, successors}, }; +use color_eyre::{Report, Result}; + /// In the Hasher trait we define the node type, the default leaf /// and the hash function used to initialize a Merkle Tree implementation pub trait Hasher { @@ -114,15 +115,12 @@ impl OptimalMerkleTree { } // Sets a leaf at the specified tree index - pub fn set(&mut self, index: usize, leaf: H::Fr) -> io::Result<()> { + pub fn set(&mut self, index: usize, leaf: H::Fr) -> Result<()> { if index >= self.capacity() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "index exceeds set size", - )); + return Err(Report::msg("index exceeds set size")); } self.nodes.insert((self.depth, index), leaf); - self.recalculate_from(index); + self.recalculate_from(index)?; self.next_index = max(self.next_index, index + 1); Ok(()) } @@ -132,31 +130,28 @@ impl OptimalMerkleTree { &mut self, start: usize, leaves: I, - ) -> io::Result<()> { + ) -> Result<()> { let leaves = leaves.into_iter().collect::>(); // check if the range is valid if start + leaves.len() > self.capacity() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "provided range exceeds set size", - )); + return Err(Report::msg("provided range exceeds set size")); } for (i, leaf) in leaves.iter().enumerate() { self.nodes.insert((self.depth, start + i), *leaf); - self.recalculate_from(start + i); + self.recalculate_from(start + i)?; } self.next_index = max(self.next_index, start + leaves.len()); Ok(()) } // Sets a leaf at the next available index - pub fn update_next(&mut self, leaf: H::Fr) -> io::Result<()> { + pub fn update_next(&mut self, leaf: H::Fr) -> Result<()> { self.set(self.next_index, leaf)?; Ok(()) } // Deletes a leaf at a certain index by setting it to its default value (next_index is not updated) - pub fn delete(&mut self, index: usize) -> io::Result<()> { + pub fn delete(&mut self, index: usize) -> Result<()> { // We reset the leaf only if we previously set a leaf at that index if index < self.next_index { self.set(index, H::default_leaf())?; @@ -165,12 +160,9 @@ impl OptimalMerkleTree { } // Computes a merkle proof the the leaf at the specified index - pub fn proof(&self, index: usize) -> io::Result> { + pub fn proof(&self, index: usize) -> Result> { if index >= self.capacity() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "index exceeds set size", - )); + return Err(Report::msg("index exceeds set size")); } let mut witness = Vec::<(H::Fr, u8)>::with_capacity(self.depth); let mut i = index; @@ -184,17 +176,17 @@ impl OptimalMerkleTree { break; } } - assert_eq!(i, 0); - Ok(OptimalMerkleProof(witness)) + if i != 0 { + Err(Report::msg("i != 0")) + } else { + Ok(OptimalMerkleProof(witness)) + } } // Verifies a Merkle proof with respect to the input leaf and the tree root - pub fn verify(&self, leaf: &H::Fr, witness: &OptimalMerkleProof) -> io::Result { + pub fn verify(&self, leaf: &H::Fr, witness: &OptimalMerkleProof) -> Result { if witness.length() != self.depth { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "witness length doesn't match tree depth", - )); + return Err(Report::msg("witness length doesn't match tree depth")); } let expected_root = witness.compute_root_from(leaf); Ok(expected_root.eq(&self.root())) @@ -219,7 +211,7 @@ impl OptimalMerkleTree { H::hash(&[self.get_node(depth, b), self.get_node(depth, b + 1)]) } - fn recalculate_from(&mut self, index: usize) { + fn recalculate_from(&mut self, index: usize) -> Result<()> { let mut i = index; let mut depth = self.depth; loop { @@ -231,8 +223,13 @@ impl OptimalMerkleTree { break; } } - assert_eq!(depth, 0); - assert_eq!(i, 0); + if depth != 0 { + return Err(Report::msg("did not reach the depth")); + } + if i != 0 { + return Err(Report::msg("did not go through all indexes")); + } + Ok(()) } } @@ -387,7 +384,7 @@ impl FullMerkleTree { } // Sets a leaf at the specified tree index - pub fn set(&mut self, leaf: usize, hash: H::Fr) -> io::Result<()> { + pub fn set(&mut self, leaf: usize, hash: H::Fr) -> Result<()> { self.set_range(leaf, once(hash))?; self.next_index = max(self.next_index, leaf + 1); Ok(()) @@ -395,41 +392,34 @@ impl FullMerkleTree { // Sets tree nodes, starting from start index // Function proper of FullMerkleTree implementation - fn set_range>( - &mut self, - start: usize, - hashes: I, - ) -> io::Result<()> { + fn set_range>(&mut self, start: usize, hashes: I) -> Result<()> { let index = self.capacity() + start - 1; let mut count = 0; // first count number of hashes, and check that they fit in the tree // then insert into the tree let hashes = hashes.into_iter().collect::>(); if hashes.len() + start > self.capacity() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "provided hashes do not fit in the tree", - )); + return Err(Report::msg("provided hashes do not fit in the tree")); } hashes.into_iter().for_each(|hash| { self.nodes[index + count] = hash; count += 1; }); if count != 0 { - self.update_nodes(index, index + (count - 1)); + self.update_nodes(index, index + (count - 1))?; self.next_index = max(self.next_index, start + count); } Ok(()) } // Sets a leaf at the next available index - pub fn update_next(&mut self, leaf: H::Fr) -> io::Result<()> { + pub fn update_next(&mut self, leaf: H::Fr) -> Result<()> { self.set(self.next_index, leaf)?; Ok(()) } // Deletes a leaf at a certain index by setting it to its default value (next_index is not updated) - pub fn delete(&mut self, index: usize) -> io::Result<()> { + pub fn delete(&mut self, index: usize) -> Result<()> { // We reset the leaf only if we previously set a leaf at that index if index < self.next_index { self.set(index, H::default_leaf())?; @@ -438,12 +428,9 @@ impl FullMerkleTree { } // Computes a merkle proof the the leaf at the specified index - pub fn proof(&self, leaf: usize) -> io::Result> { + pub fn proof(&self, leaf: usize) -> Result> { if leaf >= self.capacity() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "index exceeds set size", - )); + return Err(Report::msg("index exceeds set size")); } let mut index = self.capacity() + leaf - 1; let mut path = Vec::with_capacity(self.depth + 1); @@ -460,7 +447,7 @@ impl FullMerkleTree { } // Verifies a Merkle proof with respect to the input leaf and the tree root - pub fn verify(&self, hash: &H::Fr, proof: &FullMerkleProof) -> io::Result { + pub fn verify(&self, hash: &H::Fr, proof: &FullMerkleProof) -> Result { Ok(proof.compute_root_from(hash) == self.root()) } @@ -487,15 +474,18 @@ impl FullMerkleTree { (index + 2).next_power_of_two().trailing_zeros() as usize - 1 } - fn update_nodes(&mut self, start: usize, end: usize) { - debug_assert_eq!(self.levels(start), self.levels(end)); + fn update_nodes(&mut self, start: usize, end: usize) -> Result<()> { + if self.levels(start) != self.levels(end) { + return Err(Report::msg("self.levels(start) != self.levels(end)")); + } if let (Some(start), Some(end)) = (self.parent(start), self.parent(end)) { for parent in start..=end { let child = self.first_child(parent); self.nodes[parent] = H::hash(&[self.nodes[child], self.nodes[child + 1]]); } - self.update_nodes(start, end); + self.update_nodes(start, end)?; } + Ok(()) } }