diff --git a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs index ae8acb33a..816460dae 100644 --- a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs @@ -64,11 +64,7 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_bitnot_assign_async( + pub fn unchecked_bitnot_assign( &self, ct: &mut T, streams: &CudaStreams, @@ -83,9 +79,13 @@ impl CudaServerKey { let shift_plaintext = self.encoding().encode(Cleartext(u64::from(scalar))).0; let scalar_vector = vec![shift_plaintext; ct_blocks]; - let mut d_decomposed_scalar = - CudaVec::::new_async(ct.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0); - d_decomposed_scalar.copy_from_cpu_async(scalar_vector.as_slice(), streams, 0); + + let mut d_decomposed_scalar = unsafe { + CudaVec::::new_async(ct.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0) + }; + unsafe { + d_decomposed_scalar.copy_from_cpu_async(scalar_vector.as_slice(), streams, 0); + } cuda_lwe_ciphertext_plaintext_add_assign( &mut ct.as_mut().d_blocks, @@ -95,7 +95,7 @@ impl CudaServerKey { ct.as_mut().info = ct.as_ref().info.after_bitnot(); } - pub(crate) unsafe fn unchecked_boolean_bitnot_assign_async( + pub fn unchecked_boolean_bitnot_assign( &self, ct: &mut CudaBooleanBlock, streams: &CudaStreams, @@ -108,9 +108,12 @@ impl CudaServerKey { let shift_plaintext = self.encoding().encode(Cleartext(1u64)).0; let scalar_vector = vec![shift_plaintext; ct_blocks]; - let mut d_decomposed_scalar = - CudaVec::::new_async(ct.0.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0); - d_decomposed_scalar.copy_from_cpu_async(scalar_vector.as_slice(), streams, 0); + let mut d_decomposed_scalar = unsafe { + CudaVec::::new_async(ct.0.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0) + }; + unsafe { + d_decomposed_scalar.copy_from_cpu_async(scalar_vector.as_slice(), streams, 0); + } cuda_lwe_ciphertext_plaintext_add_assign( &mut ct.0.as_mut().d_blocks, @@ -120,17 +123,6 @@ impl CudaServerKey { // Neither noise level nor the degree changes } - pub fn unchecked_bitnot_assign( - &self, - ct: &mut T, - streams: &CudaStreams, - ) { - unsafe { - self.unchecked_bitnot_assign_async(ct, streams); - } - streams.synchronize(); - } - /// Computes homomorphically bitand between two ciphertexts encrypting integer values. /// /// This function computes the operation without checking if it exceeds the capacity of the @@ -185,11 +177,7 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn unchecked_bitop_assign_async( + pub(crate) fn unchecked_bitop_assign( &self, ct_left: &mut T, ct_right: &T, @@ -207,62 +195,64 @@ impl CudaServerKey { let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count(); - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_bitop_assign( - streams, - ct_left.as_mut(), - ct_right.as_ref(), - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - op, - lwe_ciphertext_count.0 as u32, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_bitop_assign( - streams, - ct_left.as_mut(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - op, - lwe_ciphertext_count.0 as u32, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } } } @@ -385,10 +375,7 @@ impl CudaServerKey { ct_right: &T, streams: &CudaStreams, ) { - unsafe { - self.unchecked_bitop_assign_async(ct_left, ct_right, BitOpType::And, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(ct_left, ct_right, BitOpType::And, streams); } /// Computes homomorphically bitor between two ciphertexts encrypting integer values. @@ -451,10 +438,7 @@ impl CudaServerKey { ct_right: &T, streams: &CudaStreams, ) { - unsafe { - self.unchecked_bitop_assign_async(ct_left, ct_right, BitOpType::Or, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(ct_left, ct_right, BitOpType::Or, streams); } /// Computes homomorphically bitxor between two ciphertexts encrypting integer values. @@ -517,10 +501,7 @@ impl CudaServerKey { ct_right: &T, streams: &CudaStreams, ) { - unsafe { - self.unchecked_bitop_assign_async(ct_left, ct_right, BitOpType::Xor, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(ct_left, ct_right, BitOpType::Xor, streams); } /// Computes homomorphically bitand between two ciphertexts encrypting integer values. @@ -577,11 +558,7 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn bitand_assign_async( + pub fn bitand_assign( &self, ct_left: &mut T, ct_right: &T, @@ -611,19 +588,7 @@ impl CudaServerKey { (ct_left, &tmp_rhs) } }; - self.unchecked_bitop_assign_async(lhs, rhs, BitOpType::And, streams); - } - - pub fn bitand_assign( - &self, - ct_left: &mut T, - ct_right: &T, - streams: &CudaStreams, - ) { - unsafe { - self.bitand_assign_async(ct_left, ct_right, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(lhs, rhs, BitOpType::And, streams); } /// Computes homomorphically bitor between two ciphertexts encrypting integer values. @@ -680,11 +645,7 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn bitor_assign_async( + pub fn bitor_assign( &self, ct_left: &mut T, ct_right: &T, @@ -715,19 +676,7 @@ impl CudaServerKey { } }; - self.unchecked_bitop_assign_async(lhs, rhs, BitOpType::Or, streams); - } - - pub fn bitor_assign( - &self, - ct_left: &mut T, - ct_right: &T, - streams: &CudaStreams, - ) { - unsafe { - self.bitor_assign_async(ct_left, ct_right, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(lhs, rhs, BitOpType::Or, streams); } /// Computes homomorphically bitxor between two ciphertexts encrypting integer values. @@ -784,11 +733,7 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn bitxor_assign_async( + pub fn bitxor_assign( &self, ct_left: &mut T, ct_right: &T, @@ -819,19 +764,7 @@ impl CudaServerKey { } }; - self.unchecked_bitop_assign_async(lhs, rhs, BitOpType::Xor, streams); - } - - pub fn bitxor_assign( - &self, - ct_left: &mut T, - ct_right: &T, - streams: &CudaStreams, - ) { - unsafe { - self.bitxor_assign_async(ct_left, ct_right, streams); - } - streams.synchronize(); + self.unchecked_bitop_assign(lhs, rhs, BitOpType::Xor, streams); } /// Computes homomorphically bitnot for an encrypted integer value. @@ -880,28 +813,14 @@ impl CudaServerKey { result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn bitnot_assign_async( - &self, - ct: &mut T, - streams: &CudaStreams, - ) { + pub fn bitnot_assign(&self, ct: &mut T, streams: &CudaStreams) { if !ct.block_carries_are_empty() { self.full_propagate_assign(ct, streams); } - self.unchecked_bitnot_assign_async(ct, streams); + self.unchecked_bitnot_assign(ct, streams); } - pub fn bitnot_assign(&self, ct: &mut T, streams: &CudaStreams) { - unsafe { - self.bitnot_assign_async(ct, streams); - } - streams.synchronize(); - } pub fn get_bitand_size_on_gpu( &self, ct_left: &T, @@ -910,6 +829,7 @@ impl CudaServerKey { ) -> u64 { self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::And, streams) } + pub fn get_bitor_size_on_gpu( &self, ct_left: &T, @@ -918,6 +838,7 @@ impl CudaServerKey { ) -> u64 { self.get_bitop_size_on_gpu(ct_left, ct_right, BitOpType::Or, streams) } + pub fn get_bitxor_size_on_gpu( &self, ct_left: &T, diff --git a/tfhe/src/integer/gpu/server_key/radix/cmux.rs b/tfhe/src/integer/gpu/server_key/radix/cmux.rs index 360a24b66..b43b77850 100644 --- a/tfhe/src/integer/gpu/server_key/radix/cmux.rs +++ b/tfhe/src/integer/gpu/server_key/radix/cmux.rs @@ -9,11 +9,7 @@ use crate::integer::gpu::{ }; impl CudaServerKey { - /// # Safety - /// - /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until stream is synchronised - pub unsafe fn unchecked_if_then_else_async( + pub fn unchecked_if_then_else( &self, condition: &CudaBooleanBlock, true_ct: &T, @@ -89,19 +85,6 @@ impl CudaServerKey { result } - pub fn unchecked_if_then_else( - &self, - condition: &CudaBooleanBlock, - true_ct: &T, - false_ct: &T, - stream: &CudaStreams, - ) -> T { - let result = - unsafe { self.unchecked_if_then_else_async(condition, true_ct, false_ct, stream) }; - stream.synchronize(); - result - } - pub fn if_then_else( &self, condition: &CudaBooleanBlock, @@ -130,6 +113,7 @@ impl CudaServerKey { self.unchecked_if_then_else(condition, true_ct, false_ct, stream) } + pub fn get_if_then_else_size_on_gpu( &self, _condition: &CudaBooleanBlock, diff --git a/tfhe/src/integer/gpu/server_key/radix/comparison.rs b/tfhe/src/integer/gpu/server_key/radix/comparison.rs index b992a1f2a..902886489 100644 --- a/tfhe/src/integer/gpu/server_key/radix/comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/comparison.rs @@ -12,11 +12,7 @@ use crate::integer::gpu::{ use crate::shortint::ciphertext::Degree; impl CudaServerKey { - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_comparison_async( + pub fn unchecked_comparison( &self, ct_left: &T, ct_right: &T, @@ -49,127 +45,72 @@ impl CudaServerKey { let mut result = CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info)); - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut().as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - op, - T::IS_SIGNED, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut().as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - op, - T::IS_SIGNED, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } } result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_eq_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::EQ, streams) - } - - /// Compares for equality 2 ciphertexts - /// - /// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0 - /// - /// Requires carry bits to be empty - /// - /// # Example - /// - /// ```rust - /// use tfhe::core_crypto::gpu::CudaStreams; - /// use tfhe::core_crypto::gpu::vec::GpuIndex; - /// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; - /// use tfhe::integer::gpu::gen_keys_radix_gpu; - /// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; - /// - /// let gpu_index = 0; - /// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index)); - /// - /// let size = 4; - /// // Generate the client key and the server key: - /// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, size, &streams); - /// - /// let msg1 = 14u64; - /// let msg2 = 97u64; - /// - /// let ct1 = cks.encrypt(msg1); - /// let ct2 = cks.encrypt(msg2); - /// - /// // Copy to GPU - /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); - /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); - /// - /// let d_ct_res = sks.unchecked_eq(&d_ct1, &d_ct2, &streams); - /// - /// // Copy back to CPU - /// let ct_res = d_ct_res.to_boolean_block(&streams); - /// - /// // Decrypt: - /// let dec_result = cks.decrypt_bool(&ct_res); - /// assert_eq!(dec_result, msg1 == msg2); - /// ``` pub fn unchecked_eq( &self, ct_left: &T, @@ -179,68 +120,9 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.unchecked_eq_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + self.unchecked_comparison(ct_left, ct_right, ComparisonType::EQ, streams) } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_ne_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::NE, streams) - } - - /// Compares for equality 2 ciphertexts - /// - /// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0 - /// - /// Requires carry bits to be empty - /// - /// # Example - /// - /// ```rust - /// use tfhe::core_crypto::gpu::CudaStreams; - /// use tfhe::core_crypto::gpu::vec::GpuIndex; - /// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; - /// use tfhe::integer::gpu::gen_keys_radix_gpu; - /// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; - /// - /// let gpu_index = 0; - /// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index)); - /// - /// let size = 4; - /// // Generate the client key and the server key: - /// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, size, &streams); - /// - /// let msg1 = 14u64; - /// let msg2 = 97u64; - /// - /// let ct1 = cks.encrypt(msg1); - /// let ct2 = cks.encrypt(msg2); - /// - /// // Copy to GPU - /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); - /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); - /// - /// let d_ct_res = sks.unchecked_ne(&d_ct1, &d_ct2, &streams); - /// - /// // Copy back to CPU - /// let ct_res = d_ct_res.to_boolean_block(&streams); - /// - /// // Decrypt: - /// let dec_result = cks.decrypt_bool(&ct_res); - /// assert_eq!(dec_result, msg1 != msg2); - /// ``` pub fn unchecked_ne( &self, ct_left: &T, @@ -250,53 +132,7 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.unchecked_ne_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn eq_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let mut tmp_lhs; - let mut tmp_rhs; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.duplicate(streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (ct_left, &tmp_rhs) - } - (false, true) => { - tmp_lhs = ct_left.duplicate(streams); - self.full_propagate_assign(&mut tmp_lhs, streams); - (&tmp_lhs, ct_right) - } - (false, false) => { - tmp_lhs = ct_left.duplicate(streams); - tmp_rhs = ct_right.duplicate(streams); - - self.full_propagate_assign(&mut tmp_lhs, streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (&tmp_lhs, &tmp_rhs) - } - }; - - self.unchecked_eq_async(lhs, rhs, streams) + self.unchecked_comparison(ct_left, ct_right, ComparisonType::NE, streams) } /// Compares for equality 2 ciphertexts @@ -344,9 +180,35 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.eq_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + let mut tmp_lhs; + let mut tmp_rhs; + + let (lhs, rhs) = match ( + ct_left.block_carries_are_empty(), + ct_right.block_carries_are_empty(), + ) { + (true, true) => (ct_left, ct_right), + (true, false) => { + tmp_rhs = ct_right.duplicate(streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (ct_left, &tmp_rhs) + } + (false, true) => { + tmp_lhs = ct_left.duplicate(streams); + self.full_propagate_assign(&mut tmp_lhs, streams); + (&tmp_lhs, ct_right) + } + (false, false) => { + tmp_lhs = ct_left.duplicate(streams); + tmp_rhs = ct_right.duplicate(streams); + + self.full_propagate_assign(&mut tmp_lhs, streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (&tmp_lhs, &tmp_rhs) + } + }; + + self.unchecked_eq(lhs, rhs, streams) } pub(crate) fn get_comparison_size_on_gpu( @@ -465,51 +327,7 @@ impl CudaServerKey { actual_full_prop_mem.max(comparison_mem) } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn ne_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let mut tmp_lhs; - let mut tmp_rhs; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.duplicate(streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (ct_left, &tmp_rhs) - } - (false, true) => { - tmp_lhs = ct_left.duplicate(streams); - self.full_propagate_assign(&mut tmp_lhs, streams); - (&tmp_lhs, ct_right) - } - (false, false) => { - tmp_lhs = ct_left.duplicate(streams); - tmp_rhs = ct_right.duplicate(streams); - - self.full_propagate_assign(&mut tmp_lhs, streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (&tmp_lhs, &tmp_rhs) - } - }; - - self.unchecked_ne_async(lhs, rhs, streams) - } - - /// Compares for equality 2 ciphertexts + /// Compares for non equality 2 ciphertexts /// /// Returns a ciphertext containing 1 if lhs != rhs, otherwise 0 /// @@ -554,16 +372,38 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.ne_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + let mut tmp_lhs; + let mut tmp_rhs; + + let (lhs, rhs) = match ( + ct_left.block_carries_are_empty(), + ct_right.block_carries_are_empty(), + ) { + (true, true) => (ct_left, ct_right), + (true, false) => { + tmp_rhs = ct_right.duplicate(streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (ct_left, &tmp_rhs) + } + (false, true) => { + tmp_lhs = ct_left.duplicate(streams); + self.full_propagate_assign(&mut tmp_lhs, streams); + (&tmp_lhs, ct_right) + } + (false, false) => { + tmp_lhs = ct_left.duplicate(streams); + tmp_rhs = ct_right.duplicate(streams); + + self.full_propagate_assign(&mut tmp_lhs, streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (&tmp_lhs, &tmp_rhs) + } + }; + + self.unchecked_ne(lhs, rhs, streams) } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_gt_async( + pub fn unchecked_gt( &self, ct_left: &T, ct_right: &T, @@ -572,7 +412,19 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::GT, streams) + self.unchecked_comparison(ct_left, ct_right, ComparisonType::GT, streams) + } + + pub fn unchecked_ge( + &self, + ct_left: &T, + ct_right: &T, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + { + self.unchecked_comparison(ct_left, ct_right, ComparisonType::GE, streams) } /// Compares if lhs is strictly greater than rhs @@ -605,7 +457,7 @@ impl CudaServerKey { /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); /// - /// let d_ct_res = sks.unchecked_gt(&d_ct1, &d_ct2, &streams); + /// let d_ct_res = sks.gt(&d_ct1, &d_ct2, &streams); /// /// // Copy back to CPU /// let ct_res = d_ct_res.to_boolean_block(&streams); @@ -614,46 +466,7 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 > msg2); /// ``` - pub fn unchecked_gt( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.unchecked_gt_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_ge_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::GE, streams) - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn gt_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock + pub fn gt(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock where T: CudaIntegerRadixCiphertext, { @@ -685,16 +498,7 @@ impl CudaServerKey { } }; - self.unchecked_gt_async(lhs, rhs, streams) - } - - pub fn gt(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.gt_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + self.unchecked_gt(lhs, rhs, streams) } /// Compares if lhs is greater or equal than rhs @@ -729,7 +533,7 @@ impl CudaServerKey { /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); /// - /// let d_ct_res = sks.unchecked_ge(&d_ct1, &d_ct2, &streams); + /// let d_ct_res = sks.ge(&d_ct1, &d_ct2, &streams); /// /// // Copy back to CPU /// let ct_res = d_ct_res.to_boolean_block(&streams); @@ -738,30 +542,7 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 >= msg2); /// ``` - pub fn unchecked_ge( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.unchecked_ge_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn ge_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock + pub fn ge(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock where T: CudaIntegerRadixCiphertext, { @@ -793,23 +574,10 @@ impl CudaServerKey { } }; - self.unchecked_ge_async(lhs, rhs, streams) + self.unchecked_ge(lhs, rhs, streams) } - pub fn ge(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.ge_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_lt_async( + pub fn unchecked_lt( &self, ct_left: &T, ct_right: &T, @@ -818,7 +586,7 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::LT, streams) + self.unchecked_comparison(ct_left, ct_right, ComparisonType::LT, streams) } /// Compares if lhs is lower than rhs @@ -853,7 +621,7 @@ impl CudaServerKey { /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); /// - /// let d_ct_res = sks.unchecked_lt(&d_ct1, &d_ct2, &streams); + /// let d_ct_res = sks.lt(&d_ct1, &d_ct2, &streams); /// /// // Copy back to CPU /// let ct_res = d_ct_res.to_boolean_block(&streams); @@ -862,30 +630,7 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 < msg2); /// ``` - pub fn unchecked_lt( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.unchecked_lt_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn lt_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock + pub fn lt(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock where T: CudaIntegerRadixCiphertext, { @@ -917,23 +662,10 @@ impl CudaServerKey { } }; - self.unchecked_lt_async(lhs, rhs, streams) + self.unchecked_lt(lhs, rhs, streams) } - pub fn lt(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.lt_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_le_async( + pub fn unchecked_le( &self, ct_left: &T, ct_right: &T, @@ -942,7 +674,7 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - self.unchecked_comparison_async(ct_left, ct_right, ComparisonType::LE, streams) + self.unchecked_comparison(ct_left, ct_right, ComparisonType::LE, streams) } /// Compares if lhs is lower or equal than rhs @@ -977,7 +709,7 @@ impl CudaServerKey { /// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); /// let d_ct2 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct2, &streams); /// - /// let d_ct_res = sks.unchecked_le(&d_ct1, &d_ct2, &streams); + /// let d_ct_res = sks.le(&d_ct1, &d_ct2, &streams); /// /// // Copy back to CPU /// let ct_res = d_ct_res.to_boolean_block(&streams); @@ -986,30 +718,7 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 < msg2); /// ``` - pub fn unchecked_le( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.unchecked_le_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn le_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> CudaBooleanBlock + pub fn le(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock where T: CudaIntegerRadixCiphertext, { @@ -1041,16 +750,7 @@ impl CudaServerKey { } }; - self.unchecked_le_async(lhs, rhs, streams) - } - - pub fn le(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> CudaBooleanBlock - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.le_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + self.unchecked_le(lhs, rhs, streams) } pub fn get_eq_size_on_gpu( @@ -1107,112 +807,7 @@ impl CudaServerKey { self.get_comparison_size_on_gpu(ct_left, ct_right, ComparisonType::LE, streams) } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_max_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> T - where - T: CudaIntegerRadixCiphertext, - { - assert_eq!( - ct_left.as_ref().d_blocks.lwe_dimension(), - ct_right.as_ref().d_blocks.lwe_dimension() - ); - assert_eq!( - ct_left.as_ref().d_blocks.lwe_ciphertext_count(), - ct_right.as_ref().d_blocks.lwe_ciphertext_count() - ); - - let mut result = ct_left.duplicate(streams); - - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - ComparisonType::MAX, - T::IS_SIGNED, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - ComparisonType::MAX, - T::IS_SIGNED, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); - } - } - result - } - pub fn unchecked_max(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.unchecked_max_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_min_async( - &self, - ct_left: &T, - ct_right: &T, - streams: &CudaStreams, - ) -> T where T: CudaIntegerRadixCiphertext, { @@ -1227,64 +822,66 @@ impl CudaServerKey { let mut result = ct_left.duplicate(streams); - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - ComparisonType::MIN, - T::IS_SIGNED, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - ComparisonType::MIN, - T::IS_SIGNED, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + ComparisonType::MAX, + T::IS_SIGNED, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + ComparisonType::MAX, + T::IS_SIGNED, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } } result @@ -1294,63 +891,117 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.unchecked_min_async(ct_left, ct_right, streams) }; - streams.synchronize(); + assert_eq!( + ct_left.as_ref().d_blocks.lwe_dimension(), + ct_right.as_ref().d_blocks.lwe_dimension() + ); + assert_eq!( + ct_left.as_ref().d_blocks.lwe_ciphertext_count(), + ct_right.as_ref().d_blocks.lwe_ciphertext_count() + ); + + let mut result = ct_left.duplicate(streams); + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + ComparisonType::MIN, + T::IS_SIGNED, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + ComparisonType::MIN, + T::IS_SIGNED, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } + } result } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn max_async(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T - where - T: CudaIntegerRadixCiphertext, - { - let mut tmp_lhs; - let mut tmp_rhs; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.duplicate(streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (ct_left, &tmp_rhs) - } - (false, true) => { - tmp_lhs = ct_left.duplicate(streams); - self.full_propagate_assign(&mut tmp_lhs, streams); - (&tmp_lhs, ct_right) - } - (false, false) => { - tmp_lhs = ct_left.duplicate(streams); - tmp_rhs = ct_right.duplicate(streams); - - self.full_propagate_assign(&mut tmp_lhs, streams); - self.full_propagate_assign(&mut tmp_rhs, streams); - (&tmp_lhs, &tmp_rhs) - } - }; - self.unchecked_max_async(lhs, rhs, streams) - } - pub fn max(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T where T: CudaIntegerRadixCiphertext, { - let result = unsafe { self.max_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result + let mut tmp_lhs; + let mut tmp_rhs; + + let (lhs, rhs) = match ( + ct_left.block_carries_are_empty(), + ct_right.block_carries_are_empty(), + ) { + (true, true) => (ct_left, ct_right), + (true, false) => { + tmp_rhs = ct_right.duplicate(streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (ct_left, &tmp_rhs) + } + (false, true) => { + tmp_lhs = ct_left.duplicate(streams); + self.full_propagate_assign(&mut tmp_lhs, streams); + (&tmp_lhs, ct_right) + } + (false, false) => { + tmp_lhs = ct_left.duplicate(streams); + tmp_rhs = ct_right.duplicate(streams); + + self.full_propagate_assign(&mut tmp_lhs, streams); + self.full_propagate_assign(&mut tmp_rhs, streams); + (&tmp_lhs, &tmp_rhs) + } + }; + self.unchecked_max(lhs, rhs, streams) } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn min_async(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T + pub fn min(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T where T: CudaIntegerRadixCiphertext, { @@ -1381,17 +1032,9 @@ impl CudaServerKey { (&tmp_lhs, &tmp_rhs) } }; - self.unchecked_min_async(lhs, rhs, streams) + self.unchecked_min(lhs, rhs, streams) } - pub fn min(&self, ct_left: &T, ct_right: &T, streams: &CudaStreams) -> T - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.min_async(ct_left, ct_right, streams) }; - streams.synchronize(); - result - } pub fn get_max_size_on_gpu( &self, ct_left: &T, diff --git a/tfhe/src/integer/gpu/server_key/radix/neg.rs b/tfhe/src/integer/gpu/server_key/radix/neg.rs index 145b72348..b8257a561 100644 --- a/tfhe/src/integer/gpu/server_key/radix/neg.rs +++ b/tfhe/src/integer/gpu/server_key/radix/neg.rs @@ -51,33 +51,20 @@ impl CudaServerKey { &self, ctxt: &T, streams: &CudaStreams, - ) -> T { - let result = unsafe { self.unchecked_neg_async(ctxt, streams) }; - streams.synchronize(); - result - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronised - pub unsafe fn unchecked_neg_async( - &self, - ctxt: &T, - streams: &CudaStreams, ) -> T { let mut ciphertext_out = ctxt.duplicate(streams); let info = ctxt.as_ref().info.blocks.first().unwrap(); - cuda_backend_unchecked_negate( - streams, - ciphertext_out.as_mut(), - ctxt.as_ref(), - info.message_modulus.0 as u32, - info.carry_modulus.0 as u32, - ); - + unsafe { + cuda_backend_unchecked_negate( + streams, + ciphertext_out.as_mut(), + ctxt.as_ref(), + info.message_modulus.0 as u32, + info.carry_modulus.0 as u32, + ); + } ciphertext_out } @@ -121,28 +108,6 @@ impl CudaServerKey { /// assert_eq!(modulus - msg, dec); /// ``` pub fn neg(&self, ctxt: &T, streams: &CudaStreams) -> T { - let result = unsafe { self.neg_async(ctxt, streams) }; - streams.synchronize(); - result - } - - pub fn get_neg_size_on_gpu( - &self, - ctxt: &T, - streams: &CudaStreams, - ) -> u64 { - self.get_scalar_add_size_on_gpu(ctxt, streams) - } - - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn neg_async( - &self, - ctxt: &T, - streams: &CudaStreams, - ) -> T { let mut tmp_ctxt; let ct = if ctxt.block_carries_are_empty() { @@ -153,20 +118,20 @@ impl CudaServerKey { &mut tmp_ctxt }; - let mut res = self.unchecked_neg_async(ct, streams); + let mut res = self.unchecked_neg(ct, streams); let _carry = self.propagate_single_carry_assign(&mut res, streams, None, OutputFlag::None); res } - /// # Safety - /// - /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until streams is synchronized - pub unsafe fn overflowing_neg_async( + pub fn get_neg_size_on_gpu( &self, ctxt: &T, streams: &CudaStreams, - ) -> (T, CudaBooleanBlock) + ) -> u64 { + self.get_scalar_add_size_on_gpu(ctxt, streams) + } + + pub fn overflowing_neg(&self, ctxt: &T, streams: &CudaStreams) -> (T, CudaBooleanBlock) where T: CudaIntegerRadixCiphertext, { @@ -178,7 +143,7 @@ impl CudaServerKey { ct }; - self.bitnot_assign_async(&mut ct, streams); + self.bitnot_assign(&mut ct, streams); if T::IS_SIGNED { let tmp = CudaSignedRadixCiphertext { @@ -192,18 +157,9 @@ impl CudaServerKey { ciphertext: ct.into_inner(), }; let mut overflowed = self.unsigned_overflowing_scalar_add_assign(&mut tmp, 1, streams); - self.unchecked_boolean_bitnot_assign_async(&mut overflowed, streams); + self.unchecked_boolean_bitnot_assign(&mut overflowed, streams); let result = T::from(tmp.into_inner()); (result, overflowed) } } - - pub fn overflowing_neg(&self, ctxt: &T, streams: &CudaStreams) -> (T, CudaBooleanBlock) - where - T: CudaIntegerRadixCiphertext, - { - let result = unsafe { self.overflowing_neg_async(ctxt, streams) }; - streams.synchronize(); - result - } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index 55910cd89..e331c693b 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -297,7 +297,7 @@ impl CudaServerKey { ) } else { let scalar_as_trivial = self.create_trivial_radix(scalar, num_blocks, streams); - self.unchecked_comparison_async(ct, &scalar_as_trivial, op, streams) + self.unchecked_comparison(ct, &scalar_as_trivial, op, streams) } } else { // Unsigned