mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
239 lines
10 KiB
C++
239 lines
10 KiB
C++
#ifndef ZK_UTILITIES_H
|
|
#define ZK_UTILITIES_H
|
|
|
|
#include "../integer/integer_utilities.h"
|
|
#include "integer/integer.cuh"
|
|
#include <cstdint>
|
|
|
|
template <typename Torus> struct zk_expand_mem {
|
|
int_radix_params computing_params;
|
|
int_radix_params casting_params;
|
|
bool casting_key_type;
|
|
uint32_t num_lwes;
|
|
uint32_t num_compact_lists;
|
|
|
|
int_radix_lut<Torus> *message_and_carry_extract_luts;
|
|
|
|
Torus *tmp_expanded_lwes;
|
|
Torus *tmp_ksed_small_to_big_expanded_lwes;
|
|
|
|
uint32_t *d_lwe_compact_input_indexes;
|
|
|
|
uint32_t *d_body_id_per_compact_list;
|
|
bool gpu_memory_allocated;
|
|
|
|
zk_expand_mem(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
|
uint32_t gpu_count, int_radix_params computing_params,
|
|
int_radix_params casting_params, KS_TYPE casting_key_type,
|
|
const uint32_t *num_lwes_per_compact_list,
|
|
const bool *is_boolean_array, uint32_t num_compact_lists,
|
|
bool allocate_gpu_memory, uint64_t *size_tracker)
|
|
: computing_params(computing_params), casting_params(casting_params),
|
|
num_compact_lists(num_compact_lists),
|
|
casting_key_type(casting_key_type) {
|
|
|
|
gpu_memory_allocated = allocate_gpu_memory;
|
|
num_lwes = 0;
|
|
for (int i = 0; i < num_compact_lists; i++) {
|
|
num_lwes += num_lwes_per_compact_list[i];
|
|
}
|
|
|
|
if (computing_params.carry_modulus != computing_params.message_modulus) {
|
|
PANIC("GPU backend requires carry_modulus equal to message_modulus")
|
|
}
|
|
|
|
auto message_extract_lut_f = [casting_params](Torus x) -> Torus {
|
|
return x % casting_params.message_modulus;
|
|
};
|
|
auto carry_extract_lut_f = [casting_params](Torus x) -> Torus {
|
|
return (x / casting_params.carry_modulus) %
|
|
casting_params.message_modulus;
|
|
};
|
|
|
|
// Booleans have to be sanitized
|
|
auto sanitize_bool_f = [](Torus x) -> Torus { return x == 0 ? 0 : 1; };
|
|
auto message_extract_and_sanitize_bool_lut_f =
|
|
[message_extract_lut_f, sanitize_bool_f](Torus x) -> Torus {
|
|
return sanitize_bool_f(message_extract_lut_f(x));
|
|
};
|
|
auto carry_extract_and_sanitize_bool_lut_f =
|
|
[carry_extract_lut_f, sanitize_bool_f](Torus x) -> Torus {
|
|
return sanitize_bool_f(carry_extract_lut_f(x));
|
|
};
|
|
|
|
/** In case the casting key casts from BIG to SMALL key we run a single KS
|
|
to expand using the casting key as ksk. Otherwise, in case the casting key
|
|
casts from SMALL to BIG key, we first keyswitch from SMALL to BIG using
|
|
the casting key as ksk, then we keyswitch from BIG to SMALL using the
|
|
computing ksk, and lastly we apply the PBS. The output is always on the
|
|
BIG key.
|
|
**/
|
|
auto params = casting_params;
|
|
if (casting_key_type == SMALL_TO_BIG) {
|
|
params = computing_params;
|
|
}
|
|
message_and_carry_extract_luts = new int_radix_lut<Torus>(
|
|
streams, gpu_indexes, gpu_count, params, 4, 2 * num_lwes,
|
|
allocate_gpu_memory, size_tracker);
|
|
|
|
generate_device_accumulator<Torus>(
|
|
streams[0], gpu_indexes[0],
|
|
message_and_carry_extract_luts->get_lut(0, 0),
|
|
message_and_carry_extract_luts->get_degree(0),
|
|
message_and_carry_extract_luts->get_max_degree(0),
|
|
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
|
params.carry_modulus, message_extract_lut_f, gpu_memory_allocated);
|
|
|
|
generate_device_accumulator<Torus>(
|
|
streams[0], gpu_indexes[0],
|
|
message_and_carry_extract_luts->get_lut(0, 1),
|
|
message_and_carry_extract_luts->get_degree(1),
|
|
message_and_carry_extract_luts->get_max_degree(1),
|
|
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
|
params.carry_modulus, carry_extract_lut_f, gpu_memory_allocated);
|
|
|
|
generate_device_accumulator<Torus>(
|
|
streams[0], gpu_indexes[0],
|
|
message_and_carry_extract_luts->get_lut(0, 2),
|
|
message_and_carry_extract_luts->get_degree(2),
|
|
message_and_carry_extract_luts->get_max_degree(2),
|
|
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
|
params.carry_modulus, message_extract_and_sanitize_bool_lut_f,
|
|
gpu_memory_allocated);
|
|
|
|
generate_device_accumulator<Torus>(
|
|
streams[0], gpu_indexes[0],
|
|
message_and_carry_extract_luts->get_lut(0, 3),
|
|
message_and_carry_extract_luts->get_degree(3),
|
|
message_and_carry_extract_luts->get_max_degree(3),
|
|
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
|
params.carry_modulus, carry_extract_and_sanitize_bool_lut_f,
|
|
gpu_memory_allocated);
|
|
|
|
// Hint for future readers: if message_modulus == 4 then
|
|
// packed_messages_per_lwe becomes 2
|
|
auto packed_messages_per_lwe = log2_int(params.message_modulus);
|
|
|
|
// Adjust indexes to permute the output and access the correct LUT
|
|
auto h_indexes_in = static_cast<Torus *>(
|
|
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
|
auto h_indexes_out = static_cast<Torus *>(
|
|
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
|
auto h_lut_indexes = static_cast<Torus *>(
|
|
malloc(packed_messages_per_lwe * num_lwes * sizeof(Torus)));
|
|
auto h_body_id_per_compact_list =
|
|
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
|
|
auto h_lwe_compact_input_indexes =
|
|
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
|
|
|
|
d_body_id_per_compact_list =
|
|
static_cast<uint32_t *>(cuda_malloc_with_size_tracking_async(
|
|
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
|
size_tracker, allocate_gpu_memory));
|
|
d_lwe_compact_input_indexes =
|
|
static_cast<uint32_t *>(cuda_malloc_with_size_tracking_async(
|
|
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
|
size_tracker, allocate_gpu_memory));
|
|
|
|
auto compact_list_id = 0;
|
|
auto idx = 0;
|
|
auto count = 0;
|
|
for (int i = 0; i < num_lwes; i++) {
|
|
h_lwe_compact_input_indexes[i] = idx;
|
|
count++;
|
|
if (count == num_lwes_per_compact_list[compact_list_id]) {
|
|
compact_list_id++;
|
|
idx += casting_params.big_lwe_dimension + count;
|
|
count = 0;
|
|
}
|
|
}
|
|
|
|
auto offset = 0;
|
|
for (int k = 0; k < num_compact_lists; k++) {
|
|
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
|
|
uint32_t body_count = 0;
|
|
for (int i = 0; i < num_lwes_in_kth_compact_list; i++) {
|
|
h_body_id_per_compact_list[i + offset] = body_count;
|
|
body_count++;
|
|
}
|
|
offset += num_lwes_in_kth_compact_list;
|
|
}
|
|
|
|
offset = 0;
|
|
for (int k = 0; k < num_compact_lists; k++) {
|
|
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
|
|
for (int i = 0;
|
|
i < packed_messages_per_lwe * num_lwes_in_kth_compact_list; i++) {
|
|
Torus j = i % num_lwes_in_kth_compact_list;
|
|
h_indexes_in[i + packed_messages_per_lwe * offset] = j + offset;
|
|
h_indexes_out[i + packed_messages_per_lwe * offset] =
|
|
packed_messages_per_lwe * (j + offset) +
|
|
(i / num_lwes_in_kth_compact_list);
|
|
// If the input relates to a boolean, shift the LUT so the correct one
|
|
// with sanitization is used
|
|
h_lut_indexes[i + packed_messages_per_lwe * offset] =
|
|
(is_boolean_array[h_indexes_out[i +
|
|
packed_messages_per_lwe * offset]]
|
|
? packed_messages_per_lwe
|
|
: 0) +
|
|
i / num_lwes_in_kth_compact_list;
|
|
}
|
|
offset += num_lwes_in_kth_compact_list;
|
|
}
|
|
|
|
message_and_carry_extract_luts->set_lwe_indexes(
|
|
streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out);
|
|
auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0);
|
|
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes, 0);
|
|
|
|
cuda_memcpy_with_size_tracking_async_to_gpu(
|
|
d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
|
|
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
|
allocate_gpu_memory);
|
|
cuda_memcpy_with_size_tracking_async_to_gpu(
|
|
lut_indexes, h_lut_indexes,
|
|
packed_messages_per_lwe * num_lwes * sizeof(Torus), streams[0],
|
|
gpu_indexes[0], allocate_gpu_memory);
|
|
cuda_memcpy_with_size_tracking_async_to_gpu(
|
|
d_body_id_per_compact_list, h_body_id_per_compact_list,
|
|
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
|
|
allocate_gpu_memory);
|
|
|
|
// The expanded LWEs will always be on the casting key format
|
|
tmp_expanded_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
|
num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus),
|
|
streams[0], gpu_indexes[0], size_tracker, allocate_gpu_memory);
|
|
|
|
tmp_ksed_small_to_big_expanded_lwes =
|
|
(Torus *)cuda_malloc_with_size_tracking_async(
|
|
num_lwes * (casting_params.big_lwe_dimension + 1) * sizeof(Torus),
|
|
streams[0], gpu_indexes[0], size_tracker, allocate_gpu_memory);
|
|
|
|
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
|
free(h_indexes_in);
|
|
free(h_indexes_out);
|
|
free(h_lut_indexes);
|
|
free(h_body_id_per_compact_list);
|
|
free(h_lwe_compact_input_indexes);
|
|
}
|
|
|
|
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
|
uint32_t gpu_count) {
|
|
|
|
message_and_carry_extract_luts->release(streams, gpu_indexes, gpu_count);
|
|
delete message_and_carry_extract_luts;
|
|
|
|
cuda_drop_with_size_tracking_async(d_body_id_per_compact_list, streams[0],
|
|
gpu_indexes[0], gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(d_lwe_compact_input_indexes, streams[0],
|
|
gpu_indexes[0], gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(tmp_expanded_lwes, streams[0],
|
|
gpu_indexes[0], gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(tmp_ksed_small_to_big_expanded_lwes,
|
|
streams[0], gpu_indexes[0],
|
|
gpu_memory_allocated);
|
|
}
|
|
};
|
|
|
|
#endif // ZK_UTILITIES_H
|