mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix get oprf size on gpu
This commit is contained in:
@@ -6066,15 +6066,17 @@ template <typename Torus> struct int_grouped_oprf_memory {
|
||||
// to a correction and all others to 0.
|
||||
// All lwes in h_corrections have a mask equal to 0.
|
||||
// Copy the prepared plaintext corrections to the GPU.
|
||||
cuda_memcpy_async_to_gpu(this->plaintext_corrections->ptr, h_corrections,
|
||||
num_blocks_to_process * lwe_size * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0));
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
this->plaintext_corrections->ptr, h_corrections,
|
||||
num_blocks_to_process * lwe_size * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
|
||||
// Copy the prepared LUT indexes to the GPU 0, before broadcast to all other
|
||||
// GPUs.
|
||||
cuda_memcpy_async_to_gpu(luts->get_lut_indexes(0, 0), this->h_lut_indexes,
|
||||
num_blocks_to_process * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0));
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
luts->get_lut_indexes(0, 0), this->h_lut_indexes,
|
||||
num_blocks_to_process * sizeof(Torus), streams.stream(0),
|
||||
streams.gpu_index(0), allocate_gpu_memory);
|
||||
auto active_streams = streams.active_gpu_subset(num_blocks_to_process);
|
||||
luts->broadcast_lut(active_streams);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user