Compare commits

...

1 Commits

Author SHA1 Message Date
sarah el kazdadi
b84c0672ce feat(ntt): impl crt/ntt backend 2023-09-19 14:57:09 +02:00
11 changed files with 1417 additions and 4 deletions

View File

@@ -59,6 +59,7 @@ rayon = { version = "1.5.0" }
bincode = { version = "1.3.3", optional = true }
concrete-fft = { version = "0.3.0", features = ["serde", "fft128"] }
pulp = "0.13"
concrete-ntt = "0.1.0"
aligned-vec = { version = "0.5", features = ["serde"] }
dyn-stack = { version = "0.9" }
paste = { version = "1.0.7", optional = true }
@@ -164,6 +165,11 @@ name = "pbs128-bench"
path = "benches/core_crypto/pbs128_bench.rs"
harness = false
[[bench]]
name = "pbs-crt-bench"
path = "benches/core_crypto/pbs_crt_bench.rs"
harness = false
[[bench]]
name = "boolean-bench"
path = "benches/boolean/bench.rs"

View File

@@ -0,0 +1,116 @@
use criterion::{criterion_group, criterion_main, Criterion};
use dyn_stack::PodStack;
fn criterion_bench(c: &mut Criterion) {
{
use tfhe::core_crypto::fft_impl::crt_ntt::crypto::bootstrap::{
bootstrap_scratch, CrtNttLweBootstrapKey,
};
use tfhe::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt64;
use tfhe::core_crypto::prelude::*;
type Scalar = u64;
let small_lwe_dimension = LweDimension(742);
let glwe_dimension = GlweDimension(1);
let polynomial_size = PolynomialSize(2048);
let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
let pbs_base_log = DecompositionBaseLog(23);
let pbs_level = DecompositionLevelCount(1);
// Request the best seeder possible, starting with hardware entropy sources and falling back
// to /dev/random on Unix systems if enabled via cargo features
let mut boxed_seeder = new_seeder();
// Get a mutable reference to the seeder as a trait object from the Box returned by
// new_seeder
let seeder = boxed_seeder.as_mut();
// Create a generator which uses a CSPRNG to generate secret keys
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
// Create a generator which uses two CSPRNGs to generate public masks and secret encryption
// noise
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
// Generate an LweSecretKey with binary coefficients
let small_lwe_sk =
LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator);
// Generate a GlweSecretKey with binary coefficients
let glwe_sk = GlweSecretKey::<Vec<Scalar>>::generate_new_binary(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
// Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
// Create the empty bootstrapping key in the NTT domain
let ntt_bsk = CrtNttLweBootstrapKey::new(
small_lwe_dimension,
polynomial_size,
glwe_dimension.to_glwe_size(),
pbs_base_log,
pbs_level,
);
let fft = CrtNtt64::new(polynomial_size);
let fft = fft.as_view();
// We don't need the standard bootstrapping key anymore
// Our 4 bits message space
let message_modulus: Scalar = 1 << 4;
// Our input message
let input_message: Scalar = 3;
// Delta used to encode 4 bits of message + a bit of padding on Scalar
let delta: Scalar = (1 << (Scalar::BITS - 1)) / message_modulus;
// Apply our encoding
let plaintext = Plaintext(input_message * delta);
// Allocate a new LweCiphertext and encrypt our plaintext
let lwe_ciphertext_in: LweCiphertextOwned<Scalar> = allocate_and_encrypt_new_lwe_ciphertext(
&small_lwe_sk,
plaintext,
lwe_modular_std_dev,
&mut encryption_generator,
);
let accumulator: GlweCiphertextOwned<Scalar> =
GlweCiphertextOwned::new(Scalar::ONE, glwe_dimension.to_glwe_size(), polynomial_size);
// Allocate the LweCiphertext to store the result of the PBS
let mut pbs_multiplication_ct: LweCiphertext<Vec<Scalar>> =
LweCiphertext::new(0, big_lwe_sk.lwe_dimension().to_lwe_size());
let mut buf = vec![
0u8;
bootstrap_scratch::<u32, 5, Scalar>(
ntt_bsk.glwe_size(),
ntt_bsk.polynomial_size(),
)
.unwrap()
.unaligned_bytes_required()
];
c.bench_function("pbs-crt-u64-u32x5", |b| {
b.iter(|| {
ntt_bsk.bootstrap(
&mut pbs_multiplication_ct,
&lwe_ciphertext_in,
&accumulator,
fft,
PodStack::new(&mut buf),
)
});
});
}
}
criterion_group!(benches, criterion_bench);
criterion_main!(benches);

View File

