feat(fft): update concrete-fft to 0.2.1

This commit is contained in:
sarah el kazdadi
2023-03-27 00:34:38 +02:00
committed by Arthur Meyre
parent 475b838943
commit 10174cdac6
19 changed files with 197 additions and 420 deletions

View File

@@ -40,9 +40,9 @@ lazy_static = { version = "1.4.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
rayon = { version = "1.5.0" }
bincode = { version = "1.3.3", optional = true }
concrete-fft = { version = "0.1", features = ["serde"] }
concrete-fft = { version = "0.2.1", features = ["serde"] }
aligned-vec = { version = "0.5", features = ["serde"] }
dyn-stack = { version = "0.8" }
dyn-stack = { version = "0.9" }
once_cell = "1.13"
paste = "1.0.7"
fs2 = { version = "0.4.3", optional = true }
@@ -57,6 +57,7 @@ js-sys = { version = "0.3", optional = true }
console_error_panic_hook = { version = "0.1.7", optional = true }
serde-wasm-bindgen = { version = "0.4", optional = true }
getrandom = { version = "0.2.8", optional = true }
bytemuck = "1.13.1"
[features]
boolean = []
@@ -73,15 +74,7 @@ __c_api = ["cbindgen", "bincode"]
boolean-c-api = ["boolean", "__c_api"]
shortint-c-api = ["shortint", "__c_api"]
__wasm_api = [
"wasm-bindgen",
"js-sys",
"console_error_panic_hook",
"serde-wasm-bindgen",
"getrandom",
"getrandom/js",
"bincode",
]
__wasm_api = ["wasm-bindgen", "js-sys", "console_error_panic_hook", "serde-wasm-bindgen", "getrandom", "getrandom/js", "bincode"]
boolean-client-js-wasm-api = ["boolean", "__wasm_api"]
shortint-client-js-wasm-api = ["shortint", "__wasm_api"]

View File

@@ -10,7 +10,7 @@ use crate::core_crypto::fft_impl::crypto::ggsw::{
};
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
use concrete_fft::c64;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
/// Convert a [`GGSW ciphertext`](`GgswCiphertext`) with standard coefficients to the Fourier
/// domain.
@@ -51,7 +51,7 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized<Scalar, InputCo
input_ggsw: &GgswCiphertext<InputCont>,
output_ggsw: &mut FourierGgswCiphertext<OutputCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,

View File

@@ -10,7 +10,7 @@ use crate::core_crypto::fft_impl::crypto::bootstrap::{
};
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
use concrete_fft::c64;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
/// Convert an [`LWE bootstrap key`](`LweBootstrapKey`) with standard coefficients to the Fourier
/// domain.
@@ -47,7 +47,7 @@ pub fn convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized<Scalar, Input
input_bsk: &LweBootstrapKey<InputCont>,
output_bsk: &mut FourierLweBootstrapKey<OutputCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,

View File

@@ -7,7 +7,7 @@ use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
use concrete_fft::c64;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
/// Convert an [`LWE multi_bit bootstrap key`](`LweMultiBitBootstrapKey`) with standard
/// coefficients to the Fourier domain.
@@ -48,7 +48,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized<
input_bsk: &LweMultiBitBootstrapKey<InputCont>,
output_bsk: &mut FourierLweMultiBitBootstrapKey<OutputCont>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
mut stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
@@ -67,11 +67,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized<
.zip(input_bsk_as_polynomial_list.iter())
{
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
fft.forward_as_torus(
unsafe { fourier_poly.into_uninit() },
coef_poly,
stack.rb_mut(),
);
fft.forward_as_torus(fourier_poly, coef_poly, stack.rb_mut());
}
}

View File

@@ -4,7 +4,6 @@ use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::as_mut_uninit;
use crate::core_crypto::fft_impl::crypto::bootstrap::pbs_modulus_switch;
use crate::core_crypto::fft_impl::crypto::ggsw::{
add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd,
@@ -359,8 +358,7 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
.data()
.copy_from_slice(ggsw_a_none.as_view().data());
let multi_bit_fourier_ggsw =
unsafe { as_mut_uninit(fourier_ggsw_buffer.as_mut_view().data()) };
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
@@ -394,20 +392,18 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
);
fft.forward_as_integer(
unsafe { fourier_a_monomial.as_mut_view().into_uninit() },
fourier_a_monomial.as_mut_view(),
a_monomial.as_view(),
buffers.stack(),
);
unsafe {
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
lut_poly_size.to_fourier_polynomial_size().0,
);
}
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
lut_poly_size.to_fourier_polynomial_size().0,
);
}
// Drop the lock before we wake other threads

View File

@@ -14,7 +14,7 @@ use crate::core_crypto::fft_impl::crypto::ggsw::{
use crate::core_crypto::fft_impl::crypto::wop_pbs::blind_rotate_assign_scratch;
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
use concrete_fft::c64;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
@@ -242,14 +242,14 @@ pub fn blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
}
/// Memory optimized version of [`blind_rotate_assign`], the caller must provide
/// a properly configured [`FftView`] object and a `DynStack` used as a memory buffer having a
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
/// capacity at least as large as the result of [`blind_rotate_assign_mem_optimized_requirement`].
pub fn blind_rotate_assign_mem_optimized<Scalar, InputCont, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
lut: &mut GlweCiphertext<OutputCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
@@ -309,7 +309,7 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
}
/// Memory optimized version of [`add_external_product_assign`], the caller must provide a properly
/// configured [`FftView`] object and a `DynStack` used as a memory buffer having a capacity at
/// configured [`FftView`] object and a `PodStack` used as a memory buffer having a capacity at
/// least as large as the result of [`add_external_product_assign_mem_optimized_requirement`].
///
/// Compute the external product of `ggsw` and `glwe`, and add the result to `out`.
@@ -450,7 +450,7 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
ggsw: &FourierGgswCiphertext<GgswCont>,
glwe: &GlweCiphertext<InputGlweCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
OutputGlweCont: ContainerMut<Element = Scalar>,
@@ -532,7 +532,7 @@ pub fn cmux_assign<Scalar, Cont0, Cont1, GgswCont>(
}
/// Memory optimized version of [`cmux_assign`], the caller must provide a properly configured
/// [`FftView`] object and a `DynStack` used as a memory buffer having a capacity at least as large
/// [`FftView`] object and a `PodStack` used as a memory buffer having a capacity at least as large
/// as the result of [`cmux_assign_mem_optimized_requirement`].
///
/// # Example
@@ -717,7 +717,7 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
ct1: &mut GlweCiphertext<Cont1>,
ggsw: &FourierGgswCiphertext<GgswCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
Cont0: ContainerMut<Element = Scalar>,
@@ -975,7 +975,7 @@ pub fn programmable_bootstrap_lwe_ciphertext<Scalar, InputCont, OutputCont, AccC
}
/// Memory optimized version of [`programmable_bootstrap_lwe_ciphertext`], the caller must provide
/// a properly configured [`FftView`] object and a `DynStack` used as a memory buffer having a
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
/// capacity at least as large as the result of
/// [`programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement`].
pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
@@ -990,7 +990,7 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized<
accumulator: &GlweCiphertext<AccCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,

