refactor(gpu): unchecked_index_of_clear to backend

This commit is contained in:
Enzo Di Maria
2025-11-27 17:00:27 +01:00
committed by Agnès Leroy
parent a731be3878
commit 3cfbaa40c3
7 changed files with 536 additions and 323 deletions

View File

@@ -884,24 +884,6 @@ void cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
uint64_t scratch_cuda_compute_final_index_from_selectors_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_compute_final_index_from_selectors_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem,
void *const *bsks, void *const *ksks);
void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
uint64_t scratch_cuda_unchecked_index_in_clears_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
@@ -1002,6 +984,26 @@ void cuda_unchecked_index_of_64(CudaStreamsFFI streams,
void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
uint64_t scratch_cuda_unchecked_index_of_clear_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_unchecked_index_of_clear_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
const void *d_scalar_blocks, bool is_scalar_obviously_bigger,
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
uint32_t num_blocks_index, int8_t *mem, void *const *bsks,
void *const *ksks);
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void);
} // extern C
#endif // CUDA_INTEGER_H

View File

@@ -101,7 +101,7 @@ template <typename Torus> struct int_equality_selectors_buffer {
size_tracker, allocate_gpu_memory);
this->reduction_buffers[j] = new int_comparison_buffer<Torus>(
sub_streams[j], COMPARISON_TYPE::EQ, params, num_blocks, false,
streams, COMPARISON_TYPE::EQ, params, num_blocks, false,
allocate_gpu_memory, size_tracker);
}
}
@@ -461,14 +461,14 @@ template <typename Torus> struct int_aggregate_one_hot_buffer {
delete this->carry_extract_lut;
for (uint32_t i = 0; i < num_streams; i++) {
release_radix_ciphertext_async(
sub_streams[i].stream(0), sub_streams[i].gpu_index(0),
this->partial_aggregated_vectors[i], this->allocate_gpu_memory);
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->partial_aggregated_vectors[i],
this->allocate_gpu_memory);
delete this->partial_aggregated_vectors[i];
release_radix_ciphertext_async(
sub_streams[i].stream(0), sub_streams[i].gpu_index(0),
this->partial_temp_vectors[i], this->allocate_gpu_memory);
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
this->partial_temp_vectors[i],
this->allocate_gpu_memory);
delete this->partial_temp_vectors[i];
}
delete[] partial_aggregated_vectors;
@@ -752,7 +752,7 @@ template <typename Torus> struct int_unchecked_contains_buffer {
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}
@@ -853,7 +853,7 @@ template <typename Torus> struct int_unchecked_contains_clear_buffer {
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}
@@ -1272,7 +1272,7 @@ template <typename Torus> struct int_unchecked_first_index_of_clear_buffer {
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}
@@ -1496,7 +1496,7 @@ template <typename Torus> struct int_unchecked_first_index_of_buffer {
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}
@@ -1690,7 +1690,94 @@ template <typename Torus> struct int_unchecked_index_of_buffer {
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
sub_streams[i], EQ, params, num_blocks, false, allocate_gpu_memory,
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}
this->final_index_buf = new int_final_index_from_selectors_buffer<Torus>(
streams, params, num_inputs, num_blocks_index, allocate_gpu_memory,
size_tracker);
}
void release(CudaStreams streams) {
for (uint32_t i = 0; i < num_streams; i++) {
eq_buffers[i]->release(streams);
delete eq_buffers[i];
}
delete[] eq_buffers;
this->final_index_buf->release(streams);
delete this->final_index_buf;
cuda_event_destroy(incoming_event, streams.gpu_index(0));
uint32_t num_gpus = active_streams.count();
for (uint j = 0; j < num_streams; j++) {
for (uint k = 0; k < num_gpus; k++) {
cuda_event_destroy(outgoing_events[j * num_gpus + k],
active_streams.gpu_index(k));
}
}
delete[] outgoing_events;
for (uint32_t i = 0; i < num_streams; i++) {
sub_streams[i].release();
}
delete[] sub_streams;
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
}
};
template <typename Torus> struct int_unchecked_index_of_clear_buffer {
int_radix_params params;
bool allocate_gpu_memory;
uint32_t num_inputs;
int_comparison_buffer<Torus> **eq_buffers;
int_final_index_from_selectors_buffer<Torus> *final_index_buf;
CudaStreams active_streams;
CudaStreams *sub_streams;
cudaEvent_t incoming_event;
cudaEvent_t *outgoing_events;
uint32_t num_streams;
int_unchecked_index_of_clear_buffer(CudaStreams streams,
int_radix_params params,
uint32_t num_inputs, uint32_t num_blocks,
uint32_t num_blocks_index,
bool allocate_gpu_memory,
uint64_t &size_tracker) {
this->params = params;
this->allocate_gpu_memory = allocate_gpu_memory;
this->num_inputs = num_inputs;
uint32_t num_streams_to_use =
std::min((uint32_t)MAX_STREAMS_FOR_VECTOR_FIND, num_inputs);
if (num_streams_to_use == 0)
num_streams_to_use = 1;
this->num_streams = num_streams_to_use;
this->active_streams = streams.active_gpu_subset(num_blocks);
uint32_t num_gpus = active_streams.count();
incoming_event = cuda_create_event(streams.gpu_index(0));
sub_streams = new CudaStreams[num_streams_to_use];
outgoing_events = new cudaEvent_t[num_streams_to_use * num_gpus];
for (uint32_t i = 0; i < num_streams_to_use; i++) {
sub_streams[i].create_on_same_gpus(active_streams);
for (uint32_t j = 0; j < num_gpus; j++) {
outgoing_events[i * num_gpus + j] =
cuda_create_event(active_streams.gpu_index(j));
}
}
this->eq_buffers = new int_comparison_buffer<Torus> *[num_streams];
for (uint32_t i = 0; i < num_streams; i++) {
this->eq_buffers[i] = new int_comparison_buffer<Torus>(
streams, EQ, params, num_blocks, false, allocate_gpu_memory,
size_tracker);
}

View File

@@ -228,49 +228,6 @@ void cleanup_cuda_unchecked_is_in_clears_64(CudaStreamsFFI streams,
*mem_ptr_void = nullptr;
}
uint64_t scratch_cuda_compute_final_index_from_selectors_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_inputs, uint32_t num_blocks_index, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_compute_final_index_from_selectors<uint64_t>(
CudaStreams(streams),
(int_final_index_from_selectors_buffer<uint64_t> **)mem_ptr, params,
num_inputs, num_blocks_index, allocate_gpu_memory);
}
void cuda_compute_final_index_from_selectors_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors,
uint32_t num_inputs, uint32_t num_blocks_index, int8_t *mem,
void *const *bsks, void *const *ksks) {
host_compute_final_index_from_selectors<uint64_t>(
CudaStreams(streams), index_ct, match_ct, selectors, num_inputs,
num_blocks_index, (int_final_index_from_selectors_buffer<uint64_t> *)mem,
bsks, (uint64_t *const *)ksks);
}
void cleanup_cuda_compute_final_index_from_selectors_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
int_final_index_from_selectors_buffer<uint64_t> *mem_ptr =
(int_final_index_from_selectors_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}
uint64_t scratch_cuda_unchecked_index_in_clears_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
@@ -500,3 +457,50 @@ void cleanup_cuda_unchecked_index_of_64(CudaStreamsFFI streams,
delete mem_ptr;
*mem_ptr_void = nullptr;
}
uint64_t scratch_cuda_unchecked_index_of_clear_64(
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t big_lwe_dimension,
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
big_lwe_dimension, small_lwe_dimension, ks_level,
ks_base_log, pbs_level, pbs_base_log, grouping_factor,
message_modulus, carry_modulus, noise_reduction_type);
return scratch_cuda_unchecked_index_of_clear<uint64_t>(
CudaStreams(streams),
(int_unchecked_index_of_clear_buffer<uint64_t> **)mem_ptr, params,
num_inputs, num_blocks, num_blocks_index, allocate_gpu_memory);
}
void cuda_unchecked_index_of_clear_64(
CudaStreamsFFI streams, CudaRadixCiphertextFFI *index_ct,
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
const void *d_scalar_blocks, bool is_scalar_obviously_bigger,
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
uint32_t num_blocks_index, int8_t *mem, void *const *bsks,
void *const *ksks) {
host_unchecked_index_of_clear<uint64_t>(
CudaStreams(streams), index_ct, match_ct, inputs,
(const uint64_t *)d_scalar_blocks, is_scalar_obviously_bigger, num_inputs,
num_blocks, num_scalar_blocks, num_blocks_index,
(int_unchecked_index_of_clear_buffer<uint64_t> *)mem, bsks,
(uint64_t *const *)ksks);
}
void cleanup_cuda_unchecked_index_of_clear_64(CudaStreamsFFI streams,
int8_t **mem_ptr_void) {
int_unchecked_index_of_clear_buffer<uint64_t> *mem_ptr =
(int_unchecked_index_of_clear_buffer<uint64_t> *)(*mem_ptr_void);
mem_ptr->release(CudaStreams(streams));
delete mem_ptr;
*mem_ptr_void = nullptr;
}

