mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): refactor subset_first and subset
This commit is contained in:
@@ -78,14 +78,8 @@ public:
|
||||
get_active_gpu_count(num_radix_blocks, _gpu_count));
|
||||
}
|
||||
|
||||
// Returns a subset containing only the first gpu of this set. It
|
||||
// is used to create subset of streams for mono-GPU functions
|
||||
CudaStreams subset_first_gpu() const {
|
||||
return CudaStreams(_streams, _gpu_indexes, 1);
|
||||
}
|
||||
|
||||
// Returns a subset containing only the ith stream
|
||||
CudaStreams subset(int i) const {
|
||||
// Returns a CudaStreams struct containing only the ith stream
|
||||
CudaStreams get_ith(int i) const {
|
||||
return CudaStreams(&_streams[i], &_gpu_indexes[i], 1);
|
||||
}
|
||||
|
||||
|
||||
@@ -1347,7 +1347,7 @@ template <typename Torus> struct int_fullprop_buffer {
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
this->params = params;
|
||||
gpu_memory_allocated = allocate_gpu_memory;
|
||||
lut = new int_radix_lut<Torus>(streams.subset_first_gpu(), params, 2, 2,
|
||||
lut = new int_radix_lut<Torus>(streams.get_ith(0), params, 2, 2,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
// LUTs
|
||||
@@ -1407,7 +1407,7 @@ template <typename Torus> struct int_fullprop_buffer {
|
||||
tmp_small_lwe_vector, gpu_memory_allocated);
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
|
||||
tmp_big_lwe_vector, gpu_memory_allocated);
|
||||
lut->release(streams.subset_first_gpu());
|
||||
lut->release(streams.get_ith(0));
|
||||
delete tmp_small_lwe_vector;
|
||||
delete tmp_big_lwe_vector;
|
||||
delete lut;
|
||||
@@ -4488,19 +4488,19 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
bool allocate_gpu_memory, uint64_t &size_tracker) {
|
||||
|
||||
zero_out_if_not_1_lut_1 =
|
||||
new int_radix_lut<Torus>(streams.subset(0), params, 1, num_blocks,
|
||||
new int_radix_lut<Torus>(streams.get_ith(0), params, 1, num_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
zero_out_if_not_2_lut_1 =
|
||||
new int_radix_lut<Torus>(streams.subset(1), params, 1, num_blocks,
|
||||
new int_radix_lut<Torus>(streams.get_ith(1), params, 1, num_blocks,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
zero_out_if_not_2_lut_2 =
|
||||
new int_radix_lut<Torus>(streams.subset(2), params, 1, num_blocks,
|
||||
new int_radix_lut<Torus>(streams.get_ith(2), params, 1, num_blocks,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
zero_out_if_not_1_lut_2 =
|
||||
new int_radix_lut<Torus>(streams.subset(3), params, 1, num_blocks,
|
||||
new int_radix_lut<Torus>(streams.get_ith(3), params, 1, num_blocks,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
auto zero_out_if_not_1_lut_f = [](Torus x) -> Torus {
|
||||
@@ -4539,12 +4539,14 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
params.carry_modulus, zero_out_if_not_2_lut_f, gpu_memory_allocated);
|
||||
}
|
||||
|
||||
quotient_lut_1 = new int_radix_lut<Torus>(
|
||||
streams.subset(2), params, 1, 1, allocate_gpu_memory, tmp_size_tracker);
|
||||
quotient_lut_2 = new int_radix_lut<Torus>(
|
||||
streams.subset(1), params, 1, 1, allocate_gpu_memory, tmp_size_tracker);
|
||||
quotient_lut_1 =
|
||||
new int_radix_lut<Torus>(streams.get_ith(2), params, 1, 1,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
quotient_lut_2 =
|
||||
new int_radix_lut<Torus>(streams.get_ith(1), params, 1, 1,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
quotient_lut_3 = new int_radix_lut<Torus>(
|
||||
streams.subset(0), params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
streams.get_ith(0), params, 1, 1, allocate_gpu_memory, size_tracker);
|
||||
|
||||
auto quotient_lut_1_f = [](Torus cond) -> Torus {
|
||||
return (Torus)(cond == 2);
|
||||
@@ -4610,22 +4612,22 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
gpu_memory_allocated = allocate_gpu_memory;
|
||||
|
||||
sub_and_propagate_mem = new int_sub_and_propagate<Torus>(
|
||||
streams.subset(0), params, num_blocks + 1, outputFlag::FLAG_NONE,
|
||||
streams.get_ith(0), params, num_blocks + 1, outputFlag::FLAG_NONE,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
|
||||
shift_mem = new int_logical_scalar_shift_buffer<Torus>(
|
||||
streams.subset(1), SHIFT_OR_ROTATE_TYPE::LEFT_SHIFT, params,
|
||||
streams.get_ith(1), SHIFT_OR_ROTATE_TYPE::LEFT_SHIFT, params,
|
||||
2 * num_blocks, allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
uint32_t compute_overflow = 1;
|
||||
overflow_sub_mem_1 = new int_borrow_prop_memory<Torus>(
|
||||
streams.subset(0), params, num_blocks, compute_overflow,
|
||||
streams.get_ith(0), params, num_blocks, compute_overflow,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
overflow_sub_mem_2 = new int_borrow_prop_memory<Torus>(
|
||||
streams.subset(1), params, num_blocks, compute_overflow,
|
||||
streams.get_ith(1), params, num_blocks, compute_overflow,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
overflow_sub_mem_3 = new int_borrow_prop_memory<Torus>(
|
||||
streams.subset(2), params, num_blocks, compute_overflow,
|
||||
streams.get_ith(2), params, num_blocks, compute_overflow,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
uint32_t group_size = overflow_sub_mem_1->group_size;
|
||||
bool use_seq = overflow_sub_mem_1->prop_simu_group_carries_mem
|
||||
@@ -4633,7 +4635,7 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
|
||||
cuda_set_device(0);
|
||||
cudaEventCreateWithFlags(&create_indexes_done, cudaEventDisableTiming);
|
||||
create_indexes_for_overflow_sub(streams.subset(0), num_blocks, group_size,
|
||||
create_indexes_for_overflow_sub(streams.get_ith(0), num_blocks, group_size,
|
||||
use_seq, allocate_gpu_memory, size_tracker);
|
||||
cudaEventRecord(create_indexes_done, streams.stream(0));
|
||||
cuda_set_device(1);
|
||||
@@ -4653,22 +4655,22 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
num_blocks, allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
comparison_buffer_1 = new int_comparison_buffer<Torus>(
|
||||
streams.subset(0), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
streams.get_ith(0), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
comparison_buffer_2 = new int_comparison_buffer<Torus>(
|
||||
streams.subset(1), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
streams.get_ith(1), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
comparison_buffer_3 = new int_comparison_buffer<Torus>(
|
||||
streams.subset(2), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
streams.get_ith(2), COMPARISON_TYPE::EQ, params, num_blocks, false,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
bitor_mem_1 = new int_bitop_buffer<Torus>(
|
||||
streams.subset(0), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
streams.get_ith(0), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
bitor_mem_2 = new int_bitop_buffer<Torus>(
|
||||
streams.subset(1), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
streams.get_ith(1), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
bitor_mem_3 = new int_bitop_buffer<Torus>(
|
||||
streams.subset(2), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
streams.get_ith(2), BITOP_TYPE::BITOR, params, num_blocks,
|
||||
allocate_gpu_memory, tmp_size_tracker);
|
||||
|
||||
init_lookup_tables(streams, num_blocks, allocate_gpu_memory, size_tracker);
|
||||
@@ -4823,17 +4825,17 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
}
|
||||
|
||||
// release and delete integer ops memory objects
|
||||
sub_and_propagate_mem->release(streams.subset(0));
|
||||
shift_mem->release(streams.subset(1));
|
||||
overflow_sub_mem_1->release(streams.subset(0));
|
||||
overflow_sub_mem_2->release(streams.subset(1));
|
||||
overflow_sub_mem_3->release(streams.subset(2));
|
||||
comparison_buffer_1->release(streams.subset(0));
|
||||
comparison_buffer_2->release(streams.subset(1));
|
||||
comparison_buffer_3->release(streams.subset(2));
|
||||
bitor_mem_1->release(streams.subset(0));
|
||||
bitor_mem_2->release(streams.subset(1));
|
||||
bitor_mem_3->release(streams.subset(2));
|
||||
sub_and_propagate_mem->release(streams.get_ith(0));
|
||||
shift_mem->release(streams.get_ith(1));
|
||||
overflow_sub_mem_1->release(streams.get_ith(0));
|
||||
overflow_sub_mem_2->release(streams.get_ith(1));
|
||||
overflow_sub_mem_3->release(streams.get_ith(2));
|
||||
comparison_buffer_1->release(streams.get_ith(0));
|
||||
comparison_buffer_2->release(streams.get_ith(1));
|
||||
comparison_buffer_3->release(streams.get_ith(2));
|
||||
bitor_mem_1->release(streams.get_ith(0));
|
||||
bitor_mem_2->release(streams.get_ith(1));
|
||||
bitor_mem_3->release(streams.get_ith(2));
|
||||
|
||||
delete sub_and_propagate_mem;
|
||||
sub_and_propagate_mem = nullptr;
|
||||
@@ -4861,13 +4863,13 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
|
||||
// release and delete lut objects
|
||||
message_extract_lut_1->release(streams);
|
||||
message_extract_lut_2->release(streams);
|
||||
zero_out_if_not_1_lut_1->release(streams.subset(0));
|
||||
zero_out_if_not_1_lut_2->release(streams.subset(3));
|
||||
zero_out_if_not_2_lut_1->release(streams.subset(1));
|
||||
zero_out_if_not_2_lut_2->release(streams.subset(2));
|
||||
quotient_lut_1->release(streams.subset(2));
|
||||
quotient_lut_2->release(streams.subset(1));
|
||||
quotient_lut_3->release(streams.subset(0));
|
||||
zero_out_if_not_1_lut_1->release(streams.get_ith(0));
|
||||
zero_out_if_not_1_lut_2->release(streams.get_ith(3));
|
||||
zero_out_if_not_2_lut_1->release(streams.get_ith(1));
|
||||
zero_out_if_not_2_lut_2->release(streams.get_ith(2));
|
||||
quotient_lut_1->release(streams.get_ith(2));
|
||||
quotient_lut_2->release(streams.get_ith(1));
|
||||
quotient_lut_3->release(streams.get_ith(0));
|
||||
|
||||
delete message_extract_lut_1;
|
||||
message_extract_lut_1 = nullptr;
|
||||
|
||||
@@ -118,21 +118,21 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
|
||||
// Computes 2*d by extending and shifting on gpu[1]
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(
|
||||
mem_ptr->d2, divisor_gpu_1, streams.subset(1));
|
||||
mem_ptr->d2, divisor_gpu_1, streams.get_ith(1));
|
||||
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
|
||||
streams.subset(1), mem_ptr->d2, 1, mem_ptr->shift_mem, &bsks[1], &ksks[1],
|
||||
ms_noise_reduction_keys[1], mem_ptr->d2->num_radix_blocks);
|
||||
streams.get_ith(1), mem_ptr->d2, 1, mem_ptr->shift_mem, &bsks[1],
|
||||
&ksks[1], ms_noise_reduction_keys[1], mem_ptr->d2->num_radix_blocks);
|
||||
|
||||
// Computes 3*d = 4*d - d using block shift and subtraction on gpu[0]
|
||||
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(
|
||||
mem_ptr->tmp_gpu_0, divisor_gpu_0, streams.subset(0));
|
||||
host_radix_blocks_rotate_right<Torus>(streams.subset(0), mem_ptr->d3,
|
||||
mem_ptr->tmp_gpu_0, divisor_gpu_0, streams.get_ith(0));
|
||||
host_radix_blocks_rotate_right<Torus>(streams.get_ith(0), mem_ptr->d3,
|
||||
mem_ptr->tmp_gpu_0, 1,
|
||||
mem_ptr->tmp_gpu_0->num_radix_blocks);
|
||||
set_zero_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), mem_ptr->d3, 0, 1);
|
||||
host_sub_and_propagate_single_carry(
|
||||
streams.subset(0), mem_ptr->d3, mem_ptr->tmp_gpu_0, nullptr, nullptr,
|
||||
streams.get_ith(0), mem_ptr->d3, mem_ptr->tmp_gpu_0, nullptr, nullptr,
|
||||
mem_ptr->sub_and_propagate_mem, &bsks[0], &ksks[0],
|
||||
ms_noise_reduction_keys[0], outputFlag::FLAG_NONE, 0);
|
||||
|
||||
@@ -189,10 +189,10 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
uint32_t uses_input_borrow = 0;
|
||||
sub_result->num_radix_blocks = low->num_radix_blocks;
|
||||
overflow_sub_mem->update_lut_indexes(
|
||||
streams.subset(gpu_index), first_indexes, second_indexes,
|
||||
streams.get_ith(gpu_index), first_indexes, second_indexes,
|
||||
scalar_indexes, rem->num_radix_blocks);
|
||||
host_integer_overflowing_sub<uint64_t>(
|
||||
streams.subset(gpu_index), sub_result, rem, low, sub_overflowed,
|
||||
streams.get_ith(gpu_index), sub_result, rem, low, sub_overflowed,
|
||||
(const CudaRadixCiphertextFFI *)nullptr, overflow_sub_mem,
|
||||
&bsks[gpu_index], &ksks[gpu_index],
|
||||
ms_noise_reduction_keys[gpu_index], compute_overflow,
|
||||
@@ -216,12 +216,12 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
streams.stream(gpu_index), streams.gpu_index(gpu_index));
|
||||
} else {
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
streams.subset(gpu_index), comparison_blocks, d_msb,
|
||||
streams.get_ith(gpu_index), comparison_blocks, d_msb,
|
||||
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
|
||||
ms_noise_reduction_keys[gpu_index], d_msb->num_radix_blocks,
|
||||
comparison_buffer->is_zero_lut);
|
||||
are_all_comparisons_block_true(
|
||||
streams.subset(gpu_index), out_boolean_block, comparison_blocks,
|
||||
streams.get_ith(gpu_index), out_boolean_block, comparison_blocks,
|
||||
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
|
||||
ms_noise_reduction_keys[gpu_index],
|
||||
comparison_blocks->num_radix_blocks);
|
||||
@@ -287,15 +287,15 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
auto o3 = mem_ptr->sub_1_overflowed;
|
||||
|
||||
// used as a bitor
|
||||
host_integer_radix_bitop_kb(streams.subset(0), o3, o3, mem_ptr->cmp_1,
|
||||
host_integer_radix_bitop_kb(streams.get_ith(0), o3, o3, mem_ptr->cmp_1,
|
||||
mem_ptr->bitor_mem_1, &bsks[0], &ksks[0],
|
||||
ms_noise_reduction_keys[0]);
|
||||
// used as a bitor
|
||||
host_integer_radix_bitop_kb(streams.subset(1), o2, o2, mem_ptr->cmp_2,
|
||||
host_integer_radix_bitop_kb(streams.get_ith(1), o2, o2, mem_ptr->cmp_2,
|
||||
mem_ptr->bitor_mem_2, &bsks[1], &ksks[1],
|
||||
ms_noise_reduction_keys[1]);
|
||||
// used as a bitor
|
||||
host_integer_radix_bitop_kb(streams.subset(2), o1, o1, mem_ptr->cmp_3,
|
||||
host_integer_radix_bitop_kb(streams.get_ith(2), o1, o1, mem_ptr->cmp_3,
|
||||
mem_ptr->bitor_mem_3, &bsks[2], &ksks[2],
|
||||
ms_noise_reduction_keys[2]);
|
||||
|
||||
@@ -378,8 +378,9 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
streams.gpu_index(gpu_index),
|
||||
rx, rx, cx, 4, 4);
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams.subset(gpu_index), rx, rx, &bsks[gpu_index], &ksks[gpu_index],
|
||||
ms_noise_reduction_keys[gpu_index], lut, rx->num_radix_blocks);
|
||||
streams.get_ith(gpu_index), rx, rx, &bsks[gpu_index],
|
||||
&ksks[gpu_index], ms_noise_reduction_keys[gpu_index], lut,
|
||||
rx->num_radix_blocks);
|
||||
};
|
||||
|
||||
for (uint j = 0; j < 4; j++) {
|
||||
@@ -396,15 +397,15 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
|
||||
|
||||
// calculate quotient bits GPU[2]
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
mem_ptr->sub_streams_1.subset(2), mem_ptr->q1, c1, &bsks[2], &ksks[2],
|
||||
mem_ptr->sub_streams_1.get_ith(2), mem_ptr->q1, c1, &bsks[2], &ksks[2],
|
||||
ms_noise_reduction_keys[2], mem_ptr->quotient_lut_1, 1);
|
||||
// calculate quotient bits GPU[1]
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
mem_ptr->sub_streams_1.subset(1), mem_ptr->q2, c2, &bsks[1], &ksks[1],
|
||||
mem_ptr->sub_streams_1.get_ith(1), mem_ptr->q2, c2, &bsks[1], &ksks[1],
|
||||
ms_noise_reduction_keys[1], mem_ptr->quotient_lut_2, 1);
|
||||
// calculate quotient bits GPU[0]
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
mem_ptr->sub_streams_1.subset(0), mem_ptr->q3, c3, &bsks[0], &ksks[0],
|
||||
mem_ptr->sub_streams_1.get_ith(0), mem_ptr->q3, c3, &bsks[0], &ksks[0],
|
||||
ms_noise_reduction_keys[0], mem_ptr->quotient_lut_3, 1);
|
||||
|
||||
for (uint j = 0; j < 4; j++) {
|
||||
|
||||
@@ -547,20 +547,20 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
|
||||
ks_base_log, ks_level, num_radix_blocks);
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)lwe_array_out->ptr,
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
ms_noise_reduction_key, lut->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
|
||||
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
|
||||
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
} else {
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
@@ -664,20 +664,20 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
|
||||
ks_base_log, ks_level, num_radix_blocks);
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)lwe_array_out->ptr,
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
ms_noise_reduction_key, lut->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
|
||||
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
|
||||
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
} else {
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
@@ -796,20 +796,20 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
|
||||
ks_base_log, ks_level, num_radix_blocks);
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)(lwe_array_out->ptr),
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
ms_noise_reduction_key, lut->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
streams.get_ith(0), (Torus *)(lwe_array_out->ptr), lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
|
||||
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
|
||||
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
} else {
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
@@ -1575,8 +1575,7 @@ void host_full_propagate_inplace(
|
||||
|
||||
/// Since the keyswitch is done on one input only, use only 1 GPU
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(),
|
||||
(Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
|
||||
streams.get_ith(0), (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
|
||||
mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr,
|
||||
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
|
||||
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1);
|
||||
@@ -1586,7 +1585,7 @@ void host_full_propagate_inplace(
|
||||
1, 2, mem_ptr->tmp_small_lwe_vector, 0, 1);
|
||||
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
|
||||
streams.get_ith(0), (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
|
||||
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
|
||||
mem_ptr->lut->lut_indexes_vec,
|
||||
(Torus *)mem_ptr->tmp_small_lwe_vector->ptr,
|
||||
@@ -2419,11 +2418,10 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
streams.active_gpu_subset(lwe_array_out->num_radix_blocks);
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<InputTorus>(
|
||||
streams.subset_first_gpu(), lwe_after_ks_vec[0],
|
||||
lwe_trivial_indexes_vec[0], (InputTorus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, ksks, lut->input_big_lwe_dimension,
|
||||
small_lwe_dimension, ks_base_log, ks_level,
|
||||
lwe_array_out->num_radix_blocks);
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
|
||||
ks_level, lwe_array_out->num_radix_blocks);
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
@@ -2431,7 +2429,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
/// int_noise_squashing_lut doesn't support a different output or lut
|
||||
/// indexing than the trivial
|
||||
execute_pbs_async<uint64_t, __uint128_t>(
|
||||
streams.subset_first_gpu(), (__uint128_t *)lwe_array_out->ptr,
|
||||
streams.get_ith(0), (__uint128_t *)lwe_array_out->ptr,
|
||||
lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
ms_noise_reduction_key, lut->pbs_buffer, glwe_dimension,
|
||||
|
||||
@@ -398,19 +398,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)small_lwe_vector->ptr,
|
||||
d_pbs_indexes_in, (Torus *)current_blocks->ptr, d_pbs_indexes_in,
|
||||
ksks, big_lwe_dimension, small_lwe_dimension,
|
||||
mem_ptr->params.ks_base_log, mem_ptr->params.ks_level,
|
||||
total_messages);
|
||||
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
|
||||
(Torus *)current_blocks->ptr, d_pbs_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
|
||||
mem_ptr->params.ks_level, total_messages);
|
||||
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)current_blocks->ptr,
|
||||
d_pbs_indexes_out, luts_message_carry->lut_vec,
|
||||
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
|
||||
d_pbs_indexes_in, bsks, ms_noise_reduction_key,
|
||||
luts_message_carry->buffer, glwe_dimension, small_lwe_dimension,
|
||||
polynomial_size, mem_ptr->params.pbs_base_log,
|
||||
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
|
||||
luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec,
|
||||
(Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks,
|
||||
ms_noise_reduction_key, luts_message_carry->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log,
|
||||
mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor,
|
||||
total_ciphertexts, mem_ptr->params.pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
@@ -451,18 +449,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)small_lwe_vector->ptr,
|
||||
d_pbs_indexes_in, (Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
|
||||
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
|
||||
(Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
|
||||
mem_ptr->params.ks_level, num_radix_blocks);
|
||||
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)current_blocks->ptr,
|
||||
d_pbs_indexes_out, luts_message_carry->lut_vec,
|
||||
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
|
||||
d_pbs_indexes_in, bsks, ms_noise_reduction_key,
|
||||
luts_message_carry->buffer, glwe_dimension, small_lwe_dimension,
|
||||
polynomial_size, mem_ptr->params.pbs_base_log,
|
||||
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
|
||||
luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec,
|
||||
(Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks,
|
||||
ms_noise_reduction_key, luts_message_carry->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log,
|
||||
mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor,
|
||||
2 * num_radix_blocks, mem_ptr->params.pbs_type, num_many_lut,
|
||||
lut_stride);
|
||||
|
||||
@@ -31,8 +31,8 @@ void host_integer_grouped_oprf(
|
||||
|
||||
if (active_streams.count() == 1) {
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.subset_first_gpu(), (Torus *)(radix_lwe_out->ptr),
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
streams.get_ith(0), (Torus *)(radix_lwe_out->ptr), lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec,
|
||||
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
|
||||
ms_noise_reduction_key, lut->buffer, mem_ptr->params.glwe_dimension,
|
||||
mem_ptr->params.small_lwe_dimension, mem_ptr->params.polynomial_size,
|
||||
|
||||
@@ -77,7 +77,7 @@ __host__ void host_expand_without_verification(
|
||||
|
||||
// apply keyswitch to BIG
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.subset_first_gpu(), ksed_small_to_big_expanded_lwes,
|
||||
streams.get_ith(0), ksed_small_to_big_expanded_lwes,
|
||||
lwe_trivial_indexes_vec[0], expanded_lwes, lwe_trivial_indexes_vec[0],
|
||||
casting_keys, casting_input_dimension, casting_output_dimension,
|
||||
casting_ks_base_log, casting_ks_level, num_lwes);
|
||||
|
||||
Reference in New Issue
Block a user