mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
chore(fft): rename nightly feature flag to avx512
This commit is contained in:
committed by
Nicolas Sarlin
parent
8d1f6d4d06
commit
851bd01873
@@ -21,9 +21,9 @@ serde = { workspace = true, optional = true }
|
||||
js-sys = "0.3"
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
default = ["std", "avx512"]
|
||||
fft128 = []
|
||||
nightly = ["pulp/x86-v4"]
|
||||
avx512 = ["pulp/x86-v4"]
|
||||
std = ["pulp/std"]
|
||||
serde = ["dep:serde", "num-complex/serde"]
|
||||
|
||||
|
||||
@@ -23,10 +23,8 @@ Additionally, an optional 128-bit negacyclic FFT module is provided.
|
||||
choose the fastest one at runtime.
|
||||
- `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
||||
[`fft128`] module.
|
||||
- `nightly`: This enables unstable Rust features to further speed up the FFT,
|
||||
by enabling AVX512F instructions on CPUs that support them. This feature
|
||||
requires a nightly Rust
|
||||
toolchain.
|
||||
- `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
|
||||
speed up the FFT.
|
||||
- `serde`: This enables serialization and deserialization functions for the
|
||||
unordered plan. These allow for data in the Fourier domain to be serialized
|
||||
from the permuted order to the standard order, and deserialized from the
|
||||
|
||||
@@ -292,7 +292,7 @@ pub fn bench_fft128(c: &mut Criterion) {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
let bench_id = format!("tfhe-fft128-avx512-fwd-{n}");
|
||||
c.bench_function(&bench_id, |bench| {
|
||||
|
||||
@@ -932,7 +932,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 16 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -333,7 +333,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 4 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -492,7 +492,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 8 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -871,7 +871,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 16 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -310,7 +310,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 4 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -457,7 +457,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
if n >= 8 * simd.lane_count() {
|
||||
return fft_impl(simd).make_fn_ptr(n);
|
||||
|
||||
@@ -621,7 +621,7 @@ impl f128 {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))]
|
||||
pub mod x86 {
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
use pulp::{b8, f64x8, x86::V4};
|
||||
use pulp::{f64x4, x86::V3, Simd};
|
||||
|
||||
@@ -654,7 +654,7 @@ pub mod x86 {
|
||||
(p, simd.mul_sub_f64x4(a, b, p))
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn quick_two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||
let s = simd.add_f64x8(a, b);
|
||||
@@ -662,7 +662,7 @@ pub mod x86 {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
pub(crate) fn two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||
let sign_bit = simd.splat_f64x8(-0.0);
|
||||
let cmp = simd.cmp_gt_f64x8(
|
||||
@@ -675,19 +675,19 @@ pub mod x86 {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
pub(crate) fn two_diff_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||
two_sum_f64x8(simd, a, simd.neg_f64s(b))
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn two_prod_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||
let p = simd.mul_f64x8(a, b);
|
||||
(p, simd.mul_sub_f64x8(a, b, p))
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn quick_two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||
let s = simd.add_f64x16(a, b);
|
||||
@@ -695,7 +695,7 @@ pub mod x86 {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
pub(crate) fn two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||
let sign_bit = simd.splat_f64x16(-0.0);
|
||||
let cmp = simd.cmp_gt_f64x16(
|
||||
@@ -708,7 +708,7 @@ pub mod x86 {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
pub(crate) fn two_diff_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||
two_sum_f64x16(
|
||||
simd,
|
||||
@@ -720,14 +720,14 @@ pub mod x86 {
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn two_prod_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||
let p = simd.mul_f64x16(a, b);
|
||||
(p, simd.mul_sub_f64x16(a, b, p))
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(C)]
|
||||
pub struct f64x16 {
|
||||
@@ -735,7 +735,7 @@ pub mod x86 {
|
||||
pub hi: f64x8,
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(C)]
|
||||
pub struct b16 {
|
||||
@@ -743,9 +743,9 @@ pub mod x86 {
|
||||
pub hi: b8,
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
unsafe impl bytemuck::Zeroable for f64x16 {}
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
unsafe impl bytemuck::Pod for f64x16 {}
|
||||
|
||||
pub trait V3F128Ext {
|
||||
@@ -756,7 +756,7 @@ pub mod x86 {
|
||||
fn mul_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
pub trait V4F128Ext {
|
||||
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
||||
fn sub_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
||||
@@ -841,7 +841,7 @@ pub mod x86 {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
impl V4F128Ext for V4 {
|
||||
#[inline(always)]
|
||||
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8) {
|
||||
|
||||
@@ -16,10 +16,10 @@ use crate::fft128::f128_ops::x86::V3F128Ext;
|
||||
use pulp::{f64x4, x86::V3};
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
use crate::fft128::f128_ops::x86::{f64x16, V4F128Ext};
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
use pulp::{f64x8, x86::V4};
|
||||
|
||||
trait FftSimdF128: Copy {
|
||||
@@ -41,7 +41,7 @@ trait V3InterleaveExt {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
trait V4InterleaveExt {
|
||||
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2];
|
||||
fn permute4_f64x8(self, w: [f64; 2]) -> f64x8;
|
||||
@@ -90,7 +90,7 @@ impl V3InterleaveExt for V3 {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
impl V4InterleaveExt for V4 {
|
||||
#[inline(always)]
|
||||
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2] {
|
||||
@@ -217,7 +217,7 @@ impl FftSimdF128 for V3 {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
impl FftSimdF128 for V4 {
|
||||
type Reg = f64x8;
|
||||
|
||||
@@ -246,12 +246,12 @@ impl FftSimdF128 for V4 {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct V4x2(pub V4);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
impl FftSimdF128 for V4x2 {
|
||||
type Reg = f64x16;
|
||||
|
||||
@@ -664,7 +664,7 @@ pub fn negacyclic_fwd_fft_avxfma(
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[doc(hidden)]
|
||||
pub fn negacyclic_fwd_fft_avx512(
|
||||
simd: V4,
|
||||
@@ -1052,7 +1052,7 @@ pub fn negacyclic_fwd_fft(
|
||||
) {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = V4::try_new() {
|
||||
return negacyclic_fwd_fft_avx512(
|
||||
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
||||
@@ -1084,7 +1084,7 @@ pub fn negacyclic_inv_fft(
|
||||
) {
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = V4::try_new() {
|
||||
return negacyclic_inv_fft_avx512(
|
||||
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
||||
@@ -1406,7 +1406,7 @@ pub fn negacyclic_inv_fft_avxfma(
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[doc(hidden)]
|
||||
pub fn negacyclic_inv_fft_avx512(
|
||||
simd: V4,
|
||||
@@ -2304,7 +2304,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[test]
|
||||
fn test_product_avx512() {
|
||||
if let Some(simd) = V4::try_new() {
|
||||
@@ -2465,7 +2465,7 @@ mod x86_tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[test]
|
||||
fn test_interleaves_and_permutes_f64x8() {
|
||||
if let Some(simd) = V4::try_new() {
|
||||
|
||||
@@ -7,7 +7,7 @@ use core::{f64, fmt::Debug, marker::PhantomData};
|
||||
pub struct c64x2(c64, c64);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(C)]
|
||||
pub struct c64x4(c64, c64, c64, c64);
|
||||
@@ -23,7 +23,7 @@ const __ASSERT_POD: () = {
|
||||
|
||||
// no padding
|
||||
assert!(core::mem::size_of::<c64x2>() == core::mem::size_of::<c64>() * 2);
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
assert!(core::mem::size_of::<c64x4>() == core::mem::size_of::<c64>() * 4);
|
||||
};
|
||||
|
||||
@@ -32,7 +32,7 @@ const __ASSERT_POD: () = {
|
||||
unsafe impl bytemuck::Zeroable for c64x2 {}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
unsafe impl bytemuck::Zeroable for c64x4 {}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
@@ -40,7 +40,7 @@ unsafe impl bytemuck::Zeroable for c64x4 {}
|
||||
unsafe impl bytemuck::Pod for c64x2 {}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
unsafe impl bytemuck::Pod for c64x4 {}
|
||||
|
||||
pub trait Pod: Copy + Debug + bytemuck::Pod {}
|
||||
|
||||
@@ -21,9 +21,8 @@
|
||||
//! an FFT plan that measures the various implementations to choose the fastest one at runtime.
|
||||
//! - `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
||||
//! `fft128` module.
|
||||
//! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling
|
||||
//! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust
|
||||
//! toolchain.
|
||||
//! - `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
|
||||
//! speed up the FFT.
|
||||
//! - `serde`: This enables serialization and deserialization functions for the unordered plan.
|
||||
//! These allow for data in the Fourier domain to be serialized from the permuted order to the
|
||||
//! standard order, and deserialized from the standard order to the permuted order.
|
||||
|
||||
@@ -457,7 +457,7 @@ mod tests {
|
||||
if let Some(simd) = pulp::x86::V3::try_new() {
|
||||
test_fft_simd(simd);
|
||||
}
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
test_fft_simd(simd);
|
||||
}
|
||||
|
||||
@@ -297,7 +297,7 @@ macro_rules! dispatcher {
|
||||
fn $name() -> fn(&mut [c64], &[c64]) {
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if pulp::x86::V4::try_new().is_some() {
|
||||
return |z, w| {
|
||||
let simd = pulp::x86::V4::try_new().unwrap();
|
||||
@@ -335,7 +335,7 @@ dispatcher!(get_inv_process_x8, inv_process_x8);
|
||||
fn get_complex_per_reg() -> usize {
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||
return simd.lane_count();
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ impl FftSimd<c64x2> for V3 {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
#[cfg(feature = "avx512")]
|
||||
impl FftSimd<c64x4> for V4 {
|
||||
#[inline(always)]
|
||||
fn try_new() -> Option<Self> {
|
||||
|
||||
Reference in New Issue
Block a user