chore(fft): rename nightly feature flag to avx512

This commit is contained in:
Nicolas Sarlin
2025-11-19 15:21:49 +01:00
committed by Nicolas Sarlin
parent 8d1f6d4d06
commit 851bd01873
21 changed files with 72 additions and 108 deletions

View File

@@ -21,9 +21,9 @@ serde = { workspace = true, optional = true }
js-sys = "0.3"
[features]
default = ["std"]
default = ["std", "avx512"]
fft128 = []
nightly = ["pulp/x86-v4"]
avx512 = ["pulp/x86-v4"]
std = ["pulp/std"]
serde = ["dep:serde", "num-complex/serde"]

View File

@@ -23,10 +23,8 @@ Additionally, an optional 128-bit negacyclic FFT module is provided.
choose the fastest one at runtime.
- `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
[`fft128`] module.
- `nightly`: This enables unstable Rust features to further speed up the FFT,
by enabling AVX512F instructions on CPUs that support them. This feature
requires a nightly Rust
toolchain.
- `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
speed up the FFT.
- `serde`: This enables serialization and deserialization functions for the
unordered plan. These allow for data in the Fourier domain to be serialized
from the permuted order to the standard order, and deserialized from the

View File

@@ -292,7 +292,7 @@ pub fn bench_fft128(c: &mut Criterion) {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
let bench_id = format!("tfhe-fft128-avx512-fwd-{n}");
c.bench_function(&bench_id, |bench| {

View File

@@ -932,7 +932,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 16 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -333,7 +333,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 4 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -492,7 +492,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 8 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -871,7 +871,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 16 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -310,7 +310,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 4 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -457,7 +457,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
if n >= 8 * simd.lane_count() {
return fft_impl(simd).make_fn_ptr(n);

View File

@@ -621,7 +621,7 @@ impl f128 {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))]
pub mod x86 {
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
use pulp::{b8, f64x8, x86::V4};
use pulp::{f64x4, x86::V3, Simd};
@@ -654,7 +654,7 @@ pub mod x86 {
(p, simd.mul_sub_f64x4(a, b, p))
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[inline(always)]
pub(crate) fn quick_two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
let s = simd.add_f64x8(a, b);
@@ -662,7 +662,7 @@ pub mod x86 {
}
#[inline(always)]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
pub(crate) fn two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
let sign_bit = simd.splat_f64x8(-0.0);
let cmp = simd.cmp_gt_f64x8(
@@ -675,19 +675,19 @@ pub mod x86 {
}
#[inline(always)]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
pub(crate) fn two_diff_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
two_sum_f64x8(simd, a, simd.neg_f64s(b))
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[inline(always)]
pub(crate) fn two_prod_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
let p = simd.mul_f64x8(a, b);
(p, simd.mul_sub_f64x8(a, b, p))
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[inline(always)]
pub(crate) fn quick_two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let s = simd.add_f64x16(a, b);
@@ -695,7 +695,7 @@ pub mod x86 {
}
#[inline(always)]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
pub(crate) fn two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let sign_bit = simd.splat_f64x16(-0.0);
let cmp = simd.cmp_gt_f64x16(
@@ -708,7 +708,7 @@ pub mod x86 {
}
#[inline(always)]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
pub(crate) fn two_diff_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
two_sum_f64x16(
simd,
@@ -720,14 +720,14 @@ pub mod x86 {
)
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[inline(always)]
pub(crate) fn two_prod_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
let p = simd.mul_f64x16(a, b);
(p, simd.mul_sub_f64x16(a, b, p))
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct f64x16 {
@@ -735,7 +735,7 @@ pub mod x86 {
pub hi: f64x8,
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct b16 {
@@ -743,9 +743,9 @@ pub mod x86 {
pub hi: b8,
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
unsafe impl bytemuck::Zeroable for f64x16 {}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
unsafe impl bytemuck::Pod for f64x16 {}
pub trait V3F128Ext {
@@ -756,7 +756,7 @@ pub mod x86 {
fn mul_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
pub trait V4F128Ext {
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
fn sub_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
@@ -841,7 +841,7 @@ pub mod x86 {
}
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
impl V4F128Ext for V4 {
#[inline(always)]
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8) {

View File

@@ -16,10 +16,10 @@ use crate::fft128::f128_ops::x86::V3F128Ext;
use pulp::{f64x4, x86::V3};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
use crate::fft128::f128_ops::x86::{f64x16, V4F128Ext};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
use pulp::{f64x8, x86::V4};
trait FftSimdF128: Copy {
@@ -41,7 +41,7 @@ trait V3InterleaveExt {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
trait V4InterleaveExt {
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2];
fn permute4_f64x8(self, w: [f64; 2]) -> f64x8;
@@ -90,7 +90,7 @@ impl V3InterleaveExt for V3 {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
impl V4InterleaveExt for V4 {
#[inline(always)]
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2] {
@@ -217,7 +217,7 @@ impl FftSimdF128 for V3 {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
impl FftSimdF128 for V4 {
type Reg = f64x8;
@@ -246,12 +246,12 @@ impl FftSimdF128 for V4 {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[derive(Copy, Clone, Debug)]
pub struct V4x2(pub V4);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
impl FftSimdF128 for V4x2 {
type Reg = f64x16;
@@ -664,7 +664,7 @@ pub fn negacyclic_fwd_fft_avxfma(
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[doc(hidden)]
pub fn negacyclic_fwd_fft_avx512(
simd: V4,
@@ -1052,7 +1052,7 @@ pub fn negacyclic_fwd_fft(
) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = V4::try_new() {
return negacyclic_fwd_fft_avx512(
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
@@ -1084,7 +1084,7 @@ pub fn negacyclic_inv_fft(
) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = V4::try_new() {
return negacyclic_inv_fft_avx512(
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
@@ -1406,7 +1406,7 @@ pub fn negacyclic_inv_fft_avxfma(
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[doc(hidden)]
pub fn negacyclic_inv_fft_avx512(
simd: V4,
@@ -2304,7 +2304,7 @@ mod tests {
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[test]
fn test_product_avx512() {
if let Some(simd) = V4::try_new() {
@@ -2465,7 +2465,7 @@ mod x86_tests {
}
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[test]
fn test_interleaves_and_permutes_f64x8() {
if let Some(simd) = V4::try_new() {

View File

@@ -7,7 +7,7 @@ use core::{f64, fmt::Debug, marker::PhantomData};
pub struct c64x2(c64, c64);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
#[derive(Copy, Clone, Debug)]
#[repr(C)]
pub struct c64x4(c64, c64, c64, c64);
@@ -23,7 +23,7 @@ const __ASSERT_POD: () = {
// no padding
assert!(core::mem::size_of::<c64x2>() == core::mem::size_of::<c64>() * 2);
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
assert!(core::mem::size_of::<c64x4>() == core::mem::size_of::<c64>() * 4);
};
@@ -32,7 +32,7 @@ const __ASSERT_POD: () = {
unsafe impl bytemuck::Zeroable for c64x2 {}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
unsafe impl bytemuck::Zeroable for c64x4 {}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
@@ -40,7 +40,7 @@ unsafe impl bytemuck::Zeroable for c64x4 {}
unsafe impl bytemuck::Pod for c64x2 {}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
unsafe impl bytemuck::Pod for c64x4 {}
pub trait Pod: Copy + Debug + bytemuck::Pod {}

View File

@@ -21,9 +21,8 @@
//! an FFT plan that measures the various implementations to choose the fastest one at runtime.
//! - `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
//! `fft128` module.
//! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling
//! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust
//! toolchain.
//! - `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
//! speed up the FFT.
//! - `serde`: This enables serialization and deserialization functions for the unordered plan.
//! These allow for data in the Fourier domain to be serialized from the permuted order to the
//! standard order, and deserialized from the standard order to the permuted order.

View File

@@ -457,7 +457,7 @@ mod tests {
if let Some(simd) = pulp::x86::V3::try_new() {
test_fft_simd(simd);
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
test_fft_simd(simd);
}

View File

@@ -297,7 +297,7 @@ macro_rules! dispatcher {
fn $name() -> fn(&mut [c64], &[c64]) {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if pulp::x86::V4::try_new().is_some() {
return |z, w| {
let simd = pulp::x86::V4::try_new().unwrap();
@@ -335,7 +335,7 @@ dispatcher!(get_inv_process_x8, inv_process_x8);
fn get_complex_per_reg() -> usize {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
if let Some(simd) = pulp::x86::V4::try_new() {
return simd.lane_count();
}

View File

@@ -73,7 +73,7 @@ impl FftSimd<c64x2> for V3 {
}
}
#[cfg(feature = "nightly")]
#[cfg(feature = "avx512")]
impl FftSimd<c64x4> for V4 {
#[inline(always)]
fn try_new() -> Option<Self> {