@@ -1,12 +1,11 @@
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
ModulusSwitchOffset, PolynomialSize,
};
use crate::core_crypto::commons::traits::Container;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::ContainerMut;
use crate::core_crypto::prelude::{ContainerMut, UnsignedInteger};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
/// This function switches modulus for a single coefficient of a ciphertext,
@@ -14,7 +13,7 @@ use dyn_stack::{PodStack, SizeOverflow, StackReq};
///
/// offset: the number of msb discarded
/// lut_count_log: the right padding
pub fn pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
pub fn pbs_modulus_switch<Scalar: UnsignedInteger + CastInto<usize>>(
input: Scalar,
poly_size: PolynomialSize,
offset: ModulusSwitchOffset,
@@ -287,3 +286,102 @@ pub mod tests {
);
}
}
fn collect_to_array<const N: usize, I: Iterator>(iter: I) -> [I::Item; N] {
// TODO: avoid allocating here
match iter.collect::<Vec<_>>().try_into() {
Ok(arr) => arr,
Err(_) => unreachable!(),
}
}
pub fn chain_array_with_context<Ctx, T, const N: usize>(
ctx: Ctx,
f: impl FnMut(Ctx) -> (T, Ctx),
) -> ([T; N], Ctx) {
let mut ctx = Some(ctx);
let mut f = f;
(
[(); N].map(|()| {
let local_ctx = ctx.take().unwrap();
let (val, local_ctx) = f(local_ctx);
ctx = Some(local_ctx);
val
}),
ctx.unwrap(),
)
}
pub fn as_ref_array<T, const N: usize>(array: &[T; N]) -> [&T; N] {
collect_to_array(array.iter())
}
pub fn as_mut_array<T, const N: usize>(array: &mut [T; N]) -> [&mut T; N] {
collect_to_array(array.iter_mut())
}
pub fn zip_array<T, U, const N: usize>(first: [T; N], second: [U; N]) -> [(T, U); N] {
collect_to_array(core::iter::zip(first, second))
}
pub struct ArrayIter<I, const N: usize> {
iters: [I; N],
}
impl<I: Iterator, const N: usize> Iterator for ArrayIter<I, N> {
type Item = [I::Item; N];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let items = as_mut_array(&mut self.iters).map(|iter| iter.next());
if items.iter().all(|item| item.is_some()) {
Some(items.map(|item| item.unwrap()))
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
// taken from core::iter::Zip's size_hint impl
self.iters.iter().fold((usize::MAX, None), |acc, iter| {
let (a_lower, a_upper) = acc;
let (b_lower, b_upper) = iter.size_hint();
let lower = core::cmp::min(a_lower, b_lower);
let upper = match (a_upper, b_upper) {
(Some(x), Some(y)) => Some(core::cmp::min(x, y)),
(Some(x), None) => Some(x),
(None, Some(y)) => Some(y),
(None, None) => None,
};
(lower, upper)
})
}
}
impl<I: DoubleEndedIterator, const N: usize> DoubleEndedIterator for ArrayIter<I, N> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
let items = as_mut_array(&mut self.iters).map(|iter| iter.next_back());
if items.iter().all(|item| item.is_some()) {
Some(items.map(|item| item.unwrap()))
} else {
None
}
}
}
impl<I: ExactSizeIterator, const N: usize> ExactSizeIterator for ArrayIter<I, N> {
#[inline]
fn len(&self) -> usize {
self.iters
.iter()
.fold(usize::MAX, |acc, iter| acc.min(iter.len()))
}
}
pub fn iter_array<I: IntoIterator, const N: usize>(iters: [I; N]) -> ArrayIter<I::IntoIter, N> {
ArrayIter {
iters: iters.map(|iter| iter.into_iter()),
}
}

View File

