chore: bump dyn-stack to 0.13

Notable changes:
- StackReq methods no longer returns Result<StackReq, SizeOverflow>
  instead, StackReq contains the invalid state.
  Now, its when we create a PodBuffer that we can check/catch if the
  size req is invalid by catching errors when calling
  `PodBuffer::try_new`. Its also possible to manually check that
  `stack_req != StackReq::OVERFLOW`

- GlobalaPodBuffer is now PodBuffer
This commit is contained in:
Thomas Montaigu
2025-12-09 14:04:49 +01:00
committed by tmontaigu
parent 78d1ce18c1
commit d394af7f4d
41 changed files with 336 additions and 449 deletions

View File

@@ -27,7 +27,7 @@ rust-version = "1.91.1"
[workspace.dependencies]
aligned-vec = { version = "0.6", default-features = false }
bytemuck = "1.24"
dyn-stack = { version = "0.11", default-features = false }
dyn-stack = { version = "0.13", default-features = false }
itertools = "0.14"
num-complex = "0.4"
pulp = { version = "0.22", default-features = false }

View File

@@ -106,7 +106,6 @@ fn ks_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -196,7 +195,6 @@ fn ks_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
fourier_bsk.polynomial_size(),
fft.as_view(),
)
.unwrap()
.unaligned_bytes_required(),
);

View File

@@ -117,7 +117,6 @@ fn pbs_128(c: &mut Criterion) {
fourier_bsk.polynomial_size(),
fft
)
.unwrap()
.unaligned_bytes_required()
];

View File

