chore(gpu): use smart pointers in radix lut

This commit is contained in:
Andrei Stoian
2025-10-03 14:59:13 +02:00
parent 3be7aae8f3
commit 705f383ad2
8 changed files with 71 additions and 38 deletions

View File

@@ -5,6 +5,7 @@
#include <cstdio>
#include <cstdlib>
#include <cuda_runtime.h>
#include <memory>
extern "C" {
@@ -140,4 +141,34 @@ template <typename Torus>
void cuda_set_value_async(cudaStream_t stream, uint32_t gpu_index,
Torus *d_array, Torus value, Torus n);
template <class T> struct malloc_with_size_tracking_async_deleter {
private:
cudaStream_t _stream;
uint32_t _gpu_index;
uint64_t &_size_tracker;
bool _allocate_gpu_memory;
public:
malloc_with_size_tracking_async_deleter(cudaStream_t stream,
uint32_t gpu_index,
uint64_t &size_tracker,
bool allocate_gpu_memory)
: _stream(stream), _gpu_index(gpu_index), _size_tracker(size_tracker),
_allocate_gpu_memory(allocate_gpu_memory)
{}
void operator()(T *ptr) { cuda_drop_with_size_tracking_async(ptr, _stream, _gpu_index, _allocate_gpu_memory) ; }
};
template <class T>
std::shared_ptr<T> cuda_make_shared_with_size_tracking_async(
uint64_t size, cudaStream_t stream, uint32_t gpu_index,
uint64_t &size_tracker, bool allocate_gpu_memory) {
return std::shared_ptr<T>(
(T*)cuda_malloc_with_size_tracking_async(size, stream, gpu_index,
size_tracker, allocate_gpu_memory),
malloc_with_size_tracking_async_deleter<T>(
stream, gpu_index, size_tracker, allocate_gpu_memory));
}
#endif

View File

@@ -306,7 +306,7 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
// All tmp lwe arrays and index arrays for lwe contain the total
// amount of blocks to be computed on, there is no split between GPUs
// for the moment
Torus *lwe_indexes_in = nullptr;
std::shared_ptr<Torus> lwe_indexes_in = nullptr;
Torus *lwe_indexes_out = nullptr;
Torus *h_lwe_indexes_in = nullptr;
Torus *h_lwe_indexes_out = nullptr;
@@ -315,9 +315,8 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
// lwe_trivial_indexes is the intermediary index we need in case
// lwe_indexes_in != lwe_indexes_out
Torus *lwe_trivial_indexes = nullptr;
// buffer to store packed message bits of a radix ciphertext
CudaRadixCiphertextFFI *tmp_lwe_before_ks = nullptr;
std::shared_ptr<CudaRadixCiphertextFFI> tmp_lwe_before_ks;
/// For multi GPU execution we create vectors of pointers for inputs and
/// outputs
@@ -384,10 +383,10 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
buffer.push_back(gpu_pbs_buffer);
}
tmp_lwe_before_ks = new CudaRadixCiphertextFFI;
tmp_lwe_before_ks = std::make_shared<CudaRadixCiphertextFFI>();
create_zero_radix_ciphertext_async<Torus>(
active_streams.stream(0), active_streams.gpu_index(0),
tmp_lwe_before_ks, num_radix_blocks, input_big_lwe_dimension,
tmp_lwe_before_ks.get(), num_radix_blocks, input_big_lwe_dimension,
size_tracker, allocate_gpu_memory);
}
@@ -454,7 +453,7 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
uint64_t &size_tracker) {
// lwe_(input/output)_indexes are initialized to range(num_radix_blocks)
// by default
lwe_indexes_in = (Torus *)cuda_malloc_with_size_tracking_async(
lwe_indexes_in = cuda_make_shared_with_size_tracking_async<Torus>(
num_radix_blocks * sizeof(Torus), active_streams.stream(0),
active_streams.gpu_index(0), size_tracker, allocate_gpu_memory);
lwe_indexes_out = (Torus *)cuda_malloc_with_size_tracking_async(
@@ -471,9 +470,9 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
h_lwe_indexes_in[i] = i;
cuda_memcpy_with_size_tracking_async_to_gpu(
lwe_indexes_in, h_lwe_indexes_in, num_radix_blocks * sizeof(Torus),
active_streams.stream(0), active_streams.gpu_index(0),
allocate_gpu_memory);
lwe_indexes_in.get(), h_lwe_indexes_in,
num_radix_blocks * sizeof(Torus), active_streams.stream(0),
active_streams.gpu_index(0), allocate_gpu_memory);
cuda_memcpy_with_size_tracking_async_to_gpu(
lwe_indexes_out, h_lwe_indexes_in, num_radix_blocks * sizeof(Torus),
active_streams.stream(0), active_streams.gpu_index(0),
@@ -660,8 +659,8 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
memcpy(h_lwe_indexes_out, h_indexes_out, num_blocks * sizeof(Torus));
cuda_memcpy_with_size_tracking_async_to_gpu(
lwe_indexes_in, h_lwe_indexes_in, num_blocks * sizeof(Torus), stream,
gpu_index, gpu_memory_allocated);
lwe_indexes_in.get(), h_lwe_indexes_in, num_blocks * sizeof(Torus),
stream, gpu_index, gpu_memory_allocated);
cuda_memcpy_with_size_tracking_async_to_gpu(
lwe_indexes_out, h_lwe_indexes_out, num_blocks * sizeof(Torus), stream,
gpu_index, gpu_memory_allocated);
@@ -766,9 +765,10 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
lut_indexes_vec[i], active_streams.stream(i),
active_streams.gpu_index(i), gpu_memory_allocated);
}
cuda_drop_with_size_tracking_async(lwe_indexes_in, active_streams.stream(0),
active_streams.gpu_index(0),
gpu_memory_allocated);
lwe_indexes_in.reset();
/*cuda_drop_with_size_tracking_async(lwe_indexes_in,
active_streams.stream(0), active_streams.gpu_index(0),
gpu_memory_allocated);*/
cuda_drop_with_size_tracking_async(
lwe_indexes_out, active_streams.stream(0), active_streams.gpu_index(0),
gpu_memory_allocated);
@@ -791,9 +791,11 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
}
if (!mem_reuse) {
release_radix_ciphertext_async(active_streams.stream(0),
active_streams.gpu_index(0),
tmp_lwe_before_ks, gpu_memory_allocated);
GPU_ASSERT(tmp_lwe_before_ks.use_count() == 1,
"This int_radix_lut is still sharing memory with another");
release_radix_ciphertext_async(
active_streams.stream(0), active_streams.gpu_index(0),
tmp_lwe_before_ks.get(), gpu_memory_allocated);
for (int i = 0; i < buffer.size(); i++) {
switch (params.pbs_type) {
case MULTI_BIT:
@@ -812,7 +814,7 @@ template <typename Torus, typename OutputTorus> struct int_radix_lut_generic {
cuda_synchronize_stream(active_streams.stream(i),
active_streams.gpu_index(i));
}
delete tmp_lwe_before_ks;
tmp_lwe_before_ks.reset();
buffer.clear();
if (gpu_memory_allocated) {

View File

@@ -31,8 +31,8 @@ __host__ void zero_out_if(CudaStreams streams,
// second operand is not an array
auto tmp_lwe_array_input = mem_ptr->tmp;
host_pack_bivariate_blocks_with_single_block<Torus>(
streams, tmp_lwe_array_input, predicate->lwe_indexes_in, lwe_array_input,
lwe_condition, predicate->lwe_indexes_in, params.message_modulus,
streams, tmp_lwe_array_input, predicate->lwe_indexes_in.get(), lwe_array_input,
lwe_condition, predicate->lwe_indexes_in.get(), params.message_modulus,
num_radix_blocks);
integer_radix_apply_univariate_lookup_table_kb<Torus>(

View File

@@ -344,7 +344,7 @@ host_integer_decompress(CudaStreams streams,
execute_pbs_async<Torus, Torus>(
active_streams, (Torus *)d_lwe_array_out->ptr, lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, extracted_lwe,
lut->lwe_indexes_in, d_bsks, lut->buffer,
lut->lwe_indexes_in.get(), d_bsks, lut->buffer,
encryption_params.glwe_dimension,
compression_params.small_lwe_dimension,
encryption_params.polynomial_size, encryption_params.pbs_base_log,
@@ -365,7 +365,7 @@ host_integer_decompress(CudaStreams streams,
/// With multiple GPUs we push to the vectors on each GPU then when we
/// gather data to GPU 0 we can copy back to the original indexing
multi_gpu_scatter_lwe_async<Torus>(
active_streams, lwe_array_in_vec, extracted_lwe, lut->lwe_indexes_in,
active_streams, lwe_array_in_vec, extracted_lwe, lut->lwe_indexes_in.get(),
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
lut->active_streams.count(), num_blocks_to_decompress,
compression_params.small_lwe_dimension + 1);

View File

@@ -546,7 +546,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in.get(), ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
@@ -568,7 +568,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
PUSH_RANGE("scatter")
multi_gpu_scatter_lwe_async<Torus>(
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
big_lwe_dimension + 1);
POP_RANGE()
@@ -648,7 +648,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in.get(), ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
@@ -670,7 +670,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
PUSH_RANGE("scatter")
multi_gpu_scatter_lwe_async<Torus>(
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
big_lwe_dimension + 1);
POP_RANGE()
@@ -748,10 +748,10 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
uint32_t lut_stride = 0;
// Left message is shifted
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks;
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks.get();
host_pack_bivariate_blocks<Torus>(
streams, lwe_array_pbs_in, lut->lwe_trivial_indexes, lwe_array_1,
lwe_array_2, lut->lwe_indexes_in, shift, num_radix_blocks,
lwe_array_2, lut->lwe_indexes_in.get(), shift, num_radix_blocks,
params.message_modulus, params.carry_modulus);
check_cuda_error(cudaGetLastError());
@@ -766,7 +766,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in.get(), ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
@@ -785,7 +785,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
PUSH_RANGE("scatter")
multi_gpu_scatter_lwe_async<Torus>(
active_streams, lwe_array_in_vec, (Torus *)lwe_array_pbs_in->ptr,
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
big_lwe_dimension + 1);
POP_RANGE()
@@ -2334,7 +2334,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
/// For multi GPU execution we create vectors of pointers for inputs and
/// outputs
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks;
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks.get();
std::vector<InputTorus *> lwe_array_in_vec = lut->lwe_array_in_vec;
std::vector<InputTorus *> lwe_after_ks_vec = lut->lwe_after_ks_vec;
std::vector<__uint128_t *> lwe_after_pbs_vec = lut->lwe_after_pbs_vec;
@@ -2353,7 +2353,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<InputTorus>(
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in.get(), ksks,
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
ks_level, lwe_array_out->num_radix_blocks);
@@ -2377,7 +2377,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
/// gather data to GPU 0 we can copy back to the original indexing
multi_gpu_scatter_lwe_async<InputTorus>(
active_streams, lwe_array_in_vec, (InputTorus *)lwe_array_pbs_in->ptr,
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
lut->lwe_aligned_scatter_vec, lut->active_streams.count(),
lwe_array_out->num_radix_blocks, lut->input_big_lwe_dimension + 1);

View File

@@ -375,7 +375,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
while (needs_processing) {
auto luts_message_carry = mem_ptr->luts_message_carry;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in.get();
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
calculate_chunks<Torus>
<<<number_of_blocks_2d, number_of_threads, 0, streams.stream(0)>>>(
@@ -433,7 +433,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
if (mem_ptr->reduce_degrees_for_single_carry_propagation) {
auto luts_message_carry = mem_ptr->luts_message_carry;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in.get();
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
prepare_final_pbs_indexes<Torus>
<<<1, 2 * num_radix_blocks, 0, streams.stream(0)>>>(

View File

@@ -34,7 +34,7 @@ void host_integer_grouped_oprf(CudaStreams streams,
execute_pbs_async<Torus, Torus>(
streams.get_ith(0), (Torus *)(radix_lwe_out->ptr), lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec,
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in.get(), bsks,
lut->buffer, mem_ptr->params.glwe_dimension,
mem_ptr->params.small_lwe_dimension, mem_ptr->params.polynomial_size,
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
@@ -50,7 +50,7 @@ void host_integer_grouped_oprf(CudaStreams streams,
PUSH_RANGE("scatter")
multi_gpu_scatter_lwe_async<Torus>(
active_streams, lwe_array_in_vec, seeded_lwe_input, lut->lwe_indexes_in,
active_streams, lwe_array_in_vec, seeded_lwe_input, lut->lwe_indexes_in.get(),
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
active_streams.count(), num_blocks_to_process,
mem_ptr->params.small_lwe_dimension + 1);

View File

@@ -150,7 +150,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
// control_bit|b|a
host_pack_bivariate_blocks<Torus>(
streams, mux_inputs, mux_lut->lwe_indexes_out, rotated_input,
input_bits_a, mux_lut->lwe_indexes_in, 2, total_nb_bits,
input_bits_a, mux_lut->lwe_indexes_in.get(), 2, total_nb_bits,
mem->params.message_modulus, mem->params.carry_modulus);
// The shift bit is already properly aligned/positioned