From bfbf638fed69cdd253e3aa4e781add38654141f9 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Tue, 9 Sep 2025 17:30:04 +0200 Subject: [PATCH] fix(zk): add a size check for the public key --- tfhe-zk-pok/src/proofs/mod.rs | 11 +++ tfhe-zk-pok/src/proofs/pke.rs | 94 ++++++++++++++++++++- tfhe-zk-pok/src/proofs/pke_v2/mod.rs | 121 +++++++++++++++++++++++++-- 3 files changed, 220 insertions(+), 6 deletions(-) diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index 342361037..bf8c6fa12 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -154,6 +154,8 @@ pub(crate) enum ProofSanityCheckMode { /// does not hold. #[allow(clippy::too_many_arguments)] fn assert_pke_proof_preconditions( + a: &[i64], + b: &[i64], c1: &[i64], e1: &[i64], c2: &[i64], @@ -164,6 +166,8 @@ fn assert_pke_proof_preconditions( big_d_max: usize, ) { assert!(k_max <= d); + assert_eq!(a.len(), d); + assert_eq!(b.len(), d); assert_eq!(c1.len(), d); assert_eq!(e1.len(), d); @@ -442,6 +446,13 @@ mod test { No, } + #[derive(Clone, Copy, Eq, PartialEq)] + pub(super) enum InputSizeVariation { + Oversized, + Undersized, + Nominal, + } + pub(super) fn serialize_then_deserialize< Params: Compressible + Serialize + for<'de> Deserialize<'de>, >( diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index d3fa57c5c..b4851cd6f 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -607,7 +607,7 @@ fn prove_impl( + (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize); if sanity_check_mode == ProofSanityCheckMode::Panic { - assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, big_d, big_d_max); + assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, big_d, big_d_max); } // FIXME: div_round @@ -1095,6 +1095,10 @@ pub fn verify( return Err(()); } + if a.len() != d || b.len() != d { + return Err(()); + } + let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count; let big_d = d @@ -1787,6 +1791,94 @@ mod tests { } } + /// Test that the proof is rejected without panic if the public key elements are not of the + /// correct size + #[test] + fn test_pke_wrong_pk_size() { + let PkeTestParameters { + d, + k, + B, + q, + t, + msbs_zero_padding_bit_count, + } = PKEV1_TEST_PARAMS; + + let seed = thread_rng().gen(); + println!("pke_wrong_pk_size seed: {seed:x}"); + let rng = &mut StdRng::seed_from_u64(seed); + + let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS); + + let ct = testcase.encrypt(PKEV1_TEST_PARAMS); + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let crs = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + testcase.a.clone(), + testcase.b.clone(), + ct.c1.clone(), + ct.c2.clone(), + testcase.r.clone(), + testcase.e1.clone(), + testcase.m.clone(), + testcase.e2.clone(), + &crs, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&crs, &public_commit), + &private_commit, + &testcase.metadata, + load, + &seed.to_le_bytes(), + ); + + for (a_size_kind, b_size_kind) in itertools::iproduct!( + [ + InputSizeVariation::Oversized, + InputSizeVariation::Undersized, + InputSizeVariation::Nominal, + ], + [ + InputSizeVariation::Oversized, + InputSizeVariation::Undersized, + InputSizeVariation::Nominal, + ] + ) { + if a_size_kind == InputSizeVariation::Nominal + && b_size_kind == InputSizeVariation::Nominal + { + // This is the nominal case that is already tested + continue; + } + + let mut public_commit = public_commit.clone(); + + match a_size_kind { + InputSizeVariation::Oversized => public_commit.a.push(rng.gen()), + InputSizeVariation::Undersized => { + public_commit.a.pop(); + } + InputSizeVariation::Nominal => {} + }; + + match b_size_kind { + InputSizeVariation::Oversized => public_commit.b.push(rng.gen()), + InputSizeVariation::Undersized => { + public_commit.b.pop(); + } + InputSizeVariation::Nominal => {} + }; + + // Should not panic but return an error + assert!(verify(&proof, (&crs, &public_commit), &testcase.metadata).is_err()) + } + } + } + /// Test verification with modified ciphertexts #[test] fn test_bad_ct() { diff --git a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs index 59e821c70..7e269b253 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs @@ -893,7 +893,7 @@ fn prove_impl( .sum::(); if sanity_check_mode == ProofSanityCheckMode::Panic { - assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max); + assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, D, D_max); assert!( B_squared >= e_sqr_norm, "squared norm of error ({e_sqr_norm}) exceeds threshold ({B_squared})", @@ -1076,7 +1076,16 @@ fn prove_impl( let (theta, theta_hash) = t_hash.gen_theta(); let mut a_theta = vec![G::Zp::ZERO; D]; - compute_a_theta::(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta); + compute_a_theta::( + &mut a_theta, + &theta, + a, + d, + k, + b, + effective_cleartext_t, + delta, + ); let t_theta = theta .iter() @@ -1606,6 +1615,7 @@ fn compute_a_theta( a_theta: &mut [G::Zp], theta: &[G::Zp], a: &[i64], + d: usize, k: usize, b: &[i64], t: u64, @@ -1620,8 +1630,8 @@ fn compute_a_theta( // ... // delta g[log t].T theta2_k // ] - - let d = a.len(); + assert_eq!(a.len(), d); + assert!(theta.len() >= d); let theta1 = &theta[..d]; let theta2 = &theta[d..]; @@ -1770,6 +1780,10 @@ pub fn verify_impl( return Err(()); } + if a.len() != d || b.len() != d { + return Err(()); + } + let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count; let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k); let (_, D, _, m_bound) = compute_crs_params( @@ -1815,7 +1829,16 @@ pub fn verify_impl( let (theta, theta_hash) = t_hash.gen_theta(); let mut a_theta = vec![G::Zp::ZERO; D]; - compute_a_theta::(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta); + compute_a_theta::( + &mut a_theta, + &theta, + a, + d, + k, + b, + effective_cleartext_t, + delta, + ); let t_theta = theta .iter() @@ -2778,6 +2801,94 @@ mod tests { } } + /// Test that the proof is rejected without panic if the public key elements are not of the + /// correct size + #[test] + fn test_pke_wrong_pk_size() { + let PkeTestParameters { + d, + k, + B, + q, + t, + msbs_zero_padding_bit_count, + } = PKEV2_TEST_PARAMS; + + let seed = thread_rng().gen(); + println!("pke_wrong_pk_size seed: {seed:x}"); + let rng = &mut StdRng::seed_from_u64(seed); + + let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); + let ct = testcase.encrypt(PKEV2_TEST_PARAMS); + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let crs = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + testcase.a.clone(), + testcase.b.clone(), + ct.c1.clone(), + ct.c2.clone(), + testcase.r.clone(), + testcase.e1.clone(), + testcase.m.clone(), + testcase.e2.clone(), + &crs, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&crs, &public_commit), + &private_commit, + &testcase.metadata, + load, + &seed.to_le_bytes(), + ); + + for (a_size_kind, b_size_kind) in itertools::iproduct!( + [ + InputSizeVariation::Oversized, + InputSizeVariation::Undersized, + InputSizeVariation::Nominal, + ], + [ + InputSizeVariation::Oversized, + InputSizeVariation::Undersized, + InputSizeVariation::Nominal, + ] + ) { + if a_size_kind == InputSizeVariation::Nominal + && b_size_kind == InputSizeVariation::Nominal + { + // This is the nominal case that is already tested + continue; + } + + let mut public_commit = public_commit.clone(); + + match a_size_kind { + InputSizeVariation::Oversized => public_commit.a.push(rng.gen()), + InputSizeVariation::Undersized => { + public_commit.a.pop(); + } + InputSizeVariation::Nominal => {} + }; + + match b_size_kind { + InputSizeVariation::Oversized => public_commit.b.push(rng.gen()), + InputSizeVariation::Undersized => { + public_commit.b.pop(); + } + InputSizeVariation::Nominal => {} + }; + + // Should not panic but return an error + assert!(verify(&proof, (&crs, &public_commit), &testcase.metadata).is_err()) + } + } + } + /// Test verification with modified ciphertexts #[test] fn test_bad_ct() {