diff --git a/backends/tfhe-cuda-backend/cuda/include/zk/zk_utilities.h b/backends/tfhe-cuda-backend/cuda/include/zk/zk_utilities.h index cc57408f9..430ab4762 100644 --- a/backends/tfhe-cuda-backend/cuda/include/zk/zk_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/zk/zk_utilities.h @@ -112,15 +112,15 @@ template struct zk_expand_mem { // Hint for future readers: if message_modulus == 4 then // packed_messages_per_lwe becomes 2 - auto packed_messages_per_lwe = log2_int(params.message_modulus); + auto num_packed_msgs = log2_int(params.message_modulus); // Adjust indexes to permute the output and access the correct LUT auto h_indexes_in = static_cast( - malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus))); + malloc(num_packed_msgs * num_lwes * sizeof(Torus))); auto h_indexes_out = static_cast( - malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus))); + malloc(num_packed_msgs * num_lwes * sizeof(Torus))); auto h_lut_indexes = static_cast( - malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus))); + malloc(num_packed_msgs * num_lwes * sizeof(Torus))); auto h_body_id_per_compact_list = static_cast(malloc(num_lwes * sizeof(uint32_t))); auto h_lwe_compact_input_indexes = @@ -138,6 +138,10 @@ template struct zk_expand_mem { auto compact_list_id = 0; auto idx = 0; auto count = 0; + // During flatenning, all num_lwes LWEs from all compact lists are stored + // sequentially on a Torus array. h_lwe_compact_input_indexes stores the + // index of the first LWE related to the compact list that contains the i-th + // LWE for (int i = 0; i < num_lwes; i++) { h_lwe_compact_input_indexes[i] = idx; count++; @@ -148,6 +152,8 @@ template struct zk_expand_mem { } } + // Stores the index of the i-th LWE (within each compact list) related to + // the k-th compact list. auto offset = 0; for (int k = 0; k < num_compact_lists; k++) { auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k]; @@ -159,46 +165,75 @@ template struct zk_expand_mem { offset += num_lwes_in_kth_compact_list; } + /* + * Each LWE contains encrypted data in both carry and message spaces + * that needs to be extracted. + * + * The loop processes each compact list (k) and for each LWE within that + * list: + * 1. Sets input indexes to read each LWE twice (for carry and message + * extraction) + * 2. Creates output indexes to properly reorder the results + * 3. Selects appropriate LUT index based on whether boolean sanitization is + * needed + * + * We want the output to have always first the content of the message part + * and then the content of the carry part of each LWE. + * + * i.e. msg_extract(LWE_0), carry_extract(LWE_0), msg_extract(LWE_1), + * carry_extract(LWE_1), ... + * + * Aiming that behavior, with 4 LWEs we would have: + * + * // Each LWE is processed twice + * h_indexes_in = {0, 1, 2, 3, 0, 1, 2, 3} + * + * // First 4 use message LUT, last 4 use carry LUT + * h_lut_indexes = {0, 0, 0, 0, 1, 1, 1, 1} + * + * // Reorders output so message and carry for each LWE appear together + * h_indexes_out = {0, 2, 4, 6, 1, 3, 5, 7} + * + * If an LWE contains a boolean value, its LUT index is shifted by + * num_packed_msgs to use the sanitization LUT (which ensures output is + * exactly 0 or 1). + */ offset = 0; for (int k = 0; k < num_compact_lists; k++) { - auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k]; - for (int i = 0; - i < packed_messages_per_lwe * num_lwes_in_kth_compact_list; i++) { - Torus j = i % num_lwes_in_kth_compact_list; - h_indexes_in[i + packed_messages_per_lwe * offset] = j + offset; - h_indexes_out[i + packed_messages_per_lwe * offset] = - packed_messages_per_lwe * (j + offset) + - (i / num_lwes_in_kth_compact_list); + auto num_lwes_in_kth = num_lwes_per_compact_list[k]; + for (int i = 0; i < num_packed_msgs * num_lwes_in_kth; i++) { + auto lwe_index = i + num_packed_msgs * offset; + auto lwe_index_in_list = i % num_lwes_in_kth; + h_indexes_in[lwe_index] = lwe_index_in_list + offset; + h_indexes_out[lwe_index] = + num_packed_msgs * h_indexes_in[lwe_index] + i / num_lwes_in_kth; // If the input relates to a boolean, shift the LUT so the correct one // with sanitization is used - h_lut_indexes[i + packed_messages_per_lwe * offset] = - (is_boolean_array[h_indexes_out[i + - packed_messages_per_lwe * offset]] - ? packed_messages_per_lwe - : 0) + - i / num_lwes_in_kth_compact_list; + auto boolean_offset = + is_boolean_array[h_indexes_out[lwe_index]] ? num_packed_msgs : 0; + h_lut_indexes[lwe_index] = i / num_lwes_in_kth + boolean_offset; } - offset += num_lwes_in_kth_compact_list; + offset += num_lwes_in_kth; } message_and_carry_extract_luts->set_lwe_indexes( streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out); auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0); - message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0); cuda_memcpy_with_size_tracking_async_to_gpu( d_lwe_compact_input_indexes, h_lwe_compact_input_indexes, num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0], allocate_gpu_memory); cuda_memcpy_with_size_tracking_async_to_gpu( - lut_indexes, h_lut_indexes, - packed_messages_per_lwe * num_lwes * sizeof(Torus), streams[0], - gpu_indexes[0], allocate_gpu_memory); + lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof(Torus), + streams[0], gpu_indexes[0], allocate_gpu_memory); cuda_memcpy_with_size_tracking_async_to_gpu( d_body_id_per_compact_list, h_body_id_per_compact_list, num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0], allocate_gpu_memory); + message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0); + // The expanded LWEs will always be on the casting key format tmp_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async( num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus), diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index d7066c5e7..fd43e0c28 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -861,7 +861,10 @@ mod tests { use super::*; use crate::prelude::*; use crate::shortint::parameters::*; - use crate::{CompressedServerKey, set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32}; + use crate::{set_server_key, FheBool, FheInt64, FheUint16, FheUint2, FheUint32}; + + #[cfg(all(feature = "zk-pok", feature = "gpu"))] + use crate::CompressedServerKey; #[test] fn test_compact_list() { @@ -1100,7 +1103,6 @@ mod tests { } } - #[cfg(all(feature = "zk-pok", feature = "gpu"))] #[test] fn test_gpu_proven_compact_list() {