From adcf9bc1f3b5c5ff32b3c5fd2feb4dde7118c9c0 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Fri, 5 Sep 2025 12:04:14 +0200 Subject: [PATCH] fix(zk): handle limit cases in the four_squares algorithm --- tfhe-zk-pok/src/four_squares.rs | 102 ++++++++++++++++++++++----- tfhe-zk-pok/src/proofs/mod.rs | 2 +- tfhe-zk-pok/src/proofs/pke_v2/mod.rs | 2 +- 3 files changed, 85 insertions(+), 21 deletions(-) diff --git a/tfhe-zk-pok/src/four_squares.rs b/tfhe-zk-pok/src/four_squares.rs index a3b60db5c..0d5872ef4 100644 --- a/tfhe-zk-pok/src/four_squares.rs +++ b/tfhe-zk-pok/src/four_squares.rs @@ -1,6 +1,8 @@ use ark_ff::biginteger::arithmetic::widening_mul; use rand::prelude::*; +use crate::proofs::ProofSanityCheckMode; + /// Avoid overflows for squares of u64 pub fn sqr(x: u64) -> u128 { let x = x as u128; @@ -188,21 +190,24 @@ impl Montgomery { } } -pub fn four_squares(v: u128) -> [u64; 4] { +pub fn four_squares(v: u128, sanity_check_mode: ProofSanityCheckMode) -> [u64; 4] { let rng = &mut StdRng::seed_from_u64(0); - // In the extreme case where the noise is exactly at the bound, v is 0 - if v == 0 { - return [0; 4]; - } + // Handle limit cases that would trigger an infinite loop + match v { + 0 => return [0; 4], + 2 => return [1, 1, 0, 0], + 6 => return [2, 1, 1, 0], + _ => {} + }; let f = v % 4; if f == 2 { let b = v.isqrt() as u64; 'main_loop: loop { - let x = 2 + rng.gen::() % (b - 2); - let y = 2 + rng.gen::() % (b - 2); + let x: u64 = rng.gen_range(0..=b); + let y: u64 = rng.gen_range(0..=b); let (sum, o) = u128::overflowing_add(sqr(x), sqr(y)); if o || sum > v { @@ -270,21 +275,80 @@ pub fn four_squares(v: u128) -> [u64; 4] { return [x, y, z, w]; } } else if f == 0 { - four_squares(v / 4).map(|x| x + x) + four_squares(v / 4, sanity_check_mode).map(|x| x + x) } else { - let mut r = four_squares(2 * v); - r.sort_by_key(|&x| { - if x % 2 == 0 { - -1 - ((x / 2) as i64) - } else { - (x / 2) as i64 - } - }); + // v is odd, compute the four squares for 2*v and deduce the result for v + let double = match sanity_check_mode { + ProofSanityCheckMode::Panic => v.checked_mul(2).unwrap(), + #[cfg(test)] + ProofSanityCheckMode::Ignore => v.wrapping_mul(2), + }; + let mut r = four_squares(double, sanity_check_mode); + + // At this point we know that exactly 2 values of r are even and 2 are odd: + // r = [w, x, y, z] + // 2v = w² + x² + y² + z² + // We cannot have 4 even numbers because + // 2v = (2w')²+(2x')²+(2y')²+(2z')² = 4v', but v is odd + // We cannot have 4 odd numbers because + // 2v = (2w'+1)²+(2x'+1)²+(2y'+1)²+(2z'+1)² + // = (4w'²+4w'+1)+(4x'²+4x'+1)+(4y'²+4y'+1)+(4z'²+4z'+1) = 4v', same issue + // + // Since w² + x² + y² + z² is even we must have 2 of them odd and 2 of them even + + // Sort so that r[0], r[1] are even and r[2], r[3] are odd, + // with r[1] > r[0] and r[3] > r[2] + r.sort_by_key(|&x| (x % 2 != 0, x)); + [ - (r[0] + r[1]) / 2, - (r[0] - r[1]) / 2, - (r[3] + r[2]) / 2, + // divide by 2 before addition to avoid overflows + (r[1] / 2 + r[0] / 2), + (r[1] - r[0]) / 2, + (r[3] / 2 + r[2] / 2) + 1, (r[3] - r[2]) / 2, ] } } + +#[cfg(test)] +mod test { + use rand::rngs::StdRng; + use rand::{thread_rng, Rng, SeedableRng}; + + use super::*; + + fn assert_four_squares(value: u128) { + let squares = four_squares(value, ProofSanityCheckMode::Panic); + + let res = squares.iter().map(|x| sqr(*x)).sum(); + assert_eq!(value, res); + } + + #[test] + fn test_four_squares() { + const RAND_TESTS_COUNT: usize = 1000; + + let seed = thread_rng().gen(); + println!("four_squares seed: {seed:x}"); + let rng = &mut StdRng::seed_from_u64(seed); + + for val in 0..256 { + assert_four_squares(val); + } + + // If v % 4 = 1 or 3, v will be multiplied by 2 + assert_four_squares(u128::MAX / 2); + assert_four_squares(u128::MAX / 2 - 1); + assert_four_squares(u128::MAX / 2 - 2); + assert_four_squares(u128::MAX / 2 - 3); + + for i in 8..127 { + assert_four_squares((1u128 << i) + 1); + } + + for _ in 0..RAND_TESTS_COUNT { + let v: u128 = rng.gen_range(0..(u128::MAX / 2)); + assert_four_squares(v); + } + } +} diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index 6c099a73b..342361037 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -144,7 +144,7 @@ impl GroupElements { /// Allows to compute proof with bad inputs for tests #[derive(Copy, Clone, PartialEq, Eq)] -enum ProofSanityCheckMode { +pub(crate) enum ProofSanityCheckMode { Panic, #[cfg(test)] Ignore, diff --git a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs index 1d825d0f1..59e821c70 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2/mod.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2/mod.rs @@ -928,7 +928,7 @@ fn prove_impl( ) .collect::>(); - let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64); + let v = four_squares(B_squared - e_sqr_norm, sanity_check_mode).map(|v| v as i64); let e1_zp = &*e1 .iter()