fix(gpu): fix get oprf size on gpu

This commit is contained in:
Agnes Leroy
2025-09-26 16:57:12 +02:00
committed by Agnès Leroy
parent c5ad73865c
commit daf0e79e4a
2 changed files with 27 additions and 6 deletions

View File

@@ -6066,15 +6066,17 @@ template <typename Torus> struct int_grouped_oprf_memory {
// to a correction and all others to 0.
// All lwes in h_corrections have a mask equal to 0.
// Copy the prepared plaintext corrections to the GPU.
cuda_memcpy_async_to_gpu(this->plaintext_corrections->ptr, h_corrections,
num_blocks_to_process * lwe_size * sizeof(Torus),
streams.stream(0), streams.gpu_index(0));
cuda_memcpy_with_size_tracking_async_to_gpu(
this->plaintext_corrections->ptr, h_corrections,
num_blocks_to_process * lwe_size * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
// Copy the prepared LUT indexes to the GPU 0, before broadcast to all other
// GPUs.
cuda_memcpy_async_to_gpu(luts->get_lut_indexes(0, 0), this->h_lut_indexes,
num_blocks_to_process * sizeof(Torus),
streams.stream(0), streams.gpu_index(0));
cuda_memcpy_with_size_tracking_async_to_gpu(
luts->get_lut_indexes(0, 0), this->h_lut_indexes,
num_blocks_to_process * sizeof(Torus), streams.stream(0),
streams.gpu_index(0), allocate_gpu_memory);
auto active_streams = streams.active_gpu_subset(num_blocks_to_process);
luts->broadcast_lut(active_streams);