@@ -92,7 +92,6 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -165,7 +164,6 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
fourier_bsk.polynomial_size(),
fft.as_view(),
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -321,7 +319,6 @@ fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize
CiphertextCount(count),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -404,7 +401,6 @@ fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize
fourier_bsk.polynomial_size(),
fft.as_view(),
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -772,9 +768,7 @@ fn mem_optimized_pbs_ntt(c: &mut Criterion) {
params.polynomial_size.unwrap(),
ntt,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
buffers.resize(stack_size);
@@ -844,9 +838,7 @@ fn mem_optimized_pbs_ntt(c: &mut Criterion) {
params.polynomial_size.unwrap(),
ntt.as_view(),
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
buffer.resize(stack_size);

View File

@@ -24,7 +24,7 @@ js-sys = "0.3"
default = ["std", "avx512"]
fft128 = []
avx512 = ["pulp/x86-v4"]
std = ["pulp/std"]
std = ["pulp/std", "dyn-stack/std", "dyn-stack/alloc"]
serde = ["dep:serde", "num-complex/serde"]
[dev-dependencies]
@@ -33,6 +33,7 @@ rand = { workspace = true }
bincode = "1.3"
more-asserts = "0.3.1"
serde_json = "1.0.96"
dyn-stack = { workspace = true, features = ["alloc"] }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"

View File

@@ -38,14 +38,14 @@ Additionally, an optional 128-bit negacyclic FFT module is provided.
```rust
use tfhe_fft::c64;
use tfhe_fft::ordered::{Method, Plan};
use dyn_stack::{GlobalPodBuffer, PodStack};
use dyn_stack::{PodBuffer, PodStack};
use num_complex::ComplexFloat;
use std::time::Duration;
fn main() {
const N: usize = 4;
let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
let mut scratch_memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
let stack = PodStack::new(&mut scratch_memory);
let data = [

View File

@@ -124,11 +124,12 @@ pub fn bench_ffts(c: &mut Criterion) {
1 << 15,
1 << 16,
] {
let mut mem = dyn_stack::GlobalPodBuffer::new(StackReq::all_of([
let mut mem = dyn_stack::PodBuffer::try_new(StackReq::all_of(&[
StackReq::new_aligned::<c64>(2 * n, 256), // scratch
StackReq::new_aligned::<c64>(n, 256), // src
StackReq::new_aligned::<c64>(n, 256), // dst
]));
]))
.unwrap();
let stack = PodStack::new(&mut mem);
let z = c64::new(0.0, 0.0);

View File

@@ -1025,7 +1025,9 @@ pub mod x86 {
mod tests {
use super::*;
use more_asserts::assert_le;
use rug::{ops::Pow, Float, Integer};
#[cfg(feature = "std")]
use rug::ops::Pow;
use rug::{Float, Integer};
const PREC: u32 = 1024;

View File

@@ -35,13 +35,13 @@
#![cfg_attr(not(feature = "std"), doc = "```ignore")]
//! use tfhe_fft::c64;
//! use tfhe_fft::ordered::{Plan, Method};
//! use dyn_stack::{PodStack, GlobalPodBuffer};
//! use dyn_stack::{PodStack, PodBuffer};
//! use num_complex::ComplexFloat;
//! use std::time::Duration;
//!
//! const N: usize = 4;
//! let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
//! let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
//! let mut scratch_memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
//! let stack = PodStack::new(&mut scratch_memory);
//!
//! let data = [

View File

@@ -16,8 +16,8 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
#[cfg(feature = "std")]
use core::time::Duration;
#[cfg(feature = "std")]
use dyn_stack::GlobalPodBuffer;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::PodBuffer;
use dyn_stack::{PodStack, StackReq};
/// Internal FFT algorithm.
///
@@ -245,12 +245,8 @@ impl Plan {
Method::UserProvided(algo) => algo,
#[cfg(feature = "std")]
Method::Measure(duration) => {
measure_fastest(
duration,
n,
PodStack::new(&mut GlobalPodBuffer::new(measure_fastest_scratch(n))),
)
.0
let mut buf = PodBuffer::try_new(measure_fastest_scratch(n)).unwrap();
measure_fastest(duration, n, PodStack::new(&mut buf)).0
}
};
@@ -313,10 +309,10 @@ impl Plan {
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
/// let scratch = plan.fft_scratch().unwrap();
/// let scratch = plan.fft_scratch();
/// ```
pub fn fft_scratch(&self) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<c64>(self.fft_size(), CACHELINE_ALIGN)
pub fn fft_scratch(&self) -> StackReq {
StackReq::new_aligned::<c64>(self.fft_size(), CACHELINE_ALIGN)
}
/// Performs a forward FFT in place, using the provided stack as scratch space.
@@ -326,12 +322,12 @@ impl Plan {
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
/// use tfhe_fft::c64;
/// use tfhe_fft::ordered::{Method, Plan};
/// use dyn_stack::{PodStack, GlobalPodBuffer};
/// use dyn_stack::{PodStack, PodBuffer};
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
///
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
/// let stack = PodStack::new(&mut memory);
///
/// let mut buf = [c64::default(); 4];
@@ -351,12 +347,12 @@ impl Plan {
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
/// use tfhe_fft::c64;
/// use tfhe_fft::ordered::{Method, Plan};
/// use dyn_stack::{PodStack, GlobalPodBuffer};
/// use dyn_stack::{PodStack, PodBuffer};
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
///
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
/// let stack = PodStack::new(&mut memory);
///
/// let mut buf = [c64::default(); 4];

View File

@@ -18,8 +18,8 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
#[cfg(feature = "std")]
use core::time::Duration;
#[cfg(feature = "std")]
use dyn_stack::GlobalPodBuffer;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::PodBuffer;
use dyn_stack::{PodStack, StackReq};
#[inline(always)]
fn fwd_butterfly_x2<c64xN: Pod>(
@@ -667,11 +667,8 @@ impl Plan {
#[cfg(feature = "std")]
Method::Measure(duration) => {
let (algo, base_n, _) = measure_fastest(
duration,
n,
PodStack::new(&mut GlobalPodBuffer::new(measure_fastest_scratch(n))),
);
let mut buf = PodBuffer::try_new(measure_fastest_scratch(n)).unwrap();
let (algo, base_n, _) = measure_fastest(duration, n, PodStack::new(&mut buf));
(algo, base_n)
}
};
@@ -788,10 +785,10 @@ impl Plan {
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
/// let scratch = plan.fft_scratch().unwrap();
/// let scratch = plan.fft_scratch();
/// ```
pub fn fft_scratch(&self) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<c64>(self.algo().1, CACHELINE_ALIGN)
pub fn fft_scratch(&self) -> StackReq {
StackReq::new_aligned::<c64>(self.algo().1, CACHELINE_ALIGN)
}
/// Performs a forward FFT in place, using the provided stack as scratch space.
@@ -807,12 +804,12 @@ impl Plan {
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
/// use tfhe_fft::c64;
/// use tfhe_fft::unordered::{Method, Plan};
/// use dyn_stack::{PodStack, GlobalPodBuffer};
/// use dyn_stack::{PodStack, PodBuffer};
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
///
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
/// let stack = PodStack::new(&mut memory);
///
/// let mut buf = [c64::default(); 4];
@@ -907,12 +904,12 @@ impl Plan {
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
/// use tfhe_fft::c64;
/// use tfhe_fft::unordered::{Method, Plan};
/// use dyn_stack::{PodStack, GlobalPodBuffer};
/// use dyn_stack::{PodStack, PodBuffer};
/// use core::time::Duration;
///
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
///
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
/// let stack = PodStack::new(&mut memory);
///
/// let mut buf = [c64::default(); 4];
@@ -1057,7 +1054,7 @@ fn bit_rev_twice_inv(nbits: u32, base_nbits: u32, i: usize) -> usize {
mod tests {
use super::*;
use alloc::vec;
use dyn_stack::GlobalPodBuffer;
use dyn_stack::PodBuffer;
use num_complex::ComplexFloat;
use rand::random;
@@ -1087,7 +1084,7 @@ mod tests {
},
);
let base_n = plan.algo().1;
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
plan.fwd(&mut z, stack);
@@ -1115,7 +1112,7 @@ mod tests {
base_n,
},
);
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
let mut z_target = z.clone();
@@ -1151,7 +1148,7 @@ mod tests {
base_n: 32,
},
);
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
plan.fwd(&mut z, stack);
plan.inv(&mut z, stack);
@@ -1189,7 +1186,7 @@ mod tests {
base_n: 32,
},
);
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
plan.fwd(&mut z, stack);
@@ -9395,7 +9392,7 @@ mod tests {
mod tests_serde {
use super::*;
use alloc::{vec, vec::Vec};
use dyn_stack::GlobalPodBuffer;
use dyn_stack::PodBuffer;
use num_complex::ComplexFloat;
use rand::random;
@@ -9429,12 +9426,7 @@ mod tests_serde {
},
);
let mut mem = GlobalPodBuffer::new(
plan1
.fft_scratch()
.unwrap()
.or(plan2.fft_scratch().unwrap()),
);
let mut mem = PodBuffer::try_new(plan1.fft_scratch().or(plan2.fft_scratch())).unwrap();
let stack = PodStack::new(&mut mem);
plan1.fwd(&mut z, stack);

View File

@@ -371,7 +371,6 @@ impl Bootstrapper {
let fft = fft.as_view();
self.computation_buffers.resize(
convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
@@ -467,7 +466,6 @@ impl Bootstrapper {
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = self.computation_buffers.stack();
@@ -509,7 +507,6 @@ impl Bootstrapper {
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = self.computation_buffers.stack();
@@ -556,7 +553,6 @@ impl Bootstrapper {
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = self.computation_buffers.stack();

View File

@@ -7,7 +7,7 @@ use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::fill_with_forward_fourier_scratch;
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
/// Convert a [`GGSW ciphertext`](`GgswCiphertext`) with standard coefficients to the Fourier
@@ -29,7 +29,6 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier<Scalar, InputCont, OutputCont
let mut buffers = ComputationBuffers::new();
buffers.resize(
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
@@ -63,6 +62,6 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized<Scalar, InputCo
/// Return the required memory for [`convert_standard_ggsw_ciphertext_to_fourier_mem_optimized`].
pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
fill_with_forward_fourier_scratch(fft)
}

View File

@@ -9,7 +9,7 @@ use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft128::math::fft::Fft128;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::fill_with_forward_fourier_scratch;
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use rayon::prelude::*;
use tfhe_fft::c64;
@@ -32,7 +32,6 @@ pub fn convert_standard_lwe_bootstrap_key_to_fourier<Scalar, InputCont, OutputCo
buffers.resize(
convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
@@ -152,7 +151,7 @@ pub fn par_convert_standard_lwe_bootstrap_key_to_fourier<Scalar, InputCont, Outp
/// Return the required memory for [`convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized`].
pub fn convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
fill_with_forward_fourier_scratch(fft)
}

View File

@@ -10,7 +10,7 @@ use crate::core_crypto::fft_impl::fft64::math::fft::{
par_convert_polynomials_list_to_fourier, Fft, FftView,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use rayon::prelude::*;
use tfhe_fft::c64;
@@ -33,7 +33,6 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier<Scalar, InputCont
buffers.resize(
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
@@ -80,7 +79,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized<
/// [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized`].
pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
fft.forward_scratch()
}

View File

@@ -577,7 +577,6 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -801,7 +800,6 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -1316,7 +1314,7 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender<usize>| {
let mut buffers = ComputationBuffers::new();
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
buffers.resize(fft.forward_scratch().unaligned_bytes_required());
let mut std_ggsw_buffer = GgswCiphertext::new(
Scalar::ZERO,
@@ -1413,7 +1411,6 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -1579,7 +1576,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let produce_multi_bit_fourier_ggsw = |thread_id| {
let mut buffers = ComputationBuffers::new();
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
buffers.resize(fft.forward_scratch().unaligned_bytes_required());
let mut std_ggsw_buffer = GgswCiphertext::new(
Scalar::ZERO,
@@ -1672,7 +1669,6 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -2091,7 +2087,6 @@ pub fn std_multi_bit_f128_deterministic_blind_rotate_assign<Scalar, OutputCont,
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -2513,7 +2508,6 @@ pub fn multi_bit_f128_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyC
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);

View File

@@ -9,7 +9,7 @@ use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft128::crypto::bootstrap::bootstrap_scratch as bootstrap_scratch_f128;
use crate::core_crypto::fft_impl::fft128::math::fft::{Fft128, Fft128View};
use crate::core_crypto::prelude::ModulusSwitchedLweCiphertext;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
/// Perform a programmable bootstrap given an input [`LWE ciphertext`](`LweCiphertext`), a
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
@@ -208,7 +208,6 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext<
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -259,7 +258,7 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext_mem_optimized_requirement<Scal
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
bootstrap_scratch_f128::<Scalar>(glwe_size, polynomial_size, fft)
}
@@ -408,7 +407,6 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext_mem_optimized_requirement<Scal
/// fourier_bsk.polynomial_size(),
/// fft,
/// )
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
///
@@ -465,7 +463,7 @@ pub fn blind_rotate_f128_lwe_ciphertext_mem_optimized_requirement<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
bootstrap_scratch_f128::<Scalar>(glwe_size, polynomial_size, fft)
}

View File

@@ -17,7 +17,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
};
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
use crate::core_crypto::prelude::ModulusSwitchedLweCiphertext;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
/// Perform a blind rotation given an input [`modulus switched LWE
@@ -208,7 +208,6 @@ pub fn blind_rotate_assign<OutputScalar, OutputCont, KeyCont>(
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -254,7 +253,7 @@ pub fn blind_rotate_assign_mem_optimized_requirement<OutputScalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
blind_rotate_assign_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}
@@ -290,7 +289,6 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
ggsw.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -382,12 +380,10 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
/// polynomial_size,
/// fft,
/// )
/// .unwrap()
/// .unaligned_bytes_required();
///
/// let buffer_size_req = buffer_size_req.max(
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
///
@@ -478,7 +474,7 @@ pub fn add_external_product_assign_mem_optimized_requirement<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
impl_add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
}
@@ -533,7 +529,6 @@ pub fn cmux_assign<Scalar, Cont0, Cont1, GgswCont>(
ggsw.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -646,12 +641,10 @@ pub fn cmux_assign<Scalar, Cont0, Cont1, GgswCont>(
///
/// let buffer_size_req =
/// cmux_assign_mem_optimized_requirement::<u64>(glwe_size, polynomial_size, fft)
/// .unwrap()
/// .unaligned_bytes_required();
///
/// let buffer_size_req = buffer_size_req.max(
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
///
@@ -768,7 +761,7 @@ pub fn cmux_assign_mem_optimized_requirement<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)
}
@@ -969,7 +962,6 @@ pub fn programmable_bootstrap_lwe_ciphertext<
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -1051,7 +1043,7 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputSca
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}
@@ -1139,7 +1131,7 @@ pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<Out
polynomial_size: PolynomialSize,
ciphertext_count: CiphertextCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
batch_bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, ciphertext_count, fft)
}

View File

@@ -11,7 +11,7 @@ use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{lwe_ciphertext_modulus_switch, ModulusSwitchedLweCiphertext};
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
@@ -190,7 +190,6 @@ pub fn blind_rotate_ntt64_bnf_assign<OutputCont, KeyCont>(
bsk.polynomial_size(),
ntt,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -452,7 +451,6 @@ pub fn programmable_bootstrap_ntt64_bnf_lwe_ciphertext<InputCont, OutputCont, Ac
bsk.polynomial_size(),
ntt,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -732,22 +730,18 @@ pub(crate) fn ntt64_bnf_add_external_product_assign_scratch(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
let decomp_sign_scratch =
StackReq::try_new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch = StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch_single = StackReq::try_new_aligned::<u64>(polynomial_size.0, align)?;
let standard_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
let decomp_sign_scratch = StackReq::new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align);
let ntt_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
let ntt_scratch_single = StackReq::new_aligned::<u64>(polynomial_size.0, align);
let _ = &ntt;
let substack2 = ntt_scratch_single;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = substack1
.try_and(standard_scratch)?
.try_and(decomp_sign_scratch)?;
substack0.try_and(ntt_scratch)
let substack1 = substack2.and(standard_scratch);
let substack0 = substack1.and(standard_scratch).and(decomp_sign_scratch);
substack0.and(ntt_scratch)
}
/// Return the required memory for [`cmux_ntt64_bnf_assign`].
@@ -755,7 +749,7 @@ pub(crate) fn ntt64_bnf_cmux_scratch(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
ntt64_bnf_add_external_product_assign_scratch(glwe_size, polynomial_size, ntt)
}
@@ -764,9 +758,9 @@ pub fn blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
.try_and(ntt64_bnf_cmux_scratch(glwe_size, polynomial_size, ntt)?)
) -> StackReq {
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
.and(ntt64_bnf_cmux_scratch(glwe_size, polynomial_size, ntt))
}
/// Return the required memory for
@@ -775,10 +769,8 @@ pub fn programmable_bootstrap_ntt64_bnf_lwe_ciphertext_mem_optimized_requirement
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt)?
.try_and(StackReq::try_new_aligned::<u64>(
glwe_size.0 * polynomial_size.0,
CACHELINE_ALIGN,
)?)
) -> StackReq {
blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt).and(
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
)
}

View File

@@ -18,7 +18,7 @@ use crate::core_crypto::commons::traits::*;
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::entities::*;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
@@ -195,7 +195,6 @@ pub fn blind_rotate_ntt64_assign<InputCont, OutputCont, KeyCont>(
bsk.polynomial_size(),
ntt,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -465,7 +464,6 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext<InputCont, OutputCont, AccCon
bsk.polynomial_size(),
ntt,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -708,22 +706,18 @@ pub(crate) fn ntt64_add_external_product_assign_scratch(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
let decomp_sign_scratch =
StackReq::try_new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch = StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch_single = StackReq::try_new_aligned::<u64>(polynomial_size.0, align)?;
let standard_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
let decomp_sign_scratch = StackReq::new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align);
let ntt_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
let ntt_scratch_single = StackReq::new_aligned::<u64>(polynomial_size.0, align);
let _ = &ntt;
let substack2 = ntt_scratch_single;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = substack1
.try_and(standard_scratch)?
.try_and(decomp_sign_scratch)?;
substack0.try_and(ntt_scratch)
let substack1 = substack2.and(standard_scratch);
let substack0 = substack1.and(standard_scratch).and(decomp_sign_scratch);
substack0.and(ntt_scratch)
}
/// Return the required memory for [`cmux_ntt64_assign`].
@@ -731,7 +725,7 @@ pub(crate) fn ntt64_cmux_scratch(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
ntt64_add_external_product_assign_scratch(glwe_size, polynomial_size, ntt)
}
@@ -740,9 +734,9 @@ pub fn blind_rotate_ntt64_assign_mem_optimized_requirement(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
.try_and(ntt64_cmux_scratch(glwe_size, polynomial_size, ntt)?)
) -> StackReq {
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
.and(ntt64_cmux_scratch(glwe_size, polynomial_size, ntt))
}
/// Return the required memory for [`programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized`].
@@ -750,8 +744,8 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized_requirement(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt: Ntt64View<'_>,
) -> Result<StackReq, SizeOverflow> {
blind_rotate_ntt64_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt)?.try_and(
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
) -> StackReq {
blind_rotate_ntt64_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt).and(
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
)
}

View File

@@ -11,7 +11,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::wop_pbs::{
extract_bits, extract_bits_scratch,
};
use crate::core_crypto::fft_impl::fft64::math::fft::FftView;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use rayon::prelude::*;
use tfhe_fft::c64;
@@ -365,7 +365,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
extract_bits_scratch::<Scalar>(
lwe_dimension,
ksk_output_key_lwe_dimension,
@@ -514,7 +514,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
///
/// let buffer_size_req =
/// convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
/// .unwrap()
/// .unaligned_bytes_required();
/// let buffer_size_req = buffer_size_req.max(
/// extract_bits_from_lwe_ciphertext_mem_optimized_requirement::<u64>(
@@ -524,7 +523,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
/// polynomial_size,
/// fft,
/// )
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
/// let buffer_size_req = buffer_size_req.max(
@@ -541,7 +539,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
/// cbs_level_count,
/// fft,
/// )
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
///
@@ -711,7 +708,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimi
fpksk_output_polynomial_size: PolynomialSize,
level_cbs: DecompositionLevelCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
circuit_bootstrap_boolean_vertical_packing_scratch::<Scalar>(
lwe_list_in_count,
lwe_list_out_count,

View File

@@ -268,7 +268,6 @@ where
CiphertextCount(ciphertext_count),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -460,7 +459,6 @@ where
polynomial_size,
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -798,9 +796,7 @@ fn lwe_encrypt_pbs_ntt64_decrypt_custom_mod(params: ClassicTestParams<u64>) {
polynomial_size,
ntt,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
buffers.resize(stack_size);
@@ -1097,9 +1093,7 @@ fn lwe_encrypt_pbs_ntt64_bnf_decrypt(params: ClassicTestParams<u64>) {
polynomial_size,
ntt,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
buffers.resize(stack_size);

View File

@@ -340,9 +340,7 @@ fn hpu_noise_distribution(params: HpuTestParams) {
polynomial_size,
ntt,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
buffers.resize(stack_size);

View File

@@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::fft64::math::polynomial::{
FourierPolynomialMutView, FourierPolynomialView,
};
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
/// The caller must provide a properly configured [`FftView`] object and a `PodStack` used as a
@@ -107,12 +107,10 @@ use tfhe_fft::c64;
///
/// let buffer_size_req =
/// glwe_fast_keyswitch_requirement::<u64>(glwe_size_out, polynomial_size, fft)
/// .unwrap()
/// .unaligned_bytes_required();
///
/// let buffer_size_req = buffer_size_req.max(
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
/// .unwrap()
/// .unaligned_bytes_required(),
/// );
///
@@ -320,21 +318,18 @@ pub fn glwe_fast_keyswitch_requirement<Scalar>(
glwe_size_out: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<Scalar>(glwe_size_out.0 * polynomial_size.0, align)?;
StackReq::new_aligned::<Scalar>(glwe_size_out.0 * polynomial_size.0, align);
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
let fourier_scratch =
StackReq::try_new_aligned::<c64>(glwe_size_out.0 * fourier_polynomial_size, align)?;
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
StackReq::new_aligned::<c64>(glwe_size_out.0 * fourier_polynomial_size, align);
let fourier_scratch_single = StackReq::new_aligned::<c64>(fourier_polynomial_size, align);
let substack3 = fft.forward_scratch()?;
let substack2 = substack3.try_and(fourier_scratch_single)?;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = StackReq::try_any_of([
substack1.try_and(standard_scratch)?,
fft.backward_scratch()?,
])?;
substack0.try_and(fourier_scratch)
let substack3 = fft.forward_scratch();
let substack2 = substack3.and(fourier_scratch_single);
let substack1 = substack2.and(standard_scratch);
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
substack0.and(fourier_scratch)
}

View File

@@ -9,7 +9,7 @@ use crate::core_crypto::experimental::entities::fourier_pseudo_ggsw_ciphertext::
};
use crate::core_crypto::experimental::entities::pseudo_ggsw_ciphertext::PseudoGgswCiphertext;
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
/// Convert a [`pseudo GGSW ciphertext`](`PseudoGgswCiphertext`) with standard coefficients to the
@@ -31,7 +31,6 @@ pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier<Scalar, InputCont, Out
let mut buffers = ComputationBuffers::new();
buffers.resize(
convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
@@ -67,6 +66,6 @@ pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized<
/// [`convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized`].
pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized_requirement(
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
fill_with_forward_fourier_scratch(fft)
}

View File

@@ -147,21 +147,16 @@ fn lwe_encrypt_fast_ks_decrypt_custom_mod<
ks1_polynomial_size,
fft_ks,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
let ks_buffer_size_req = ks_buffer_size_req.max(
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft_ks)
.unwrap()
.unaligned_bytes_required(),
);
let pbs_buffer_size_req = programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<
Scalar,
>(glwe_dimension.to_glwe_size(), polynomial_size, fft_pbs)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
ks_buffers.resize(ks_buffer_size_req);
pbs_buffers.resize(pbs_buffer_size_req);

View File

@@ -11,7 +11,7 @@ use crate::core_crypto::fft_impl::fft64::math::decomposition::DecompositionLevel
use crate::core_crypto::fft_impl::fft64::math::fft::{FftView, FourierPolynomialList};
use crate::core_crypto::fft_impl::fft64::math::polynomial::FourierPolynomialMutView;
use aligned_vec::{avec, ABox};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
/// A pseudo GGSW ciphertext in the Fourier domain.
@@ -262,7 +262,7 @@ impl<'a> PseudoFourierGgswCiphertextView<'a> {
/// Return the required memory for
/// [`PseudoFourierGgswCiphertextMutView::fill_with_forward_fourier`].
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
fft.forward_scratch()
}

View File

@@ -5,7 +5,7 @@ use crate::core_crypto::commons::parameters::{
use crate::core_crypto::commons::traits::Container;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{CiphertextModulusLog, ContainerMut};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
pub fn modulus_switch<Scalar: UnsignedInteger>(
input: Scalar,
@@ -35,7 +35,7 @@ pub trait FourierBootstrapKey<Scalar: UnsignedInteger> {
decomposition_level_count: DecompositionLevelCount,
) -> Self;
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow>;
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq;
fn fill_with_forward_fourier<ContBsk>(
&mut self,
@@ -49,7 +49,7 @@ pub trait FourierBootstrapKey<Scalar: UnsignedInteger> {
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: &Self::Fft,
) -> Result<StackReq, SizeOverflow>;
) -> StackReq;
fn bootstrap<ContLweOut, ContLweIn, ContAcc>(
&self,
@@ -72,7 +72,7 @@ pub mod tests {
use crate::core_crypto::fft_impl::common::FourierBootstrapKey;
use crate::core_crypto::keycache::KeyCacheAccess;
use crate::core_crypto::prelude::*;
use dyn_stack::{GlobalPodBuffer, PodStack};
use dyn_stack::{PodBuffer, PodStack};
use serde::de::DeserializeOwned;
use serde::Serialize;
@@ -160,9 +160,9 @@ pub mod tests {
fourier_bsk.fill_with_forward_fourier(
&std_bootstrapping_key,
&fft,
PodStack::new(&mut GlobalPodBuffer::new(
K::fill_with_forward_fourier_scratch(&fft).unwrap(),
)),
PodStack::new(
&mut PodBuffer::try_new(K::fill_with_forward_fourier_scratch(&fft)).unwrap(),
),
);
// Our 4 bits message space
@@ -209,14 +209,14 @@ pub mod tests {
&lwe_ciphertext_in,
&accumulator,
&fft,
PodStack::new(&mut GlobalPodBuffer::new(
K::bootstrap_scratch(
PodStack::new(
&mut PodBuffer::try_new(K::bootstrap_scratch(
std_bootstrapping_key.glwe_size(),
std_bootstrapping_key.polynomial_size(),
&fft,
)
))
.unwrap(),
)),
),
);
// Decrypt the PBS result

View File

@@ -25,7 +25,7 @@ use crate::core_crypto::prelude::{
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use core::any::TypeId;
use core::mem::transmute;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_versionable::Versionize;
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize, Versionize)]
@@ -234,9 +234,9 @@ pub fn blind_rotate_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
.try_and(cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)?)
) -> StackReq {
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
.and(cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft))
}
/// Return the required memory for [`Fourier128LweBootstrapKey::bootstrap`].
@@ -244,10 +244,13 @@ pub fn bootstrap_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
blind_rotate_scratch::<Scalar>(glwe_size, polynomial_size, fft)?.try_and(
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
)
) -> StackReq {
blind_rotate_scratch::<Scalar>(glwe_size, polynomial_size, fft).and(StackReq::new_aligned::<
Scalar,
>(
glwe_size.0 * polynomial_size.0,
CACHELINE_ALIGN,
))
}
impl<Cont> Fourier128LweBootstrapKey<Cont>
@@ -470,7 +473,7 @@ where
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: &Self::Fft,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft.as_view())
}
@@ -489,9 +492,9 @@ where
self.bootstrap(lwe_out, lwe_in, accumulator, fft.as_view(), stack);
}
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow> {
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq {
let _ = fft;
Ok(StackReq::empty())
StackReq::empty()
}
}

View File

@@ -17,7 +17,7 @@ use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecomp
use crate::core_crypto::prelude::ContainerMut;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::fft128::f128;
use tfhe_versionable::Versionize;
@@ -380,24 +380,20 @@ pub fn add_external_product_assign_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
let fourier_scratch = StackReq::try_new_aligned::<f64>(
let standard_scratch = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
let fourier_scratch = StackReq::new_aligned::<f64>(
glwe_size.0 * polynomial_size.to_fourier_polynomial_size().0,
align,
)?;
);
let fourier_scratch_single =
StackReq::try_new_aligned::<f64>(polynomial_size.to_fourier_polynomial_size().0, align)?;
StackReq::new_aligned::<f64>(polynomial_size.to_fourier_polynomial_size().0, align);
let substack2 = StackReq::try_all_of([fourier_scratch_single; 4])?;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = StackReq::try_any_of([
substack1.try_and(standard_scratch)?,
fft.backward_scratch()?,
])?;
substack0.try_and(StackReq::try_all_of([fourier_scratch; 4])?)
let substack2 = StackReq::all_of(&[fourier_scratch_single; 4]);
let substack1 = substack2.and(standard_scratch);
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
substack0.and(StackReq::all_of(&[fourier_scratch; 4]))
}
#[cfg_attr(feature = "__profiling", inline(never))]
@@ -798,7 +794,7 @@ pub fn cmux_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: Fft128View<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
}

View File

@@ -3,7 +3,7 @@ use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
use crate::core_crypto::commons::parameters::PolynomialSize;
use crate::core_crypto::commons::utils::izip_eq;
use core::any::TypeId;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
@@ -358,12 +358,12 @@ impl Fft128View<'_> {
}
/// Return the memory required for a backward negacyclic FFT.
pub fn backward_scratch(self) -> Result<StackReq, SizeOverflow> {
let one = StackReq::try_new_aligned::<f64>(
pub fn backward_scratch(self) -> StackReq {
let one = StackReq::new_aligned::<f64>(
self.polynomial_size().0 / 2,
aligned_vec::CACHELINE_ALIGN,
)?;
StackReq::try_all_of([one; 4])
);
StackReq::all_of(&[one; 4])
}
pub fn forward_as_torus<Scalar: UnsignedTorus>(

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator};
use aligned_vec::avec;
use dyn_stack::GlobalPodBuffer;
use dyn_stack::PodBuffer;
fn test_roundtrip<Scalar: UnsignedTorus>() {
let mut generator = new_random_generator();
@@ -23,7 +23,7 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
*x = generator.random_uniform();
}
let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
let mut mem = PodBuffer::try_new(fft.backward_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
fft.forward_as_torus(
@@ -110,7 +110,7 @@ fn test_product<Scalar: UnsignedTorus>() {
*y >>= Scalar::BITS - integer_magnitude;
}
let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
let mut mem = PodBuffer::try_new(fft.backward_scratch()).unwrap();
let stack = PodStack::new(&mut mem);
fft.forward_as_torus(

View File

@@ -7,7 +7,7 @@ use crate::core_crypto::fft_impl::common::tests::{
use crate::core_crypto::prelude::test::{TestResources, FFT128_U128_PARAMS};
use crate::core_crypto::prelude::*;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{GlobalPodBuffer, PodStack};
use dyn_stack::{PodBuffer, PodStack};
#[test]
fn test_split_external_product() {
@@ -83,14 +83,14 @@ fn test_split_external_product() {
&ggsw,
&glwe,
fft,
PodStack::new(&mut GlobalPodBuffer::new(
fft128::crypto::ggsw::add_external_product_assign_scratch::<u128>(
glwe_dimension.to_glwe_size(),
polynomial_size,
fft,
)
PodStack::new(
&mut PodBuffer::try_new(fft128::crypto::ggsw::add_external_product_assign_scratch::<
u128,
>(
glwe_dimension.to_glwe_size(), polynomial_size, fft
))
.unwrap(),
)),
),
);
let mut out_lo = GlweCiphertext::new(
@@ -113,14 +113,14 @@ fn test_split_external_product() {
&glwe_lo,
&glwe_hi,
fft,
PodStack::new(&mut GlobalPodBuffer::new(
fft128::crypto::ggsw::add_external_product_assign_scratch::<u128>(
glwe_dimension.to_glwe_size(),
polynomial_size,
fft,
)
PodStack::new(
&mut PodBuffer::try_new(fft128::crypto::ggsw::add_external_product_assign_scratch::<
u128,
>(
glwe_dimension.to_glwe_size(), polynomial_size, fft
))
.unwrap(),
)),
),
);
for ((lo, hi), val) in out_lo
@@ -170,14 +170,12 @@ fn test_split_pbs() {
ciphertext_modulus,
);
let mut mem = GlobalPodBuffer::new(
fft128::crypto::bootstrap::bootstrap_scratch::<u128>(
let mut mem = PodBuffer::try_new(fft128::crypto::bootstrap::bootstrap_scratch::<u128>(
glwe_dimension.to_glwe_size(),
polynomial_size,
fft,
)
.unwrap(),
);
))
.unwrap();
let stack = PodStack::new(&mut mem);
for _ in 0..20 {

View File

@@ -23,7 +23,7 @@ use crate::core_crypto::prelude::{
ModulusSwitchedLweCiphertext,
};
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
use tfhe_versionable::Versionize;
@@ -186,7 +186,7 @@ impl FourierLweBootstrapKey<ABox<[c64]>> {
}
/// Return the required memory for [`FourierLweBootstrapKeyMutView::fill_with_forward_fourier`].
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
fft.forward_scratch()
}
@@ -230,16 +230,16 @@ pub fn blind_rotate_assign_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_any_of([
) -> StackReq {
StackReq::any_of(&[
// tmp_poly allocation
StackReq::try_new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN)?,
StackReq::try_all_of([
StackReq::new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN),
StackReq::all_of(&[
// ct1 allocation
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
// external product
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
])?,
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
]),
])
}
@@ -248,9 +248,9 @@ pub fn bootstrap_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?.try_and(
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
) -> StackReq {
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft).and(
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
)
}
@@ -260,19 +260,19 @@ pub fn batch_blind_rotate_assign_scratch<Scalar>(
polynomial_size: PolynomialSize,
ciphertext_count: CiphertextCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_any_of([
) -> StackReq {
StackReq::any_of(&[
// tmp_poly allocation
StackReq::try_new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN)?,
StackReq::try_all_of([
StackReq::new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN),
StackReq::all_of(&[
// ct1 allocation
StackReq::try_new_aligned::<Scalar>(
StackReq::new_aligned::<Scalar>(
glwe_ciphertext_size(glwe_size, polynomial_size) * ciphertext_count.0,
CACHELINE_ALIGN,
)?,
),
// external product
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
])?,
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
]),
])
}
@@ -282,12 +282,12 @@ pub fn batch_bootstrap_scratch<Scalar>(
polynomial_size: PolynomialSize,
ciphertext_count: CiphertextCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
batch_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, ciphertext_count, fft)?
.try_and(StackReq::try_new_aligned::<Scalar>(
) -> StackReq {
batch_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, ciphertext_count, fft)
.and(StackReq::new_aligned::<Scalar>(
glwe_ciphertext_size(glwe_size, polynomial_size) * ciphertext_count.0,
CACHELINE_ALIGN,
)?)
))
}
impl FourierLweBootstrapKeyView<'_> {
@@ -596,7 +596,7 @@ where
)
}
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow> {
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq {
fill_with_forward_fourier_scratch(fft.as_view())
}
@@ -616,7 +616,7 @@ where
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: &Self::Fft,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft.as_view())
}

View File

@@ -16,7 +16,7 @@ use crate::core_crypto::entities::ggsw_ciphertext::{
};
use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertextMutView, GlweCiphertextView};
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
use tfhe_versionable::Versionize;
@@ -249,7 +249,7 @@ impl<'a> FourierGgswCiphertextView<'a> {
}
/// Return the required memory for [`FourierGgswCiphertextMutView::fill_with_forward_fourier`].
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
fft.forward_scratch()
}
@@ -463,23 +463,19 @@ pub fn add_external_product_assign_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
let standard_scratch = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
let fourier_scratch =
StackReq::try_new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align)?;
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
StackReq::new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align);
let fourier_scratch_single = StackReq::new_aligned::<c64>(fourier_polynomial_size, align);
let substack3 = fft.forward_scratch()?;
let substack2 = substack3.try_and(fourier_scratch_single)?;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = StackReq::try_any_of([
substack1.try_and(standard_scratch)?,
fft.backward_scratch()?,
])?;
substack0.try_and(fourier_scratch)
let substack3 = fft.forward_scratch();
let substack2 = substack3.and(fourier_scratch_single);
let substack1 = substack2.and(standard_scratch);
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
substack0.and(fourier_scratch)
}
/// Perform the external product of `ggsw` and `glwe`, and adds the result to `out`.
@@ -763,7 +759,7 @@ pub fn cmux_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
}

