mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
chore: bump dyn-stack to 0.13
Notable changes: - StackReq methods no longer returns Result<StackReq, SizeOverflow> instead, StackReq contains the invalid state. Now, its when we create a PodBuffer that we can check/catch if the size req is invalid by catching errors when calling `PodBuffer::try_new`. Its also possible to manually check that `stack_req != StackReq::OVERFLOW` - GlobalaPodBuffer is now PodBuffer
This commit is contained in:
committed by
tmontaigu
parent
78d1ce18c1
commit
d394af7f4d
@@ -27,7 +27,7 @@ rust-version = "1.91.1"
|
||||
[workspace.dependencies]
|
||||
aligned-vec = { version = "0.6", default-features = false }
|
||||
bytemuck = "1.24"
|
||||
dyn-stack = { version = "0.11", default-features = false }
|
||||
dyn-stack = { version = "0.13", default-features = false }
|
||||
itertools = "0.14"
|
||||
num-complex = "0.4"
|
||||
pulp = { version = "0.22", default-features = false }
|
||||
|
||||
@@ -106,7 +106,6 @@ fn ks_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -196,7 +195,6 @@ fn ks_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
|
||||
@@ -117,7 +117,6 @@ fn pbs_128(c: &mut Criterion) {
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required()
|
||||
];
|
||||
|
||||
|
||||
@@ -92,7 +92,6 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -165,7 +164,6 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -321,7 +319,6 @@ fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize
|
||||
CiphertextCount(count),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -404,7 +401,6 @@ fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -772,9 +768,7 @@ fn mem_optimized_pbs_ntt(c: &mut Criterion) {
|
||||
params.polynomial_size.unwrap(),
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(stack_size);
|
||||
|
||||
@@ -844,9 +838,7 @@ fn mem_optimized_pbs_ntt(c: &mut Criterion) {
|
||||
params.polynomial_size.unwrap(),
|
||||
ntt.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffer.resize(stack_size);
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ js-sys = "0.3"
|
||||
default = ["std", "avx512"]
|
||||
fft128 = []
|
||||
avx512 = ["pulp/x86-v4"]
|
||||
std = ["pulp/std"]
|
||||
std = ["pulp/std", "dyn-stack/std", "dyn-stack/alloc"]
|
||||
serde = ["dep:serde", "num-complex/serde"]
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -33,6 +33,7 @@ rand = { workspace = true }
|
||||
bincode = "1.3"
|
||||
more-asserts = "0.3.1"
|
||||
serde_json = "1.0.96"
|
||||
dyn-stack = { workspace = true, features = ["alloc"] }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
@@ -38,14 +38,14 @@ Additionally, an optional 128-bit negacyclic FFT module is provided.
|
||||
```rust
|
||||
use tfhe_fft::c64;
|
||||
use tfhe_fft::ordered::{Method, Plan};
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack};
|
||||
use dyn_stack::{PodBuffer, PodStack};
|
||||
use num_complex::ComplexFloat;
|
||||
use std::time::Duration;
|
||||
|
||||
fn main() {
|
||||
const N: usize = 4;
|
||||
let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
let mut scratch_memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut scratch_memory);
|
||||
|
||||
let data = [
|
||||
|
||||
@@ -124,11 +124,12 @@ pub fn bench_ffts(c: &mut Criterion) {
|
||||
1 << 15,
|
||||
1 << 16,
|
||||
] {
|
||||
let mut mem = dyn_stack::GlobalPodBuffer::new(StackReq::all_of([
|
||||
let mut mem = dyn_stack::PodBuffer::try_new(StackReq::all_of(&[
|
||||
StackReq::new_aligned::<c64>(2 * n, 256), // scratch
|
||||
StackReq::new_aligned::<c64>(n, 256), // src
|
||||
StackReq::new_aligned::<c64>(n, 256), // dst
|
||||
]));
|
||||
]))
|
||||
.unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
let z = c64::new(0.0, 0.0);
|
||||
|
||||
|
||||
@@ -1025,7 +1025,9 @@ pub mod x86 {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use more_asserts::assert_le;
|
||||
use rug::{ops::Pow, Float, Integer};
|
||||
#[cfg(feature = "std")]
|
||||
use rug::ops::Pow;
|
||||
use rug::{Float, Integer};
|
||||
|
||||
const PREC: u32 = 1024;
|
||||
|
||||
|
||||
@@ -35,13 +35,13 @@
|
||||
#![cfg_attr(not(feature = "std"), doc = "```ignore")]
|
||||
//! use tfhe_fft::c64;
|
||||
//! use tfhe_fft::ordered::{Plan, Method};
|
||||
//! use dyn_stack::{PodStack, GlobalPodBuffer};
|
||||
//! use dyn_stack::{PodStack, PodBuffer};
|
||||
//! use num_complex::ComplexFloat;
|
||||
//! use std::time::Duration;
|
||||
//!
|
||||
//! const N: usize = 4;
|
||||
//! let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
//! let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
//! let mut scratch_memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
//! let stack = PodStack::new(&mut scratch_memory);
|
||||
//!
|
||||
//! let data = [
|
||||
|
||||
@@ -16,8 +16,8 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
#[cfg(feature = "std")]
|
||||
use core::time::Duration;
|
||||
#[cfg(feature = "std")]
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::PodBuffer;
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
/// Internal FFT algorithm.
|
||||
///
|
||||
@@ -245,12 +245,8 @@ impl Plan {
|
||||
Method::UserProvided(algo) => algo,
|
||||
#[cfg(feature = "std")]
|
||||
Method::Measure(duration) => {
|
||||
measure_fastest(
|
||||
duration,
|
||||
n,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(measure_fastest_scratch(n))),
|
||||
)
|
||||
.0
|
||||
let mut buf = PodBuffer::try_new(measure_fastest_scratch(n)).unwrap();
|
||||
measure_fastest(duration, n, PodStack::new(&mut buf)).0
|
||||
}
|
||||
};
|
||||
|
||||
@@ -313,10 +309,10 @@ impl Plan {
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
/// let scratch = plan.fft_scratch().unwrap();
|
||||
/// let scratch = plan.fft_scratch();
|
||||
/// ```
|
||||
pub fn fft_scratch(&self) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<c64>(self.fft_size(), CACHELINE_ALIGN)
|
||||
pub fn fft_scratch(&self) -> StackReq {
|
||||
StackReq::new_aligned::<c64>(self.fft_size(), CACHELINE_ALIGN)
|
||||
}
|
||||
|
||||
/// Performs a forward FFT in place, using the provided stack as scratch space.
|
||||
@@ -326,12 +322,12 @@ impl Plan {
|
||||
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
|
||||
/// use tfhe_fft::c64;
|
||||
/// use tfhe_fft::ordered::{Method, Plan};
|
||||
/// use dyn_stack::{PodStack, GlobalPodBuffer};
|
||||
/// use dyn_stack::{PodStack, PodBuffer};
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
///
|
||||
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
/// let stack = PodStack::new(&mut memory);
|
||||
///
|
||||
/// let mut buf = [c64::default(); 4];
|
||||
@@ -351,12 +347,12 @@ impl Plan {
|
||||
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
|
||||
/// use tfhe_fft::c64;
|
||||
/// use tfhe_fft::ordered::{Method, Plan};
|
||||
/// use dyn_stack::{PodStack, GlobalPodBuffer};
|
||||
/// use dyn_stack::{PodStack, PodBuffer};
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
///
|
||||
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
/// let stack = PodStack::new(&mut memory);
|
||||
///
|
||||
/// let mut buf = [c64::default(); 4];
|
||||
|
||||
@@ -18,8 +18,8 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
#[cfg(feature = "std")]
|
||||
use core::time::Duration;
|
||||
#[cfg(feature = "std")]
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::PodBuffer;
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
#[inline(always)]
|
||||
fn fwd_butterfly_x2<c64xN: Pod>(
|
||||
@@ -667,11 +667,8 @@ impl Plan {
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
Method::Measure(duration) => {
|
||||
let (algo, base_n, _) = measure_fastest(
|
||||
duration,
|
||||
n,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(measure_fastest_scratch(n))),
|
||||
);
|
||||
let mut buf = PodBuffer::try_new(measure_fastest_scratch(n)).unwrap();
|
||||
let (algo, base_n, _) = measure_fastest(duration, n, PodStack::new(&mut buf));
|
||||
(algo, base_n)
|
||||
}
|
||||
};
|
||||
@@ -788,10 +785,10 @@ impl Plan {
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
/// let scratch = plan.fft_scratch().unwrap();
|
||||
/// let scratch = plan.fft_scratch();
|
||||
/// ```
|
||||
pub fn fft_scratch(&self) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<c64>(self.algo().1, CACHELINE_ALIGN)
|
||||
pub fn fft_scratch(&self) -> StackReq {
|
||||
StackReq::new_aligned::<c64>(self.algo().1, CACHELINE_ALIGN)
|
||||
}
|
||||
|
||||
/// Performs a forward FFT in place, using the provided stack as scratch space.
|
||||
@@ -807,12 +804,12 @@ impl Plan {
|
||||
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
|
||||
/// use tfhe_fft::c64;
|
||||
/// use tfhe_fft::unordered::{Method, Plan};
|
||||
/// use dyn_stack::{PodStack, GlobalPodBuffer};
|
||||
/// use dyn_stack::{PodStack, PodBuffer};
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
///
|
||||
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
/// let stack = PodStack::new(&mut memory);
|
||||
///
|
||||
/// let mut buf = [c64::default(); 4];
|
||||
@@ -907,12 +904,12 @@ impl Plan {
|
||||
#[cfg_attr(not(feature = "std"), doc = " ```ignore")]
|
||||
/// use tfhe_fft::c64;
|
||||
/// use tfhe_fft::unordered::{Method, Plan};
|
||||
/// use dyn_stack::{PodStack, GlobalPodBuffer};
|
||||
/// use dyn_stack::{PodStack, PodBuffer};
|
||||
/// use core::time::Duration;
|
||||
///
|
||||
/// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
|
||||
///
|
||||
/// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
/// let mut memory = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
/// let stack = PodStack::new(&mut memory);
|
||||
///
|
||||
/// let mut buf = [c64::default(); 4];
|
||||
@@ -1057,7 +1054,7 @@ fn bit_rev_twice_inv(nbits: u32, base_nbits: u32, i: usize) -> usize {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::PodBuffer;
|
||||
use num_complex::ComplexFloat;
|
||||
use rand::random;
|
||||
|
||||
@@ -1087,7 +1084,7 @@ mod tests {
|
||||
},
|
||||
);
|
||||
let base_n = plan.algo().1;
|
||||
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
plan.fwd(&mut z, stack);
|
||||
|
||||
@@ -1115,7 +1112,7 @@ mod tests {
|
||||
base_n,
|
||||
},
|
||||
);
|
||||
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
let mut z_target = z.clone();
|
||||
@@ -1151,7 +1148,7 @@ mod tests {
|
||||
base_n: 32,
|
||||
},
|
||||
);
|
||||
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
plan.fwd(&mut z, stack);
|
||||
plan.inv(&mut z, stack);
|
||||
@@ -1189,7 +1186,7 @@ mod tests {
|
||||
base_n: 32,
|
||||
},
|
||||
);
|
||||
let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(plan.fft_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
plan.fwd(&mut z, stack);
|
||||
|
||||
@@ -9395,7 +9392,7 @@ mod tests {
|
||||
mod tests_serde {
|
||||
use super::*;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::PodBuffer;
|
||||
use num_complex::ComplexFloat;
|
||||
use rand::random;
|
||||
|
||||
@@ -9429,12 +9426,7 @@ mod tests_serde {
|
||||
},
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
plan1
|
||||
.fft_scratch()
|
||||
.unwrap()
|
||||
.or(plan2.fft_scratch().unwrap()),
|
||||
);
|
||||
let mut mem = PodBuffer::try_new(plan1.fft_scratch().or(plan2.fft_scratch())).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
plan1.fwd(&mut z, stack);
|
||||
|
||||
@@ -371,7 +371,6 @@ impl Bootstrapper {
|
||||
let fft = fft.as_view();
|
||||
self.computation_buffers.resize(
|
||||
convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -467,7 +466,6 @@ impl Bootstrapper {
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = self.computation_buffers.stack();
|
||||
@@ -509,7 +507,6 @@ impl Bootstrapper {
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = self.computation_buffers.stack();
|
||||
@@ -556,7 +553,6 @@ impl Bootstrapper {
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = self.computation_buffers.stack();
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::fill_with_forward_fourier_scratch;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// Convert a [`GGSW ciphertext`](`GgswCiphertext`) with standard coefficients to the Fourier
|
||||
@@ -29,7 +29,6 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier<Scalar, InputCont, OutputCont
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
buffers.resize(
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -63,6 +62,6 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized<Scalar, InputCo
|
||||
/// Return the required memory for [`convert_standard_ggsw_ciphertext_to_fourier_mem_optimized`].
|
||||
pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
fill_with_forward_fourier_scratch(fft)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft128::math::fft::Fft128;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::fill_with_forward_fourier_scratch;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use rayon::prelude::*;
|
||||
use tfhe_fft::c64;
|
||||
|
||||
@@ -32,7 +32,6 @@ pub fn convert_standard_lwe_bootstrap_key_to_fourier<Scalar, InputCont, OutputCo
|
||||
|
||||
buffers.resize(
|
||||
convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -152,7 +151,7 @@ pub fn par_convert_standard_lwe_bootstrap_key_to_fourier<Scalar, InputCont, Outp
|
||||
/// Return the required memory for [`convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized`].
|
||||
pub fn convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
fill_with_forward_fourier_scratch(fft)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::core_crypto::fft_impl::fft64::math::fft::{
|
||||
par_convert_polynomials_list_to_fourier, Fft, FftView,
|
||||
};
|
||||
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use rayon::prelude::*;
|
||||
use tfhe_fft::c64;
|
||||
|
||||
@@ -33,7 +33,6 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier<Scalar, InputCont
|
||||
|
||||
buffers.resize(
|
||||
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -80,7 +79,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized<
|
||||
/// [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized`].
|
||||
pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
fft.forward_scratch()
|
||||
}
|
||||
|
||||
|
||||
@@ -577,7 +577,6 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -801,7 +800,6 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -1316,7 +1314,7 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
|
||||
let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender<usize>| {
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
|
||||
buffers.resize(fft.forward_scratch().unaligned_bytes_required());
|
||||
|
||||
let mut std_ggsw_buffer = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
@@ -1413,7 +1411,6 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -1579,7 +1576,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
|
||||
let produce_multi_bit_fourier_ggsw = |thread_id| {
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
|
||||
buffers.resize(fft.forward_scratch().unaligned_bytes_required());
|
||||
|
||||
let mut std_ggsw_buffer = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
@@ -1672,7 +1669,6 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -2091,7 +2087,6 @@ pub fn std_multi_bit_f128_deterministic_blind_rotate_assign<Scalar, OutputCont,
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -2513,7 +2508,6 @@ pub fn multi_bit_f128_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyC
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft128::crypto::bootstrap::bootstrap_scratch as bootstrap_scratch_f128;
|
||||
use crate::core_crypto::fft_impl::fft128::math::fft::{Fft128, Fft128View};
|
||||
use crate::core_crypto::prelude::ModulusSwitchedLweCiphertext;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
/// Perform a programmable bootstrap given an input [`LWE ciphertext`](`LweCiphertext`), a
|
||||
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
@@ -208,7 +208,6 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext<
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -259,7 +258,7 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext_mem_optimized_requirement<Scal
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
bootstrap_scratch_f128::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -408,7 +407,6 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext_mem_optimized_requirement<Scal
|
||||
/// fourier_bsk.polynomial_size(),
|
||||
/// fft,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
@@ -465,7 +463,7 @@ pub fn blind_rotate_f128_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
bootstrap_scratch_f128::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use crate::core_crypto::prelude::ModulusSwitchedLweCiphertext;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// Perform a blind rotation given an input [`modulus switched LWE
|
||||
@@ -208,7 +208,6 @@ pub fn blind_rotate_assign<OutputScalar, OutputCont, KeyCont>(
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -254,7 +253,7 @@ pub fn blind_rotate_assign_mem_optimized_requirement<OutputScalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
blind_rotate_assign_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -290,7 +289,6 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
|
||||
ggsw.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -382,12 +380,10 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
|
||||
/// polynomial_size,
|
||||
/// fft,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required();
|
||||
///
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
@@ -478,7 +474,7 @@ pub fn add_external_product_assign_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
impl_add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -533,7 +529,6 @@ pub fn cmux_assign<Scalar, Cont0, Cont1, GgswCont>(
|
||||
ggsw.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -646,12 +641,10 @@ pub fn cmux_assign<Scalar, Cont0, Cont1, GgswCont>(
|
||||
///
|
||||
/// let buffer_size_req =
|
||||
/// cmux_assign_mem_optimized_requirement::<u64>(glwe_size, polynomial_size, fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required();
|
||||
///
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
@@ -768,7 +761,7 @@ pub fn cmux_assign_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -969,7 +962,6 @@ pub fn programmable_bootstrap_lwe_ciphertext<
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -1051,7 +1043,7 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputSca
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
@@ -1139,7 +1131,7 @@ pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<Out
|
||||
polynomial_size: PolynomialSize,
|
||||
ciphertext_count: CiphertextCount,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
batch_bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, ciphertext_count, fft)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::prelude::{lwe_ciphertext_modulus_switch, ModulusSwitchedLweCiphertext};
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
|
||||
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
@@ -190,7 +190,6 @@ pub fn blind_rotate_ntt64_bnf_assign<OutputCont, KeyCont>(
|
||||
bsk.polynomial_size(),
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -452,7 +451,6 @@ pub fn programmable_bootstrap_ntt64_bnf_lwe_ciphertext<InputCont, OutputCont, Ac
|
||||
bsk.polynomial_size(),
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -732,22 +730,18 @@ pub(crate) fn ntt64_bnf_add_external_product_assign_scratch(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let decomp_sign_scratch =
|
||||
StackReq::try_new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch = StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch_single = StackReq::try_new_aligned::<u64>(polynomial_size.0, align)?;
|
||||
let standard_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
|
||||
let decomp_sign_scratch = StackReq::new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align);
|
||||
let ntt_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
|
||||
let ntt_scratch_single = StackReq::new_aligned::<u64>(polynomial_size.0, align);
|
||||
let _ = &ntt;
|
||||
|
||||
let substack2 = ntt_scratch_single;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = substack1
|
||||
.try_and(standard_scratch)?
|
||||
.try_and(decomp_sign_scratch)?;
|
||||
substack0.try_and(ntt_scratch)
|
||||
let substack1 = substack2.and(standard_scratch);
|
||||
let substack0 = substack1.and(standard_scratch).and(decomp_sign_scratch);
|
||||
substack0.and(ntt_scratch)
|
||||
}
|
||||
|
||||
/// Return the required memory for [`cmux_ntt64_bnf_assign`].
|
||||
@@ -755,7 +749,7 @@ pub(crate) fn ntt64_bnf_cmux_scratch(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
ntt64_bnf_add_external_product_assign_scratch(glwe_size, polynomial_size, ntt)
|
||||
}
|
||||
|
||||
@@ -764,9 +758,9 @@ pub fn blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
|
||||
.try_and(ntt64_bnf_cmux_scratch(glwe_size, polynomial_size, ntt)?)
|
||||
) -> StackReq {
|
||||
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
|
||||
.and(ntt64_bnf_cmux_scratch(glwe_size, polynomial_size, ntt))
|
||||
}
|
||||
|
||||
/// Return the required memory for
|
||||
@@ -775,10 +769,8 @@ pub fn programmable_bootstrap_ntt64_bnf_lwe_ciphertext_mem_optimized_requirement
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt)?
|
||||
.try_and(StackReq::try_new_aligned::<u64>(
|
||||
glwe_size.0 * polynomial_size.0,
|
||||
CACHELINE_ALIGN,
|
||||
)?)
|
||||
) -> StackReq {
|
||||
blind_rotate_ntt64_bnf_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt).and(
|
||||
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
|
||||
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
@@ -195,7 +195,6 @@ pub fn blind_rotate_ntt64_assign<InputCont, OutputCont, KeyCont>(
|
||||
bsk.polynomial_size(),
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -465,7 +464,6 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext<InputCont, OutputCont, AccCon
|
||||
bsk.polynomial_size(),
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -708,22 +706,18 @@ pub(crate) fn ntt64_add_external_product_assign_scratch(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let decomp_sign_scratch =
|
||||
StackReq::try_new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch = StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch_single = StackReq::try_new_aligned::<u64>(polynomial_size.0, align)?;
|
||||
let standard_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
|
||||
let decomp_sign_scratch = StackReq::new_aligned::<u8>(glwe_size.0 * polynomial_size.0, align);
|
||||
let ntt_scratch = StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, align);
|
||||
let ntt_scratch_single = StackReq::new_aligned::<u64>(polynomial_size.0, align);
|
||||
let _ = &ntt;
|
||||
|
||||
let substack2 = ntt_scratch_single;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = substack1
|
||||
.try_and(standard_scratch)?
|
||||
.try_and(decomp_sign_scratch)?;
|
||||
substack0.try_and(ntt_scratch)
|
||||
let substack1 = substack2.and(standard_scratch);
|
||||
let substack0 = substack1.and(standard_scratch).and(decomp_sign_scratch);
|
||||
substack0.and(ntt_scratch)
|
||||
}
|
||||
|
||||
/// Return the required memory for [`cmux_ntt64_assign`].
|
||||
@@ -731,7 +725,7 @@ pub(crate) fn ntt64_cmux_scratch(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
ntt64_add_external_product_assign_scratch(glwe_size, polynomial_size, ntt)
|
||||
}
|
||||
|
||||
@@ -740,9 +734,9 @@ pub fn blind_rotate_ntt64_assign_mem_optimized_requirement(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
|
||||
.try_and(ntt64_cmux_scratch(glwe_size, polynomial_size, ntt)?)
|
||||
) -> StackReq {
|
||||
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
|
||||
.and(ntt64_cmux_scratch(glwe_size, polynomial_size, ntt))
|
||||
}
|
||||
|
||||
/// Return the required memory for [`programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized`].
|
||||
@@ -750,8 +744,8 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized_requirement(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt: Ntt64View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
blind_rotate_ntt64_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt)?.try_and(
|
||||
StackReq::try_new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
) -> StackReq {
|
||||
blind_rotate_ntt64_assign_mem_optimized_requirement(glwe_size, polynomial_size, ntt).and(
|
||||
StackReq::new_aligned::<u64>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::wop_pbs::{
|
||||
extract_bits, extract_bits_scratch,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::FftView;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use rayon::prelude::*;
|
||||
use tfhe_fft::c64;
|
||||
|
||||
@@ -365,7 +365,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
extract_bits_scratch::<Scalar>(
|
||||
lwe_dimension,
|
||||
ksk_output_key_lwe_dimension,
|
||||
@@ -514,7 +514,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
///
|
||||
/// let buffer_size_req =
|
||||
/// convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required();
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
/// extract_bits_from_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
@@ -524,7 +523,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
/// polynomial_size,
|
||||
/// fft,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
@@ -541,7 +539,6 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
/// cbs_level_count,
|
||||
/// fft,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
@@ -711,7 +708,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimi
|
||||
fpksk_output_polynomial_size: PolynomialSize,
|
||||
level_cbs: DecompositionLevelCount,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
circuit_bootstrap_boolean_vertical_packing_scratch::<Scalar>(
|
||||
lwe_list_in_count,
|
||||
lwe_list_out_count,
|
||||
|
||||
@@ -268,7 +268,6 @@ where
|
||||
CiphertextCount(ciphertext_count),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -460,7 +459,6 @@ where
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -798,9 +796,7 @@ fn lwe_encrypt_pbs_ntt64_decrypt_custom_mod(params: ClassicTestParams<u64>) {
|
||||
polynomial_size,
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(stack_size);
|
||||
|
||||
@@ -1097,9 +1093,7 @@ fn lwe_encrypt_pbs_ntt64_bnf_decrypt(params: ClassicTestParams<u64>) {
|
||||
polynomial_size,
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(stack_size);
|
||||
|
||||
|
||||
@@ -340,9 +340,7 @@ fn hpu_noise_distribution(params: HpuTestParams) {
|
||||
polynomial_size,
|
||||
ntt,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
buffers.resize(stack_size);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::fft64::math::polynomial::{
|
||||
FourierPolynomialMutView, FourierPolynomialView,
|
||||
};
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// The caller must provide a properly configured [`FftView`] object and a `PodStack` used as a
|
||||
@@ -107,12 +107,10 @@ use tfhe_fft::c64;
|
||||
///
|
||||
/// let buffer_size_req =
|
||||
/// glwe_fast_keyswitch_requirement::<u64>(glwe_size_out, polynomial_size, fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required();
|
||||
///
|
||||
/// let buffer_size_req = buffer_size_req.max(
|
||||
/// convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
/// .unwrap()
|
||||
/// .unaligned_bytes_required(),
|
||||
/// );
|
||||
///
|
||||
@@ -320,21 +318,18 @@ pub fn glwe_fast_keyswitch_requirement<Scalar>(
|
||||
glwe_size_out: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size_out.0 * polynomial_size.0, align)?;
|
||||
StackReq::new_aligned::<Scalar>(glwe_size_out.0 * polynomial_size.0, align);
|
||||
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
|
||||
let fourier_scratch =
|
||||
StackReq::try_new_aligned::<c64>(glwe_size_out.0 * fourier_polynomial_size, align)?;
|
||||
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
|
||||
StackReq::new_aligned::<c64>(glwe_size_out.0 * fourier_polynomial_size, align);
|
||||
let fourier_scratch_single = StackReq::new_aligned::<c64>(fourier_polynomial_size, align);
|
||||
|
||||
let substack3 = fft.forward_scratch()?;
|
||||
let substack2 = substack3.try_and(fourier_scratch_single)?;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = StackReq::try_any_of([
|
||||
substack1.try_and(standard_scratch)?,
|
||||
fft.backward_scratch()?,
|
||||
])?;
|
||||
substack0.try_and(fourier_scratch)
|
||||
let substack3 = fft.forward_scratch();
|
||||
let substack2 = substack3.and(fourier_scratch_single);
|
||||
let substack1 = substack2.and(standard_scratch);
|
||||
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
|
||||
substack0.and(fourier_scratch)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::core_crypto::experimental::entities::fourier_pseudo_ggsw_ciphertext::
|
||||
};
|
||||
use crate::core_crypto::experimental::entities::pseudo_ggsw_ciphertext::PseudoGgswCiphertext;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// Convert a [`pseudo GGSW ciphertext`](`PseudoGgswCiphertext`) with standard coefficients to the
|
||||
@@ -31,7 +31,6 @@ pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier<Scalar, InputCont, Out
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
buffers.resize(
|
||||
convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -67,6 +66,6 @@ pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized<
|
||||
/// [`convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized`].
|
||||
pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized_requirement(
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
fill_with_forward_fourier_scratch(fft)
|
||||
}
|
||||
|
||||
@@ -147,21 +147,16 @@ fn lwe_encrypt_fast_ks_decrypt_custom_mod<
|
||||
ks1_polynomial_size,
|
||||
fft_ks,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
let ks_buffer_size_req = ks_buffer_size_req.max(
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized_requirement(fft_ks)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let pbs_buffer_size_req = programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<
|
||||
Scalar,
|
||||
>(glwe_dimension.to_glwe_size(), polynomial_size, fft_pbs)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
ks_buffers.resize(ks_buffer_size_req);
|
||||
pbs_buffers.resize(pbs_buffer_size_req);
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::core_crypto::fft_impl::fft64::math::decomposition::DecompositionLevel
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{FftView, FourierPolynomialList};
|
||||
use crate::core_crypto::fft_impl::fft64::math::polynomial::FourierPolynomialMutView;
|
||||
use aligned_vec::{avec, ABox};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
/// A pseudo GGSW ciphertext in the Fourier domain.
|
||||
@@ -262,7 +262,7 @@ impl<'a> PseudoFourierGgswCiphertextView<'a> {
|
||||
|
||||
/// Return the required memory for
|
||||
/// [`PseudoFourierGgswCiphertextMutView::fill_with_forward_fourier`].
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
|
||||
fft.forward_scratch()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::core_crypto::commons::parameters::{
|
||||
use crate::core_crypto::commons::traits::Container;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::prelude::{CiphertextModulusLog, ContainerMut};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
pub fn modulus_switch<Scalar: UnsignedInteger>(
|
||||
input: Scalar,
|
||||
@@ -35,7 +35,7 @@ pub trait FourierBootstrapKey<Scalar: UnsignedInteger> {
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self;
|
||||
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow>;
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq;
|
||||
|
||||
fn fill_with_forward_fourier<ContBsk>(
|
||||
&mut self,
|
||||
@@ -49,7 +49,7 @@ pub trait FourierBootstrapKey<Scalar: UnsignedInteger> {
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: &Self::Fft,
|
||||
) -> Result<StackReq, SizeOverflow>;
|
||||
) -> StackReq;
|
||||
|
||||
fn bootstrap<ContLweOut, ContLweIn, ContAcc>(
|
||||
&self,
|
||||
@@ -72,7 +72,7 @@ pub mod tests {
|
||||
use crate::core_crypto::fft_impl::common::FourierBootstrapKey;
|
||||
use crate::core_crypto::keycache::KeyCacheAccess;
|
||||
use crate::core_crypto::prelude::*;
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack};
|
||||
use dyn_stack::{PodBuffer, PodStack};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -160,9 +160,9 @@ pub mod tests {
|
||||
fourier_bsk.fill_with_forward_fourier(
|
||||
&std_bootstrapping_key,
|
||||
&fft,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(
|
||||
K::fill_with_forward_fourier_scratch(&fft).unwrap(),
|
||||
)),
|
||||
PodStack::new(
|
||||
&mut PodBuffer::try_new(K::fill_with_forward_fourier_scratch(&fft)).unwrap(),
|
||||
),
|
||||
);
|
||||
|
||||
// Our 4 bits message space
|
||||
@@ -209,14 +209,14 @@ pub mod tests {
|
||||
&lwe_ciphertext_in,
|
||||
&accumulator,
|
||||
&fft,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(
|
||||
K::bootstrap_scratch(
|
||||
PodStack::new(
|
||||
&mut PodBuffer::try_new(K::bootstrap_scratch(
|
||||
std_bootstrapping_key.glwe_size(),
|
||||
std_bootstrapping_key.polynomial_size(),
|
||||
&fft,
|
||||
)
|
||||
))
|
||||
.unwrap(),
|
||||
)),
|
||||
),
|
||||
);
|
||||
|
||||
// Decrypt the PBS result
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::core_crypto::prelude::{
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use core::any::TypeId;
|
||||
use core::mem::transmute;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
@@ -234,9 +234,9 @@ pub fn blind_rotate_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?
|
||||
.try_and(cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)?)
|
||||
) -> StackReq {
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)
|
||||
.and(cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft))
|
||||
}
|
||||
|
||||
/// Return the required memory for [`Fourier128LweBootstrapKey::bootstrap`].
|
||||
@@ -244,10 +244,13 @@ pub fn bootstrap_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
blind_rotate_scratch::<Scalar>(glwe_size, polynomial_size, fft)?.try_and(
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
)
|
||||
) -> StackReq {
|
||||
blind_rotate_scratch::<Scalar>(glwe_size, polynomial_size, fft).and(StackReq::new_aligned::<
|
||||
Scalar,
|
||||
>(
|
||||
glwe_size.0 * polynomial_size.0,
|
||||
CACHELINE_ALIGN,
|
||||
))
|
||||
}
|
||||
|
||||
impl<Cont> Fourier128LweBootstrapKey<Cont>
|
||||
@@ -470,7 +473,7 @@ where
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: &Self::Fft,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft.as_view())
|
||||
}
|
||||
|
||||
@@ -489,9 +492,9 @@ where
|
||||
self.bootstrap(lwe_out, lwe_in, accumulator, fft.as_view(), stack);
|
||||
}
|
||||
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow> {
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq {
|
||||
let _ = fft;
|
||||
Ok(StackReq::empty())
|
||||
StackReq::empty()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecomp
|
||||
use crate::core_crypto::prelude::ContainerMut;
|
||||
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::fft128::f128;
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
@@ -380,24 +380,20 @@ pub fn add_external_product_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let fourier_scratch = StackReq::try_new_aligned::<f64>(
|
||||
let standard_scratch = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
|
||||
let fourier_scratch = StackReq::new_aligned::<f64>(
|
||||
glwe_size.0 * polynomial_size.to_fourier_polynomial_size().0,
|
||||
align,
|
||||
)?;
|
||||
);
|
||||
let fourier_scratch_single =
|
||||
StackReq::try_new_aligned::<f64>(polynomial_size.to_fourier_polynomial_size().0, align)?;
|
||||
StackReq::new_aligned::<f64>(polynomial_size.to_fourier_polynomial_size().0, align);
|
||||
|
||||
let substack2 = StackReq::try_all_of([fourier_scratch_single; 4])?;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = StackReq::try_any_of([
|
||||
substack1.try_and(standard_scratch)?,
|
||||
fft.backward_scratch()?,
|
||||
])?;
|
||||
substack0.try_and(StackReq::try_all_of([fourier_scratch; 4])?)
|
||||
let substack2 = StackReq::all_of(&[fourier_scratch_single; 4]);
|
||||
let substack1 = substack2.and(standard_scratch);
|
||||
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
|
||||
substack0.and(StackReq::all_of(&[fourier_scratch; 4]))
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
@@ -798,7 +794,7 @@ pub fn cmux_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: Fft128View<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::PolynomialSize;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use core::any::TypeId;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
@@ -358,12 +358,12 @@ impl Fft128View<'_> {
|
||||
}
|
||||
|
||||
/// Return the memory required for a backward negacyclic FFT.
|
||||
pub fn backward_scratch(self) -> Result<StackReq, SizeOverflow> {
|
||||
let one = StackReq::try_new_aligned::<f64>(
|
||||
pub fn backward_scratch(self) -> StackReq {
|
||||
let one = StackReq::new_aligned::<f64>(
|
||||
self.polynomial_size().0 / 2,
|
||||
aligned_vec::CACHELINE_ALIGN,
|
||||
)?;
|
||||
StackReq::try_all_of([one; 4])
|
||||
);
|
||||
StackReq::all_of(&[one; 4])
|
||||
}
|
||||
|
||||
pub fn forward_as_torus<Scalar: UnsignedTorus>(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator};
|
||||
use aligned_vec::avec;
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::PodBuffer;
|
||||
|
||||
fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
let mut generator = new_random_generator();
|
||||
@@ -23,7 +23,7 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
*x = generator.random_uniform();
|
||||
}
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(fft.backward_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
fft.forward_as_torus(
|
||||
@@ -110,7 +110,7 @@ fn test_product<Scalar: UnsignedTorus>() {
|
||||
*y >>= Scalar::BITS - integer_magnitude;
|
||||
}
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap());
|
||||
let mut mem = PodBuffer::try_new(fft.backward_scratch()).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
fft.forward_as_torus(
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::core_crypto::fft_impl::common::tests::{
|
||||
use crate::core_crypto::prelude::test::{TestResources, FFT128_U128_PARAMS};
|
||||
use crate::core_crypto::prelude::*;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{GlobalPodBuffer, PodStack};
|
||||
use dyn_stack::{PodBuffer, PodStack};
|
||||
|
||||
#[test]
|
||||
fn test_split_external_product() {
|
||||
@@ -83,14 +83,14 @@ fn test_split_external_product() {
|
||||
&ggsw,
|
||||
&glwe,
|
||||
fft,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(
|
||||
fft128::crypto::ggsw::add_external_product_assign_scratch::<u128>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
PodStack::new(
|
||||
&mut PodBuffer::try_new(fft128::crypto::ggsw::add_external_product_assign_scratch::<
|
||||
u128,
|
||||
>(
|
||||
glwe_dimension.to_glwe_size(), polynomial_size, fft
|
||||
))
|
||||
.unwrap(),
|
||||
)),
|
||||
),
|
||||
);
|
||||
|
||||
let mut out_lo = GlweCiphertext::new(
|
||||
@@ -113,14 +113,14 @@ fn test_split_external_product() {
|
||||
&glwe_lo,
|
||||
&glwe_hi,
|
||||
fft,
|
||||
PodStack::new(&mut GlobalPodBuffer::new(
|
||||
fft128::crypto::ggsw::add_external_product_assign_scratch::<u128>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
PodStack::new(
|
||||
&mut PodBuffer::try_new(fft128::crypto::ggsw::add_external_product_assign_scratch::<
|
||||
u128,
|
||||
>(
|
||||
glwe_dimension.to_glwe_size(), polynomial_size, fft
|
||||
))
|
||||
.unwrap(),
|
||||
)),
|
||||
),
|
||||
);
|
||||
|
||||
for ((lo, hi), val) in out_lo
|
||||
@@ -170,14 +170,12 @@ fn test_split_pbs() {
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
fft128::crypto::bootstrap::bootstrap_scratch::<u128>(
|
||||
let mut mem = PodBuffer::try_new(fft128::crypto::bootstrap::bootstrap_scratch::<u128>(
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
))
|
||||
.unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
for _ in 0..20 {
|
||||
|
||||
@@ -23,7 +23,7 @@ use crate::core_crypto::prelude::{
|
||||
ModulusSwitchedLweCiphertext,
|
||||
};
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
@@ -186,7 +186,7 @@ impl FourierLweBootstrapKey<ABox<[c64]>> {
|
||||
}
|
||||
|
||||
/// Return the required memory for [`FourierLweBootstrapKeyMutView::fill_with_forward_fourier`].
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
|
||||
fft.forward_scratch()
|
||||
}
|
||||
|
||||
@@ -230,16 +230,16 @@ pub fn blind_rotate_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_any_of([
|
||||
) -> StackReq {
|
||||
StackReq::any_of(&[
|
||||
// tmp_poly allocation
|
||||
StackReq::try_new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
StackReq::try_all_of([
|
||||
StackReq::new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN),
|
||||
StackReq::all_of(&[
|
||||
// ct1 allocation
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
// external product
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
|
||||
])?,
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -248,9 +248,9 @@ pub fn bootstrap_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?.try_and(
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
) -> StackReq {
|
||||
blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft).and(
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -260,19 +260,19 @@ pub fn batch_blind_rotate_assign_scratch<Scalar>(
|
||||
polynomial_size: PolynomialSize,
|
||||
ciphertext_count: CiphertextCount,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_any_of([
|
||||
) -> StackReq {
|
||||
StackReq::any_of(&[
|
||||
// tmp_poly allocation
|
||||
StackReq::try_new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
StackReq::try_all_of([
|
||||
StackReq::new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN),
|
||||
StackReq::all_of(&[
|
||||
// ct1 allocation
|
||||
StackReq::try_new_aligned::<Scalar>(
|
||||
StackReq::new_aligned::<Scalar>(
|
||||
glwe_ciphertext_size(glwe_size, polynomial_size) * ciphertext_count.0,
|
||||
CACHELINE_ALIGN,
|
||||
)?,
|
||||
),
|
||||
// external product
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
|
||||
])?,
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -282,12 +282,12 @@ pub fn batch_bootstrap_scratch<Scalar>(
|
||||
polynomial_size: PolynomialSize,
|
||||
ciphertext_count: CiphertextCount,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
batch_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, ciphertext_count, fft)?
|
||||
.try_and(StackReq::try_new_aligned::<Scalar>(
|
||||
) -> StackReq {
|
||||
batch_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, ciphertext_count, fft)
|
||||
.and(StackReq::new_aligned::<Scalar>(
|
||||
glwe_ciphertext_size(glwe_size, polynomial_size) * ciphertext_count.0,
|
||||
CACHELINE_ALIGN,
|
||||
)?)
|
||||
))
|
||||
}
|
||||
|
||||
impl FourierLweBootstrapKeyView<'_> {
|
||||
@@ -596,7 +596,7 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> Result<StackReq, SizeOverflow> {
|
||||
fn fill_with_forward_fourier_scratch(fft: &Self::Fft) -> StackReq {
|
||||
fill_with_forward_fourier_scratch(fft.as_view())
|
||||
}
|
||||
|
||||
@@ -616,7 +616,7 @@ where
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: &Self::Fft,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft.as_view())
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use crate::core_crypto::entities::ggsw_ciphertext::{
|
||||
};
|
||||
use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertextMutView, GlweCiphertextView};
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
@@ -249,7 +249,7 @@ impl<'a> FourierGgswCiphertextView<'a> {
|
||||
}
|
||||
|
||||
/// Return the required memory for [`FourierGgswCiphertextMutView::fill_with_forward_fourier`].
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
|
||||
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> StackReq {
|
||||
fft.forward_scratch()
|
||||
}
|
||||
|
||||
@@ -463,23 +463,19 @@ pub fn add_external_product_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let standard_scratch = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
|
||||
let fourier_polynomial_size = polynomial_size.to_fourier_polynomial_size().0;
|
||||
let fourier_scratch =
|
||||
StackReq::try_new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align)?;
|
||||
let fourier_scratch_single = StackReq::try_new_aligned::<c64>(fourier_polynomial_size, align)?;
|
||||
StackReq::new_aligned::<c64>(glwe_size.0 * fourier_polynomial_size, align);
|
||||
let fourier_scratch_single = StackReq::new_aligned::<c64>(fourier_polynomial_size, align);
|
||||
|
||||
let substack3 = fft.forward_scratch()?;
|
||||
let substack2 = substack3.try_and(fourier_scratch_single)?;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = StackReq::try_any_of([
|
||||
substack1.try_and(standard_scratch)?,
|
||||
fft.backward_scratch()?,
|
||||
])?;
|
||||
substack0.try_and(fourier_scratch)
|
||||
let substack3 = fft.forward_scratch();
|
||||
let substack2 = substack3.and(fourier_scratch_single);
|
||||
let substack1 = substack2.and(standard_scratch);
|
||||
let substack0 = StackReq::any_of(&[substack1.and(standard_scratch), fft.backward_scratch()]);
|
||||
substack0.and(fourier_scratch)
|
||||
}
|
||||
|
||||
/// Perform the external product of `ggsw` and `glwe`, and adds the result to `out`.
|
||||
@@ -763,7 +759,7 @@ pub fn cmux_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use tfhe_fft::c64;
|
||||
|
||||
pub fn extract_bits_scratch<Scalar>(
|
||||
@@ -24,34 +24,32 @@ pub fn extract_bits_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
|
||||
let lwe_in_buffer =
|
||||
StackReq::try_new_aligned::<Scalar>(input_lwe_dimension.to_lwe_size().0, align)?;
|
||||
let lwe_in_buffer = StackReq::new_aligned::<Scalar>(input_lwe_dimension.to_lwe_size().0, align);
|
||||
let lwe_out_ks_buffer =
|
||||
StackReq::try_new_aligned::<Scalar>(ksk_after_key_size.to_lwe_size().0, align)?;
|
||||
let pbs_accumulator =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let lwe_out_pbs_buffer = StackReq::try_new_aligned::<Scalar>(
|
||||
StackReq::new_aligned::<Scalar>(ksk_after_key_size.to_lwe_size().0, align);
|
||||
let pbs_accumulator = StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align);
|
||||
let lwe_out_pbs_buffer = StackReq::new_aligned::<Scalar>(
|
||||
glwe_size
|
||||
.to_glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(polynomial_size)
|
||||
.to_lwe_size()
|
||||
.0,
|
||||
align,
|
||||
)?;
|
||||
);
|
||||
let lwe_bit_left_shift_buffer = lwe_in_buffer;
|
||||
let bootstrap_scratch = bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft)?;
|
||||
let bootstrap_scratch = bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft);
|
||||
|
||||
lwe_in_buffer
|
||||
.try_and(lwe_out_ks_buffer)?
|
||||
.try_and(pbs_accumulator)?
|
||||
.try_and(lwe_out_pbs_buffer)?
|
||||
.try_and(StackReq::try_any_of([
|
||||
.and(lwe_out_ks_buffer)
|
||||
.and(pbs_accumulator)
|
||||
.and(lwe_out_pbs_buffer)
|
||||
.and(StackReq::any_of(&[
|
||||
lwe_bit_left_shift_buffer,
|
||||
bootstrap_scratch,
|
||||
])?)
|
||||
]))
|
||||
}
|
||||
|
||||
/// Function to extract `number_of_bits_to_extract` from an [`LweCiphertext`] starting at the bit
|
||||
@@ -227,9 +225,9 @@ pub fn circuit_bootstrap_boolean_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<Scalar>(bsk_output_lwe_size.0, CACHELINE_ALIGN)?.try_and(
|
||||
homomorphic_shift_boolean_scratch::<Scalar>(lwe_in_size, glwe_size, polynomial_size, fft)?,
|
||||
) -> StackReq {
|
||||
StackReq::new_aligned::<Scalar>(bsk_output_lwe_size.0, CACHELINE_ALIGN).and(
|
||||
homomorphic_shift_boolean_scratch::<Scalar>(lwe_in_size, glwe_size, polynomial_size, fft),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -345,18 +343,14 @@ pub fn homomorphic_shift_boolean_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let align = CACHELINE_ALIGN;
|
||||
StackReq::try_new_aligned::<Scalar>(lwe_in_size.0, align)?
|
||||
.try_and(StackReq::try_new_aligned::<Scalar>(
|
||||
StackReq::new_aligned::<Scalar>(lwe_in_size.0, align)
|
||||
.and(StackReq::new_aligned::<Scalar>(
|
||||
polynomial_size.0 * glwe_size.0,
|
||||
align,
|
||||
)?)?
|
||||
.try_and(bootstrap_scratch::<Scalar>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?)
|
||||
))
|
||||
.and(bootstrap_scratch::<Scalar>(glwe_size, polynomial_size, fft))
|
||||
}
|
||||
|
||||
/// Homomorphic shift for LWE without padding bit
|
||||
@@ -445,18 +439,18 @@ pub fn cmux_tree_memory_optimized_scratch<Scalar>(
|
||||
polynomial_size: PolynomialSize,
|
||||
nb_layer: usize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
let t_scratch = StackReq::try_new_aligned::<Scalar>(
|
||||
) -> StackReq {
|
||||
let t_scratch = StackReq::new_aligned::<Scalar>(
|
||||
polynomial_size.0 * glwe_size.0 * nb_layer,
|
||||
CACHELINE_ALIGN,
|
||||
)?;
|
||||
);
|
||||
|
||||
StackReq::try_all_of([
|
||||
StackReq::all_of(&[
|
||||
t_scratch, // t_0
|
||||
t_scratch, // t_1
|
||||
StackReq::try_new::<usize>(nb_layer)?, // t_fill
|
||||
StackReq::new::<usize>(nb_layer), // t_fill
|
||||
t_scratch, // diff
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
|
||||
add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -477,10 +471,12 @@ pub fn cmux_tree_memory_optimized<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
let polynomial_size = ggsw_list.polynomial_size();
|
||||
let nb_layer = ggsw_list.count();
|
||||
|
||||
debug_assert!(stack.can_hold(
|
||||
cmux_tree_memory_optimized_scratch::<Scalar>(glwe_size, polynomial_size, nb_layer, fft)
|
||||
.unwrap()
|
||||
));
|
||||
debug_assert!(stack.can_hold(cmux_tree_memory_optimized_scratch::<Scalar>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
nb_layer,
|
||||
fft
|
||||
)));
|
||||
|
||||
// These are accumulator that will be used to propagate the result from layer to layer
|
||||
// At index 0 you have the lut that will be loaded, and then the result for each layer gets
|
||||
@@ -596,40 +592,40 @@ pub fn circuit_bootstrap_boolean_vertical_packing_scratch<Scalar>(
|
||||
fpksk_output_polynomial_size: PolynomialSize,
|
||||
level_cbs: DecompositionLevelCount,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
// We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out
|
||||
let number_of_luts = lwe_list_out_count.0;
|
||||
let small_lut_size = PolynomialCount(big_lut_polynomial_count.0 / number_of_luts);
|
||||
|
||||
StackReq::try_all_of([
|
||||
StackReq::try_new_aligned::<c64>(
|
||||
StackReq::all_of(&[
|
||||
StackReq::new_aligned::<c64>(
|
||||
lwe_list_in_count.0 * fpksk_output_polynomial_size.0 / 2
|
||||
* glwe_size.0
|
||||
* glwe_size.0
|
||||
* level_cbs.0,
|
||||
CACHELINE_ALIGN,
|
||||
)?,
|
||||
StackReq::try_new_aligned::<Scalar>(
|
||||
),
|
||||
StackReq::new_aligned::<Scalar>(
|
||||
fpksk_output_polynomial_size.0 * glwe_size.0 * glwe_size.0 * level_cbs.0,
|
||||
CACHELINE_ALIGN,
|
||||
)?,
|
||||
StackReq::try_any_of([
|
||||
),
|
||||
StackReq::any_of(&[
|
||||
circuit_bootstrap_boolean_scratch::<Scalar>(
|
||||
lwe_in_size,
|
||||
bsk_output_lwe_size,
|
||||
glwe_size,
|
||||
fpksk_output_polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
),
|
||||
fill_with_forward_fourier_scratch(fft),
|
||||
vertical_packing_scratch::<Scalar>(
|
||||
glwe_size,
|
||||
fpksk_output_polynomial_size,
|
||||
small_lut_size,
|
||||
lwe_list_in_count.0,
|
||||
fft,
|
||||
)?,
|
||||
])?,
|
||||
),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -662,7 +658,6 @@ pub fn circuit_bootstrap_boolean_vertical_packing<Scalar: UnsignedTorus + CastIn
|
||||
level_cbs,
|
||||
fft
|
||||
)
|
||||
.unwrap()
|
||||
));
|
||||
debug_assert!(
|
||||
lwe_list_out.lwe_size().to_lwe_dimension() == fourier_bsk.output_lwe_dimension(),
|
||||
@@ -742,7 +737,7 @@ pub fn vertical_packing_scratch<Scalar>(
|
||||
lut_polynomial_count: PolynomialCount,
|
||||
ggsw_list_count: usize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
) -> StackReq {
|
||||
let bits = core::mem::size_of::<Scalar>() * 8;
|
||||
|
||||
// Get the base 2 logarithm (rounded down) of the number of polynomials in the list i.e. if
|
||||
@@ -757,18 +752,18 @@ pub fn vertical_packing_scratch<Scalar>(
|
||||
log_lut_number
|
||||
};
|
||||
|
||||
StackReq::try_all_of([
|
||||
StackReq::all_of(&[
|
||||
// cmux_tree_lut_res
|
||||
StackReq::try_new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?,
|
||||
StackReq::try_any_of([
|
||||
wop_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
|
||||
StackReq::new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN),
|
||||
StackReq::any_of(&[
|
||||
wop_blind_rotate_assign_scratch::<Scalar>(glwe_size, polynomial_size, fft),
|
||||
cmux_tree_memory_optimized_scratch::<Scalar>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
log_number_of_luts_for_cmux_tree,
|
||||
fft,
|
||||
)?,
|
||||
])?,
|
||||
),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -833,10 +828,10 @@ pub fn wop_blind_rotate_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
fft: FftView<'_>,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_all_of([
|
||||
StackReq::try_new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?,
|
||||
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft)?,
|
||||
) -> StackReq {
|
||||
StackReq::all_of(&[
|
||||
StackReq::new_aligned::<Scalar>(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN),
|
||||
cmux_scratch::<Scalar>(glwe_size, polynomial_size, fft),
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::core_crypto::prelude::test::{
|
||||
TestResources, FFT_WOPBS_N1024_PARAMS, FFT_WOPBS_N2048_PARAMS, FFT_WOPBS_N512_PARAMS,
|
||||
FFT_WOPBS_PARAMS,
|
||||
};
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::PodBuffer;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -159,19 +159,19 @@ pub fn test_extract_bits() {
|
||||
let input_lwe_dimension = lwe_big_sk.lwe_dimension();
|
||||
|
||||
let req = || {
|
||||
StackReq::try_any_of([
|
||||
fill_with_forward_fourier_scratch(fft)?,
|
||||
StackReq::any_of(&[
|
||||
fill_with_forward_fourier_scratch(fft),
|
||||
extract_bits_scratch::<u64>(
|
||||
input_lwe_dimension,
|
||||
ksk_lwe_big_to_small.output_key_lwe_dimension(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)?,
|
||||
),
|
||||
])
|
||||
};
|
||||
let req = req().unwrap();
|
||||
let mut mem = GlobalPodBuffer::new(req);
|
||||
let req = req();
|
||||
let mut mem = PodBuffer::try_new(req).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
fourier_bsk
|
||||
@@ -348,16 +348,14 @@ fn test_circuit_bootstrapping_binary() {
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_scratch::<u64>(
|
||||
let mut mem = PodBuffer::try_new(circuit_bootstrap_boolean_scratch::<u64>(
|
||||
lwe_in.lwe_size(),
|
||||
fourier_bsk.output_lwe_dimension().to_lwe_size(),
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
))
|
||||
.unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
// Execute the CBS
|
||||
circuit_bootstrap_boolean(
|
||||
@@ -539,7 +537,7 @@ pub fn test_cmux_tree() {
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap());
|
||||
let mut mem = PodBuffer::try_new(fill_with_forward_fourier_scratch(fft)).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
fourier_ggsw
|
||||
.as_mut_view()
|
||||
@@ -548,10 +546,13 @@ pub fn test_cmux_tree() {
|
||||
|
||||
let mut result_cmux_tree =
|
||||
GlweCiphertext::new(0_u64, glwe_size, polynomial_size, ciphertext_modulus);
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
cmux_tree_memory_optimized_scratch::<u64>(glwe_size, polynomial_size, nb_ggsw, fft)
|
||||
.unwrap(),
|
||||
);
|
||||
let mut mem = PodBuffer::try_new(cmux_tree_memory_optimized_scratch::<u64>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
nb_ggsw,
|
||||
fft,
|
||||
))
|
||||
.unwrap();
|
||||
cmux_tree_memory_optimized(
|
||||
result_cmux_tree.as_mut_view(),
|
||||
lut.as_view(),
|
||||
@@ -656,16 +657,14 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
extract_bits_scratch::<u64>(
|
||||
let mut mem = PodBuffer::try_new(extract_bits_scratch::<u64>(
|
||||
input_lwe_dimension,
|
||||
ksk_lwe_big_to_small.output_key_lwe_dimension(),
|
||||
fourier_bsk.glwe_size(),
|
||||
polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
))
|
||||
.unwrap();
|
||||
extract_bits(
|
||||
extracted_bits_lwe_list.as_mut_view(),
|
||||
lwe_in.as_view(),
|
||||
@@ -729,8 +728,8 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
|
||||
);
|
||||
|
||||
// Perform circuit bootstrap + vertical packing
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
|
||||
let mut mem =
|
||||
PodBuffer::try_new(circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
|
||||
extracted_bits_lwe_list.lwe_ciphertext_count(),
|
||||
vertical_packing_lwe_list_out.lwe_ciphertext_count(),
|
||||
extracted_bits_lwe_list.lwe_size(),
|
||||
@@ -740,9 +739,8 @@ pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() {
|
||||
vec_pfpksk.output_polynomial_size(),
|
||||
level_cbs,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
))
|
||||
.unwrap();
|
||||
circuit_bootstrap_boolean_vertical_packing(
|
||||
lut_poly_list.as_view(),
|
||||
fourier_bsk.as_view(),
|
||||
@@ -864,8 +862,7 @@ fn test_wop_add_one(params: FftWopPbsTestParams<u64>) {
|
||||
let lut_as_polynomial_list = PolynomialList::from_container(lut, polynomial_size);
|
||||
|
||||
// Perform circuit bootstrap + vertical packing
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
|
||||
let mut mem = PodBuffer::new(circuit_bootstrap_boolean_vertical_packing_scratch::<u64>(
|
||||
extracted_bits.lwe_ciphertext_count(),
|
||||
output_cbs_vp.lwe_ciphertext_count(),
|
||||
extracted_bits.lwe_size(),
|
||||
@@ -875,9 +872,7 @@ fn test_wop_add_one(params: FftWopPbsTestParams<u64>) {
|
||||
cbs_pfpksk.output_polynomial_size(),
|
||||
level_cbs,
|
||||
fft,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
));
|
||||
circuit_bootstrap_boolean_vertical_packing(
|
||||
lut_as_polynomial_list.as_view(),
|
||||
fourier_bsk.as_view(),
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainer
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
use aligned_vec::{avec, ABox};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
use rayon::prelude::*;
|
||||
use std::any::TypeId;
|
||||
use std::collections::hash_map::Entry;
|
||||
@@ -380,18 +380,16 @@ impl FftView<'_> {
|
||||
}
|
||||
|
||||
/// Return the memory required for a forward negacyclic FFT.
|
||||
pub fn forward_scratch(self) -> Result<StackReq, SizeOverflow> {
|
||||
pub fn forward_scratch(self) -> StackReq {
|
||||
self.plan.fft_scratch()
|
||||
}
|
||||
|
||||
/// Return the memory required for a backward negacyclic FFT.
|
||||
pub fn backward_scratch(self) -> Result<StackReq, SizeOverflow> {
|
||||
self.plan
|
||||
.fft_scratch()?
|
||||
.try_and(StackReq::try_new_aligned::<c64>(
|
||||
pub fn backward_scratch(self) -> StackReq {
|
||||
self.plan.fft_scratch().and(StackReq::new_aligned::<c64>(
|
||||
self.polynomial_size().to_fourier_polynomial_size().0,
|
||||
aligned_vec::CACHELINE_ALIGN,
|
||||
)?)
|
||||
))
|
||||
}
|
||||
|
||||
/// Perform a negacyclic real FFT of `standard`, viewed as torus elements, and stores the
|
||||
@@ -793,11 +791,7 @@ pub fn par_convert_polynomials_list_to_fourier<Scalar: UnsignedTorus>(
|
||||
dest.par_chunks_mut(chunk_size * f_polynomial_size)
|
||||
.zip_eq(origin.par_chunks(chunk_size * polynomial_size.0))
|
||||
.for_each(|(fourier_poly_chunk, standard_poly_chunk)| {
|
||||
let stack_len = fft
|
||||
.forward_scratch()
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
let stack_len = fft.forward_scratch().unaligned_bytes_required();
|
||||
let mut mem = vec![0; stack_len];
|
||||
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator};
|
||||
use aligned_vec::avec;
|
||||
use dyn_stack::GlobalPodBuffer;
|
||||
use dyn_stack::PodBuffer;
|
||||
|
||||
fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
let mut generator = new_random_generator();
|
||||
@@ -23,11 +23,8 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
|
||||
*x = generator.random_uniform();
|
||||
}
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
fft.forward_scratch()
|
||||
.unwrap()
|
||||
.and(fft.backward_scratch().unwrap()),
|
||||
);
|
||||
let mut mem =
|
||||
PodBuffer::try_new(fft.forward_scratch().and(fft.backward_scratch())).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
// Simple roundtrip
|
||||
@@ -125,11 +122,8 @@ fn test_product<Scalar: UnsignedTorus>() {
|
||||
*y >>= Scalar::BITS - integer_magnitude;
|
||||
}
|
||||
|
||||
let mut mem = GlobalPodBuffer::new(
|
||||
fft.forward_scratch()
|
||||
.unwrap()
|
||||
.and(fft.backward_scratch().unwrap()),
|
||||
);
|
||||
let mut mem =
|
||||
PodBuffer::try_new(fft.forward_scratch().and(fft.backward_scratch())).unwrap();
|
||||
let stack = PodStack::new(&mut mem);
|
||||
|
||||
fft.forward_as_torus(fourier0.as_mut_view(), poly0.as_view(), stack);
|
||||
|
||||
@@ -1510,7 +1510,6 @@ pub(crate) fn apply_standard_blind_rotate<OutputScalar, OutputCont>(
|
||||
poly_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -1588,9 +1587,7 @@ pub(crate) fn apply_programmable_bootstrap_128<InputScalar, InputCont, OutputSca
|
||||
bsk_polynomial_size,
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap();
|
||||
.unaligned_bytes_required();
|
||||
|
||||
let br_input_modulus_log = bsk.polynomial_size().to_blind_rotation_input_modulus_log();
|
||||
let lwe_ciphertext_to_squash_noise = modulus_switch_noise_reduction_key
|
||||
|
||||
@@ -720,7 +720,6 @@ mod experimental {
|
||||
bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
@@ -821,7 +820,6 @@ mod experimental {
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
@@ -927,7 +925,6 @@ mod experimental {
|
||||
self.param.cbs_level,
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user