chore(gpu): move sum ctxt lut allocation to host to save memory

This commit is contained in:
Agnes Leroy
2025-06-19 14:48:05 +02:00
committed by Agnès Leroy
parent dbd158c641
commit 3ba6a72166
7 changed files with 269 additions and 231 deletions

View File

@@ -395,14 +395,14 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, bool allocate_ms_array);
bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory,
bool allocate_ms_array);
void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI *radix_lwe_vec,
bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr,
void *const *bsks, void *const *ksks,
CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks,
void *const *ksks,
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key);
void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec(

View File

@@ -10,6 +10,7 @@
#include "utils/helper_multi_gpu.cuh"
#include <cmath>
#include <functional>
#include <queue>
class NoiseLevel {
public:
@@ -82,6 +83,165 @@ void generate_many_lut_device_accumulator(
uint32_t message_modulus, uint32_t carry_modulus,
std::vector<std::function<Torus(Torus)>> &f, bool gpu_memory_allocated);
template <typename Torus> struct radix_columns {
std::vector<std::vector<Torus>> columns;
std::vector<size_t> columns_counter;
std::vector<std::vector<Torus>> new_columns;
std::vector<size_t> new_columns_counter;
uint32_t num_blocks;
uint32_t num_radix_in_vec;
uint32_t chunk_size;
radix_columns(const uint64_t *const input_degrees, uint32_t num_blocks,
uint32_t num_radix_in_vec, uint32_t chunk_size,
bool &needs_processing)
: num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec),
chunk_size(chunk_size) {
needs_processing = false;
columns.resize(num_blocks);
columns_counter.resize(num_blocks, 0);
new_columns.resize(num_blocks);
new_columns_counter.resize(num_blocks, 0);
for (uint32_t i = 0; i < num_blocks; ++i) {
new_columns[i].resize(num_radix_in_vec);
}
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
for (uint32_t j = 0; j < num_blocks; ++j) {
if (input_degrees[i * num_blocks + j]) {
columns[j].push_back(i * num_blocks + j);
columns_counter[j]++;
}
}
}
for (uint32_t i = 0; i < num_blocks; ++i) {
if (columns_counter[i] > chunk_size) {
needs_processing = true;
break;
}
}
}
void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out,
Torus *h_lut_indexes, uint32_t &total_ciphertexts,
uint32_t &message_ciphertexts,
bool &needs_processing) {
message_ciphertexts = 0;
total_ciphertexts = 0;
needs_processing = false;
uint32_t pbs_count = 0;
for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) {
const uint32_t column_len = columns_counter[c_id];
new_columns_counter[c_id] = 0;
uint32_t ct_count = 0;
// add message cts into new columns
for (uint32_t i = 0; i + chunk_size <= column_len; i += chunk_size) {
const Torus in_index = columns[c_id][i];
new_columns[c_id][ct_count] = in_index;
if (h_indexes_in != nullptr)
h_indexes_in[pbs_count] = in_index;
if (h_indexes_out != nullptr)
h_indexes_out[pbs_count] = in_index;
if (h_lut_indexes != nullptr)
h_lut_indexes[pbs_count] = 0;
++pbs_count;
++ct_count;
++message_ciphertexts;
}
new_columns_counter[c_id] = ct_count;
}
for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) {
const uint32_t column_len = columns_counter[c_id];
uint32_t ct_count = new_columns_counter[c_id];
// add carry cts into new columns
if (c_id > 0) {
const uint32_t prev_c_id = c_id - 1;
const uint32_t prev_column_len = columns_counter[prev_c_id];
for (uint32_t i = 0; i + chunk_size <= prev_column_len;
i += chunk_size) {
const Torus in_index = columns[prev_c_id][i];
const Torus out_index = columns[prev_c_id][i + 1];
new_columns[c_id][ct_count] = out_index;
if (h_indexes_in != nullptr)
h_indexes_in[pbs_count] = in_index;
if (h_indexes_out != nullptr)
h_indexes_out[pbs_count] = out_index;
if (h_lut_indexes != nullptr)
h_lut_indexes[pbs_count] = 1;
++pbs_count;
++ct_count;
}
}
// add remaining cts into new columns
const uint32_t start_index = column_len - column_len % chunk_size;
for (uint32_t i = start_index; i < column_len; ++i) {
new_columns[c_id][ct_count] = columns[c_id][i];
++ct_count;
}
new_columns_counter[c_id] = ct_count;
if (ct_count > chunk_size) {
needs_processing = true;
}
}
total_ciphertexts = pbs_count;
swap(columns, new_columns);
swap(columns_counter, new_columns_counter);
}
void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out,
Torus *h_lut_indexes) {
for (uint32_t idx = 0; idx < 2 * num_blocks; ++idx) {
if (h_indexes_in != nullptr)
h_indexes_in[idx] = idx % num_blocks;
if (h_indexes_out != nullptr)
h_indexes_out[idx] = idx + idx / num_blocks;
if (h_lut_indexes != nullptr)
h_lut_indexes[idx] = idx / num_blocks;
}
}
};
inline void calculate_final_degrees(uint64_t *const out_degrees,
const uint64_t *const input_degrees,
uint32_t num_blocks,
uint32_t num_radix_in_vec,
uint32_t chunk_size,
uint64_t message_modulus) {
auto get_degree = [message_modulus](uint64_t degree) -> uint64_t {
return std::min(message_modulus - 1, degree);
};
std::vector<std::queue<uint64_t>> columns(num_blocks);
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
for (uint32_t j = 0; j < num_blocks; ++j) {
if (input_degrees[i * num_blocks + j])
columns[j].push(input_degrees[i * num_blocks + j]);
}
}
for (uint32_t i = 0; i < num_blocks; ++i) {
auto &col = columns[i];
while (col.size() > 1) {
uint32_t cur_degree = 0;
uint32_t mn = std::min(chunk_size, (uint32_t)col.size());
for (int j = 0; j < mn; ++j) {
cur_degree += col.front();
col.pop();
}
const uint64_t new_degree = get_degree(cur_degree);
col.push(new_degree);
if ((i + 1) < num_blocks) {
columns[i + 1].push(new_degree);
}
}
}
for (int i = 0; i < num_blocks; i++) {
out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front();
}
}
struct int_radix_params {
PBS_TYPE pbs_type;
uint32_t glwe_dimension;
@@ -1325,8 +1485,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
uint32_t num_blocks_in_radix;
uint32_t max_num_radix_in_vec;
uint32_t chunk_size;
uint64_t *size_tracker;
bool gpu_memory_allocated;
bool reduce_degrees_for_single_carry_propagation;
// temporary buffers
CudaRadixCiphertextFFI *current_blocks;
@@ -1349,7 +1509,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
bool allocated_luts_message_carry;
void setup_index_buffers(cudaStream_t const *streams,
uint32_t const *gpu_indexes) {
uint32_t const *gpu_indexes,
uint64_t *size_tracker) {
d_degrees = (uint64_t *)cuda_malloc_with_size_tracking_async(
max_total_blocks_in_vec * sizeof(uint64_t), streams[0], gpu_indexes[0],
@@ -1394,18 +1555,29 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
}
void setup_lookup_tables(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count) {
uint32_t const *gpu_indexes, uint32_t gpu_count,
uint32_t num_radix_in_vec,
const uint64_t *const degrees) {
uint32_t message_modulus = params.message_modulus;
bool _needs_processing = false;
radix_columns<Torus> current_columns(degrees, num_blocks_in_radix,
num_radix_in_vec, chunk_size,
_needs_processing);
uint32_t total_ciphertexts = 0;
uint32_t total_messages = 0;
current_columns.next_accumulation(nullptr, nullptr, nullptr,
total_ciphertexts, total_messages,
_needs_processing);
if (!mem_reuse) {
uint32_t pbs_count = std::max(2 * (max_total_blocks_in_vec / chunk_size),
2 * num_blocks_in_radix);
if (max_total_blocks_in_vec > 0) {
luts_message_carry = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, params, 2, pbs_count,
gpu_memory_allocated, size_tracker);
} else {
allocated_luts_message_carry = false;
uint32_t pbs_count = std::max(total_ciphertexts, 2 * num_blocks_in_radix);
if (total_ciphertexts > 0 ||
reduce_degrees_for_single_carry_propagation) {
uint64_t size_tracker = 0;
luts_message_carry =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
pbs_count, true, &size_tracker);
allocated_luts_message_carry = true;
}
}
if (allocated_luts_message_carry) {
@@ -1436,25 +1608,35 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
}
}
int_sum_ciphertexts_vec_memory(cudaStream_t const *streams,
uint32_t const *gpu_indexes,
uint32_t gpu_count, int_radix_params params,
uint32_t num_blocks_in_radix,
uint32_t max_num_radix_in_vec,
bool allocate_gpu_memory,
uint64_t *size_tracker) {
int_sum_ciphertexts_vec_memory(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix,
uint32_t max_num_radix_in_vec,
bool reduce_degrees_for_single_carry_propagation,
bool allocate_gpu_memory, uint64_t *size_tracker) {
this->params = params;
this->mem_reuse = false;
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
this->num_blocks_in_radix = num_blocks_in_radix;
this->max_num_radix_in_vec = max_num_radix_in_vec;
this->gpu_memory_allocated = allocate_gpu_memory;
this->size_tracker = size_tracker;
this->chunk_size = (params.message_modulus * params.carry_modulus - 1) /
(params.message_modulus - 1);
this->allocated_luts_message_carry = true;
setup_index_buffers(streams, gpu_indexes);
setup_lookup_tables(streams, gpu_indexes, gpu_count);
this->allocated_luts_message_carry = false;
this->reduce_degrees_for_single_carry_propagation =
reduce_degrees_for_single_carry_propagation;
setup_index_buffers(streams, gpu_indexes, size_tracker);
// because we setup_lut in host function for sum_ciphertexts to save memory
// the size_tracker is topped up here to have a max bound on the used memory
uint32_t max_pbs_count = std::max(
2 * (max_total_blocks_in_vec / chunk_size), 2 * num_blocks_in_radix);
if (max_pbs_count > 0) {
int_radix_lut<Torus> *luts_message_carry_dry_run =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
max_pbs_count, false, size_tracker);
luts_message_carry_dry_run->release(streams, gpu_indexes, gpu_count);
delete luts_message_carry_dry_run;
}
// create and allocate intermediate buffers
current_blocks = new CudaRadixCiphertextFFI;
@@ -1472,23 +1654,25 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix,
uint32_t max_num_radix_in_vec, CudaRadixCiphertextFFI *current_blocks,
CudaRadixCiphertextFFI *small_lwe_vector,
int_radix_lut<Torus> *reused_lut, bool allocate_gpu_memory,
uint64_t *size_tracker) {
int_radix_lut<Torus> *reused_lut,
bool reduce_degrees_for_single_carry_propagation,
bool allocate_gpu_memory, uint64_t *size_tracker) {
this->mem_reuse = true;
this->params = params;
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
this->num_blocks_in_radix = num_blocks_in_radix;
this->max_num_radix_in_vec = max_num_radix_in_vec;
this->gpu_memory_allocated = allocate_gpu_memory;
this->size_tracker = size_tracker;
this->chunk_size = (params.message_modulus * params.carry_modulus - 1) /
(params.message_modulus - 1);
this->allocated_luts_message_carry = true;
this->reduce_degrees_for_single_carry_propagation =
reduce_degrees_for_single_carry_propagation;
this->current_blocks = current_blocks;
this->small_lwe_vector = small_lwe_vector;
this->luts_message_carry = reused_lut;
setup_index_buffers(streams, gpu_indexes);
setup_index_buffers(streams, gpu_indexes, size_tracker);
}
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -1518,7 +1702,6 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
luts_message_carry->release(streams, gpu_indexes, gpu_count);
delete luts_message_carry;
}
delete current_blocks;
delete small_lwe_vector;
}
@@ -2898,7 +3081,7 @@ template <typename Torus> struct int_mul_memory {
sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
2 * num_radix_blocks, vector_result_sb, small_lwe_vector, luts_array,
allocate_gpu_memory, size_tracker);
true, allocate_gpu_memory, size_tracker);
uint32_t uses_carry = 0;
uint32_t requested_flag = outputFlag::FLAG_NONE;
sc_prop_mem = new int_sc_prop_memory<Torus>(
@@ -4694,10 +4877,11 @@ template <typename Torus> struct int_scalar_mul_buffer {
//// The idea is that with these we can create all other shift that are
/// in / range (0..total_bits) for free (block rotation)
preshifted_buffer = new CudaRadixCiphertextFFI;
uint64_t anticipated_drop_mem = 0;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], preshifted_buffer,
msg_bits * num_radix_blocks, params.big_lwe_dimension, size_tracker,
allocate_gpu_memory);
msg_bits * num_radix_blocks, params.big_lwe_dimension,
&anticipated_drop_mem, allocate_gpu_memory);
all_shifted_buffer = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
@@ -4708,22 +4892,28 @@ template <typename Torus> struct int_scalar_mul_buffer {
if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2)
logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks,
allocate_gpu_memory, all_shifted_buffer, size_tracker);
allocate_gpu_memory, all_shifted_buffer, &anticipated_drop_mem);
else
logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks,
allocate_gpu_memory, size_tracker);
allocate_gpu_memory, &anticipated_drop_mem);
uint64_t last_step_mem = 0;
if (num_ciphertext_bits > 0) {
sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
num_ciphertext_bits, allocate_gpu_memory, size_tracker);
num_ciphertext_bits, true, allocate_gpu_memory, &last_step_mem);
}
uint32_t uses_carry = 0;
uint32_t requested_flag = outputFlag::FLAG_NONE;
sc_prop_mem = new int_sc_prop_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
requested_flag, uses_carry, allocate_gpu_memory, size_tracker);
requested_flag, uses_carry, allocate_gpu_memory, &last_step_mem);
if (anticipated_buffer_drop) {
*size_tracker += std::max(anticipated_drop_mem, last_step_mem);
} else {
*size_tracker += anticipated_drop_mem + last_step_mem;
}
}
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,

