From 32b1a7ab1d6f701e8533e5b2aa68f3255a5bafd2 Mon Sep 17 00:00:00 2001 From: Enzo Di Maria Date: Thu, 20 Nov 2025 16:00:16 +0100 Subject: [PATCH] refactor(gpu): unchecked_match_value_or to backend --- .../cuda/include/integer/cast.h | 52 +++ .../cuda/include/integer/integer.h | 44 +++ .../cuda/include/integer/vector_find.h | 96 ++++++ .../cuda/src/integer/cast.cu | 52 +++ .../cuda/src/integer/cast.cuh | 69 ++++ .../cuda/src/integer/vector_find.cu | 48 +++ .../cuda/src/integer/vector_find.cuh | 45 +++ backends/tfhe-cuda-backend/src/bindings.rs | 90 +++++ .../high_level_api/integers/unsigned/base.rs | 57 ++++ .../integers/unsigned/tests/cpu.rs | 6 + .../integers/unsigned/tests/gpu.rs | 39 +++ .../integers/unsigned/tests/mod.rs | 49 +++ tfhe/src/integer/gpu/ciphertext/info.rs | 12 - tfhe/src/integer/gpu/mod.rs | 321 ++++++++++++++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 135 ++++---- .../server_key/radix/tests_long_run/mod.rs | 40 +++ .../tests_long_run/test_random_op_sequence.rs | 25 +- .../gpu/server_key/radix/vector_find.rs | 242 +++++++++++-- .../radix_parallel/tests_long_run/mod.rs | 15 +- .../tests_long_run/test_random_op_sequence.rs | 58 +++- 20 files changed, 1396 insertions(+), 99 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/cast.h b/backends/tfhe-cuda-backend/cuda/include/integer/cast.h index fda0300e1..94be1e6cc 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/cast.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/cast.h @@ -75,3 +75,55 @@ template struct int_extend_radix_with_sign_msb_buffer { cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); } }; + +template struct int_cast_to_unsigned_buffer { + int_radix_params params; + bool allocate_gpu_memory; + + bool requires_full_propagate; + bool requires_sign_extension; + + int_fullprop_buffer *prop_buffer; + int_extend_radix_with_sign_msb_buffer *extend_buffer; + + int_cast_to_unsigned_buffer(CudaStreams streams, int_radix_params params, + uint32_t num_input_blocks, + uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, + bool allocate_gpu_memory, + uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->requires_full_propagate = requires_full_propagate; + + this->prop_buffer = nullptr; + this->extend_buffer = nullptr; + + if (requires_full_propagate) { + this->prop_buffer = new int_fullprop_buffer( + streams, params, allocate_gpu_memory, size_tracker); + } + + this->requires_sign_extension = + (target_num_blocks > num_input_blocks) && input_is_signed; + + if (this->requires_sign_extension) { + uint32_t num_blocks_to_add = target_num_blocks - num_input_blocks; + this->extend_buffer = new int_extend_radix_with_sign_msb_buffer( + streams, params, num_input_blocks, num_blocks_to_add, + allocate_gpu_memory, size_tracker); + } + } + + void release(CudaStreams streams) { + if (this->prop_buffer) { + this->prop_buffer->release(streams); + delete this->prop_buffer; + } + if (this->extend_buffer) { + this->extend_buffer->release(streams); + delete this->extend_buffer; + } + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index ec28f7af3..65b2c8f9e 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -569,6 +569,10 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *input, CudaStreamsFFI streams); +void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *input, + CudaStreamsFFI streams); + uint64_t scratch_cuda_apply_noise_squashing( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, @@ -850,6 +854,46 @@ void cuda_unchecked_match_value_64( void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); + +uint64_t scratch_cuda_cast_to_unsigned_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_cast_to_unsigned_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI *input, int8_t *mem_ptr, + uint32_t target_num_blocks, bool input_is_signed, + void *const *bsks, void *const *ksks); + +void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); + +uint64_t scratch_cuda_unchecked_match_value_or_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_matches, uint32_t num_input_blocks, + uint32_t num_match_packed_blocks, uint32_t num_final_blocks, + uint32_t max_output_is_zero, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_unchecked_match_value_or_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_in_ct, + const uint64_t *h_match_inputs, const uint64_t *h_match_outputs, + const uint64_t *h_or_value, int8_t *mem, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); } // extern C #endif // CUDA_INTEGER_H diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h b/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h index b597a073f..6a5e846a3 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h @@ -1,4 +1,5 @@ #pragma once +#include "cast.h" #include "integer/comparison.h" #include "integer/radix_ciphertext.cuh" #include "integer_utilities.h" @@ -593,3 +594,98 @@ template struct int_unchecked_match_buffer { delete this->packed_selectors_ct; } }; + +template struct int_unchecked_match_value_or_buffer { + int_radix_params params; + bool allocate_gpu_memory; + + uint32_t num_matches; + uint32_t num_input_blocks; + uint32_t num_match_packed_blocks; + uint32_t num_final_blocks; + bool max_output_is_zero; + + int_unchecked_match_buffer *match_buffer; + int_cmux_buffer *cmux_buffer; + + CudaRadixCiphertextFFI *tmp_match_result; + CudaRadixCiphertextFFI *tmp_match_bool; + CudaRadixCiphertextFFI *tmp_or_value; + + Torus *d_or_value; + + int_unchecked_match_value_or_buffer( + CudaStreams streams, int_radix_params params, uint32_t num_matches, + uint32_t num_input_blocks, uint32_t num_match_packed_blocks, + uint32_t num_final_blocks, bool max_output_is_zero, + bool allocate_gpu_memory, uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->num_matches = num_matches; + this->num_input_blocks = num_input_blocks; + this->num_match_packed_blocks = num_match_packed_blocks; + this->num_final_blocks = num_final_blocks; + this->max_output_is_zero = max_output_is_zero; + + this->match_buffer = new int_unchecked_match_buffer( + streams, params, num_matches, num_input_blocks, num_match_packed_blocks, + max_output_is_zero, allocate_gpu_memory, size_tracker); + + this->cmux_buffer = new int_cmux_buffer( + streams, [](Torus x) -> Torus { return x == 1; }, params, + num_final_blocks, allocate_gpu_memory, size_tracker); + + this->tmp_match_result = new CudaRadixCiphertextFFI; + this->tmp_match_bool = new CudaRadixCiphertextFFI; + this->tmp_or_value = new CudaRadixCiphertextFFI; + + this->d_or_value = (Torus *)cuda_malloc_with_size_tracking_async( + num_final_blocks * sizeof(Torus), streams.stream(0), + streams.gpu_index(0), size_tracker, allocate_gpu_memory); + + if (!max_output_is_zero) { + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_match_result, + num_final_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + } + + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_match_bool, 1, + params.big_lwe_dimension, size_tracker, allocate_gpu_memory); + + create_zero_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), this->tmp_or_value, + num_final_blocks, params.big_lwe_dimension, size_tracker, + allocate_gpu_memory); + } + + void release(CudaStreams streams) { + this->match_buffer->release(streams); + delete this->match_buffer; + + this->cmux_buffer->release(streams); + delete this->cmux_buffer; + + if (!max_output_is_zero) { + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_match_result, + this->allocate_gpu_memory); + } + delete this->tmp_match_result; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_match_bool, + this->allocate_gpu_memory); + delete this->tmp_match_bool; + + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->tmp_or_value, + this->allocate_gpu_memory); + delete this->tmp_or_value; + + cuda_drop_async(this->d_or_value, streams.stream(0), streams.gpu_index(0)); + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu index f9042b23e..f63d6c872 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu @@ -18,6 +18,15 @@ void trim_radix_blocks_lsb_64(CudaRadixCiphertextFFI *output, cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0)); } +void trim_radix_blocks_msb_64(CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI const *input, + CudaStreamsFFI streams) { + + auto cuda_streams = CudaStreams(streams); + host_trim_radix_blocks_msb(output, input, cuda_streams); + cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0)); +} + uint64_t scratch_cuda_extend_radix_with_sign_msb_64( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, @@ -64,3 +73,46 @@ void cleanup_cuda_extend_radix_with_sign_msb_64(CudaStreamsFFI streams, delete mem_ptr; *mem_ptr_void = nullptr; } + +uint64_t scratch_cuda_cast_to_unsigned_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_cast_to_unsigned( + CudaStreams(streams), (int_cast_to_unsigned_buffer **)mem_ptr, + params, num_input_blocks, target_num_blocks, input_is_signed, + requires_full_propagate, allocate_gpu_memory); +} + +void cuda_cast_to_unsigned_64(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI *input, int8_t *mem_ptr, + uint32_t target_num_blocks, bool input_is_signed, + void *const *bsks, void *const *ksks) { + + host_cast_to_unsigned( + CudaStreams(streams), output, input, + (int_cast_to_unsigned_buffer *)mem_ptr, target_num_blocks, + input_is_signed, bsks, (uint64_t **)ksks); +} + +void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + int_cast_to_unsigned_buffer *mem_ptr = + (int_cast_to_unsigned_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh index b08297d9c..404914be8 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh @@ -36,6 +36,23 @@ __host__ void host_trim_radix_blocks_lsb(CudaRadixCiphertextFFI *output, input->num_radix_blocks); } +template +__host__ void +host_trim_radix_blocks_msb(CudaRadixCiphertextFFI *output_radix, + const CudaRadixCiphertextFFI *input_radix, + CudaStreams streams) { + + PANIC_IF_FALSE(input_radix->num_radix_blocks >= + output_radix->num_radix_blocks, + "Cuda error: input radix ciphertext has fewer blocks than " + "required to keep"); + + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), output_radix, 0, + output_radix->num_radix_blocks, input_radix, 0, + output_radix->num_radix_blocks); +} + template __host__ uint64_t scratch_extend_radix_with_sign_msb( CudaStreams streams, int_extend_radix_with_sign_msb_buffer **mem_ptr, @@ -91,4 +108,56 @@ __host__ void host_extend_radix_with_sign_msb( POP_RANGE() } +template +uint64_t scratch_cuda_cast_to_unsigned( + CudaStreams streams, int_cast_to_unsigned_buffer **mem_ptr, + int_radix_params params, uint32_t num_input_blocks, + uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, bool allocate_gpu_memory) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_cast_to_unsigned_buffer( + streams, params, num_input_blocks, target_num_blocks, input_is_signed, + requires_full_propagate, allocate_gpu_memory, size_tracker); + + return size_tracker; +} + +template +__host__ void +host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI *input, + int_cast_to_unsigned_buffer *mem_ptr, + uint32_t target_num_blocks, bool input_is_signed, + void *const *bsks, Torus *const *ksks) { + + uint32_t current_num_blocks = input->num_radix_blocks; + + if (mem_ptr->requires_full_propagate) { + host_full_propagate_inplace(streams, input, mem_ptr->prop_buffer, + ksks, bsks, current_num_blocks); + } + + if (target_num_blocks > current_num_blocks) { + uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks; + + if (input_is_signed) { + host_extend_radix_with_sign_msb( + streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add, + bsks, (Torus **)ksks); + } else { + host_extend_radix_with_trivial_zero_blocks_msb(output, input, + streams); + } + + } else if (target_num_blocks < current_num_blocks) { + host_trim_radix_blocks_msb(output, input, streams); + + } else { + copy_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), output, 0, current_num_blocks, + input, 0, current_num_blocks); + } +} + #endif diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu index 89e257dc6..f953e5a74 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu @@ -173,3 +173,51 @@ void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams, delete mem_ptr; *mem_ptr_void = nullptr; } + +uint64_t scratch_cuda_unchecked_match_value_or_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_matches, uint32_t num_input_blocks, + uint32_t num_match_packed_blocks, uint32_t num_final_blocks, + uint32_t max_output_is_zero, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_unchecked_match_value_or( + CudaStreams(streams), + (int_unchecked_match_value_or_buffer **)mem_ptr, params, + num_matches, num_input_blocks, num_match_packed_blocks, num_final_blocks, + max_output_is_zero, allocate_gpu_memory); +} + +void cuda_unchecked_match_value_or_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_in_ct, + const uint64_t *h_match_inputs, const uint64_t *h_match_outputs, + const uint64_t *h_or_value, int8_t *mem, void *const *bsks, + void *const *ksks) { + + host_unchecked_match_value_or( + CudaStreams(streams), lwe_array_out, lwe_array_in_ct, h_match_inputs, + h_match_outputs, h_or_value, + (int_unchecked_match_value_or_buffer *)mem, bsks, + (uint64_t *const *)ksks); +} + +void cleanup_cuda_unchecked_match_value_or_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + int_unchecked_match_value_or_buffer *mem_ptr = + (int_unchecked_match_value_or_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh index a5df18556..ea52d1788 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh @@ -1,5 +1,7 @@ #pragma once +#include "integer/cast.cuh" +#include "integer/cmux.cuh" #include "integer/comparison.cuh" #include "integer/integer.cuh" #include "integer/radix_ciphertext.cuh" @@ -453,3 +455,46 @@ uint64_t scratch_cuda_unchecked_match_value( return size_tracker; } + +template +uint64_t scratch_cuda_unchecked_match_value_or( + CudaStreams streams, int_unchecked_match_value_or_buffer **mem_ptr, + int_radix_params params, uint32_t num_matches, uint32_t num_input_blocks, + uint32_t num_match_packed_blocks, uint32_t num_final_blocks, + bool max_output_is_zero, bool allocate_gpu_memory) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_unchecked_match_value_or_buffer( + streams, params, num_matches, num_input_blocks, num_match_packed_blocks, + num_final_blocks, max_output_is_zero, allocate_gpu_memory, size_tracker); + + return size_tracker; +} + +template +__host__ void host_unchecked_match_value_or( + CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_in_ct, + const uint64_t *h_match_inputs, const uint64_t *h_match_outputs, + const uint64_t *h_or_value, + int_unchecked_match_value_or_buffer *mem_ptr, void *const *bsks, + Torus *const *ksks) { + + host_unchecked_match_value(streams, mem_ptr->tmp_match_result, + mem_ptr->tmp_match_bool, lwe_array_in_ct, + h_match_inputs, h_match_outputs, + mem_ptr->match_buffer, bsks, ksks); + + cuda_memcpy_async_to_gpu(mem_ptr->d_or_value, h_or_value, + mem_ptr->num_final_blocks * sizeof(Torus), + streams.stream(0), streams.gpu_index(0)); + + set_trivial_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_or_value, + mem_ptr->d_or_value, (Torus *)h_or_value, mem_ptr->num_final_blocks, + mem_ptr->params.message_modulus, mem_ptr->params.carry_modulus); + + host_cmux(streams, lwe_array_out, mem_ptr->tmp_match_bool, + mem_ptr->tmp_match_result, mem_ptr->tmp_or_value, + mem_ptr->cmux_buffer, bsks, (Torus **)ksks); +} diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index e77a2cfd8..de14e045e 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1270,6 +1270,13 @@ unsafe extern "C" { streams: CudaStreamsFFI, ); } +unsafe extern "C" { + pub fn trim_radix_blocks_msb_64( + output: *mut CudaRadixCiphertextFFI, + input: *const CudaRadixCiphertextFFI, + streams: CudaStreamsFFI, + ); +} unsafe extern "C" { pub fn scratch_cuda_apply_noise_squashing( streams: CudaStreamsFFI, @@ -1872,6 +1879,89 @@ unsafe extern "C" { mem_ptr_void: *mut *mut i8, ); } +unsafe extern "C" { + pub fn scratch_cuda_cast_to_unsigned_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_input_blocks: u32, + target_num_blocks: u32, + input_is_signed: bool, + requires_full_propagate: bool, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_cast_to_unsigned_64( + streams: CudaStreamsFFI, + output: *mut CudaRadixCiphertextFFI, + input: *mut CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + target_num_blocks: u32, + input_is_signed: bool, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); +} +unsafe extern "C" { + pub fn scratch_cuda_unchecked_match_value_or_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_matches: u32, + num_input_blocks: u32, + num_match_packed_blocks: u32, + num_final_blocks: u32, + max_output_is_zero: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_unchecked_match_value_or_64( + streams: CudaStreamsFFI, + lwe_array_out: *mut CudaRadixCiphertextFFI, + lwe_array_in_ct: *const CudaRadixCiphertextFFI, + h_match_inputs: *const u64, + h_match_outputs: *const u64, + h_or_value: *const u64, + mem: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_unchecked_match_value_or_64( + streams: CudaStreamsFFI, + mem_ptr_void: *mut *mut i8, + ); +} unsafe extern "C" { pub fn scratch_cuda_integer_compress_radix_ciphertext_64( streams: CudaStreamsFFI, diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index a56f04578..caeba5e56 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -1363,6 +1363,63 @@ where }) } + /// Returns the estimated memory usage (in bytes) required on the GPU to perform + /// the `match_value_or` operation. + /// + /// This is useful to check if the operation fits in the GPU memory before attempting execution. + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16, MatchValues}; + /// + /// let config = ConfigBuilder::default().build(); + /// let (client_key, server_key) = generate_keys(config); + /// set_server_key(server_key); + /// + /// let a = FheUint16::encrypt(17u16, &client_key); + /// + /// let match_values = MatchValues::new(vec![(0u16, 3u16), (17u16, 25u16)]).unwrap(); + /// + /// #[cfg(feature = "gpu")] + /// { + /// let size_bytes = a + /// .get_match_value_or_size_on_gpu(&match_values, 55u16) + /// .unwrap(); + /// println!("Memory required on GPU: {} bytes", size_bytes); + /// assert!(size_bytes > 0); + /// } + /// ``` + #[cfg(feature = "gpu")] + pub fn get_match_value_or_size_on_gpu( + &self, + matches: &MatchValues, + or_value: Clear, + ) -> crate::Result + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(_) => Err(crate::Error::new( + "This function is only available when using the CUDA backend".to_string(), + )), + InternalServerKey::Cuda(cuda_key) => { + let streams = &cuda_key.streams; + let ct_on_gpu = self.ciphertext.on_gpu(streams); + + let size = cuda_key.key.key.get_unchecked_match_value_or_size_on_gpu( + &ct_on_gpu, matches, or_value, streams, + ); + Ok(size) + } + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + panic!("Hpu does not support this operation.") + } + }) + } + /// Reverse the bit of the unsigned integer /// /// # Example diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index ad04472a8..94cc7b22f 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -653,3 +653,9 @@ fn test_compressed_cpk_encrypt_cast_compute_hl() { let clear: u64 = mul.decrypt(&client_key); assert_eq!(clear, (input_msg * multiplier) % modulus); } + +#[test] +fn test_match_value_or() { + let client_key = setup_default_cpu(); + super::test_case_match_value_or(&client_key); +} diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index b2dade8d2..938ec6f2a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -960,3 +960,42 @@ fn test_gpu_get_match_value_size_on_gpu() { assert!(memory_size > 0); } } + +#[test] +fn test_match_value_or_gpu() { + let client_key = setup_classical_gpu(); + super::test_case_match_value_or(&client_key); +} + +#[test] +fn test_match_value_or_gpu_multibit() { + let client_key = setup_multibit_gpu(); + super::test_case_match_value_or(&client_key); +} + +#[test] +fn test_gpu_get_match_value_or_size_on_gpu() { + for setup_fn in GPU_SETUP_FN { + let cks = setup_fn(); + let mut rng = rand::thread_rng(); + let clear_a = rng.gen::(); + let or_value = rng.gen::(); + + let mut a = FheUint32::try_encrypt(clear_a, &cks).unwrap(); + a.move_to_current_device(); + + let match_values = MatchValues::new(vec![ + (0u32, 10u32), + (1u32, 20u32), + (clear_a, 30u32), + (u32::MAX, 40u32), + ]) + .unwrap(); + + let memory_size = a + .get_match_value_or_size_on_gpu(&match_values, or_value) + .unwrap(); + check_valid_cuda_malloc_assert_oom(memory_size, GpuIndex::new(0)); + assert!(memory_size > 0); + } +} diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs index 17ecaac45..c42ed0b8a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs @@ -839,3 +839,52 @@ fn test_case_match_value(cks: &ClientKey) { } } } + +fn test_case_match_value_or(cks: &ClientKey) { + let mut rng = thread_rng(); + + for _ in 0..5 { + let clear_in = rng.gen::(); + let ct = FheUint8::encrypt(clear_in, cks); + let clear_or_value = rng.gen::(); + + let should_match = rng.gen_bool(0.5); + + let mut map: HashMap = HashMap::new(); + let mut pairs = Vec::new(); + + let expected_value = if should_match { + let val = rng.gen::(); + map.insert(clear_in, val); + pairs.push((clear_in, val)); + val + } else { + clear_or_value + }; + + let num_entries = rng.gen_range(1..10); + for _ in 0..num_entries { + let mut k = rng.gen::(); + while !should_match && k == clear_in { + k = rng.gen::(); + } + + if let std::collections::hash_map::Entry::Vacant(e) = map.entry(k) { + let v = rng.gen::(); + e.insert(v); + pairs.push((k, v)); + } + } + + let matches = MatchValues::new(pairs).unwrap(); + + let result: FheUint8 = ct.match_value_or(&matches, clear_or_value).unwrap(); + + let dec_result: u8 = result.decrypt(cks); + + assert_eq!( + dec_result, expected_value, + "Mismatch on result value for input {clear_in}. Should match: {should_match}" + ); + } +} diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index e8ffe7f53..0d69307ba 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -71,18 +71,6 @@ impl CudaRadixCiphertextInfo { new_block_info } - pub(crate) fn after_trim_radix_blocks_msb(&self, num_blocks: usize) -> Self { - assert!(num_blocks > 0); - - let mut new_block_info = Self { - blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)), - }; - new_block_info - .blocks - .extend(self.blocks[..num_blocks].iter().copied()); - new_block_info - } - pub fn duplicate(&self) -> Self { Self { blocks: self diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 0cbc7c0a1..a6e65b5ee 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -7504,6 +7504,34 @@ pub(crate) unsafe fn cuda_backend_trim_radix_blocks_lsb( update_noise_degree(output, &cuda_ffi_output); } +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_trim_radix_blocks_msb( + output: &mut CudaRadixCiphertext, + input: &CudaRadixCiphertext, + streams: &CudaStreams, +) { + let mut input_degrees = input.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut input_noise_levels = input.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let mut output_degrees = output.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut output_noise_levels = output.info.blocks.iter().map(|b| b.noise_level.0).collect(); + + let mut cuda_ffi_output = + prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels); + + let cuda_ffi_input = prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels); + + trim_radix_blocks_msb_64( + &raw mut cuda_ffi_output, + &raw const cuda_ffi_input, + streams.ffi(), + ); + update_noise_degree(output, &cuda_ffi_output); +} + /// # Safety /// /// - The data must not be moved or dropped while being used by the CUDA kernel. @@ -8518,6 +8546,11 @@ pub(crate) unsafe fn cuda_backend_compute_equality_selectors, +) -> u64 { + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + let size_tracker = unsafe { + scratch_cuda_unchecked_match_value_or_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_matches, + num_input_blocks, + num_match_packed_blocks, + num_output_blocks, + max_output_is_zero as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + false, + noise_reduction_type as u32, + ) + }; + + unsafe { + cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)) + }; + + size_tracker +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_cast_to_unsigned( + streams: &CudaStreams, + output: &mut CudaRadixCiphertext, + input: &mut CudaRadixCiphertext, + input_is_signed: bool, + requires_full_propagate: bool, + target_num_blocks: u32, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + big_lwe_dimension: LweDimension, + small_lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let message_modulus = input.info.blocks.first().unwrap().message_modulus; + let carry_modulus = input.info.blocks.first().unwrap().carry_modulus; + let num_input_blocks = input.d_blocks.lwe_ciphertext_count().0 as u32; + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut input_degrees: Vec = input.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut input_noise_levels: Vec = + input.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let mut cuda_ffi_input = + prepare_cuda_radix_ffi(input, &mut input_degrees, &mut input_noise_levels); + + let mut output_degrees: Vec = output.info.blocks.iter().map(|b| b.degree.0).collect(); + let mut output_noise_levels: Vec = + output.info.blocks.iter().map(|b| b.noise_level.0).collect(); + let mut cuda_ffi_output = + prepare_cuda_radix_ffi(output, &mut output_degrees, &mut output_noise_levels); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + scratch_cuda_cast_to_unsigned_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_input_blocks, + target_num_blocks, + input_is_signed, + requires_full_propagate, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_cast_to_unsigned_64( + streams.ffi(), + &raw mut cuda_ffi_output, + &raw mut cuda_ffi_input, + mem_ptr, + target_num_blocks, + input_is_signed, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(output, &cuda_ffi_output); + if requires_full_propagate { + update_noise_degree(input, &cuda_ffi_input); + } +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_unchecked_match_value_or< + T: UnsignedInteger, + B: Numeric, + R: CudaIntegerRadixCiphertext, +>( + streams: &CudaStreams, + lwe_array_out: &mut R, + lwe_array_in_ct: &CudaRadixCiphertext, + h_match_inputs: &[u64], + h_match_outputs: &[u64], + h_or_value: &[u64], + num_matches: u32, + num_input_blocks: u32, + num_match_packed_blocks: u32, + num_final_blocks: u32, + max_output_is_zero: bool, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + big_lwe_dimension: LweDimension, + small_lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut ffi_out_degrees: Vec = lwe_array_out + .as_ref() + .info + .blocks + .iter() + .map(|b| b.degree.get()) + .collect(); + let mut ffi_out_noise_levels: Vec = lwe_array_out + .as_ref() + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut ffi_out_struct = prepare_cuda_radix_ffi( + lwe_array_out.as_ref(), + &mut ffi_out_degrees, + &mut ffi_out_noise_levels, + ); + + let mut ffi_in_ct_degrees: Vec = lwe_array_in_ct + .info + .blocks + .iter() + .map(|b| b.degree.get()) + .collect(); + let mut ffi_in_ct_noise_levels: Vec = lwe_array_in_ct + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let ffi_in_ct_struct = prepare_cuda_radix_ffi( + lwe_array_in_ct, + &mut ffi_in_ct_degrees, + &mut ffi_in_ct_noise_levels, + ); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + scratch_cuda_unchecked_match_value_or_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_matches, + num_input_blocks, + num_match_packed_blocks, + num_final_blocks, + max_output_is_zero as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_unchecked_match_value_or_64( + streams.ffi(), + &raw mut ffi_out_struct, + &raw const ffi_in_ct_struct, + h_match_inputs.as_ptr(), + h_match_outputs.as_ptr(), + h_or_value.as_ptr(), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_unchecked_match_value_or_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(lwe_array_out.as_mut(), &ffi_out_struct); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index b4ccf6958..e76ae5b1b 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -17,11 +17,11 @@ use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey; use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::{ cuda_backend_apply_bivariate_lut, cuda_backend_apply_many_univariate_lut, - cuda_backend_apply_univariate_lut, cuda_backend_compute_prefix_sum_hillis_steele, - cuda_backend_extend_radix_with_sign_msb, + cuda_backend_apply_univariate_lut, cuda_backend_cast_to_unsigned, + cuda_backend_compute_prefix_sum_hillis_steele, cuda_backend_extend_radix_with_sign_msb, cuda_backend_extend_radix_with_trivial_zero_blocks_msb, cuda_backend_full_propagate_assign, cuda_backend_noise_squashing, cuda_backend_propagate_single_carry_assign, - cuda_backend_trim_radix_blocks_lsb, CudaServerKey, PBSType, + cuda_backend_trim_radix_blocks_lsb, cuda_backend_trim_radix_blocks_msb, CudaServerKey, PBSType, }; use crate::integer::server_key::radix_parallel::OutputFlag; use crate::shortint::ciphertext::{Degree, NoiseLevel}; @@ -584,31 +584,19 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + let total_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + if num_blocks == 0 { return ct.duplicate(streams); } - let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 - num_blocks; - let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus(); - let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size(); - let shift = new_num_blocks * lwe_size.0; + let new_num_blocks = total_blocks - num_blocks; + let mut trimmed_ct: T = self.create_trivial_zero_radix(new_num_blocks, streams); - let mut trimmed_ct_vec = CudaVec::new(new_num_blocks * lwe_size.0, streams, 0); unsafe { - trimmed_ct_vec.copy_src_range_gpu_to_gpu_async( - 0..shift, - &ct.as_ref().d_blocks.0.d_vec, - streams, - 0, - ); - streams.synchronize(); + cuda_backend_trim_radix_blocks_msb(trimmed_ct.as_mut(), ct.as_ref(), streams); } - let trimmed_ct_list = CudaLweCiphertextList::from_cuda_vec( - trimmed_ct_vec, - LweCiphertextCount(new_num_blocks), - ciphertext_modulus, - ); - let trimmed_ct_info = ct.as_ref().info.after_trim_radix_blocks_msb(new_num_blocks); - T::from(CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info)) + + trimmed_ct } pub(crate) fn generate_lookup_table(&self, f: F) -> LookupTableOwned @@ -1339,48 +1327,71 @@ impl CudaServerKey { where T: CudaIntegerRadixCiphertext, { - if !source.block_carries_are_empty() { - self.full_propagate_assign(&mut source, streams); - } - let current_num_blocks = source.as_ref().info.blocks.len(); - if T::IS_SIGNED { - // Casting from signed to unsigned - // We have to trim or sign extend first - if target_num_blocks > current_num_blocks { - let num_blocks_to_add = target_num_blocks - current_num_blocks; - let signed_res: T = - self.extend_radix_with_sign_msb(&source, num_blocks_to_add, streams); - ::from( - signed_res.into_inner(), - ) - } else { - let num_blocks_to_remove = current_num_blocks - target_num_blocks; - let signed_res = self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams); - ::from( - signed_res.into_inner(), - ) - } - } else { - // Casting from unsigned to unsigned, this is just about trimming/extending with zeros - if target_num_blocks > current_num_blocks { - let num_blocks_to_add = target_num_blocks - current_num_blocks; - let unsigned_res = self.extend_radix_with_trivial_zero_blocks_msb( - &source, - num_blocks_to_add, - streams, - ); - ::from( - unsigned_res.into_inner(), - ) - } else { - let num_blocks_to_remove = current_num_blocks - target_num_blocks; - let unsigned_res = - self.trim_radix_blocks_msb(&source, num_blocks_to_remove, streams); - ::from( - unsigned_res.into_inner(), - ) + let mut result: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(target_num_blocks, streams); + + let requires_full_propagate = !source.block_carries_are_empty(); + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } } + + result } /// Cast a `CudaUnsignedRadixCiphertext` or `CudaSignedRadixCiphertext` to a diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs index 7b818a6ed..764d27bc1 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs @@ -494,6 +494,46 @@ where } } +/// For match_value_or operation +impl<'a, F> + OpSequenceFunctionExecutor<(&'a RadixCiphertext, &'a MatchValues, u64), RadixCiphertext> + for OpSequenceGpuMultiDeviceFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &MatchValues, + u64, + &CudaStreams, + ) -> CudaUnsignedRadixCiphertext, +{ + fn setup( + &mut self, + cks: &RadixClientKey, + sks: &CompressedServerKey, + seeder: &mut DeterministicSeeder, + ) { + self.setup_from_gpu_keys(cks, sks, seeder); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a MatchValues, u64), + ) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + + let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, input.2, &context.streams); + + d_res.to_radix_ciphertext(&context.streams) + } +} + impl<'a, F> OpSequenceFunctionExecutor< (&'a RadixCiphertext, &'a RadixCiphertext), diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs index e02a15ee7..c16ce276e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs @@ -5,9 +5,10 @@ use crate::integer::gpu::CudaServerKey; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::{ random_op_sequence_test, BinaryOpExecutor, ComparisonOpExecutor, DivRemOpExecutor, - Log2OpExecutor, MatchValueExecutor, OprfBoundedExecutor, OprfCustomRangeExecutor, OprfExecutor, - OverflowingOpExecutor, ScalarBinaryOpExecutor, ScalarComparisonOpExecutor, - ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, SelectOpExecutor, UnaryOpExecutor, + Log2OpExecutor, MatchValueExecutor, MatchValueOrExecutor, OprfBoundedExecutor, + OprfCustomRangeExecutor, OprfExecutor, OverflowingOpExecutor, ScalarBinaryOpExecutor, + ScalarComparisonOpExecutor, ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, + SelectOpExecutor, UnaryOpExecutor, }; use crate::integer::server_key::radix_parallel::tests_long_run::{ get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN, @@ -54,6 +55,7 @@ pub(crate) fn random_op_sequence_test_init_gpu

