From a9d213742467296ccfa193ad0fe8188c540ee279 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 26 Nov 2025 17:52:31 +0100 Subject: [PATCH] chore: add test and fix for specialized decomposition code - add dedicated edge case for u128 - fix all implementations that can be used and run tests on the same u128 edge case --- .../fft_impl/fft128_u128/crypto/ggsw.rs | 43 ++++-- .../fft_impl/fft128_u128/crypto/tests.rs | 145 ++++++++++++++++++ .../fft_impl/fft128_u128/math/fft/mod.rs | 91 +++++++++++ 3 files changed, 262 insertions(+), 17 deletions(-) diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs index b8527185a..0f56d0fe5 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs @@ -1,4 +1,6 @@ -use super::super::math::fft::{wrapping_add, wrapping_sub, zeroing_shl, zeroing_shr, Fft128View}; +use super::super::math::fft::{ + arithmetic_shr_split_u128, wrapping_add, wrapping_sub, zeroing_shl, zeroing_shr, Fft128View, +}; use crate::core_crypto::commons::math::decomposition::DecompositionLevel; use crate::core_crypto::commons::traits::container::Split; use crate::core_crypto::commons::traits::contiguous_entity_container::{ @@ -235,7 +237,7 @@ pub fn add_external_product_assign_split 64 { u8::MAX } else { 0 }); let shift_minus_64 = simd.splat_u64x8(shift.wrapping_sub(64)); let shift_complement = simd.splat_u64x8(64u64.wrapping_sub(shift)); let shift = simd.splat_u64x8(shift); @@ -305,14 +308,15 @@ fn collect_next_term_split_avx512( let res_lo = simd.and_u64x8(vstate_lo, mod_b_mask_lo); let res_hi = simd.and_u64x8(vstate_hi, mod_b_mask_hi); - vstate_lo = simd.or_u64x8( - simd.shr_dyn_u64x8(vstate_hi, base_log_minus_64), + vstate_lo = simd.select_u64x8( + base_log_gt_64, + pulp::cast(simd.shr_dyn_i64x8(pulp::cast(vstate_hi), base_log_minus_64)), simd.or_u64x8( simd.shl_dyn_u64x8(vstate_hi, base_log_complement), simd.shr_dyn_u64x8(vstate_lo, base_log), ), ); - vstate_hi = simd.shr_dyn_u64x8(vstate_hi, base_log); + vstate_hi = pulp::cast(simd.shr_dyn_i64x8(pulp::cast(vstate_hi), base_log)); let res_sub1_lo = simd.wrapping_sub_u64x8(res_lo, simd.splat_u64x8(1)); let overflow = @@ -367,7 +371,7 @@ fn collect_next_term_split_avx512( } #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -fn collect_next_term_split_avx2( +pub(crate) fn collect_next_term_split_avx2( simd: pulp::x86::V3, glwe_decomp_term_lo: &mut [u64], glwe_decomp_term_hi: &mut [u64], @@ -392,7 +396,9 @@ fn collect_next_term_split_avx2( #[inline(always)] fn call(self) -> Self::Output { - use super::super::math::fft::{wrapping_add_avx2, wrapping_sub_avx2}; + use super::super::math::fft::{ + mm256_sra_epi64_avx2, wrapping_add_avx2, wrapping_sub_avx2, + }; let Self { simd, @@ -418,6 +424,7 @@ fn collect_next_term_split_avx2( let mod_b_mask_lo = simd.splat_u64x4(mod_b_mask_lo); let mod_b_mask_hi = simd.splat_u64x4(mod_b_mask_hi); + let base_log_gt_64 = simd.splat_u64x4(if base_log > 64 { u64::MAX } else { 0 }); let shift_minus_64 = simd.splat_u64x4(shift.wrapping_sub(64)); let shift_complement = simd.splat_u64x4(64u64.wrapping_sub(shift)); let shift = simd.splat_u64x4(shift); @@ -438,13 +445,20 @@ fn collect_next_term_split_avx2( let res_hi = simd.and_u64x4(vstate_hi, mod_b_mask_hi); vstate_lo = simd.or_u64x4( - simd.shr_dyn_u64x4(vstate_hi, base_log_minus_64), + simd.and_u64x4( + base_log_gt_64, + pulp::cast(mm256_sra_epi64_avx2( + simd, + pulp::cast(vstate_hi), + base_log_minus_64, + )), + ), simd.or_u64x4( simd.shl_dyn_u64x4(vstate_hi, base_log_complement), simd.shr_dyn_u64x4(vstate_lo, base_log), ), ); - vstate_hi = simd.shr_dyn_u64x4(vstate_hi, base_log); + vstate_hi = pulp::cast(mm256_sra_epi64_avx2(simd, pulp::cast(vstate_hi), base_log)); let res_sub1_lo = simd.wrapping_sub_u64x4(res_lo, simd.splat_u64x4(1)); let overflow = pulp::cast(simd.cmp_eq_u64x4(res_lo, simd.splat_u64x4(0))); @@ -497,7 +511,7 @@ fn collect_next_term_split_avx2( }); } -fn collect_next_term_split_scalar( +pub(crate) fn collect_next_term_split_scalar( glwe_decomp_term_lo: &mut [u64], glwe_decomp_term_hi: &mut [u64], decomposition_states_lo: &mut [u64], @@ -518,13 +532,8 @@ fn collect_next_term_split_scalar( let res_hi = *state_hi & mod_b_mask_hi; let base_log = base_log as u64; - if base_log < 64 { - *state_lo = zeroing_shl(*state_hi, 64 - base_log) | zeroing_shr(*state_lo, base_log); - *state_hi = zeroing_shr(*state_hi, base_log); - } else { - *state_lo = zeroing_shr(*state_hi, base_log - 64); - *state_hi = 0; - } + (*state_lo, *state_hi) = arithmetic_shr_split_u128(*state_lo, *state_hi, base_log); + let (res_sub1_lo, overflow) = res_lo.overflowing_sub(1); let res_sub1_hi = res_hi.wrapping_sub(overflow as u64); diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs index 00797d59f..d71adfd9b 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs @@ -1,5 +1,6 @@ use super::super::super::{fft128, fft128_u128}; use super::super::math::fft::Fft128View; +use super::ggsw::collect_next_term_split_scalar; use crate::core_crypto::fft_impl::common::tests::{ gen_keys_or_get_from_cache_if_enabled, generate_keys, }; @@ -254,3 +255,147 @@ fn test_split_pbs() { assert_eq!(lwe_out_split, lwe_out_non_split); } } + +#[test] +fn test_decomposition_edge_case_sign_handling_split_u128() { + let decomposer = SignedDecomposer::new(DecompositionBaseLog(40), DecompositionLevelCount(3)); + // This value triggers a negative state at the start of the decomposition, invalid code using + // logic shift will wrongly compute an intermediate value by not keeping the sign of the + // state on the last level if base_log * (level_count + 1) > Scalar::BITS, the logic shift will + // shift in 0s instead of the 1s to keep the sign information + let val = 170141183460604905165246226680529368983u128; + let base_log = decomposer.base_log; + + let expected = [-421613125320i128, 482008863255, -549755813888]; + + let decomp_state = decomposer.init_decomposer_state(val); + let mut decomp_state_lo = decomp_state as u64; + let mut decomp_state_hi = (decomp_state >> 64) as u64; + + let mod_b_mask = (1u128 << decomposer.base_log) - 1; + let mod_b_mask_lo = mod_b_mask as u64; + let mod_b_mask_hi = (mod_b_mask >> 64) as u64; + + for expect in expected { + let mut decomp_term_lo = 0u64; + let mut decomp_term_hi = 0u64; + + collect_next_term_split_scalar( + core::slice::from_mut(&mut decomp_term_lo), + core::slice::from_mut(&mut decomp_term_hi), + core::slice::from_mut(&mut decomp_state_lo), + core::slice::from_mut(&mut decomp_state_hi), + mod_b_mask_lo, + mod_b_mask_hi, + base_log, + ); + + let term_value_u128 = ((decomp_term_hi as u128) << 64) | decomp_term_lo as u128; + let term_value_i128 = term_value_u128 as i128; + + assert_eq!(term_value_i128, expect); + } +} + +#[test] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +fn test_decomposition_edge_case_sign_handling_split_u128_avx2() { + use super::ggsw::collect_next_term_split_avx2; + + let Some(simd) = pulp::x86::V3::try_new() else { + return; + }; + + let decomposer = SignedDecomposer::new(DecompositionBaseLog(40), DecompositionLevelCount(3)); + // This value triggers a negative state at the start of the decomposition, invalid code using + // logic shift will wrongly compute an intermediate value by not keeping the sign of the + // state on the last level if base_log * (level_count + 1) > Scalar::BITS, the logic shift will + // shift in 0s instead of the 1s to keep the sign information + let val = 170141183460604905165246226680529368983u128; + let base_log = decomposer.base_log; + + let expected = [-421613125320i128, 482008863255, -549755813888]; + + let decomp_state = decomposer.init_decomposer_state(val); + let mut decomp_state_lo = [decomp_state as u64; 4]; + let mut decomp_state_hi = [(decomp_state >> 64) as u64; 4]; + + let mod_b_mask = (1u128 << decomposer.base_log) - 1; + let mod_b_mask_lo = mod_b_mask as u64; + let mod_b_mask_hi = (mod_b_mask >> 64) as u64; + + for expect in expected { + let mut decomp_term_lo = [0u64; 4]; + let mut decomp_term_hi = [0u64; 4]; + + collect_next_term_split_avx2( + simd, + &mut decomp_term_lo, + &mut decomp_term_hi, + &mut decomp_state_lo, + &mut decomp_state_hi, + mod_b_mask_lo, + mod_b_mask_hi, + base_log, + ); + + for (decomp_term_hi, decomp_term_lo) in decomp_term_hi.into_iter().zip(decomp_term_lo) { + let term_value_u128 = ((decomp_term_hi as u128) << 64) | decomp_term_lo as u128; + let term_value_i128 = term_value_u128 as i128; + + assert_eq!(term_value_i128, expect); + } + } +} + +#[test] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "avx512")] +fn test_decomposition_edge_case_sign_handling_split_u128_avx512() { + use super::ggsw::collect_next_term_split_avx512; + + let Some(simd) = pulp::x86::V4::try_new() else { + return; + }; + + let decomposer = SignedDecomposer::new(DecompositionBaseLog(40), DecompositionLevelCount(3)); + // This value triggers a negative state at the start of the decomposition, invalid code using + // logic shift will wrongly compute an intermediate value by not keeping the sign of the + // state on the last level if base_log * (level_count + 1) > Scalar::BITS, the logic shift will + // shift in 0s instead of the 1s to keep the sign information + let val = 170141183460604905165246226680529368983u128; + let base_log = decomposer.base_log; + + let expected = [-421613125320i128, 482008863255, -549755813888]; + + let decomp_state = decomposer.init_decomposer_state(val); + let mut decomp_state_lo = [decomp_state as u64; 8]; + let mut decomp_state_hi = [(decomp_state >> 64) as u64; 8]; + + let mod_b_mask = (1u128 << decomposer.base_log) - 1; + let mod_b_mask_lo = mod_b_mask as u64; + let mod_b_mask_hi = (mod_b_mask >> 64) as u64; + + for expect in expected { + let mut decomp_term_lo = [0u64; 8]; + let mut decomp_term_hi = [0u64; 8]; + + collect_next_term_split_avx512( + simd, + &mut decomp_term_lo, + &mut decomp_term_hi, + &mut decomp_state_lo, + &mut decomp_state_hi, + mod_b_mask_lo, + mod_b_mask_hi, + base_log, + ); + + for (decomp_term_hi, decomp_term_lo) in decomp_term_hi.into_iter().zip(decomp_term_lo) { + let term_value_u128 = ((decomp_term_hi as u128) << 64) | decomp_term_lo as u128; + let term_value_i128 = term_value_u128 as i128; + + assert_eq!(term_value_i128, expect); + } + } +} diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs index 4c4ee5c9b..dd3aa092b 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs @@ -1373,6 +1373,56 @@ impl Fft128View<'_> { } } +/// Workaround implementation of the arithmetic shift on 64 bits integer for avx2 +#[inline(always)] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub(crate) fn mm256_sra_epi64_avx2( + simd: pulp::x86::V3, + input: pulp::i64x4, + shift: pulp::u64x4, +) -> pulp::i64x4 { + struct Sra64 { + simd: pulp::x86::V3, + input: pulp::i64x4, + shift: pulp::u64x4, + } + + impl pulp::NullaryFnOnce for Sra64 { + type Output = pulp::i64x4; + + fn call(self) -> Self::Output { + let Self { simd, input, shift } = self; + + // Proposed algorithm: + // take the top bit (sign) + // turn it into a mask m (0 -> 0, 1 -> u64::MAX) + // compute x = a ^ m + // compute y = shr_logical(x, b) + // compute result = y ^ m + + let zero = simd.splat_u64x4(0); + + let input = pulp::cast(input); + // Get the MSB giving the sign + let sign = simd.shr_const_u64x4::<63>(input); + // 0 if input >= 0 else == -1 == 0xFFFF_FFFF_FFFF_FFFF == u64::MAX + let sign_mask = simd.wrapping_sub_u64x4(zero, sign); + // If sign_mask == 0 values stays the same + // else all bits are inverted, the nice property is that if the top bit is a 1 then it + // becomes a 0 (see shr comment as to why it's awesome) + let masked_input = simd.xor_u64x4(input, sign_mask); + // If sign_mask == 0, we are shifting in 0s and the last inversion won't change anything + // It works as expected + // If sign_mask == -1 then the 0s we shift in will be inverted at the last step and so + // the logical shift is acting as an arithmetic shift + let shifted = simd.shr_dyn_u64x4(masked_input, shift); + pulp::cast(simd.xor_u64x4(shifted, sign_mask)) + } + } + + simd.vectorize(Sra64 { simd, input, shift }) +} + #[cfg(test)] mod tests { use super::*; @@ -1446,4 +1496,45 @@ mod tests { } } } + + #[test] + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + fn test_sra_avx2() { + use rand::prelude::*; + + let Some(simd) = pulp::x86::V3::try_new() else { + return; + }; + + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + for shift in 0..63 { + let shift = [shift as u64; 4]; + for range in [0i64..=i64::MAX, i64::MIN..=-1] { + let input: [i64; 4] = core::array::from_fn(|_| rng.gen_range(range.clone())); + + let res_i64 = mm256_sra_epi64_avx2(simd, pulp::cast(input), pulp::cast(shift)); + let res_as_array: [i64; 4] = pulp::cast(res_i64); + + let expected: [i64; 4] = core::array::from_fn(|idx| input[idx] >> shift[idx]); + + assert_eq!(res_as_array, expected); + } + } + + // Shift hardcoded as 64 + for range in [0i64..=i64::MAX, i64::MIN..=-1] { + let shift = [64u64; 4]; + let input: [i64; 4] = core::array::from_fn(|_| rng.gen_range(range.clone())); + + let res_i64 = mm256_sra_epi64_avx2(simd, pulp::cast(input), pulp::cast(shift)); + let res_as_array: [i64; 4] = pulp::cast(res_i64); + + let expected: [i64; 4] = + core::array::from_fn(|idx| if input[idx] > 0 { 0 } else { -1 }); + + assert_eq!(res_as_array, expected); + } + } + } }