diff --git a/tfhe/src/core_crypto/backward_compatibility/entities/compressed_modulus_switched_lwe_ciphertext.rs b/tfhe/src/core_crypto/backward_compatibility/entities/compressed_modulus_switched_lwe_ciphertext.rs index 0dc3b4169..1cae14025 100644 --- a/tfhe/src/core_crypto/backward_compatibility/entities/compressed_modulus_switched_lwe_ciphertext.rs +++ b/tfhe/src/core_crypto/backward_compatibility/entities/compressed_modulus_switched_lwe_ciphertext.rs @@ -22,17 +22,17 @@ impl Upgrade Result, Self::Error> { - let packed_integers = PackedIntegers { - packed_coeffs: self.packed_coeffs, - log_modulus: self.log_modulus, - initial_len: self.lwe_dimension.to_lwe_size().0, - }; + let packed_integers = PackedIntegers::from_raw_parts( + self.packed_coeffs, + self.log_modulus, + self.lwe_dimension.to_lwe_size().0, + ); - Ok(CompressedModulusSwitchedLweCiphertext { + Ok(CompressedModulusSwitchedLweCiphertext::from_raw_parts( packed_integers, - lwe_dimension: self.lwe_dimension, - uncompressed_ciphertext_modulus: self.uncompressed_ciphertext_modulus, - }) + self.lwe_dimension, + self.uncompressed_ciphertext_modulus, + )) } } diff --git a/tfhe/src/core_crypto/commons/ciphertext_modulus.rs b/tfhe/src/core_crypto/commons/ciphertext_modulus.rs index e398a1a38..fde597844 100644 --- a/tfhe/src/core_crypto/commons/ciphertext_modulus.rs +++ b/tfhe/src/core_crypto/commons/ciphertext_modulus.rs @@ -10,6 +10,8 @@ use std::cmp::Ordering; use std::fmt::Display; use std::marker::PhantomData; +use super::parameters::CiphertextModulusLog; + #[derive(Clone, Copy, PartialEq, Eq)] /// Private enum to avoid end user instantiating a bad CiphertextModulus /// @@ -274,6 +276,15 @@ impl CiphertextModulus { } } + pub fn into_modulus_log(self) -> CiphertextModulusLog { + match self.inner { + CiphertextModulusInner::Native => CiphertextModulusLog(Scalar::BITS), + CiphertextModulusInner::Custom(custom_mod) => { + CiphertextModulusLog(custom_mod.get().ceil_ilog2() as usize) + } + } + } + pub fn get_custom_modulus_as_optional_scalar(&self) -> Option { match self.inner { CiphertextModulusInner::Native => None, @@ -354,6 +365,12 @@ impl CiphertextModulus { } } +impl From> for CiphertextModulusLog { + fn from(value: CiphertextModulus) -> Self { + value.into_modulus_log() + } +} + impl std::fmt::Display for CiphertextModulus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.inner { diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs index 21454b979..a04e43058 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs @@ -80,14 +80,73 @@ use crate::core_crypto::prelude::*; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedModulusSwitchedGlweCiphertextVersions)] pub struct CompressedModulusSwitchedGlweCiphertext { - pub(crate) packed_integers: PackedIntegers, - pub(crate) glwe_dimension: GlweDimension, - pub(crate) polynomial_size: PolynomialSize, - pub(crate) bodies_count: LweCiphertextCount, - pub(crate) uncompressed_ciphertext_modulus: CiphertextModulus, + packed_integers: PackedIntegers, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + bodies_count: LweCiphertextCount, + uncompressed_ciphertext_modulus: CiphertextModulus, } -impl CompressedModulusSwitchedGlweCiphertext { +impl CompressedModulusSwitchedGlweCiphertext { + pub fn from_raw_parts( + packed_integers: PackedIntegers, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + bodies_count: LweCiphertextCount, + uncompressed_ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert_eq!( + glwe_dimension.0 * polynomial_size.0 + bodies_count.0, + packed_integers.initial_len(), + "Packed integers list is not of the correct size for the uncompressed GLWE: expected {}, got {}", + glwe_dimension.0 * polynomial_size.0 + bodies_count.0, + packed_integers.initial_len(), + ); + + assert!( + packed_integers.log_modulus().0 + <= CiphertextModulusLog::from(uncompressed_ciphertext_modulus).0, + "Compressed modulus (={}) should be smaller than the uncompressed modulus (={})", + packed_integers.log_modulus().0, + CiphertextModulusLog::from(uncompressed_ciphertext_modulus).0, + ); + + Self { + packed_integers, + glwe_dimension, + polynomial_size, + bodies_count, + uncompressed_ciphertext_modulus, + } + } + + #[cfg(test)] + pub(crate) fn into_raw_parts( + self, + ) -> ( + PackedIntegers, + GlweDimension, + PolynomialSize, + LweCiphertextCount, + CiphertextModulus, + ) { + let Self { + packed_integers, + glwe_dimension, + polynomial_size, + bodies_count, + uncompressed_ciphertext_modulus, + } = self; + + ( + packed_integers, + glwe_dimension, + polynomial_size, + bodies_count, + uncompressed_ciphertext_modulus, + ) + } + pub fn glwe_dimension(&self) -> GlweDimension { self.glwe_dimension } @@ -101,6 +160,10 @@ impl CompressedModulusSwitchedGlweCiphertext { self.uncompressed_ciphertext_modulus } + pub fn packed_integers(&self) -> &PackedIntegers { + &self.packed_integers + } + /// Compresses a ciphertext by reducing its modulus /// This operation adds a lot of noise pub fn compress>( @@ -160,7 +223,7 @@ impl CompressedModulusSwitchedGlweCiphertext { /// The noise added during the compression stays in the output /// The output must got through a PBS to reduce the noise pub fn extract(&self) -> GlweCiphertextOwned { - let log_modulus = self.packed_integers.log_modulus.0; + let log_modulus = self.packed_integers.log_modulus().0; let number_bits_to_unpack = (self.glwe_dimension.0 * self.polynomial_size.0 + self.bodies_count.0) * log_modulus; @@ -168,9 +231,9 @@ impl CompressedModulusSwitchedGlweCiphertext { let len: usize = number_bits_to_unpack.div_ceil(Scalar::BITS); assert_eq!( - self.packed_integers.packed_coeffs.len(), len, + self.packed_integers.packed_coeffs().len(), len, "Mismatch between actual(={}) and maximal(={len}) CompressedModulusSwitchedGlweCiphertext packed_coeffs size", - self.packed_integers.packed_coeffs.len(), + self.packed_integers.packed_coeffs().len(), ); let container = self @@ -205,14 +268,14 @@ impl ParameterSetConformant bodies_count, uncompressed_ciphertext_modulus, } = self; - let log_modulus = packed_integers.log_modulus.0; + let log_modulus = packed_integers.log_modulus().0; let number_bits_to_unpack = (glwe_dimension.0 * polynomial_size.0 + bodies_count.0) * log_modulus; let len = number_bits_to_unpack.div_ceil(Scalar::BITS); - packed_integers.packed_coeffs.len() == len + packed_integers.packed_coeffs().len() == len && *glwe_dimension == lwe_ct_parameters.glwe_dim && *polynomial_size == lwe_ct_parameters.polynomial_size && lwe_ct_parameters.ct_modulus.is_power_of_two() @@ -222,8 +285,9 @@ impl ParameterSetConformant #[cfg(test)] mod test { + use rand::{Fill, Rng}; + use super::*; - use crate::core_crypto::prelude::test::TestResources; #[test] fn glwe_ms_compression_() { @@ -248,17 +312,20 @@ mod test { glwe_dimension: GlweDimension, polynomial_size: PolynomialSize, bodies_count: usize, - ) { - let mut rsc: TestResources = TestResources::new(); - + ) where + [Scalar]: Fill, + { let ciphertext_modulus = CiphertextModulus::new_native(); - let mut glwe = vec![Scalar::ZERO; (glwe_dimension.0 + 1) * polynomial_size.0]; + let mut glwe = GlweCiphertext::new( + Scalar::ZERO, + glwe_dimension.to_glwe_size(), + polynomial_size, + ciphertext_modulus, + ); - rsc.encryption_random_generator - .fill_slice_with_random_uniform_mask(&mut glwe); - - let glwe = GlweCiphertextOwned::from_container(glwe, polynomial_size, ciphertext_modulus); + // We don't care about the exact content here + rand::thread_rng().fill(glwe.as_mut()); let compressed = CompressedModulusSwitchedGlweCiphertext::compress( &glwe, @@ -286,4 +353,63 @@ mod test { ) } } + + #[test] + fn test_from_raw_parts() { + type Scalar = u64; + + let ciphertext_modulus = CiphertextModulus::new_native(); + let glwe_dimension = GlweDimension(1); + let polynomial_size = PolynomialSize(512); + let bodies_count = LweCiphertextCount(512); + let log_modulus = 12; + + let mut glwe = GlweCiphertext::new( + Scalar::ZERO, + glwe_dimension.to_glwe_size(), + polynomial_size, + ciphertext_modulus, + ); + + // We don't care about the exact content here + rand::thread_rng().fill(glwe.as_mut()); + + let compressed = CompressedModulusSwitchedGlweCiphertext::compress( + &glwe, + CiphertextModulusLog(log_modulus), + bodies_count, + ); + + let ( + packed_integers, + glwe_dimension, + polynomial_size, + bodies_count, + uncompressed_ciphertext_modulus, + ) = compressed.into_raw_parts(); + + let rebuilt = CompressedModulusSwitchedGlweCiphertext::from_raw_parts( + packed_integers, + glwe_dimension, + polynomial_size, + bodies_count, + uncompressed_ciphertext_modulus, + ); + + let glwe_ms_ed = rebuilt.extract().into_container(); + let glwe = glwe.into_container(); + + for (i, output) in glwe_ms_ed.into_iter().enumerate() { + assert_eq!( + output, + (output >> (Scalar::BITS as usize - log_modulus)) + << (Scalar::BITS as usize - log_modulus), + ); + + assert_eq!( + output >> (Scalar::BITS as usize - log_modulus), + modulus_switch(glwe[i], CiphertextModulusLog(log_modulus)) + ) + } + } } diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_lwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_lwe_ciphertext.rs index 9a386c110..65ef121c2 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_lwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_lwe_ciphertext.rs @@ -63,12 +63,50 @@ use crate::core_crypto::prelude::*; #[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedModulusSwitchedLweCiphertextVersions)] pub struct CompressedModulusSwitchedLweCiphertext { - pub(crate) packed_integers: PackedIntegers, - pub(crate) lwe_dimension: LweDimension, - pub(crate) uncompressed_ciphertext_modulus: CiphertextModulus, + packed_integers: PackedIntegers, + lwe_dimension: LweDimension, + uncompressed_ciphertext_modulus: CiphertextModulus, } -impl CompressedModulusSwitchedLweCiphertext { +impl CompressedModulusSwitchedLweCiphertext { + pub(crate) fn from_raw_parts( + packed_integers: PackedIntegers, + lwe_dimension: LweDimension, + uncompressed_ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert_eq!(packed_integers.initial_len(), lwe_dimension.to_lwe_size().0, + "Packed integers list is not of the correct size for the uncompressed LWE: expected {}, got {}", + lwe_dimension.to_lwe_size().0, + packed_integers.initial_len()); + + Self { + packed_integers, + lwe_dimension, + uncompressed_ciphertext_modulus, + } + } + + #[cfg(test)] + pub(crate) fn into_raw_parts( + self, + ) -> ( + PackedIntegers, + LweDimension, + CiphertextModulus, + ) { + let Self { + packed_integers, + lwe_dimension, + uncompressed_ciphertext_modulus, + } = self; + + ( + packed_integers, + lwe_dimension, + uncompressed_ciphertext_modulus, + ) + } + /// Compresses a ciphertext by reducing its modulus /// This operation adds a lot of noise pub fn compress>( @@ -117,17 +155,17 @@ impl CompressedModulusSwitchedLweCiphertext { pub fn extract(&self) -> LweCiphertextOwned { let lwe_size = self.lwe_dimension.to_lwe_size().0; - let log_modulus = self.packed_integers.log_modulus.0; + let log_modulus = self.packed_integers.log_modulus().0; let number_bits_to_unpack = lwe_size * log_modulus; let len = number_bits_to_unpack.div_ceil(Scalar::BITS); assert_eq!( - self.packed_integers.packed_coeffs.len(), + self.packed_integers.packed_coeffs().len(), len, "Mismatch between actual(={}) and expected(={len}) CompressedModulusSwitchedLweCiphertext packed_coeffs size", - self.packed_integers.packed_coeffs.len(), + self.packed_integers.packed_coeffs().len(), ); let container = self @@ -155,11 +193,11 @@ impl ParameterSetConformant let lwe_size = lwe_dimension.to_lwe_size().0; - let number_bits_to_pack = lwe_size * packed_integers.log_modulus.0; + let number_bits_to_pack = lwe_size * packed_integers.log_modulus().0; let len = number_bits_to_pack.div_ceil(Scalar::BITS); - packed_integers.packed_coeffs.len() == len + packed_integers.packed_coeffs().len() == len && *lwe_dimension == lwe_ct_parameters.lwe_dim && lwe_ct_parameters.ct_modulus.is_power_of_two() && *uncompressed_ciphertext_modulus == lwe_ct_parameters.ct_modulus @@ -172,8 +210,9 @@ impl ParameterSetConformant #[cfg(test)] mod test { + use rand::{Fill, Rng}; + use super::*; - use crate::core_crypto::prelude::test::TestResources; #[test] fn ms_compression_() { @@ -196,17 +235,15 @@ mod test { fn ms_compression + CastFrom>( log_modulus: usize, len: usize, - ) { - let mut rsc: TestResources = TestResources::new(); - + ) where + [Scalar]: Fill, + { let ciphertext_modulus = CiphertextModulus::new_native(); - let mut lwe = vec![Scalar::ZERO; len]; + let mut lwe = LweCiphertext::new(Scalar::ZERO, LweSize(len), ciphertext_modulus); - rsc.encryption_random_generator - .fill_slice_with_random_uniform_mask(&mut lwe); - - let lwe = LweCiphertextOwned::from_container(lwe, ciphertext_modulus); + // We don't care about the exact content here + rand::thread_rng().fill(lwe.as_mut()); let compressed = CompressedModulusSwitchedLweCiphertext::compress( &lwe, @@ -229,4 +266,49 @@ mod test { ) } } + + #[test] + fn test_from_raw_parts() { + type Scalar = u64; + + let len = 751; + let log_modulus = 12; + + let ciphertext_modulus = CiphertextModulus::new_native(); + + let mut lwe = LweCiphertext::new(Scalar::ZERO, LweSize(len), ciphertext_modulus); + + // We don't care about the exact content here + rand::thread_rng().fill(lwe.as_mut()); + + let compressed = CompressedModulusSwitchedLweCiphertext::compress( + &lwe, + CiphertextModulusLog(log_modulus), + ); + + let (packed_integers, lwe_dimension, uncompressed_ciphertext_modulus) = + compressed.into_raw_parts(); + + let rebuilt = CompressedModulusSwitchedLweCiphertext::from_raw_parts( + packed_integers, + lwe_dimension, + uncompressed_ciphertext_modulus, + ); + + let lwe_ms_ed = rebuilt.extract().into_container(); + let lwe = lwe.into_container(); + + for (i, output) in lwe_ms_ed.into_iter().enumerate() { + assert_eq!( + output, + (output >> (Scalar::BITS as usize - log_modulus)) + << (Scalar::BITS as usize - log_modulus), + ); + + assert_eq!( + output >> (Scalar::BITS as usize - log_modulus), + modulus_switch(lwe[i], CiphertextModulusLog(log_modulus)) + ) + } + } } diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_multi_bit_lwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_multi_bit_lwe_ciphertext.rs index a12b8421b..e982e047f 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_multi_bit_lwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_multi_bit_lwe_ciphertext.rs @@ -323,7 +323,7 @@ impl + CastFrom> let diffs = |a: usize| { self.packed_diffs.as_ref().map_or(0, |packed_diffs| { let diffs_two_complement: usize = diffs_two_complement[a]; - let used_space_log = packed_diffs.log_modulus.0; + let used_space_log = packed_diffs.log_modulus().0; // rebuild from two complement representation on used_space_log bits let used_space = 1 << used_space_log; if diffs_two_complement >= used_space / 2 { @@ -358,7 +358,7 @@ impl + CastFrom> } switched_modulus_input_mask_per_group - .push(monomial_degree % (1 << self.packed_mask.log_modulus.0)); + .push(monomial_degree % (1 << self.packed_mask.log_modulus().0)); } } @@ -416,7 +416,7 @@ impl + CastFrom> ParameterSetCo let lwe_dim = lwe_dimension.0; - body >> packed_mask.log_modulus.0 == 0 + body >> packed_mask.log_modulus().0 == 0 && packed_mask.is_conformant(&lwe_dim) && packed_diffs .as_ref() diff --git a/tfhe/src/core_crypto/entities/packed_integers.rs b/tfhe/src/core_crypto/entities/packed_integers.rs index 0df76d79f..2f86989d6 100644 --- a/tfhe/src/core_crypto/entities/packed_integers.rs +++ b/tfhe/src/core_crypto/entities/packed_integers.rs @@ -7,12 +7,35 @@ use crate::core_crypto::prelude::*; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(PackedIntegersVersions)] pub struct PackedIntegers { - pub(crate) packed_coeffs: Vec, - pub(crate) log_modulus: CiphertextModulusLog, - pub(crate) initial_len: usize, + packed_coeffs: Vec, + log_modulus: CiphertextModulusLog, + initial_len: usize, } impl PackedIntegers { + pub(crate) fn from_raw_parts( + packed_coeffs: Vec, + log_modulus: CiphertextModulusLog, + initial_len: usize, + ) -> Self { + let required_bits_packed = initial_len * log_modulus.0; + let expected_len = required_bits_packed.div_ceil(Scalar::BITS); + + assert_eq!( + packed_coeffs.len(), + expected_len, + "Invalid size for the packed coeffs, got {}, expected {}", + packed_coeffs.len(), + expected_len + ); + + Self { + packed_coeffs, + log_modulus, + initial_len, + } + } + pub fn pack(slice: &[Scalar], log_modulus: CiphertextModulusLog) -> Self { let log_modulus = log_modulus.0; @@ -166,6 +189,18 @@ impl PackedIntegers { } }) } + + pub fn log_modulus(&self) -> CiphertextModulusLog { + self.log_modulus + } + + pub fn packed_coeffs(&self) -> &[Scalar] { + &self.packed_coeffs + } + + pub fn initial_len(&self) -> usize { + self.initial_len + } } impl ParameterSetConformant for PackedIntegers { diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index edf0c82a0..0b81c75f3 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -217,17 +217,19 @@ impl CudaCompressedCiphertextList { let number_bits_to_pack = initial_len * storage_log_modulus.0; let len = number_bits_to_pack.div_ceil(u64::BITS as usize); let chunk_end = chunk_start + len; - modulus_switched_glwe_ciphertext_list.push(CompressedModulusSwitchedGlweCiphertext { - packed_integers: PackedIntegers { - packed_coeffs: flat_cpu_data[chunk_start..chunk_end].to_vec(), - log_modulus: storage_log_modulus, - initial_len, - }, - glwe_dimension, - polynomial_size, - bodies_count, - uncompressed_ciphertext_modulus: ciphertext_modulus, - }); + modulus_switched_glwe_ciphertext_list.push( + CompressedModulusSwitchedGlweCiphertext::from_raw_parts( + PackedIntegers::from_raw_parts( + flat_cpu_data[chunk_start..chunk_end].to_vec(), + storage_log_modulus, + initial_len, + ), + glwe_dimension, + polynomial_size, + bodies_count, + ciphertext_modulus, + ), + ); num_bodies_left = num_bodies_left.saturating_sub(lwe_per_glwe.0); chunk_start = chunk_end; } @@ -345,10 +347,10 @@ impl CompressedCiphertextList { &self.packed_list.modulus_switched_glwe_ciphertext_list; let first_ct = modulus_switched_glwe_ciphertext_list.first().unwrap(); - let storage_log_modulus = first_ct.packed_integers.log_modulus; + let storage_log_modulus = first_ct.packed_integers().log_modulus(); let initial_len = modulus_switched_glwe_ciphertext_list .iter() - .map(|glwe| glwe.packed_integers.initial_len) + .map(|glwe| glwe.packed_integers().initial_len()) .sum(); let message_modulus = self.packed_list.message_modulus; @@ -356,7 +358,7 @@ impl CompressedCiphertextList { let flat_cpu_data = modulus_switched_glwe_ciphertext_list .iter() - .flat_map(|ct| ct.packed_integers.packed_coeffs.clone()) + .flat_map(|ct| ct.packed_integers().packed_coeffs().to_vec()) .collect_vec(); let flat_gpu_data = unsafe { diff --git a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs index 2ee6a3d34..f291f63ef 100644 --- a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs @@ -25,7 +25,7 @@ impl CompressedCiphertextList { pub(crate) fn flat_len(&self) -> usize { self.modulus_switched_glwe_ciphertext_list .iter() - .map(|glwe| glwe.packed_integers.packed_coeffs.len()) + .map(|glwe| glwe.packed_integers().packed_coeffs().len()) .sum() } }