@@ -0,0 +1,452 @@
use super::ggsw::{cmux, cmux_scratch, CrtNttGgswCiphertext};
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
use crate::core_crypto::algorithms::polynomial_algorithms::*;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
ModulusSwitchOffset, MonomialDegree, PolynomialSize,
};
use crate::core_crypto::commons::traits::{
Container, ContiguousEntityContainer, ContiguousEntityContainerMut, Split,
};
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::{
as_mut_array, as_ref_array, iter_array, pbs_modulus_switch, FourierBootstrapKey,
};
use crate::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt;
use crate::core_crypto::prelude::{CiphertextModulus, ContainerMut, UnsignedInteger};
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
pub struct CrtNttLweBootstrapKey<
CrtScalar,
const N_COMPONENTS: usize,
C: Container<Element = CrtScalar>,
> {
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
input_lwe_dimension: LweDimension,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, C>
{
pub fn from_container(
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
input_lwe_dimension: LweDimension,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
let container_len = input_lwe_dimension.0
* polynomial_size.0
* decomposition_level_count.0
* glwe_size.0
* glwe_size.0;
data.iter()
.for_each(|data| assert_eq!(data.container_len(), container_len));
Self {
data,
polynomial_size,
input_lwe_dimension,
glwe_size,
decomposition_base_log,
decomposition_level_count,
}
}
/// Return an iterator over the GGSW ciphertexts composing the key.
pub fn into_ggsw_iter(
self,
) -> impl DoubleEndedIterator<Item = CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, C>>
where
C: Split,
{
iter_array(
self.data
.map(|data| data.split_into(self.input_lwe_dimension.0)),
)
.map(move |data| {
CrtNttGgswCiphertext::from_container(
data,
self.polynomial_size,
self.glwe_size,
self.decomposition_base_log,
self.decomposition_level_count,
)
})
}
pub fn input_lwe_dimension(&self) -> LweDimension {
self.input_lwe_dimension
}
pub fn polynomial_size(&self) -> PolynomialSize {
self.polynomial_size
}
pub fn glwe_size(&self) -> GlweSize {
self.glwe_size
}
pub fn decomposition_base_log(&self) -> DecompositionBaseLog {
self.decomposition_base_log
}
pub fn decomposition_level_count(&self) -> DecompositionLevelCount {
self.decomposition_level_count
}
pub fn output_lwe_dimension(&self) -> LweDimension {
LweDimension((self.glwe_size.0 - 1) * self.polynomial_size().0)
}
pub fn data(self) -> [C; N_COMPONENTS] {
self.data
}
pub fn as_view(&self) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[C::Element]> {
CrtNttLweBootstrapKey {
data: as_ref_array(&self.data).map(|data| data.as_ref()),
polynomial_size: self.polynomial_size,
input_lwe_dimension: self.input_lwe_dimension,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
}
}
pub fn as_mut_view(
&mut self,
) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &mut [C::Element]>
where
C: AsMut<[C::Element]>,
{
CrtNttLweBootstrapKey {
data: as_mut_array(&mut self.data).map(|data| data.as_mut()),
polynomial_size: self.polynomial_size,
input_lwe_dimension: self.input_lwe_dimension,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
}
}
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize>
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>>
{
pub fn new(
input_lwe_dimension: LweDimension,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>> {
let container_len = polynomial_size.0
* input_lwe_dimension.0
* decomposition_level_count.0
* glwe_size.0
* glwe_size.0;
let boxed =
[(); N_COMPONENTS].map(|()| avec![CrtScalar::ZERO; container_len].into_boxed_slice());
CrtNttLweBootstrapKey::from_container(
boxed,
polynomial_size,
input_lwe_dimension,
glwe_size,
decomposition_base_log,
decomposition_level_count,
)
}
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Cont>
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, Cont>
where
Cont: ContainerMut<Element = CrtScalar>,
{
/// Fill a bootstrapping key with the NTT of a bootstrapping key in the standard
/// domain.
pub fn fill_with_forward_ntt<Scalar, ContBsk>(
&mut self,
coef_bsk: &LweBootstrapKey<ContBsk>,
ntt_plan: Scalar::PlanView<'_>,
) where
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
ContBsk: Container<Element = Scalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
>(
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &mut [CrtScalar]>,
coef_bsk: LweBootstrapKey<&[Scalar]>,
ntt_plan: Scalar::PlanView<'_>,
) {
for (mut ntt_ggsw, standard_ggsw) in izip!(this.into_ggsw_iter(), coef_bsk.iter()) {
ntt_ggsw.fill_with_forward_ntt(&standard_ggsw, ntt_plan);
}
}
implementation(self.as_mut_view(), coef_bsk.as_view(), ntt_plan)
}
}
/// Return the required memory for [`CrtNttLweBootstrapKey::blind_rotate_assign`].
pub fn blind_rotate_scratch<CrtScalar, const N_COMPONENTS: usize, Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
) -> Result<StackReq, SizeOverflow>
where
CrtScalar: UnsignedInteger,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
{
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?.try_and(
cmux_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)?,
)
}
/// Return the required memory for [`CrtNttLweBootstrapKey::bootstrap`].
pub fn bootstrap_scratch<CrtScalar, const N_COMPONENTS: usize, Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
) -> Result<StackReq, SizeOverflow>
where
CrtScalar: UnsignedInteger,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
{
blind_rotate_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)?.try_and(
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
)
}
impl<CrtScalar, const N_COMPONENTS: usize, Cont>
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, Cont>
where
CrtScalar: UnsignedInteger,
Cont: Container<Element = CrtScalar>,
{
// CastInto required for PBS modulus switch which returns a usize
pub fn blind_rotate_assign<Scalar, ContLut, ContLwe>(
&self,
lut: &mut GlweCiphertext<ContLut>,
lwe: &LweCiphertext<ContLwe>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) where
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
ContLut: ContainerMut<Element = Scalar>,
ContLwe: Container<Element = Scalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
>(
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
mut lut: GlweCiphertext<&mut [Scalar]>,
lwe: LweCiphertext<&[Scalar]>,
ntt_plan: Scalar::PlanView<'_>,
mut stack: PodStack<'_>,
) {
let lwe = lwe.as_ref();
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();
let lut_poly_size = lut.polynomial_size();
let monomial_degree = pbs_modulus_switch(
*lwe_body,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
lut.as_mut_polynomial_list()
.iter_mut()
.for_each(|mut poly| {
polynomial_wrapping_monic_monomial_div_assign(
&mut poly,
MonomialDegree(monomial_degree),
)
});
// We initialize the ct_0 used for the successive cmuxes
let mut ct0 = lut;
for (lwe_mask_element, bootstrap_key_ggsw) in
izip!(lwe_mask.iter(), this.into_ggsw_iter())
{
if *lwe_mask_element != Scalar::ZERO {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
let mut ct1 = GlweCiphertextMutView::from_container(
&mut *ct1,
ct0.polynomial_size(),
CiphertextModulus::new_native(),
);
// We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat}
for mut poly in ct1.as_mut_polynomial_list().iter_mut() {
polynomial_wrapping_monic_monomial_mul_assign(
&mut poly,
MonomialDegree(pbs_modulus_switch(
*lwe_mask_element,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
)),
);
}
// ct1 is re-created each loop it can be moved, ct0 is already a view, but
// as_mut_view is required to keep borrow rules consistent
cmux(&mut ct0, &mut ct1, &bootstrap_key_ggsw, ntt_plan, stack);
}
}
}
implementation(
self.as_view(),
lut.as_mut_view(),
lwe.as_view(),
ntt_plan,
stack,
)
}
pub fn bootstrap<Scalar, ContLweOut, ContLweIn, ContAcc>(
&self,
lwe_out: &mut LweCiphertext<ContLweOut>,
lwe_in: &LweCiphertext<ContLweIn>,
accumulator: &GlweCiphertext<ContAcc>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
ContLweOut: ContainerMut<Element = Scalar>,
ContLweIn: Container<Element = Scalar>,
ContAcc: Container<Element = Scalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
>(
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
mut lwe_out: LweCiphertext<&mut [Scalar]>,
lwe_in: LweCiphertext<&[Scalar]>,
accumulator: GlweCiphertext<&[Scalar]>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) {
let (mut local_accumulator_data, stack) =
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
let mut local_accumulator = GlweCiphertextMutView::from_container(
&mut *local_accumulator_data,
accumulator.polynomial_size(),
CiphertextModulus::new_native(),
);
this.blind_rotate_assign(
&mut local_accumulator.as_mut_view(),
&lwe_in,
ntt_plan,
stack,
);
extract_lwe_sample_from_glwe_ciphertext(
&local_accumulator,
&mut lwe_out,
MonomialDegree(0),
);
}
implementation(
self.as_view(),
lwe_out.as_mut_view(),
lwe_in.as_view(),
accumulator.as_view(),
ntt_plan,
stack,
)
}
}
impl<CrtScalar, const N_COMPONENTS: usize, Scalar> FourierBootstrapKey<Scalar>
for CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>>
where
CrtScalar: UnsignedInteger,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
{
type Fft = Scalar::Plan;
fn new_fft(polynomial_size: PolynomialSize) -> Self::Fft {
Scalar::new_plan(polynomial_size)
}
fn new(
input_lwe_dimension: LweDimension,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
Self::new(
input_lwe_dimension,
polynomial_size,
glwe_size,
decomposition_base_log,
decomposition_level_count,
)
}
fn fill_with_forward_fourier<ContBsk>(
&mut self,
coef_bsk: &LweBootstrapKey<ContBsk>,
ntt_plan: &Self::Fft,
stack: PodStack<'_>,
) where
ContBsk: Container<Element = Scalar>,
{
let _ = stack;
let ntt_plan = Scalar::plan_as_view(ntt_plan);
self.fill_with_forward_ntt(coef_bsk, ntt_plan);
}
fn bootstrap_scratch(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ntt_plan: &Self::Fft,
) -> Result<StackReq, SizeOverflow> {
let _ = ntt_plan;
bootstrap_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)
}
fn bootstrap<ContLweOut, ContLweIn, ContAcc>(
&self,
lwe_out: &mut LweCiphertext<ContLweOut>,
lwe_in: &LweCiphertext<ContLweIn>,
accumulator: &GlweCiphertext<ContAcc>,
ntt_plan: &Self::Fft,
stack: PodStack<'_>,
) where
ContLweOut: ContainerMut<Element = Scalar>,
ContLweIn: Container<Element = Scalar>,
ContAcc: Container<Element = Scalar>,
{
let ntt_plan = Scalar::plan_as_view(ntt_plan);
self.bootstrap(lwe_out, lwe_in, accumulator, ntt_plan, stack)
}
fn fill_with_forward_fourier_scratch(ntt_plan: &Self::Fft) -> Result<StackReq, SizeOverflow> {
let _ = ntt_plan;
Ok(StackReq::empty())
}
}

