mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
chore(fft): rename nightly feature flag to avx512
This commit is contained in:
committed by
Nicolas Sarlin
parent
8d1f6d4d06
commit
851bd01873
36
.github/workflows/cargo_test_fft.yml
vendored
36
.github/workflows/cargo_test_fft.yml
vendored
@@ -67,44 +67,24 @@ jobs:
|
|||||||
toolchain: stable
|
toolchain: stable
|
||||||
override: true
|
override: true
|
||||||
|
|
||||||
- name: Test debug
|
- name: Test avx2
|
||||||
run: |
|
run: |
|
||||||
make test_fft
|
make test_fft
|
||||||
|
|
||||||
- name: Test serialization
|
- name: Test serialization
|
||||||
run: make test_fft_serde
|
run: make test_fft_serde
|
||||||
|
|
||||||
- name: Test no-std
|
- name: Test no-std avx2
|
||||||
run: |
|
run: |
|
||||||
make test_fft_no_std
|
make test_fft_no_std
|
||||||
|
|
||||||
cargo-tests-fft-nightly:
|
- name: Test avx512
|
||||||
name: cargo_test_fft/cargo-tests-fft-nightly
|
|
||||||
needs: should-run
|
|
||||||
if: needs.should-run.outputs.fft_test == 'true'
|
|
||||||
runs-on: ${{ matrix.runner_type }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
runner_type: [ ubuntu-latest, macos-latest, windows-latest ]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8
|
|
||||||
with:
|
|
||||||
persist-credentials: 'false'
|
|
||||||
token: ${{ env.CHECKOUT_TOKEN }}
|
|
||||||
|
|
||||||
- name: Install Rust
|
|
||||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af
|
|
||||||
with:
|
|
||||||
toolchain: nightly
|
|
||||||
override: true
|
|
||||||
|
|
||||||
- name: Test nightly
|
|
||||||
run: |
|
run: |
|
||||||
make test_fft_nightly
|
make test_fft_avx512
|
||||||
|
|
||||||
- name: Test no-std nightly
|
- name: Test no-std avx512
|
||||||
run: |
|
run: |
|
||||||
make test_fft_no_std_nightly
|
make test_fft_no_std_avx512
|
||||||
|
|
||||||
cargo-tests-fft-node-js:
|
cargo-tests-fft-node-js:
|
||||||
name: cargo_test_fft/cargo-tests-fft-node-js
|
name: cargo_test_fft/cargo-tests-fft-node-js
|
||||||
@@ -124,7 +104,7 @@ jobs:
|
|||||||
|
|
||||||
cargo-tests-fft-successful:
|
cargo-tests-fft-successful:
|
||||||
name: cargo_test_fft/cargo-tests-fft-successful (bpr)
|
name: cargo_test_fft/cargo-tests-fft-successful (bpr)
|
||||||
needs: [ should-run, cargo-tests-fft, cargo-tests-fft-nightly, cargo-tests-fft-node-js ]
|
needs: [ should-run, cargo-tests-fft, cargo-tests-fft-node-js ]
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
@@ -136,7 +116,6 @@ jobs:
|
|||||||
- name: Check all tests passed
|
- name: Check all tests passed
|
||||||
if: needs.should-run.outputs.fft_test == 'true' &&
|
if: needs.should-run.outputs.fft_test == 'true' &&
|
||||||
needs.cargo-tests-fft.result == 'success' &&
|
needs.cargo-tests-fft.result == 'success' &&
|
||||||
needs.cargo-tests-fft-nightly.result == 'success' &&
|
|
||||||
needs.cargo-tests-fft-node-js.result == 'success'
|
needs.cargo-tests-fft-node-js.result == 'success'
|
||||||
run: |
|
run: |
|
||||||
echo "All tfhe-fft test passed"
|
echo "All tfhe-fft test passed"
|
||||||
@@ -144,7 +123,6 @@ jobs:
|
|||||||
- name: Check tests failure
|
- name: Check tests failure
|
||||||
if: needs.should-run.outputs.fft_test == 'true' &&
|
if: needs.should-run.outputs.fft_test == 'true' &&
|
||||||
(needs.cargo-tests-fft.result != 'success' ||
|
(needs.cargo-tests-fft.result != 'success' ||
|
||||||
needs.cargo-tests-fft-nightly.result != 'success' ||
|
|
||||||
needs.cargo-tests-fft-node-js.result != 'success')
|
needs.cargo-tests-fft-node-js.result != 'success')
|
||||||
run: |
|
run: |
|
||||||
echo "Some tfhe-fft tests failed"
|
echo "Some tfhe-fft tests failed"
|
||||||
|
|||||||
2
.github/workflows/m1_tests.yml
vendored
2
.github/workflows/m1_tests.yml
vendored
@@ -67,9 +67,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make test_fft
|
make test_fft
|
||||||
make test_fft_serde
|
make test_fft_serde
|
||||||
make test_fft_nightly
|
|
||||||
make test_fft_no_std
|
make test_fft_no_std
|
||||||
make test_fft_no_std_nightly
|
|
||||||
# we don't run the js stuff here as it's causing issues with the M1 config
|
# we don't run the js stuff here as it's causing issues with the M1 config
|
||||||
|
|
||||||
- name: Run pcc NTT checks
|
- name: Run pcc NTT checks
|
||||||
|
|||||||
35
Makefile
35
Makefile
@@ -1919,30 +1919,24 @@ build_fft_no_std: install_rs_build_toolchain
|
|||||||
##### Tests #####
|
##### Tests #####
|
||||||
|
|
||||||
.PHONY: test_fft
|
.PHONY: test_fft
|
||||||
test_fft: install_rs_build_toolchain
|
test_fft:
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft
|
RUSTFLAGS="$(RUSTFLAGS)" cargo test --release -p tfhe-fft \
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft \
|
--no-default-features \
|
||||||
--features=fft128
|
--features=std,fft128
|
||||||
|
|
||||||
.PHONY: test_fft_serde
|
.PHONY: test_fft_serde
|
||||||
test_fft_serde: install_rs_build_toolchain
|
test_fft_serde:
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft \
|
RUSTFLAGS="$(RUSTFLAGS)" cargo test --release -p tfhe-fft \
|
||||||
--features=serde
|
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft \
|
|
||||||
--features=serde,fft128
|
--features=serde,fft128
|
||||||
|
|
||||||
.PHONY: test_fft_nightly
|
.PHONY: test_fft_avx512
|
||||||
test_fft_nightly: install_rs_check_toolchain
|
test_fft_avx512:
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --release -p tfhe-fft \
|
RUSTFLAGS="$(RUSTFLAGS)" cargo test --release -p tfhe-fft \
|
||||||
--features=nightly
|
--features=avx512,fft128
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --release -p tfhe-fft \
|
|
||||||
--features=nightly,fft128
|
|
||||||
|
|
||||||
.PHONY: test_fft_no_std
|
.PHONY: test_fft_no_std
|
||||||
test_fft_no_std: install_rs_build_toolchain
|
test_fft_no_std:
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft \
|
RUSTFLAGS="$(RUSTFLAGS)" cargo test --release -p tfhe-fft \
|
||||||
--no-default-features
|
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release -p tfhe-fft \
|
|
||||||
--no-default-features \
|
--no-default-features \
|
||||||
--features=fft128
|
--features=fft128
|
||||||
|
|
||||||
@@ -1950,10 +1944,7 @@ test_fft_no_std: install_rs_build_toolchain
|
|||||||
test_fft_no_std_nightly: install_rs_check_toolchain
|
test_fft_no_std_nightly: install_rs_check_toolchain
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --release -p tfhe-fft \
|
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --release -p tfhe-fft \
|
||||||
--no-default-features \
|
--no-default-features \
|
||||||
--features=nightly
|
--features=avx512,fft128
|
||||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --release -p tfhe-fft \
|
|
||||||
--no-default-features \
|
|
||||||
--features=nightly,fft128
|
|
||||||
|
|
||||||
.PHONY: test_fft_node_js
|
.PHONY: test_fft_node_js
|
||||||
test_fft_node_js: install_rs_build_toolchain install_build_wasm32_target install_wasm_bindgen_cli
|
test_fft_node_js: install_rs_build_toolchain install_build_wasm32_target install_wasm_bindgen_cli
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ serde = { workspace = true, optional = true }
|
|||||||
js-sys = "0.3"
|
js-sys = "0.3"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["std"]
|
default = ["std", "avx512"]
|
||||||
fft128 = []
|
fft128 = []
|
||||||
nightly = ["pulp/x86-v4"]
|
avx512 = ["pulp/x86-v4"]
|
||||||
std = ["pulp/std"]
|
std = ["pulp/std"]
|
||||||
serde = ["dep:serde", "num-complex/serde"]
|
serde = ["dep:serde", "num-complex/serde"]
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,8 @@ Additionally, an optional 128-bit negacyclic FFT module is provided.
|
|||||||
choose the fastest one at runtime.
|
choose the fastest one at runtime.
|
||||||
- `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
- `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
||||||
[`fft128`] module.
|
[`fft128`] module.
|
||||||
- `nightly`: This enables unstable Rust features to further speed up the FFT,
|
- `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
|
||||||
by enabling AVX512F instructions on CPUs that support them. This feature
|
speed up the FFT.
|
||||||
requires a nightly Rust
|
|
||||||
toolchain.
|
|
||||||
- `serde`: This enables serialization and deserialization functions for the
|
- `serde`: This enables serialization and deserialization functions for the
|
||||||
unordered plan. These allow for data in the Fourier domain to be serialized
|
unordered plan. These allow for data in the Fourier domain to be serialized
|
||||||
from the permuted order to the standard order, and deserialized from the
|
from the permuted order to the standard order, and deserialized from the
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ pub fn bench_fft128(c: &mut Criterion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
let bench_id = format!("tfhe-fft128-avx512-fwd-{n}");
|
let bench_id = format!("tfhe-fft128-avx512-fwd-{n}");
|
||||||
c.bench_function(&bench_id, |bench| {
|
c.bench_function(&bench_id, |bench| {
|
||||||
|
|||||||
@@ -932,7 +932,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 16 * simd.lane_count() {
|
if n >= 16 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 4 * simd.lane_count() {
|
if n >= 4 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -492,7 +492,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 8 * simd.lane_count() {
|
if n >= 8 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -871,7 +871,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 16 * simd.lane_count() {
|
if n >= 16 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 4 * simd.lane_count() {
|
if n >= 4 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ pub(crate) fn fft_impl<c64xN: Pod>(simd: impl FftSimd<c64xN>) -> crate::FftImpl
|
|||||||
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
pub fn fft_impl_dispatch(n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
if n >= 8 * simd.lane_count() {
|
if n >= 8 * simd.lane_count() {
|
||||||
return fft_impl(simd).make_fn_ptr(n);
|
return fft_impl(simd).make_fn_ptr(n);
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ impl f128 {
|
|||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))]
|
#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))]
|
||||||
pub mod x86 {
|
pub mod x86 {
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
use pulp::{b8, f64x8, x86::V4};
|
use pulp::{b8, f64x8, x86::V4};
|
||||||
use pulp::{f64x4, x86::V3, Simd};
|
use pulp::{f64x4, x86::V3, Simd};
|
||||||
|
|
||||||
@@ -654,7 +654,7 @@ pub mod x86 {
|
|||||||
(p, simd.mul_sub_f64x4(a, b, p))
|
(p, simd.mul_sub_f64x4(a, b, p))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn quick_two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
pub(crate) fn quick_two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||||
let s = simd.add_f64x8(a, b);
|
let s = simd.add_f64x8(a, b);
|
||||||
@@ -662,7 +662,7 @@ pub mod x86 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
pub(crate) fn two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
pub(crate) fn two_sum_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||||
let sign_bit = simd.splat_f64x8(-0.0);
|
let sign_bit = simd.splat_f64x8(-0.0);
|
||||||
let cmp = simd.cmp_gt_f64x8(
|
let cmp = simd.cmp_gt_f64x8(
|
||||||
@@ -675,19 +675,19 @@ pub mod x86 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
pub(crate) fn two_diff_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
pub(crate) fn two_diff_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||||
two_sum_f64x8(simd, a, simd.neg_f64s(b))
|
two_sum_f64x8(simd, a, simd.neg_f64s(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn two_prod_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
pub(crate) fn two_prod_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) {
|
||||||
let p = simd.mul_f64x8(a, b);
|
let p = simd.mul_f64x8(a, b);
|
||||||
(p, simd.mul_sub_f64x8(a, b, p))
|
(p, simd.mul_sub_f64x8(a, b, p))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn quick_two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
pub(crate) fn quick_two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||||
let s = simd.add_f64x16(a, b);
|
let s = simd.add_f64x16(a, b);
|
||||||
@@ -695,7 +695,7 @@ pub mod x86 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
pub(crate) fn two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
pub(crate) fn two_sum_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||||
let sign_bit = simd.splat_f64x16(-0.0);
|
let sign_bit = simd.splat_f64x16(-0.0);
|
||||||
let cmp = simd.cmp_gt_f64x16(
|
let cmp = simd.cmp_gt_f64x16(
|
||||||
@@ -708,7 +708,7 @@ pub mod x86 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
pub(crate) fn two_diff_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
pub(crate) fn two_diff_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||||
two_sum_f64x16(
|
two_sum_f64x16(
|
||||||
simd,
|
simd,
|
||||||
@@ -720,14 +720,14 @@ pub mod x86 {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn two_prod_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
pub(crate) fn two_prod_f64x16(simd: V4, a: f64x16, b: f64x16) -> (f64x16, f64x16) {
|
||||||
let p = simd.mul_f64x16(a, b);
|
let p = simd.mul_f64x16(a, b);
|
||||||
(p, simd.mul_sub_f64x16(a, b, p))
|
(p, simd.mul_sub_f64x16(a, b, p))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct f64x16 {
|
pub struct f64x16 {
|
||||||
@@ -735,7 +735,7 @@ pub mod x86 {
|
|||||||
pub hi: f64x8,
|
pub hi: f64x8,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct b16 {
|
pub struct b16 {
|
||||||
@@ -743,9 +743,9 @@ pub mod x86 {
|
|||||||
pub hi: b8,
|
pub hi: b8,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
unsafe impl bytemuck::Zeroable for f64x16 {}
|
unsafe impl bytemuck::Zeroable for f64x16 {}
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
unsafe impl bytemuck::Pod for f64x16 {}
|
unsafe impl bytemuck::Pod for f64x16 {}
|
||||||
|
|
||||||
pub trait V3F128Ext {
|
pub trait V3F128Ext {
|
||||||
@@ -756,7 +756,7 @@ pub mod x86 {
|
|||||||
fn mul_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
|
fn mul_f128x4(self, a0: f64x4, a1: f64x4, b0: f64x4, b1: f64x4) -> (f64x4, f64x4);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
pub trait V4F128Ext {
|
pub trait V4F128Ext {
|
||||||
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
||||||
fn sub_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
fn sub_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8);
|
||||||
@@ -841,7 +841,7 @@ pub mod x86 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
impl V4F128Ext for V4 {
|
impl V4F128Ext for V4 {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8) {
|
fn add_estimate_f128x8(self, a0: f64x8, a1: f64x8, b0: f64x8, b1: f64x8) -> (f64x8, f64x8) {
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ use crate::fft128::f128_ops::x86::V3F128Ext;
|
|||||||
use pulp::{f64x4, x86::V3};
|
use pulp::{f64x4, x86::V3};
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
use crate::fft128::f128_ops::x86::{f64x16, V4F128Ext};
|
use crate::fft128::f128_ops::x86::{f64x16, V4F128Ext};
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
use pulp::{f64x8, x86::V4};
|
use pulp::{f64x8, x86::V4};
|
||||||
|
|
||||||
trait FftSimdF128: Copy {
|
trait FftSimdF128: Copy {
|
||||||
@@ -41,7 +41,7 @@ trait V3InterleaveExt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
trait V4InterleaveExt {
|
trait V4InterleaveExt {
|
||||||
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2];
|
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2];
|
||||||
fn permute4_f64x8(self, w: [f64; 2]) -> f64x8;
|
fn permute4_f64x8(self, w: [f64; 2]) -> f64x8;
|
||||||
@@ -90,7 +90,7 @@ impl V3InterleaveExt for V3 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
impl V4InterleaveExt for V4 {
|
impl V4InterleaveExt for V4 {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2] {
|
fn interleave4_f64x8(self, z0z0z0z0z1z1z1z1: [f64x8; 2]) -> [f64x8; 2] {
|
||||||
@@ -217,7 +217,7 @@ impl FftSimdF128 for V3 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
impl FftSimdF128 for V4 {
|
impl FftSimdF128 for V4 {
|
||||||
type Reg = f64x8;
|
type Reg = f64x8;
|
||||||
|
|
||||||
@@ -246,12 +246,12 @@ impl FftSimdF128 for V4 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub struct V4x2(pub V4);
|
pub struct V4x2(pub V4);
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
impl FftSimdF128 for V4x2 {
|
impl FftSimdF128 for V4x2 {
|
||||||
type Reg = f64x16;
|
type Reg = f64x16;
|
||||||
|
|
||||||
@@ -664,7 +664,7 @@ pub fn negacyclic_fwd_fft_avxfma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub fn negacyclic_fwd_fft_avx512(
|
pub fn negacyclic_fwd_fft_avx512(
|
||||||
simd: V4,
|
simd: V4,
|
||||||
@@ -1052,7 +1052,7 @@ pub fn negacyclic_fwd_fft(
|
|||||||
) {
|
) {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = V4::try_new() {
|
if let Some(simd) = V4::try_new() {
|
||||||
return negacyclic_fwd_fft_avx512(
|
return negacyclic_fwd_fft_avx512(
|
||||||
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
||||||
@@ -1084,7 +1084,7 @@ pub fn negacyclic_inv_fft(
|
|||||||
) {
|
) {
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = V4::try_new() {
|
if let Some(simd) = V4::try_new() {
|
||||||
return negacyclic_inv_fft_avx512(
|
return negacyclic_inv_fft_avx512(
|
||||||
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0,
|
||||||
@@ -1406,7 +1406,7 @@ pub fn negacyclic_inv_fft_avxfma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub fn negacyclic_inv_fft_avx512(
|
pub fn negacyclic_inv_fft_avx512(
|
||||||
simd: V4,
|
simd: V4,
|
||||||
@@ -2304,7 +2304,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_product_avx512() {
|
fn test_product_avx512() {
|
||||||
if let Some(simd) = V4::try_new() {
|
if let Some(simd) = V4::try_new() {
|
||||||
@@ -2465,7 +2465,7 @@ mod x86_tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_interleaves_and_permutes_f64x8() {
|
fn test_interleaves_and_permutes_f64x8() {
|
||||||
if let Some(simd) = V4::try_new() {
|
if let Some(simd) = V4::try_new() {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use core::{f64, fmt::Debug, marker::PhantomData};
|
|||||||
pub struct c64x2(c64, c64);
|
pub struct c64x2(c64, c64);
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct c64x4(c64, c64, c64, c64);
|
pub struct c64x4(c64, c64, c64, c64);
|
||||||
@@ -23,7 +23,7 @@ const __ASSERT_POD: () = {
|
|||||||
|
|
||||||
// no padding
|
// no padding
|
||||||
assert!(core::mem::size_of::<c64x2>() == core::mem::size_of::<c64>() * 2);
|
assert!(core::mem::size_of::<c64x2>() == core::mem::size_of::<c64>() * 2);
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
assert!(core::mem::size_of::<c64x4>() == core::mem::size_of::<c64>() * 4);
|
assert!(core::mem::size_of::<c64x4>() == core::mem::size_of::<c64>() * 4);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ const __ASSERT_POD: () = {
|
|||||||
unsafe impl bytemuck::Zeroable for c64x2 {}
|
unsafe impl bytemuck::Zeroable for c64x2 {}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
unsafe impl bytemuck::Zeroable for c64x4 {}
|
unsafe impl bytemuck::Zeroable for c64x4 {}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
@@ -40,7 +40,7 @@ unsafe impl bytemuck::Zeroable for c64x4 {}
|
|||||||
unsafe impl bytemuck::Pod for c64x2 {}
|
unsafe impl bytemuck::Pod for c64x2 {}
|
||||||
|
|
||||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
unsafe impl bytemuck::Pod for c64x4 {}
|
unsafe impl bytemuck::Pod for c64x4 {}
|
||||||
|
|
||||||
pub trait Pod: Copy + Debug + bytemuck::Pod {}
|
pub trait Pod: Copy + Debug + bytemuck::Pod {}
|
||||||
|
|||||||
@@ -21,9 +21,8 @@
|
|||||||
//! an FFT plan that measures the various implementations to choose the fastest one at runtime.
|
//! an FFT plan that measures the various implementations to choose the fastest one at runtime.
|
||||||
//! - `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
//! - `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
|
||||||
//! `fft128` module.
|
//! `fft128` module.
|
||||||
//! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling
|
//! - `avx512` (default): This enables AVX512F instructions on CPUs that support them to further
|
||||||
//! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust
|
//! speed up the FFT.
|
||||||
//! toolchain.
|
|
||||||
//! - `serde`: This enables serialization and deserialization functions for the unordered plan.
|
//! - `serde`: This enables serialization and deserialization functions for the unordered plan.
|
||||||
//! These allow for data in the Fourier domain to be serialized from the permuted order to the
|
//! These allow for data in the Fourier domain to be serialized from the permuted order to the
|
||||||
//! standard order, and deserialized from the standard order to the permuted order.
|
//! standard order, and deserialized from the standard order to the permuted order.
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ mod tests {
|
|||||||
if let Some(simd) = pulp::x86::V3::try_new() {
|
if let Some(simd) = pulp::x86::V3::try_new() {
|
||||||
test_fft_simd(simd);
|
test_fft_simd(simd);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
test_fft_simd(simd);
|
test_fft_simd(simd);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ macro_rules! dispatcher {
|
|||||||
fn $name() -> fn(&mut [c64], &[c64]) {
|
fn $name() -> fn(&mut [c64], &[c64]) {
|
||||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if pulp::x86::V4::try_new().is_some() {
|
if pulp::x86::V4::try_new().is_some() {
|
||||||
return |z, w| {
|
return |z, w| {
|
||||||
let simd = pulp::x86::V4::try_new().unwrap();
|
let simd = pulp::x86::V4::try_new().unwrap();
|
||||||
@@ -335,7 +335,7 @@ dispatcher!(get_inv_process_x8, inv_process_x8);
|
|||||||
fn get_complex_per_reg() -> usize {
|
fn get_complex_per_reg() -> usize {
|
||||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||||
{
|
{
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
if let Some(simd) = pulp::x86::V4::try_new() {
|
if let Some(simd) = pulp::x86::V4::try_new() {
|
||||||
return simd.lane_count();
|
return simd.lane_count();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ impl FftSimd<c64x2> for V3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "nightly")]
|
#[cfg(feature = "avx512")]
|
||||||
impl FftSimd<c64x4> for V4 {
|
impl FftSimd<c64x4> for V4 {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn try_new() -> Option<Self> {
|
fn try_new() -> Option<Self> {
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
//! # Features
|
//! # Features
|
||||||
//!
|
//!
|
||||||
//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
|
//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
|
||||||
//! - `avx512` (default): This enables unstable Rust features to further speed up the NTT, by
|
//! - `avx512` (default): This enables AVX512 instructions on CPUs that support them to further
|
||||||
//! enabling AVX512 instructions on CPUs that support them.
|
//! speed up the NTT.
|
||||||
//!
|
//!
|
||||||
//! # Example
|
//! # Example
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ high-level-client-js-wasm-api = [
|
|||||||
]
|
]
|
||||||
parallel-wasm-api = ["dep:wasm-bindgen-rayon"]
|
parallel-wasm-api = ["dep:wasm-bindgen-rayon"]
|
||||||
|
|
||||||
nightly-avx512 = ["tfhe-fft/nightly", "tfhe-ntt/avx512", "pulp/x86-v4"]
|
nightly-avx512 = ["tfhe-fft/avx512", "tfhe-ntt/avx512", "pulp/x86-v4"]
|
||||||
|
|
||||||
# Private features
|
# Private features
|
||||||
__profiling = []
|
__profiling = []
|
||||||
|
|||||||
Reference in New Issue
Block a user