( )], log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)], match_value_ops: &mut [(MatchValueExecutor, String)], + match_value_or_ops: &mut [(MatchValueOrExecutor, String)], oprf_ops: &mut [(OprfExecutor, String)], oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)], oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)], @@ -88,6 +90,7 @@ where + scalar_div_rem_op.len() + log2_ops.len() + match_value_ops.len() + + match_value_or_ops.len() + oprf_ops.len() + oprf_bounded_ops.len() + oprf_custom_range_ops.len(); @@ -142,6 +145,9 @@ where for x in match_value_ops.iter_mut() { x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); } + for x in match_value_or_ops.iter_mut() { + x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); + } for x in oprf_ops.iter_mut() { x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); } @@ -317,6 +323,7 @@ where )]; let mut match_value_ops: Vec<(MatchValueExecutor, String)> = vec![]; + let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![]; let mut oprf_ops: Vec<(OprfExecutor, String)> = vec![]; let mut oprf_bounded_ops: Vec<(OprfBoundedExecutor, String)> = vec![]; @@ -336,6 +343,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, @@ -357,6 +365,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, @@ -781,6 +790,14 @@ where let mut match_value_ops: Vec<(MatchValueExecutor, String)> = vec![(Box::new(match_value_executor), "match_value".to_string())]; + // Match Value Or Executor + let match_value_or_executor = + OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::match_value_or); + let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![( + Box::new(match_value_or_executor), + "match_value_or".to_string(), + )]; + // OPRF Executors let oprf_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new( &CudaServerKey::par_generate_oblivious_pseudo_random_unsigned_integer, @@ -816,6 +833,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, @@ -837,6 +855,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, diff --git a/tfhe/src/integer/gpu/server_key/radix/vector_find.rs b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs index 16feebd1f..9c45e5944 100644 --- a/tfhe/src/integer/gpu/server_key/radix/vector_find.rs +++ b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs @@ -9,8 +9,9 @@ use crate::integer::gpu::server_key::radix::CudaRadixCiphertext; use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::{ cuda_backend_aggregate_one_hot_vector, cuda_backend_compute_equality_selectors, - cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_size_on_gpu, - cuda_backend_unchecked_match_value, PBSType, + cuda_backend_create_possible_results, cuda_backend_get_unchecked_match_value_or_size_on_gpu, + cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_match_value, + cuda_backend_unchecked_match_value_or, PBSType, }; pub use crate::integer::server_key::radix_parallel::MatchValues; use crate::prelude::CastInto; @@ -410,26 +411,229 @@ impl CudaServerKey { Clear: UnsignedInteger + DecomposableInto + CastInto, { if matches.get_values().is_empty() { - let ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix( - or_value, - self.num_blocks_to_represent_unsigned_value(or_value), - streams, - ); + let num_blocks = self.num_blocks_to_represent_unsigned_value(or_value); + let ct: CudaUnsignedRadixCiphertext = + self.create_trivial_radix(or_value, num_blocks, streams); return ct; } - let (result, selected) = self.unchecked_match_value(ct, matches, streams); - // The result must have as many block to represent either the result of the match or the - // or_value - let num_blocks_to_represent_or_value = - self.num_blocks_to_represent_unsigned_value(or_value); - let num_blocks = (result.as_ref().d_blocks.lwe_ciphertext_count().0) - .max(num_blocks_to_represent_or_value); - let or_value: CudaUnsignedRadixCiphertext = - self.create_trivial_radix(or_value, num_blocks, streams); - let casted_result = self.cast_to_unsigned(result, num_blocks, streams); - // Note, this could be slightly faster when we have scalar if then_else - self.unchecked_if_then_else(&selected, &casted_result, &or_value, streams) + let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32; + let num_bits_in_message = self.message_modulus.0.ilog2(); + + let h_match_inputs: Vec = matches + .get_values() + .par_iter() + .map(|(input, _output)| *input) + .flat_map(|input_value| { + BlockDecomposer::new(input_value, num_bits_in_message) + .take(num_input_blocks as usize) + .map(|block_value| block_value.cast_into()) + .collect::>() + }) + .collect::>(); + + let max_output_value_match = matches + .get_values() + .iter() + .copied() + .max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr)) + .expect("luts is not empty at this point") + .1; + + let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match); + let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value); + let final_num_blocks = num_blocks_match.max(num_blocks_or); + + let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32; + + let h_match_outputs: Vec = matches + .get_values() + .par_iter() + .map(|(_input, output)| *output) + .flat_map(|output_value| { + BlockDecomposer::new(output_value, 2 * num_bits_in_message) + .take(num_match_packed_blocks as usize) + .map(|block_value| block_value.cast_into()) + .collect::>() + }) + .collect::>(); + + let h_or_value: Vec = BlockDecomposer::new(or_value, num_bits_in_message) + .take(final_num_blocks) + .map(|block_value| block_value.cast_into()) + .collect(); + + let mut result: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(final_num_blocks, streams); + + let max_output_is_zero = max_output_value_match == Clear::ZERO; + let num_matches = matches.get_values().len() as u32; + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_match_value_or( + streams, + &mut result, + ct.as_ref(), + &h_match_inputs, + &h_match_outputs, + &h_or_value, + num_matches, + num_input_blocks, + num_match_packed_blocks, + final_num_blocks as u32, + max_output_is_zero, + self.message_modulus, + self.carry_modulus, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_match_value_or( + streams, + &mut result, + ct.as_ref(), + &h_match_inputs, + &h_match_outputs, + &h_or_value, + num_matches, + num_input_blocks, + num_match_packed_blocks, + final_num_blocks as u32, + max_output_is_zero, + self.message_modulus, + self.carry_modulus, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } + } + + result + } + + pub fn get_unchecked_match_value_or_size_on_gpu( + &self, + ct: &CudaUnsignedRadixCiphertext, + matches: &MatchValues, + or_value: Clear, + streams: &CudaStreams, + ) -> u64 + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + if matches.get_values().is_empty() { + return 0; + } + + let num_input_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 as u32; + + let max_output_value_match = matches + .get_values() + .iter() + .copied() + .max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr)) + .expect("luts is not empty at this point") + .1; + + let num_blocks_match = self.num_blocks_to_represent_unsigned_value(max_output_value_match); + let num_blocks_or = self.num_blocks_to_represent_unsigned_value(or_value); + let final_num_blocks = num_blocks_match.max(num_blocks_or); + + let num_match_packed_blocks = num_blocks_match.div_ceil(2) as u32; + + let max_output_is_zero = max_output_value_match == Clear::ZERO; + let num_matches = matches.get_values().len() as u32; + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_get_unchecked_match_value_or_size_on_gpu( + streams, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + LweBskGroupingFactor(0), + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + num_matches, + num_input_blocks, + num_match_packed_blocks, + final_num_blocks as u32, + max_output_is_zero, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ) + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_get_unchecked_match_value_or_size_on_gpu( + streams, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.grouping_factor, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + num_matches, + num_input_blocks, + num_match_packed_blocks, + final_num_blocks as u32, + max_output_is_zero, + None, + ) + } + } } /// `match` an input value to an output value diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs index 9af101154..a5fd6bacb 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs @@ -267,9 +267,7 @@ impl< } } } -} -impl RandomOpSequenceDataGenerator { #[allow(clippy::manual_is_multiple_of)] pub(crate) fn gen_match_values(&mut self, key_to_match: u64) -> (MatchValues, u64, bool) { let mut pairings = Vec::new(); @@ -297,6 +295,19 @@ impl RandomOpSequenceDataGenerator { does_match, ) } + + pub(crate) fn gen_match_values_or( + &mut self, + key_to_match: u64, + ) -> (MatchValues, u64, u64) { + let (mv, match_val, does_match) = self.gen_match_values(key_to_match); + + let or_value = self.deterministic_seeder.seed().0 as u64; + + let expected = if does_match { match_val } else { or_value }; + + (mv, or_value, expected) + } } #[allow(clippy::too_many_arguments)] diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs index d23b0733b..f3e3d1cf6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs @@ -79,6 +79,13 @@ pub(crate) type MatchValueExecutor = Box< >, >; +pub(crate) type MatchValueOrExecutor = Box< + dyn for<'a> OpSequenceFunctionExecutor< + (&'a RadixCiphertext, &'a MatchValues, u64), + RadixCiphertext, + >, +>; + pub(crate) type OprfExecutor = Box OpSequenceFunctionExecutor<(Seed, u64), RadixCiphertext>>; @@ -474,6 +481,15 @@ where let mut match_value_ops: Vec<(MatchValueExecutor, String)> = vec![(Box::new(match_value_executor), "match_value".to_string())]; + // Match Values Or Executor + let match_value_or_executor = + OpSequenceCpuFunctionExecutor::new(&ServerKey::match_value_or_parallelized); + + let mut match_value_or_ops: Vec<(MatchValueOrExecutor, String)> = vec![( + Box::new(match_value_or_executor), + "match_value_or".to_string(), + )]; + // OPRF Executors let oprf_executor = OpSequenceCpuFunctionExecutor::new( &ServerKey::par_generate_oblivious_pseudo_random_unsigned_integer, @@ -514,6 +530,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, @@ -535,6 +552,7 @@ where &mut scalar_div_rem_op, &mut log2_ops, &mut match_value_ops, + &mut match_value_or_ops, &mut oprf_ops, &mut oprf_bounded_ops, &mut oprf_custom_range_ops, @@ -572,6 +590,7 @@ pub(crate) fn random_op_sequence_test_init_cpu

