mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
fix(zk): handle limit cases in the four_squares algorithm
This commit is contained in:
committed by
Nicolas Sarlin
parent
0a1651adf3
commit
adcf9bc1f3
@@ -1,6 +1,8 @@
|
||||
use ark_ff::biginteger::arithmetic::widening_mul;
|
||||
use rand::prelude::*;
|
||||
|
||||
use crate::proofs::ProofSanityCheckMode;
|
||||
|
||||
/// Avoid overflows for squares of u64
|
||||
pub fn sqr(x: u64) -> u128 {
|
||||
let x = x as u128;
|
||||
@@ -188,21 +190,24 @@ impl Montgomery {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn four_squares(v: u128) -> [u64; 4] {
|
||||
pub fn four_squares(v: u128, sanity_check_mode: ProofSanityCheckMode) -> [u64; 4] {
|
||||
let rng = &mut StdRng::seed_from_u64(0);
|
||||
|
||||
// In the extreme case where the noise is exactly at the bound, v is 0
|
||||
if v == 0 {
|
||||
return [0; 4];
|
||||
}
|
||||
// Handle limit cases that would trigger an infinite loop
|
||||
match v {
|
||||
0 => return [0; 4],
|
||||
2 => return [1, 1, 0, 0],
|
||||
6 => return [2, 1, 1, 0],
|
||||
_ => {}
|
||||
};
|
||||
|
||||
let f = v % 4;
|
||||
if f == 2 {
|
||||
let b = v.isqrt() as u64;
|
||||
|
||||
'main_loop: loop {
|
||||
let x = 2 + rng.gen::<u64>() % (b - 2);
|
||||
let y = 2 + rng.gen::<u64>() % (b - 2);
|
||||
let x: u64 = rng.gen_range(0..=b);
|
||||
let y: u64 = rng.gen_range(0..=b);
|
||||
|
||||
let (sum, o) = u128::overflowing_add(sqr(x), sqr(y));
|
||||
if o || sum > v {
|
||||
@@ -270,21 +275,80 @@ pub fn four_squares(v: u128) -> [u64; 4] {
|
||||
return [x, y, z, w];
|
||||
}
|
||||
} else if f == 0 {
|
||||
four_squares(v / 4).map(|x| x + x)
|
||||
four_squares(v / 4, sanity_check_mode).map(|x| x + x)
|
||||
} else {
|
||||
let mut r = four_squares(2 * v);
|
||||
r.sort_by_key(|&x| {
|
||||
if x % 2 == 0 {
|
||||
-1 - ((x / 2) as i64)
|
||||
} else {
|
||||
(x / 2) as i64
|
||||
}
|
||||
});
|
||||
// v is odd, compute the four squares for 2*v and deduce the result for v
|
||||
let double = match sanity_check_mode {
|
||||
ProofSanityCheckMode::Panic => v.checked_mul(2).unwrap(),
|
||||
#[cfg(test)]
|
||||
ProofSanityCheckMode::Ignore => v.wrapping_mul(2),
|
||||
};
|
||||
let mut r = four_squares(double, sanity_check_mode);
|
||||
|
||||
// At this point we know that exactly 2 values of r are even and 2 are odd:
|
||||
// r = [w, x, y, z]
|
||||
// 2v = w² + x² + y² + z²
|
||||
// We cannot have 4 even numbers because
|
||||
// 2v = (2w')²+(2x')²+(2y')²+(2z')² = 4v', but v is odd
|
||||
// We cannot have 4 odd numbers because
|
||||
// 2v = (2w'+1)²+(2x'+1)²+(2y'+1)²+(2z'+1)²
|
||||
// = (4w'²+4w'+1)+(4x'²+4x'+1)+(4y'²+4y'+1)+(4z'²+4z'+1) = 4v', same issue
|
||||
//
|
||||
// Since w² + x² + y² + z² is even we must have 2 of them odd and 2 of them even
|
||||
|
||||
// Sort so that r[0], r[1] are even and r[2], r[3] are odd,
|
||||
// with r[1] > r[0] and r[3] > r[2]
|
||||
r.sort_by_key(|&x| (x % 2 != 0, x));
|
||||
|
||||
[
|
||||
(r[0] + r[1]) / 2,
|
||||
(r[0] - r[1]) / 2,
|
||||
(r[3] + r[2]) / 2,
|
||||
// divide by 2 before addition to avoid overflows
|
||||
(r[1] / 2 + r[0] / 2),
|
||||
(r[1] - r[0]) / 2,
|
||||
(r[3] / 2 + r[2] / 2) + 1,
|
||||
(r[3] - r[2]) / 2,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{thread_rng, Rng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_four_squares(value: u128) {
|
||||
let squares = four_squares(value, ProofSanityCheckMode::Panic);
|
||||
|
||||
let res = squares.iter().map(|x| sqr(*x)).sum();
|
||||
assert_eq!(value, res);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_squares() {
|
||||
const RAND_TESTS_COUNT: usize = 1000;
|
||||
|
||||
let seed = thread_rng().gen();
|
||||
println!("four_squares seed: {seed:x}");
|
||||
let rng = &mut StdRng::seed_from_u64(seed);
|
||||
|
||||
for val in 0..256 {
|
||||
assert_four_squares(val);
|
||||
}
|
||||
|
||||
// If v % 4 = 1 or 3, v will be multiplied by 2
|
||||
assert_four_squares(u128::MAX / 2);
|
||||
assert_four_squares(u128::MAX / 2 - 1);
|
||||
assert_four_squares(u128::MAX / 2 - 2);
|
||||
assert_four_squares(u128::MAX / 2 - 3);
|
||||
|
||||
for i in 8..127 {
|
||||
assert_four_squares((1u128 << i) + 1);
|
||||
}
|
||||
|
||||
for _ in 0..RAND_TESTS_COUNT {
|
||||
let v: u128 = rng.gen_range(0..(u128::MAX / 2));
|
||||
assert_four_squares(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ impl<G: Curve> GroupElements<G> {
|
||||
|
||||
/// Allows to compute proof with bad inputs for tests
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
enum ProofSanityCheckMode {
|
||||
pub(crate) enum ProofSanityCheckMode {
|
||||
Panic,
|
||||
#[cfg(test)]
|
||||
Ignore,
|
||||
|
||||
@@ -928,7 +928,7 @@ fn prove_impl<G: Curve>(
|
||||
)
|
||||
.collect::<Box<[_]>>();
|
||||
|
||||
let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64);
|
||||
let v = four_squares(B_squared - e_sqr_norm, sanity_check_mode).map(|v| v as i64);
|
||||
|
||||
let e1_zp = &*e1
|
||||
.iter()
|
||||
|
||||
Reference in New Issue
Block a user