mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix race condition on expand when on multi-gpu
This commit is contained in:
@@ -112,15 +112,15 @@ template <typename Torus> struct zk_expand_mem {
|
||||
|
||||
// Hint for future readers: if message_modulus == 4 then
|
||||
// packed_messages_per_lwe becomes 2
|
||||
auto packed_messages_per_lwe = log2_int(params.message_modulus);
|
||||
auto num_packed_msgs = log2_int(params.message_modulus);
|
||||
|
||||
// Adjust indexes to permute the output and access the correct LUT
|
||||
auto h_indexes_in = static_cast<Torus *>(
|
||||
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
||||
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
|
||||
auto h_indexes_out = static_cast<Torus *>(
|
||||
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
||||
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
|
||||
auto h_lut_indexes = static_cast<Torus *>(
|
||||
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
||||
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
|
||||
auto h_body_id_per_compact_list =
|
||||
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
|
||||
auto h_lwe_compact_input_indexes =
|
||||
@@ -138,6 +138,10 @@ template <typename Torus> struct zk_expand_mem {
|
||||
auto compact_list_id = 0;
|
||||
auto idx = 0;
|
||||
auto count = 0;
|
||||
// During flatenning, all num_lwes LWEs from all compact lists are stored
|
||||
// sequentially on a Torus array. h_lwe_compact_input_indexes stores the
|
||||
// index of the first LWE related to the compact list that contains the i-th
|
||||
// LWE
|
||||
for (int i = 0; i < num_lwes; i++) {
|
||||
h_lwe_compact_input_indexes[i] = idx;
|
||||
count++;
|
||||
@@ -148,6 +152,8 @@ template <typename Torus> struct zk_expand_mem {
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the index of the i-th LWE (within each compact list) related to
|
||||
// the k-th compact list.
|
||||
auto offset = 0;
|
||||
for (int k = 0; k < num_compact_lists; k++) {
|
||||
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
|
||||
@@ -159,46 +165,75 @@ template <typename Torus> struct zk_expand_mem {
|
||||
offset += num_lwes_in_kth_compact_list;
|
||||
}
|
||||
|
||||
/*
|
||||
* Each LWE contains encrypted data in both carry and message spaces
|
||||
* that needs to be extracted.
|
||||
*
|
||||
* The loop processes each compact list (k) and for each LWE within that
|
||||
* list:
|
||||
* 1. Sets input indexes to read each LWE twice (for carry and message
|
||||
* extraction)
|
||||
* 2. Creates output indexes to properly reorder the results
|
||||
* 3. Selects appropriate LUT index based on whether boolean sanitization is
|
||||
* needed
|
||||
*
|
||||
* We want the output to have always first the content of the message part
|
||||
* and then the content of the carry part of each LWE.
|
||||
*
|
||||
* i.e. msg_extract(LWE_0), carry_extract(LWE_0), msg_extract(LWE_1),
|
||||
* carry_extract(LWE_1), ...
|
||||
*
|
||||
* Aiming that behavior, with 4 LWEs we would have:
|
||||
*
|
||||
* // Each LWE is processed twice
|
||||
* h_indexes_in = {0, 1, 2, 3, 0, 1, 2, 3}
|
||||
*
|
||||
* // First 4 use message LUT, last 4 use carry LUT
|
||||
* h_lut_indexes = {0, 0, 0, 0, 1, 1, 1, 1}
|
||||
*
|
||||
* // Reorders output so message and carry for each LWE appear together
|
||||
* h_indexes_out = {0, 2, 4, 6, 1, 3, 5, 7}
|
||||
*
|
||||
* If an LWE contains a boolean value, its LUT index is shifted by
|
||||
* num_packed_msgs to use the sanitization LUT (which ensures output is
|
||||
* exactly 0 or 1).
|
||||
*/
|
||||
offset = 0;
|
||||
for (int k = 0; k < num_compact_lists; k++) {
|
||||
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
|
||||
for (int i = 0;
|
||||
i < packed_messages_per_lwe * num_lwes_in_kth_compact_list; i++) {
|
||||
Torus j = i % num_lwes_in_kth_compact_list;
|
||||
h_indexes_in[i + packed_messages_per_lwe * offset] = j + offset;
|
||||
h_indexes_out[i + packed_messages_per_lwe * offset] =
|
||||
packed_messages_per_lwe * (j + offset) +
|
||||
(i / num_lwes_in_kth_compact_list);
|
||||
auto num_lwes_in_kth = num_lwes_per_compact_list[k];
|
||||
for (int i = 0; i < num_packed_msgs * num_lwes_in_kth; i++) {
|
||||
auto lwe_index = i + num_packed_msgs * offset;
|
||||
auto lwe_index_in_list = i % num_lwes_in_kth;
|
||||
h_indexes_in[lwe_index] = lwe_index_in_list + offset;
|
||||
h_indexes_out[lwe_index] =
|
||||
num_packed_msgs * h_indexes_in[lwe_index] + i / num_lwes_in_kth;
|
||||
// If the input relates to a boolean, shift the LUT so the correct one
|
||||
// with sanitization is used
|
||||
h_lut_indexes[i + packed_messages_per_lwe * offset] =
|
||||
(is_boolean_array[h_indexes_out[i +
|
||||
packed_messages_per_lwe * offset]]
|
||||
? packed_messages_per_lwe
|
||||
: 0) +
|
||||
i / num_lwes_in_kth_compact_list;
|
||||
auto boolean_offset =
|
||||
is_boolean_array[h_indexes_out[lwe_index]] ? num_packed_msgs : 0;
|
||||
h_lut_indexes[lwe_index] = i / num_lwes_in_kth + boolean_offset;
|
||||
}
|
||||
offset += num_lwes_in_kth_compact_list;
|
||||
offset += num_lwes_in_kth;
|
||||
}
|
||||
|
||||
message_and_carry_extract_luts->set_lwe_indexes(
|
||||
streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out);
|
||||
auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0);
|
||||
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0);
|
||||
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
|
||||
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
||||
allocate_gpu_memory);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lut_indexes, h_lut_indexes,
|
||||
packed_messages_per_lwe * num_lwes * sizeof(Torus), streams[0],
|
||||
gpu_indexes[0], allocate_gpu_memory);
|
||||
lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0], allocate_gpu_memory);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
d_body_id_per_compact_list, h_body_id_per_compact_list,
|
||||
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
||||
allocate_gpu_memory);
|
||||
|
||||
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0);
|
||||
|
||||
// The expanded LWEs will always be on the casting key format
|
||||
tmp_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus),
|
||||
|
||||
@@ -861,7 +861,10 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::prelude::*;
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::{CompressedServerKey, set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32};
|
||||
use crate::{set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32};
|
||||
|
||||
#[cfg(all(feature = "zk-pok", feature = "gpu"))]
|
||||
use crate::CompressedServerKey;
|
||||
|
||||
#[test]
|
||||
fn test_compact_list() {
|
||||
@@ -1100,7 +1103,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(all(feature = "zk-pok", feature = "gpu"))]
|
||||
#[test]
|
||||
fn test_gpu_proven_compact_list() {
|
||||
|
||||
Reference in New Issue
Block a user