mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
refactor(gpu): unchecked_index_of_clear to backend
This commit is contained in:
committed by
Agnès Leroy
parent
a731be3878
commit
3cfbaa40c3
@@ -884,24 +884,6 @@ void cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
|
||||
void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_compute_final_index_from_selectors_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_compute_final_index_from_selectors_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
|
||||
uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem,
|
||||
void *const *bsks, void *const *ksks);
|
||||
|
||||
void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_unchecked_index_in_clears_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
@@ -1002,6 +984,26 @@ void cuda_unchecked_index_of_64(CudaStreamsFFI streams,
|
||||
|
||||
void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
|
||||
uint64_t scratch_cuda_unchecked_index_of_clear_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
|
||||
|
||||
void cuda_unchecked_index_of_clear_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
|
||||
const void *d_scalar_blocks, bool is_scalar_obviously_bigger,
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
|
||||
uint32_t num_blocks_index, int8_t *mem, void *const *bsks,
|
||||
void *const *ksks);
|
||||
|
||||
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void);
|
||||
} // extern C
|
||||
|
||||
#endif // CUDA_INTEGER_H
|
||||
|
||||
@@ -101,7 +101,7 @@ template <typename Torus> struct int_equality_selectors_buffer {
|
||||
size_tracker, allocate_gpu_memory);
|
||||
|
||||
this->reduction_buffers[j] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[j], COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
streams, COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
}
|
||||
}
|
||||
@@ -461,14 +461,14 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
|
||||
delete this->carry_extract_lut;
|
||||
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
release_radix_ciphertext_async(
|
||||
sub_streams[i].stream(0), sub_streams[i].gpu_index(0),
|
||||
this->partial_aggregated_vectors[i], this->allocate_gpu_memory);
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->partial_aggregated_vectors[i],
|
||||
this->allocate_gpu_memory);
|
||||
delete this->partial_aggregated_vectors[i];
|
||||
|
||||
release_radix_ciphertext_async(
|
||||
sub_streams[i].stream(0), sub_streams[i].gpu_index(0),
|
||||
this->partial_temp_vectors[i], this->allocate_gpu_memory);
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
this->partial_temp_vectors[i],
|
||||
this->allocate_gpu_memory);
|
||||
delete this->partial_temp_vectors[i];
|
||||
}
|
||||
delete[] partial_aggregated_vectors;
|
||||
@@ -752,7 +752,7 @@ template <typename Torus> struct int_unchecked_contains_buffer {
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
@@ -853,7 +853,7 @@ template <typename Torus> struct int_unchecked_contains_clear_buffer {
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
@@ -1272,7 +1272,7 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
@@ -1496,7 +1496,7 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
@@ -1690,7 +1690,94 @@ template <typename Torus> struct int_unchecked_index_of_buffer {
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
this->final_index_buf = new int_final_index_from_selectors_buffer<Torus>(
|
||||
streams, params, num_inputs, num_blocks_index, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
void release(CudaStreams streams) {
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
eq_buffers[i]->release(streams);
|
||||
delete eq_buffers[i];
|
||||
}
|
||||
delete[] eq_buffers;
|
||||
|
||||
this->final_index_buf->release(streams);
|
||||
delete this->final_index_buf;
|
||||
|
||||
cuda_event_destroy(incoming_event, streams.gpu_index(0));
|
||||
|
||||
uint32_t num_gpus = active_streams.count();
|
||||
for (uint j = 0; j < num_streams; j++) {
|
||||
for (uint k = 0; k < num_gpus; k++) {
|
||||
cuda_event_destroy(outgoing_events[j * num_gpus + k],
|
||||
active_streams.gpu_index(k));
|
||||
}
|
||||
}
|
||||
delete[] outgoing_events;
|
||||
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
sub_streams[i].release();
|
||||
}
|
||||
delete[] sub_streams;
|
||||
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus> struct int_unchecked_index_of_clear_buffer {
|
||||
int_radix_params params;
|
||||
bool allocate_gpu_memory;
|
||||
uint32_t num_inputs;
|
||||
|
||||
int_comparison_buffer<Torus> **eq_buffers;
|
||||
int_final_index_from_selectors_buffer<Torus> *final_index_buf;
|
||||
|
||||
CudaStreams active_streams;
|
||||
CudaStreams *sub_streams;
|
||||
cudaEvent_t incoming_event;
|
||||
cudaEvent_t *outgoing_events;
|
||||
uint32_t num_streams;
|
||||
|
||||
int_unchecked_index_of_clear_buffer(CudaStreams streams,
|
||||
int_radix_params params,
|
||||
uint32_t num_inputs, uint32_t num_blocks,
|
||||
uint32_t num_blocks_index,
|
||||
bool allocate_gpu_memory,
|
||||
uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->allocate_gpu_memory = allocate_gpu_memory;
|
||||
this->num_inputs = num_inputs;
|
||||
|
||||
uint32_t num_streams_to_use =
|
||||
std::min((uint32_t)MAX_STREAMS_FOR_VECTOR_FIND, num_inputs);
|
||||
if (num_streams_to_use == 0)
|
||||
num_streams_to_use = 1;
|
||||
|
||||
this->num_streams = num_streams_to_use;
|
||||
this->active_streams = streams.active_gpu_subset(num_blocks);
|
||||
uint32_t num_gpus = active_streams.count();
|
||||
|
||||
incoming_event = cuda_create_event(streams.gpu_index(0));
|
||||
sub_streams = new CudaStreams[num_streams_to_use];
|
||||
outgoing_events = new cudaEvent_t[num_streams_to_use * num_gpus];
|
||||
|
||||
for (uint32_t i = 0; i < num_streams_to_use; i++) {
|
||||
sub_streams[i].create_on_same_gpus(active_streams);
|
||||
for (uint32_t j = 0; j < num_gpus; j++) {
|
||||
outgoing_events[i * num_gpus + j] =
|
||||
cuda_create_event(active_streams.gpu_index(j));
|
||||
}
|
||||
}
|
||||
|
||||
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
|
||||
for (uint32_t i = 0; i < num_streams; i++) {
|
||||
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
|
||||
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
|
||||
size_tracker);
|
||||
}
|
||||
|
||||
|
||||
@@ -228,49 +228,6 @@ void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_compute_final_index_from_selectors_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus,
|
||||
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_compute_final_index_from_selectors<uint64_t>(
|
||||
CudaStreams(streams),
|
||||
(int_final_index_from_selectors_buffer<uint64_t> **)mem_ptr, params,
|
||||
num_inputs, num_blocks_index, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_compute_final_index_from_selectors_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
|
||||
uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem,
|
||||
void *const *bsks, void *const *ksks) {
|
||||
|
||||
host_compute_final_index_from_selectors<uint64_t>(
|
||||
CudaStreams(streams), index_ct, match_ct, selectors, num_inputs,
|
||||
num_blocks_index, (int_final_index_from_selectors_buffer<uint64_t> *)mem,
|
||||
bsks, (uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_final_index_from_selectors_buffer<uint64_t> *mem_ptr =
|
||||
(int_final_index_from_selectors_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_unchecked_index_in_clears_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
@@ -500,3 +457,50 @@ void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams,
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
uint64_t scratch_cuda_unchecked_index_of_clear_64(
|
||||
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t big_lwe_dimension,
|
||||
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_level,
|
||||
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
|
||||
message_modulus, carry_modulus, noise_reduction_type);
|
||||
|
||||
return scratch_cuda_unchecked_index_of_clear<uint64_t>(
|
||||
CudaStreams(streams),
|
||||
(int_unchecked_index_of_clear_buffer<uint64_t> **)mem_ptr, params,
|
||||
num_inputs, num_blocks, num_blocks_index, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_unchecked_index_of_clear_64(
|
||||
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
|
||||
const void *d_scalar_blocks, bool is_scalar_obviously_bigger,
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
|
||||
uint32_t num_blocks_index, int8_t *mem, void *const *bsks,
|
||||
void *const *ksks) {
|
||||
|
||||
host_unchecked_index_of_clear<uint64_t>(
|
||||
CudaStreams(streams), index_ct, match_ct, inputs,
|
||||
(const uint64_t *)d_scalar_blocks, is_scalar_obviously_bigger, num_inputs,
|
||||
num_blocks, num_scalar_blocks, num_blocks_index,
|
||||
(int_unchecked_index_of_clear_buffer<uint64_t> *)mem, bsks,
|
||||
(uint64_t *const *)ksks);
|
||||
}
|
||||
|
||||
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
|
||||
int8_t **mem_ptr_void) {
|
||||
int_unchecked_index_of_clear_buffer<uint64_t> *mem_ptr =
|
||||
(int_unchecked_index_of_clear_buffer<uint64_t> *)(*mem_ptr_void);
|
||||
|
||||
mem_ptr->release(CudaStreams(streams));
|
||||
|
||||
delete mem_ptr;
|
||||
*mem_ptr_void = nullptr;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "integer/comparison.cuh"
|
||||
#include "integer/integer.cuh"
|
||||
#include "integer/radix_ciphertext.cuh"
|
||||
#include "integer/scalar_comparison.cuh"
|
||||
#include "integer/vector_find.h"
|
||||
|
||||
template <typename Torus>
|
||||
@@ -1110,3 +1111,96 @@ __host__ void host_unchecked_index_of(
|
||||
mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks,
|
||||
num_inputs);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
uint64_t scratch_cuda_unchecked_index_of_clear(
|
||||
CudaStreams streams, int_unchecked_index_of_clear_buffer<Torus> **mem_ptr,
|
||||
int_radix_params params, uint32_t num_inputs, uint32_t num_blocks,
|
||||
uint32_t num_blocks_index, bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_unchecked_index_of_clear_buffer<Torus>(
|
||||
streams, params, num_inputs, num_blocks, num_blocks_index,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_unchecked_index_of_clear(
|
||||
CudaStreams streams, CudaRadixCiphertextFFI *index_ct,
|
||||
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
|
||||
const Torus *d_scalar_blocks, bool is_scalar_obviously_bigger,
|
||||
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
|
||||
uint32_t num_blocks_index,
|
||||
int_unchecked_index_of_clear_buffer<Torus> *mem_ptr, void *const *bsks,
|
||||
Torus *const *ksks) {
|
||||
|
||||
CudaRadixCiphertextFFI *packed_selectors =
|
||||
mem_ptr->final_index_buf->packed_selectors;
|
||||
|
||||
if (is_scalar_obviously_bigger) {
|
||||
set_zero_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), packed_selectors, 0,
|
||||
num_inputs);
|
||||
} else {
|
||||
cuda_event_record(mem_ptr->incoming_event, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
|
||||
for (uint32_t j = 0; j < mem_ptr->num_streams; j++) {
|
||||
for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) {
|
||||
cuda_stream_wait_event(mem_ptr->sub_streams[j].stream(i),
|
||||
mem_ptr->incoming_event,
|
||||
mem_ptr->sub_streams[j].gpu_index(i));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t num_streams = mem_ptr->num_streams;
|
||||
uint32_t num_gpus = mem_ptr->active_streams.count();
|
||||
|
||||
for (uint32_t i = 0; i < num_inputs; i++) {
|
||||
uint32_t stream_idx = i % num_streams;
|
||||
CudaStreams current_stream = mem_ptr->sub_streams[stream_idx];
|
||||
|
||||
CudaRadixCiphertextFFI const *input_ct = &inputs[i];
|
||||
|
||||
CudaRadixCiphertextFFI current_selector_dest;
|
||||
as_radix_ciphertext_slice<Torus>(¤t_selector_dest, packed_selectors,
|
||||
i, i + 1);
|
||||
|
||||
host_scalar_equality_check<Torus>(
|
||||
current_stream, ¤t_selector_dest, input_ct, d_scalar_blocks,
|
||||
mem_ptr->eq_buffers[stream_idx], bsks, (Torus **)ksks, num_blocks,
|
||||
num_scalar_blocks);
|
||||
}
|
||||
|
||||
for (uint32_t j = 0; j < mem_ptr->num_streams; j++) {
|
||||
for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) {
|
||||
cuda_event_record(mem_ptr->outgoing_events[j * num_gpus + i],
|
||||
mem_ptr->sub_streams[j].stream(i),
|
||||
mem_ptr->sub_streams[j].gpu_index(i));
|
||||
cuda_stream_wait_event(streams.stream(0),
|
||||
mem_ptr->outgoing_events[j * num_gpus + i],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t packed_len = (num_blocks_index + 1) / 2;
|
||||
|
||||
host_create_possible_results<Torus>(
|
||||
streams, mem_ptr->final_index_buf->possible_results_ct_list,
|
||||
mem_ptr->final_index_buf->unpacked_selectors, num_inputs,
|
||||
(const uint64_t *)mem_ptr->final_index_buf->h_indices, packed_len,
|
||||
mem_ptr->final_index_buf->possible_results_buf, bsks, ksks);
|
||||
|
||||
host_aggregate_one_hot_vector<Torus>(
|
||||
streams, index_ct, mem_ptr->final_index_buf->possible_results_ct_list,
|
||||
num_inputs, packed_len, mem_ptr->final_index_buf->aggregate_buf, bsks,
|
||||
ksks);
|
||||
|
||||
host_integer_is_at_least_one_comparisons_block_true<Torus>(
|
||||
streams, match_ct, packed_selectors,
|
||||
mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks,
|
||||
num_inputs);
|
||||
}
|
||||
|
||||
@@ -1922,47 +1922,6 @@ unsafe extern "C" {
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_compute_final_index_from_selectors_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_inputs: u32,
|
||||
num_blocks_index: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_compute_final_index_from_selectors_64(
|
||||
streams: CudaStreamsFFI,
|
||||
index_ct: *mut CudaRadixCiphertextFFI,
|
||||
match_ct: *mut CudaRadixCiphertextFFI,
|
||||
selectors: *const CudaRadixCiphertextFFI,
|
||||
num_inputs: u32,
|
||||
num_blocks_index: u32,
|
||||
mem: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_compute_final_index_from_selectors_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_unchecked_index_in_clears_64(
|
||||
streams: CudaStreamsFFI,
|
||||
@@ -2181,6 +2140,52 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_unchecked_index_of_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_unchecked_index_of_clear_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr: *mut *mut i8,
|
||||
glwe_dimension: u32,
|
||||
polynomial_size: u32,
|
||||
big_lwe_dimension: u32,
|
||||
small_lwe_dimension: u32,
|
||||
ks_level: u32,
|
||||
ks_base_log: u32,
|
||||
pbs_level: u32,
|
||||
pbs_base_log: u32,
|
||||
grouping_factor: u32,
|
||||
num_inputs: u32,
|
||||
num_blocks: u32,
|
||||
num_blocks_index: u32,
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
allocate_gpu_memory: bool,
|
||||
noise_reduction_type: PBS_MS_REDUCTION_T,
|
||||
) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cuda_unchecked_index_of_clear_64(
|
||||
streams: CudaStreamsFFI,
|
||||
index_ct: *mut CudaRadixCiphertextFFI,
|
||||
match_ct: *mut CudaRadixCiphertextFFI,
|
||||
inputs: *const CudaRadixCiphertextFFI,
|
||||
d_scalar_blocks: *const ffi::c_void,
|
||||
is_scalar_obviously_bigger: bool,
|
||||
num_inputs: u32,
|
||||
num_blocks: u32,
|
||||
num_scalar_blocks: u32,
|
||||
num_blocks_index: u32,
|
||||
mem: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn cleanup_cuda_unchecked_index_of_clear_64(
|
||||
streams: CudaStreamsFFI,
|
||||
mem_ptr_void: *mut *mut i8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
|
||||
streams: CudaStreamsFFI,
|
||||
|
||||
@@ -9321,129 +9321,6 @@ pub(crate) unsafe fn cuda_backend_unchecked_is_in_clears<
|
||||
update_noise_degree(&mut output.0.ciphertext, &ffi_output);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_compute_final_index_from_selectors<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
index_ct: &mut CudaRadixCiphertext,
|
||||
match_ct: &mut CudaBooleanBlock,
|
||||
selectors: &[CudaBooleanBlock],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
|
||||
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
|
||||
|
||||
let num_inputs = selectors.len() as u32;
|
||||
let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut index_degrees = index_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let mut index_noise_levels = index_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut ffi_index =
|
||||
prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels);
|
||||
|
||||
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
|
||||
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
|
||||
let mut ffi_match = prepare_cuda_radix_ffi(
|
||||
&match_ct.0.ciphertext,
|
||||
&mut match_degrees,
|
||||
&mut match_noise_levels,
|
||||
);
|
||||
|
||||
let mut ffi_selectors_degrees: Vec<Vec<u64>> = Vec::with_capacity(selectors.len());
|
||||
let mut ffi_selectors_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(selectors.len());
|
||||
let ffi_selectors: Vec<CudaRadixCiphertextFFI> = selectors
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = vec![ct.0.ciphertext.info.blocks[0].degree.get()];
|
||||
let noise_levels = vec![ct.0.ciphertext.info.blocks[0].noise_level.0];
|
||||
ffi_selectors_degrees.push(degrees);
|
||||
ffi_selectors_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
&ct.0.ciphertext,
|
||||
ffi_selectors_degrees.last_mut().unwrap(),
|
||||
ffi_selectors_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_compute_final_index_from_selectors_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_inputs,
|
||||
num_blocks_index,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_compute_final_index_from_selectors_64(
|
||||
streams.ffi(),
|
||||
&raw mut ffi_index,
|
||||
&raw mut ffi_match,
|
||||
ffi_selectors.as_ptr(),
|
||||
num_inputs,
|
||||
num_blocks_index,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_compute_final_index_from_selectors_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
);
|
||||
|
||||
update_noise_degree(index_ct, &ffi_index);
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -10152,3 +10029,158 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of<
|
||||
update_noise_degree(index_ct, &ffi_index);
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
C: CudaIntegerRadixCiphertext,
|
||||
Clear: DecomposableInto<u64> + CastInto<usize>,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
index_ct: &mut CudaRadixCiphertext,
|
||||
match_ct: &mut CudaBooleanBlock,
|
||||
inputs: &[C],
|
||||
clear: Clear,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
|
||||
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
|
||||
|
||||
let num_inputs = inputs.len() as u32;
|
||||
let num_blocks_in_ct = inputs[0].as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
|
||||
let mut scalar_blocks =
|
||||
BlockDecomposer::with_early_stop_at_zero(clear, message_modulus.0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let is_scalar_obviously_bigger = scalar_blocks
|
||||
.get(num_blocks_in_ct as usize..)
|
||||
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
|
||||
|
||||
scalar_blocks.truncate(num_blocks_in_ct as usize);
|
||||
let num_scalar_blocks = scalar_blocks.len() as u32;
|
||||
|
||||
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut index_degrees = index_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let mut index_noise_levels = index_ct
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
let mut ffi_index =
|
||||
prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels);
|
||||
|
||||
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
|
||||
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
|
||||
let mut ffi_match = prepare_cuda_radix_ffi(
|
||||
&match_ct.0.ciphertext,
|
||||
&mut match_degrees,
|
||||
&mut match_noise_levels,
|
||||
);
|
||||
|
||||
let mut ffi_inputs_degrees: Vec<Vec<u64>> = Vec::with_capacity(inputs.len());
|
||||
let mut ffi_inputs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(inputs.len());
|
||||
let ffi_inputs: Vec<CudaRadixCiphertextFFI> = inputs
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let noise_levels = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
ffi_inputs_degrees.push(degrees);
|
||||
ffi_inputs_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
ct.as_ref(),
|
||||
ffi_inputs_degrees.last_mut().unwrap(),
|
||||
ffi_inputs_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_unchecked_index_of_clear_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_inputs,
|
||||
num_blocks_in_ct,
|
||||
num_blocks_index,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_unchecked_index_of_clear_64(
|
||||
streams.ffi(),
|
||||
&raw mut ffi_index,
|
||||
&raw mut ffi_match,
|
||||
ffi_inputs.as_ptr(),
|
||||
d_scalar_blocks.as_c_ptr(0),
|
||||
is_scalar_obviously_bigger,
|
||||
num_inputs,
|
||||
num_blocks_in_ct,
|
||||
num_scalar_blocks,
|
||||
num_blocks_index,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_unchecked_index_of_clear_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(index_ct, &ffi_index);
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
@@ -5,14 +5,13 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_compute_final_index_from_selectors,
|
||||
cuda_backend_get_unchecked_match_value_or_size_on_gpu,
|
||||
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_contains,
|
||||
cuda_backend_unchecked_contains_clear, cuda_backend_unchecked_first_index_in_clears,
|
||||
cuda_backend_unchecked_first_index_of, cuda_backend_unchecked_first_index_of_clear,
|
||||
cuda_backend_unchecked_index_in_clears, cuda_backend_unchecked_index_of,
|
||||
cuda_backend_unchecked_is_in_clears, cuda_backend_unchecked_match_value,
|
||||
cuda_backend_unchecked_match_value_or, PBSType,
|
||||
cuda_backend_unchecked_index_of_clear, cuda_backend_unchecked_is_in_clears,
|
||||
cuda_backend_unchecked_match_value, cuda_backend_unchecked_match_value_or, PBSType,
|
||||
};
|
||||
pub use crate::integer::server_key::radix_parallel::MatchValues;
|
||||
use crate::prelude::CastInto;
|
||||
@@ -1490,12 +1489,80 @@ impl CudaServerKey {
|
||||
);
|
||||
return (trivial_ct, trivial_bool);
|
||||
}
|
||||
let selectors = cts
|
||||
.iter()
|
||||
.map(|ct| self.scalar_eq(ct, clear, streams))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.compute_final_index_from_selectors(&selectors, streams)
|
||||
let num_inputs = cts.len();
|
||||
let num_blocks_index =
|
||||
(num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize;
|
||||
|
||||
let mut index_ct: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_blocks_index, streams);
|
||||
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_index_of_clear(
|
||||
streams,
|
||||
index_ct.as_mut(),
|
||||
&mut match_ct,
|
||||
cts,
|
||||
clear,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_index_of_clear(
|
||||
streams,
|
||||
index_ct.as_mut(),
|
||||
&mut match_ct,
|
||||
cts,
|
||||
clear,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(index_ct, match_ct)
|
||||
}
|
||||
|
||||
/// Returns the encrypted index of the of clear `value` in the ciphertext slice
|
||||
@@ -1932,82 +1999,4 @@ impl CudaServerKey {
|
||||
};
|
||||
self.unchecked_first_index_of(cts, value, streams)
|
||||
}
|
||||
|
||||
fn compute_final_index_from_selectors(
|
||||
&self,
|
||||
selectors: &[CudaBooleanBlock],
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
|
||||
let num_inputs = selectors.len();
|
||||
let num_blocks_index =
|
||||
(num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize;
|
||||
|
||||
let mut index_ct: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_zero_radix(num_blocks_index, streams);
|
||||
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_compute_final_index_from_selectors(
|
||||
streams,
|
||||
index_ct.as_mut(),
|
||||
&mut match_ct,
|
||||
selectors,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_compute_final_index_from_selectors(
|
||||
streams,
|
||||
index_ct.as_mut(),
|
||||
&mut match_ct,
|
||||
selectors,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(index_ct, match_ct)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user