View File

@@ -5,6 +5,7 @@
#include "integer/comparison.cuh"
#include "integer/integer.cuh"
#include "integer/radix_ciphertext.cuh"
#include "integer/scalar_comparison.cuh"
#include "integer/vector_find.h"
template <typename Torus>
@@ -1110,3 +1111,96 @@ __host__ void host_unchecked_index_of(
mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks,
num_inputs);
}
template <typename Torus>
uint64_t scratch_cuda_unchecked_index_of_clear(
CudaStreams streams, int_unchecked_index_of_clear_buffer<Torus> **mem_ptr,
int_radix_params params, uint32_t num_inputs, uint32_t num_blocks,
uint32_t num_blocks_index, bool allocate_gpu_memory) {
uint64_t size_tracker = 0;
*mem_ptr = new int_unchecked_index_of_clear_buffer<Torus>(
streams, params, num_inputs, num_blocks, num_blocks_index,
allocate_gpu_memory, size_tracker);
return size_tracker;
}
template <typename Torus>
__host__ void host_unchecked_index_of_clear(
CudaStreams streams, CudaRadixCiphertextFFI *index_ct,
CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *inputs,
const Torus *d_scalar_blocks, bool is_scalar_obviously_bigger,
uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks,
uint32_t num_blocks_index,
int_unchecked_index_of_clear_buffer<Torus> *mem_ptr, void *const *bsks,
Torus *const *ksks) {
CudaRadixCiphertextFFI *packed_selectors =
mem_ptr->final_index_buf->packed_selectors;
if (is_scalar_obviously_bigger) {
set_zero_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), packed_selectors, 0,
num_inputs);
} else {
cuda_event_record(mem_ptr->incoming_event, streams.stream(0),
streams.gpu_index(0));
for (uint32_t j = 0; j < mem_ptr->num_streams; j++) {
for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) {
cuda_stream_wait_event(mem_ptr->sub_streams[j].stream(i),
mem_ptr->incoming_event,
mem_ptr->sub_streams[j].gpu_index(i));
}
}
uint32_t num_streams = mem_ptr->num_streams;
uint32_t num_gpus = mem_ptr->active_streams.count();
for (uint32_t i = 0; i < num_inputs; i++) {
uint32_t stream_idx = i % num_streams;
CudaStreams current_stream = mem_ptr->sub_streams[stream_idx];
CudaRadixCiphertextFFI const *input_ct = &inputs[i];
CudaRadixCiphertextFFI current_selector_dest;
as_radix_ciphertext_slice<Torus>(&current_selector_dest, packed_selectors,
i, i + 1);
host_scalar_equality_check<Torus>(
current_stream, &current_selector_dest, input_ct, d_scalar_blocks,
mem_ptr->eq_buffers[stream_idx], bsks, (Torus **)ksks, num_blocks,
num_scalar_blocks);
}
for (uint32_t j = 0; j < mem_ptr->num_streams; j++) {
for (uint32_t i = 0; i < mem_ptr->active_streams.count(); i++) {
cuda_event_record(mem_ptr->outgoing_events[j * num_gpus + i],
mem_ptr->sub_streams[j].stream(i),
mem_ptr->sub_streams[j].gpu_index(i));
cuda_stream_wait_event(streams.stream(0),
mem_ptr->outgoing_events[j * num_gpus + i],
streams.gpu_index(0));
}
}
}
uint32_t packed_len = (num_blocks_index + 1) / 2;
host_create_possible_results<Torus>(
streams, mem_ptr->final_index_buf->possible_results_ct_list,
mem_ptr->final_index_buf->unpacked_selectors, num_inputs,
(const uint64_t *)mem_ptr->final_index_buf->h_indices, packed_len,
mem_ptr->final_index_buf->possible_results_buf, bsks, ksks);
host_aggregate_one_hot_vector<Torus>(
streams, index_ct, mem_ptr->final_index_buf->possible_results_ct_list,
num_inputs, packed_len, mem_ptr->final_index_buf->aggregate_buf, bsks,
ksks);
host_integer_is_at_least_one_comparisons_block_true<Torus>(
streams, match_ct, packed_selectors,
mem_ptr->final_index_buf->reduction_buf, bsks, (Torus **)ksks,
num_inputs);
}

