fix: make sure computations don't overflow for certain primes for 32 bits

- The original code seemed to assume that the Barrett reduction would not
overflow if p <= 2^31, this is incorrect but rare
- The correctness constraint has a bound much smaller than 2^31, some
primes bigger than the derived threshold can still use the fast code
given a certain criterion is respected which corresponds to a "lucky" case
of the Barrett reduction, the new code now manages this

maths explained in https://blog.zksecurity.xyz/posts/barrett-tighter-bound/
and copiously in comments in the code
This commit is contained in:
Arthur Meyre
2025-08-08 19:17:54 +02:00
parent 1dcc3c8c89
commit a4841036b7

View File

@@ -597,6 +597,36 @@ fn mul_accumulate_scalar(
}
}
struct BarrettInit32 {
big_q: u32,
p_barrett: u32,
requires_single_reduction_step: bool,
}
impl BarrettInit32 {
pub fn new(modulus: u32) -> Self {
let big_q = modulus.ilog2() + 1;
let big_l = big_q + 31;
let m_as_u64: u64 = modulus.into();
let two_to_the_l = 1u64 << big_l; // Equivalent to 2^{2k} from the zk security blog
let (p_barrett, beta) = ((two_to_the_l / m_as_u64) as u32, (two_to_the_l % m_as_u64));
// Check that the chosen prime will only trigger a single barrett reduction step with
// our implementation. If two reductions are needed there can be cases where it is not
// possible to decide whether a reduction is required yielding wrong results.
// Formula derived with https://blog.zksecurity.xyz/posts/barrett-tighter-bound/
let single_reduction_threshold = m_as_u64 - (1 << (big_q - 1));
let requires_single_reduction_step = beta <= single_reduction_threshold;
Self {
big_q,
p_barrett,
requires_single_reduction_step,
}
}
}
/// Negacyclic NTT plan for 32bit primes.
#[derive(Clone)]
pub struct Plan {
@@ -606,6 +636,8 @@ pub struct Plan {
inv_twid_shoup: ABox<[u32]>,
p: u32,
p_div: Div32,
/// If true can use non generic code and optimized reduction algorithms.
can_use_fast_reduction_code: bool,
// used for elementwise product
p_barrett: u32,
@@ -666,9 +698,55 @@ impl Plan {
let n_inv_mod_p = crate::prime::exp_mod32(p_div, polynomial_size as u32, modulus - 2);
let n_inv_mod_p_shoup = (((n_inv_mod_p as u64) << 32) / modulus as u64) as u32;
let big_q = modulus.ilog2() + 1;
let big_l = big_q + 31;
let p_barrett = ((1u64 << big_l) / modulus as u64) as u32;
let BarrettInit32 {
big_q,
p_barrett,
requires_single_reduction_step,
} = BarrettInit32::new(modulus);
// The mul_accumulate_scalar code (and derived SIMD code) can have an overflow issue due
// to how Barrett reduction works. It can also overflow during the following
// accumulations. We'll derive the requirements for the mul_accumulate_scalar code to be
// correct and use that to determine whether a given prime is a "fast" prime for our
// code.
//
// You can check https://blog.zksecurity.xyz/posts/barrett-tighter-bound/ for a readable
// Barrett reduction analysis/breakdown.
//
// Barrett reduction gives an approximate value q_approx of the quotient q we are
// looking to compute the reduction
// which can differ from q by at most 2: q_approx € [q - 2, q - 1, q], given we subtract
// q * p from our result to start the reduction we have:
// prod = to_reduce - q * p
// prod = true_prod + c * p with c € {0, 1, 2}
// We need to make sure that prod does not overflow 32 bits
// We have true_prod which is reduced mod p, so true_prod <= p - 1
// prod <= 2^32 - 1 <=>
// true_prod + 2 * p <= 2^32 - 1 <=>
// 3 * p - 1 <= 2^32 - 1 <=>
// p <= 2^32 / 3 <= 1431655765.3333333 < 1431655766
//
// After the computation of prod a first reduction step is performed, meaning we now
// have:
// prod = true_prod + c * p with c € {0, 1}
// We are accumulating in acc which is already reduced, so acc <= p - 1
// The accumulation yields:
// We need acc + prod <= 2^32 - 1 <=>
// (p - 1) + (p - 1) + p <= 2^32 - 1 <=>
// 3p - 2 <= 2^32 - 1 <=>
// 3p <= 2^32 + 1 <=>
// p <= (2^32 + 1) / 3 <= 1431655765.6666667 < 1431655766
//
// It is the same criterion.
// Now for cases where moduli are known to yield a Barrett reduction with a single step
// of reduction required (see blog post again) the conditions become:
// true_prod + p <= 2^32 - 1 <=>
// 2p - 1 <= 2^32 - 1 <=>
// p <= 2^31
let can_use_fast_reduction_code =
(modulus < 1431655766) || (requires_single_reduction_step && modulus <= (1 << 31));
Some(Self {
twid,
@@ -677,6 +755,7 @@ impl Plan {
inv_twid,
p: modulus,
p_div,
can_use_fast_reduction_code,
n_inv_mod_p,
n_inv_mod_p_shoup,
p_barrett,
@@ -701,6 +780,15 @@ impl Plan {
self.p
}
/// Returns whether the negacyclic NTT plan can use fast reduction code.
///
/// To avoid correctness issues linked to overflows the code has to make performance sacrifices
/// for certain primes and will not yield the best performance possible for them.
#[inline]
pub fn can_use_fast_reduction_code(&self) -> bool {
self.can_use_fast_reduction_code
}
/// Applies a forward negacyclic NTT transform in place to the given buffer.
///
/// # Note
@@ -810,7 +898,7 @@ impl Plan {
/// Computes the elementwise product of `lhs` and `rhs`, multiplied by the inverse of the
/// polynomial modulo the NTT modulus, and stores the result in `lhs`.
pub fn mul_assign_normalize(&self, lhs: &mut [u32], rhs: &[u32]) {
if self.p < (1 << 31) {
if self.can_use_fast_reduction_code {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4::try_new() {
@@ -866,7 +954,7 @@ impl Plan {
/// Multiplies the values by the inverse of the polynomial modulo the NTT modulus, and stores
/// the result in `values`.
pub fn normalize(&self, values: &mut [u32]) {
if self.p < (1 << 31) {
if self.can_use_fast_reduction_code {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4::try_new() {
@@ -903,7 +991,7 @@ impl Plan {
/// Computes the elementwise product of `lhs` and `rhs` and accumulates the result to `acc`.
pub fn mul_accumulate(&self, acc: &mut [u32], lhs: &[u32], rhs: &[u32]) {
if self.p < (1 << 31) {
if self.can_use_fast_reduction_code {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4::try_new() {
@@ -1246,6 +1334,67 @@ pub mod tests {
assert_eq!(val, val_target);
}
}
#[test]
fn test_plan_can_use_fast_reduction_code() {
use crate::primes32::{P0, P1, P2, P3, P4, P5, P6, P7, P8, P9};
const POLYNOMIAL_SIZE: usize = 32;
// First prime is smaller than 1431655766
// Second is larger, but satisfies the single reduction condition for Barrett
// The other ones can be used for performant code, we want those to be fast
for p in [
1062862849, 1431669377, P0, P1, P2, P3, P4, P5, P6, P7, P8, P9,
] {
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
assert!(plan.can_use_fast_reduction_code);
}
// Prime is bigger than threshold and does not satisfy the single reduction condition for
// Barrett
let plan = Plan::try_new(POLYNOMIAL_SIZE, 0x7fe0_1001).unwrap();
assert!(!plan.can_use_fast_reduction_code);
}
#[test]
fn test_barret_invalid_reduction_non_regression() {
const POLYNOMIAL_SIZE: usize = 32;
let p: u32 = 0x7fe0_1001;
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
let value = 0x6e63593a;
let mut acc = [0u32; POLYNOMIAL_SIZE];
// Essentially = [value, 0, 0, ...]
let input: [u32; POLYNOMIAL_SIZE] =
core::array::from_fn(|i| if i == 0 { value } else { 0 });
plan.mul_accumulate(&mut acc, &input, &input);
let expected = (u64::from(value) * u64::from(value) % u64::from(p)) as u32;
assert_eq!(acc[0], expected);
}
#[test]
fn test_barrett_invalid_reduction_normalize_non_regression() {
const POLYNOMIAL_SIZE: usize = 32;
let p: u32 = 0x7fe0_1001;
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
let value = 0x6e63593a;
// Essentially = [value, 0, 0, ...]
let input: [u32; POLYNOMIAL_SIZE] =
core::array::from_fn(|i| if i == 0 { value } else { 0 });
let mut acc = input;
plan.mul_assign_normalize(&mut acc, &input);
let expected = (u128::from(value) * u128::from(value) * u128::from(plan.n_inv_mod_p)
% u128::from(p)) as u32;
assert_eq!(acc[0], expected);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]