mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix: make sure computations don't overflow for certain primes for 32 bits
- The original code seemed to assume that the Barrett reduction would not overflow if p <= 2^31, this is incorrect but rare - The correctness constraint has a bound much smaller than 2^31, some primes bigger than the derived threshold can still use the fast code given a certain criterion is respected which corresponds to a "lucky" case of the Barrett reduction, the new code now manages this maths explained in https://blog.zksecurity.xyz/posts/barrett-tighter-bound/ and copiously in comments in the code
This commit is contained in:
@@ -597,6 +597,36 @@ fn mul_accumulate_scalar(
|
||||
}
|
||||
}
|
||||
|
||||
struct BarrettInit32 {
|
||||
big_q: u32,
|
||||
p_barrett: u32,
|
||||
requires_single_reduction_step: bool,
|
||||
}
|
||||
|
||||
impl BarrettInit32 {
|
||||
pub fn new(modulus: u32) -> Self {
|
||||
let big_q = modulus.ilog2() + 1;
|
||||
let big_l = big_q + 31;
|
||||
let m_as_u64: u64 = modulus.into();
|
||||
let two_to_the_l = 1u64 << big_l; // Equivalent to 2^{2k} from the zk security blog
|
||||
let (p_barrett, beta) = ((two_to_the_l / m_as_u64) as u32, (two_to_the_l % m_as_u64));
|
||||
|
||||
// Check that the chosen prime will only trigger a single barrett reduction step with
|
||||
// our implementation. If two reductions are needed there can be cases where it is not
|
||||
// possible to decide whether a reduction is required yielding wrong results.
|
||||
// Formula derived with https://blog.zksecurity.xyz/posts/barrett-tighter-bound/
|
||||
let single_reduction_threshold = m_as_u64 - (1 << (big_q - 1));
|
||||
|
||||
let requires_single_reduction_step = beta <= single_reduction_threshold;
|
||||
|
||||
Self {
|
||||
big_q,
|
||||
p_barrett,
|
||||
requires_single_reduction_step,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Negacyclic NTT plan for 32bit primes.
|
||||
#[derive(Clone)]
|
||||
pub struct Plan {
|
||||
@@ -606,6 +636,8 @@ pub struct Plan {
|
||||
inv_twid_shoup: ABox<[u32]>,
|
||||
p: u32,
|
||||
p_div: Div32,
|
||||
/// If true can use non generic code and optimized reduction algorithms.
|
||||
can_use_fast_reduction_code: bool,
|
||||
|
||||
// used for elementwise product
|
||||
p_barrett: u32,
|
||||
@@ -666,9 +698,55 @@ impl Plan {
|
||||
|
||||
let n_inv_mod_p = crate::prime::exp_mod32(p_div, polynomial_size as u32, modulus - 2);
|
||||
let n_inv_mod_p_shoup = (((n_inv_mod_p as u64) << 32) / modulus as u64) as u32;
|
||||
let big_q = modulus.ilog2() + 1;
|
||||
let big_l = big_q + 31;
|
||||
let p_barrett = ((1u64 << big_l) / modulus as u64) as u32;
|
||||
|
||||
let BarrettInit32 {
|
||||
big_q,
|
||||
p_barrett,
|
||||
requires_single_reduction_step,
|
||||
} = BarrettInit32::new(modulus);
|
||||
|
||||
// The mul_accumulate_scalar code (and derived SIMD code) can have an overflow issue due
|
||||
// to how Barrett reduction works. It can also overflow during the following
|
||||
// accumulations. We'll derive the requirements for the mul_accumulate_scalar code to be
|
||||
// correct and use that to determine whether a given prime is a "fast" prime for our
|
||||
// code.
|
||||
//
|
||||
// You can check https://blog.zksecurity.xyz/posts/barrett-tighter-bound/ for a readable
|
||||
// Barrett reduction analysis/breakdown.
|
||||
//
|
||||
// Barrett reduction gives an approximate value q_approx of the quotient q we are
|
||||
// looking to compute the reduction
|
||||
// which can differ from q by at most 2: q_approx € [q - 2, q - 1, q], given we subtract
|
||||
// q * p from our result to start the reduction we have:
|
||||
// prod = to_reduce - q * p
|
||||
// prod = true_prod + c * p with c € {0, 1, 2}
|
||||
// We need to make sure that prod does not overflow 32 bits
|
||||
// We have true_prod which is reduced mod p, so true_prod <= p - 1
|
||||
// prod <= 2^32 - 1 <=>
|
||||
// true_prod + 2 * p <= 2^32 - 1 <=>
|
||||
// 3 * p - 1 <= 2^32 - 1 <=>
|
||||
// p <= 2^32 / 3 <= 1431655765.3333333 < 1431655766
|
||||
//
|
||||
// After the computation of prod a first reduction step is performed, meaning we now
|
||||
// have:
|
||||
// prod = true_prod + c * p with c € {0, 1}
|
||||
// We are accumulating in acc which is already reduced, so acc <= p - 1
|
||||
// The accumulation yields:
|
||||
// We need acc + prod <= 2^32 - 1 <=>
|
||||
// (p - 1) + (p - 1) + p <= 2^32 - 1 <=>
|
||||
// 3p - 2 <= 2^32 - 1 <=>
|
||||
// 3p <= 2^32 + 1 <=>
|
||||
// p <= (2^32 + 1) / 3 <= 1431655765.6666667 < 1431655766
|
||||
//
|
||||
// It is the same criterion.
|
||||
// Now for cases where moduli are known to yield a Barrett reduction with a single step
|
||||
// of reduction required (see blog post again) the conditions become:
|
||||
// true_prod + p <= 2^32 - 1 <=>
|
||||
// 2p - 1 <= 2^32 - 1 <=>
|
||||
// p <= 2^31
|
||||
|
||||
let can_use_fast_reduction_code =
|
||||
(modulus < 1431655766) || (requires_single_reduction_step && modulus <= (1 << 31));
|
||||
|
||||
Some(Self {
|
||||
twid,
|
||||
@@ -677,6 +755,7 @@ impl Plan {
|
||||
inv_twid,
|
||||
p: modulus,
|
||||
p_div,
|
||||
can_use_fast_reduction_code,
|
||||
n_inv_mod_p,
|
||||
n_inv_mod_p_shoup,
|
||||
p_barrett,
|
||||
@@ -701,6 +780,15 @@ impl Plan {
|
||||
self.p
|
||||
}
|
||||
|
||||
/// Returns whether the negacyclic NTT plan can use fast reduction code.
|
||||
///
|
||||
/// To avoid correctness issues linked to overflows the code has to make performance sacrifices
|
||||
/// for certain primes and will not yield the best performance possible for them.
|
||||
#[inline]
|
||||
pub fn can_use_fast_reduction_code(&self) -> bool {
|
||||
self.can_use_fast_reduction_code
|
||||
}
|
||||
|
||||
/// Applies a forward negacyclic NTT transform in place to the given buffer.
|
||||
///
|
||||
/// # Note
|
||||
@@ -810,7 +898,7 @@ impl Plan {
|
||||
/// Computes the elementwise product of `lhs` and `rhs`, multiplied by the inverse of the
|
||||
/// polynomial modulo the NTT modulus, and stores the result in `lhs`.
|
||||
pub fn mul_assign_normalize(&self, lhs: &mut [u32], rhs: &[u32]) {
|
||||
if self.p < (1 << 31) {
|
||||
if self.can_use_fast_reduction_code {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
@@ -866,7 +954,7 @@ impl Plan {
|
||||
/// Multiplies the values by the inverse of the polynomial modulo the NTT modulus, and stores
|
||||
/// the result in `values`.
|
||||
pub fn normalize(&self, values: &mut [u32]) {
|
||||
if self.p < (1 << 31) {
|
||||
if self.can_use_fast_reduction_code {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
@@ -903,7 +991,7 @@ impl Plan {
|
||||
|
||||
/// Computes the elementwise product of `lhs` and `rhs` and accumulates the result to `acc`.
|
||||
pub fn mul_accumulate(&self, acc: &mut [u32], lhs: &[u32], rhs: &[u32]) {
|
||||
if self.p < (1 << 31) {
|
||||
if self.can_use_fast_reduction_code {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
@@ -1246,6 +1334,67 @@ pub mod tests {
|
||||
assert_eq!(val, val_target);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_can_use_fast_reduction_code() {
|
||||
use crate::primes32::{P0, P1, P2, P3, P4, P5, P6, P7, P8, P9};
|
||||
const POLYNOMIAL_SIZE: usize = 32;
|
||||
|
||||
// First prime is smaller than 1431655766
|
||||
// Second is larger, but satisfies the single reduction condition for Barrett
|
||||
// The other ones can be used for performant code, we want those to be fast
|
||||
for p in [
|
||||
1062862849, 1431669377, P0, P1, P2, P3, P4, P5, P6, P7, P8, P9,
|
||||
] {
|
||||
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
|
||||
|
||||
assert!(plan.can_use_fast_reduction_code);
|
||||
}
|
||||
|
||||
// Prime is bigger than threshold and does not satisfy the single reduction condition for
|
||||
// Barrett
|
||||
let plan = Plan::try_new(POLYNOMIAL_SIZE, 0x7fe0_1001).unwrap();
|
||||
assert!(!plan.can_use_fast_reduction_code);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_barret_invalid_reduction_non_regression() {
|
||||
const POLYNOMIAL_SIZE: usize = 32;
|
||||
|
||||
let p: u32 = 0x7fe0_1001;
|
||||
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
|
||||
|
||||
let value = 0x6e63593a;
|
||||
let mut acc = [0u32; POLYNOMIAL_SIZE];
|
||||
// Essentially = [value, 0, 0, ...]
|
||||
let input: [u32; POLYNOMIAL_SIZE] =
|
||||
core::array::from_fn(|i| if i == 0 { value } else { 0 });
|
||||
|
||||
plan.mul_accumulate(&mut acc, &input, &input);
|
||||
|
||||
let expected = (u64::from(value) * u64::from(value) % u64::from(p)) as u32;
|
||||
assert_eq!(acc[0], expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_barrett_invalid_reduction_normalize_non_regression() {
|
||||
const POLYNOMIAL_SIZE: usize = 32;
|
||||
|
||||
let p: u32 = 0x7fe0_1001;
|
||||
let plan = Plan::try_new(POLYNOMIAL_SIZE, p).unwrap();
|
||||
|
||||
let value = 0x6e63593a;
|
||||
// Essentially = [value, 0, 0, ...]
|
||||
let input: [u32; POLYNOMIAL_SIZE] =
|
||||
core::array::from_fn(|i| if i == 0 { value } else { 0 });
|
||||
let mut acc = input;
|
||||
|
||||
plan.mul_assign_normalize(&mut acc, &input);
|
||||
|
||||
let expected = (u128::from(value) * u128::from(value) * u128::from(plan.n_inv_mod_p)
|
||||
% u128::from(p)) as u32;
|
||||
assert_eq!(acc[0], expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
|
||||
Reference in New Issue
Block a user