#ifndef ZK_UTILITIES_H #define ZK_UTILITIES_H #include "../integer/integer_utilities.h" #include "integer/integer.cuh" #include //////////////////////////////////// // Helper structures used in expand template struct lwe_mask { Torus *mask; lwe_mask(Torus *mask) : mask{mask} {} }; template struct compact_lwe_body { Torus *body; uint64_t monomial_degree; /* Body id is the index of the body in the compact ciphertext list. * It's used to compute the rotation. */ compact_lwe_body(Torus *body, const uint64_t body_id) : body{body}, monomial_degree{body_id} {} }; template struct compact_lwe_list { Torus *ptr; uint32_t lwe_dimension; uint32_t total_num_lwes; compact_lwe_list(Torus *ptr, uint32_t lwe_dimension, uint32_t total_num_lwes) : ptr{ptr}, lwe_dimension{lwe_dimension}, total_num_lwes{total_num_lwes} { } lwe_mask get_mask() { return lwe_mask(ptr); } // Returns the index-th body compact_lwe_body get_body(uint32_t index) { if (index >= total_num_lwes) { PANIC("index out of range in compact_lwe_list::get_body"); } return compact_lwe_body(&ptr[lwe_dimension + index], uint64_t(index)); } }; template struct flattened_compact_lwe_lists { Torus *d_ptr; Torus **d_ptr_to_compact_list; const uint32_t *h_num_lwes_per_compact_list; uint32_t num_compact_lists; uint32_t lwe_dimension; uint32_t total_num_lwes; flattened_compact_lwe_lists(Torus *d_ptr, const uint32_t *h_num_lwes_per_compact_list, uint32_t num_compact_lists, uint32_t lwe_dimension) : d_ptr(d_ptr), h_num_lwes_per_compact_list(h_num_lwes_per_compact_list), num_compact_lists(num_compact_lists), lwe_dimension(lwe_dimension) { d_ptr_to_compact_list = static_cast(malloc(num_compact_lists * sizeof(Torus **))); total_num_lwes = 0; auto curr_list = d_ptr; for (auto i = 0; i < num_compact_lists; ++i) { total_num_lwes += h_num_lwes_per_compact_list[i]; d_ptr_to_compact_list[i] = curr_list; curr_list += lwe_dimension + h_num_lwes_per_compact_list[i]; } } compact_lwe_list get_device_compact_list(uint32_t compact_list_index) { if (compact_list_index >= num_compact_lists) { PANIC("index out of range in flattened_compact_lwe_lists::get"); } return compact_lwe_list(d_ptr_to_compact_list[compact_list_index], lwe_dimension, h_num_lwes_per_compact_list[compact_list_index]); } }; /* * A expand_job tells the expand kernel exactly which input mask and body to use * and what rotation to apply */ template struct expand_job { lwe_mask mask_to_use; compact_lwe_body body_to_use; expand_job(lwe_mask mask_to_use, compact_lwe_body body_to_use) : mask_to_use{mask_to_use}, body_to_use{body_to_use} {} }; //////////////////////////////////// template struct zk_expand_mem { int_radix_params computing_params; int_radix_params casting_params; bool casting_key_type; uint32_t num_lwes; uint32_t num_compact_lists; int_radix_lut *message_and_carry_extract_luts; Torus *tmp_expanded_lwes; Torus *tmp_ksed_small_to_big_expanded_lwes; bool gpu_memory_allocated; uint32_t *num_lwes_per_compact_list; expand_job *d_expand_jobs; expand_job *h_expand_jobs; zk_expand_mem(cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, int_radix_params computing_params, int_radix_params casting_params, KS_TYPE casting_key_type, const uint32_t *num_lwes_per_compact_list, const bool *is_boolean_array, uint32_t num_compact_lists, bool allocate_gpu_memory, uint64_t &size_tracker) : computing_params(computing_params), casting_params(casting_params), num_compact_lists(num_compact_lists), casting_key_type(casting_key_type) { gpu_memory_allocated = allocate_gpu_memory; // We copy num_lwes_per_compact_list so we get protection against // num_lwes_per_compact_list being freed while this buffer is still in use this->num_lwes_per_compact_list = (uint32_t *)malloc(num_compact_lists * sizeof(uint32_t)); memcpy(this->num_lwes_per_compact_list, num_lwes_per_compact_list, num_compact_lists * sizeof(uint32_t)); num_lwes = 0; for (int i = 0; i < num_compact_lists; i++) { num_lwes += this->num_lwes_per_compact_list[i]; } if (computing_params.carry_modulus != computing_params.message_modulus) { PANIC("GPU backend requires carry_modulus equal to message_modulus") } auto message_extract_lut_f = [casting_params](Torus x) -> Torus { return x % casting_params.message_modulus; }; auto carry_extract_lut_f = [casting_params](Torus x) -> Torus { return (x / casting_params.carry_modulus) % casting_params.message_modulus; }; // Booleans have to be sanitized auto sanitize_bool_f = [](Torus x) -> Torus { return x == 0 ? 0 : 1; }; auto message_extract_and_sanitize_bool_lut_f = [message_extract_lut_f, sanitize_bool_f](Torus x) -> Torus { return sanitize_bool_f(message_extract_lut_f(x)); }; auto carry_extract_and_sanitize_bool_lut_f = [carry_extract_lut_f, sanitize_bool_f](Torus x) -> Torus { return sanitize_bool_f(carry_extract_lut_f(x)); }; /** In case the casting key casts from BIG to SMALL key we run a single KS to expand using the casting key as ksk. Otherwise, in case the casting key casts from SMALL to BIG key, we first keyswitch from SMALL to BIG using the casting key as ksk, then we keyswitch from BIG to SMALL using the computing ksk, and lastly we apply the PBS. The output is always on the BIG key. **/ auto params = casting_params; if (casting_key_type == SMALL_TO_BIG) { params = computing_params; } message_and_carry_extract_luts = new int_radix_lut( streams, gpu_indexes, gpu_count, params, 4, 2 * num_lwes, allocate_gpu_memory, size_tracker); generate_device_accumulator( streams[0], gpu_indexes[0], message_and_carry_extract_luts->get_lut(0, 0), message_and_carry_extract_luts->get_degree(0), message_and_carry_extract_luts->get_max_degree(0), params.glwe_dimension, params.polynomial_size, params.message_modulus, params.carry_modulus, message_extract_lut_f, gpu_memory_allocated); generate_device_accumulator( streams[0], gpu_indexes[0], message_and_carry_extract_luts->get_lut(0, 1), message_and_carry_extract_luts->get_degree(1), message_and_carry_extract_luts->get_max_degree(1), params.glwe_dimension, params.polynomial_size, params.message_modulus, params.carry_modulus, carry_extract_lut_f, gpu_memory_allocated); generate_device_accumulator( streams[0], gpu_indexes[0], message_and_carry_extract_luts->get_lut(0, 2), message_and_carry_extract_luts->get_degree(2), message_and_carry_extract_luts->get_max_degree(2), params.glwe_dimension, params.polynomial_size, params.message_modulus, params.carry_modulus, message_extract_and_sanitize_bool_lut_f, gpu_memory_allocated); generate_device_accumulator( streams[0], gpu_indexes[0], message_and_carry_extract_luts->get_lut(0, 3), message_and_carry_extract_luts->get_degree(3), message_and_carry_extract_luts->get_max_degree(3), params.glwe_dimension, params.polynomial_size, params.message_modulus, params.carry_modulus, carry_extract_and_sanitize_bool_lut_f, gpu_memory_allocated); // We are always packing two LWEs. We just need to be sure we have enough // space in the carry part to store a message of the same size as is in the // message part. if (params.carry_modulus < params.message_modulus) PANIC("Carry modulus must be at least as large as message modulus"); auto num_packed_msgs = 2; // Adjust indexes to permute the output and access the correct LUT auto h_indexes_in = static_cast( malloc(num_packed_msgs * num_lwes * sizeof(Torus))); auto h_indexes_out = static_cast( malloc(num_packed_msgs * num_lwes * sizeof(Torus))); auto h_lut_indexes = static_cast( malloc(num_packed_msgs * num_lwes * sizeof(Torus))); d_expand_jobs = static_cast *>(cuda_malloc_with_size_tracking_async( num_lwes * sizeof(expand_job), streams[0], gpu_indexes[0], size_tracker, allocate_gpu_memory)); h_expand_jobs = static_cast *>( malloc(num_lwes * sizeof(expand_job))); /* * Each LWE contains encrypted data in both carry and message spaces * that needs to be extracted. * * The loop processes each compact list (k) and for each LWE within that * list: * 1. Sets input indexes to read each LWE twice (for carry and message * extraction) * 2. Creates output indexes to properly reorder the results * 3. Selects appropriate LUT index based on whether boolean sanitization is * needed * * We want the output to have always first the content of the message part * and then the content of the carry part of each LWE. * * i.e. msg_extract(LWE_0), carry_extract(LWE_0), msg_extract(LWE_1), * carry_extract(LWE_1), ... * * Aiming that behavior, with 4 LWEs we would have: * * // Each LWE is processed twice * h_indexes_in = {0, 1, 2, 3, 0, 1, 2, 3} * * // First 4 use message LUT, last 4 use carry LUT * h_lut_indexes = {0, 0, 0, 0, 1, 1, 1, 1} * * // Reorders output so message and carry for each LWE appear together * h_indexes_out = {0, 2, 4, 6, 1, 3, 5, 7} * * If an LWE contains a boolean value, its LUT index is shifted by * num_packed_msgs to use the sanitization LUT (which ensures output is * exactly 0 or 1). */ auto offset = 0; for (int k = 0; k < num_compact_lists; k++) { auto num_lwes_in_kth = this->num_lwes_per_compact_list[k]; for (int i = 0; i < num_packed_msgs * num_lwes_in_kth; i++) { auto lwe_index = i + num_packed_msgs * offset; auto lwe_index_in_list = i % num_lwes_in_kth; h_indexes_in[lwe_index] = lwe_index_in_list + offset; h_indexes_out[lwe_index] = num_packed_msgs * h_indexes_in[lwe_index] + i / num_lwes_in_kth; // If the input relates to a boolean, shift the LUT so the correct one // with sanitization is used auto boolean_offset = is_boolean_array[h_indexes_out[lwe_index]] ? num_packed_msgs : 0; h_lut_indexes[lwe_index] = i / num_lwes_in_kth + boolean_offset; } offset += num_lwes_in_kth; } message_and_carry_extract_luts->set_lwe_indexes( streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out); auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0); cuda_memcpy_with_size_tracking_async_to_gpu( lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof(Torus), streams[0], gpu_indexes[0], allocate_gpu_memory); auto active_gpu_count = get_active_gpu_count(2 * num_lwes, gpu_count); message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, active_gpu_count); message_and_carry_extract_luts->allocate_lwe_vector_for_non_trivial_indexes( streams, gpu_indexes, active_gpu_count, 2 * num_lwes, size_tracker, allocate_gpu_memory); // The expanded LWEs will always be on the casting key format tmp_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async( num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0], size_tracker, allocate_gpu_memory); tmp_ksed_small_to_big_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async( num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0], size_tracker, allocate_gpu_memory); cuda_synchronize_stream(streams[0], gpu_indexes[0]); free(h_indexes_in); free(h_indexes_out); free(h_lut_indexes); } void release(cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count) { message_and_carry_extract_luts->release(streams, gpu_indexes, gpu_count); delete message_and_carry_extract_luts; cuda_drop_with_size_tracking_async(tmp_expanded_lwes, streams[0], gpu_indexes[0], gpu_memory_allocated); cuda_drop_with_size_tracking_async(tmp_ksed_small_to_big_expanded_lwes, streams[0], gpu_indexes[0], gpu_memory_allocated); cuda_drop_with_size_tracking_async(d_expand_jobs, streams[0], gpu_indexes[0], gpu_memory_allocated); cuda_synchronize_stream(streams[0], gpu_indexes[0]); free(num_lwes_per_compact_list); free(h_expand_jobs); } }; #endif // ZK_UTILITIES_H