View File

@@ -15,7 +15,7 @@ use crate::core_crypto::commons::traits::*;
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::entities::*;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use tfhe_fft::c64;
pub fn extract_bits_scratch<Scalar>(
@@ -24,34 +24,32 @@ pub fn extract_bits_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
let lwe_in_buffer =
StackReq::try_new_aligned::<Scalar>(input_lwe_dimension.to_lwe_size().0, align)?;
let lwe_in_buffer = StackReq::new_aligned::<Scalar>(input_lwe_dimension.to_lwe_size().0, align);
let lwe_out_ks_buffer =
StackReq::try_new_aligned::<Scalar>(ksk_after_key_size.to_lwe_size().0, align)?;
let pbs_accumulator =
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
let lwe_out_pbs_buffer = StackReq::try_new_aligned::<Scalar>(
StackReq::new_aligned::<Scalar>(ksk_after_key_size.to_lwe_size().0, align);
let pbs_accumulator = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
let lwe_out_pbs_buffer = StackReq::new_aligned::<Scalar>(
glwe_size
.to_glwe_dimension()
.to_equivalent_lwe_dimension(polynomial_size)
.to_lwe_size()
.0,
align,
)?;
);
let lwe_bit_left_shift_buffer = lwe_in_buffer;
let bootstrap_scratch = bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft)?;
let bootstrap_scratch = bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft);
lwe_in_buffer
.try_and(lwe_out_ks_buffer)?
.try_and(pbs_accumulator)?
.try_and(lwe_out_pbs_buffer)?
.try_and(StackReq::try_any_of([
.and(lwe_out_ks_buffer)
.and(pbs_accumulator)
.and(lwe_out_pbs_buffer)
.and(StackReq::any_of(&[
lwe_bit_left_shift_buffer,
bootstrap_scratch,
])?)
]))
}
/// Function to extract `number_of_bits_to_extract` from an [`LweCiphertext`] starting at the bit
@@ -227,9 +225,9 @@ pub fn circuit_bootstrap_boolean_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<Scalar>(bsk_output_lwe_size.0, CACHELINE_ALIGN)?.try_and(
homomorphic_shift_boolean_scratch::<Scalar>(lwe_in_size, glwe_size, polynomial_size, fft)?,
) -> StackReq {
StackReq::new_aligned::<Scalar>(bsk_output_lwe_size.0, CACHELINE_ALIGN).and(
homomorphic_shift_boolean_scratch::<Scalar>(lwe_in_size, glwe_size, polynomial_size, fft),
)
}
@@ -345,18 +343,14 @@ pub fn homomorphic_shift_boolean_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let align = CACHELINE_ALIGN;
StackReq::try_new_aligned::<Scalar>(lwe_in_size.0, align)?
.try_and(StackReq::try_new_aligned::<Scalar>(
StackReq::new_aligned::<Scalar>(lwe_in_size.0, align)
.and(StackReq::new_aligned::<Scalar>(
polynomial_size.0 * glwe_size.0,
align,
)?)?
.try_and(bootstrap_scratch::<Scalar>(
glwe_size,
polynomial_size,
fft,
)?)
))
.and(bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft))
}
/// Homomorphic shift for LWE without padding bit
@@ -445,18 +439,18 @@ pub fn cmux_tree_memory_optimized_scratch<Scalar>(
polynomial_size: PolynomialSize,
nb_layer: usize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
let t_scratch = StackReq::try_new_aligned::<Scalar>(
) -> StackReq {
let t_scratch = StackReq::new_aligned::<Scalar>(
polynomial_size.0 * glwe_size.0 * nb_layer,
CACHELINE_ALIGN,
)?;
);
StackReq::try_all_of([
StackReq::all_of(&[
t_scratch, // t_0
t_scratch, // t_1
StackReq::try_new::<usize>(nb_layer)?, // t_fill
StackReq::new::<usize>(nb_layer), // t_fill
t_scratch, // diff
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
])
}
@@ -477,10 +471,12 @@ pub fn cmux_tree_memory_optimized<Scalar: UnsignedTorus + CastInto<usize>>(
let polynomial_size = ggsw_list.polynomial_size();
let nb_layer = ggsw_list.count();
debug_assert!(stack.can_hold(
cmux_tree_memory_optimized_scratch::<Scalar>(glwe_size, polynomial_size, nb_layer, fft)
.unwrap()
));
debug_assert!(stack.can_hold(cmux_tree_memory_optimized_scratch::<Scalar>(
glwe_size,
polynomial_size,
nb_layer,
fft
)));
// These are accumulator that will be used to propagate the result from layer to layer
// At index 0 you have the lut that will be loaded, and then the result for each layer gets
@@ -596,40 +592,40 @@ pub fn circuit_bootstrap_boolean_vertical_packing_scratch<Scalar>(
fpksk_output_polynomial_size: PolynomialSize,
level_cbs: DecompositionLevelCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
// We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out
let number_of_luts = lwe_list_out_count.0;
let small_lut_size = PolynomialCount(big_lut_polynomial_count.0 / number_of_luts);
StackReq::try_all_of([
StackReq::try_new_aligned::<c64>(
StackReq::all_of(&[
StackReq::new_aligned::<c64>(
lwe_list_in_count.0 * fpksk_output_polynomial_size.0 / 2
* glwe_size.0
* glwe_size.0
* level_cbs.0,
CACHELINE_ALIGN,
)?,
StackReq::try_new_aligned::<Scalar>(
),
StackReq::new_aligned::<Scalar>(
fpksk_output_polynomial_size.0 * glwe_size.0 * glwe_size.0 * level_cbs.0,
CACHELINE_ALIGN,
)?,
StackReq::try_any_of([
),
StackReq::any_of(&[
circuit_bootstrap_boolean_scratch::<Scalar>(
lwe_in_size,
bsk_output_lwe_size,
glwe_size,
fpksk_output_polynomial_size,
fft,
)?,
fill_with_forward_fourier_scratch(fft)?,
),
fill_with_forward_fourier_scratch(fft),
vertical_packing_scratch::<Scalar>(
glwe_size,
fpksk_output_polynomial_size,
small_lut_size,
lwe_list_in_count.0,
fft,
)?,
])?,
),
]),
])
}
@@ -662,7 +658,6 @@ pub fn circuit_bootstrap_boolean_vertical_packing<Scalar: UnsignedTorus + CastIn
level_cbs,
fft
)
.unwrap()
));
debug_assert!(
lwe_list_out.lwe_size().to_lwe_dimension() == fourier_bsk.output_lwe_dimension(),
@@ -742,7 +737,7 @@ pub fn vertical_packing_scratch<Scalar>(
lut_polynomial_count: PolynomialCount,
ggsw_list_count: usize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
) -> StackReq {
let bits = core::mem::size_of::<Scalar>() * 8;
// Get the base 2 logarithm (rounded down) of the number of polynomials in the list i.e. if
@@ -757,18 +752,18 @@ pub fn vertical_packing_scratch<Scalar>(
log_lut_number
};
StackReq::try_all_of([
StackReq::all_of(&[
// cmux_tree_lut_res
StackReq::try_new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?,
StackReq::try_any_of([
wop_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
StackReq::new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN),
StackReq::any_of(&[
wop_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
cmux_tree_memory_optimized_scratch::<Scalar>(
glwe_size,
polynomial_size,
log_number_of_luts_for_cmux_tree,
fft,
)?,
])?,
),
]),
])
}
@@ -833,10 +828,10 @@ pub fn wop_blind_rotate_assign_scratch<Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_all_of([
StackReq::try_new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?,
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
) -> StackReq {
StackReq::all_of(&[
StackReq::new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN),
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft),
])
}

