mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
1 Commits
al/pfail_g
...
ntt-experi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b84c0672ce |
@@ -59,6 +59,7 @@ rayon = { version = "1.5.0" }
|
||||
bincode = { version = "1.3.3", optional = true }
|
||||
concrete-fft = { version = "0.3.0", features = ["serde", "fft128"] }
|
||||
pulp = "0.13"
|
||||
concrete-ntt = "0.1.0"
|
||||
aligned-vec = { version = "0.5", features = ["serde"] }
|
||||
dyn-stack = { version = "0.9" }
|
||||
paste = { version = "1.0.7", optional = true }
|
||||
@@ -164,6 +165,11 @@ name = "pbs128-bench"
|
||||
path = "benches/core_crypto/pbs128_bench.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "pbs-crt-bench"
|
||||
path = "benches/core_crypto/pbs_crt_bench.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "boolean-bench"
|
||||
path = "benches/boolean/bench.rs"
|
||||
|
||||
116
tfhe/benches/core_crypto/pbs_crt_bench.rs
Normal file
116
tfhe/benches/core_crypto/pbs_crt_bench.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use dyn_stack::PodStack;
|
||||
|
||||
fn criterion_bench(c: &mut Criterion) {
|
||||
{
|
||||
use tfhe::core_crypto::fft_impl::crt_ntt::crypto::bootstrap::{
|
||||
bootstrap_scratch, CrtNttLweBootstrapKey,
|
||||
};
|
||||
use tfhe::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt64;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
type Scalar = u64;
|
||||
|
||||
let small_lwe_dimension = LweDimension(742);
|
||||
let glwe_dimension = GlweDimension(1);
|
||||
let polynomial_size = PolynomialSize(2048);
|
||||
let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
|
||||
let pbs_base_log = DecompositionBaseLog(23);
|
||||
let pbs_level = DecompositionLevelCount(1);
|
||||
|
||||
// Request the best seeder possible, starting with hardware entropy sources and falling back
|
||||
// to /dev/random on Unix systems if enabled via cargo features
|
||||
let mut boxed_seeder = new_seeder();
|
||||
// Get a mutable reference to the seeder as a trait object from the Box returned by
|
||||
// new_seeder
|
||||
let seeder = boxed_seeder.as_mut();
|
||||
|
||||
// Create a generator which uses a CSPRNG to generate secret keys
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
// Create a generator which uses two CSPRNGs to generate public masks and secret encryption
|
||||
// noise
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
|
||||
// Generate an LweSecretKey with binary coefficients
|
||||
let small_lwe_sk =
|
||||
LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator);
|
||||
|
||||
// Generate a GlweSecretKey with binary coefficients
|
||||
let glwe_sk = GlweSecretKey::<Vec<Scalar>>::generate_new_binary(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
// Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey
|
||||
let big_lwe_sk = glwe_sk.into_lwe_secret_key();
|
||||
|
||||
// Create the empty bootstrapping key in the NTT domain
|
||||
let ntt_bsk = CrtNttLweBootstrapKey::new(
|
||||
small_lwe_dimension,
|
||||
polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
);
|
||||
|
||||
let fft = CrtNtt64::new(polynomial_size);
|
||||
let fft = fft.as_view();
|
||||
|
||||
// We don't need the standard bootstrapping key anymore
|
||||
|
||||
// Our 4 bits message space
|
||||
let message_modulus: Scalar = 1 << 4;
|
||||
|
||||
// Our input message
|
||||
let input_message: Scalar = 3;
|
||||
|
||||
// Delta used to encode 4 bits of message + a bit of padding on Scalar
|
||||
let delta: Scalar = (1 << (Scalar::BITS - 1)) / message_modulus;
|
||||
|
||||
// Apply our encoding
|
||||
let plaintext = Plaintext(input_message * delta);
|
||||
|
||||
// Allocate a new LweCiphertext and encrypt our plaintext
|
||||
let lwe_ciphertext_in: LweCiphertextOwned<Scalar> = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&small_lwe_sk,
|
||||
plaintext,
|
||||
lwe_modular_std_dev,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let accumulator: GlweCiphertextOwned<Scalar> =
|
||||
GlweCiphertextOwned::new(Scalar::ONE, glwe_dimension.to_glwe_size(), polynomial_size);
|
||||
|
||||
// Allocate the LweCiphertext to store the result of the PBS
|
||||
let mut pbs_multiplication_ct: LweCiphertext<Vec<Scalar>> =
|
||||
LweCiphertext::new(0, big_lwe_sk.lwe_dimension().to_lwe_size());
|
||||
|
||||
let mut buf = vec![
|
||||
0u8;
|
||||
bootstrap_scratch::<u32, 5, Scalar>(
|
||||
ntt_bsk.glwe_size(),
|
||||
ntt_bsk.polynomial_size(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required()
|
||||
];
|
||||
|
||||
c.bench_function("pbs-crt-u64-u32x5", |b| {
|
||||
b.iter(|| {
|
||||
ntt_bsk.bootstrap(
|
||||
&mut pbs_multiplication_ct,
|
||||
&lwe_ciphertext_in,
|
||||
&accumulator,
|
||||
fft,
|
||||
PodStack::new(&mut buf),
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_bench);
|
||||
criterion_main!(benches);
|
||||
@@ -1,12 +1,11 @@
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::numeric::CastInto;
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
|
||||
ModulusSwitchOffset, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::Container;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::prelude::ContainerMut;
|
||||
use crate::core_crypto::prelude::{ContainerMut, UnsignedInteger};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
|
||||
/// This function switches modulus for a single coefficient of a ciphertext,
|
||||
@@ -14,7 +13,7 @@ use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
///
|
||||
/// offset: the number of msb discarded
|
||||
/// lut_count_log: the right padding
|
||||
pub fn pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
|
||||
pub fn pbs_modulus_switch<Scalar: UnsignedInteger + CastInto<usize>>(
|
||||
input: Scalar,
|
||||
poly_size: PolynomialSize,
|
||||
offset: ModulusSwitchOffset,
|
||||
@@ -287,3 +286,102 @@ pub mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_to_array<const N: usize, I: Iterator>(iter: I) -> [I::Item; N] {
|
||||
// TODO: avoid allocating here
|
||||
match iter.collect::<Vec<_>>().try_into() {
|
||||
Ok(arr) => arr,
|
||||
Err(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chain_array_with_context<Ctx, T, const N: usize>(
|
||||
ctx: Ctx,
|
||||
f: impl FnMut(Ctx) -> (T, Ctx),
|
||||
) -> ([T; N], Ctx) {
|
||||
let mut ctx = Some(ctx);
|
||||
let mut f = f;
|
||||
(
|
||||
[(); N].map(|()| {
|
||||
let local_ctx = ctx.take().unwrap();
|
||||
let (val, local_ctx) = f(local_ctx);
|
||||
ctx = Some(local_ctx);
|
||||
val
|
||||
}),
|
||||
ctx.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn as_ref_array<T, const N: usize>(array: &[T; N]) -> [&T; N] {
|
||||
collect_to_array(array.iter())
|
||||
}
|
||||
pub fn as_mut_array<T, const N: usize>(array: &mut [T; N]) -> [&mut T; N] {
|
||||
collect_to_array(array.iter_mut())
|
||||
}
|
||||
pub fn zip_array<T, U, const N: usize>(first: [T; N], second: [U; N]) -> [(T, U); N] {
|
||||
collect_to_array(core::iter::zip(first, second))
|
||||
}
|
||||
|
||||
pub struct ArrayIter<I, const N: usize> {
|
||||
iters: [I; N],
|
||||
}
|
||||
|
||||
impl<I: Iterator, const N: usize> Iterator for ArrayIter<I, N> {
|
||||
type Item = [I::Item; N];
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let items = as_mut_array(&mut self.iters).map(|iter| iter.next());
|
||||
if items.iter().all(|item| item.is_some()) {
|
||||
Some(items.map(|item| item.unwrap()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
// taken from core::iter::Zip's size_hint impl
|
||||
self.iters.iter().fold((usize::MAX, None), |acc, iter| {
|
||||
let (a_lower, a_upper) = acc;
|
||||
let (b_lower, b_upper) = iter.size_hint();
|
||||
|
||||
let lower = core::cmp::min(a_lower, b_lower);
|
||||
|
||||
let upper = match (a_upper, b_upper) {
|
||||
(Some(x), Some(y)) => Some(core::cmp::min(x, y)),
|
||||
(Some(x), None) => Some(x),
|
||||
(None, Some(y)) => Some(y),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
(lower, upper)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: DoubleEndedIterator, const N: usize> DoubleEndedIterator for ArrayIter<I, N> {
|
||||
#[inline]
|
||||
fn next_back(&mut self) -> Option<Self::Item> {
|
||||
let items = as_mut_array(&mut self.iters).map(|iter| iter.next_back());
|
||||
if items.iter().all(|item| item.is_some()) {
|
||||
Some(items.map(|item| item.unwrap()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: ExactSizeIterator, const N: usize> ExactSizeIterator for ArrayIter<I, N> {
|
||||
#[inline]
|
||||
fn len(&self) -> usize {
|
||||
self.iters
|
||||
.iter()
|
||||
.fold(usize::MAX, |acc, iter| acc.min(iter.len()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter_array<I: IntoIterator, const N: usize>(iters: [I; N]) -> ArrayIter<I::IntoIter, N> {
|
||||
ArrayIter {
|
||||
iters: iters.map(|iter| iter.into_iter()),
|
||||
}
|
||||
}
|
||||
|
||||
452
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/bootstrap.rs
Normal file
452
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/bootstrap.rs
Normal file
@@ -0,0 +1,452 @@
|
||||
use super::ggsw::{cmux, cmux_scratch, CrtNttGgswCiphertext};
|
||||
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
use crate::core_crypto::commons::numeric::CastInto;
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
|
||||
ModulusSwitchOffset, MonomialDegree, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::{
|
||||
Container, ContiguousEntityContainer, ContiguousEntityContainerMut, Split,
|
||||
};
|
||||
use crate::core_crypto::commons::utils::izip;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::common::{
|
||||
as_mut_array, as_ref_array, iter_array, pbs_modulus_switch, FourierBootstrapKey,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt;
|
||||
use crate::core_crypto::prelude::{CiphertextModulus, ContainerMut, UnsignedInteger};
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
|
||||
|
||||
pub struct CrtNttLweBootstrapKey<
|
||||
CrtScalar,
|
||||
const N_COMPONENTS: usize,
|
||||
C: Container<Element = CrtScalar>,
|
||||
> {
|
||||
data: [C; N_COMPONENTS],
|
||||
polynomial_size: PolynomialSize,
|
||||
input_lwe_dimension: LweDimension,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
|
||||
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, C>
|
||||
{
|
||||
pub fn from_container(
|
||||
data: [C; N_COMPONENTS],
|
||||
polynomial_size: PolynomialSize,
|
||||
input_lwe_dimension: LweDimension,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
let container_len = input_lwe_dimension.0
|
||||
* polynomial_size.0
|
||||
* decomposition_level_count.0
|
||||
* glwe_size.0
|
||||
* glwe_size.0;
|
||||
data.iter()
|
||||
.for_each(|data| assert_eq!(data.container_len(), container_len));
|
||||
Self {
|
||||
data,
|
||||
polynomial_size,
|
||||
input_lwe_dimension,
|
||||
glwe_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return an iterator over the GGSW ciphertexts composing the key.
|
||||
pub fn into_ggsw_iter(
|
||||
self,
|
||||
) -> impl DoubleEndedIterator<Item = CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, C>>
|
||||
where
|
||||
C: Split,
|
||||
{
|
||||
iter_array(
|
||||
self.data
|
||||
.map(|data| data.split_into(self.input_lwe_dimension.0)),
|
||||
)
|
||||
.map(move |data| {
|
||||
CrtNttGgswCiphertext::from_container(
|
||||
data,
|
||||
self.polynomial_size,
|
||||
self.glwe_size,
|
||||
self.decomposition_base_log,
|
||||
self.decomposition_level_count,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn input_lwe_dimension(&self) -> LweDimension {
|
||||
self.input_lwe_dimension
|
||||
}
|
||||
|
||||
pub fn polynomial_size(&self) -> PolynomialSize {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
pub fn glwe_size(&self) -> GlweSize {
|
||||
self.glwe_size
|
||||
}
|
||||
|
||||
pub fn decomposition_base_log(&self) -> DecompositionBaseLog {
|
||||
self.decomposition_base_log
|
||||
}
|
||||
|
||||
pub fn decomposition_level_count(&self) -> DecompositionLevelCount {
|
||||
self.decomposition_level_count
|
||||
}
|
||||
|
||||
pub fn output_lwe_dimension(&self) -> LweDimension {
|
||||
LweDimension((self.glwe_size.0 - 1) * self.polynomial_size().0)
|
||||
}
|
||||
|
||||
pub fn data(self) -> [C; N_COMPONENTS] {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub fn as_view(&self) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[C::Element]> {
|
||||
CrtNttLweBootstrapKey {
|
||||
data: as_ref_array(&self.data).map(|data| data.as_ref()),
|
||||
polynomial_size: self.polynomial_size,
|
||||
input_lwe_dimension: self.input_lwe_dimension,
|
||||
glwe_size: self.glwe_size,
|
||||
decomposition_base_log: self.decomposition_base_log,
|
||||
decomposition_level_count: self.decomposition_level_count,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_view(
|
||||
&mut self,
|
||||
) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &mut [C::Element]>
|
||||
where
|
||||
C: AsMut<[C::Element]>,
|
||||
{
|
||||
CrtNttLweBootstrapKey {
|
||||
data: as_mut_array(&mut self.data).map(|data| data.as_mut()),
|
||||
polynomial_size: self.polynomial_size,
|
||||
input_lwe_dimension: self.input_lwe_dimension,
|
||||
glwe_size: self.glwe_size,
|
||||
decomposition_base_log: self.decomposition_base_log,
|
||||
decomposition_level_count: self.decomposition_level_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize>
|
||||
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>>
|
||||
{
|
||||
pub fn new(
|
||||
input_lwe_dimension: LweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>> {
|
||||
let container_len = polynomial_size.0
|
||||
* input_lwe_dimension.0
|
||||
* decomposition_level_count.0
|
||||
* glwe_size.0
|
||||
* glwe_size.0;
|
||||
|
||||
let boxed =
|
||||
[(); N_COMPONENTS].map(|()| avec![CrtScalar::ZERO; container_len].into_boxed_slice());
|
||||
|
||||
CrtNttLweBootstrapKey::from_container(
|
||||
boxed,
|
||||
polynomial_size,
|
||||
input_lwe_dimension,
|
||||
glwe_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Cont>
|
||||
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, Cont>
|
||||
where
|
||||
Cont: ContainerMut<Element = CrtScalar>,
|
||||
{
|
||||
/// Fill a bootstrapping key with the NTT of a bootstrapping key in the standard
|
||||
/// domain.
|
||||
pub fn fill_with_forward_ntt<Scalar, ContBsk>(
|
||||
&mut self,
|
||||
coef_bsk: &LweBootstrapKey<ContBsk>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
) where
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
ContBsk: Container<Element = Scalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
>(
|
||||
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &mut [CrtScalar]>,
|
||||
coef_bsk: LweBootstrapKey<&[Scalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
) {
|
||||
for (mut ntt_ggsw, standard_ggsw) in izip!(this.into_ggsw_iter(), coef_bsk.iter()) {
|
||||
ntt_ggsw.fill_with_forward_ntt(&standard_ggsw, ntt_plan);
|
||||
}
|
||||
}
|
||||
implementation(self.as_mut_view(), coef_bsk.as_view(), ntt_plan)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the required memory for [`CrtNttLweBootstrapKey::blind_rotate_assign`].
|
||||
pub fn blind_rotate_scratch<CrtScalar, const N_COMPONENTS: usize, Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> Result<StackReq, SizeOverflow>
|
||||
where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
{
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?.try_and(
|
||||
cmux_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)?,
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the required memory for [`CrtNttLweBootstrapKey::bootstrap`].
|
||||
pub fn bootstrap_scratch<CrtScalar, const N_COMPONENTS: usize, Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> Result<StackReq, SizeOverflow>
|
||||
where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
{
|
||||
blind_rotate_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)?.try_and(
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?,
|
||||
)
|
||||
}
|
||||
|
||||
impl<CrtScalar, const N_COMPONENTS: usize, Cont>
|
||||
CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, Cont>
|
||||
where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Cont: Container<Element = CrtScalar>,
|
||||
{
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
pub fn blind_rotate_assign<Scalar, ContLut, ContLwe>(
|
||||
&self,
|
||||
lut: &mut GlweCiphertext<ContLut>,
|
||||
lwe: &LweCiphertext<ContLwe>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
|
||||
ContLut: ContainerMut<Element = Scalar>,
|
||||
ContLwe: Container<Element = Scalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
|
||||
>(
|
||||
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
|
||||
mut lut: GlweCiphertext<&mut [Scalar]>,
|
||||
lwe: LweCiphertext<&[Scalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
mut stack: PodStack<'_>,
|
||||
) {
|
||||
let lwe = lwe.as_ref();
|
||||
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
let monomial_degree = pbs_modulus_switch(
|
||||
*lwe_body,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
lut.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
});
|
||||
|
||||
// We initialize the ct_0 used for the successive cmuxes
|
||||
let mut ct0 = lut;
|
||||
|
||||
for (lwe_mask_element, bootstrap_key_ggsw) in
|
||||
izip!(lwe_mask.iter(), this.into_ggsw_iter())
|
||||
{
|
||||
if *lwe_mask_element != Scalar::ZERO {
|
||||
let stack = stack.rb_mut();
|
||||
// We copy ct_0 to ct_1
|
||||
let (mut ct1, stack) =
|
||||
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
|
||||
let mut ct1 = GlweCiphertextMutView::from_container(
|
||||
&mut *ct1,
|
||||
ct0.polynomial_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
// We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat}
|
||||
for mut poly in ct1.as_mut_polynomial_list().iter_mut() {
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(pbs_modulus_switch(
|
||||
*lwe_mask_element,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
// ct1 is re-created each loop it can be moved, ct0 is already a view, but
|
||||
// as_mut_view is required to keep borrow rules consistent
|
||||
cmux(&mut ct0, &mut ct1, &bootstrap_key_ggsw, ntt_plan, stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
implementation(
|
||||
self.as_view(),
|
||||
lut.as_mut_view(),
|
||||
lwe.as_view(),
|
||||
ntt_plan,
|
||||
stack,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bootstrap<Scalar, ContLweOut, ContLweIn, ContAcc>(
|
||||
&self,
|
||||
lwe_out: &mut LweCiphertext<ContLweOut>,
|
||||
lwe_in: &LweCiphertext<ContLweIn>,
|
||||
accumulator: &GlweCiphertext<ContAcc>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
|
||||
ContLweOut: ContainerMut<Element = Scalar>,
|
||||
ContLweIn: Container<Element = Scalar>,
|
||||
ContAcc: Container<Element = Scalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
|
||||
>(
|
||||
this: CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
|
||||
mut lwe_out: LweCiphertext<&mut [Scalar]>,
|
||||
lwe_in: LweCiphertext<&[Scalar]>,
|
||||
accumulator: GlweCiphertext<&[Scalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
let (mut local_accumulator_data, stack) =
|
||||
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
|
||||
let mut local_accumulator = GlweCiphertextMutView::from_container(
|
||||
&mut *local_accumulator_data,
|
||||
accumulator.polynomial_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
this.blind_rotate_assign(
|
||||
&mut local_accumulator.as_mut_view(),
|
||||
&lwe_in,
|
||||
ntt_plan,
|
||||
stack,
|
||||
);
|
||||
extract_lwe_sample_from_glwe_ciphertext(
|
||||
&local_accumulator,
|
||||
&mut lwe_out,
|
||||
MonomialDegree(0),
|
||||
);
|
||||
}
|
||||
|
||||
implementation(
|
||||
self.as_view(),
|
||||
lwe_out.as_mut_view(),
|
||||
lwe_in.as_view(),
|
||||
accumulator.as_view(),
|
||||
ntt_plan,
|
||||
stack,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar, const N_COMPONENTS: usize, Scalar> FourierBootstrapKey<Scalar>
|
||||
for CrtNttLweBootstrapKey<CrtScalar, N_COMPONENTS, ABox<[CrtScalar]>>
|
||||
where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS> + CastInto<usize>,
|
||||
{
|
||||
type Fft = Scalar::Plan;
|
||||
|
||||
fn new_fft(polynomial_size: PolynomialSize) -> Self::Fft {
|
||||
Scalar::new_plan(polynomial_size)
|
||||
}
|
||||
|
||||
fn new(
|
||||
input_lwe_dimension: LweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
input_lwe_dimension,
|
||||
polynomial_size,
|
||||
glwe_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
)
|
||||
}
|
||||
|
||||
fn fill_with_forward_fourier<ContBsk>(
|
||||
&mut self,
|
||||
coef_bsk: &LweBootstrapKey<ContBsk>,
|
||||
ntt_plan: &Self::Fft,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
ContBsk: Container<Element = Scalar>,
|
||||
{
|
||||
let _ = stack;
|
||||
let ntt_plan = Scalar::plan_as_view(ntt_plan);
|
||||
self.fill_with_forward_ntt(coef_bsk, ntt_plan);
|
||||
}
|
||||
|
||||
fn bootstrap_scratch(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
ntt_plan: &Self::Fft,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
let _ = ntt_plan;
|
||||
bootstrap_scratch::<CrtScalar, N_COMPONENTS, Scalar>(glwe_size, polynomial_size)
|
||||
}
|
||||
|
||||
fn bootstrap<ContLweOut, ContLweIn, ContAcc>(
|
||||
&self,
|
||||
lwe_out: &mut LweCiphertext<ContLweOut>,
|
||||
lwe_in: &LweCiphertext<ContLweIn>,
|
||||
accumulator: &GlweCiphertext<ContAcc>,
|
||||
ntt_plan: &Self::Fft,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
ContLweOut: ContainerMut<Element = Scalar>,
|
||||
ContLweIn: Container<Element = Scalar>,
|
||||
ContAcc: Container<Element = Scalar>,
|
||||
{
|
||||
let ntt_plan = Scalar::plan_as_view(ntt_plan);
|
||||
self.bootstrap(lwe_out, lwe_in, accumulator, ntt_plan, stack)
|
||||
}
|
||||
|
||||
fn fill_with_forward_fourier_scratch(ntt_plan: &Self::Fft) -> Result<StackReq, SizeOverflow> {
|
||||
let _ = ntt_plan;
|
||||
Ok(StackReq::empty())
|
||||
}
|
||||
}
|
||||
540
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/ggsw.rs
Normal file
540
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/ggsw.rs
Normal file
@@ -0,0 +1,540 @@
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainerMut;
|
||||
use crate::core_crypto::commons::traits::{Container, ContiguousEntityContainer, Split};
|
||||
use crate::core_crypto::commons::utils::izip;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::common::{
|
||||
as_mut_array, as_ref_array, chain_array_with_context, iter_array,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::crt_ntt::math::ntt::CrtNtt;
|
||||
use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecompositionLendingIter;
|
||||
use crate::core_crypto::prelude::{CiphertextModulus, ContainerMut, UnsignedInteger};
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
|
||||
|
||||
/// A GGSW ciphertext in the NTT domain.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct CrtNttGgswCiphertext<
|
||||
CrtScalar,
|
||||
const N_COMPONENTS: usize,
|
||||
C: Container<Element = CrtScalar>,
|
||||
> {
|
||||
data: [C; N_COMPONENTS],
|
||||
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
}
|
||||
|
||||
/// A matrix containing a single level of gadget decomposition, in the NTT domain.
|
||||
pub struct CrtNttGgswLevelMatrix<
|
||||
CrtScalar,
|
||||
const N_COMPONENTS: usize,
|
||||
C: Container<Element = CrtScalar>,
|
||||
> {
|
||||
data: [C; N_COMPONENTS],
|
||||
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
row_count: usize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
}
|
||||
|
||||
/// A row of a GGSW level matrix, in the NTT domain.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct CrtNttGgswLevelRow<
|
||||
CrtScalar,
|
||||
const N_COMPONENTS: usize,
|
||||
C: Container<Element = CrtScalar>,
|
||||
> {
|
||||
data: [C; N_COMPONENTS],
|
||||
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
|
||||
CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, C>
|
||||
{
|
||||
pub fn from_container(
|
||||
data: [C; N_COMPONENTS],
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
let container_len =
|
||||
polynomial_size.0 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0;
|
||||
|
||||
data.iter()
|
||||
.for_each(|data| assert_eq!(data.container_len(), container_len));
|
||||
|
||||
Self {
|
||||
data,
|
||||
polynomial_size,
|
||||
glwe_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_size(&self) -> PolynomialSize {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
pub fn glwe_size(&self) -> GlweSize {
|
||||
self.glwe_size
|
||||
}
|
||||
|
||||
pub fn decomposition_base_log(&self) -> DecompositionBaseLog {
|
||||
self.decomposition_base_log
|
||||
}
|
||||
|
||||
pub fn decomposition_level_count(&self) -> DecompositionLevelCount {
|
||||
self.decomposition_level_count
|
||||
}
|
||||
|
||||
pub fn data(self) -> [C; N_COMPONENTS] {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub fn as_view(&self) -> CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[C::Element]> {
|
||||
CrtNttGgswCiphertext {
|
||||
data: as_ref_array(&self.data).map(|data| data.as_ref()),
|
||||
polynomial_size: self.polynomial_size,
|
||||
glwe_size: self.glwe_size,
|
||||
decomposition_base_log: self.decomposition_base_log,
|
||||
decomposition_level_count: self.decomposition_level_count,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_view(
|
||||
&mut self,
|
||||
) -> CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &mut [C::Element]>
|
||||
where
|
||||
C: AsMut<[C::Element]>,
|
||||
{
|
||||
CrtNttGgswCiphertext {
|
||||
data: as_mut_array(&mut self.data).map(|data| data.as_mut()),
|
||||
polynomial_size: self.polynomial_size,
|
||||
glwe_size: self.glwe_size,
|
||||
decomposition_base_log: self.decomposition_base_log,
|
||||
decomposition_level_count: self.decomposition_level_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return an iterator over the level matrices.
|
||||
pub fn into_levels(
|
||||
self,
|
||||
) -> impl DoubleEndedIterator<Item = CrtNttGgswLevelMatrix<CrtScalar, N_COMPONENTS, C>>
|
||||
where
|
||||
C: Split,
|
||||
{
|
||||
iter_array(
|
||||
self.data
|
||||
.map(|data| data.split_into(self.decomposition_level_count.0)),
|
||||
)
|
||||
.enumerate()
|
||||
.map(move |(i, data)| {
|
||||
CrtNttGgswLevelMatrix::from_container(
|
||||
data,
|
||||
self.polynomial_size,
|
||||
self.glwe_size,
|
||||
self.glwe_size.0,
|
||||
DecompositionLevel(i + 1),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
|
||||
CrtNttGgswLevelMatrix<CrtScalar, N_COMPONENTS, C>
|
||||
{
|
||||
pub fn from_container(
|
||||
data: [C; N_COMPONENTS],
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
row_count: usize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
) -> Self {
|
||||
let container_len = polynomial_size.0 * glwe_size.0 * row_count;
|
||||
|
||||
data.iter()
|
||||
.for_each(|data| assert_eq!(data.container_len(), container_len));
|
||||
|
||||
Self {
|
||||
data,
|
||||
polynomial_size,
|
||||
glwe_size,
|
||||
row_count,
|
||||
decomposition_level,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_rows(
|
||||
self,
|
||||
) -> impl DoubleEndedIterator<Item = CrtNttGgswLevelRow<CrtScalar, N_COMPONENTS, C>>
|
||||
where
|
||||
C: Split,
|
||||
{
|
||||
iter_array(self.data.map(|data| data.split_into(self.row_count))).map(move |data| {
|
||||
CrtNttGgswLevelRow::from_container(
|
||||
data,
|
||||
self.polynomial_size,
|
||||
self.glwe_size,
|
||||
self.decomposition_level,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn polynomial_size(&self) -> PolynomialSize {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
pub fn glwe_size(&self) -> GlweSize {
|
||||
self.glwe_size
|
||||
}
|
||||
|
||||
pub fn row_count(&self) -> usize {
|
||||
self.row_count
|
||||
}
|
||||
|
||||
pub fn decomposition_level(&self) -> DecompositionLevel {
|
||||
self.decomposition_level
|
||||
}
|
||||
|
||||
pub fn data(self) -> [C; N_COMPONENTS] {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, C: Container<Element = CrtScalar>>
|
||||
CrtNttGgswLevelRow<CrtScalar, N_COMPONENTS, C>
|
||||
{
|
||||
pub fn from_container(
|
||||
data: [C; N_COMPONENTS],
|
||||
polynomial_size: PolynomialSize,
|
||||
glwe_size: GlweSize,
|
||||
decomposition_level: DecompositionLevel,
|
||||
) -> Self {
|
||||
let container_len = polynomial_size.0 * glwe_size.0;
|
||||
|
||||
data.iter()
|
||||
.for_each(|data| assert_eq!(data.container_len(), container_len));
|
||||
|
||||
Self {
|
||||
data,
|
||||
polynomial_size,
|
||||
glwe_size,
|
||||
decomposition_level,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_size(&self) -> PolynomialSize {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
pub fn glwe_size(&self) -> GlweSize {
|
||||
self.glwe_size
|
||||
}
|
||||
|
||||
pub fn decomposition_level(&self) -> DecompositionLevel {
|
||||
self.decomposition_level
|
||||
}
|
||||
|
||||
pub fn data(self) -> [C; N_COMPONENTS] {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Cont>
|
||||
CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, Cont>
|
||||
where
|
||||
Cont: ContainerMut<Element = CrtScalar>,
|
||||
{
|
||||
/// Fill a GGSW ciphertext with the NTT of a GGSW ciphertext in the standard domain.
|
||||
|
||||
pub fn fill_with_forward_ntt<Scalar, ContGgsw>(
|
||||
&mut self,
|
||||
coef_ggsw: &GgswCiphertext<ContGgsw>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
) where
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
ContGgsw: Container<Element = Scalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
>(
|
||||
this: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &mut [CrtScalar]>,
|
||||
coef_ggsw: GgswCiphertext<&[Scalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
) {
|
||||
assert_eq!(coef_ggsw.polynomial_size(), this.polynomial_size());
|
||||
let poly_size = this.polynomial_size().0;
|
||||
let data = this.data();
|
||||
|
||||
for (ntt, coef_poly) in izip!(
|
||||
iter_array(data.map(|data| data.into_chunks(poly_size))),
|
||||
coef_ggsw.as_polynomial_list().iter(),
|
||||
) {
|
||||
Scalar::forward_normalized(ntt_plan, ntt, coef_poly.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
implementation(self.as_mut_view(), coef_ggsw.as_view(), ntt_plan)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_external_product_assign_scratch<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> Result<StackReq, SizeOverflow> {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let standard_scratch =
|
||||
StackReq::try_new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch =
|
||||
StackReq::try_new_aligned::<CrtScalar>(glwe_size.0 * polynomial_size.0, align)?;
|
||||
let ntt_scratch_single = StackReq::try_new_aligned::<CrtScalar>(polynomial_size.0, align)?;
|
||||
|
||||
let substack2 = StackReq::try_all_of([ntt_scratch_single; N_COMPONENTS])?;
|
||||
let substack1 = substack2.try_and(standard_scratch)?;
|
||||
let substack0 = StackReq::try_any_of([
|
||||
substack1.try_and(standard_scratch)?,
|
||||
Scalar::add_backward_scratch(polynomial_size)?,
|
||||
])?;
|
||||
substack0.try_and(StackReq::try_all_of([ntt_scratch; N_COMPONENTS])?)
|
||||
}
|
||||
|
||||
#[cfg_attr(__profiling, inline(never))]
|
||||
pub fn add_external_product_assign<
|
||||
CrtScalar,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar,
|
||||
ContOut,
|
||||
ContGgsw,
|
||||
ContGlwe,
|
||||
>(
|
||||
out: &mut GlweCiphertext<ContOut>,
|
||||
ggsw: &CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, ContGgsw>,
|
||||
glwe: &GlweCiphertext<ContGlwe>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
ContOut: ContainerMut<Element = Scalar>,
|
||||
ContGgsw: Container<Element = CrtScalar>,
|
||||
ContGlwe: Container<Element = Scalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
>(
|
||||
mut out: GlweCiphertext<&mut [Scalar]>,
|
||||
ggsw: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
|
||||
glwe: GlweCiphertext<&[Scalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
// we check that the polynomial sizes match
|
||||
assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size());
|
||||
assert_eq!(ggsw.polynomial_size(), out.polynomial_size());
|
||||
// we check that the glwe sizes match
|
||||
assert_eq!(ggsw.glwe_size(), glwe.glwe_size());
|
||||
assert_eq!(ggsw.glwe_size(), out.glwe_size());
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
let poly_size = ggsw.polynomial_size().0;
|
||||
|
||||
// we round the input mask and body
|
||||
let decomposer = SignedDecomposer::<Scalar>::new(
|
||||
ggsw.decomposition_base_log(),
|
||||
ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
let (mut output_ntt_buffer, mut substack0) =
|
||||
chain_array_with_context::<_, _, N_COMPONENTS>(stack, |stack| {
|
||||
stack.make_aligned_with(poly_size * ggsw.glwe_size().0, align, |_| CrtScalar::ZERO)
|
||||
});
|
||||
|
||||
let mut output_ntt_buffer =
|
||||
as_mut_array(&mut output_ntt_buffer).map(|buffer| &mut **buffer);
|
||||
{
|
||||
// ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER
|
||||
// DOMAIN In this section, we perform the external product in the NTT
|
||||
// domain, and accumulate the result in the output_ntt_buffer variable.
|
||||
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
|
||||
glwe.as_ref()
|
||||
.iter()
|
||||
.map(|s| decomposer.closest_representable(*s)),
|
||||
DecompositionBaseLog(decomposer.base_log),
|
||||
DecompositionLevelCount(decomposer.level_count),
|
||||
substack0.rb_mut(),
|
||||
);
|
||||
|
||||
// We loop through the levels (we reverse to match the order of the decomposition
|
||||
// iterator.)
|
||||
for ggsw_decomp_matrix in ggsw.into_levels().rev() {
|
||||
// We retrieve the decomposition of this level.
|
||||
let (glwe_level, glwe_decomp_term, mut substack2) =
|
||||
collect_next_term(&mut decomposition, &mut substack1, align);
|
||||
let glwe_decomp_term = GlweCiphertextView::from_container(
|
||||
&*glwe_decomp_term,
|
||||
ggsw.polynomial_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
debug_assert_eq!(ggsw_decomp_matrix.decomposition_level(), glwe_level);
|
||||
|
||||
// For each level we have to add the result of the vector-matrix product between the
|
||||
// decomposition of the glwe, and the ggsw level matrix to the output. To do so, we
|
||||
// iteratively add to the output, the product between every line of the matrix, and
|
||||
// the corresponding (scalar) polynomial in the glwe decomposition:
|
||||
//
|
||||
// ggsw_mat ggsw_mat
|
||||
// glwe_dec | - - - - | < glwe_dec | - - - - |
|
||||
// | - - - | x | - - - - | | - - - | x | - - - - | <
|
||||
// ^ | - - - - | ^ | - - - - |
|
||||
//
|
||||
// t = 1 t = 2 ...
|
||||
|
||||
for (ggsw_row, glwe_poly) in izip!(
|
||||
ggsw_decomp_matrix.into_rows(),
|
||||
glwe_decomp_term.as_polynomial_list().iter()
|
||||
) {
|
||||
let len = poly_size;
|
||||
let stack = substack2.rb_mut();
|
||||
let (mut ntt, _) =
|
||||
chain_array_with_context::<_, _, N_COMPONENTS>(stack, |stack| {
|
||||
stack.make_aligned_raw::<CrtScalar>(len, align)
|
||||
});
|
||||
let mut ntt = as_mut_array(&mut ntt).map(|ntt| &mut **ntt);
|
||||
|
||||
// We perform the forward NTT for the glwe polynomial
|
||||
Scalar::forward(
|
||||
ntt_plan,
|
||||
as_mut_array(&mut ntt).map(|buf| &mut **buf),
|
||||
glwe_poly.as_ref(),
|
||||
);
|
||||
// Now we loop through the polynomials of the output, and add the
|
||||
// corresponding product of polynomials.
|
||||
for (output_ntt, ggsw_poly) in izip!(
|
||||
iter_array(
|
||||
as_mut_array(&mut output_ntt_buffer)
|
||||
.map(|buf| (&mut **buf).into_chunks(poly_size))
|
||||
),
|
||||
iter_array(
|
||||
as_ref_array(&ggsw_row.data).map(|buf| (&**buf).into_chunks(poly_size))
|
||||
),
|
||||
) {
|
||||
Scalar::mul_accumulate(
|
||||
ntt_plan,
|
||||
output_ntt,
|
||||
ggsw_poly,
|
||||
as_ref_array(&ntt).map(|buf| &**buf),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------- TRANSFORMATION OF RESULT TO STANDARD DOMAIN
|
||||
// In this section, we bring the result from the fourier domain, back to the standard
|
||||
// domain, and add it to the output.
|
||||
//
|
||||
// We iterate over the polynomials in the output.
|
||||
for (mut out, ntt) in izip!(
|
||||
out.as_mut_polynomial_list().iter_mut(),
|
||||
iter_array(output_ntt_buffer.map(|buf| buf.into_chunks(poly_size))),
|
||||
) {
|
||||
Scalar::add_backward(ntt_plan, out.as_mut(), ntt, substack0.rb_mut());
|
||||
}
|
||||
}
|
||||
|
||||
implementation(
|
||||
out.as_mut_view(),
|
||||
ggsw.as_view(),
|
||||
glwe.as_view(),
|
||||
ntt_plan,
|
||||
stack,
|
||||
)
|
||||
}
|
||||
|
||||
fn collect_next_term<'a, Scalar: UnsignedInteger>(
|
||||
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
|
||||
substack1: &'a mut PodStack,
|
||||
align: usize,
|
||||
) -> (
|
||||
DecompositionLevel,
|
||||
dyn_stack::DynArray<'a, Scalar>,
|
||||
PodStack<'a>,
|
||||
) {
|
||||
let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap();
|
||||
let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
|
||||
(glwe_level, glwe_decomp_term, substack2)
|
||||
}
|
||||
|
||||
/// Return the required memory for [`cmux`].
|
||||
pub fn cmux_scratch<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize, Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> Result<StackReq, SizeOverflow>
|
||||
where
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
{
|
||||
add_external_product_assign_scratch::<CrtScalar, N_COMPONENTS, Scalar>(
|
||||
glwe_size,
|
||||
polynomial_size,
|
||||
)
|
||||
}
|
||||
|
||||
/// This cmux mutates both ct1 and ct0. The result is in ct0 after the method was called.
|
||||
pub fn cmux<CrtScalar, const N_COMPONENTS: usize, Scalar, ContCt0, ContCt1, ContGgsw>(
|
||||
ct0: &mut GlweCiphertext<ContCt0>,
|
||||
ct1: &mut GlweCiphertext<ContCt1>,
|
||||
ggsw: &CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, ContGgsw>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
CrtScalar: UnsignedInteger,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
ContCt0: ContainerMut<Element = Scalar>,
|
||||
ContCt1: ContainerMut<Element = Scalar>,
|
||||
ContGgsw: Container<Element = CrtScalar>,
|
||||
{
|
||||
fn implementation<
|
||||
CrtScalar: UnsignedInteger,
|
||||
const N_COMPONENTS: usize,
|
||||
Scalar: CrtNtt<CrtScalar, N_COMPONENTS>,
|
||||
>(
|
||||
mut ct0: GlweCiphertext<&mut [Scalar]>,
|
||||
mut ct1: GlweCiphertext<&mut [Scalar]>,
|
||||
ggsw: CrtNttGgswCiphertext<CrtScalar, N_COMPONENTS, &[CrtScalar]>,
|
||||
ntt_plan: Scalar::PlanView<'_>,
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
for (c1, c0) in izip!(ct1.as_mut(), ct0.as_ref(),) {
|
||||
*c1 = c1.wrapping_sub(*c0);
|
||||
}
|
||||
add_external_product_assign(&mut ct0, &ggsw, &ct1, ntt_plan, stack);
|
||||
}
|
||||
|
||||
implementation(
|
||||
ct0.as_mut_view(),
|
||||
ct1.as_mut_view(),
|
||||
ggsw.as_view(),
|
||||
ntt_plan,
|
||||
stack,
|
||||
)
|
||||
}
|
||||
5
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/mod.rs
Normal file
5
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod bootstrap;
|
||||
pub mod ggsw;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
12
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/tests.rs
Normal file
12
tfhe/src/core_crypto/fft_impl/crt_ntt/crypto/tests.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use super::bootstrap::CrtNttLweBootstrapKey;
|
||||
use crate::core_crypto::fft_impl::common::tests::test_bootstrap_generic;
|
||||
use crate::core_crypto::prelude::*;
|
||||
use aligned_vec::ABox;
|
||||
|
||||
#[test]
|
||||
fn test_crt_bootstrap_u64() {
|
||||
test_bootstrap_generic::<u64, CrtNttLweBootstrapKey<u32, 5, ABox<[u32]>>>(
|
||||
StandardDev(0.000007069849454709433),
|
||||
StandardDev(0.00000000000000029403601535432533),
|
||||
);
|
||||
}
|
||||
1
tfhe/src/core_crypto/fft_impl/crt_ntt/math/mod.rs
Normal file
1
tfhe/src/core_crypto/fft_impl/crt_ntt/math/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ntt;
|
||||
179
tfhe/src/core_crypto/fft_impl/crt_ntt/math/ntt/mod.rs
Normal file
179
tfhe/src/core_crypto/fft_impl/crt_ntt/math/ntt/mod.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use crate::core_crypto::commons::parameters::PolynomialSize;
|
||||
use crate::core_crypto::commons::utils::izip;
|
||||
use crate::core_crypto::prelude::UnsignedInteger;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use concrete_ntt::native64::Plan32;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct PlanWrapper(Plan32);
|
||||
impl core::ops::Deref for PlanWrapper {
|
||||
type Target = Plan32;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for PlanWrapper {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
"[?]".fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CrtNtt64 {
|
||||
inner: Arc<PlanWrapper>,
|
||||
}
|
||||
|
||||
/// View type for [`CrtNtt64`].
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct CrtNtt64View<'a> {
|
||||
pub(crate) inner: &'a PlanWrapper,
|
||||
}
|
||||
|
||||
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<PlanWrapper>>>>>;
|
||||
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
|
||||
fn plans() -> &'static PlanMap {
|
||||
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
impl CrtNtt64 {
|
||||
/// Polynomial of size `size`.
|
||||
pub fn new(size: PolynomialSize) -> Self {
|
||||
let global_plans = plans();
|
||||
|
||||
let n = size.0;
|
||||
let get_plan = || {
|
||||
let plans = global_plans.read().unwrap();
|
||||
let plan = plans.get(&n).cloned();
|
||||
drop(plans);
|
||||
|
||||
plan.map(|p| {
|
||||
p.get_or_init(|| Arc::new(PlanWrapper(Plan32::try_new(n).unwrap())))
|
||||
.clone()
|
||||
})
|
||||
};
|
||||
|
||||
// could not find a plan of the given size, we lock the map again and try to insert it
|
||||
let mut plans = global_plans.write().unwrap();
|
||||
if let Entry::Vacant(v) = plans.entry(n) {
|
||||
v.insert(Arc::new(OnceLock::new()));
|
||||
}
|
||||
|
||||
drop(plans);
|
||||
|
||||
Self {
|
||||
inner: get_plan().unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_view(&self) -> CrtNtt64View<'_> {
|
||||
CrtNtt64View { inner: &self.inner }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CrtNtt<CrtScalar: UnsignedInteger, const N_COMPONENTS: usize>: UnsignedInteger {
|
||||
type Plan;
|
||||
type PlanView<'a>: Copy;
|
||||
|
||||
fn new_plan(size: PolynomialSize) -> Self::Plan;
|
||||
fn plan_as_view(plan: &Self::Plan) -> Self::PlanView<'_>;
|
||||
|
||||
fn forward(plan: Self::PlanView<'_>, ntt: [&mut [CrtScalar]; N_COMPONENTS], standard: &[Self]);
|
||||
fn forward_normalized(
|
||||
plan: Self::PlanView<'_>,
|
||||
ntt: [&mut [CrtScalar]; N_COMPONENTS],
|
||||
standard: &[Self],
|
||||
);
|
||||
|
||||
fn add_backward(
|
||||
plan: Self::PlanView<'_>,
|
||||
standard: &mut [Self],
|
||||
ntt: [&mut [CrtScalar]; N_COMPONENTS],
|
||||
stack: PodStack<'_>,
|
||||
);
|
||||
|
||||
fn add_backward_scratch(polynomial_size: PolynomialSize) -> Result<StackReq, SizeOverflow>;
|
||||
|
||||
fn mul_accumulate(
|
||||
plan: Self::PlanView<'_>,
|
||||
acc: [&mut [CrtScalar]; N_COMPONENTS],
|
||||
lhs: [&[CrtScalar]; N_COMPONENTS],
|
||||
rhs: [&[CrtScalar]; N_COMPONENTS],
|
||||
);
|
||||
}
|
||||
|
||||
impl CrtNtt<u32, 5> for u64 {
|
||||
type Plan = CrtNtt64;
|
||||
type PlanView<'a> = CrtNtt64View<'a>;
|
||||
|
||||
fn new_plan(size: PolynomialSize) -> Self::Plan {
|
||||
Self::Plan::new(size)
|
||||
}
|
||||
fn plan_as_view(plan: &Self::Plan) -> Self::PlanView<'_> {
|
||||
plan.as_view()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
plan: Self::PlanView<'_>,
|
||||
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
|
||||
standard: &[Self],
|
||||
) {
|
||||
plan.inner.0.fwd(standard, ntt0, ntt1, ntt2, ntt3, ntt4);
|
||||
}
|
||||
|
||||
fn forward_normalized(
|
||||
plan: Self::PlanView<'_>,
|
||||
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
|
||||
standard: &[Self],
|
||||
) {
|
||||
plan.inner.0.fwd(standard, ntt0, ntt1, ntt2, ntt3, ntt4);
|
||||
plan.inner.0.ntt_0().normalize(ntt0);
|
||||
plan.inner.0.ntt_1().normalize(ntt1);
|
||||
plan.inner.0.ntt_2().normalize(ntt2);
|
||||
plan.inner.0.ntt_3().normalize(ntt3);
|
||||
plan.inner.0.ntt_4().normalize(ntt4);
|
||||
}
|
||||
|
||||
fn add_backward(
|
||||
plan: Self::PlanView<'_>,
|
||||
standard: &mut [Self],
|
||||
[ntt0, ntt1, ntt2, ntt3, ntt4]: [&mut [u32]; 5],
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
let n = standard.len();
|
||||
let (mut tmp, _) = stack.make_aligned_raw::<u64>(n, CACHELINE_ALIGN);
|
||||
plan.inner.0.inv(&mut tmp, ntt0, ntt1, ntt2, ntt3, ntt4);
|
||||
|
||||
// autovectorize
|
||||
pulp::Arch::new().dispatch(
|
||||
#[inline(always)]
|
||||
|| {
|
||||
for (out, inp) in izip!(standard, &*tmp) {
|
||||
*out = u64::wrapping_add(*out, *inp);
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn add_backward_scratch(polynomial_size: PolynomialSize) -> Result<StackReq, SizeOverflow> {
|
||||
StackReq::try_new_aligned::<u64>(polynomial_size.0, CACHELINE_ALIGN)
|
||||
}
|
||||
|
||||
fn mul_accumulate(
|
||||
plan: Self::PlanView<'_>,
|
||||
[acc0, acc1, acc2, acc3, acc4]: [&mut [u32]; 5],
|
||||
[lhs0, lhs1, lhs2, lhs3, lhs4]: [&[u32]; 5],
|
||||
[rhs0, rhs1, rhs2, rhs3, rhs4]: [&[u32]; 5],
|
||||
) {
|
||||
plan.inner.ntt_0().mul_accumulate(acc0, lhs0, rhs0);
|
||||
plan.inner.ntt_1().mul_accumulate(acc1, lhs1, rhs1);
|
||||
plan.inner.ntt_2().mul_accumulate(acc2, lhs2, rhs2);
|
||||
plan.inner.ntt_3().mul_accumulate(acc3, lhs3, rhs3);
|
||||
plan.inner.ntt_4().mul_accumulate(acc4, lhs4, rhs4);
|
||||
}
|
||||
}
|
||||
2
tfhe/src/core_crypto/fft_impl/crt_ntt/mod.rs
Normal file
2
tfhe/src/core_crypto/fft_impl/crt_ntt/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod crypto;
|
||||
pub mod math;
|
||||
@@ -7,3 +7,5 @@ pub mod fft64;
|
||||
|
||||
pub mod fft128;
|
||||
mod fft128_u128;
|
||||
|
||||
pub mod crt_ntt;
|
||||
|
||||
Reference in New Issue
Block a user