View File

@@ -210,7 +210,8 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, bool allocate_ms_array) {
bool reduce_degrees_for_single_carry_propagation, bool allocate_gpu_memory,
bool allocate_ms_array) {
int_radix_params params(pbs_type, glwe_dimension, polynomial_size,
glwe_dimension * polynomial_size, lwe_dimension,
@@ -220,15 +221,15 @@ uint64_t scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
return scratch_cuda_integer_partial_sum_ciphertexts_vec_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
(int_sum_ciphertexts_vec_memory<uint64_t> **)mem_ptr, num_blocks_in_radix,
max_num_radix_in_vec, params, allocate_gpu_memory);
max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation, params,
allocate_gpu_memory);
}
void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI *radix_lwe_vec,
bool reduce_degrees_for_single_carry_propagation, int8_t *mem_ptr,
void *const *bsks, void *const *ksks,
CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks,
void *const *ksks,
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key) {
auto mem = (int_sum_ciphertexts_vec_memory<uint64_t> *)mem_ptr;
@@ -239,8 +240,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
case 512:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t, AmortizedDegree<512>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;
@@ -248,8 +248,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<1024>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;
@@ -257,8 +256,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<2048>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;
@@ -266,8 +264,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<4096>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;
@@ -275,8 +272,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<8192>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;
@@ -284,8 +280,7 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<16384>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count, radix_lwe_out,
radix_lwe_vec, reduce_degrees_for_single_carry_propagation, bsks,
(uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_vec, bsks, (uint64_t **)(ksks), ms_noise_reduction_key, mem,
radix_lwe_out->num_radix_blocks,
radix_lwe_vec->num_radix_blocks / radix_lwe_out->num_radix_blocks);
break;

View File

@@ -20,7 +20,6 @@
#include <fstream>
#include <iostream>
#include <omp.h>
#include <queue>
#include <sstream>
#include <string>
#include <vector>
@@ -273,164 +272,19 @@ __global__ void fill_radix_from_lsb_msb(Torus *result_blocks, Torus *lsb_blocks,
}
}
template <typename Torus> struct radix_columns {
std::vector<std::vector<Torus>> columns;
std::vector<size_t> columns_counter;
std::vector<std::vector<Torus>> new_columns;
std::vector<size_t> new_columns_counter;
size_t num_blocks;
size_t num_radix_in_vec;
size_t chunk_size;
radix_columns(const uint64_t *const input_degrees, size_t num_blocks,
size_t num_radix_in_vec, size_t chunk_size,
bool &needs_processing)
: num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec),
chunk_size(chunk_size) {
needs_processing = false;
columns.resize(num_blocks);
columns_counter.resize(num_blocks, 0);
new_columns.resize(num_blocks);
new_columns_counter.resize(num_blocks, 0);
for (size_t i = 0; i < num_blocks; ++i) {
new_columns[i].resize(num_radix_in_vec);
}
for (size_t i = 0; i < num_radix_in_vec; ++i) {
for (size_t j = 0; j < num_blocks; ++j) {
if (input_degrees[i * num_blocks + j]) {
columns[j].push_back(i * num_blocks + j);
columns_counter[j]++;
}
}
}
for (size_t i = 0; i < num_blocks; ++i) {
if (columns_counter[i] > chunk_size) {
needs_processing = true;
break;
}
}
}
void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out,
Torus *h_lut_indexes, size_t &total_ciphertexts,
size_t &message_ciphertexts, bool &needs_processing) {
message_ciphertexts = 0;
total_ciphertexts = 0;
needs_processing = false;
size_t pbs_count = 0;
for (size_t c_id = 0; c_id < num_blocks; ++c_id) {
const size_t column_len = columns_counter[c_id];
new_columns_counter[c_id] = 0;
size_t ct_count = 0;
// add message cts into new columns
for (size_t i = 0; i + chunk_size <= column_len; i += chunk_size) {
const Torus in_index = columns[c_id][i];
new_columns[c_id][ct_count] = in_index;
h_indexes_in[pbs_count] = in_index;
h_indexes_out[pbs_count] = in_index;
h_lut_indexes[pbs_count] = 0;
++pbs_count;
++ct_count;
++message_ciphertexts;
}
new_columns_counter[c_id] = ct_count;
}
for (size_t c_id = 0; c_id < num_blocks; ++c_id) {
const size_t column_len = columns_counter[c_id];
size_t ct_count = new_columns_counter[c_id];
// add carry cts into new columns
if (c_id > 0) {
const size_t prev_c_id = c_id - 1;
const size_t prev_column_len = columns_counter[prev_c_id];
for (size_t i = 0; i + chunk_size <= prev_column_len; i += chunk_size) {
const Torus in_index = columns[prev_c_id][i];
const Torus out_index = columns[prev_c_id][i + 1];
new_columns[c_id][ct_count] = out_index;
h_indexes_in[pbs_count] = in_index;
h_indexes_out[pbs_count] = out_index;
h_lut_indexes[pbs_count] = 1;
++pbs_count;
++ct_count;
}
}
// add remaining cts into new columns
const size_t start_index = column_len - column_len % chunk_size;
for (size_t i = start_index; i < column_len; ++i) {
new_columns[c_id][ct_count] = columns[c_id][i];
++ct_count;
}
new_columns_counter[c_id] = ct_count;
if (ct_count > chunk_size) {
needs_processing = true;
}
}
total_ciphertexts = pbs_count;
swap(columns, new_columns);
swap(columns_counter, new_columns_counter);
}
void final_calculation(Torus *h_indexes_in, Torus *h_indexes_out,
Torus *h_lut_indexes) {
for (size_t idx = 0; idx < 2 * num_blocks; ++idx) {
h_indexes_in[idx] = idx % num_blocks;
h_indexes_out[idx] = idx + idx / num_blocks;
h_lut_indexes[idx] = idx / num_blocks;
}
}
};
inline void calculate_final_degrees(uint64_t *const out_degrees,
const uint64_t *const input_degrees,
size_t num_blocks, size_t num_radix_in_vec,
size_t chunk_size,
uint64_t message_modulus) {
auto get_degree = [message_modulus](uint64_t degree) -> uint64_t {
return std::min(message_modulus - 1, degree);
};
std::vector<std::queue<uint64_t>> columns(num_blocks);
for (size_t i = 0; i < num_radix_in_vec; ++i) {
for (size_t j = 0; j < num_blocks; ++j) {
if (input_degrees[i * num_blocks + j])
columns[j].push(input_degrees[i * num_blocks + j]);
}
}
for (size_t i = 0; i < num_blocks; ++i) {
auto &col = columns[i];
while (col.size() > 1) {
uint32_t cur_degree = 0;
size_t mn = std::min(chunk_size, col.size());
for (int j = 0; j < mn; ++j) {
cur_degree += col.front();
col.pop();
}
const uint64_t new_degree = get_degree(cur_degree);
col.push(new_degree);
if ((i + 1) < num_blocks) {
columns[i + 1].push(new_degree);
}
}
}
for (int i = 0; i < num_blocks; i++) {
out_degrees[i] = (columns[i].empty()) ? 0 : columns[i].front();
}
}
template <typename Torus>
__host__ uint64_t scratch_cuda_integer_partial_sum_ciphertexts_vec_kb(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_sum_ciphertexts_vec_memory<Torus> **mem_ptr,
uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec,
int_radix_params params, bool allocate_gpu_memory) {
bool reduce_degrees_for_single_carry_propagation, int_radix_params params,
bool allocate_gpu_memory) {
uint64_t size_tracker = 0;
*mem_ptr = new int_sum_ciphertexts_vec_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_blocks_in_radix,
max_num_radix_in_vec, allocate_gpu_memory, &size_tracker);
max_num_radix_in_vec, reduce_degrees_for_single_carry_propagation,
allocate_gpu_memory, &size_tracker);
return size_tracker;
}
@@ -438,9 +292,7 @@ template <typename Torus, class params>
__host__ void host_integer_partial_sum_ciphertexts_vec_kb(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI *terms,
bool reduce_degrees_for_single_carry_propagation, void *const *bsks,
uint64_t *const *ksks,
CudaRadixCiphertextFFI *terms, void *const *bsks, uint64_t *const *ksks,
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
uint32_t num_radix_blocks, uint32_t num_radix_in_vec) {
@@ -465,10 +317,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
auto d_columns_counter = mem_ptr->d_columns_counter;
auto d_new_columns = mem_ptr->d_new_columns;
auto d_new_columns_counter = mem_ptr->d_new_columns_counter;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
auto luts_message_carry = mem_ptr->luts_message_carry;
auto glwe_dimension = mem_ptr->params.glwe_dimension;
auto polynomial_size = mem_ptr->params.polynomial_size;
@@ -483,8 +331,9 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
uint32_t num_many_lut = 1;
uint32_t lut_stride = 0;
if (terms->num_radix_blocks == 0)
if (terms->num_radix_blocks == 0) {
return;
}
if (num_radix_in_vec == 1) {
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
radix_lwe_out, 0, num_radix_blocks,
@@ -501,10 +350,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
return;
}
if (mem_ptr->mem_reuse) {
mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count);
}
if (current_blocks != terms) {
copy_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
current_blocks, terms);
@@ -523,11 +368,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
radix_columns<Torus> current_columns(current_blocks->degrees,
num_radix_blocks, num_radix_in_vec,
chunk_size, needs_processing);
int number_of_threads = min(256, params::degree);
int number_of_threads = std::min(256, params::degree);
int part_count = (big_lwe_size + number_of_threads - 1) / number_of_threads;
const dim3 number_of_blocks_2d(num_radix_blocks, part_count, 1);
mem_ptr->setup_lookup_tables(streams, gpu_indexes, gpu_count,
num_radix_in_vec, current_blocks->degrees);
while (needs_processing) {
auto luts_message_carry = mem_ptr->luts_message_carry;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
calculate_chunks<Torus>
<<<number_of_blocks_2d, number_of_threads, 0, streams[0]>>>(
(Torus *)(current_blocks->ptr), d_columns, d_columns_counter,
@@ -538,8 +389,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
d_pbs_indexes_out, luts_message_carry->get_lut_indexes(0, 0), d_columns,
d_columns_counter, chunk_size);
size_t total_ciphertexts;
size_t total_messages;
uint32_t total_ciphertexts;
uint32_t total_messages;
current_columns.next_accumulation(luts_message_carry->h_lwe_indexes_in,
luts_message_carry->h_lwe_indexes_out,
luts_message_carry->h_lut_indexes,
@@ -570,9 +421,8 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, active_gpu_count, current_blocks,
current_blocks, bsks, ksks, ms_noise_reduction_key,
luts_message_carry, total_ciphertexts);
streams, gpu_indexes, gpu_count, current_blocks, current_blocks, bsks,
ksks, ms_noise_reduction_key, luts_message_carry, total_ciphertexts);
}
cuda_set_device(gpu_indexes[0]);
std::swap(d_columns, d_new_columns);
@@ -584,7 +434,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
(Torus *)(radix_lwe_out->ptr), (Torus *)(current_blocks->ptr),
d_columns, d_columns_counter, chunk_size, big_lwe_size);
if (reduce_degrees_for_single_carry_propagation) {
if (mem_ptr->reduce_degrees_for_single_carry_propagation) {
auto luts_message_carry = mem_ptr->luts_message_carry;
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
prepare_final_pbs_indexes<Torus>
<<<1, 2 * num_radix_blocks, 0, streams[0]>>>(
d_pbs_indexes_in, d_pbs_indexes_out,
@@ -772,9 +625,9 @@ __host__ void host_integer_mult_radix_kb(
terms_degree_msb[i] = (b_id > r_id) ? message_modulus - 2 : 0;
}
host_integer_partial_sum_ciphertexts_vec_kb<Torus, params>(
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, true,
bsks, ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem,
num_blocks, 2 * num_blocks);
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, bsks,
ksks, ms_noise_reduction_key, mem_ptr->sum_ciphertexts_mem, num_blocks,
2 * num_blocks);
auto scp_mem_ptr = mem_ptr->sc_prop_mem;
uint32_t requested_flag = outputFlag::FLAG_NONE;

View File

@@ -117,8 +117,8 @@ __host__ void host_integer_scalar_mul_radix(
lwe_array, 0, num_radix_blocks);
} else {
host_integer_partial_sum_ciphertexts_vec_kb<T, params>(
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, true,
bsks, ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem,
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, bsks,
ksks, ms_noise_reduction_key, mem->sum_ciphertexts_vec_mem,
num_radix_blocks, j);
auto scp_mem_ptr = mem->sc_prop_mem;

View File

@@ -1007,6 +1007,7 @@ unsafe extern "C" {
message_modulus: u32,
carry_modulus: u32,
pbs_type: PBS_TYPE,
reduce_degrees_for_single_carry_propagation: bool,
allocate_gpu_memory: bool,
allocate_ms_array: bool,
) -> u64;
@@ -1018,7 +1019,6 @@ unsafe extern "C" {
gpu_count: u32,
radix_lwe_out: *mut CudaRadixCiphertextFFI,
radix_lwe_vec: *mut CudaRadixCiphertextFFI,
reduce_degrees_for_single_carry_propagation: bool,
mem_ptr: *mut i8,
bsks: *const *mut ffi::c_void,
ksks: *const *mut ffi::c_void,

View File

@@ -4729,6 +4729,7 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async<
message_modulus.0 as u32,
carry_modulus.0 as u32,
pbs_type as u32,
reduce_degrees_for_single_carry_propagation,
true,
allocate_ms_noise_array,
);
@@ -4738,7 +4739,6 @@ pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async<
streams.len() as u32,
&raw mut cuda_ffi_result,
&raw mut cuda_ffi_radix_list,
reduce_degrees_for_single_carry_propagation,
mem_ptr,
bootstrapping_key.ptr.as_ptr(),
keyswitch_key.ptr.as_ptr(),