fix(zk): handle limit cases in the four_squares algorithm

This commit is contained in:
Nicolas Sarlin
2025-09-05 12:04:14 +02:00
committed by Nicolas Sarlin
parent 0a1651adf3
commit adcf9bc1f3
3 changed files with 85 additions and 21 deletions

View File

@@ -1,6 +1,8 @@
use ark_ff::biginteger::arithmetic::widening_mul;
use rand::prelude::*;
use crate::proofs::ProofSanityCheckMode;
/// Avoid overflows for squares of u64
pub fn sqr(x: u64) -> u128 {
let x = x as u128;
@@ -188,21 +190,24 @@ impl Montgomery {
}
}
pub fn four_squares(v: u128) -> [u64; 4] {
pub fn four_squares(v: u128, sanity_check_mode: ProofSanityCheckMode) -> [u64; 4] {
let rng = &mut StdRng::seed_from_u64(0);
// In the extreme case where the noise is exactly at the bound, v is 0
if v == 0 {
return [0; 4];
}
// Handle limit cases that would trigger an infinite loop
match v {
0 => return [0; 4],
2 => return [1, 1, 0, 0],
6 => return [2, 1, 1, 0],
_ => {}
};
let f = v % 4;
if f == 2 {
let b = v.isqrt() as u64;
'main_loop: loop {
let x = 2 + rng.gen::<u64>() % (b - 2);
let y = 2 + rng.gen::<u64>() % (b - 2);
let x: u64 = rng.gen_range(0..=b);
let y: u64 = rng.gen_range(0..=b);
let (sum, o) = u128::overflowing_add(sqr(x), sqr(y));
if o || sum > v {
@@ -270,21 +275,80 @@ pub fn four_squares(v: u128) -> [u64; 4] {
return [x, y, z, w];
}
} else if f == 0 {
four_squares(v / 4).map(|x| x + x)
four_squares(v / 4, sanity_check_mode).map(|x| x + x)
} else {
let mut r = four_squares(2 * v);
r.sort_by_key(|&x| {
if x % 2 == 0 {
-1 - ((x / 2) as i64)
} else {
(x / 2) as i64
}
});
// v is odd, compute the four squares for 2*v and deduce the result for v
let double = match sanity_check_mode {
ProofSanityCheckMode::Panic => v.checked_mul(2).unwrap(),
#[cfg(test)]
ProofSanityCheckMode::Ignore => v.wrapping_mul(2),
};
let mut r = four_squares(double, sanity_check_mode);
// At this point we know that exactly 2 values of r are even and 2 are odd:
// r = [w, x, y, z]
// 2v = w² + x² + y² + z²
// We cannot have 4 even numbers because
// 2v = (2w')²+(2x')²+(2y')²+(2z')² = 4v', but v is odd
// We cannot have 4 odd numbers because
// 2v = (2w'+1)²+(2x'+1)²+(2y'+1)²+(2z'+1)²
// = (4w'²+4w'+1)+(4x'²+4x'+1)+(4y'²+4y'+1)+(4z'²+4z'+1) = 4v', same issue
//
// Since w² + x² + y² + z² is even we must have 2 of them odd and 2 of them even
// Sort so that r[0], r[1] are even and r[2], r[3] are odd,
// with r[1] > r[0] and r[3] > r[2]
r.sort_by_key(|&x| (x % 2 != 0, x));
[
(r[0] + r[1]) / 2,
(r[0] - r[1]) / 2,
(r[3] + r[2]) / 2,
// divide by 2 before addition to avoid overflows
(r[1] / 2 + r[0] / 2),
(r[1] - r[0]) / 2,
(r[3] / 2 + r[2] / 2) + 1,
(r[3] - r[2]) / 2,
]
}
}
#[cfg(test)]
mod test {
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
use super::*;
fn assert_four_squares(value: u128) {
let squares = four_squares(value, ProofSanityCheckMode::Panic);
let res = squares.iter().map(|x| sqr(*x)).sum();
assert_eq!(value, res);
}
#[test]
fn test_four_squares() {
const RAND_TESTS_COUNT: usize = 1000;
let seed = thread_rng().gen();
println!("four_squares seed: {seed:x}");
let rng = &mut StdRng::seed_from_u64(seed);
for val in 0..256 {
assert_four_squares(val);
}
// If v % 4 = 1 or 3, v will be multiplied by 2
assert_four_squares(u128::MAX / 2);
assert_four_squares(u128::MAX / 2 - 1);
assert_four_squares(u128::MAX / 2 - 2);
assert_four_squares(u128::MAX / 2 - 3);
for i in 8..127 {
assert_four_squares((1u128 << i) + 1);
}
for _ in 0..RAND_TESTS_COUNT {
let v: u128 = rng.gen_range(0..(u128::MAX / 2));
assert_four_squares(v);
}
}
}

View File

@@ -144,7 +144,7 @@ impl<G: Curve> GroupElements<G> {
/// Allows to compute proof with bad inputs for tests
#[derive(Copy, Clone, PartialEq, Eq)]
enum ProofSanityCheckMode {
pub(crate) enum ProofSanityCheckMode {
Panic,
#[cfg(test)]
Ignore,

View File

@@ -928,7 +928,7 @@ fn prove_impl<G: Curve>(
)
.collect::<Box<[_]>>();
let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64);
let v = four_squares(B_squared - e_sqr_norm, sanity_check_mode).map(|v| v as i64);
let e1_zp = &*e1
.iter()