mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): move sum ctxt lut allocation to host to save memory
This commit is contained in:
@@ -395,14 +395,14 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
bool allocate_gpu_memory, bool allocate_ms_array);
|
||||
bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory,
|
||||
bool allocate_ms_array);
|
||||
|
||||
void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
|
||||
CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI *radix_lwe_vec,
|
||||
bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks,
|
||||
CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks,
|
||||
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key);
|
||||
|
||||
void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec(
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "utils/helper_multi_gpu.cuh"
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
|
||||
class NoiseLevel {
|
||||
public:
|
||||
@@ -82,6 +83,165 @@ void generate_many_lut_device_accumulator(
|
||||
uint32_t message_modulus, uint32_t carry_modulus,
|
||||
std::vector<std::function<Torus(Torus)>> &f, bool gpu_memory_allocated);
|
||||
|
||||
template <typename Torus> struct radix_columns {
|
||||
std::vector<std::vector<Torus>> columns;
|
||||
std::vector<size_t> columns_counter;
|
||||
std::vector<std::vector<Torus>> new_columns;
|
||||
std::vector<size_t> new_columns_counter;
|
||||
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_radix_in_vec;
|
||||
uint32_t chunk_size;
|
||||
radix_columns(const uint64_t *const input_degrees, uint32_t num_blocks,
|
||||
uint32_t num_radix_in_vec, uint32_t chunk_size,
|
||||
bool &needs_processing)
|
||||
: num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec),
|
||||
chunk_size(chunk_size) {
|
||||
needs_processing = false;
|
||||
columns.resize(num_blocks);
|
||||
columns_counter.resize(num_blocks, 0);
|
||||
new_columns.resize(num_blocks);
|
||||
new_columns_counter.resize(num_blocks, 0);
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
new_columns[i].resize(num_radix_in_vec);
|
||||
}
|
||||
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (uint32_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j]) {
|
||||
columns[j].push_back(i * num_blocks + j);
|
||||
columns_counter[j]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
if (columns_counter[i] > chunk_size) {
|
||||
needs_processing = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out,
|
||||
Torus *h_lut_indexes, uint32_t &total_ciphertexts,
|
||||
uint32_t &message_ciphertexts,
|
||||
bool &needs_processing) {
|
||||
message_ciphertexts = 0;
|
||||
total_ciphertexts = 0;
|
||||
needs_processing = false;
|
||||
|
||||
uint32_t pbs_count = 0;
|
||||
for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) {
|
||||
const uint32_t column_len = columns_counter[c_id];
|
||||
new_columns_counter[c_id] = 0;
|
||||
uint32_t ct_count = 0;
|
||||
// add message cts into new columns
|
||||
for (uint32_t i = 0; i + chunk_size <= column_len; i += chunk_size) {
|
||||
const Torus in_index = columns[c_id][i];
|
||||
new_columns[c_id][ct_count] = in_index;
|
||||
if (h_indexes_in != nullptr)
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
if (h_indexes_out != nullptr)
|
||||
h_indexes_out[pbs_count] = in_index;
|
||||
if (h_lut_indexes != nullptr)
|
||||
h_lut_indexes[pbs_count] = 0;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
++message_ciphertexts;
|
||||
}
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
}
|
||||
|
||||
for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) {
|
||||
const uint32_t column_len = columns_counter[c_id];
|
||||
uint32_t ct_count = new_columns_counter[c_id];
|
||||
// add carry cts into new columns
|
||||
if (c_id > 0) {
|
||||
const uint32_t prev_c_id = c_id - 1;
|
||||
const uint32_t prev_column_len = columns_counter[prev_c_id];
|
||||
for (uint32_t i = 0; i + chunk_size <= prev_column_len;
|
||||
i += chunk_size) {
|
||||
const Torus in_index = columns[prev_c_id][i];
|
||||
const Torus out_index = columns[prev_c_id][i + 1];
|
||||
new_columns[c_id][ct_count] = out_index;
|
||||
if (h_indexes_in != nullptr)
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
if (h_indexes_out != nullptr)
|
||||
h_indexes_out[pbs_count] = out_index;
|
||||
if (h_lut_indexes != nullptr)
|
||||
h_lut_indexes[pbs_count] = 1;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
}
|
||||
}
|
||||
// add remaining cts into new columns
|
||||
const uint32_t start_index = column_len - column_len % chunk_size;
|
||||
for (uint32_t i = start_index; i < column_len; ++i) {
|
||||
new_columns[c_id][ct_count] = columns[c_id][i];
|
||||
++ct_count;
|
||||
}
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
if (ct_count > chunk_size) {
|
||||
needs_processing = true;
|
||||
}
|
||||
}
|
||||
total_ciphertexts = pbs_count;
|
||||
swap(columns, new_columns);
|
||||
swap(columns_counter, new_columns_counter);
|
||||
}
|
||||
|
||||
void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out,
|
||||
Torus *h_lut_indexes) {
|
||||
for (uint32_t idx = 0; idx < 2 * num_blocks; ++idx) {
|
||||
if (h_indexes_in != nullptr)
|
||||
h_indexes_in[idx] = idx % num_blocks;
|
||||
if (h_indexes_out != nullptr)
|
||||
h_indexes_out[idx] = idx + idx / num_blocks;
|
||||
if (h_lut_indexes != nullptr)
|
||||
h_lut_indexes[idx] = idx / num_blocks;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline void calculate_final_degrees(uint64_t *const out_degrees,
|
||||
const uint64_t *const input_degrees,
|
||||
uint32_t num_blocks,
|
||||
uint32_t num_radix_in_vec,
|
||||
uint32_t chunk_size,
|
||||
uint64_t message_modulus) {
|
||||
|
||||
auto get_degree = [message_modulus](uint64_t degree) -> uint64_t {
|
||||
return std::min(message_modulus - 1, degree);
|
||||
};
|
||||
std::vector<std::queue<uint64_t>> columns(num_blocks);
|
||||
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (uint32_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j])
|
||||
columns[j].push(input_degrees[i * num_blocks + j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
auto &col = columns[i];
|
||||
while (col.size() > 1) {
|
||||
uint32_t cur_degree = 0;
|
||||
uint32_t mn = std::min(chunk_size, (uint32_t)col.size());
|
||||
for (int j = 0; j < mn; ++j) {
|
||||
cur_degree += col.front();
|
||||
col.pop();
|
||||
}
|
||||
const uint64_t new_degree = get_degree(cur_degree);
|
||||
col.push(new_degree);
|
||||
if ((i + 1) < num_blocks) {
|
||||
columns[i + 1].push(new_degree);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front();
|
||||
}
|
||||
}
|
||||
|
||||
struct int_radix_params {
|
||||
PBS_TYPE pbs_type;
|
||||
uint32_t glwe_dimension;
|
||||
@@ -1325,8 +1485,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
uint32_t num_blocks_in_radix;
|
||||
uint32_t max_num_radix_in_vec;
|
||||
uint32_t chunk_size;
|
||||
uint64_t *size_tracker;
|
||||
bool gpu_memory_allocated;
|
||||
bool reduce_degrees_for_single_carry_propagation;
|
||||
|
||||
// temporary buffers
|
||||
CudaRadixCiphertextFFI *current_blocks;
|
||||
@@ -1349,7 +1509,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
bool allocated_luts_message_carry;
|
||||
|
||||
void setup_index_buffers(cudaStream_t const *streams,
|
||||
uint32_t const *gpu_indexes) {
|
||||
uint32_t const *gpu_indexes,
|
||||
uint64_t *size_tracker) {
|
||||
|
||||
d_degrees = (uint64_t *)cuda_malloc_with_size_tracking_async(
|
||||
max_total_blocks_in_vec * sizeof(uint64_t), streams[0], gpu_indexes[0],
|
||||
@@ -1394,18 +1555,29 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
}
|
||||
|
||||
void setup_lookup_tables(cudaStream_t const *streams,
|
||||
uint32_t const *gpu_indexes, uint32_t gpu_count) {
|
||||
uint32_t const *gpu_indexes, uint32_t gpu_count,
|
||||
uint32_t num_radix_in_vec,
|
||||
const uint64_t *const degrees) {
|
||||
uint32_t message_modulus = params.message_modulus;
|
||||
bool _needs_processing = false;
|
||||
radix_columns<Torus> current_columns(degrees, num_blocks_in_radix,
|
||||
num_radix_in_vec, chunk_size,
|
||||
_needs_processing);
|
||||
uint32_t total_ciphertexts = 0;
|
||||
uint32_t total_messages = 0;
|
||||
current_columns.next_accumulation(nullptr, nullptr, nullptr,
|
||||
total_ciphertexts, total_messages,
|
||||
_needs_processing);
|
||||
|
||||
if (!mem_reuse) {
|
||||
uint32_t pbs_count = std::max(2 * (max_total_blocks_in_vec / chunk_size),
|
||||
2 * num_blocks_in_radix);
|
||||
if (max_total_blocks_in_vec > 0) {
|
||||
luts_message_carry = new int_radix_lut<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, 2, pbs_count,
|
||||
gpu_memory_allocated, size_tracker);
|
||||
} else {
|
||||
allocated_luts_message_carry = false;
|
||||
uint32_t pbs_count = std::max(total_ciphertexts, 2 * num_blocks_in_radix);
|
||||
if (total_ciphertexts > 0 ||
|
||||
reduce_degrees_for_single_carry_propagation) {
|
||||
uint64_t size_tracker = 0;
|
||||
luts_message_carry =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
|
||||
pbs_count, true, &size_tracker);
|
||||
allocated_luts_message_carry = true;
|
||||
}
|
||||
}
|
||||
if (allocated_luts_message_carry) {
|
||||
@@ -1436,25 +1608,35 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
|
||||
}
|
||||
}
|
||||
int_sum_ciphertexts_vec_memory(cudaStream_t const *streams,
|
||||
uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, int_radix_params params,
|
||||
uint32_t num_blocks_in_radix,
|
||||
uint32_t max_num_radix_in_vec,
|
||||
bool allocate_gpu_memory,
|
||||
uint64_t *size_tracker) {
|
||||
int_sum_ciphertexts_vec_memory(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix,
|
||||
uint32_t max_num_radix_in_vec,
|
||||
bool reduce_degrees_for_single_carry_propagation,
|
||||
bool allocate_gpu_memory, uint64_t *size_tracker) {
|
||||
this->params = params;
|
||||
this->mem_reuse = false;
|
||||
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
|
||||
this->num_blocks_in_radix = num_blocks_in_radix;
|
||||
this->max_num_radix_in_vec = max_num_radix_in_vec;
|
||||
this->gpu_memory_allocated = allocate_gpu_memory;
|
||||
this->size_tracker = size_tracker;
|
||||
this->chunk_size = (params.message_modulus * params.carry_modulus - 1) /
|
||||
(params.message_modulus - 1);
|
||||
this->allocated_luts_message_carry = true;
|
||||
setup_index_buffers(streams, gpu_indexes);
|
||||
setup_lookup_tables(streams, gpu_indexes, gpu_count);
|
||||
this->allocated_luts_message_carry = false;
|
||||
this->reduce_degrees_for_single_carry_propagation =
|
||||
reduce_degrees_for_single_carry_propagation;
|
||||
setup_index_buffers(streams, gpu_indexes, size_tracker);
|
||||
// because we setup_lut in host function for sum_ciphertexts to save memory
|
||||
// the size_tracker is topped up here to have a max bound on the used memory
|
||||
uint32_t max_pbs_count = std::max(
|
||||
2 * (max_total_blocks_in_vec / chunk_size), 2 * num_blocks_in_radix);
|
||||
if (max_pbs_count > 0) {
|
||||
int_radix_lut<Torus> *luts_message_carry_dry_run =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
|
||||
max_pbs_count, false, size_tracker);
|
||||
luts_message_carry_dry_run->release(streams, gpu_indexes, gpu_count);
|
||||
delete luts_message_carry_dry_run;
|
||||
}
|
||||
|
||||
// create and allocate intermediate buffers
|
||||
current_blocks = new CudaRadixCiphertextFFI;
|
||||
@@ -1472,23 +1654,25 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix,
|
||||
uint32_t max_num_radix_in_vec, CudaRadixCiphertextFFI *current_blocks,
|
||||
CudaRadixCiphertextFFI *small_lwe_vector,
|
||||
int_radix_lut<Torus> *reused_lut, bool allocate_gpu_memory,
|
||||
uint64_t *size_tracker) {
|
||||
int_radix_lut<Torus> *reused_lut,
|
||||
bool reduce_degrees_for_single_carry_propagation,
|
||||
bool allocate_gpu_memory, uint64_t *size_tracker) {
|
||||
this->mem_reuse = true;
|
||||
this->params = params;
|
||||
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
|
||||
this->num_blocks_in_radix = num_blocks_in_radix;
|
||||
this->max_num_radix_in_vec = max_num_radix_in_vec;
|
||||
this->gpu_memory_allocated = allocate_gpu_memory;
|
||||
this->size_tracker = size_tracker;
|
||||
this->chunk_size = (params.message_modulus * params.carry_modulus - 1) /
|
||||
(params.message_modulus - 1);
|
||||
this->allocated_luts_message_carry = true;
|
||||
this->reduce_degrees_for_single_carry_propagation =
|
||||
reduce_degrees_for_single_carry_propagation;
|
||||
|
||||
this->current_blocks = current_blocks;
|
||||
this->small_lwe_vector = small_lwe_vector;
|
||||
this->luts_message_carry = reused_lut;
|
||||
setup_index_buffers(streams, gpu_indexes);
|
||||
setup_index_buffers(streams, gpu_indexes, size_tracker);
|
||||
}
|
||||
|
||||
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
@@ -1518,7 +1702,6 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
delete luts_message_carry;
|
||||
}
|
||||
|
||||
delete current_blocks;
|
||||
delete small_lwe_vector;
|
||||
}
|
||||
@@ -2898,7 +3081,7 @@ template <typename Torus> struct int_mul_memory {
|
||||
sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
2 * num_radix_blocks, vector_result_sb, small_lwe_vector, luts_array,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
true, allocate_gpu_memory, size_tracker);
|
||||
uint32_t uses_carry = 0;
|
||||
uint32_t requested_flag = outputFlag::FLAG_NONE;
|
||||
sc_prop_mem = new int_sc_prop_memory<Torus>(
|
||||
@@ -4694,10 +4877,11 @@ template <typename Torus> struct int_scalar_mul_buffer {
|
||||
//// The idea is that with these we can create all other shift that are
|
||||
/// in / range (0..total_bits) for free (block rotation)
|
||||
preshifted_buffer = new CudaRadixCiphertextFFI;
|
||||
uint64_t anticipated_drop_mem = 0;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams[0], gpu_indexes[0], preshifted_buffer,
|
||||
msg_bits * num_radix_blocks, params.big_lwe_dimension, size_tracker,
|
||||
allocate_gpu_memory);
|
||||
msg_bits * num_radix_blocks, params.big_lwe_dimension,
|
||||
&anticipated_drop_mem, allocate_gpu_memory);
|
||||
|
||||
all_shifted_buffer = new CudaRadixCiphertextFFI;
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
@@ -4708,22 +4892,28 @@ template <typename Torus> struct int_scalar_mul_buffer {
|
||||
if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2)
|
||||
logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
|
||||
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks,
|
||||
allocate_gpu_memory, all_shifted_buffer, size_tracker);
|
||||
allocate_gpu_memory, all_shifted_buffer, &anticipated_drop_mem);
|
||||
else
|
||||
logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
|
||||
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
allocate_gpu_memory, &anticipated_drop_mem);
|
||||
|
||||
uint64_t last_step_mem = 0;
|
||||
if (num_ciphertext_bits > 0) {
|
||||
sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
num_ciphertext_bits, allocate_gpu_memory, size_tracker);
|
||||
num_ciphertext_bits, true, allocate_gpu_memory, &last_step_mem);
|
||||
}
|
||||
uint32_t uses_carry = 0;
|
||||
uint32_t requested_flag = outputFlag::FLAG_NONE;
|
||||
sc_prop_mem = new int_sc_prop_memory<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
requested_flag, uses_carry, allocate_gpu_memory, size_tracker);
|
||||
requested_flag, uses_carry, allocate_gpu_memory, &last_step_mem);
|
||||
if (anticipated_buffer_drop) {
|
||||
*size_tracker += std::max(anticipated_drop_mem, last_step_mem);
|
||||
} else {
|
||||
*size_tracker += anticipated_drop_mem + last_step_mem;
|
||||
}
|
||||
}
|
||||
|
||||
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
|
||||
@@ -210,7 +210,8 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
|
||||
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
|
||||
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
|
||||
bool allocate_gpu_memory, bool allocate_ms_array) {
|
||||
bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory,
|
||||
bool allocate_ms_array) {
|
||||
|
||||
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
|
||||
glwe_dimension * polynomial_size, lwe_dimension,
|
||||
@@ -220,15 +221,15 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
return scratch_cuda_integer_partial_sum_ciphertexts_vec_kb<uint64_t>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
|
||||
(int_sum_ciphertexts_vec_memory<uint64_t> **)mem_ptr, num_blocks_in_radix,
|
||||
max_num_radix_in_vec, params, allocate_gpu_memory);
|
||||
max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation, params,
|
||||
allocate_gpu_memory);
|
||||
}
|
||||
|
||||
void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
|
||||
CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI *radix_lwe_vec,
|
||||
bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr,
|
||||
void *const *bsks, void *const *ksks,
|
||||
CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks,
|
||||
void *const *ksks,
|
||||
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key) {
|
||||
|
||||
auto mem = (int_sum_ciphertexts_vec_memory<uint64_t> *)mem_ptr;
|
||||
@@ -239,8 +240,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
case 512:
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t, AmortizedDegree<512>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
@@ -248,8 +248,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
AmortizedDegree<1024>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
@@ -257,8 +256,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
AmortizedDegree<2048>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
@@ -266,8 +264,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
AmortizedDegree<4096>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
@@ -275,8 +272,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
AmortizedDegree<8192>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
@@ -284,8 +280,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
|
||||
AmortizedDegree<16384>>(
|
||||
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
|
||||
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
|
||||
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
|
||||
radix_lwe_out->num_radix_blocks,
|
||||
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
|
||||
break;
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <omp.h>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -273,164 +272,19 @@ __global__ void fill_radix_from_lsb_msb(Torus *result_blocks, Torus *lsb_blocks,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus> struct radix_columns {
|
||||
std::vector<std::vector<Torus>> columns;
|
||||
std::vector<size_t> columns_counter;
|
||||
std::vector<std::vector<Torus>> new_columns;
|
||||
std::vector<size_t> new_columns_counter;
|
||||
|
||||
size_t num_blocks;
|
||||
size_t num_radix_in_vec;
|
||||
size_t chunk_size;
|
||||
radix_columns(const uint64_t *const input_degrees, size_t num_blocks,
|
||||
size_t num_radix_in_vec, size_t chunk_size,
|
||||
bool &needs_processing)
|
||||
: num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec),
|
||||
chunk_size(chunk_size) {
|
||||
needs_processing = false;
|
||||
columns.resize(num_blocks);
|
||||
columns_counter.resize(num_blocks, 0);
|
||||
new_columns.resize(num_blocks);
|
||||
new_columns_counter.resize(num_blocks, 0);
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
new_columns[i].resize(num_radix_in_vec);
|
||||
}
|
||||
for (size_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (size_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j]) {
|
||||
columns[j].push_back(i * num_blocks + j);
|
||||
columns_counter[j]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
if (columns_counter[i] > chunk_size) {
|
||||
needs_processing = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out,
|
||||
Torus *h_lut_indexes, size_t &total_ciphertexts,
|
||||
size_t &message_ciphertexts, bool &needs_processing) {
|
||||
message_ciphertexts = 0;
|
||||
total_ciphertexts = 0;
|
||||
needs_processing = false;
|
||||
|
||||
size_t pbs_count = 0;
|
||||
for (size_t c_id = 0; c_id < num_blocks; ++c_id) {
|
||||
const size_t column_len = columns_counter[c_id];
|
||||
new_columns_counter[c_id] = 0;
|
||||
size_t ct_count = 0;
|
||||
// add message cts into new columns
|
||||
for (size_t i = 0; i + chunk_size <= column_len; i += chunk_size) {
|
||||
const Torus in_index = columns[c_id][i];
|
||||
new_columns[c_id][ct_count] = in_index;
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
h_indexes_out[pbs_count] = in_index;
|
||||
h_lut_indexes[pbs_count] = 0;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
++message_ciphertexts;
|
||||
}
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
}
|
||||
|
||||
for (size_t c_id = 0; c_id < num_blocks; ++c_id) {
|
||||
const size_t column_len = columns_counter[c_id];
|
||||
size_t ct_count = new_columns_counter[c_id];
|
||||
// add carry cts into new columns
|
||||
if (c_id > 0) {
|
||||
const size_t prev_c_id = c_id - 1;
|
||||
const size_t prev_column_len = columns_counter[prev_c_id];
|
||||
for (size_t i = 0; i + chunk_size <= prev_column_len; i += chunk_size) {
|
||||
const Torus in_index = columns[prev_c_id][i];
|
||||
const Torus out_index = columns[prev_c_id][i + 1];
|
||||
new_columns[c_id][ct_count] = out_index;
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
h_indexes_out[pbs_count] = out_index;
|
||||
h_lut_indexes[pbs_count] = 1;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
}
|
||||
}
|
||||
// add remaining cts into new columns
|
||||
const size_t start_index = column_len - column_len % chunk_size;
|
||||
for (size_t i = start_index; i < column_len; ++i) {
|
||||
new_columns[c_id][ct_count] = columns[c_id][i];
|
||||
++ct_count;
|
||||
}
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
if (ct_count > chunk_size) {
|
||||
needs_processing = true;
|
||||
}
|
||||
}
|
||||
total_ciphertexts = pbs_count;
|
||||
swap(columns, new_columns);
|
||||
swap(columns_counter, new_columns_counter);
|
||||
}
|
||||
|
||||
void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out,
|
||||
Torus *h_lut_indexes) {
|
||||
for (size_t idx = 0; idx < 2 * num_blocks; ++idx) {
|
||||
h_indexes_in[idx] = idx % num_blocks;
|
||||
h_indexes_out[idx] = idx + idx / num_blocks;
|
||||
h_lut_indexes[idx] = idx / num_blocks;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline void calculate_final_degrees(uint64_t *const out_degrees,
|
||||
const uint64_t *const input_degrees,
|
||||
size_t num_blocks, size_t num_radix_in_vec,
|
||||
size_t chunk_size,
|
||||
uint64_t message_modulus) {
|
||||
|
||||
auto get_degree = [message_modulus](uint64_t degree) -> uint64_t {
|
||||
return std::min(message_modulus - 1, degree);
|
||||
};
|
||||
std::vector<std::queue<uint64_t>> columns(num_blocks);
|
||||
for (size_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (size_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j])
|
||||
columns[j].push(input_degrees[i * num_blocks + j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
auto &col = columns[i];
|
||||
while (col.size() > 1) {
|
||||
uint32_t cur_degree = 0;
|
||||
size_t mn = std::min(chunk_size, col.size());
|
||||
for (int j = 0; j < mn; ++j) {
|
||||
cur_degree += col.front();
|
||||
col.pop();
|
||||
}
|
||||
const uint64_t new_degree = get_degree(cur_degree);
|
||||
col.push(new_degree);
|
||||
if ((i + 1) < num_blocks) {
|
||||
columns[i + 1].push(new_degree);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ uint64_t scratch_cuda_integer_partial_sum_ciphertexts_vec_kb(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, int_sum_ciphertexts_vec_memory<Torus> **mem_ptr,
|
||||
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
|
||||
int_radix_params params, bool allocate_gpu_memory) {
|
||||
bool reduce_degrees_for_single_carry_propagation, int_radix_params params,
|
||||
bool allocate_gpu_memory) {
|
||||
|
||||
uint64_t size_tracker = 0;
|
||||
*mem_ptr = new int_sum_ciphertexts_vec_memory<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_blocks_in_radix,
|
||||
max_num_radix_in_vec, allocate_gpu_memory, &size_tracker);
|
||||
max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation,
|
||||
allocate_gpu_memory, &size_tracker);
|
||||
return size_tracker;
|
||||
}
|
||||
|
||||
@@ -438,9 +292,7 @@ template <typename Torus, class params>
|
||||
__host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, CudaRadixCiphertextFFI *radix_lwe_out,
|
||||
CudaRadixCiphertextFFI *terms,
|
||||
bool reduce_degrees_for_single_carry_propagation, void *const *bsks,
|
||||
uint64_t *const *ksks,
|
||||
CudaRadixCiphertextFFI *terms, void *const *bsks, uint64_t *const *ksks,
|
||||
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
|
||||
int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
|
||||
uint32_t num_radix_blocks, uint32_t num_radix_in_vec) {
|
||||
@@ -465,10 +317,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
auto d_columns_counter = mem_ptr->d_columns_counter;
|
||||
auto d_new_columns = mem_ptr->d_new_columns;
|
||||
auto d_new_columns_counter = mem_ptr->d_new_columns_counter;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
|
||||
auto glwe_dimension = mem_ptr->params.glwe_dimension;
|
||||
auto polynomial_size = mem_ptr->params.polynomial_size;
|
||||
@@ -483,8 +331,9 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
uint32_t num_many_lut = 1;
|
||||
uint32_t lut_stride = 0;
|
||||
|
||||
if (terms->num_radix_blocks == 0)
|
||||
if (terms->num_radix_blocks == 0) {
|
||||
return;
|
||||
}
|
||||
if (num_radix_in_vec == 1) {
|
||||
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
|
||||
radix_lwe_out, 0, num_radix_blocks,
|
||||
@@ -501,10 +350,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
return;
|
||||
}
|
||||
|
||||
if (mem_ptr->mem_reuse) {
|
||||
mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count);
|
||||
}
|
||||
|
||||
if (current_blocks != terms) {
|
||||
copy_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
|
||||
current_blocks, terms);
|
||||
@@ -523,11 +368,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
radix_columns<Torus> current_columns(current_blocks->degrees,
|
||||
num_radix_blocks, num_radix_in_vec,
|
||||
chunk_size, needs_processing);
|
||||
int number_of_threads = min(256, params::degree);
|
||||
int number_of_threads = std::min(256, params::degree);
|
||||
int part_count = (big_lwe_size + number_of_threads - 1) / number_of_threads;
|
||||
const dim3 number_of_blocks_2d(num_radix_blocks, part_count, 1);
|
||||
|
||||
mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count,
|
||||
num_radix_in_vec, current_blocks->degrees);
|
||||
|
||||
while (needs_processing) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
calculate_chunks<Torus>
|
||||
<<<number_of_blocks_2d, number_of_threads, 0, streams[0]>>>(
|
||||
(Torus *)(current_blocks->ptr), d_columns, d_columns_counter,
|
||||
@@ -538,8 +389,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
d_pbs_indexes_out, luts_message_carry->get_lut_indexes(0, 0), d_columns,
|
||||
d_columns_counter, chunk_size);
|
||||
|
||||
size_t total_ciphertexts;
|
||||
size_t total_messages;
|
||||
uint32_t total_ciphertexts;
|
||||
uint32_t total_messages;
|
||||
current_columns.next_accumulation(luts_message_carry->h_lwe_indexes_in,
|
||||
luts_message_carry->h_lwe_indexes_out,
|
||||
luts_message_carry->h_lut_indexes,
|
||||
@@ -570,9 +421,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, current_blocks,
|
||||
current_blocks, bsks, ksks, ms_noise_reduction_key,
|
||||
luts_message_carry, total_ciphertexts);
|
||||
streams, gpu_indexes, gpu_count, current_blocks, current_blocks, bsks,
|
||||
ksks, ms_noise_reduction_key, luts_message_carry, total_ciphertexts);
|
||||
}
|
||||
cuda_set_device(gpu_indexes[0]);
|
||||
std::swap(d_columns, d_new_columns);
|
||||
@@ -584,7 +434,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
(Torus *)(radix_lwe_out->ptr), (Torus *)(current_blocks->ptr),
|
||||
d_columns, d_columns_counter, chunk_size, big_lwe_size);
|
||||
|
||||
if (reduce_degrees_for_single_carry_propagation) {
|
||||
if (mem_ptr->reduce_degrees_for_single_carry_propagation) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
prepare_final_pbs_indexes<Torus>
|
||||
<<<1, 2 * num_radix_blocks, 0, streams[0]>>>(
|
||||
d_pbs_indexes_in, d_pbs_indexes_out,
|
||||
@@ -772,9 +625,9 @@ __host__ void host_integer_mult_radix_kb(
|
||||
terms_degree_msb[i] = (b_id > r_id) ? message_modulus - 2 : 0;
|
||||
}
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<Torus, params>(
|
||||
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, true,
|
||||
bsks, ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem,
|
||||
num_blocks, 2 * num_blocks);
|
||||
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, bsks,
|
||||
ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem, num_blocks,
|
||||
2 * num_blocks);
|
||||
|
||||
auto scp_mem_ptr = mem_ptr->sc_prop_mem;
|
||||
uint32_t requested_flag = outputFlag::FLAG_NONE;
|
||||
|
||||
@@ -117,8 +117,8 @@ __host__ void host_integer_scalar_mul_radix(
|
||||
lwe_array, 0, num_radix_blocks);
|
||||
} else {
|
||||
host_integer_partial_sum_ciphertexts_vec_kb<T, params>(
|
||||
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, true,
|
||||
bsks, ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem,
|
||||
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, bsks,
|
||||
ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem,
|
||||
num_radix_blocks, j);
|
||||
|
||||
auto scp_mem_ptr = mem->sc_prop_mem;
|
||||
|
||||
@@ -1007,6 +1007,7 @@ unsafe extern "C" {
|
||||
message_modulus: u32,
|
||||
carry_modulus: u32,
|
||||
pbs_type: PBS_TYPE,
|
||||
reduce_degrees_for_single_carry_propagation: bool,
|
||||
allocate_gpu_memory: bool,
|
||||
allocate_ms_array: bool,
|
||||
) -> u64;
|
||||
@@ -1018,7 +1019,6 @@ unsafe extern "C" {
|
||||
gpu_count: u32,
|
||||
radix_lwe_out: *mut CudaRadixCiphertextFFI,
|
||||
radix_lwe_vec: *mut CudaRadixCiphertextFFI,
|
||||
reduce_degrees_for_single_carry_propagation: bool,
|
||||
mem_ptr: *mut i8,
|
||||
bsks: *const *mut ffi::c_void,
|
||||
ksks: *const *mut ffi::c_void,
|
||||
|
||||
@@ -4729,6 +4729,7 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async<
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
true,
|
||||
allocate_ms_noise_array,
|
||||
);
|
||||
@@ -4738,7 +4739,6 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async<
|
||||
streams.len() as u32,
|
||||
&raw mut cuda_ffi_result,
|
||||
&raw mut cuda_ffi_radix_list,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
|
||||
Reference in New Issue
Block a user