View File

@@ -0,0 +1,540 @@
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize,
};
use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainerMut;
use crate::core_crypto::commons::traits::{Container, ContiguousEntityContainer, Split};
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::{
as_mut_array, as_ref_array, chain_array_with_context, iter_array,
};
use crate::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt;
use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecompositionLendingIter;
use crate::core_crypto::prelude::{CiphertextModulus, ContainerMut, UnsignedInteger};
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
/// A GGSW ciphertext in the NTT domain.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct CrtNttGgswCiphertext<
CrtScalar,
const N_COMPONENTS: usize,
C: Container<Element = CrtScalar>,
> {
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
}
/// A matrix containing a single level of gadget decomposition, in the NTT domain.
pub struct CrtNttGgswLevelMatrix<
CrtScalar,
const N_COMPONENTS: usize,
C: Container<Element = CrtScalar>,
> {
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
row_count: usize,
decomposition_level: DecompositionLevel,
}
/// A row of a GGSW level matrix, in the NTT domain.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct CrtNttGgswLevelRow<
CrtScalar,
const N_COMPONENTS: usize,
C: Container<Element = CrtScalar>,
> {
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_level: DecompositionLevel,
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, C>
{
pub fn from_container(
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
) -> Self {
let container_len =
polynomial_size.0 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0;
data.iter()
.for_each(|data| assert_eq!(data.container_len(), container_len));
Self {
data,
polynomial_size,
glwe_size,
decomposition_base_log,
decomposition_level_count,
}
}
pub fn polynomial_size(&self) -> PolynomialSize {
self.polynomial_size
}
pub fn glwe_size(&self) -> GlweSize {
self.glwe_size
}
pub fn decomposition_base_log(&self) -> DecompositionBaseLog {
self.decomposition_base_log
}
pub fn decomposition_level_count(&self) -> DecompositionLevelCount {
self.decomposition_level_count
}
pub fn data(self) -> [C; N_COMPONENTS] {
self.data
}
pub fn as_view(&self) -> CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[C::Element]> {
CrtNttGgswCiphertext {
data: as_ref_array(&self.data).map(|data| data.as_ref()),
polynomial_size: self.polynomial_size,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
}
}
pub fn as_mut_view(
&mut self,
) -> CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &mut [C::Element]>
where
C: AsMut<[C::Element]>,
{
CrtNttGgswCiphertext {
data: as_mut_array(&mut self.data).map(|data| data.as_mut()),
polynomial_size: self.polynomial_size,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
}
}
/// Return an iterator over the level matrices.
pub fn into_levels(
self,
) -> impl DoubleEndedIterator<Item = CrtNttGgswLevelMatrix<CrtScalar, N_COMPONENTS, C>>
where
C: Split,
{
iter_array(
self.data
.map(|data| data.split_into(self.decomposition_level_count.0)),
)
.enumerate()
.map(move |(i, data)| {
CrtNttGgswLevelMatrix::from_container(
data,
self.polynomial_size,
self.glwe_size,
self.glwe_size.0,
DecompositionLevel(i + 1),
)
})
}
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
CrtNttGgswLevelMatrix<CrtScalar, N_COMPONENTS, C>
{
pub fn from_container(
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
row_count: usize,
decomposition_level: DecompositionLevel,
) -> Self {
let container_len = polynomial_size.0 * glwe_size.0 * row_count;
data.iter()
.for_each(|data| assert_eq!(data.container_len(), container_len));
Self {
data,
polynomial_size,
glwe_size,
row_count,
decomposition_level,
}
}
pub fn into_rows(
self,
) -> impl DoubleEndedIterator<Item = CrtNttGgswLevelRow<CrtScalar, N_COMPONENTS, C>>
where
C: Split,
{
iter_array(self.data.map(|data| data.split_into(self.row_count))).map(move |data| {
CrtNttGgswLevelRow::from_container(
data,
self.polynomial_size,
self.glwe_size,
self.decomposition_level,
)
})
}
pub fn polynomial_size(&self) -> PolynomialSize {
self.polynomial_size
}
pub fn glwe_size(&self) -> GlweSize {
self.glwe_size
}
pub fn row_count(&self) -> usize {
self.row_count
}
pub fn decomposition_level(&self) -> DecompositionLevel {
self.decomposition_level
}
pub fn data(self) -> [C; N_COMPONENTS] {
self.data
}
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
CrtNttGgswLevelRow<CrtScalar, N_COMPONENTS, C>
{
pub fn from_container(
data: [C; N_COMPONENTS],
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
decomposition_level: DecompositionLevel,
) -> Self {
let container_len = polynomial_size.0 * glwe_size.0;
data.iter()
.for_each(|data| assert_eq!(data.container_len(), container_len));
Self {
data,
polynomial_size,
glwe_size,
decomposition_level,
}
}
pub fn polynomial_size(&self) -> PolynomialSize {
self.polynomial_size
}
pub fn glwe_size(&self) -> GlweSize {
self.glwe_size
}
pub fn decomposition_level(&self) -> DecompositionLevel {
self.decomposition_level
}
pub fn data(self) -> [C; N_COMPONENTS] {
self.data
}
}
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Cont>
CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, Cont>
where
Cont: ContainerMut<Element = CrtScalar>,
{
/// Fill a GGSW ciphertext with the NTT of a GGSW ciphertext in the standard domain.
pub fn fill_with_forward_ntt<Scalar, ContGgsw>(
&mut self,
coef_ggsw: &GgswCiphertext<ContGgsw>,
ntt_plan: Scalar::PlanView<'_>,
) where
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
ContGgsw: Container<Element = Scalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
>(
this: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &mut [CrtScalar]>,
coef_ggsw: GgswCiphertext<&[Scalar]>,
ntt_plan: Scalar::PlanView<'_>,
) {
assert_eq!(coef_ggsw.polynomial_size(), this.polynomial_size());
let poly_size = this.polynomial_size().0;
let data = this.data();
for (ntt, coef_poly) in izip!(
iter_array(data.map(|data| data.into_chunks(poly_size))),
coef_ggsw.as_polynomial_list().iter(),
) {
Scalar::forward_normalized(ntt_plan, ntt, coef_poly.as_ref());
}
}
implementation(self.as_mut_view(), coef_ggsw.as_view(), ntt_plan)
}
}
pub fn add_external_product_assign_scratch<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
) -> Result<StackReq, SizeOverflow> {
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch =
StackReq::try_new_aligned::<CrtScalar>(glwe_size.0 * polynomial_size.0, align)?;
let ntt_scratch_single = StackReq::try_new_aligned::<CrtScalar>(polynomial_size.0, align)?;
let substack2 = StackReq::try_all_of([ntt_scratch_single; N_COMPONENTS])?;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = StackReq::try_any_of([
substack1.try_and(standard_scratch)?,
Scalar::add_backward_scratch(polynomial_size)?,
])?;
substack0.try_and(StackReq::try_all_of([ntt_scratch; N_COMPONENTS])?)
}
#[cfg_attr(__profiling, inline(never))]
pub fn add_external_product_assign<
CrtScalar,
const N_COMPONENTS: usize,
Scalar,
ContOut,
ContGgsw,
ContGlwe,
>(
out: &mut GlweCiphertext<ContOut>,
ggsw: &CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, ContGgsw>,
glwe: &GlweCiphertext<ContGlwe>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) where
CrtScalar: UnsignedInteger,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
ContOut: ContainerMut<Element = Scalar>,
ContGgsw: Container<Element = CrtScalar>,
ContGlwe: Container<Element = Scalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
>(
mut out: GlweCiphertext<&mut [Scalar]>,
ggsw: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
glwe: GlweCiphertext<&[Scalar]>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) {
// we check that the polynomial sizes match
assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size());
assert_eq!(ggsw.polynomial_size(), out.polynomial_size());
// we check that the glwe sizes match
assert_eq!(ggsw.glwe_size(), glwe.glwe_size());
assert_eq!(ggsw.glwe_size(), out.glwe_size());
let align = CACHELINE_ALIGN;
let poly_size = ggsw.polynomial_size().0;
// we round the input mask and body
let decomposer = SignedDecomposer::<Scalar>::new(
ggsw.decomposition_base_log(),
ggsw.decomposition_level_count(),
);
let (mut output_ntt_buffer, mut substack0) =
chain_array_with_context::<_, _, N_COMPONENTS>(stack, |stack| {
stack.make_aligned_with(poly_size * ggsw.glwe_size().0, align, |_| CrtScalar::ZERO)
});
let mut output_ntt_buffer =
as_mut_array(&mut output_ntt_buffer).map(|buffer| &mut **buffer);
{
// ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER
// DOMAIN In this section, we perform the external product in the NTT
// domain, and accumulate the result in the output_ntt_buffer variable.
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
);
// We loop through the levels (we reverse to match the order of the decomposition
// iterator.)
for ggsw_decomp_matrix in ggsw.into_levels().rev() {
// We retrieve the decomposition of this level.
let (glwe_level, glwe_decomp_term, mut substack2) =
collect_next_term(&mut decomposition, &mut substack1, align);
let glwe_decomp_term = GlweCiphertextView::from_container(
&*glwe_decomp_term,
ggsw.polynomial_size(),
CiphertextModulus::new_native(),
);
debug_assert_eq!(ggsw_decomp_matrix.decomposition_level(), glwe_level);
// For each level we have to add the result of the vector-matrix product between the
// decomposition of the glwe, and the ggsw level matrix to the output. To do so, we
// iteratively add to the output, the product between every line of the matrix, and
// the corresponding (scalar) polynomial in the glwe decomposition:
//
// ggsw_mat ggsw_mat
// glwe_dec | - - - - | < glwe_dec | - - - - |
// | - - - | x | - - - - | | - - - | x | - - - - | <
// ^ | - - - - | ^ | - - - - |
//
// t = 1 t = 2 ...
for (ggsw_row, glwe_poly) in izip!(
ggsw_decomp_matrix.into_rows(),
glwe_decomp_term.as_polynomial_list().iter()
) {
let len = poly_size;
let stack = substack2.rb_mut();
let (mut ntt, _) =
chain_array_with_context::<_, _, N_COMPONENTS>(stack, |stack| {
stack.make_aligned_raw::<CrtScalar>(len, align)
});
let mut ntt = as_mut_array(&mut ntt).map(|ntt| &mut **ntt);
// We perform the forward NTT for the glwe polynomial
Scalar::forward(
ntt_plan,
as_mut_array(&mut ntt).map(|buf| &mut **buf),
glwe_poly.as_ref(),
);
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.
for (output_ntt, ggsw_poly) in izip!(
iter_array(
as_mut_array(&mut output_ntt_buffer)
.map(|buf| (&mut **buf).into_chunks(poly_size))
),
iter_array(
as_ref_array(&ggsw_row.data).map(|buf| (&**buf).into_chunks(poly_size))
),
) {
Scalar::mul_accumulate(
ntt_plan,
output_ntt,
ggsw_poly,
as_ref_array(&ntt).map(|buf| &**buf),
);
}
}
}
}
// -------------------------------------------- TRANSFORMATION OF RESULT TO STANDARD DOMAIN
// In this section, we bring the result from the fourier domain, back to the standard
// domain, and add it to the output.
//
// We iterate over the polynomials in the output.
for (mut out, ntt) in izip!(
out.as_mut_polynomial_list().iter_mut(),
iter_array(output_ntt_buffer.map(|buf| buf.into_chunks(poly_size))),
) {
Scalar::add_backward(ntt_plan, out.as_mut(), ntt, substack0.rb_mut());
}
}
implementation(
out.as_mut_view(),
ggsw.as_view(),
glwe.as_view(),
ntt_plan,
stack,
)
}
fn collect_next_term<'a, Scalar: UnsignedInteger>(
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, Scalar>,
PodStack<'a>,
) {
let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap();
let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
(glwe_level, glwe_decomp_term, substack2)
}
/// Return the required memory for [`cmux`].
pub fn cmux_scratch<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Scalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
) -> Result<StackReq, SizeOverflow>
where
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
{
add_external_product_assign_scratch::<CrtScalar, N_COMPONENTS, Scalar>(
glwe_size,
polynomial_size,
)
}
/// This cmux mutates both ct1 and ct0. The result is in ct0 after the method was called.
pub fn cmux<CrtScalar, const N_COMPONENTS: usize, Scalar, ContCt0, ContCt1, ContGgsw>(
ct0: &mut GlweCiphertext<ContCt0>,
ct1: &mut GlweCiphertext<ContCt1>,
ggsw: &CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, ContGgsw>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) where
CrtScalar: UnsignedInteger,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
ContCt0: ContainerMut<Element = Scalar>,
ContCt1: ContainerMut<Element = Scalar>,
ContGgsw: Container<Element = CrtScalar>,
{
fn implementation<
CrtScalar: UnsignedInteger,
const N_COMPONENTS: usize,
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
>(
mut ct0: GlweCiphertext<&mut [Scalar]>,
mut ct1: GlweCiphertext<&mut [Scalar]>,
ggsw: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
ntt_plan: Scalar::PlanView<'_>,
stack: PodStack<'_>,
) {
for (c1, c0) in izip!(ct1.as_mut(), ct0.as_ref(),) {
*c1 = c1.wrapping_sub(*c0);
}
add_external_product_assign(&mut ct0, &ggsw, &ct1, ntt_plan, stack);
}
implementation(
ct0.as_mut_view(),
ct1.as_mut_view(),
ggsw.as_view(),
ntt_plan,
stack,
)
}

