From 03f63ec202554455dcc5aed4156e9ccb5b1c444e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 6 Dec 2022 10:19:23 +0100 Subject: [PATCH] chore(tfhe): fix refactor TODOs --- .../core_crypto/algorithms/glwe_encryption.rs | 6 ++-- ...tional_packing_keyswitch_key_generation.rs | 30 +++++++++++++++---- .../algorithms/lwe_public_key_generation.rs | 4 ++- tfhe/src/core_crypto/algorithms/lwe_wopbs.rs | 1 - .../core_crypto/commons/traits/container.rs | 6 ++-- .../traits/contiguous_entity_container.rs | 13 ++++++-- .../core_crypto/fft_impl/crypto/bootstrap.rs | 4 +-- tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs | 4 +-- .../fft_impl/crypto/wop_pbs/mod.rs | 3 +- tfhe/src/core_crypto/fft_impl/math/fft/mod.rs | 10 ++++--- tfhe/src/shortint/engine/wopbs/mod.rs | 5 ++-- 11 files changed, 57 insertions(+), 29 deletions(-) diff --git a/tfhe/src/core_crypto/algorithms/glwe_encryption.rs b/tfhe/src/core_crypto/algorithms/glwe_encryption.rs index 960702961..c4731383d 100644 --- a/tfhe/src/core_crypto/algorithms/glwe_encryption.rs +++ b/tfhe/src/core_crypto/algorithms/glwe_encryption.rs @@ -248,8 +248,10 @@ pub fn trivially_encrypt_glwe_ciphertext( InputCont: Container, { assert!( - output.polynomial_size().0 == encoded.plaintext_count().0, - "TODO Error message" + encoded.plaintext_count().0 == output.polynomial_size().0, + "Mismatched input PlaintextCount {:?} and output PolynomialSize {:?}", + encoded.plaintext_count(), + output.polynomial_size() ); let (mut mask, mut body) = output.get_mut_mask_and_body(); diff --git a/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch_key_generation.rs b/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch_key_generation.rs index e2d02656f..a6cf5b7aa 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch_key_generation.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch_key_generation.rs @@ -37,15 +37,24 @@ pub fn generate_lwe_private_functional_packing_keyswitch_key< { assert!( input_lwe_secret_key.lwe_dimension() == lwe_pfpksk.input_lwe_key_dimension(), - "TODO error message" + "Mismatched LweDimension between input_lwe_secret_key {:?} and lwe_pfpksk input dimension \ + {:?}.", + input_lwe_secret_key.lwe_dimension(), + lwe_pfpksk.input_lwe_key_dimension() ); assert!( output_glwe_secret_key.glwe_dimension() == lwe_pfpksk.output_glwe_key_dimension(), - "TODO error message" + "Mismatched GlweDimension between output_glwe_secret_key {:?} and lwe_pfpksk output \ + dimension {:?}.", + output_glwe_secret_key.glwe_dimension(), + lwe_pfpksk.output_glwe_key_dimension() ); assert!( output_glwe_secret_key.polynomial_size() == lwe_pfpksk.output_polynomial_size(), - "TODO error message" + "Mismatched PolynomialSize between output_glwe_secret_key {:?} and lwe_pfpksk output \ + polynomial size {:?}.", + output_glwe_secret_key.polynomial_size(), + lwe_pfpksk.output_polynomial_size() ); // We instantiate a buffer @@ -137,15 +146,24 @@ pub fn par_generate_lwe_private_functional_packing_keyswitch_key< { assert!( input_lwe_secret_key.lwe_dimension() == lwe_pfpksk.input_lwe_key_dimension(), - "TODO error message" + "Mismatched LweDimension between input_lwe_secret_key {:?} and lwe_pfpksk input dimension \ + {:?}.", + input_lwe_secret_key.lwe_dimension(), + lwe_pfpksk.input_lwe_key_dimension() ); assert!( output_glwe_secret_key.glwe_dimension() == lwe_pfpksk.output_glwe_key_dimension(), - "TODO error message" + "Mismatched GlweDimension between output_glwe_secret_key {:?} and lwe_pfpksk output \ + dimension {:?}.", + output_glwe_secret_key.glwe_dimension(), + lwe_pfpksk.output_glwe_key_dimension() ); assert!( output_glwe_secret_key.polynomial_size() == lwe_pfpksk.output_polynomial_size(), - "TODO error message" + "Mismatched PolynomialSize between output_glwe_secret_key {:?} and lwe_pfpksk output \ + polynomial size {:?}.", + output_glwe_secret_key.polynomial_size(), + lwe_pfpksk.output_polynomial_size() ); // We retrieve decomposition arguments diff --git a/tfhe/src/core_crypto/algorithms/lwe_public_key_generation.rs b/tfhe/src/core_crypto/algorithms/lwe_public_key_generation.rs index 659e6f939..e53b79c56 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_public_key_generation.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_public_key_generation.rs @@ -68,7 +68,9 @@ pub fn par_generate_lwe_public_key( { assert!( lwe_secret_key.lwe_dimension() == output.lwe_size().to_lwe_dimension(), - "TODO error message" + "Mismatch LweDimension between lwe_secret_key {:?} and public key {:?}", + lwe_secret_key.lwe_dimension(), + output.lwe_size().to_lwe_dimension() ); let zeros = PlaintextListOwned::new( diff --git a/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs b/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs index 0fc6f0131..867863a69 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs @@ -333,7 +333,6 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list< ) } -// TODO big_lut_polynomial_count looks wrong #[allow(clippy::too_many_arguments)] pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_scracth( lwe_list_in_count: LweCiphertextCount, diff --git a/tfhe/src/core_crypto/commons/traits/container.rs b/tfhe/src/core_crypto/commons/traits/container.rs index abf2514fb..d93136c3e 100644 --- a/tfhe/src/core_crypto/commons/traits/container.rs +++ b/tfhe/src/core_crypto/commons/traits/container.rs @@ -48,13 +48,11 @@ impl Container for aligned_vec::AVec { impl ContainerMut for aligned_vec::AVec {} -// TODO REFACTOR -// Rework the fft traits -pub trait ContainerOwned: Container + AsMut<[Self::Element]> { +pub trait IntoContainerOwned: Container + AsMut<[Self::Element]> { fn collect>(iter: I) -> Self; } -impl ContainerOwned for aligned_vec::ABox<[T]> { +impl IntoContainerOwned for aligned_vec::ABox<[T]> { fn collect>(iter: I) -> Self { aligned_vec::AVec::::from_iter(0, iter).into_boxed_slice() } diff --git a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs index 11e256d4f..4d2d3409c 100644 --- a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs +++ b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs @@ -117,7 +117,11 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> { let entity_view_pod_size = self.get_entity_view_pod_size(); let entity_count = self.as_ref().len() / entity_view_pod_size; - assert!(entity_count % chunk_size == 0, "TODO Err message"); + assert!( + entity_count % chunk_size == 0, + "The current container has {entity_count} entities, which is not dividable by the \ + requested chunk_size: {chunk_size}, preventing chunks_exact from returning an iterator." + ); let pod_chunk_size = entity_view_pod_size * chunk_size; @@ -208,7 +212,12 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self: let entity_view_pod_size = self.get_entity_view_pod_size(); let entity_count = self.as_ref().len() / entity_view_pod_size; - assert!(entity_count % chunk_size == 0, "TODO Err message"); + assert!( + entity_count % chunk_size == 0, + "The current container has {entity_count} entities, which is not dividable by the \ + requested chunk_size: {chunk_size}, preventing chunks_exact_mut from returning an \ + iterator." + ); let pod_chunk_size = entity_view_pod_size * chunk_size; diff --git a/tfhe/src/core_crypto/fft_impl/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/crypto/bootstrap.rs index f719bf9d8..71d06c134 100644 --- a/tfhe/src/core_crypto/fft_impl/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/crypto/bootstrap.rs @@ -5,7 +5,7 @@ use crate::core_crypto::algorithms::polynomial_algorithms::*; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::CastInto; use crate::core_crypto::commons::traits::{ - Container, ContainerOwned, ContiguousEntityContainer, ContiguousEntityContainerMut, Split, + Container, ContiguousEntityContainer, ContiguousEntityContainerMut, IntoContainerOwned, Split, }; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; @@ -18,7 +18,7 @@ use concrete_fft::c64; use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq}; #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(bound(deserialize = "C: ContainerOwned"))] +#[serde(bound(deserialize = "C: IntoContainerOwned"))] pub struct FourierLweBootstrapKey> { fourier: FourierPolynomialList, input_lwe_dimension: LweDimension, diff --git a/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs index 1b0b73963..ab8cb52e4 100644 --- a/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs @@ -7,7 +7,7 @@ use super::super::{as_mut_uninit, assume_init_mut}; use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::traits::{ - Container, ContainerOwned, ContiguousEntityContainer, ContiguousEntityContainerMut, Split, + Container, ContiguousEntityContainer, ContiguousEntityContainerMut, IntoContainerOwned, Split, }; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; @@ -25,7 +25,7 @@ use core::arch::x86_64::*; /// A GGSW ciphertext in the Fourier domain. #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(bound(deserialize = "C: ContainerOwned"))] +#[serde(bound(deserialize = "C: IntoContainerOwned"))] pub struct FourierGgswCiphertext> { fourier: FourierPolynomialList, glwe_size: GlweSize, diff --git a/tfhe/src/core_crypto/fft_impl/crypto/wop_pbs/mod.rs b/tfhe/src/core_crypto/fft_impl/crypto/wop_pbs/mod.rs index b114c8d6e..3b25d0d28 100644 --- a/tfhe/src/core_crypto/fft_impl/crypto/wop_pbs/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/crypto/wop_pbs/mod.rs @@ -266,8 +266,7 @@ pub fn circuit_bootstrap_boolean>( debug_assert!( ggsw_out.glwe_size().0 == pfpksk_list.lwe_pfpksk_count().0, - "The input vector of fpksk needs to have {} (ggsw.glwe_size * \ - ggsw.decomposition_level_count) elements got {}", + "The input vector of pfpksk_list needs to have {} ggsw.glwe_size elements got {}", ggsw_out.glwe_size().0, pfpksk_list.lwe_pfpksk_count().0, ); diff --git a/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs index 2bafcf9ff..3ab724772 100644 --- a/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs @@ -5,7 +5,7 @@ use super::polynomial::{ }; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::CastInto; -use crate::core_crypto::commons::traits::{Container, ContainerOwned}; +use crate::core_crypto::commons::traits::{Container, IntoContainerOwned}; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; use crate::core_crypto::specification::parameters::PolynomialSize; @@ -571,12 +571,14 @@ impl> serde::Serialize for FourierPolynomialList } } -impl<'de, C: ContainerOwned> serde::Deserialize<'de> for FourierPolynomialList { +impl<'de, C: IntoContainerOwned> serde::Deserialize<'de> + for FourierPolynomialList +{ fn deserialize>(deserializer: D) -> Result { use std::marker::PhantomData; - struct SeqVisitor>(PhantomData C>); + struct SeqVisitor>(PhantomData C>); - impl<'de, C: ContainerOwned> serde::de::Visitor<'de> for SeqVisitor { + impl<'de, C: IntoContainerOwned> serde::de::Visitor<'de> for SeqVisitor { type Value = FourierPolynomialList; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { diff --git a/tfhe/src/shortint/engine/wopbs/mod.rs b/tfhe/src/shortint/engine/wopbs/mod.rs index 5939b78a5..cc8e80619 100644 --- a/tfhe/src/shortint/engine/wopbs/mod.rs +++ b/tfhe/src/shortint/engine/wopbs/mod.rs @@ -223,6 +223,7 @@ impl ShortintEngine { let output_lwe_size = fourier_bsk.output_lwe_dimension().to_lwe_size(); let mut output_cbs_vp_ct = LweCiphertextListOwned::new(0u64, output_lwe_size, count); + let lut = PolynomialListView::from_container(lut.as_ref(), fourier_bsk.polynomial_size()); let fft = Fft::new(fourier_bsk.polynomial_size()); let fft = fft.as_view(); @@ -231,7 +232,7 @@ impl ShortintEngine { extracted_bits.lwe_ciphertext_count(), output_cbs_vp_ct.lwe_ciphertext_count(), extracted_bits.lwe_size(), - PolynomialCount(lut.plaintext_count().0), + lut.polynomial_count(), fourier_bsk.output_lwe_dimension().to_lwe_size(), wopbs_key.cbs_pfpksk.output_polynomial_size(), fourier_bsk.glwe_size(), @@ -242,8 +243,6 @@ impl ShortintEngine { .unaligned_bytes_required(), ); - let lut = PolynomialListView::from_container(lut.as_ref(), fourier_bsk.polynomial_size()); - let stack = self.fft_buffers.stack(); circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list(