fix(gpu): fix race condition on expand when on multi-gpu

This commit is contained in:
Pedro Alves
2025-05-29 15:42:33 -03:00
committed by Agnès Leroy
parent f511bdc279
commit 16ff092ce4
2 changed files with 62 additions and 25 deletions

View File

@@ -112,15 +112,15 @@ template <typename Torus> struct zk_expand_mem {
// Hint for future readers: if message_modulus == 4 then
// packed_messages_per_lwe becomes 2
auto packed_messages_per_lwe = log2_int(params.message_modulus);
auto num_packed_msgs = log2_int(params.message_modulus);
// Adjust indexes to permute the output and access the correct LUT
auto h_indexes_in = static_cast<Torus *>(
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
auto h_indexes_out = static_cast<Torus *>(
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
auto h_lut_indexes = static_cast<Torus *>(
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
auto h_body_id_per_compact_list =
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
auto h_lwe_compact_input_indexes =
@@ -138,6 +138,10 @@ template <typename Torus> struct zk_expand_mem {
auto compact_list_id = 0;
auto idx = 0;
auto count = 0;
// During flatenning, all num_lwes LWEs from all compact lists are stored
// sequentially on a Torus array. h_lwe_compact_input_indexes stores the
// index of the first LWE related to the compact list that contains the i-th
// LWE
for (int i = 0; i < num_lwes; i++) {
h_lwe_compact_input_indexes[i] = idx;
count++;
@@ -148,6 +152,8 @@ template <typename Torus> struct zk_expand_mem {
}
}
// Stores the index of the i-th LWE (within each compact list) related to
// the k-th compact list.
auto offset = 0;
for (int k = 0; k < num_compact_lists; k++) {
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
@@ -159,46 +165,75 @@ template <typename Torus> struct zk_expand_mem {
offset += num_lwes_in_kth_compact_list;
}
/*
* Each LWE contains encrypted data in both carry and message spaces
* that needs to be extracted.
*
* The loop processes each compact list (k) and for each LWE within that
* list:
* 1. Sets input indexes to read each LWE twice (for carry and message
* extraction)
* 2. Creates output indexes to properly reorder the results
* 3. Selects appropriate LUT index based on whether boolean sanitization is
* needed
*
* We want the output to have always first the content of the message part
* and then the content of the carry part of each LWE.
*
* i.e. msg_extract(LWE_0), carry_extract(LWE_0), msg_extract(LWE_1),
* carry_extract(LWE_1), ...
*
* Aiming that behavior, with 4 LWEs we would have:
*
* // Each LWE is processed twice
* h_indexes_in = {0, 1, 2, 3, 0, 1, 2, 3}
*
* // First 4 use message LUT, last 4 use carry LUT
* h_lut_indexes = {0, 0, 0, 0, 1, 1, 1, 1}
*
* // Reorders output so message and carry for each LWE appear together
* h_indexes_out = {0, 2, 4, 6, 1, 3, 5, 7}
*
* If an LWE contains a boolean value, its LUT index is shifted by
* num_packed_msgs to use the sanitization LUT (which ensures output is
* exactly 0 or 1).
*/
offset = 0;
for (int k = 0; k < num_compact_lists; k++) {
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
for (int i = 0;
i < packed_messages_per_lwe * num_lwes_in_kth_compact_list; i++) {
Torus j = i % num_lwes_in_kth_compact_list;
h_indexes_in[i + packed_messages_per_lwe * offset] = j + offset;
h_indexes_out[i + packed_messages_per_lwe * offset] =
packed_messages_per_lwe * (j + offset) +
(i / num_lwes_in_kth_compact_list);
auto num_lwes_in_kth = num_lwes_per_compact_list[k];
for (int i = 0; i < num_packed_msgs * num_lwes_in_kth; i++) {
auto lwe_index = i + num_packed_msgs * offset;
auto lwe_index_in_list = i % num_lwes_in_kth;
h_indexes_in[lwe_index] = lwe_index_in_list + offset;
h_indexes_out[lwe_index] =
num_packed_msgs * h_indexes_in[lwe_index] + i / num_lwes_in_kth;
// If the input relates to a boolean, shift the LUT so the correct one
// with sanitization is used
h_lut_indexes[i + packed_messages_per_lwe * offset] =
(is_boolean_array[h_indexes_out[i +
packed_messages_per_lwe * offset]]
? packed_messages_per_lwe
: 0) +
i / num_lwes_in_kth_compact_list;
auto boolean_offset =
is_boolean_array[h_indexes_out[lwe_index]] ? num_packed_msgs : 0;
h_lut_indexes[lwe_index] = i / num_lwes_in_kth + boolean_offset;
}
offset += num_lwes_in_kth_compact_list;
offset += num_lwes_in_kth;
}
message_and_carry_extract_luts->set_lwe_indexes(
streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out);
auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0);
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0);
cuda_memcpy_with_size_tracking_async_to_gpu(
d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
allocate_gpu_memory);
cuda_memcpy_with_size_tracking_async_to_gpu(
lut_indexes, h_lut_indexes,
packed_messages_per_lwe * num_lwes * sizeof(Torus), streams[0],
gpu_indexes[0], allocate_gpu_memory);
lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof(Torus),
streams[0], gpu_indexes[0], allocate_gpu_memory);
cuda_memcpy_with_size_tracking_async_to_gpu(
d_body_id_per_compact_list, h_body_id_per_compact_list,
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
allocate_gpu_memory);
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0);
// The expanded LWEs will always be on the casting key format
tmp_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus),

View File

@@ -861,7 +861,10 @@ mod tests {
use super::*;
use crate::prelude::*;
use crate::shortint::parameters::*;
use crate::{CompressedServerKey, set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32};
use crate::{set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32};
#[cfg(all(feature = "zk-pok", feature = "gpu"))]
use crate::CompressedServerKey;
#[test]
fn test_compact_list() {
@@ -1100,7 +1103,6 @@ mod tests {
}
}
#[cfg(all(feature = "zk-pok", feature = "gpu"))]
#[test]
fn test_gpu_proven_compact_list() {