mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
10 Commits
ns/fix_ben
...
al/backup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3d01c23ae | ||
|
|
ef3c02adaf | ||
|
|
1b21e245f3 | ||
|
|
2a439d1be2 | ||
|
|
56481bc02e | ||
|
|
101acf12a6 | ||
|
|
1fe2e11405 | ||
|
|
ca7103bc79 | ||
|
|
d2598a06cf | ||
|
|
d2f99512f3 |
@@ -113,8 +113,12 @@ void generate_many_lut_device_accumulator(
|
||||
uint32_t message_modulus, uint32_t carry_modulus,
|
||||
std::vector<std::function<Torus(Torus)>> &f, bool gpu_memory_allocated);
|
||||
|
||||
struct radix_columns {
|
||||
template <typename Torus> struct radix_columns {
|
||||
std::vector<std::vector<Torus>> columns;
|
||||
std::vector<uint32_t> columns_counter;
|
||||
std::vector<std::vector<Torus>> new_columns;
|
||||
std::vector<uint32_t> new_columns_counter;
|
||||
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_radix_in_vec;
|
||||
uint32_t chunk_size;
|
||||
@@ -124,14 +128,21 @@ struct radix_columns {
|
||||
: num_blocks(num_blocks), num_radix_in_vec(num_radix_in_vec),
|
||||
chunk_size(chunk_size) {
|
||||
needs_processing = false;
|
||||
columns.resize(num_blocks);
|
||||
columns_counter.resize(num_blocks, 0);
|
||||
new_columns.resize(num_blocks);
|
||||
new_columns_counter.resize(num_blocks, 0);
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
new_columns[i].resize(num_radix_in_vec);
|
||||
}
|
||||
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (uint32_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j])
|
||||
columns_counter[j] += 1;
|
||||
if (input_degrees[i * num_blocks + j]) {
|
||||
columns[j].push_back(i * num_blocks + j);
|
||||
columns_counter[j]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
if (columns_counter[i] > chunk_size) {
|
||||
needs_processing = true;
|
||||
@@ -140,70 +151,96 @@ struct radix_columns {
|
||||
}
|
||||
}
|
||||
|
||||
void next_accumulation(uint32_t &total_ciphertexts,
|
||||
void next_accumulation(Torus *h_indexes_in, Torus *h_indexes_out,
|
||||
Torus *h_lut_indexes, uint32_t &total_ciphertexts,
|
||||
uint32_t &message_ciphertexts,
|
||||
bool &needs_processing) {
|
||||
message_ciphertexts = 0;
|
||||
total_ciphertexts = 0;
|
||||
needs_processing = false;
|
||||
for (int i = num_blocks - 1; i > 0; --i) {
|
||||
uint32_t cur_count = columns_counter[i];
|
||||
uint32_t prev_count = columns_counter[i - 1];
|
||||
uint32_t new_count = 0;
|
||||
|
||||
// accumulated_blocks from current columns
|
||||
new_count += cur_count / chunk_size;
|
||||
// all accumulated message blocks needs pbs
|
||||
message_ciphertexts += new_count;
|
||||
// carry blocks from previous columns
|
||||
new_count += prev_count / chunk_size;
|
||||
// both carry and message blocks that needs pbs
|
||||
total_ciphertexts += new_count;
|
||||
// now add remaining non accumulated blocks that does not require pbs
|
||||
new_count += cur_count % chunk_size;
|
||||
uint32_t pbs_count = 0;
|
||||
for (uint32_t c_id = 0; c_id < num_blocks; ++c_id) {
|
||||
const uint32_t column_len = columns_counter[c_id];
|
||||
new_columns_counter[c_id] = 0;
|
||||
uint32_t ct_count = 0;
|
||||
// add message cts into new columns
|
||||
for (uint32_t i = 0; i + chunk_size <= column_len; i += chunk_size) {
|
||||
const Torus in_index = columns[c_id][i];
|
||||
new_columns[c_id][ct_count] = in_index;
|
||||
if (h_indexes_in != nullptr)
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
if (h_indexes_out != nullptr)
|
||||
h_indexes_out[pbs_count] = in_index;
|
||||
if (h_lut_indexes != nullptr)
|
||||
h_lut_indexes[pbs_count] = 0;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
++message_ciphertexts;
|
||||
}
|
||||
|
||||
columns_counter[i] = new_count;
|
||||
// add carry cts into new columns
|
||||
if (c_id > 0) {
|
||||
const uint32_t prev_c_id = c_id - 1;
|
||||
const uint32_t prev_column_len = columns_counter[prev_c_id];
|
||||
for (uint32_t i = 0; i + chunk_size <= prev_column_len;
|
||||
i += chunk_size) {
|
||||
const Torus in_index = columns[prev_c_id][i];
|
||||
const Torus out_index = columns[prev_c_id][i + 1];
|
||||
new_columns[c_id][ct_count] = out_index;
|
||||
if (h_indexes_in != nullptr)
|
||||
h_indexes_in[pbs_count] = in_index;
|
||||
if (h_indexes_out != nullptr)
|
||||
h_indexes_out[pbs_count] = out_index;
|
||||
if (h_lut_indexes != nullptr)
|
||||
h_lut_indexes[pbs_count] = 1;
|
||||
++pbs_count;
|
||||
++ct_count;
|
||||
}
|
||||
}
|
||||
|
||||
if (new_count > chunk_size)
|
||||
// add remaining cts into new columns
|
||||
const uint32_t start_index = column_len - column_len % chunk_size;
|
||||
for (uint32_t i = start_index; i < column_len; ++i) {
|
||||
new_columns[c_id][ct_count] = columns[c_id][i];
|
||||
++ct_count;
|
||||
}
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
if (ct_count > chunk_size) {
|
||||
needs_processing = true;
|
||||
}
|
||||
|
||||
new_columns_counter[c_id] = ct_count;
|
||||
}
|
||||
|
||||
// now do it for 0th block
|
||||
uint32_t new_count = columns_counter[0] / chunk_size;
|
||||
message_ciphertexts += new_count;
|
||||
total_ciphertexts += new_count;
|
||||
new_count += columns_counter[0] % chunk_size;
|
||||
columns_counter[0] = new_count;
|
||||
|
||||
if (new_count > chunk_size) {
|
||||
needs_processing = true;
|
||||
}
|
||||
total_ciphertexts = pbs_count;
|
||||
swap(columns, new_columns);
|
||||
swap(columns_counter, new_columns_counter);
|
||||
}
|
||||
};
|
||||
|
||||
inline void calculate_final_degrees(uint64_t *const out_degrees,
|
||||
const uint64_t *const input_degrees,
|
||||
uint32_t num_blocks,
|
||||
uint32_t num_radix_in_vec,
|
||||
uint32_t chunk_size,
|
||||
size_t num_blocks, size_t num_radix_in_vec,
|
||||
size_t chunk_size,
|
||||
uint64_t message_modulus) {
|
||||
|
||||
auto get_degree = [message_modulus](uint64_t degree) -> uint64_t {
|
||||
return std::min(message_modulus - 1, degree);
|
||||
};
|
||||
std::vector<std::queue<uint64_t>> columns(num_blocks);
|
||||
for (uint32_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (uint32_t j = 0; j < num_blocks; ++j) {
|
||||
for (size_t i = 0; i < num_radix_in_vec; ++i) {
|
||||
for (size_t j = 0; j < num_blocks; ++j) {
|
||||
if (input_degrees[i * num_blocks + j])
|
||||
columns[j].push(input_degrees[i * num_blocks + j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_blocks; ++i) {
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
auto &col = columns[i];
|
||||
while (col.size() > 1) {
|
||||
uint32_t cur_degree = 0;
|
||||
uint32_t mn = std::min(chunk_size, (uint32_t)col.size());
|
||||
size_t mn = std::min(chunk_size, col.size());
|
||||
for (int j = 0; j < mn; ++j) {
|
||||
cur_degree += col.front();
|
||||
col.pop();
|
||||
@@ -686,8 +723,10 @@ template <typename Torus> struct int_radix_lut {
|
||||
void set_lwe_indexes(cudaStream_t stream, uint32_t gpu_index,
|
||||
Torus *h_indexes_in, Torus *h_indexes_out) {
|
||||
|
||||
memcpy(h_lwe_indexes_in, h_indexes_in, num_blocks * sizeof(Torus));
|
||||
memcpy(h_lwe_indexes_out, h_indexes_out, num_blocks * sizeof(Torus));
|
||||
if (h_indexes_in != h_lwe_indexes_in)
|
||||
memcpy(h_lwe_indexes_in, h_indexes_in, num_blocks * sizeof(Torus));
|
||||
if (h_indexes_out != h_lwe_indexes_out)
|
||||
memcpy(h_lwe_indexes_out, h_indexes_out, num_blocks * sizeof(Torus));
|
||||
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
lwe_indexes_in, h_lwe_indexes_in, num_blocks * sizeof(Torus), stream,
|
||||
@@ -1487,7 +1526,6 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
// lookup table for extracting message and carry
|
||||
int_radix_lut<Torus> *luts_message_carry;
|
||||
|
||||
bool mem_reuse = false;
|
||||
bool allocated_luts_message_carry;
|
||||
|
||||
void setup_index_buffers(cudaStream_t const *streams,
|
||||
@@ -1542,24 +1580,23 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
const uint64_t *const degrees) {
|
||||
uint32_t message_modulus = params.message_modulus;
|
||||
bool _needs_processing = false;
|
||||
radix_columns current_columns(degrees, num_blocks_in_radix,
|
||||
num_radix_in_vec, chunk_size,
|
||||
_needs_processing);
|
||||
radix_columns<Torus> current_columns(degrees, num_blocks_in_radix,
|
||||
num_radix_in_vec, chunk_size,
|
||||
_needs_processing);
|
||||
uint32_t total_ciphertexts = 0;
|
||||
uint32_t total_messages = 0;
|
||||
current_columns.next_accumulation(total_ciphertexts, total_messages,
|
||||
|
||||
current_columns.next_accumulation(nullptr, nullptr, nullptr,
|
||||
total_ciphertexts, total_messages,
|
||||
_needs_processing);
|
||||
|
||||
if (!mem_reuse) {
|
||||
uint32_t pbs_count = std::max(total_ciphertexts, 2 * num_blocks_in_radix);
|
||||
if (total_ciphertexts > 0 ||
|
||||
reduce_degrees_for_single_carry_propagation) {
|
||||
uint64_t size_tracker = 0;
|
||||
luts_message_carry =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
|
||||
pbs_count, true, size_tracker);
|
||||
allocated_luts_message_carry = true;
|
||||
}
|
||||
uint32_t pbs_count = std::max(total_ciphertexts, 2 * num_blocks_in_radix);
|
||||
if (total_ciphertexts > 0 || reduce_degrees_for_single_carry_propagation) {
|
||||
uint64_t size_tracker = 0;
|
||||
luts_message_carry =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
|
||||
pbs_count, true, size_tracker);
|
||||
allocated_luts_message_carry = true;
|
||||
}
|
||||
if (allocated_luts_message_carry) {
|
||||
auto message_acc = luts_message_carry->get_lut(0, 0);
|
||||
@@ -1596,7 +1633,6 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
bool reduce_degrees_for_single_carry_propagation,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
this->mem_reuse = false;
|
||||
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
|
||||
this->num_blocks_in_radix = num_blocks_in_radix;
|
||||
this->max_num_radix_in_vec = max_num_radix_in_vec;
|
||||
@@ -1630,32 +1666,6 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
params.small_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
}
|
||||
|
||||
int_sum_ciphertexts_vec_memory(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, int_radix_params params, uint32_t num_blocks_in_radix,
|
||||
uint32_t max_num_radix_in_vec, CudaRadixCiphertextFFI *current_blocks,
|
||||
CudaRadixCiphertextFFI *small_lwe_vector,
|
||||
int_radix_lut<Torus> *reused_lut,
|
||||
bool reduce_degrees_for_single_carry_propagation,
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
this->mem_reuse = true;
|
||||
this->params = params;
|
||||
this->max_total_blocks_in_vec = num_blocks_in_radix * max_num_radix_in_vec;
|
||||
this->num_blocks_in_radix = num_blocks_in_radix;
|
||||
this->max_num_radix_in_vec = max_num_radix_in_vec;
|
||||
this->gpu_memory_allocated = allocate_gpu_memory;
|
||||
this->chunk_size = (params.message_modulus * params.carry_modulus - 1) /
|
||||
(params.message_modulus - 1);
|
||||
this->allocated_luts_message_carry = true;
|
||||
this->reduce_degrees_for_single_carry_propagation =
|
||||
reduce_degrees_for_single_carry_propagation;
|
||||
|
||||
this->current_blocks = current_blocks;
|
||||
this->small_lwe_vector = small_lwe_vector;
|
||||
this->luts_message_carry = reused_lut;
|
||||
setup_index_buffers(streams, gpu_indexes, size_tracker);
|
||||
}
|
||||
|
||||
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count) {
|
||||
cuda_drop_with_size_tracking_async(d_degrees, streams[0], gpu_indexes[0],
|
||||
@@ -1674,18 +1684,16 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
|
||||
cuda_drop_with_size_tracking_async(d_new_columns, streams[0],
|
||||
gpu_indexes[0], gpu_memory_allocated);
|
||||
|
||||
if (!mem_reuse) {
|
||||
release_radix_ciphertext_async(streams[0], gpu_indexes[0], current_blocks,
|
||||
gpu_memory_allocated);
|
||||
release_radix_ciphertext_async(streams[0], gpu_indexes[0],
|
||||
small_lwe_vector, gpu_memory_allocated);
|
||||
if (allocated_luts_message_carry) {
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
delete luts_message_carry;
|
||||
}
|
||||
delete current_blocks;
|
||||
delete small_lwe_vector;
|
||||
release_radix_ciphertext_async(streams[0], gpu_indexes[0], current_blocks,
|
||||
gpu_memory_allocated);
|
||||
release_radix_ciphertext_async(streams[0], gpu_indexes[0], small_lwe_vector,
|
||||
gpu_memory_allocated);
|
||||
if (allocated_luts_message_carry) {
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
delete luts_message_carry;
|
||||
}
|
||||
delete current_blocks;
|
||||
delete small_lwe_vector;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3061,10 +3069,10 @@ template <typename Torus> struct int_mul_memory {
|
||||
|
||||
luts_array->broadcast_lut(streams, gpu_indexes);
|
||||
// create memory object for sum ciphertexts
|
||||
|
||||
sum_ciphertexts_mem = new int_sum_ciphertexts_vec_memory<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
2 * num_radix_blocks, vector_result_sb, small_lwe_vector, luts_array,
|
||||
true, allocate_gpu_memory, size_tracker);
|
||||
2 * num_radix_blocks, true, allocate_gpu_memory, size_tracker);
|
||||
uint32_t uses_carry = 0;
|
||||
uint32_t requested_flag = outputFlag::FLAG_NONE;
|
||||
sc_prop_mem = new int_sc_prop_memory<Torus>(
|
||||
|
||||
@@ -95,17 +95,10 @@ __global__ inline void radix_vec_to_columns(uint32_t *const *const columns,
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__global__ inline void prepare_new_columns_and_pbs_indexes(
|
||||
__global__ inline void prepare_new_columns(
|
||||
uint32_t *const *const new_columns, uint32_t *const new_columns_counter,
|
||||
Torus *const pbs_indexes_in, Torus *const pbs_indexes_out,
|
||||
Torus *const lut_indexes, const uint32_t *const *const columns,
|
||||
const uint32_t *const columns_counter, const uint32_t chunk_size) {
|
||||
__shared__ uint32_t counter;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
counter = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t *const *const columns, const uint32_t *const columns_counter,
|
||||
const uint32_t chunk_size) {
|
||||
|
||||
const uint32_t base_id = threadIdx.x;
|
||||
const uint32_t column_len = columns_counter[base_id];
|
||||
@@ -116,10 +109,6 @@ __global__ inline void prepare_new_columns_and_pbs_indexes(
|
||||
// for message ciphertexts in and out index should be same
|
||||
const uint32_t in_index = columns[base_id][i];
|
||||
new_columns[base_id][ct_count] = in_index;
|
||||
const uint32_t pbs_index = atomicAdd(&counter, 1);
|
||||
pbs_indexes_in[pbs_index] = in_index;
|
||||
pbs_indexes_out[pbs_index] = in_index;
|
||||
lut_indexes[pbs_index] = 0;
|
||||
++ct_count;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -135,10 +124,6 @@ __global__ inline void prepare_new_columns_and_pbs_indexes(
|
||||
const uint32_t in_index = columns[prev_base_id][i];
|
||||
const uint32_t out_index = columns[prev_base_id][i + 1];
|
||||
new_columns[base_id][ct_count] = out_index;
|
||||
const uint32_t pbs_index = atomicAdd(&counter, 1);
|
||||
pbs_indexes_in[pbs_index] = in_index;
|
||||
pbs_indexes_out[pbs_index] = out_index;
|
||||
lut_indexes[pbs_index] = 1;
|
||||
++ct_count;
|
||||
}
|
||||
}
|
||||
@@ -152,16 +137,6 @@ __global__ inline void prepare_new_columns_and_pbs_indexes(
|
||||
new_columns_counter[base_id] = ct_count;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__global__ inline void prepare_final_pbs_indexes(
|
||||
Torus *const pbs_indexes_in, Torus *const pbs_indexes_out,
|
||||
Torus *const lut_indexes, const uint32_t num_radix_blocks) {
|
||||
int idx = threadIdx.x;
|
||||
pbs_indexes_in[idx] = idx % num_radix_blocks;
|
||||
pbs_indexes_out[idx] = idx + idx / num_radix_blocks;
|
||||
lut_indexes[idx] = idx / num_radix_blocks;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__global__ void calculate_chunks(Torus *const input_blocks,
|
||||
const uint32_t *const *const columns,
|
||||
@@ -367,8 +342,9 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
num_radix_in_vec);
|
||||
|
||||
bool needs_processing = false;
|
||||
radix_columns current_columns(current_blocks->degrees, num_radix_blocks,
|
||||
num_radix_in_vec, chunk_size, needs_processing);
|
||||
radix_columns<Torus> current_columns(current_blocks->degrees,
|
||||
num_radix_blocks, num_radix_in_vec,
|
||||
chunk_size, needs_processing);
|
||||
int number_of_threads = std::min(256, (int)mem_ptr->params.polynomial_size);
|
||||
int part_count = (big_lwe_size + number_of_threads - 1) / number_of_threads;
|
||||
const dim3 number_of_blocks_2d(num_radix_blocks, part_count, 1);
|
||||
@@ -378,22 +354,31 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
while (needs_processing) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
calculate_chunks<Torus>
|
||||
<<<number_of_blocks_2d, number_of_threads, 0, streams[0]>>>(
|
||||
(Torus *)(current_blocks->ptr), d_columns, d_columns_counter,
|
||||
chunk_size, big_lwe_size);
|
||||
|
||||
prepare_new_columns_and_pbs_indexes<<<1, num_radix_blocks, 0, streams[0]>>>(
|
||||
d_new_columns, d_new_columns_counter, d_pbs_indexes_in,
|
||||
d_pbs_indexes_out, luts_message_carry->get_lut_indexes(0, 0), d_columns,
|
||||
d_columns_counter, chunk_size);
|
||||
prepare_new_columns<Torus><<<1, num_radix_blocks, 0, streams[0]>>>(
|
||||
d_new_columns, d_new_columns_counter, d_columns, d_columns_counter,
|
||||
chunk_size);
|
||||
|
||||
uint32_t total_ciphertexts;
|
||||
uint32_t total_messages;
|
||||
current_columns.next_accumulation(total_ciphertexts, total_messages,
|
||||
needs_processing);
|
||||
uint32_t total_ciphertexts = 0;
|
||||
uint32_t total_messages = 0;
|
||||
auto h_pbs_indexes_in = mem_ptr->luts_message_carry->h_lwe_indexes_in;
|
||||
auto h_pbs_indexes_out = mem_ptr->luts_message_carry->h_lwe_indexes_out;
|
||||
auto h_lut_indexes = mem_ptr->luts_message_carry->h_lut_indexes;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
current_columns.next_accumulation(h_pbs_indexes_in, h_pbs_indexes_out,
|
||||
h_lut_indexes, total_ciphertexts,
|
||||
total_messages, needs_processing);
|
||||
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
|
||||
h_pbs_indexes_in, h_pbs_indexes_out);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
luts_message_carry->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
total_ciphertexts * sizeof(Torus), streams[0], gpu_indexes[0], true);
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes);
|
||||
|
||||
auto active_gpu_count = get_active_gpu_count(total_ciphertexts, gpu_count);
|
||||
if (active_gpu_count == 1) {
|
||||
@@ -415,36 +400,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
total_ciphertexts, mem_ptr->params.pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
} else {
|
||||
Torus *h_lwe_indexes_in_pinned;
|
||||
Torus *h_lwe_indexes_out_pinned;
|
||||
cudaMallocHost((void **)&h_lwe_indexes_in_pinned,
|
||||
total_ciphertexts * sizeof(Torus));
|
||||
cudaMallocHost((void **)&h_lwe_indexes_out_pinned,
|
||||
total_ciphertexts * sizeof(Torus));
|
||||
for (uint32_t i = 0; i < total_ciphertexts; i++) {
|
||||
h_lwe_indexes_in_pinned[i] = luts_message_carry->h_lwe_indexes_in[i];
|
||||
h_lwe_indexes_out_pinned[i] = luts_message_carry->h_lwe_indexes_out[i];
|
||||
}
|
||||
cuda_memcpy_async_to_cpu(
|
||||
h_lwe_indexes_in_pinned, luts_message_carry->lwe_indexes_in,
|
||||
total_ciphertexts * sizeof(Torus), streams[0], gpu_indexes[0]);
|
||||
cuda_memcpy_async_to_cpu(
|
||||
h_lwe_indexes_out_pinned, luts_message_carry->lwe_indexes_out,
|
||||
total_ciphertexts * sizeof(Torus), streams[0], gpu_indexes[0]);
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
for (uint32_t i = 0; i < total_ciphertexts; i++) {
|
||||
luts_message_carry->h_lwe_indexes_in[i] = h_lwe_indexes_in_pinned[i];
|
||||
luts_message_carry->h_lwe_indexes_out[i] = h_lwe_indexes_out_pinned[i];
|
||||
}
|
||||
cudaFreeHost(h_lwe_indexes_in_pinned);
|
||||
cudaFreeHost(h_lwe_indexes_out_pinned);
|
||||
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes);
|
||||
luts_message_carry->using_trivial_lwe_indexes = false;
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, current_blocks, current_blocks, bsks,
|
||||
ksks, ms_noise_reduction_key, luts_message_carry, total_ciphertexts);
|
||||
streams, gpu_indexes, active_gpu_count, current_blocks,
|
||||
current_blocks, bsks, ksks, ms_noise_reduction_key,
|
||||
luts_message_carry, total_ciphertexts);
|
||||
}
|
||||
cuda_set_device(gpu_indexes[0]);
|
||||
std::swap(d_columns, d_new_columns);
|
||||
@@ -458,12 +417,22 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
if (mem_ptr->reduce_degrees_for_single_carry_propagation) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto h_pbs_indexes_in = mem_ptr->luts_message_carry->h_lwe_indexes_in;
|
||||
auto h_pbs_indexes_out = mem_ptr->luts_message_carry->h_lwe_indexes_out;
|
||||
auto h_lut_indexes = mem_ptr->luts_message_carry->h_lut_indexes;
|
||||
for (uint i = 0; i < 2 * num_radix_blocks; i++) {
|
||||
h_pbs_indexes_in[i] = i % num_radix_blocks;
|
||||
h_pbs_indexes_out[i] = i + i / num_radix_blocks;
|
||||
h_lut_indexes[i] = i / num_radix_blocks;
|
||||
}
|
||||
mem_ptr->luts_message_carry->set_lwe_indexes(
|
||||
streams[0], gpu_indexes[0], h_pbs_indexes_in, h_pbs_indexes_out);
|
||||
cuda_memcpy_with_size_tracking_async_to_gpu(
|
||||
luts_message_carry->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
2 * num_radix_blocks * sizeof(Torus), streams[0], gpu_indexes[0], true);
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes);
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
prepare_final_pbs_indexes<Torus>
|
||||
<<<1, 2 * num_radix_blocks, 0, streams[0]>>>(
|
||||
d_pbs_indexes_in, d_pbs_indexes_out,
|
||||
luts_message_carry->get_lut_indexes(0, 0), num_radix_blocks);
|
||||
|
||||
set_zero_radix_ciphertext_slice_async<Torus>(
|
||||
streams[0], gpu_indexes[0], current_blocks, num_radix_blocks,
|
||||
@@ -490,38 +459,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
2 * num_radix_blocks, mem_ptr->params.pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
} else {
|
||||
uint32_t num_blocks_in_apply_lut = 2 * num_radix_blocks;
|
||||
Torus *h_lwe_indexes_in_pinned;
|
||||
Torus *h_lwe_indexes_out_pinned;
|
||||
cudaMallocHost((void **)&h_lwe_indexes_in_pinned,
|
||||
num_blocks_in_apply_lut * sizeof(Torus));
|
||||
cudaMallocHost((void **)&h_lwe_indexes_out_pinned,
|
||||
num_blocks_in_apply_lut * sizeof(Torus));
|
||||
for (uint32_t i = 0; i < num_blocks_in_apply_lut; i++) {
|
||||
h_lwe_indexes_in_pinned[i] = luts_message_carry->h_lwe_indexes_in[i];
|
||||
h_lwe_indexes_out_pinned[i] = luts_message_carry->h_lwe_indexes_out[i];
|
||||
}
|
||||
cuda_memcpy_async_to_cpu(
|
||||
h_lwe_indexes_in_pinned, luts_message_carry->lwe_indexes_in,
|
||||
num_blocks_in_apply_lut * sizeof(Torus), streams[0], gpu_indexes[0]);
|
||||
cuda_memcpy_async_to_cpu(
|
||||
h_lwe_indexes_out_pinned, luts_message_carry->lwe_indexes_out,
|
||||
num_blocks_in_apply_lut * sizeof(Torus), streams[0], gpu_indexes[0]);
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
for (uint32_t i = 0; i < num_blocks_in_apply_lut; i++) {
|
||||
luts_message_carry->h_lwe_indexes_in[i] = h_lwe_indexes_in_pinned[i];
|
||||
luts_message_carry->h_lwe_indexes_out[i] = h_lwe_indexes_out_pinned[i];
|
||||
}
|
||||
cudaFreeHost(h_lwe_indexes_in_pinned);
|
||||
cudaFreeHost(h_lwe_indexes_out_pinned);
|
||||
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes);
|
||||
luts_message_carry->using_trivial_lwe_indexes = false;
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, current_blocks, radix_lwe_out,
|
||||
bsks, ksks, ms_noise_reduction_key, luts_message_carry,
|
||||
num_blocks_in_apply_lut);
|
||||
2 * num_radix_blocks);
|
||||
}
|
||||
calculate_final_degrees(radix_lwe_out->degrees, terms->degrees,
|
||||
num_radix_blocks, num_radix_in_vec, chunk_size,
|
||||
|
||||
Reference in New Issue
Block a user