View File

@@ -11,7 +11,7 @@ use crate::core_crypto::prelude::test::{
TestResources, FFT_WOPBS_N1024_PARAMS, FFT_WOPBS_N2048_PARAMS, FFT_WOPBS_N512_PARAMS,
FFT_WOPBS_PARAMS,
};
use dyn_stack::GlobalPodBuffer;
use dyn_stack::PodBuffer;
use serde::de::DeserializeOwned;
use serde::Serialize;
@@ -159,19 +159,19 @@ pub fn test_extract_bits() {
let input_lwe_dimension = lwe_big_sk.lwe_dimension();
let req = || {
StackReq::try_any_of([
fill_with_forward_fourier_scratch(fft)?,
StackReq::any_of(&[
fill_with_forward_fourier_scratch(fft),
extract_bits_scratch::<u64>(
input_lwe_dimension,
ksk_lwe_big_to_small.output_key_lwe_dimension(),
glwe_dimension.to_glwe_size(),
polynomial_size,
fft,
)?,
),
])
};
let req = req().unwrap();
let mut mem = GlobalPodBuffer::new(req);
let req = req();
let mut mem = PodBuffer::try_new(req).unwrap();
let stack = PodStack::new(&mut mem);
fourier_bsk
@@ -348,16 +348,14 @@ fn test_circuit_bootstrapping_binary() {
ciphertext_modulus,
);
let mut mem = GlobalPodBuffer::new(
circuit_bootstrap_boolean_scratch::<u64>(
let mut mem = PodBuffer::try_new(circuit_bootstrap_boolean_scratch::<u64>(
lwe_in.lwe_size(),
fourier_bsk.output_lwe_dimension().to_lwe_size(),
glwe_dimension.to_glwe_size(),
polynomial_size,
fft,
)
.unwrap(),
);
))
.unwrap();
let stack = PodStack::new(&mut mem);
// Execute the CBS
circuit_bootstrap_boolean(
@@ -539,7 +537,7 @@ pub fn test_cmux_tree() {
&mut rsc.encryption_random_generator,
);
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
let mut mem = PodBuffer::try_new(fill_with_forward_fourier_scratch(fft)).unwrap();
let stack = PodStack::new(&mut mem);
fourier_ggsw
.as_mut_view()
@@ -548,10 +546,13 @@ pub fn test_cmux_tree() {
let mut result_cmux_tree =
GlweCiphertext::new(0_u64, glwe_size, polynomial_size, ciphertext_modulus);
let mut mem = GlobalPodBuffer::new(
cmux_tree_memory_optimized_scratch::<u64>(glwe_size, polynomial_size, nb_ggsw, fft)
.unwrap(),
);
let mut mem = PodBuffer::try_new(cmux_tree_memory_optimized_scratch::<u64>(
glwe_size,
polynomial_size,
nb_ggsw,
fft,
))
.unwrap();
cmux_tree_memory_optimized(
result_cmux_tree.as_mut_view(),
lut.as_view(),
@@ -656,16 +657,14 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
ciphertext_modulus,
);
let mut mem = GlobalPodBuffer::new(
extract_bits_scratch::<u64>(
let mut mem = PodBuffer::try_new(extract_bits_scratch::<u64>(
input_lwe_dimension,
ksk_lwe_big_to_small.output_key_lwe_dimension(),
fourier_bsk.glwe_size(),
polynomial_size,
fft,
)
.unwrap(),
);
))
.unwrap();
extract_bits(
extracted_bits_lwe_list.as_mut_view(),
lwe_in.as_view(),
@@ -729,8 +728,8 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
);
// Perform circuit bootstrap + vertical packing
let mut mem = GlobalPodBuffer::new(
circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
let mut mem =
PodBuffer::try_new(circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
extracted_bits_lwe_list.lwe_ciphertext_count(),
vertical_packing_lwe_list_out.lwe_ciphertext_count(),
extracted_bits_lwe_list.lwe_size(),
@@ -740,9 +739,8 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
vec_pfpksk.output_polynomial_size(),
level_cbs,
fft,
)
.unwrap(),
);
))
.unwrap();
circuit_bootstrap_boolean_vertical_packing(
lut_poly_list.as_view(),
fourier_bsk.as_view(),
@@ -864,8 +862,7 @@ fn test_wop_add_one(params: FftWopPbsTestParams<u64>) {
let lut_as_polynomial_list = PolynomialList::from_container(lut, polynomial_size);
// Perform circuit bootstrap + vertical packing
let mut mem = GlobalPodBuffer::new(
circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
let mut mem = PodBuffer::new(circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
extracted_bits.lwe_ciphertext_count(),
output_cbs_vp.lwe_ciphertext_count(),
extracted_bits.lwe_size(),
@@ -875,9 +872,7 @@ fn test_wop_add_one(params: FftWopPbsTestParams<u64>) {
cbs_pfpksk.output_polynomial_size(),
level_cbs,
fft,
)
.unwrap(),
);
));
circuit_bootstrap_boolean_vertical_packing(
lut_as_polynomial_list.as_view(),
fourier_bsk.as_view(),

View File

@@ -10,7 +10,7 @@ use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainer
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::entities::*;
use aligned_vec::{avec, ABox};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use dyn_stack::{PodStack, StackReq};
use rayon::prelude::*;
use std::any::TypeId;
use std::collections::hash_map::Entry;
@@ -380,18 +380,16 @@ impl FftView<'_> {
}
/// Return the memory required for a forward negacyclic FFT.
pub fn forward_scratch(self) -> Result<StackReq, SizeOverflow> {
pub fn forward_scratch(self) -> StackReq {
self.plan.fft_scratch()
}
/// Return the memory required for a backward negacyclic FFT.
pub fn backward_scratch(self) -> Result<StackReq, SizeOverflow> {
self.plan
.fft_scratch()?
.try_and(StackReq::try_new_aligned::<c64>(
pub fn backward_scratch(self) -> StackReq {
self.plan.fft_scratch().and(StackReq::new_aligned::<c64>(
self.polynomial_size().to_fourier_polynomial_size().0,
aligned_vec::CACHELINE_ALIGN,
)?)
))
}
/// Perform a negacyclic real FFT of `standard`, viewed as torus elements, and stores the
@@ -793,11 +791,7 @@ pub fn par_convert_polynomials_list_to_fourier<Scalar: UnsignedTorus>(
dest.par_chunks_mut(chunk_size * f_polynomial_size)
.zip_eq(origin.par_chunks(chunk_size * polynomial_size.0))
.for_each(|(fourier_poly_chunk, standard_poly_chunk)| {
let stack_len = fft
.forward_scratch()
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
let stack_len = fft.forward_scratch().unaligned_bytes_required();
let mut mem = vec![0; stack_len];
let stack = PodStack::new(&mut mem);

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator};
use aligned_vec::avec;
use dyn_stack::GlobalPodBuffer;
use dyn_stack::PodBuffer;
fn test_roundtrip<Scalar: UnsignedTorus>() {
let mut generator = new_random_generator();
@@ -23,11 +23,8 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
*x = generator.random_uniform();
}
let mut mem = GlobalPodBuffer::new(
fft.forward_scratch()
.unwrap()
.and(fft.backward_scratch().unwrap()),
);
let mut mem =
PodBuffer::try_new(fft.forward_scratch().and(fft.backward_scratch())).unwrap();
let stack = PodStack::new(&mut mem);
// Simple roundtrip
@@ -125,11 +122,8 @@ fn test_product<Scalar: UnsignedTorus>() {
*y >>= Scalar::BITS - integer_magnitude;
}
let mut mem = GlobalPodBuffer::new(
fft.forward_scratch()
.unwrap()
.and(fft.backward_scratch().unwrap()),
);
let mut mem =
PodBuffer::try_new(fft.forward_scratch().and(fft.backward_scratch())).unwrap();
let stack = PodStack::new(&mut mem);
fft.forward_as_torus(fourier0.as_mut_view(), poly0.as_view(), stack);

View File

@@ -1510,7 +1510,6 @@ pub(crate) fn apply_standard_blind_rotate<OutputScalar, OutputCont>(
poly_size,
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -1588,9 +1587,7 @@ pub(crate) fn apply_programmable_bootstrap_128<InputScalar, InputCont, OutputSca
bsk_polynomial_size,
fft,
)
.unwrap()
.try_unaligned_bytes_required()
.unwrap();
.unaligned_bytes_required();
let br_input_modulus_log = bsk.polynomial_size().to_blind_rotation_input_modulus_log();
let lwe_ciphertext_to_squash_noise = modulus_switch_noise_reduction_key

View File

@@ -720,7 +720,6 @@ mod experimental {
bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
@@ -821,7 +820,6 @@ mod experimental {
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
@@ -927,7 +925,6 @@ mod experimental {
self.param.cbs_level,
fft,
)
.unwrap()
.unaligned_bytes_required(),
);