mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
refactor(core): use avx512 intrinsics when available for data conversions
- we use inline assembly for now as rust does not propose those in the std or core arch crates at the moment - add tests for avx512 conversion
This commit is contained in:
@@ -81,57 +81,62 @@ pub fn mm256_cvtpd_epi64(simd: V3, x: __m256d) -> __m256i {
|
||||
}
|
||||
|
||||
/// Convert a vector of f64 values to a vector of i64 values.
|
||||
/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version.
|
||||
/// This intrinsics is currently not available in rust, so we have our own implementation using
|
||||
/// inline assembly.
|
||||
///
|
||||
/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading
|
||||
/// `_`.
|
||||
///
|
||||
/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtt_roundpd_epi64 `)
|
||||
#[cfg(feature = "nightly-avx512")]
|
||||
#[inline(always)]
|
||||
pub fn mm512_cvtpd_epi64(simd: V4, x: __m512d) -> __m512i {
|
||||
let avx = simd.avx512f;
|
||||
pub fn mm512_cvtt_roundpd_epi64(simd: V4, x: __m512d) -> __m512i {
|
||||
// This first one is required for the zmm_reg notation
|
||||
#[inline]
|
||||
#[target_feature(enable = "sse")]
|
||||
#[target_feature(enable = "sse2")]
|
||||
#[target_feature(enable = "fxsr")]
|
||||
#[target_feature(enable = "sse3")]
|
||||
#[target_feature(enable = "ssse3")]
|
||||
#[target_feature(enable = "sse4.1")]
|
||||
#[target_feature(enable = "sse4.2")]
|
||||
#[target_feature(enable = "popcnt")]
|
||||
#[target_feature(enable = "avx")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
#[target_feature(enable = "bmi1")]
|
||||
#[target_feature(enable = "bmi2")]
|
||||
#[target_feature(enable = "fma")]
|
||||
#[target_feature(enable = "lzcnt")]
|
||||
#[target_feature(enable = "avx512f")]
|
||||
#[target_feature(enable = "avx512dq")]
|
||||
unsafe fn implementation(x: __m512d) -> __m512i {
|
||||
let mut as_i64x8: __m512i;
|
||||
|
||||
// reinterpret the bits as u64 values
|
||||
let bits = avx._mm512_castpd_si512(x);
|
||||
// mask that covers the first 52 bits
|
||||
let mantissa_mask = avx._mm512_set1_epi64(0xFFFFFFFFFFFFF_u64 as i64);
|
||||
// mask that covers the 53rd bit
|
||||
let explicit_mantissa_bit = avx._mm512_set1_epi64(0x10000000000000_u64 as i64);
|
||||
// mask that covers the first 11 bits
|
||||
let exp_mask = avx._mm512_set1_epi64(0x7FF_u64 as i64);
|
||||
// From Intel's documentation the syntax to use this intrinsics is
|
||||
// Instruction: vcvttpd2qq zmm, zmm
|
||||
// With Intel syntax, left operand is the destination, right operand is the source
|
||||
// For the asm! macro
|
||||
// in: indicates an input register
|
||||
// out: indicates an output register
|
||||
// zmm_reg: the avx512 register type
|
||||
// options: see https://doc.rust-lang.org/nightly/reference/inline-assembly.html#options
|
||||
// pure: no side effect
|
||||
// nomem: does not reference RAM (only registers)
|
||||
// nostrack: does not alter the state of the stack
|
||||
core::arch::asm!(
|
||||
"vcvttpd2qq {dst}, {src}",
|
||||
src = in(zmm_reg) x,
|
||||
dst = out(zmm_reg) as_i64x8,
|
||||
options(pure, nomem, nostack)
|
||||
);
|
||||
|
||||
// extract the first 52 bits and add the implicit bit
|
||||
let mantissa = avx._mm512_or_si512(
|
||||
avx._mm512_and_si512(bits, mantissa_mask),
|
||||
explicit_mantissa_bit,
|
||||
);
|
||||
as_i64x8
|
||||
}
|
||||
let _ = simd.avx512dq;
|
||||
|
||||
// extract the 52nd to 63rd (excluded) bits for the biased exponent
|
||||
let biased_exp = avx._mm512_and_si512(avx._mm512_srli_epi64::<52>(bits), exp_mask);
|
||||
|
||||
// extract the 63rd sign bit
|
||||
let sign_is_negative_mask =
|
||||
avx._mm512_cmpeq_epi64_mask(avx._mm512_srli_epi64::<63>(bits), avx._mm512_set1_epi64(1));
|
||||
|
||||
// we need to shift the mantissa by some value that may be negative, so we first shift it to
|
||||
// the left by the maximum amount, then shift it to the right by our value plus the offset we
|
||||
// just shifted by
|
||||
//
|
||||
// the 53rd bit is set to 1, so we shift to the left by 11 so the 63rd (last) bit is set.
|
||||
let mantissa_lshift = avx._mm512_slli_epi64::<11>(mantissa);
|
||||
|
||||
// shift to the right and apply the exponent bias
|
||||
// If biased_exp == 0 then we have 0 or a subnormal value which should return 0, here we will
|
||||
// shift to the right by 1086 which will return 0 as we are shifting in 0s from the left, so
|
||||
// subnormals are already covered
|
||||
let mantissa_shift = avx._mm512_srlv_epi64(
|
||||
mantissa_lshift,
|
||||
avx._mm512_sub_epi64(avx._mm512_set1_epi64(1086), biased_exp),
|
||||
);
|
||||
|
||||
// if the sign bit is unset, we keep our result
|
||||
let value_if_positive = mantissa_shift;
|
||||
// otherwise, we negate it
|
||||
let value_if_negative = avx._mm512_sub_epi64(avx._mm512_setzero_si512(), value_if_positive);
|
||||
|
||||
// Select the value based on the sign mask
|
||||
avx._mm512_mask_blend_epi64(sign_is_negative_mask, value_if_positive, value_if_negative)
|
||||
// SAFETY: simd contains an instance of avx512dq, that matches the target feature of
|
||||
// `implementation`
|
||||
unsafe { implementation(x) }
|
||||
}
|
||||
|
||||
/// Convert a vector of i64 values to a vector of f64 values. Not sure how it works.
|
||||
@@ -161,25 +166,56 @@ pub fn mm256_cvtepi64_pd(simd: V3, x: __m256i) -> __m256d {
|
||||
}
|
||||
|
||||
/// Convert a vector of i64 values to a vector of f64 values.
|
||||
/// This intrinsics is currently not available in rust, so we have our own implementation using
|
||||
/// inline assembly.
|
||||
///
|
||||
/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading
|
||||
/// `_`.
|
||||
///
|
||||
/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtepi64_pd`)
|
||||
#[cfg(feature = "nightly-avx512")]
|
||||
#[inline(always)]
|
||||
pub fn mm512_cvtepi64_pd(simd: V4, x: __m512i) -> __m512d {
|
||||
#[target_feature(enable = "avx512dq")]
|
||||
// This first one is required for the zmm_reg notation
|
||||
#[inline]
|
||||
#[target_feature(enable = "sse")]
|
||||
#[target_feature(enable = "sse2")]
|
||||
#[target_feature(enable = "fxsr")]
|
||||
#[target_feature(enable = "sse3")]
|
||||
#[target_feature(enable = "ssse3")]
|
||||
#[target_feature(enable = "sse4.1")]
|
||||
#[target_feature(enable = "sse4.2")]
|
||||
#[target_feature(enable = "popcnt")]
|
||||
#[target_feature(enable = "avx")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
#[target_feature(enable = "bmi1")]
|
||||
#[target_feature(enable = "bmi2")]
|
||||
#[target_feature(enable = "fma")]
|
||||
#[target_feature(enable = "lzcnt")]
|
||||
#[target_feature(enable = "avx512f")]
|
||||
#[target_feature(enable = "avx512dq")]
|
||||
unsafe fn implementation(x: __m512i) -> __m512d {
|
||||
// hopefully this compiles to vcvtqq2pd
|
||||
let i64x8: [i64; 8] = core::mem::transmute(x);
|
||||
let as_f64x8 = [
|
||||
i64x8[0] as f64,
|
||||
i64x8[1] as f64,
|
||||
i64x8[2] as f64,
|
||||
i64x8[3] as f64,
|
||||
i64x8[4] as f64,
|
||||
i64x8[5] as f64,
|
||||
i64x8[6] as f64,
|
||||
i64x8[7] as f64,
|
||||
];
|
||||
core::mem::transmute(as_f64x8)
|
||||
let mut as_f64x8: __m512d;
|
||||
|
||||
// From Intel's documentation the syntax to use this intrinsics is
|
||||
// Instruction: vcvtqq2pd zmm, zmm
|
||||
// With Intel syntax, left operand is the destination, right operand is the source
|
||||
// For the asm! macro
|
||||
// in: indicates an input register
|
||||
// out: indicates an output register
|
||||
// zmm_reg: the avx512 register type
|
||||
// options: see https://doc.rust-lang.org/nightly/reference/inline-assembly.html#options
|
||||
// pure: no side effect
|
||||
// nomem: does not reference RAM (only registers)
|
||||
// nostrack: does not alter the state of the stack
|
||||
core::arch::asm!(
|
||||
"vcvtqq2pd {dst}, {src}",
|
||||
src = in(zmm_reg) x,
|
||||
dst = out(zmm_reg) as_f64x8,
|
||||
options(pure, nomem, nostack)
|
||||
);
|
||||
|
||||
as_f64x8
|
||||
}
|
||||
let _ = simd.avx512dq;
|
||||
|
||||
@@ -758,9 +794,9 @@ pub fn convert_add_backward_torus_u64_v4(
|
||||
);
|
||||
|
||||
// convert f64 to i64
|
||||
let fract_re = mm512_cvtpd_epi64(simd, fract_re);
|
||||
let fract_re = mm512_cvtt_roundpd_epi64(simd, fract_re);
|
||||
// convert f64 to i64
|
||||
let fract_im = mm512_cvtpd_epi64(simd, fract_im);
|
||||
let fract_im = mm512_cvtt_roundpd_epi64(simd, fract_im);
|
||||
|
||||
// add to input and store
|
||||
*out_re = pulp::cast(avx512f._mm512_add_epi64(fract_re, pulp::cast(*out_re)));
|
||||
@@ -1157,7 +1193,7 @@ mod tests {
|
||||
});
|
||||
|
||||
let computed: [i64; 4] =
|
||||
pulp::cast_lossy(mm512_cvtpd_epi64(simd, pulp::cast([v, v])));
|
||||
pulp::cast_lossy(mm512_cvtt_roundpd_epi64(simd, pulp::cast([v, v])));
|
||||
assert_eq!(target, computed);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user