mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 06:38:06 -05:00
chore: add arithmetic right shift for split u128 values
- this applies a shift that extends the sign of the value interpreted as an i128, this will be required for the fix of the specialized decomp code for u128
This commit is contained in:
committed by
IceTDrinker
parent
dfcceefa83
commit
3dcdf9043a
@@ -26,6 +26,37 @@ pub fn zeroing_shr(x: u64, shift: u64) -> u64 {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Return the arithmetic shift of the u128 value represented by lo and hi interpreted as a signed
|
||||
/// value, as (res_lo, res_hi).
|
||||
pub fn arithmetic_shr_split_u128(lo: u64, hi: u64, shift: u64) -> (u64, u64) {
|
||||
/// Should behave like the following intel intrinsics
|
||||
/// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_srai_epi64
|
||||
fn arithmetic_shr(x: u64, shift: u64) -> u64 {
|
||||
let signed_x = x as i64;
|
||||
if shift >= 64 {
|
||||
if signed_x >= 0 {
|
||||
0
|
||||
} else {
|
||||
// All ones as if the shift extended the sign
|
||||
u64::MAX
|
||||
}
|
||||
} else {
|
||||
(signed_x >> shift) as u64
|
||||
}
|
||||
}
|
||||
|
||||
// This will zero out or fill with 1s depending on the sign bit
|
||||
let res_hi = arithmetic_shr(hi, shift);
|
||||
let res_lo = if shift < 64 {
|
||||
zeroing_shl(hi, 64 - shift) | zeroing_shr(lo, shift)
|
||||
} else {
|
||||
arithmetic_shr(hi, shift - 64)
|
||||
};
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn u128_to_f64((lo, hi): (u64, u64)) -> f64 {
|
||||
const A: f64 = (1u128 << 52) as f64;
|
||||
@@ -1380,4 +1411,39 @@ mod tests {
|
||||
|
||||
assert!(b.1.abs() <= 0.5 * ulp(b.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arihtmetic_shr_split_u128() {
|
||||
use rand::prelude::*;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..1000 {
|
||||
let positive = rng.gen_range(0i128..=i128::MAX);
|
||||
let negative = rng.gen_range(i128::MIN..0);
|
||||
|
||||
for shift in 0..127 {
|
||||
for case in [positive, negative] {
|
||||
let case_lo = case as u64;
|
||||
let case_hi = (case >> 64) as u64;
|
||||
|
||||
let case_shifted = case >> shift;
|
||||
let (res_lo, res_hi) = arithmetic_shr_split_u128(case_lo, case_hi, shift);
|
||||
let res_as_u128 = (res_lo as u128) | ((res_hi as u128) << 64);
|
||||
assert_eq!(res_as_u128, case_shifted as u128);
|
||||
}
|
||||
}
|
||||
|
||||
// Shift hardcoded as 128
|
||||
for case in [positive, negative] {
|
||||
let expected = if case > 0 { 0u128 } else { u128::MAX };
|
||||
|
||||
let case_lo = case as u64;
|
||||
let case_hi = (case >> 64) as u64;
|
||||
|
||||
let (res_lo, res_hi) = arithmetic_shr_split_u128(case_lo, case_hi, 128);
|
||||
let res_as_u128 = (res_lo as u128) | ((res_hi as u128) << 64);
|
||||
assert_eq!(res_as_u128, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user