diff --git a/rln/src/circuit/iden3calc.rs b/rln/src/circuit/iden3calc.rs index 0389e29..f309aed 100644 --- a/rln/src/circuit/iden3calc.rs +++ b/rln/src/circuit/iden3calc.rs @@ -107,3 +107,34 @@ fn get_inputs_buffer(size: usize) -> Vec { inputs[0] = U256::from(1); inputs } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use super::*; + + #[test] + fn test_populate_inputs_missing() { + let mut input_list: HashMap> = HashMap::new(); + input_list.insert("missing".to_string(), vec![U256::from(1u64)]); + + let input_info: InputSignalsInfo = HashMap::new(); + let mut buffer = vec![U256::ZERO; 2]; + let err = populate_inputs(&input_list, &input_info, &mut buffer).unwrap_err(); + assert!(matches!(err, WitnessCalcError::MissingInput(_))); + } + + #[test] + fn test_populate_inputs_length_mismatch() { + let mut input_list: HashMap> = HashMap::new(); + input_list.insert("sig".to_string(), vec![U256::from(1u64)]); + + let mut input_info: InputSignalsInfo = HashMap::new(); + input_info.insert("sig".to_string(), (0, 2)); + + let mut buffer = vec![U256::ZERO; 2]; + let err = populate_inputs(&input_list, &input_info, &mut buffer).unwrap_err(); + assert!(matches!(err, WitnessCalcError::InvalidInputLength { .. })); + } +} diff --git a/rln/src/circuit/iden3calc/graph.rs b/rln/src/circuit/iden3calc/graph.rs index a1e3b10..2987442 100644 --- a/rln/src/circuit/iden3calc/graph.rs +++ b/rln/src/circuit/iden3calc/graph.rs @@ -430,6 +430,7 @@ mod test { use std::{ops::Div, str::FromStr}; use ruint::uint; + use serde_json; use super::*; @@ -566,6 +567,105 @@ mod test { assert_eq!(result, uint!(1_U256)); } + #[test] + fn test_serde_mont_constant_roundtrip() { + let node = Node::MontConstant(Fr::from(42u64)); + let encoded = serde_json::to_vec(&node).unwrap(); + let decoded: Node = serde_json::from_slice(&encoded).unwrap(); + assert_eq!(node, decoded); + } + + #[test] + fn test_eval_zero_divisors() { + let zero = Fr::zero(); + let a = Fr::from(7u64); + assert_eq!(Operation::Div.eval_fr(a, zero).unwrap(), Fr::zero()); + assert_eq!(Operation::Idiv.eval_fr(a, zero).unwrap(), Fr::zero()); + assert_eq!(Operation::Mod.eval_fr(a, zero).unwrap(), Fr::zero()); + } + + #[test] + fn test_eval_pow_and_comparisons() { + let a = Fr::from(2u64); + let b = Fr::from(5u64); + assert_eq!(Operation::Pow.eval_fr(a, b).unwrap(), Fr::from(32u64)); + + let a = Fr::from(2u64); + let b = Fr::from(3u64); + assert_eq!(Operation::Eq.eval_fr(a, b).unwrap(), Fr::zero()); + assert_eq!(Operation::Neq.eval_fr(a, b).unwrap(), Fr::one()); + assert_eq!(Operation::Lt.eval_fr(a, b).unwrap(), Fr::one()); + assert_eq!(Operation::Gt.eval_fr(a, b).unwrap(), Fr::zero()); + assert_eq!(Operation::Leq.eval_fr(a, b).unwrap(), Fr::one()); + assert_eq!(Operation::Geq.eval_fr(a, b).unwrap(), Fr::zero()); + + let zero = Fr::zero(); + let one = Fr::one(); + assert_eq!(Operation::Land.eval_fr(zero, one).unwrap(), Fr::zero()); + assert_eq!(Operation::Lor.eval_fr(zero, one).unwrap(), Fr::one()); + } + + #[test] + fn test_shifts_edges() { + let a = Fr::from(5u64); + let b = Fr::zero(); + assert_eq!(shl(a, b).unwrap(), a); + assert_eq!( + shl(a, Fr::from(Fr::MODULUS_BIT_SIZE as u64)).unwrap(), + Fr::zero() + ); + + let b = Fr::zero(); + assert_eq!(shr(a, b).unwrap(), a); + assert_eq!(shr(a, Fr::from(254u64)).unwrap(), Fr::zero()); + + let a = Fr::from(1u64); + assert_eq!(shr(a, Fr::from(64u64)).unwrap(), Fr::zero()); + } + + #[test] + fn test_uno_id_error() { + let err = UnoOperation::Id.eval_fr(Fr::from(1u64)).unwrap_err(); + assert!(err.contains("not implemented")); + } + + #[test] + fn test_evaluate_u256_to_fr_error() { + let nodes = vec![Node::Input(0)]; + let bad = U256::from_limbs(Fr::MODULUS.0); + let inputs = vec![bad]; + let outputs = vec![0usize]; + let err = evaluate(&nodes, &inputs, &outputs).unwrap_err(); + assert!(err.contains("Failed to convert U256 to Fr")); + } + + #[test] + fn test_u_comparisons_sign() { + let pos = uint!(1_U256); + let neg = HALF_M + uint!(1_U256); + let neg2 = HALF_M + uint!(2_U256); + + assert_eq!(u_lt(&pos, &neg), uint!(0_U256)); + assert_eq!(u_gt(&pos, &neg), uint!(1_U256)); + assert_eq!(u_lte(&pos, &neg), uint!(0_U256)); + assert_eq!(u_gte(&pos, &neg), uint!(1_U256)); + + assert_eq!(u_lt(&neg, &pos), uint!(1_U256)); + assert_eq!(u_gt(&neg, &pos), uint!(0_U256)); + + assert_eq!(u_lt(&neg2, &neg), uint!(0_U256)); + assert_eq!(u_gt(&neg2, &neg), uint!(1_U256)); + } + + #[test] + fn test_bitwise_ops_basic() { + let a = Fr::from(5u64); + let b = Fr::from(3u64); + assert_eq!(bit_or(a, b).unwrap(), Fr::from(7u64)); + assert_eq!(bit_xor(a, b).unwrap(), Fr::from(6u64)); + assert_eq!(bit_and(a, b).unwrap(), Fr::from(1u64)); + } + #[test] fn test_x() { let x = M.div(uint!(2_U256)); diff --git a/rln/src/circuit/iden3calc/storage.rs b/rln/src/circuit/iden3calc/storage.rs index d864b50..f549fc8 100644 --- a/rln/src/circuit/iden3calc/storage.rs +++ b/rln/src/circuit/iden3calc/storage.rs @@ -525,4 +525,87 @@ mod test { assert_eq!(metadata, metadata_want); } + + #[test] + fn test_try_from_errors() { + let node = proto::Node { node: None }; + let err = graph::Node::try_from(node).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + + let node = proto::Node { + node: Some(proto::node::Node::Constant(proto::ConstantNode { + value: None, + })), + }; + let err = graph::Node::try_from(node).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + + let node = proto::Node { + node: Some(proto::node::Node::UnoOp(proto::UnoOpNode { + op: 999, + a_idx: 0, + })), + }; + let err = graph::Node::try_from(node).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + + let node = proto::Node { + node: Some(proto::node::Node::DuoOp(proto::DuoOpNode { + op: 999, + a_idx: 0, + b_idx: 1, + })), + }; + let err = graph::Node::try_from(node).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + + let node = proto::Node { + node: Some(proto::node::Node::TresOp(proto::TresOpNode { + op: 999, + a_idx: 0, + b_idx: 1, + c_idx: 2, + })), + }; + let err = graph::Node::try_from(node).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + } + + #[test] + #[should_panic( + expected = "We are not supposed to write Constant to the witnesscalc graph. All Constant should be converted to MontConstant." + )] + fn test_proto_node_from_constant_panics() { + let _ = proto::node::Node::from(&graph::Node::Constant(ruint::aliases::U256::from(1u64))); + } + + #[test] + fn test_read_message_errors() { + let mut rw = WriteBackReader::new(std::io::Cursor::new(&[])); + let err = read_message_length(&mut rw).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); + + let mut buf = Vec::new(); + prost::encode_length_delimiter(5, &mut buf).unwrap(); + buf.extend_from_slice(&[1u8, 2]); + let mut rw = WriteBackReader::new(std::io::Cursor::new(&buf)); + let err = read_message::<_, proto::Node>(&mut rw).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); + } + + #[test] + fn test_deserialize_invalid_magic() { + let bad = vec![b'x'; WITNESSCALC_GRAPH_MAGIC.len()]; + let err = deserialize_witnesscalc_graph(std::io::Cursor::new(&bad)).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + } + + #[test] + fn test_write_back_reader_empty_read_and_flush() { + let mut rw = WriteBackReader::new(std::io::Cursor::new(&[])); + let mut buf = []; + let n = rw.read(&mut buf).unwrap(); + assert_eq!(n, 0); + rw.flush().unwrap(); + } } diff --git a/rln/src/circuit/mod.rs b/rln/src/circuit/mod.rs index c473367..d66a260 100644 --- a/rln/src/circuit/mod.rs +++ b/rln/src/circuit/mod.rs @@ -200,3 +200,28 @@ fn read_arkzkey_from_bytes_uncompressed(arkzkey_data: &[u8]) -> Result { + num_instance_variables: 1, + num_witness_variables: 1, + num_constraints: 1, + a_num_non_zero: 1, + b_num_non_zero: 1, + c_num_non_zero: 0, + a: vec![vec![(Fr::one(), 0)]], + b: vec![vec![(Fr::one(), 1)]], + c: vec![vec![]], + }; + + let full_assignment = vec![Fr::one(), Fr::from(3u64)]; + let res = CircomReduction::witness_map_from_matrices::>( + &matrices, + 1, + 1, + &full_assignment, + ) + .unwrap(); + + assert_eq!(res.len(), 2); + assert!(res.iter().all(|v| v.is_zero())); + } + + #[test] + fn test_h_query_scalars_length() { + let max_power = 2usize; + let domain = GeneralEvaluationDomain::::new(2 * max_power + 1).expect("valid domain"); + let res = CircomReduction::h_query_scalars::>( + 2, + Fr::from(5u64), + Fr::from(1u64), + Fr::from(3u64), + ) + .unwrap(); + + assert_eq!(res.len(), domain.size() / 2); + } +} diff --git a/rln/tests/protocol.rs b/rln/tests/protocol.rs index 10d6218..a310098 100644 --- a/rln/tests/protocol.rs +++ b/rln/tests/protocol.rs @@ -82,7 +82,7 @@ mod test { assert!(tree.verify(&rate_commitment, &merkle_proof).unwrap()); } - fn get_test_witness() -> RLNWitnessInput { + fn get_test_witness_and_root() -> (RLNWitnessInput, Fr) { let leaf_index = 3; // Generate identity pair let (identity_secret, id_commitment) = keygen().unwrap(); @@ -98,6 +98,7 @@ mod test { ) .unwrap(); tree.set(leaf_index, rate_commitment).unwrap(); + let root = tree.root(); let merkle_proof = tree.proof(leaf_index).unwrap(); @@ -111,6 +112,57 @@ mod test { let message_id = Fr::from(1); + RLNWitnessInput::new( + identity_secret, + user_message_limit, + message_id, + merkle_proof.get_path_elements(), + merkle_proof.get_path_index(), + x, + external_nullifier, + ) + .map(|witness| (witness, root)) + .unwrap() + } + + fn get_test_witness() -> RLNWitnessInput { + get_test_witness_and_root().0 + } + + fn get_test_witness_with_params( + signal: &[u8], + epoch: &[u8], + rln_identifier: &[u8], + message_id: u64, + user_message_limit: u64, + ) -> RLNWitnessInput { + let leaf_index = 3; + // Generate identity pair + let (identity_secret, id_commitment) = keygen().unwrap(); + let user_message_limit = Fr::from(user_message_limit); + let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]).unwrap(); + + //// generate merkle tree + let default_leaf = Fr::from(0); + let mut tree = PoseidonTree::new( + DEFAULT_TREE_DEPTH, + default_leaf, + ConfigOf::::default(), + ) + .unwrap(); + tree.set(leaf_index, rate_commitment).unwrap(); + + let merkle_proof = tree.proof(leaf_index).unwrap(); + + let x = hash_to_field_le(signal).unwrap(); + + // We set the remaining values to random ones + let epoch = hash_to_field_le(epoch).unwrap(); + let rln_identifier = hash_to_field_le(rln_identifier).unwrap(); + let external_nullifier = poseidon_hash(&[epoch, rln_identifier]).unwrap(); + + let message_id = Fr::from(message_id); + RLNWitnessInput::new( identity_secret, user_message_limit, @@ -290,4 +342,241 @@ mod test { assert_eq!(identity_secret, expected_identity_secret_seed_phrase); assert_eq!(id_commitment, expected_id_commitment_seed_phrase); } + + #[test] + fn test_extended_keygen_relations() { + let (trapdoor, nullifier, identity_secret, id_commitment) = extended_keygen().unwrap(); + + let expected_identity_secret = poseidon_hash(&[trapdoor, nullifier]).unwrap(); + let expected_id_commitment = poseidon_hash(&[identity_secret]).unwrap(); + + assert_eq!(identity_secret, expected_identity_secret); + assert_eq!(id_commitment, expected_id_commitment); + } + + #[test] + fn test_extended_seeded_keygen_determinism() { + let seed = b"test-seed-extended"; + let first = extended_seeded_keygen(seed).unwrap(); + let second = extended_seeded_keygen(seed).unwrap(); + + assert_eq!(first, second); + + let (trapdoor, nullifier, identity_secret, id_commitment) = first; + let expected_identity_secret = poseidon_hash(&[trapdoor, nullifier]).unwrap(); + let expected_id_commitment = poseidon_hash(&[identity_secret]).unwrap(); + + assert_eq!(identity_secret, expected_identity_secret); + assert_eq!(id_commitment, expected_id_commitment); + } + + #[test] + fn test_witness_serialization_be_roundtrip_and_length_check() { + // Test with default witness + let witness = get_test_witness(); + let ser = rln_witness_to_bytes_be(&witness).unwrap(); + let (deser, _) = bytes_be_to_rln_witness(&ser).unwrap(); + assert_eq!(witness, deser); + + // Test with varied witness + let witness2 = get_test_witness_with_params( + b"different signal", + b"another epoch", + b"alt rln id", + 42, + 200, + ); + let ser2 = rln_witness_to_bytes_be(&witness2).unwrap(); + let (deser2, _) = bytes_be_to_rln_witness(&ser2).unwrap(); + assert_eq!(witness2, deser2); + + // Test with extreme values (large message_id and limit) + let witness3 = get_test_witness_with_params( + b"extreme signal", + b"extreme epoch", + b"extreme id", + 1000000, + 2000000, + ); + let ser3 = rln_witness_to_bytes_be(&witness3).unwrap(); + let (deser3, _) = bytes_be_to_rln_witness(&ser3).unwrap(); + assert_eq!(witness3, deser3); + + let mut bad = ser.clone(); + bad.push(0); + assert!(matches!( + bytes_be_to_rln_witness(&bad), + Err(ProtocolError::InvalidReadLen(_, _)) + )); + } + + #[test] + fn test_proof_values_serialization_be_roundtrip() { + // Test with default witness + let witness = get_test_witness(); + let proof_values = proof_values_from_witness(&witness).unwrap(); + + let ser = rln_proof_values_to_bytes_be(&proof_values); + let (deser, _) = bytes_be_to_rln_proof_values(&ser).unwrap(); + + assert_eq!(proof_values, deser); + + // Test with varied witness + let witness2 = get_test_witness_with_params(b"another signal", b"epoch2", b"id2", 10, 150); + let proof_values2 = proof_values_from_witness(&witness2).unwrap(); + + let ser2 = rln_proof_values_to_bytes_be(&proof_values2); + let (deser2, _) = bytes_be_to_rln_proof_values(&ser2).unwrap(); + + assert_eq!(proof_values2, deser2); + } + + #[test] + fn test_rln_proof_serialization_be_roundtrip() { + let witness = get_test_witness(); + let proving_key = zkey_from_folder(); + let graph_data = graph_from_folder(); + let proof = generate_zk_proof(proving_key, &witness, graph_data).unwrap(); + let proof_values = proof_values_from_witness(&witness).unwrap(); + + let rln_proof = RLNProof { + proof: proof.clone(), + proof_values, + }; + + let ser = rln_proof_to_bytes_be(&rln_proof).unwrap(); + let (deser, _) = bytes_be_to_rln_proof(&ser).unwrap(); + + assert_eq!(rln_proof.proof, deser.proof); + assert_eq!(rln_proof.proof_values, deser.proof_values); + } + + #[test] + fn test_verify_zk_proof_with_modified_public_value_fails() { + let witness = get_test_witness(); + let proving_key = zkey_from_folder(); + let graph_data = graph_from_folder(); + let proof = generate_zk_proof(proving_key, &witness, graph_data).unwrap(); + let mut proof_values = proof_values_from_witness(&witness).unwrap(); + + proof_values.root += Fr::from(1u64); + + let verified = verify_zk_proof(&proving_key.0.vk, &proof, &proof_values).unwrap(); + assert!(!verified); + } + + #[test] + fn test_compute_tree_root_matches_merkle_tree_root() { + // Test with default witness + let (witness, root) = get_test_witness_and_root(); + + let computed_root = compute_tree_root( + witness.identity_secret(), + witness.user_message_limit(), + witness.path_elements(), + witness.identity_path_index(), + ) + .unwrap(); + + assert_eq!(computed_root, root); + + // Test with varied witness + let witness2 = + get_test_witness_with_params(b"root test signal", b"root epoch", b"root id", 25, 300); + let leaf_index = 3; + let id_commitment = poseidon_hash(&[**witness2.identity_secret()]).unwrap(); + let rate_commitment = + poseidon_hash(&[id_commitment, *witness2.user_message_limit()]).unwrap(); + let default_leaf = Fr::from(0); + let mut tree = PoseidonTree::new( + DEFAULT_TREE_DEPTH, + default_leaf, + ConfigOf::::default(), + ) + .unwrap(); + tree.set(leaf_index, rate_commitment).unwrap(); + let root2 = tree.root(); + + let computed_root2 = compute_tree_root( + witness2.identity_secret(), + witness2.user_message_limit(), + witness2.path_elements(), + witness2.identity_path_index(), + ) + .unwrap(); + + assert_eq!(computed_root2, root2); + } + + #[test] + fn test_rln_witness_to_bigint_json_fields() { + // Test with default witness + let witness = get_test_witness(); + let json = rln_witness_to_bigint_json(&witness).unwrap(); + + assert_eq!( + json["identitySecret"].as_str().unwrap(), + to_bigint(witness.identity_secret()).to_str_radix(10) + ); + assert_eq!( + json["userMessageLimit"].as_str().unwrap(), + to_bigint(witness.user_message_limit()).to_str_radix(10) + ); + assert_eq!( + json["messageId"].as_str().unwrap(), + to_bigint(witness.message_id()).to_str_radix(10) + ); + assert_eq!( + json["x"].as_str().unwrap(), + to_bigint(witness.x()).to_str_radix(10) + ); + assert_eq!( + json["externalNullifier"].as_str().unwrap(), + to_bigint(witness.external_nullifier()).to_str_radix(10) + ); + + assert_eq!( + json["pathElements"].as_array().unwrap().len(), + witness.path_elements().len() + ); + assert_eq!( + json["identityPathIndex"].as_array().unwrap().len(), + witness.identity_path_index().len() + ); + + // Test with varied witness + let witness2 = + get_test_witness_with_params(b"json test signal", b"json epoch", b"json id", 99, 500); + let json2 = rln_witness_to_bigint_json(&witness2).unwrap(); + + assert_eq!( + json2["identitySecret"].as_str().unwrap(), + to_bigint(witness2.identity_secret()).to_str_radix(10) + ); + assert_eq!( + json2["userMessageLimit"].as_str().unwrap(), + to_bigint(witness2.user_message_limit()).to_str_radix(10) + ); + assert_eq!( + json2["messageId"].as_str().unwrap(), + to_bigint(witness2.message_id()).to_str_radix(10) + ); + assert_eq!( + json2["x"].as_str().unwrap(), + to_bigint(witness2.x()).to_str_radix(10) + ); + assert_eq!( + json2["externalNullifier"].as_str().unwrap(), + to_bigint(witness2.external_nullifier()).to_str_radix(10) + ); + + assert_eq!( + json2["pathElements"].as_array().unwrap().len(), + witness2.path_elements().len() + ); + assert_eq!( + json2["identityPathIndex"].as_array().unwrap().len(), + witness2.identity_path_index().len() + ); + } }