From 0c3919628f4466aa192f7ecf5c0545098323b86e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 3 Oct 2023 15:48:54 +0200 Subject: [PATCH] refactor(core): use avx512 intrinsics when available for data conversions - we use inline assembly for now as rust does not propose those in the std or core arch crates at the moment - add tests for avx512 conversion --- .../fft_impl/fft64/math/fft/x86.rs | 162 +++++++++++------- 1 file changed, 99 insertions(+), 63 deletions(-) diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs index a13796c9a..c0f9e59be 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs @@ -81,57 +81,62 @@ pub fn mm256_cvtpd_epi64(simd: V3, x: __m256d) -> __m256i { } /// Convert a vector of f64 values to a vector of i64 values. -/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version. +/// This intrinsics is currently not available in rust, so we have our own implementation using +/// inline assembly. +/// +/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading +/// `_`. +/// +/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtt_roundpd_epi64 `) #[cfg(feature = "nightly-avx512")] #[inline(always)] -pub fn mm512_cvtpd_epi64(simd: V4, x: __m512d) -> __m512i { - let avx = simd.avx512f; +pub fn mm512_cvtt_roundpd_epi64(simd: V4, x: __m512d) -> __m512i { + // This first one is required for the zmm_reg notation + #[inline] + #[target_feature(enable = "sse")] + #[target_feature(enable = "sse2")] + #[target_feature(enable = "fxsr")] + #[target_feature(enable = "sse3")] + #[target_feature(enable = "ssse3")] + #[target_feature(enable = "sse4.1")] + #[target_feature(enable = "sse4.2")] + #[target_feature(enable = "popcnt")] + #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] + #[target_feature(enable = "bmi1")] + #[target_feature(enable = "bmi2")] + #[target_feature(enable = "fma")] + #[target_feature(enable = "lzcnt")] + #[target_feature(enable = "avx512f")] + #[target_feature(enable = "avx512dq")] + unsafe fn implementation(x: __m512d) -> __m512i { + let mut as_i64x8: __m512i; - // reinterpret the bits as u64 values - let bits = avx._mm512_castpd_si512(x); - // mask that covers the first 52 bits - let mantissa_mask = avx._mm512_set1_epi64(0xFFFFFFFFFFFFF_u64 as i64); - // mask that covers the 53rd bit - let explicit_mantissa_bit = avx._mm512_set1_epi64(0x10000000000000_u64 as i64); - // mask that covers the first 11 bits - let exp_mask = avx._mm512_set1_epi64(0x7FF_u64 as i64); + // From Intel's documentation the syntax to use this intrinsics is + // Instruction: vcvttpd2qq zmm, zmm + // With Intel syntax, left operand is the destination, right operand is the source + // For the asm! macro + // in: indicates an input register + // out: indicates an output register + // zmm_reg: the avx512 register type + // options: see https://doc.rust-lang.org/nightly/reference/inline-assembly.html#options + // pure: no side effect + // nomem: does not reference RAM (only registers) + // nostrack: does not alter the state of the stack + core::arch::asm!( + "vcvttpd2qq {dst}, {src}", + src = in(zmm_reg) x, + dst = out(zmm_reg) as_i64x8, + options(pure, nomem, nostack) + ); - // extract the first 52 bits and add the implicit bit - let mantissa = avx._mm512_or_si512( - avx._mm512_and_si512(bits, mantissa_mask), - explicit_mantissa_bit, - ); + as_i64x8 + } + let _ = simd.avx512dq; - // extract the 52nd to 63rd (excluded) bits for the biased exponent - let biased_exp = avx._mm512_and_si512(avx._mm512_srli_epi64::<52>(bits), exp_mask); - - // extract the 63rd sign bit - let sign_is_negative_mask = - avx._mm512_cmpeq_epi64_mask(avx._mm512_srli_epi64::<63>(bits), avx._mm512_set1_epi64(1)); - - // we need to shift the mantissa by some value that may be negative, so we first shift it to - // the left by the maximum amount, then shift it to the right by our value plus the offset we - // just shifted by - // - // the 53rd bit is set to 1, so we shift to the left by 11 so the 63rd (last) bit is set. - let mantissa_lshift = avx._mm512_slli_epi64::<11>(mantissa); - - // shift to the right and apply the exponent bias - // If biased_exp == 0 then we have 0 or a subnormal value which should return 0, here we will - // shift to the right by 1086 which will return 0 as we are shifting in 0s from the left, so - // subnormals are already covered - let mantissa_shift = avx._mm512_srlv_epi64( - mantissa_lshift, - avx._mm512_sub_epi64(avx._mm512_set1_epi64(1086), biased_exp), - ); - - // if the sign bit is unset, we keep our result - let value_if_positive = mantissa_shift; - // otherwise, we negate it - let value_if_negative = avx._mm512_sub_epi64(avx._mm512_setzero_si512(), value_if_positive); - - // Select the value based on the sign mask - avx._mm512_mask_blend_epi64(sign_is_negative_mask, value_if_positive, value_if_negative) + // SAFETY: simd contains an instance of avx512dq, that matches the target feature of + // `implementation` + unsafe { implementation(x) } } /// Convert a vector of i64 values to a vector of f64 values. Not sure how it works. @@ -161,25 +166,56 @@ pub fn mm256_cvtepi64_pd(simd: V3, x: __m256i) -> __m256d { } /// Convert a vector of i64 values to a vector of f64 values. +/// This intrinsics is currently not available in rust, so we have our own implementation using +/// inline assembly. +/// +/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading +/// `_`. +/// +/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtepi64_pd`) #[cfg(feature = "nightly-avx512")] #[inline(always)] pub fn mm512_cvtepi64_pd(simd: V4, x: __m512i) -> __m512d { - #[target_feature(enable = "avx512dq")] + // This first one is required for the zmm_reg notation #[inline] + #[target_feature(enable = "sse")] + #[target_feature(enable = "sse2")] + #[target_feature(enable = "fxsr")] + #[target_feature(enable = "sse3")] + #[target_feature(enable = "ssse3")] + #[target_feature(enable = "sse4.1")] + #[target_feature(enable = "sse4.2")] + #[target_feature(enable = "popcnt")] + #[target_feature(enable = "avx")] + #[target_feature(enable = "avx2")] + #[target_feature(enable = "bmi1")] + #[target_feature(enable = "bmi2")] + #[target_feature(enable = "fma")] + #[target_feature(enable = "lzcnt")] + #[target_feature(enable = "avx512f")] + #[target_feature(enable = "avx512dq")] unsafe fn implementation(x: __m512i) -> __m512d { - // hopefully this compiles to vcvtqq2pd - let i64x8: [i64; 8] = core::mem::transmute(x); - let as_f64x8 = [ - i64x8[0] as f64, - i64x8[1] as f64, - i64x8[2] as f64, - i64x8[3] as f64, - i64x8[4] as f64, - i64x8[5] as f64, - i64x8[6] as f64, - i64x8[7] as f64, - ]; - core::mem::transmute(as_f64x8) + let mut as_f64x8: __m512d; + + // From Intel's documentation the syntax to use this intrinsics is + // Instruction: vcvtqq2pd zmm, zmm + // With Intel syntax, left operand is the destination, right operand is the source + // For the asm! macro + // in: indicates an input register + // out: indicates an output register + // zmm_reg: the avx512 register type + // options: see https://doc.rust-lang.org/nightly/reference/inline-assembly.html#options + // pure: no side effect + // nomem: does not reference RAM (only registers) + // nostrack: does not alter the state of the stack + core::arch::asm!( + "vcvtqq2pd {dst}, {src}", + src = in(zmm_reg) x, + dst = out(zmm_reg) as_f64x8, + options(pure, nomem, nostack) + ); + + as_f64x8 } let _ = simd.avx512dq; @@ -758,9 +794,9 @@ pub fn convert_add_backward_torus_u64_v4( ); // convert f64 to i64 - let fract_re = mm512_cvtpd_epi64(simd, fract_re); + let fract_re = mm512_cvtt_roundpd_epi64(simd, fract_re); // convert f64 to i64 - let fract_im = mm512_cvtpd_epi64(simd, fract_im); + let fract_im = mm512_cvtt_roundpd_epi64(simd, fract_im); // add to input and store *out_re = pulp::cast(avx512f._mm512_add_epi64(fract_re, pulp::cast(*out_re))); @@ -1157,7 +1193,7 @@ mod tests { }); let computed: [i64; 4] = - pulp::cast_lossy(mm512_cvtpd_epi64(simd, pulp::cast([v, v]))); + pulp::cast_lossy(mm512_cvtt_roundpd_epi64(simd, pulp::cast([v, v]))); assert_eq!(target, computed); } }