From 770ae22bb6daa38a52309b015ae337a64914628c Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Thu, 14 Nov 2024 17:05:09 +0100 Subject: [PATCH] refactor(zk): place asserts in proof behind a condition --- tfhe-zk-pok/src/four_squares.rs | 5 +++ tfhe-zk-pok/src/proofs/mod.rs | 28 ++++++++++++++++ tfhe-zk-pok/src/proofs/pke.rs | 24 ++++++++++++-- tfhe-zk-pok/src/proofs/pke_v2.rs | 55 ++++++++++++++++++++++---------- 4 files changed, 94 insertions(+), 18 deletions(-) diff --git a/tfhe-zk-pok/src/four_squares.rs b/tfhe-zk-pok/src/four_squares.rs index 0e1fa16a8..9e85a465a 100644 --- a/tfhe-zk-pok/src/four_squares.rs +++ b/tfhe-zk-pok/src/four_squares.rs @@ -214,6 +214,11 @@ impl Montgomery { pub fn four_squares(v: u128) -> [u64; 4] { let rng = &mut StdRng::seed_from_u64(0); + // In the extreme case where the noise is exactly at the bound, v is 0 + if v == 0 { + return [0; 4]; + } + let f = v % 4; if f == 2 { let b = isqrt(v as _) as u64; diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index f81d6cea8..5cd6f21b9 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -132,6 +132,34 @@ impl GroupElements { } } +/// Allows to compute proof with bad inputs for tests +#[derive(PartialEq, Eq)] +enum ProofSanityCheckMode { + Panic, +} + +/// Check the preconditions of the pke proof before computing it. Panic if one of the conditions +/// does not hold. +#[allow(clippy::too_many_arguments)] +fn assert_pke_proof_preconditions( + c1: &[i64], + e1: &[i64], + c2: &[i64], + e2: &[i64], + d: usize, + k_max: usize, + big_d: usize, + big_d_max: usize, +) { + assert_eq!(c1.len(), d); + assert_eq!(e1.len(), d); + + assert_eq!(c2.len(), e2.len()); + assert!(c2.len() <= k_max); + + assert!(big_d <= big_d_max); +} + /// q (modulus) is encoded on 64b, with 0 meaning 2^64. This converts the encoded q to its effective /// value for modular operations. fn decode_q(q: u64) -> u128 { diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index e40dab377..d1c945125 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -475,6 +475,24 @@ pub fn prove( metadata: &[u8], load: ComputeLoad, rng: &mut dyn RngCore, +) -> Proof { + prove_impl( + public, + private_commit, + metadata, + load, + rng, + ProofSanityCheckMode::Panic, + ) +} + +fn prove_impl( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + metadata: &[u8], + load: ComputeLoad, + rng: &mut dyn RngCore, + sanity_check_mode: ProofSanityCheckMode, ) -> Proof { let &PublicParams { ref g_lists, @@ -503,7 +521,6 @@ pub fn prove( let PrivateCommit { r, e1, m, e2, .. } = private_commit; let k = c2.len(); - assert!(k <= k_max); let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count; @@ -512,7 +529,10 @@ pub fn prove( let big_d = d + k * effective_t_for_decomposition.ilog2() as usize + (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize); - assert!(big_d <= big_d_max); + + if sanity_check_mode == ProofSanityCheckMode::Panic { + assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, big_d, big_d_max); + } // FIXME: div_round let delta = { diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index 3504c6652..d6a7a904f 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -654,6 +654,24 @@ pub fn prove( metadata: &[u8], load: ComputeLoad, rng: &mut dyn RngCore, +) -> Proof { + prove_impl( + public, + private_commit, + metadata, + load, + rng, + ProofSanityCheckMode::Panic, + ) +} + +fn prove_impl( + public: (&PublicParams, &PublicCommit), + private_commit: &PrivateCommit, + metadata: &[u8], + load: ComputeLoad, + rng: &mut dyn RngCore, + sanity_check_mode: ProofSanityCheckMode, ) -> Proof { _ = load; let ( @@ -689,7 +707,6 @@ pub fn prove( let PrivateCommit { r, e1, m, e2, .. } = private_commit; let k = c2.len(); - assert!(k <= k_max); let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count; @@ -698,7 +715,21 @@ pub fn prove( // Recompute the D for our case if k is smaller than the k max // formula in Prove_pp: 2. let D = d + k * effective_cleartext_t.ilog2() as usize; - assert!(D <= D_max); + + let e_sqr_norm = e1 + .iter() + .chain(e2) + .map(|x| sqr(x.unsigned_abs() as u128)) + .sum::(); + + if sanity_check_mode == ProofSanityCheckMode::Panic { + assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max); + assert!( + sqr(B as u128) >= e_sqr_norm, + "squared norm of error ({e_sqr_norm}) exceeds threshold ({})", + sqr(B as u128) + ); + } // FIXME: div_round let delta = { @@ -730,18 +761,6 @@ pub fn prove( ) .collect::>(); - let e_sqr_norm = e1 - .iter() - .chain(e2) - .map(|x| sqr(x.unsigned_abs() as u128)) - .sum::(); - - assert!( - sqr(B as u128) >= e_sqr_norm, - "squared norm of error ({e_sqr_norm}) exceeds threshold ({})", - sqr(B as u128) - ); - let v = four_squares(sqr(B as u128) - e_sqr_norm).map(|v| v as i64); let e1_zp = &*e1 @@ -872,7 +891,9 @@ pub fn prove( -1 => acc -= x as i128, _ => unreachable!(), }); - assert!(acc.unsigned_abs() <= B_bound as u128); + if sanity_check_mode == ProofSanityCheckMode::Panic { + assert!(acc.unsigned_abs() <= B_bound as u128); + } acc as i64 }) .collect::>(); @@ -970,7 +991,9 @@ pub fn prove( .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); - assert_eq!(y.len(), w_bin.len()); + if sanity_check_mode == ProofSanityCheckMode::Panic { + assert_eq!(y.len(), w_bin.len()); + } let scalars = y .iter() .zip(w_bin.iter())