View File

@@ -1922,47 +1922,6 @@ unsafe extern "C" {
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_compute_final_index_from_selectors_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_inputs: u32,
num_blocks_index: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_compute_final_index_from_selectors_64(
streams: CudaStreamsFFI,
index_ct: *mut CudaRadixCiphertextFFI,
match_ct: *mut CudaRadixCiphertextFFI,
selectors: *const CudaRadixCiphertextFFI,
num_inputs: u32,
num_blocks_index: u32,
mem: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_compute_final_index_from_selectors_64(
streams: CudaStreamsFFI,
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_unchecked_index_in_clears_64(
streams: CudaStreamsFFI,
@@ -2181,6 +2140,52 @@ unsafe extern "C" {
unsafe extern "C" {
pub fn cleanup_cuda_unchecked_index_of_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8);
}
unsafe extern "C" {
pub fn scratch_cuda_unchecked_index_of_clear_64(
streams: CudaStreamsFFI,
mem_ptr: *mut *mut i8,
glwe_dimension: u32,
polynomial_size: u32,
big_lwe_dimension: u32,
small_lwe_dimension: u32,
ks_level: u32,
ks_base_log: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_inputs: u32,
num_blocks: u32,
num_blocks_index: u32,
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
allocate_gpu_memory: bool,
noise_reduction_type: PBS_MS_REDUCTION_T,
) -> u64;
}
unsafe extern "C" {
pub fn cuda_unchecked_index_of_clear_64(
streams: CudaStreamsFFI,
index_ct: *mut CudaRadixCiphertextFFI,
match_ct: *mut CudaRadixCiphertextFFI,
inputs: *const CudaRadixCiphertextFFI,
d_scalar_blocks: *const ffi::c_void,
is_scalar_obviously_bigger: bool,
num_inputs: u32,
num_blocks: u32,
num_scalar_blocks: u32,
num_blocks_index: u32,
mem: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,
);
}
unsafe extern "C" {
pub fn cleanup_cuda_unchecked_index_of_clear_64(
streams: CudaStreamsFFI,
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn scratch_cuda_integer_compress_radix_ciphertext_64(
streams: CudaStreamsFFI,

View File

@@ -9321,129 +9321,6 @@ pub(crate) unsafe fn cuda_backend_unchecked_is_in_clears<
update_noise_degree(&mut output.0.ciphertext, &ffi_output);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_compute_final_index_from_selectors<
T: UnsignedInteger,
B: Numeric,
>(
streams: &CudaStreams,
index_ct: &mut CudaRadixCiphertext,
match_ct: &mut CudaBooleanBlock,
selectors: &[CudaBooleanBlock],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) {
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
let num_inputs = selectors.len() as u32;
let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32;
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut index_degrees = index_ct
.info
.blocks
.iter()
.map(|b| b.degree.get())
.collect();
let mut index_noise_levels = index_ct
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut ffi_index =
prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels);
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
let mut ffi_match = prepare_cuda_radix_ffi(
&match_ct.0.ciphertext,
&mut match_degrees,
&mut match_noise_levels,
);
let mut ffi_selectors_degrees: Vec<Vec<u64>> = Vec::with_capacity(selectors.len());
let mut ffi_selectors_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(selectors.len());
let ffi_selectors: Vec<CudaRadixCiphertextFFI> = selectors
.iter()
.map(|ct| {
let degrees = vec![ct.0.ciphertext.info.blocks[0].degree.get()];
let noise_levels = vec![ct.0.ciphertext.info.blocks[0].noise_level.0];
ffi_selectors_degrees.push(degrees);
ffi_selectors_noise_levels.push(noise_levels);
prepare_cuda_radix_ffi(
&ct.0.ciphertext,
ffi_selectors_degrees.last_mut().unwrap(),
ffi_selectors_noise_levels.last_mut().unwrap(),
)
})
.collect();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_compute_final_index_from_selectors_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_inputs,
num_blocks_index,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_compute_final_index_from_selectors_64(
streams.ffi(),
&raw mut ffi_index,
&raw mut ffi_match,
ffi_selectors.as_ptr(),
num_inputs,
num_blocks_index,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_compute_final_index_from_selectors_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
);
update_noise_degree(index_ct, &ffi_index);
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
@@ -10152,3 +10029,158 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of<
update_noise_degree(index_ct, &ffi_index);
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - The data must not be moved or dropped while being used by the CUDA kernel.
/// - This function assumes exclusive access to the passed data; violating this may lead to
/// undefined behavior.
pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear<
T: UnsignedInteger,
B: Numeric,
C: CudaIntegerRadixCiphertext,
Clear: DecomposableInto<u64> + CastInto<usize>,
>(
streams: &CudaStreams,
index_ct: &mut CudaRadixCiphertext,
match_ct: &mut CudaBooleanBlock,
inputs: &[C],
clear: Clear,
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
big_lwe_dimension: LweDimension,
small_lwe_dimension: LweDimension,
ks_level: DecompositionLevelCount,
ks_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
pbs_base_log: DecompositionBaseLog,
pbs_type: PBSType,
grouping_factor: LweBskGroupingFactor,
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
) {
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
let num_inputs = inputs.len() as u32;
let num_blocks_in_ct = inputs[0].as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
let num_blocks_index = index_ct.d_blocks.lwe_ciphertext_count().0 as u32;
let mut scalar_blocks =
BlockDecomposer::with_early_stop_at_zero(clear, message_modulus.0.ilog2())
.iter_as::<u64>()
.collect::<Vec<_>>();
let is_scalar_obviously_bigger = scalar_blocks
.get(num_blocks_in_ct as usize..)
.is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0));
scalar_blocks.truncate(num_blocks_in_ct as usize);
let num_scalar_blocks = scalar_blocks.len() as u32;
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
let mut index_degrees = index_ct
.info
.blocks
.iter()
.map(|b| b.degree.get())
.collect();
let mut index_noise_levels = index_ct
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
let mut ffi_index =
prepare_cuda_radix_ffi(index_ct, &mut index_degrees, &mut index_noise_levels);
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
let mut ffi_match = prepare_cuda_radix_ffi(
&match_ct.0.ciphertext,
&mut match_degrees,
&mut match_noise_levels,
);
let mut ffi_inputs_degrees: Vec<Vec<u64>> = Vec::with_capacity(inputs.len());
let mut ffi_inputs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(inputs.len());
let ffi_inputs: Vec<CudaRadixCiphertextFFI> = inputs
.iter()
.map(|ct| {
let degrees = ct
.as_ref()
.info
.blocks
.iter()
.map(|b| b.degree.get())
.collect();
let noise_levels = ct
.as_ref()
.info
.blocks
.iter()
.map(|b| b.noise_level.0)
.collect();
ffi_inputs_degrees.push(degrees);
ffi_inputs_noise_levels.push(noise_levels);
prepare_cuda_radix_ffi(
ct.as_ref(),
ffi_inputs_degrees.last_mut().unwrap(),
ffi_inputs_noise_levels.last_mut().unwrap(),
)
})
.collect();
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
scratch_cuda_unchecked_index_of_clear_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
big_lwe_dimension.0 as u32,
small_lwe_dimension.0 as u32,
ks_level.0 as u32,
ks_base_log.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_inputs,
num_blocks_in_ct,
num_blocks_index,
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
true,
noise_reduction_type as u32,
);
cuda_unchecked_index_of_clear_64(
streams.ffi(),
&raw mut ffi_index,
&raw mut ffi_match,
ffi_inputs.as_ptr(),
d_scalar_blocks.as_c_ptr(0),
is_scalar_obviously_bigger,
num_inputs,
num_blocks_in_ct,
num_scalar_blocks,
num_blocks_index,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),
);
cleanup_cuda_unchecked_index_of_clear_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
update_noise_degree(index_ct, &ffi_index);
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
}

View File

@@ -5,14 +5,13 @@ use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
cuda_backend_compute_final_index_from_selectors,
cuda_backend_get_unchecked_match_value_or_size_on_gpu,
cuda_backend_get_unchecked_match_value_size_on_gpu, cuda_backend_unchecked_contains,
cuda_backend_unchecked_contains_clear, cuda_backend_unchecked_first_index_in_clears,
cuda_backend_unchecked_first_index_of, cuda_backend_unchecked_first_index_of_clear,
cuda_backend_unchecked_index_in_clears, cuda_backend_unchecked_index_of,
cuda_backend_unchecked_is_in_clears, cuda_backend_unchecked_match_value,
cuda_backend_unchecked_match_value_or, PBSType,
cuda_backend_unchecked_index_of_clear, cuda_backend_unchecked_is_in_clears,
cuda_backend_unchecked_match_value, cuda_backend_unchecked_match_value_or, PBSType,
};
pub use crate::integer::server_key::radix_parallel::MatchValues;
use crate::prelude::CastInto;
@@ -1490,12 +1489,80 @@ impl CudaServerKey {
);
return (trivial_ct, trivial_bool);
}
let selectors = cts
.iter()
.map(|ct| self.scalar_eq(ct, clear, streams))
.collect::<Vec<_>>();
self.compute_final_index_from_selectors(&selectors, streams)
let num_inputs = cts.len();
let num_blocks_index =
(num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize;
let mut index_ct: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_blocks_index, streams);
let trivial_bool =
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_index_of_clear(
streams,
index_ct.as_mut(),
&mut match_ct,
cts,
clear,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_index_of_clear(
streams,
index_ct.as_mut(),
&mut match_ct,
cts,
clear,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
(index_ct, match_ct)
}
/// Returns the encrypted index of the of clear `value` in the ciphertext slice
@@ -1932,82 +1999,4 @@ impl CudaServerKey {
};
self.unchecked_first_index_of(cts, value, streams)
}
fn compute_final_index_from_selectors(
&self,
selectors: &[CudaBooleanBlock],
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let num_inputs = selectors.len();
let num_blocks_index =
(num_inputs.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize;
let mut index_ct: CudaUnsignedRadixCiphertext =
self.create_trivial_zero_radix(num_blocks_index, streams);
let trivial_bool =
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_compute_final_index_from_selectors(
streams,
index_ct.as_mut(),
&mut match_ct,
selectors,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_compute_final_index_from_selectors(
streams,
index_ct.as_mut(),
&mut match_ct,
selectors,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
(index_ct, match_ct)
}
}