View File

@@ -13,7 +13,7 @@ use crate::core_crypto::fft_impl::crypto::wop_pbs::{
};
use crate::core_crypto::fft_impl::math::fft::FftView;
use concrete_fft::c64;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use rayon::prelude::*;
/// Allocate a new [`list of LWE private functional packing keyswitch
@@ -275,7 +275,7 @@ pub fn par_generate_circuit_bootstrap_lwe_pfpksk_list<
/// ciphertext`](`LweCiphertext`), containing the encryption of the bit scaled by q/2 (i.e., the
/// most significant bit in the plaintext representation).
///
/// The caller must provide a properly configured [`FftView`] object and a `DynStack` used as a
/// The caller must provide a properly configured [`FftView`] object and a `PodStack` used as a
/// memory buffer having a capacity at least as large as the result of
/// [`extract_bits_from_lwe_ciphertext_mem_optimized_requirement`].
///
@@ -302,7 +302,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized<
delta_log: DeltaLog,
number_of_bits_to_extract: ExtractedBitsCount,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,
@@ -378,7 +378,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
/// |[ polynomial 1 ]|...|[ polynomial 1 ]|
/// ```
///
/// The caller must provide a properly configured [`FftView`] object and a `DynStack` used as a
/// The caller must provide a properly configured [`FftView`] object and a `PodStack` used as a
/// memory buffer having a capacity at least as large as the result of
/// [`circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimized_requirement`].
///
@@ -617,7 +617,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimi
base_log_cbs: DecompositionBaseLog,
level_cbs: DecompositionLevelCount,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,

View File

@@ -1,13 +1,12 @@
//! Module containing primitives to manage computations buffers for memory optimized fft primitives.
use core::mem::MaybeUninit;
use dyn_stack::DynStack;
use dyn_stack::PodStack;
#[derive(Default)]
/// Struct containing a resizable buffer that can be used with a `DynStack` to provide memory
/// Struct containing a resizable buffer that can be used with a `PodStack` to provide memory
/// buffers for memory optimized fft primitives.
pub struct ComputationBuffers {
memory: Vec<MaybeUninit<u8>>,
memory: Vec<u8>,
}
impl ComputationBuffers {
@@ -19,12 +18,12 @@ impl ComputationBuffers {
/// Resize the underlying memory buffer, reallocating memory when capacity exceeds the current
/// buffer capacity.
pub fn resize(&mut self, capacity: usize) {
self.memory.resize_with(capacity, MaybeUninit::uninit);
self.memory.resize(capacity, 0);
}
/// Return a `DynStack` borrowoing from the managed memory buffer for use with optimized fft
/// primitives or other functions using `DynStack` to manage temporary memory.
pub fn stack(&mut self) -> DynStack<'_> {
DynStack::new(&mut self.memory)
/// Return a `PodStack` borrowoing from the managed memory buffer for use with optimized fft
/// primitives or other functions using `PodStack` to manage temporary memory.
pub fn stack(&mut self) -> PodStack<'_> {
PodStack::new(&mut self.memory)
}
}

View File

@@ -21,7 +21,7 @@ mod signed;
mod unsigned;
/// A trait implemented by any generic numeric type suitable for computations.
pub trait Numeric: Sized + Copy + PartialEq + PartialOrd + 'static {
pub trait Numeric: Sized + Copy + PartialEq + PartialOrd + bytemuck::Pod + 'static {
/// This size of the type in bits.
const BITS: usize;

View File