( )], log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)], match_value_ops: &mut [(MatchValueExecutor, String)], + match_value_or_ops: &mut [(MatchValueOrExecutor, String)], oprf_ops: &mut [(OprfExecutor, String)], oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)], oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)], @@ -602,6 +621,7 @@ where + scalar_div_rem_op.len() + log2_ops.len() + match_value_ops.len() + + match_value_or_ops.len() + oprf_ops.len() + oprf_bounded_ops.len() + oprf_custom_range_ops.len(); @@ -661,6 +681,9 @@ where for x in match_value_ops.iter_mut() { x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); } + for x in match_value_or_ops.iter_mut() { + x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); + } for x in oprf_ops.iter_mut() { x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder); } @@ -706,6 +729,7 @@ pub(crate) fn random_op_sequence_test( )], log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)], match_value_ops: &mut [(MatchValueExecutor, String)], + match_value_or_ops: &mut [(MatchValueOrExecutor, String)], oprf_ops: &mut [(OprfExecutor, String)], oprf_bounded_ops: &mut [(OprfBoundedExecutor, String)], oprf_custom_range_ops: &mut [(OprfCustomRangeExecutor, String)], @@ -729,7 +753,10 @@ pub(crate) fn random_op_sequence_test( div_rem_op_range.end..div_rem_op_range.end + scalar_div_rem_op.len(); let log2_ops_range = scalar_div_rem_op_range.end..scalar_div_rem_op_range.end + log2_ops.len(); let match_value_ops_range = log2_ops_range.end..log2_ops_range.end + match_value_ops.len(); - let oprf_ops_range = match_value_ops_range.end..match_value_ops_range.end + oprf_ops.len(); + let match_value_or_ops_range = + match_value_ops_range.end..match_value_ops_range.end + match_value_or_ops.len(); + let oprf_ops_range = + match_value_or_ops_range.end..match_value_or_ops_range.end + oprf_ops.len(); let oprf_bounded_ops_range = oprf_ops_range.end..oprf_ops_range.end + oprf_bounded_ops.len(); let oprf_custom_range_ops_range = oprf_bounded_ops_range.end..oprf_bounded_ops_range.end + oprf_custom_range_ops.len(); @@ -1149,6 +1176,35 @@ pub(crate) fn random_op_sequence_test( operand.p, operand.p, ); + } else if match_value_or_ops_range.contains(&i) { + let index = i - match_value_or_ops_range.start; + let (match_value_or_executor, fn_name) = &mut match_value_or_ops[index]; + let operand = datagen.gen_op_single_operand(idx, fn_name); + + let (match_values, or_value, expected_value) = datagen.gen_match_values_or(operand.p); + + println!("{idx}: MatchValuesOr generated. Expected: {expected_value}"); + + let res = match_value_or_executor.execute((&operand.c, &match_values, or_value)); + // Determinism check + let res_1 = match_value_or_executor.execute((&operand.c, &match_values, or_value)); + + let decrypted_res: u64 = cks.decrypt(&res); + + let res_casted = sks.cast_to_unsigned(res.clone(), NB_CTXT_LONG_RUN); + datagen.put_op_result_random_side(expected_value, &res_casted, fn_name, idx); + + sanity_check_op_sequence_result_u64( + idx, + fn_name, + fn_index, + &res, + &res_1, + decrypted_res, + expected_value, + operand.p, + operand.p, + ); } else if oprf_ops_range.contains(&i) { let index = i - oprf_ops_range.start; let (op_executor, fn_name) = &mut oprf_ops[index];