From 3dcdf9043a0758d0831565a04f26b3df032b7b3f Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 26 Nov 2025 16:49:37 +0100 Subject: [PATCH] chore: add arithmetic right shift for split u128 values - this applies a shift that extends the sign of the value interpreted as an i128, this will be required for the fix of the specialized decomp code for u128 --- .../fft_impl/fft128_u128/math/fft/mod.rs | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs index d393e7d10..4c4ee5c9b 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs @@ -26,6 +26,37 @@ pub fn zeroing_shr(x: u64, shift: u64) -> u64 { } } +#[inline(always)] +/// Return the arithmetic shift of the u128 value represented by lo and hi interpreted as a signed +/// value, as (res_lo, res_hi). +pub fn arithmetic_shr_split_u128(lo: u64, hi: u64, shift: u64) -> (u64, u64) { + /// Should behave like the following intel intrinsics + /// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_srai_epi64 + fn arithmetic_shr(x: u64, shift: u64) -> u64 { + let signed_x = x as i64; + if shift >= 64 { + if signed_x >= 0 { + 0 + } else { + // All ones as if the shift extended the sign + u64::MAX + } + } else { + (signed_x >> shift) as u64 + } + } + + // This will zero out or fill with 1s depending on the sign bit + let res_hi = arithmetic_shr(hi, shift); + let res_lo = if shift < 64 { + zeroing_shl(hi, 64 - shift) | zeroing_shr(lo, shift) + } else { + arithmetic_shr(hi, shift - 64) + }; + + (res_lo, res_hi) +} + #[inline(always)] pub fn u128_to_f64((lo, hi): (u64, u64)) -> f64 { const A: f64 = (1u128 << 52) as f64; @@ -1380,4 +1411,39 @@ mod tests { assert!(b.1.abs() <= 0.5 * ulp(b.0)); } + + #[test] + fn test_arihtmetic_shr_split_u128() { + use rand::prelude::*; + + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let positive = rng.gen_range(0i128..=i128::MAX); + let negative = rng.gen_range(i128::MIN..0); + + for shift in 0..127 { + for case in [positive, negative] { + let case_lo = case as u64; + let case_hi = (case >> 64) as u64; + + let case_shifted = case >> shift; + let (res_lo, res_hi) = arithmetic_shr_split_u128(case_lo, case_hi, shift); + let res_as_u128 = (res_lo as u128) | ((res_hi as u128) << 64); + assert_eq!(res_as_u128, case_shifted as u128); + } + } + + // Shift hardcoded as 128 + for case in [positive, negative] { + let expected = if case > 0 { 0u128 } else { u128::MAX }; + + let case_lo = case as u64; + let case_hi = (case >> 64) as u64; + + let (res_lo, res_hi) = arithmetic_shr_split_u128(case_lo, case_hi, 128); + let res_as_u128 = (res_lo as u128) | ((res_hi as u128) << 64); + assert_eq!(res_as_u128, expected); + } + } + } }