View File

@@ -0,0 +1,5 @@
pub mod bootstrap;
pub mod ggsw;
#[cfg(test)]
pub mod tests;

View File

@@ -0,0 +1,12 @@
use super::bootstrap::CrtNttLweBootstrapKey;
use crate::core_crypto::fft_impl::common::tests::test_bootstrap_generic;
use crate::core_crypto::prelude::*;
use aligned_vec::ABox;
#[test]
fn test_crt_bootstrap_u64() {
test_bootstrap_generic::<u64, CrtNttLweBootstrapKey<u32, 5, ABox<[u32]>>>(
StandardDev(0.000007069849454709433),
StandardDev(0.00000000000000029403601535432533),
);
}

View File

@@ -0,0 +1 @@
pub mod ntt;

View File

@@ -0,0 +1,179 @@
use crate::core_crypto::commons::parameters::PolynomialSize;
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::prelude::UnsignedInteger;
use aligned_vec::CACHELINE_ALIGN;
use concrete_ntt::native64::Plan32;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
#[derive(Clone)]
pub(crate) struct PlanWrapper(Plan32);
impl core::ops::Deref for PlanWrapper {
type Target = Plan32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl core::fmt::Debug for PlanWrapper {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
"[?]".fmt(f)
}
}
#[derive(Clone, Debug)]
pub struct CrtNtt64 {
inner: Arc<PlanWrapper>,
}
/// View type for [`CrtNtt64`].
#[derive(Clone, Copy, Debug)]
pub struct CrtNtt64View<'a> {
pub(crate) inner: &'a PlanWrapper,
}
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<PlanWrapper>>>>>;
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
}
impl CrtNtt64 {
/// Polynomial of size `size`.
pub fn new(size: PolynomialSize) -> Self {
let global_plans = plans();
let n = size.0;
let get_plan = || {
let plans = global_plans.read().unwrap();
let plan = plans.get(&n).cloned();
drop(plans);
plan.map(|p| {
p.get_or_init(|| Arc::new(PlanWrapper(Plan32::try_new(n).unwrap())))
.clone()
})
};
// could not find a plan of the given size, we lock the map again and try to insert it
let mut plans = global_plans.write().unwrap();
if let Entry::Vacant(v) = plans.entry(n) {
v.insert(Arc::new(OnceLock::new()));
}
drop(plans);
Self {
inner: get_plan().unwrap(),
}
}
pub fn as_view(&self) -> CrtNtt64View<'_> {
CrtNtt64View { inner: &self.inner }
}
}
pub trait CrtNtt<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize>: UnsignedInteger {
type Plan;
type PlanView<'a>: Copy;
fn new_plan(size: PolynomialSize) -> Self::Plan;
fn plan_as_view(plan: &Self::Plan) -> Self::PlanView<'_>;
fn forward(plan: Self::PlanView<'_>, ntt: [&mut [CrtScalar]; N_COMPONENTS], standard: &[Self]);
fn forward_normalized(
plan: Self::PlanView<'_>,
ntt: [&mut [CrtScalar]; N_COMPONENTS],
standard: &[Self],
);
fn add_backward(
plan: Self::PlanView<'_>,
standard: &mut [Self],
ntt: [&mut [CrtScalar]; N_COMPONENTS],
stack: PodStack<'_>,
);
fn add_backward_scratch(polynomial_size: PolynomialSize) -> Result<StackReq, SizeOverflow>;
fn mul_accumulate(
plan: Self::PlanView<'_>,
acc: [&mut [CrtScalar]; N_COMPONENTS],
lhs: [&[CrtScalar]; N_COMPONENTS],
rhs: [&[CrtScalar]; N_COMPONENTS],
);
}
impl CrtNtt<u32, 5> for u64 {
type Plan = CrtNtt64;
type PlanView<'a> = CrtNtt64View<'a>;
fn new_plan(size: PolynomialSize) -> Self::Plan {
Self::Plan::new(size)
}
fn plan_as_view(plan: &Self::Plan) -> Self::PlanView<'_> {
plan.as_view()
}
fn forward(
plan: Self::PlanView<'_>,
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
standard: &[Self],
) {
plan.inner.0.fwd(standard, ntt0, ntt1, ntt2, ntt3, ntt4);
}
fn forward_normalized(
plan: Self::PlanView<'_>,
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
standard: &[Self],
) {
plan.inner.0.fwd(standard, ntt0, ntt1, ntt2, ntt3, ntt4);
plan.inner.0.ntt_0().normalize(ntt0);
plan.inner.0.ntt_1().normalize(ntt1);
plan.inner.0.ntt_2().normalize(ntt2);
plan.inner.0.ntt_3().normalize(ntt3);
plan.inner.0.ntt_4().normalize(ntt4);
}
fn add_backward(
plan: Self::PlanView<'_>,
standard: &mut [Self],
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
stack: PodStack<'_>,
) {
let n = standard.len();
let (mut tmp, _) = stack.make_aligned_raw::<u64>(n, CACHELINE_ALIGN);
plan.inner.0.inv(&mut tmp, ntt0, ntt1, ntt2, ntt3, ntt4);
// autovectorize
pulp::Arch::new().dispatch(
#[inline(always)]
|| {
for (out, inp) in izip!(standard, &*tmp) {
*out = u64::wrapping_add(*out, *inp);
}
},
)
}
fn add_backward_scratch(polynomial_size: PolynomialSize) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<u64>(polynomial_size.0, CACHELINE_ALIGN)
}
fn mul_accumulate(
plan: Self::PlanView<'_>,
[acc0, acc1, acc2, acc3, acc4]: [&mut [u32]; 5],
[lhs0, lhs1, lhs2, lhs3, lhs4]: [&[u32]; 5],
[rhs0, rhs1, rhs2, rhs3, rhs4]: [&[u32]; 5],
) {
plan.inner.ntt_0().mul_accumulate(acc0, lhs0, rhs0);
plan.inner.ntt_1().mul_accumulate(acc1, lhs1, rhs1);
plan.inner.ntt_2().mul_accumulate(acc2, lhs2, rhs2);
plan.inner.ntt_3().mul_accumulate(acc3, lhs3, rhs3);
plan.inner.ntt_4().mul_accumulate(acc4, lhs4, rhs4);
}
}

View File

@@ -0,0 +1,2 @@
pub mod crypto;
pub mod math;

View File

@@ -7,3 +7,5 @@ pub mod fft64;
pub mod fft128;
mod fft128_u128;
pub mod crt_ntt;