@@ -15,7 +15,7 @@ use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use concrete_fft::c64;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "C: IntoContainerOwned"))]
@@ -179,7 +179,7 @@ impl<'a> FourierLweBootstrapKeyMutView<'a> {
mut self,
coef_bsk: LweBootstrapKey<&'_ [Scalar]>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
mut stack: PodStack<'_>,
) {
for (fourier_ggsw, standard_ggsw) in
izip!(self.as_mut_view().into_ggsw_iter(), coef_bsk.iter())
@@ -217,7 +217,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
mut lut: GlweCiphertextMutView<'_, Scalar>,
lwe: &[Scalar],
fft: FftView<'_>,
mut stack: DynStack<'_>,
mut stack: PodStack<'_>,
) {
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();
@@ -277,7 +277,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
lwe_in: &[Scalar],
accumulator: GlweCiphertextView<'_, Scalar>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize>,

View File

@@ -1,9 +1,6 @@
use core::mem::MaybeUninit;
use super::super::math::decomposition::TensorSignedDecompositionLendingIter;
use super::super::math::fft::{FftView, FourierPolynomialList};
use super::super::math::polynomial::{FourierPolynomialUninitMutView, FourierPolynomialView};
use super::super::{as_mut_uninit, assume_init_mut};
use super::super::math::polynomial::{FourierPolynomialMutView, FourierPolynomialView};
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
@@ -16,7 +13,7 @@ use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use concrete_fft::c64;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
@@ -261,7 +258,7 @@ impl<'a> FourierGgswCiphertextMutView<'a> {
self,
coef_ggsw: GgswCiphertextView<'_, Scalar>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
mut stack: PodStack<'_>,
) {
debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size());
let fourier_poly_size = coef_ggsw.polynomial_size().to_fourier_polynomial_size().0;
@@ -270,11 +267,8 @@ impl<'a> FourierGgswCiphertextMutView<'a> {
self.data().into_chunks(fourier_poly_size),
coef_ggsw.as_polynomial_list().iter()
) {
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
fft.forward_as_torus(
FourierPolynomialUninitMutView {
data: unsafe { as_mut_uninit(fourier_poly) },
},
FourierPolynomialMutView { data: fourier_poly },
coef_poly,
stack.rb_mut(),
);
@@ -342,7 +336,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
ggsw: FourierGgswCiphertextView<'_>,
glwe: GlweCiphertext<InputGlweCont>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) where
Scalar: UnsignedTorus,
InputGlweCont: Container<Element = Scalar>,
@@ -364,7 +358,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
);
let (mut output_fft_buffer, mut substack0) =
stack.make_aligned_uninit::<c64>(fourier_poly_size * ggsw.glwe_size().0, align);
stack.make_aligned_raw::<c64>(fourier_poly_size * ggsw.glwe_size().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
@@ -412,11 +406,11 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
.for_each(|(ggsw_row, glwe_poly)| {
let (mut fourier, substack3) = substack2
.rb_mut()
.make_aligned_uninit::<c64>(fourier_poly_size, align);
.make_aligned_raw::<c64>(fourier_poly_size, align);
// We perform the forward fft transform for the glwe polynomial
let fourier = fft
.forward_as_integer(
FourierPolynomialUninitMutView { data: &mut fourier },
FourierPolynomialMutView { data: &mut fourier },
glwe_poly,
substack3,
)
@@ -424,16 +418,13 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.
// SAFETY: see comment above definition of `output_fft_buffer`
unsafe {
update_with_fmadd(
output_fft_buffer,
ggsw_row.data(),
fourier,
is_output_uninit,
fourier_poly_size,
)
};
update_with_fmadd(
output_fft_buffer,
ggsw_row.data(),
fourier,
is_output_uninit,
fourier_poly_size,
);
// we initialized `output_fft_buffer, so we can set this to false
is_output_uninit = false;
@@ -447,8 +438,6 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
//
// We iterate over the polynomials in the output.
if !is_output_uninit {
// SAFETY: output_fft_buffer is initialized, since `is_output_uninit` is false
let output_fft_buffer = &*unsafe { assume_init_mut(output_fft_buffer) };
izip!(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer
@@ -464,25 +453,20 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
#[cfg_attr(__profiling, inline(never))]
fn collect_next_term<'a, Scalar: UnsignedTorus>(
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
substack1: &'a mut DynStack,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, Scalar>,
DynStack<'a>,
PodStack<'a>,
) {
let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap();
let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
(glwe_level, glwe_decomp_term, substack2)
}
/// # Note
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
/// - `is_x86_feature_detected!("avx512f")` must be true.
#[cfg(all(
feature = "nightly-avx512",
@@ -490,7 +474,7 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>(
))]
#[target_feature(enable = "avx512f")]
unsafe fn update_with_fmadd_avx512(
output_fourier: &mut [MaybeUninit<c64>],
output_fourier: &mut [c64],
ggsw_poly: &[c64],
fourier: &[c64],
is_output_uninit: bool,
@@ -540,18 +524,13 @@ unsafe fn update_with_fmadd_avx512(
}
}
/// # Note
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
/// - `is_x86_feature_detected!("fma")` must be true.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma")]
unsafe fn update_with_fmadd_fma(
output_fourier: &mut [MaybeUninit<c64>],
output_fourier: &mut [c64],
ggsw_poly: &[c64],
fourier: &[c64],
is_output_uninit: bool,
@@ -601,15 +580,11 @@ unsafe fn update_with_fmadd_fma(
}
}
/// # Note
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
unsafe fn update_with_fmadd_scalar(
output_fourier: &mut [MaybeUninit<c64>],
output_fourier: &mut [c64],
ggsw_poly: &[c64],
fourier: &[c64],
is_output_uninit: bool,
@@ -618,34 +593,27 @@ unsafe fn update_with_fmadd_scalar(
// we're writing to output_fft_buffer for the first time
// so its contents are uninitialized
izip!(output_fourier, ggsw_poly, fourier).for_each(|(out_fourier, lhs, rhs)| {
out_fourier.write(lhs * rhs);
*out_fourier = lhs * rhs;
});
} else {
// we already wrote to output_fft_buffer, so we can assume its contents are
// initialized.
izip!(output_fourier, ggsw_poly, fourier).for_each(|(out_fourier, lhs, rhs)| {
*{ out_fourier.assume_init_mut() } += lhs * rhs;
*out_fourier += lhs * rhs;
});
}
}
/// # Note
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
#[cfg_attr(__profiling, inline(never))]
pub(crate) unsafe fn update_with_fmadd(
output_fft_buffer: &mut [MaybeUninit<c64>],
pub(crate) fn update_with_fmadd(
output_fft_buffer: &mut [c64],
lhs_polynomial_list: &[c64],
fourier: &[c64],
is_output_uninit: bool,
fourier_poly_size: usize,
) {
#[allow(clippy::type_complexity)]
let ptr_fn = || -> unsafe fn(&mut [MaybeUninit<c64>], &[c64], &[c64], bool) {
let ptr_fn = || -> unsafe fn(&mut [c64], &[c64], &[c64], bool) {
#[cfg(all(
feature = "nightly-avx512",
any(target_arch = "x86_64", target_arch = "x86")
@@ -667,8 +635,8 @@ pub(crate) unsafe fn update_with_fmadd(
output_fft_buffer.into_chunks(fourier_poly_size),
lhs_polynomial_list.into_chunks(fourier_poly_size)
)
.for_each(|(output_fourier, poly_from_list)| {
ptr(output_fourier, poly_from_list, fourier, is_output_uninit);
.for_each(|(output_fourier, ggsw_poly)| {
unsafe { ptr(output_fourier, ggsw_poly, fourier, is_output_uninit) };
});
}
@@ -687,7 +655,7 @@ pub fn cmux<Scalar: UnsignedTorus>(
mut ct1: GlweCiphertextMutView<'_, Scalar>,
ggsw: FourierGgswCiphertextView<'_>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
izip!(ct1.as_mut(), ct0.as_ref(),).for_each(|(c1, c0)| {
*c1 = c1.wrapping_sub(*c0);

View File

@@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)]
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
use super::super::math::fft::{FftView, FourierPolynomialList};
use super::bootstrap::{bootstrap_scratch, FourierLweBootstrapKeyView};
@@ -64,7 +64,7 @@ pub fn extract_bits<Scalar: UnsignedTorus + CastInto<usize>>(
delta_log: DeltaLog,
number_of_bits_to_extract: ExtractedBitsCount,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
let ciphertext_n_bits = Scalar::BITS;
let number_of_bits_to_extract = number_of_bits_to_extract.0;
@@ -217,7 +217,7 @@ pub fn circuit_bootstrap_boolean<Scalar: UnsignedTorus + CastInto<usize>>(
delta_log: DeltaLog,
pfpksk_list: LwePrivateFunctionalPackingKeyswitchKeyList<&[Scalar]>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
let level_cbs = ggsw_out.decomposition_level_count();
let base_log_cbs = ggsw_out.decomposition_base_log();
@@ -337,7 +337,7 @@ pub fn homomorphic_shift_boolean<Scalar: UnsignedTorus + CastInto<usize>>(
base_log_cbs: DecompositionBaseLog,
delta_log: DeltaLog,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
let ciphertext_n_bits = Scalar::BITS;
let lwe_in_size = lwe_in.lwe_size();
@@ -644,7 +644,7 @@ pub fn cmux_tree_memory_optimized<Scalar: UnsignedTorus + CastInto<usize>>(
lut_per_layer: PolynomialList<&[Scalar]>,
ggsw_list: FourierGgswCiphertextListView<'_>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
debug_assert!(lut_per_layer.polynomial_count().0 == 1 << ggsw_list.count());
@@ -818,7 +818,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing<Scalar: UnsignedTorus + CastIn
level_cbs: DecompositionLevelCount,
base_log_cbs: DecompositionBaseLog,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
debug_assert!(stack.can_hold(
circuit_bootstrap_boolean_vertical_packing_scratch::<Scalar>(
@@ -945,7 +945,7 @@ pub fn vertical_packing<Scalar: UnsignedTorus + CastInto<usize>>(
mut lwe_out: LweCiphertext<&mut [Scalar]>,
ggsw_list: FourierGgswCiphertextListView<'_>,
fft: FftView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
let polynomial_size = ggsw_list.polynomial_size();
let glwe_size = ggsw_list.glwe_size();
@@ -1015,7 +1015,7 @@ pub fn blind_rotate_assign<Scalar: UnsignedTorus + CastInto<usize>>(
mut lut: GlweCiphertext<&mut [Scalar]>,
ggsw_list: FourierGgswCiphertextListView<'_>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
mut stack: PodStack<'_>,
) {
let mut monomial_degree = MonomialDegree(1);

View File

@@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::math::fft::Fft;
use crate::core_crypto::seeders::new_seeder;
use concrete_csprng::generators::SoftwareRandomGenerator;
use concrete_fft::c64;
use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut, StackReq};
use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut, StackReq};
// Extract all the bits of a LWE
#[test]
@@ -96,8 +96,8 @@ pub fn test_extract_bits() {
])
};
let req = req().unwrap();
let mut mem = GlobalMemBuffer::new(req);
let mut stack = DynStack::new(&mut mem);
let mut mem = GlobalPodBuffer::new(req);
let mut stack = PodStack::new(&mut mem);
fourier_bsk
.as_mut_view()
@@ -225,8 +225,8 @@ fn test_circuit_bootstrapping_binary() {
let fft = Fft::new(polynomial_size);
let fft = fft.as_view();
let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let stack = DynStack::new(&mut mem);
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let stack = PodStack::new(&mut mem);
fourier_bsk
.as_mut_view()
.fill_with_forward_fourier(std_bsk.as_view(), fft, stack);
@@ -269,7 +269,7 @@ fn test_circuit_bootstrapping_binary() {
level_count_cbs,
);
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
circuit_bootstrap_boolean_scratch::<u64>(
lwe_in.lwe_size(),
fourier_bsk.output_lwe_dimension().to_lwe_size(),
@@ -279,7 +279,7 @@ fn test_circuit_bootstrapping_binary() {
)
.unwrap(),
);
let stack = DynStack::new(&mut mem);
let stack = PodStack::new(&mut mem);
// Execute the CBS
circuit_bootstrap_boolean(
fourier_bsk.as_view(),
@@ -464,15 +464,15 @@ pub fn test_cmux_tree() {
&mut encryption_generator,
);
let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let stack = DynStack::new(&mut mem);
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let stack = PodStack::new(&mut mem);
fourier_ggsw
.as_mut_view()
.fill_with_forward_fourier(ggsw.as_view(), fft, stack);
}
let mut result_cmux_tree = GlweCiphertextOwned::new(0_u64, glwe_size, polynomial_size);
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
cmux_tree_memory_optimized_scratch::<u64>(glwe_size, polynomial_size, nb_ggsw, fft)
.unwrap(),
);
@@ -481,7 +481,7 @@ pub fn test_cmux_tree() {
lut.as_view(),
ggsw_list.as_view(),
fft,
DynStack::new(&mut mem),
PodStack::new(&mut mem),
);
let mut decrypted_result =
PlaintextListOwned::new(0u64, PlaintextCount(glwe_sk.polynomial_size().0));
@@ -568,11 +568,11 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
let fft = Fft::new(polynomial_size);
let fft = fft.as_view();
let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
fourier_bsk.as_mut_view().fill_with_forward_fourier(
std_bsk.as_view(),
fft,
DynStack::new(&mut mem),
PodStack::new(&mut mem),
);
let ksk_lwe_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
@@ -630,7 +630,7 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
LweCiphertextCount(number_of_values_to_extract.0),
);
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
extract_bits_scratch::<u64>(
lwe_dimension,
ksk_lwe_big_to_small.output_key_lwe_dimension(),
@@ -648,7 +648,7 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
delta_log,
number_of_values_to_extract,
fft,
DynStack::new(&mut mem),
PodStack::new(&mut mem),
);
// Decrypt all extracted bit for checking purposes in case of problems
@@ -700,7 +700,7 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
);
// Perform circuit bootstrap + vertical packing
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
extracted_bits_lwe_list.lwe_ciphertext_count(),
vertical_packing_lwe_list_out.lwe_ciphertext_count(),
@@ -723,7 +723,7 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
level_cbs,
base_log_cbs,
fft,
DynStack::new(&mut mem),
PodStack::new(&mut mem),
);
// We have a single output ct

View File

@@ -1,7 +1,7 @@
pub use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use dyn_stack::{DynArray, DynStack};
use dyn_stack::{DynArray, PodStack};
use std::iter::Map;
use std::slice::IterMut;
@@ -28,8 +28,8 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu
input: impl Iterator<Item = Scalar>,
base_log: DecompositionBaseLog,
level: DecompositionLevelCount,
stack: DynStack<'buffers>,
) -> (Self, DynStack<'buffers>) {
stack: PodStack<'buffers>,
) -> (Self, PodStack<'buffers>) {
let shift = Scalar::BITS - base_log.0 * level.0;
let (states, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift));

View File

@@ -1,8 +1,4 @@
use super::super::assume_init_mut;
use super::polynomial::{
FourierPolynomialMutView, FourierPolynomialUninitMutView, FourierPolynomialView,
PolynomialUninitMutView,
};
use super::polynomial::{FourierPolynomialMutView, FourierPolynomialView};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize};
@@ -12,12 +8,12 @@ use crate::core_crypto::entities::*;
use aligned_vec::{avec, ABox};
use concrete_fft::c64;
use concrete_fft::unordered::{Method, Plan};
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use once_cell::sync::OnceCell;
use std::any::TypeId;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::mem::{align_of, size_of, MaybeUninit};
use std::mem::{align_of, size_of};
use std::sync::{Arc, RwLock};
use std::time::Duration;
@@ -182,7 +178,7 @@ impl Fft {
#[cfg_attr(__profiling, inline(never))]
fn convert_forward_torus<Scalar: UnsignedTorus>(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[Scalar],
in_im: &[Scalar],
twisties: TwistiesView<'_>,
@@ -193,21 +189,19 @@ fn convert_forward_torus<Scalar: UnsignedTorus>(
|(out, in_re, in_im, w_re, w_im)| {
let in_re: f64 = in_re.into_signed().cast_into() * normalization;
let in_im: f64 = in_im.into_signed().cast_into() * normalization;
out.write(
c64 {
re: in_re,
im: in_im,
} * c64 {
re: *w_re,
im: *w_im,
},
);
*out = c64 {
re: in_re,
im: in_im,
} * c64 {
re: *w_re,
im: *w_im,
};
},
);
}
fn convert_forward_integer_scalar<Scalar: UnsignedTorus>(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[Scalar],
in_im: &[Scalar],
twisties: TwistiesView<'_>,
@@ -216,22 +210,20 @@ fn convert_forward_integer_scalar<Scalar: UnsignedTorus>(
|(out, in_re, in_im, w_re, w_im)| {
let in_re: f64 = in_re.into_signed().cast_into();
let in_im: f64 = in_im.into_signed().cast_into();
out.write(
c64 {
re: in_re,
im: in_im,
} * c64 {
re: *w_re,
im: *w_im,
},
);
*out = c64 {
re: in_re,
im: in_im,
} * c64 {
re: *w_re,
im: *w_im,
};
},
);
}
#[cfg_attr(__profiling, inline(never))]
fn convert_forward_integer<Scalar: UnsignedTorus>(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[Scalar],
in_im: &[Scalar],
twisties: TwistiesView<'_>,
@@ -247,15 +239,14 @@ fn convert_forward_integer<Scalar: UnsignedTorus>(
}
}
// SAFETY: same as above
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
convert_forward_integer_scalar::<Scalar>(out, in_re, in_im, twisties)
}
#[cfg_attr(__profiling, inline(never))]
fn convert_backward_torus<Scalar: UnsignedTorus>(
out_re: &mut [MaybeUninit<Scalar>],
out_im: &mut [MaybeUninit<Scalar>],
out_re: &mut [Scalar],
out_im: &mut [Scalar],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -268,20 +259,15 @@ fn convert_backward_torus<Scalar: UnsignedTorus>(
im: -*w_im,
} * normalization);
out_re.write(Scalar::from_torus(tmp.re));
out_im.write(Scalar::from_torus(tmp.im));
*out_re = Scalar::from_torus(tmp.re);
*out_im = Scalar::from_torus(tmp.im);
},
);
}
/// See [`convert_add_backward_torus`].
///
/// # Safety
///
/// - Same preconditions as [`convert_add_backward_torus`].
unsafe fn convert_add_backward_torus_scalar<Scalar: UnsignedTorus>(
out_re: &mut [MaybeUninit<Scalar>],
out_im: &mut [MaybeUninit<Scalar>],
fn convert_add_backward_torus_scalar<Scalar: UnsignedTorus>(
out_re: &mut [Scalar],
out_im: &mut [Scalar],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -294,27 +280,16 @@ unsafe fn convert_add_backward_torus_scalar<Scalar: UnsignedTorus>(
im: -*w_im,
} * normalization);
let out_re = out_re.assume_init_mut();
let out_im = out_im.assume_init_mut();
*out_re = Scalar::wrapping_add(*out_re, Scalar::from_torus(tmp.re));
*out_im = Scalar::wrapping_add(*out_im, Scalar::from_torus(tmp.im));
},
);
}
/// # Warning
///
/// This function is actually unsafe, but can't be marked as such since we need it to implement
/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait.
///
/// # Safety
///
/// - `out_re` and `out_im` must not hold any uninitialized values.
#[cfg_attr(__profiling, inline(never))]
fn convert_add_backward_torus<Scalar: UnsignedTorus>(
out_re: &mut [MaybeUninit<Scalar>],
out_im: &mut [MaybeUninit<Scalar>],
out_re: &mut [Scalar],
out_im: &mut [Scalar],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -329,11 +304,8 @@ fn convert_add_backward_torus<Scalar: UnsignedTorus>(
}
}
// SAFETY: same as above
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
unsafe {
convert_add_backward_torus_scalar::<Scalar>(out_re, out_im, inp, twisties)
};
convert_add_backward_torus_scalar::<Scalar>(out_re, out_im, inp, twisties);
}
impl<'a> FftView<'a> {
@@ -388,12 +360,11 @@ impl<'a> FftView<'a> {
/// have size equal to that amount divided by two.
pub fn forward_as_torus<'out, Scalar: UnsignedTorus>(
self,
fourier: FourierPolynomialUninitMutView<'out>,
fourier: FourierPolynomialMutView<'out>,
standard: PolynomialView<'_, Scalar>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) -> FourierPolynomialMutView<'out> {
// SAFETY: `convert_forward_torus` initializes the output slice that is passed to it
unsafe { self.forward_with_conv(fourier, standard, convert_forward_torus, stack) }
self.forward_with_conv(fourier, standard, convert_forward_torus, stack)
}
/// Perform a negacyclic real FFT of `standard`, viewed as integers, and stores the result in
@@ -409,12 +380,11 @@ impl<'a> FftView<'a> {
/// have size equal to that amount divided by two.
pub fn forward_as_integer<'out, Scalar: UnsignedTorus>(
self,
fourier: FourierPolynomialUninitMutView<'out>,
fourier: FourierPolynomialMutView<'out>,
standard: PolynomialView<'_, Scalar>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) -> FourierPolynomialMutView<'out> {
// SAFETY: `convert_forward_integer` initializes the output slice that is passed to it
unsafe { self.forward_with_conv(fourier, standard, convert_forward_integer, stack) }
self.forward_with_conv(fourier, standard, convert_forward_integer, stack)
}
/// Perform an inverse negacyclic real FFT of `fourier` and stores the result in `standard`,
@@ -429,12 +399,11 @@ impl<'a> FftView<'a> {
/// See [`Self::forward_as_torus`]
pub fn backward_as_torus<Scalar: UnsignedTorus>(
self,
standard: PolynomialUninitMutView<'_, Scalar>,
standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
// SAFETY: `convert_backward_torus` initializes the output slices that are passed to it
unsafe { self.backward_with_conv(standard, fourier, convert_backward_torus, stack) }
self.backward_with_conv(standard, fourier, convert_backward_torus, stack)
}
/// Perform an inverse negacyclic real FFT of `fourier` and adds the result to `standard`,
@@ -451,32 +420,21 @@ impl<'a> FftView<'a> {
self,
standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialView<'_>,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
// SAFETY: `convert_add_backward_torus` initializes the output slices that are passed to it
unsafe {
self.backward_with_conv(
standard.into_uninit(),
fourier,
convert_add_backward_torus,
stack,
)
}
self.backward_with_conv(standard, fourier, convert_add_backward_torus, stack)
}
/// # Safety
///
/// `conv_fn` must initialize the entirety of the mutable slice that it receives.
unsafe fn forward_with_conv<
fn forward_with_conv<
'out,
Scalar: UnsignedTorus,
F: Fn(&mut [MaybeUninit<c64>], &[Scalar], &[Scalar], TwistiesView<'_>),
F: Fn(&mut [c64], &[Scalar], &[Scalar], TwistiesView<'_>),
>(
self,
fourier: FourierPolynomialUninitMutView<'out>,
fourier: FourierPolynomialMutView<'out>,
standard: PolynomialView<'_, Scalar>,
conv_fn: F,
stack: DynStack<'_>,
stack: PodStack<'_>,
) -> FourierPolynomialMutView<'out> {
let fourier = fourier.data;
let standard = standard.as_ref();
@@ -484,23 +442,19 @@ impl<'a> FftView<'a> {
debug_assert_eq!(n, 2 * fourier.len());
let (standard_re, standard_im) = standard.split_at(n / 2);
conv_fn(fourier, standard_re, standard_im, self.twisties);
let fourier = assume_init_mut(fourier);
self.plan.fwd(fourier, stack);
FourierPolynomialMutView { data: fourier }
}
/// # Safety
///
/// `conv_fn` must initialize the entirety of the mutable slices that it receives.
unsafe fn backward_with_conv<
fn backward_with_conv<
Scalar: UnsignedTorus,
F: Fn(&mut [MaybeUninit<Scalar>], &mut [MaybeUninit<Scalar>], &[c64], TwistiesView<'_>),
F: Fn(&mut [Scalar], &mut [Scalar], &[c64], TwistiesView<'_>),
>(
self,
mut standard: PolynomialUninitMutView<'_, Scalar>,
mut standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialView<'_>,
conv_fn: F,
stack: DynStack<'_>,
stack: PodStack<'_>,
) {
let fourier = fourier.data;
let standard = standard.as_mut();

View File

@@ -1,4 +1,4 @@
use dyn_stack::{GlobalMemBuffer, ReborrowMut};
use dyn_stack::{GlobalPodBuffer, ReborrowMut};
use super::super::polynomial::FourierPolynomial;
use super::*;
@@ -33,23 +33,15 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
*x = generator.random_uniform();
}
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
fft.forward_scratch()
.unwrap()
.and(fft.backward_scratch().unwrap()),
);
let mut stack = DynStack::new(&mut mem);
let mut stack = PodStack::new(&mut mem);
fft.forward_as_torus(
unsafe { fourier.as_mut_view().into_uninit() },
poly.as_view(),
stack.rb_mut(),
);
fft.backward_as_torus(
unsafe { roundtrip.as_mut_view().into_uninit() },
fourier.as_view(),
stack.rb_mut(),
);
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut());
for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
if Scalar::BITS == 32 {
@@ -118,33 +110,22 @@ fn test_product<Scalar: UnsignedTorus>() {
}
}
let mut mem = GlobalMemBuffer::new(
let mut mem = GlobalPodBuffer::new(
fft.forward_scratch()
.unwrap()
.and(fft.backward_scratch().unwrap()),
);
let mut stack = DynStack::new(&mut mem);
let mut stack = PodStack::new(&mut mem);
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
fft.forward_as_torus(
unsafe { fourier0.as_mut_view().into_uninit() },
poly0.as_view(),
stack.rb_mut(),
);
// SAFETY: forward_as_integer doesn't write any uninitialized values into its output
fft.forward_as_integer(
unsafe { fourier1.as_mut_view().into_uninit() },
poly1.as_view(),
stack.rb_mut(),
);
fft.forward_as_torus(fourier0.as_mut_view(), poly0.as_view(), stack.rb_mut());
fft.forward_as_integer(fourier1.as_mut_view(), poly1.as_view(), stack.rb_mut());
for (f0, f1) in izip!(&mut *fourier0.data, &*fourier1.data) {
*f0 *= *f1;
}
// SAFETY: backward_as_torus doesn't write any uninitialized values into its output
fft.backward_as_torus(
unsafe { convolution_from_fft.as_mut_view().into_uninit() },
convolution_from_fft.as_mut_view(),
fourier0.as_view(),
stack.rb_mut(),
);

View File

@@ -16,7 +16,6 @@ use core::arch::x86_64::*;
use super::super::super::c64;
use super::TwistiesView;
use std::mem::MaybeUninit;
/// Convert a vector of f64 values to a vector of i64 values.
/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version.
@@ -196,7 +195,7 @@ pub unsafe fn mm512_cvtepi64_pd(x: __m512i) -> __m512d {
#[cfg(feature = "nightly-avx512")]
#[target_feature(enable = "avx512f")]
pub unsafe fn convert_forward_integer_u32_avx512f(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u32],
in_im: &[u32],
twisties: TwistiesView<'_>,
@@ -265,7 +264,7 @@ pub unsafe fn convert_forward_integer_u32_avx512f(
#[cfg(feature = "nightly-avx512")]
#[target_feature(enable = "avx512f,avx512dq")]
pub unsafe fn convert_forward_integer_u64_avx512f_avx512dq(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u64],
in_im: &[u64],
twisties: TwistiesView<'_>,
@@ -332,7 +331,7 @@ pub unsafe fn convert_forward_integer_u64_avx512f_avx512dq(
/// - `is_x86_feature_detected!("fma")` must be true.
#[target_feature(enable = "avx,fma")]
pub unsafe fn convert_forward_integer_u32_fma(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u32],
in_im: &[u32],
twisties: TwistiesView<'_>,
@@ -399,7 +398,7 @@ pub unsafe fn convert_forward_integer_u32_fma(
/// - `is_x86_feature_detected!("fma")` must be true.
#[target_feature(enable = "avx,avx2,fma")]
pub unsafe fn convert_forward_integer_u64_avx2_fma(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u64],
in_im: &[u64],
twisties: TwistiesView<'_>,
@@ -523,13 +522,12 @@ pub unsafe fn convert_torus_prologue_avx512f(
///
/// # Safety
///
/// - Same preconditions as [`convert_add_backward_torus`].
/// - `is_x86_feature_detected!("avx512f")` must be true.
#[cfg(feature = "nightly-avx512")]
#[target_feature(enable = "avx512f")]
pub unsafe fn convert_add_backward_torus_u32_avx512f(
out_re: &mut [MaybeUninit<u32>],
out_im: &mut [MaybeUninit<u32>],
out_re: &mut [u32],
out_im: &mut [u32],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -577,13 +575,12 @@ pub unsafe fn convert_add_backward_torus_u32_avx512f(
///
/// # Safety
///
/// - Same preconditions as [`convert_add_backward_torus`].
/// - `is_x86_feature_detected!("avx512f")` must be true.
#[cfg(feature = "nightly-avx512")]
#[target_feature(enable = "avx512f")]
pub unsafe fn convert_add_backward_torus_u64_avx512f(
out_re: &mut [MaybeUninit<u64>],
out_im: &mut [MaybeUninit<u64>],
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -693,12 +690,11 @@ pub unsafe fn convert_torus_prologue_fma(
///
/// # Safety
///
/// - Same preconditions as [`convert_add_backward_torus`].
/// - `is_x86_feature_detected!("fma")` must be true.
#[target_feature(enable = "avx,fma")]
pub unsafe fn convert_add_backward_torus_u32_fma(
out_re: &mut [MaybeUninit<u32>],
out_im: &mut [MaybeUninit<u32>],
out_re: &mut [u32],
out_im: &mut [u32],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -742,17 +738,14 @@ pub unsafe fn convert_add_backward_torus_u32_fma(
}
}
/// See [`convert_add_backward_torus`].
///
/// # Safety
///
/// - Same preconditions as [`convert_add_backward_torus`].
/// - `is_x86_feature_detected!("avx2")` must be true.
/// - `is_x86_feature_detected!("fma")` must be true.
#[target_feature(enable = "avx2,fma")]
pub unsafe fn convert_add_backward_torus_u64_fma(
out_re: &mut [MaybeUninit<u64>],
out_im: &mut [MaybeUninit<u64>],
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
@@ -797,14 +790,14 @@ pub unsafe fn convert_add_backward_torus_u64_fma(
}
pub fn convert_forward_integer_u32(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u32],
in_im: &[u32],
twisties: TwistiesView<'_>,
) {
// this is a function that returns a function pointer to the right simd function
#[allow(clippy::type_complexity)]
let ptr_fn = || -> unsafe fn(&mut [MaybeUninit<c64>], &[u32], &[u32], TwistiesView<'_>) {
let ptr_fn = || -> unsafe fn(&mut [c64], &[u32], &[u32], TwistiesView<'_>) {
#[cfg(feature = "nightly-avx512")]
if is_x86_feature_detected!("avx512f") {
return convert_forward_integer_u32_avx512f;
@@ -819,21 +812,19 @@ pub fn convert_forward_integer_u32(
// we call it to get the function pointer to the right simd function
let ptr = ptr_fn();
// SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im`
// do not hold any uninitialized values since that is a precondition of calling this
// function
// SAFETY: the target x86 feature availability was checked
unsafe { ptr(out, in_re, in_im, twisties) }
}
pub fn convert_forward_integer_u64(
out: &mut [MaybeUninit<c64>],
out: &mut [c64],
in_re: &[u64],
in_im: &[u64],
twisties: TwistiesView<'_>,
) {
#[allow(clippy::type_complexity)]
// this is a function that returns a function pointer to the right simd function
let ptr_fn = || -> unsafe fn(&mut [MaybeUninit<c64>], &[u64], &[u64], TwistiesView<'_>) {
let ptr_fn = || -> unsafe fn(&mut [c64], &[u64], &[u64], TwistiesView<'_>) {
#[cfg(feature = "nightly-avx512")]
if is_x86_feature_detected!("avx512f") & is_x86_feature_detected!("avx512dq") {
return convert_forward_integer_u64_avx512f_avx512dq;
@@ -848,34 +839,19 @@ pub fn convert_forward_integer_u64(
// we call it to get the function pointer to the right simd function
let ptr = ptr_fn();
// SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im`
// do not hold any uninitialized values since that is a precondition of calling this
// function
// SAFETY: the target x86 feature availability was checked
unsafe { ptr(out, in_re, in_im, twisties) }
}
/// # Warning
///
/// This function is actually unsafe, but can't be marked as such since we need it to implement
/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait.
///
/// # Safety
///
/// - `out_re` and `out_im` must not hold any uninitialized values.
pub fn convert_add_backward_torus_u32(
out_re: &mut [MaybeUninit<u32>],
out_im: &mut [MaybeUninit<u32>],
out_re: &mut [u32],
out_im: &mut [u32],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
// this is a function that returns a function pointer to the right simd function
#[allow(clippy::type_complexity)]
let ptr_fn = || -> unsafe fn (
&mut [MaybeUninit<u32>],
&mut [MaybeUninit<u32>],
&[c64],
TwistiesView<'_>,
) {
let ptr_fn = || -> unsafe fn(&mut [u32], &mut [u32], &[c64], TwistiesView<'_>) {
#[cfg(feature = "nightly-avx512")]
if is_x86_feature_detected!("avx512f") {
return convert_add_backward_torus_u32_avx512f;
@@ -890,34 +866,19 @@ pub fn convert_add_backward_torus_u32(
// we call it to get the function pointer to the right simd function
let ptr = ptr_fn();
// SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im`
// do not hold any uninitialized values since that is a precondition of calling this
// function
// SAFETY: the target x86 feature availability was checked
unsafe { ptr(out_re, out_im, inp, twisties) }
}
/// # Warning
///
/// This function is actually unsafe, but can't be marked as such since we need it to implement
/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait.
///
/// # Safety
///
/// - `out_re` and `out_im` must not hold any uninitialized values.
pub fn convert_add_backward_torus_u64(
out_re: &mut [MaybeUninit<u64>],
out_im: &mut [MaybeUninit<u64>],
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[c64],
twisties: TwistiesView<'_>,
) {
// this is a function that returns a function pointer to the right simd function
#[allow(clippy::type_complexity)]
let ptr_fn = || -> unsafe fn (
&mut [MaybeUninit<u64>],
&mut [MaybeUninit<u64>],
&[c64],
TwistiesView<'_>,
) {
let ptr_fn = || -> unsafe fn(&mut [u64], &mut [u64], &[c64], TwistiesView<'_>) {
#[cfg(feature = "nightly-avx512")]
if is_x86_feature_detected!("avx512f") {
return convert_add_backward_torus_u64_avx512f;
@@ -932,18 +893,14 @@ pub fn convert_add_backward_torus_u64(
// we call it to get the function pointer to the right simd function
let ptr = ptr_fn();
// SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im`
// do not hold any uninitialized values since that is a precondition of calling this
// function
// SAFETY: the target x86 feature availability was checked
unsafe { ptr(out_re, out_im, inp, twisties) }
}
#[cfg(test)]
mod tests {
use std::mem::transmute;
use crate::core_crypto::fft_impl::as_mut_uninit;
use crate::core_crypto::fft_impl::math::fft::{convert_add_backward_torus_scalar, Twisties};
use std::mem::transmute;
use super::*;
@@ -998,15 +955,15 @@ mod tests {
unsafe {
convert_add_backward_torus_u64_fma(
as_mut_uninit(&mut out_fma_re),
as_mut_uninit(&mut out_fma_im),
&mut out_fma_re,
&mut out_fma_im,
&input,
twisties.as_view(),
);
convert_add_backward_torus_scalar(
as_mut_uninit(&mut out_scalar_re),
as_mut_uninit(&mut out_scalar_im),
&mut out_scalar_re,
&mut out_scalar_im,
&input,
twisties.as_view(),
);
@@ -1035,15 +992,15 @@ mod tests {
unsafe {
convert_add_backward_torus_u64_avx512f(
as_mut_uninit(&mut out_avx_re),
as_mut_uninit(&mut out_avx_im),
&mut out_avx_re,
&mut out_avx_im,
&input,
twisties.as_view(),
);
convert_add_backward_torus_scalar(
as_mut_uninit(&mut out_scalar_re),
as_mut_uninit(&mut out_scalar_im),
&mut out_scalar_re,
&mut out_scalar_im,
&input,
twisties.as_view(),
);

View File

@@ -1,7 +1,5 @@
use super::super::as_mut_uninit;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::Polynomial;
use aligned_vec::{avec, ABox};
use concrete_fft::c64;
@@ -37,24 +35,6 @@ impl FourierPolynomial<ABox<[c64]>> {
}
}
/// Polynomial in the standard domain, with possibly uninitialized coefficients.
///
/// This is used for the Fourier transforms to avoid the cost of initializing the output buffer,
/// which can be non negligible.
pub type PolynomialUninitMutView<'a, Scalar> = Polynomial<&'a mut [core::mem::MaybeUninit<Scalar>]>;
/// Polynomial in the Fourier domain, with possibly uninitialized coefficients.
///
/// This is used for the Fourier transforms to avoid the cost of initializing the output buffer,
/// which can be non negligible.
///
/// # Note
///
/// Polynomials in the Fourier domain have half the size of the corresponding polynomials in
/// the standard domain.
pub type FourierPolynomialUninitMutView<'a> =
FourierPolynomial<&'a mut [core::mem::MaybeUninit<c64>]>;
impl<C: Container<Element = c64>> FourierPolynomial<C> {
pub fn as_view(&self) -> FourierPolynomialView<'_> {
FourierPolynomial {
@@ -75,23 +55,3 @@ impl<C: Container<Element = c64>> FourierPolynomial<C> {
PolynomialSize(self.data.container_len() * 2)
}
}
impl<'a, Scalar> Polynomial<&'a mut [Scalar]> {
/// # Safety
///
/// No uninitialized values must be written into the output buffer when the borrow ends
pub unsafe fn into_uninit(self) -> PolynomialUninitMutView<'a, Scalar> {
PolynomialUninitMutView::from_container(as_mut_uninit(self.into_container()))
}
}
impl<'a> FourierPolynomialMutView<'a> {
/// # Safety
///
/// No uninitialized values must be written into the output buffer when the borrow ends
pub unsafe fn into_uninit(self) -> FourierPolynomialUninitMutView<'a> {
FourierPolynomialUninitMutView {
data: as_mut_uninit(self.data),
}
}
}

View File

@@ -1,32 +1,5 @@
#![doc(hidden)]
pub use concrete_fft::c64;
use core::mem::MaybeUninit;
pub mod crypto;
pub mod math;
/// Convert a mutable slice reference to an uninitialized mutable slice reference.
///
/// # Safety
///
/// No uninitialized values must be written into the output slice by the time the borrow ends
#[inline]
pub unsafe fn as_mut_uninit<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
let len = slice.len();
let ptr = slice.as_mut_ptr();
// SAFETY: T and MaybeUninit<T> have the same layout
core::slice::from_raw_parts_mut(ptr as *mut _, len)
}
/// Convert an uninitialized mutable slice reference to an initialized mutable slice reference.
///
/// # Safety
///
/// All the elements of the input slice must be initialized and in a valid state.
#[inline]
pub unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
let len = slice.len();
let ptr = slice.as_mut_ptr();
// SAFETY: T and MaybeUninit<T> have the same layout
core::slice::from_raw_parts_mut(ptr as *mut _, len)
}