diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index ca5503a96..acd7bcfd2 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -884,24 +884,6 @@ void cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams, void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); -uint64_t scratch_cuda_compute_final_index_from_selectors_64( - CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t big_lwe_dimension, - uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, - uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, - uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus, - uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, - PBS_MS_REDUCTION_T noise_reduction_type); - -void cuda_compute_final_index_from_selectors_64( - CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct, - CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors, - uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem, - void *const *bsks, void *const *ksks); - -void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams, - int8_t **mem_ptr_void); - uint64_t scratch_cuda_unchecked_index_in_clears_64( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t big_lwe_dimension, @@ -1002,6 +984,26 @@ void cuda_unchecked_index_of_64(CudaStreamsFFI streams, void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); + +uint64_t scratch_cuda_unchecked_index_of_clear_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_unchecked_index_of_clear_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct, + CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs, + const void *d_scalar_blocks, bool is_scalar_obviously_bigger, + uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks, + uint32_t num_blocks_index, int8_t *mem, void *const *bsks, + void *const *ksks); + +void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void); } // extern C #endif // CUDA_INTEGER_H diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h b/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h index 55bc89be7..d1776140f 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/vector_find.h @@ -101,7 +101,7 @@ template struct int_equality_selectors_buffer { size_tracker, allocate_gpu_memory); this->reduction_buffers[j] = new int_comparison_buffer( - sub_streams[j], COMPARISON_TYPE::EQ, params, num_blocks, false, + streams, COMPARISON_TYPE::EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } } @@ -461,14 +461,14 @@ template struct int_aggregate_one_hot_buffer { delete this->carry_extract_lut; for (uint32_t i = 0; i < num_streams; i++) { - release_radix_ciphertext_async( - sub_streams[i].stream(0), sub_streams[i].gpu_index(0), - this->partial_aggregated_vectors[i], this->allocate_gpu_memory); + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->partial_aggregated_vectors[i], + this->allocate_gpu_memory); delete this->partial_aggregated_vectors[i]; - release_radix_ciphertext_async( - sub_streams[i].stream(0), sub_streams[i].gpu_index(0), - this->partial_temp_vectors[i], this->allocate_gpu_memory); + release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), + this->partial_temp_vectors[i], + this->allocate_gpu_memory); delete this->partial_temp_vectors[i]; } delete[] partial_aggregated_vectors; @@ -752,7 +752,7 @@ template struct int_unchecked_contains_buffer { this->eq_buffers = new int_comparison_buffer *[num_streams]; for (uint32_t i = 0; i < num_streams; i++) { this->eq_buffers[i] = new int_comparison_buffer( - sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory, + streams, EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } @@ -853,7 +853,7 @@ template struct int_unchecked_contains_clear_buffer { this->eq_buffers = new int_comparison_buffer *[num_streams]; for (uint32_t i = 0; i < num_streams; i++) { this->eq_buffers[i] = new int_comparison_buffer( - sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory, + streams, EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } @@ -1272,7 +1272,7 @@ template struct int_unchecked_first_index_of_clear_buffer { this->eq_buffers = new int_comparison_buffer *[num_streams]; for (uint32_t i = 0; i < num_streams; i++) { this->eq_buffers[i] = new int_comparison_buffer( - sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory, + streams, EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } @@ -1496,7 +1496,7 @@ template struct int_unchecked_first_index_of_buffer { this->eq_buffers = new int_comparison_buffer *[num_streams]; for (uint32_t i = 0; i < num_streams; i++) { this->eq_buffers[i] = new int_comparison_buffer( - sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory, + streams, EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } @@ -1690,7 +1690,94 @@ template struct int_unchecked_index_of_buffer { this->eq_buffers = new int_comparison_buffer *[num_streams]; for (uint32_t i = 0; i < num_streams; i++) { this->eq_buffers[i] = new int_comparison_buffer( - sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory, + streams, EQ, params, num_blocks, false, allocate_gpu_memory, + size_tracker); + } + + this->final_index_buf = new int_final_index_from_selectors_buffer( + streams, params, num_inputs, num_blocks_index, allocate_gpu_memory, + size_tracker); + } + + void release(CudaStreams streams) { + for (uint32_t i = 0; i < num_streams; i++) { + eq_buffers[i]->release(streams); + delete eq_buffers[i]; + } + delete[] eq_buffers; + + this->final_index_buf->release(streams); + delete this->final_index_buf; + + cuda_event_destroy(incoming_event, streams.gpu_index(0)); + + uint32_t num_gpus = active_streams.count(); + for (uint j = 0; j < num_streams; j++) { + for (uint k = 0; k < num_gpus; k++) { + cuda_event_destroy(outgoing_events[j * num_gpus + k], + active_streams.gpu_index(k)); + } + } + delete[] outgoing_events; + + for (uint32_t i = 0; i < num_streams; i++) { + sub_streams[i].release(); + } + delete[] sub_streams; + + cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0)); + } +}; + +template struct int_unchecked_index_of_clear_buffer { + int_radix_params params; + bool allocate_gpu_memory; + uint32_t num_inputs; + + int_comparison_buffer **eq_buffers; + int_final_index_from_selectors_buffer *final_index_buf; + + CudaStreams active_streams; + CudaStreams *sub_streams; + cudaEvent_t incoming_event; + cudaEvent_t *outgoing_events; + uint32_t num_streams; + + int_unchecked_index_of_clear_buffer(CudaStreams streams, + int_radix_params params, + uint32_t num_inputs, uint32_t num_blocks, + uint32_t num_blocks_index, + bool allocate_gpu_memory, + uint64_t &size_tracker) { + this->params = params; + this->allocate_gpu_memory = allocate_gpu_memory; + this->num_inputs = num_inputs; + + uint32_t num_streams_to_use = + std::min((uint32_t)MAX_STREAMS_FOR_VECTOR_FIND, num_inputs); + if (num_streams_to_use == 0) + num_streams_to_use = 1; + + this->num_streams = num_streams_to_use; + this->active_streams = streams.active_gpu_subset(num_blocks); + uint32_t num_gpus = active_streams.count(); + + incoming_event = cuda_create_event(streams.gpu_index(0)); + sub_streams = new CudaStreams[num_streams_to_use]; + outgoing_events = new cudaEvent_t[num_streams_to_use * num_gpus]; + + for (uint32_t i = 0; i < num_streams_to_use; i++) { + sub_streams[i].create_on_same_gpus(active_streams); + for (uint32_t j = 0; j < num_gpus; j++) { + outgoing_events[i * num_gpus + j] = + cuda_create_event(active_streams.gpu_index(j)); + } + } + + this->eq_buffers = new int_comparison_buffer *[num_streams]; + for (uint32_t i = 0; i < num_streams; i++) { + this->eq_buffers[i] = new int_comparison_buffer( + streams, EQ, params, num_blocks, false, allocate_gpu_memory, size_tracker); } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu index 9d74f8981..44cb1e821 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cu @@ -228,49 +228,6 @@ void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams, *mem_ptr_void = nullptr; } -uint64_t scratch_cuda_compute_final_index_from_selectors_64( - CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t big_lwe_dimension, - uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, - uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, - uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus, - uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, - PBS_MS_REDUCTION_T noise_reduction_type) { - - int_radix_params params(pbs_type, glwe_dimension, polynomial_size, - big_lwe_dimension, small_lwe_dimension, ks_level, - ks_base_log, pbs_level, pbs_base_log, grouping_factor, - message_modulus, carry_modulus, noise_reduction_type); - - return scratch_cuda_compute_final_index_from_selectors( - CudaStreams(streams), - (int_final_index_from_selectors_buffer **)mem_ptr, params, - num_inputs, num_blocks_index, allocate_gpu_memory); -} - -void cuda_compute_final_index_from_selectors_64( - CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct, - CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors, - uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem, - void *const *bsks, void *const *ksks) { - - host_compute_final_index_from_selectors( - CudaStreams(streams), index_ct, match_ct, selectors, num_inputs, - num_blocks_index, (int_final_index_from_selectors_buffer *)mem, - bsks, (uint64_t *const *)ksks); -} - -void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams, - int8_t **mem_ptr_void) { - int_final_index_from_selectors_buffer *mem_ptr = - (int_final_index_from_selectors_buffer *)(*mem_ptr_void); - - mem_ptr->release(CudaStreams(streams)); - - delete mem_ptr; - *mem_ptr_void = nullptr; -} - uint64_t scratch_cuda_unchecked_index_in_clears_64( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t big_lwe_dimension, @@ -500,3 +457,50 @@ void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams, delete mem_ptr; *mem_ptr_void = nullptr; } + +uint64_t scratch_cuda_unchecked_index_of_clear_64( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_unchecked_index_of_clear( + CudaStreams(streams), + (int_unchecked_index_of_clear_buffer **)mem_ptr, params, + num_inputs, num_blocks, num_blocks_index, allocate_gpu_memory); +} + +void cuda_unchecked_index_of_clear_64( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct, + CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs, + const void *d_scalar_blocks, bool is_scalar_obviously_bigger, + uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks, + uint32_t num_blocks_index, int8_t *mem, void *const *bsks, + void *const *ksks) { + + host_unchecked_index_of_clear( + CudaStreams(streams), index_ct, match_ct, inputs, + (const uint64_t *)d_scalar_blocks, is_scalar_obviously_bigger, num_inputs, + num_blocks, num_scalar_blocks, num_blocks_index, + (int_unchecked_index_of_clear_buffer *)mem, bsks, + (uint64_t *const *)ksks); +} + +void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams, + int8_t **mem_ptr_void) { + int_unchecked_index_of_clear_buffer *mem_ptr = + (int_unchecked_index_of_clear_buffer *)(*mem_ptr_void); + + mem_ptr->release(CudaStreams(streams)); + + delete mem_ptr; + *mem_ptr_void = nullptr; +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh index dbf00a593..c42a4bf96 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh @@ -5,6 +5,7 @@ #include "integer/comparison.cuh" #include "integer/integer.cuh" #include "integer/radix_ciphertext.cuh" +#include "integer/scalar_comparison.cuh" #include "integer/vector_find.h" template @@ -1110,3 +1111,96 @@ __host__ void host_unchecked_index_of( mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks, num_inputs); } + +template +uint64_t scratch_cuda_unchecked_index_of_clear( + CudaStreams streams, int_unchecked_index_of_clear_buffer **mem_ptr, + int_radix_params params, uint32_t num_inputs, uint32_t num_blocks, + uint32_t num_blocks_index, bool allocate_gpu_memory) { + + uint64_t size_tracker = 0; + *mem_ptr = new int_unchecked_index_of_clear_buffer( + streams, params, num_inputs, num_blocks, num_blocks_index, + allocate_gpu_memory, size_tracker); + + return size_tracker; +} + +template +__host__ void host_unchecked_index_of_clear( + CudaStreams streams, CudaRadixCiphertextFFI *index_ct, + CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs, + const Torus *d_scalar_blocks, bool is_scalar_obviously_bigger, + uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks, + uint32_t num_blocks_index, + int_unchecked_index_of_clear_buffer *mem_ptr, void *const *bsks, + Torus *const *ksks) { + + CudaRadixCiphertextFFI *packed_selectors = + mem_ptr->final_index_buf->packed_selectors; + + if (is_scalar_obviously_bigger) { + set_zero_radix_ciphertext_slice_async( + streams.stream(0), streams.gpu_index(0), packed_selectors, 0, + num_inputs); + } else { + cuda_event_record(mem_ptr->incoming_event, streams.stream(0), + streams.gpu_index(0)); + + for (uint32_t j = 0; j < mem_ptr->num_streams; j++) { + for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) { + cuda_stream_wait_event(mem_ptr->sub_streams[j].stream(i), + mem_ptr->incoming_event, + mem_ptr->sub_streams[j].gpu_index(i)); + } + } + + uint32_t num_streams = mem_ptr->num_streams; + uint32_t num_gpus = mem_ptr->active_streams.count(); + + for (uint32_t i = 0; i < num_inputs; i++) { + uint32_t stream_idx = i % num_streams; + CudaStreams current_stream = mem_ptr->sub_streams[stream_idx]; + + CudaRadixCiphertextFFI const *input_ct = &inputs[i]; + + CudaRadixCiphertextFFI current_selector_dest; + as_radix_ciphertext_slice(¤t_selector_dest, packed_selectors, + i, i + 1); + + host_scalar_equality_check( + current_stream, ¤t_selector_dest, input_ct, d_scalar_blocks, + mem_ptr->eq_buffers[stream_idx], bsks, (Torus **)ksks, num_blocks, + num_scalar_blocks); + } + + for (uint32_t j = 0; j < mem_ptr->num_streams; j++) { + for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) { + cuda_event_record(mem_ptr->outgoing_events[j * num_gpus + i], + mem_ptr->sub_streams[j].stream(i), + mem_ptr->sub_streams[j].gpu_index(i)); + cuda_stream_wait_event(streams.stream(0), + mem_ptr->outgoing_events[j * num_gpus + i], + streams.gpu_index(0)); + } + } + } + + uint32_t packed_len = (num_blocks_index + 1) / 2; + + host_create_possible_results( + streams, mem_ptr->final_index_buf->possible_results_ct_list, + mem_ptr->final_index_buf->unpacked_selectors, num_inputs, + (const uint64_t *)mem_ptr->final_index_buf->h_indices, packed_len, + mem_ptr->final_index_buf->possible_results_buf, bsks, ksks); + + host_aggregate_one_hot_vector( + streams, index_ct, mem_ptr->final_index_buf->possible_results_ct_list, + num_inputs, packed_len, mem_ptr->final_index_buf->aggregate_buf, bsks, + ksks); + + host_integer_is_at_least_one_comparisons_block_true( + streams, match_ct, packed_selectors, + mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks, + num_inputs); +} diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index 22406f094..489e69666 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1922,47 +1922,6 @@ unsafe extern "C" { mem_ptr_void: *mut *mut i8, ); } -unsafe extern "C" { - pub fn scratch_cuda_compute_final_index_from_selectors_64( - streams: CudaStreamsFFI, - mem_ptr: *mut *mut i8, - glwe_dimension: u32, - polynomial_size: u32, - big_lwe_dimension: u32, - small_lwe_dimension: u32, - ks_level: u32, - ks_base_log: u32, - pbs_level: u32, - pbs_base_log: u32, - grouping_factor: u32, - num_inputs: u32, - num_blocks_index: u32, - message_modulus: u32, - carry_modulus: u32, - pbs_type: PBS_TYPE, - allocate_gpu_memory: bool, - noise_reduction_type: PBS_MS_REDUCTION_T, - ) -> u64; -} -unsafe extern "C" { - pub fn cuda_compute_final_index_from_selectors_64( - streams: CudaStreamsFFI, - index_ct: *mut CudaRadixCiphertextFFI, - match_ct: *mut CudaRadixCiphertextFFI, - selectors: *const CudaRadixCiphertextFFI, - num_inputs: u32, - num_blocks_index: u32, - mem: *mut i8, - bsks: *const *mut ffi::c_void, - ksks: *const *mut ffi::c_void, - ); -} -unsafe extern "C" { - pub fn cleanup_cuda_compute_final_index_from_selectors_64( - streams: CudaStreamsFFI, - mem_ptr_void: *mut *mut i8, - ); -} unsafe extern "C" { pub fn scratch_cuda_unchecked_index_in_clears_64( streams: CudaStreamsFFI, @@ -2181,6 +2140,52 @@ unsafe extern "C" { unsafe extern "C" { pub fn cleanup_cuda_unchecked_index_of_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } +unsafe extern "C" { + pub fn scratch_cuda_unchecked_index_of_clear_64( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_inputs: u32, + num_blocks: u32, + num_blocks_index: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_unchecked_index_of_clear_64( + streams: CudaStreamsFFI, + index_ct: *mut CudaRadixCiphertextFFI, + match_ct: *mut CudaRadixCiphertextFFI, + inputs: *const CudaRadixCiphertextFFI, + d_scalar_blocks: *const ffi::c_void, + is_scalar_obviously_bigger: bool, + num_inputs: u32, + num_blocks: u32, + num_scalar_blocks: u32, + num_blocks_index: u32, + mem: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} +unsafe extern "C" { + pub fn cleanup_cuda_unchecked_index_of_clear_64( + streams: CudaStreamsFFI, + mem_ptr_void: *mut *mut i8, + ); +} unsafe extern "C" { pub fn scratch_cuda_integer_compress_radix_ciphertext_64( streams: CudaStreamsFFI, diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 98b47a6ca..7db17ac79 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -9321,129 +9321,6 @@ pub(crate) unsafe fn cuda_backend_unchecked_is_in_clears< update_noise_degree(&mut output.0.ciphertext, &ffi_output); } -#[allow(clippy::too_many_arguments)] -/// # Safety -/// -/// - The data must not be moved or dropped while being used by the CUDA kernel. -/// - This function assumes exclusive access to the passed data; violating this may lead to -/// undefined behavior. -pub(crate) unsafe fn cuda_backend_compute_final_index_from_selectors< - T: UnsignedInteger, - B: Numeric, ->( - streams: &CudaStreams, - index_ct: &mut CudaRadixCiphertext, - match_ct: &mut CudaBooleanBlock, - selectors: &[CudaBooleanBlock], - bootstrapping_key: &CudaVec, - keyswitch_key: &CudaVec, - message_modulus: MessageModulus, - carry_modulus: CarryModulus, - glwe_dimension: GlweDimension, - polynomial_size: PolynomialSize, - big_lwe_dimension: LweDimension, - small_lwe_dimension: LweDimension, - ks_level: DecompositionLevelCount, - ks_base_log: DecompositionBaseLog, - pbs_level: DecompositionLevelCount, - pbs_base_log: DecompositionBaseLog, - pbs_type: PBSType, - grouping_factor: LweBskGroupingFactor, - ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, -) { - assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0)); - assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0)); - - let num_inputs = selectors.len() as u32; - let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32; - - let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); - - let mut index_degrees = index_ct - .info - .blocks - .iter() - .map(|b| b.degree.get()) - .collect(); - let mut index_noise_levels = index_ct - .info - .blocks - .iter() - .map(|b| b.noise_level.0) - .collect(); - let mut ffi_index = - prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels); - - let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()]; - let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0]; - let mut ffi_match = prepare_cuda_radix_ffi( - &match_ct.0.ciphertext, - &mut match_degrees, - &mut match_noise_levels, - ); - - let mut ffi_selectors_degrees: Vec> = Vec::with_capacity(selectors.len()); - let mut ffi_selectors_noise_levels: Vec> = Vec::with_capacity(selectors.len()); - let ffi_selectors: Vec = selectors - .iter() - .map(|ct| { - let degrees = vec![ct.0.ciphertext.info.blocks[0].degree.get()]; - let noise_levels = vec![ct.0.ciphertext.info.blocks[0].noise_level.0]; - ffi_selectors_degrees.push(degrees); - ffi_selectors_noise_levels.push(noise_levels); - - prepare_cuda_radix_ffi( - &ct.0.ciphertext, - ffi_selectors_degrees.last_mut().unwrap(), - ffi_selectors_noise_levels.last_mut().unwrap(), - ) - }) - .collect(); - - let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - - scratch_cuda_compute_final_index_from_selectors_64( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - num_inputs, - num_blocks_index, - message_modulus.0 as u32, - carry_modulus.0 as u32, - pbs_type as u32, - true, - noise_reduction_type as u32, - ); - - cuda_compute_final_index_from_selectors_64( - streams.ffi(), - &raw mut ffi_index, - &raw mut ffi_match, - ffi_selectors.as_ptr(), - num_inputs, - num_blocks_index, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); - - cleanup_cuda_compute_final_index_from_selectors_64( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - ); - - update_noise_degree(index_ct, &ffi_index); - update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match); -} - #[allow(clippy::too_many_arguments)] /// # Safety /// @@ -10152,3 +10029,158 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of< update_noise_degree(index_ct, &ffi_index); update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match); } + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - The data must not be moved or dropped while being used by the CUDA kernel. +/// - This function assumes exclusive access to the passed data; violating this may lead to +/// undefined behavior. +pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear< + T: UnsignedInteger, + B: Numeric, + C: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, +>( + streams: &CudaStreams, + index_ct: &mut CudaRadixCiphertext, + match_ct: &mut CudaBooleanBlock, + inputs: &[C], + clear: Clear, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + big_lwe_dimension: LweDimension, + small_lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, + ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>, +) { + assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0)); + assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0)); + + let num_inputs = inputs.len() as u32; + let num_blocks_in_ct = inputs[0].as_ref().d_blocks.lwe_ciphertext_count().0 as u32; + let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32; + + let mut scalar_blocks = + BlockDecomposer::with_early_stop_at_zero(clear, message_modulus.0.ilog2()) + .iter_as::() + .collect::>(); + + let is_scalar_obviously_bigger = scalar_blocks + .get(num_blocks_in_ct as usize..) + .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); + + scalar_blocks.truncate(num_blocks_in_ct as usize); + let num_scalar_blocks = scalar_blocks.len() as u32; + + let d_scalar_blocks: CudaVec = CudaVec::from_cpu_async(&scalar_blocks, streams, 0); + + let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration); + + let mut index_degrees = index_ct + .info + .blocks + .iter() + .map(|b| b.degree.get()) + .collect(); + let mut index_noise_levels = index_ct + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + let mut ffi_index = + prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels); + + let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()]; + let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0]; + let mut ffi_match = prepare_cuda_radix_ffi( + &match_ct.0.ciphertext, + &mut match_degrees, + &mut match_noise_levels, + ); + + let mut ffi_inputs_degrees: Vec> = Vec::with_capacity(inputs.len()); + let mut ffi_inputs_noise_levels: Vec> = Vec::with_capacity(inputs.len()); + let ffi_inputs: Vec = inputs + .iter() + .map(|ct| { + let degrees = ct + .as_ref() + .info + .blocks + .iter() + .map(|b| b.degree.get()) + .collect(); + let noise_levels = ct + .as_ref() + .info + .blocks + .iter() + .map(|b| b.noise_level.0) + .collect(); + ffi_inputs_degrees.push(degrees); + ffi_inputs_noise_levels.push(noise_levels); + + prepare_cuda_radix_ffi( + ct.as_ref(), + ffi_inputs_degrees.last_mut().unwrap(), + ffi_inputs_noise_levels.last_mut().unwrap(), + ) + }) + .collect(); + + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + + scratch_cuda_unchecked_index_of_clear_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_inputs, + num_blocks_in_ct, + num_blocks_index, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_unchecked_index_of_clear_64( + streams.ffi(), + &raw mut ffi_index, + &raw mut ffi_match, + ffi_inputs.as_ptr(), + d_scalar_blocks.as_c_ptr(0), + is_scalar_obviously_bigger, + num_inputs, + num_blocks_in_ct, + num_scalar_blocks, + num_blocks_index, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + + cleanup_cuda_unchecked_index_of_clear_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); + + update_noise_degree(index_ct, &ffi_index); + update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/vector_find.rs b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs index 26a2b0546..5c1a4d580 100644 --- a/tfhe/src/integer/gpu/server_key/radix/vector_find.rs +++ b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs @@ -5,14 +5,13 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::{ - cuda_backend_compute_final_index_from_selectors, cuda_backend_get_unchecked_match_value_or_size_on_gpu, cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_contains, cuda_backend_unchecked_contains_clear, cuda_backend_unchecked_first_index_in_clears, cuda_backend_unchecked_first_index_of, cuda_backend_unchecked_first_index_of_clear, cuda_backend_unchecked_index_in_clears, cuda_backend_unchecked_index_of, - cuda_backend_unchecked_is_in_clears, cuda_backend_unchecked_match_value, - cuda_backend_unchecked_match_value_or, PBSType, + cuda_backend_unchecked_index_of_clear, cuda_backend_unchecked_is_in_clears, + cuda_backend_unchecked_match_value, cuda_backend_unchecked_match_value_or, PBSType, }; pub use crate::integer::server_key::radix_parallel::MatchValues; use crate::prelude::CastInto; @@ -1490,12 +1489,80 @@ impl CudaServerKey { ); return (trivial_ct, trivial_bool); } - let selectors = cts - .iter() - .map(|ct| self.scalar_eq(ct, clear, streams)) - .collect::>(); - self.compute_final_index_from_selectors(&selectors, streams) + let num_inputs = cts.len(); + let num_blocks_index = + (num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let mut index_ct: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(num_blocks_index, streams); + + let trivial_bool = + self.create_trivial_zero_radix::(1, streams); + let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner()); + + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_index_of_clear( + streams, + index_ct.as_mut(), + &mut match_ct, + cts, + clear, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_index_of_clear( + streams, + index_ct.as_mut(), + &mut match_ct, + cts, + clear, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } + } + + (index_ct, match_ct) } /// Returns the encrypted index of the of clear `value` in the ciphertext slice @@ -1932,82 +1999,4 @@ impl CudaServerKey { }; self.unchecked_first_index_of(cts, value, streams) } - - fn compute_final_index_from_selectors( - &self, - selectors: &[CudaBooleanBlock], - streams: &CudaStreams, - ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) { - let num_inputs = selectors.len(); - let num_blocks_index = - (num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; - - let mut index_ct: CudaUnsignedRadixCiphertext = - self.create_trivial_zero_radix(num_blocks_index, streams); - - let trivial_bool = - self.create_trivial_zero_radix::(1, streams); - let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner()); - - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_compute_final_index_from_selectors( - streams, - index_ct.as_mut(), - &mut match_ct, - selectors, - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_compute_final_index_from_selectors( - streams, - index_ct.as_mut(), - &mut match_ct, - selectors, - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - self.key_switching_key - .input_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key - .output_key_lwe_size() - .to_lwe_dimension(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); - } - } - } - - (index_ct, match_ct) - } }