chore(gpu): refactor subset_first and subset

This commit is contained in:
Agnes Leroy
2025-09-29 17:46:29 +02:00
committed by Agnès Leroy
parent 4f5d711c4e
commit 71b45c14da
7 changed files with 120 additions and 128 deletions

View File

@@ -78,14 +78,8 @@ public:
get_active_gpu_count(num_radix_blocks, _gpu_count));
}
// Returns a subset containing only the first gpu of this set. It
// is used to create subset of streams for mono-GPU functions
CudaStreams subset_first_gpu() const {
return CudaStreams(_streams, _gpu_indexes, 1);
}
// Returns a subset containing only the ith stream
CudaStreams subset(int i) const {
// Returns a CudaStreams struct containing only the ith stream
CudaStreams get_ith(int i) const {
return CudaStreams(&_streams[i], &_gpu_indexes[i], 1);
}

View File

@@ -1347,7 +1347,7 @@ template <typename Torus> struct int_fullprop_buffer {
bool allocate_gpu_memory, uint64_t &size_tracker) {
this->params = params;
gpu_memory_allocated = allocate_gpu_memory;
lut = new int_radix_lut<Torus>(streams.subset_first_gpu(), params, 2, 2,
lut = new int_radix_lut<Torus>(streams.get_ith(0), params, 2, 2,
allocate_gpu_memory, size_tracker);
// LUTs
@@ -1407,7 +1407,7 @@ template <typename Torus> struct int_fullprop_buffer {
tmp_small_lwe_vector, gpu_memory_allocated);
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0),
tmp_big_lwe_vector, gpu_memory_allocated);
lut->release(streams.subset_first_gpu());
lut->release(streams.get_ith(0));
delete tmp_small_lwe_vector;
delete tmp_big_lwe_vector;
delete lut;
@@ -4488,19 +4488,19 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
bool allocate_gpu_memory, uint64_t &size_tracker) {
zero_out_if_not_1_lut_1 =
new int_radix_lut<Torus>(streams.subset(0), params, 1, num_blocks,
new int_radix_lut<Torus>(streams.get_ith(0), params, 1, num_blocks,
allocate_gpu_memory, size_tracker);
zero_out_if_not_2_lut_1 =
new int_radix_lut<Torus>(streams.subset(1), params, 1, num_blocks,
new int_radix_lut<Torus>(streams.get_ith(1), params, 1, num_blocks,
allocate_gpu_memory, tmp_size_tracker);
zero_out_if_not_2_lut_2 =
new int_radix_lut<Torus>(streams.subset(2), params, 1, num_blocks,
new int_radix_lut<Torus>(streams.get_ith(2), params, 1, num_blocks,
allocate_gpu_memory, tmp_size_tracker);
zero_out_if_not_1_lut_2 =
new int_radix_lut<Torus>(streams.subset(3), params, 1, num_blocks,
new int_radix_lut<Torus>(streams.get_ith(3), params, 1, num_blocks,
allocate_gpu_memory, tmp_size_tracker);
auto zero_out_if_not_1_lut_f = [](Torus x) -> Torus {
@@ -4539,12 +4539,14 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
params.carry_modulus, zero_out_if_not_2_lut_f, gpu_memory_allocated);
}
quotient_lut_1 = new int_radix_lut<Torus>(
streams.subset(2), params, 1, 1, allocate_gpu_memory, tmp_size_tracker);
quotient_lut_2 = new int_radix_lut<Torus>(
streams.subset(1), params, 1, 1, allocate_gpu_memory, tmp_size_tracker);
quotient_lut_1 =
new int_radix_lut<Torus>(streams.get_ith(2), params, 1, 1,
allocate_gpu_memory, tmp_size_tracker);
quotient_lut_2 =
new int_radix_lut<Torus>(streams.get_ith(1), params, 1, 1,
allocate_gpu_memory, tmp_size_tracker);
quotient_lut_3 = new int_radix_lut<Torus>(
streams.subset(0), params, 1, 1, allocate_gpu_memory, size_tracker);
streams.get_ith(0), params, 1, 1, allocate_gpu_memory, size_tracker);
auto quotient_lut_1_f = [](Torus cond) -> Torus {
return (Torus)(cond == 2);
@@ -4610,22 +4612,22 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
gpu_memory_allocated = allocate_gpu_memory;
sub_and_propagate_mem = new int_sub_and_propagate<Torus>(
streams.subset(0), params, num_blocks + 1, outputFlag::FLAG_NONE,
streams.get_ith(0), params, num_blocks + 1, outputFlag::FLAG_NONE,
allocate_gpu_memory, size_tracker);
shift_mem = new int_logical_scalar_shift_buffer<Torus>(
streams.subset(1), SHIFT_OR_ROTATE_TYPE::LEFT_SHIFT, params,
streams.get_ith(1), SHIFT_OR_ROTATE_TYPE::LEFT_SHIFT, params,
2 * num_blocks, allocate_gpu_memory, tmp_size_tracker);
uint32_t compute_overflow = 1;
overflow_sub_mem_1 = new int_borrow_prop_memory<Torus>(
streams.subset(0), params, num_blocks, compute_overflow,
streams.get_ith(0), params, num_blocks, compute_overflow,
allocate_gpu_memory, size_tracker);
overflow_sub_mem_2 = new int_borrow_prop_memory<Torus>(
streams.subset(1), params, num_blocks, compute_overflow,
streams.get_ith(1), params, num_blocks, compute_overflow,
allocate_gpu_memory, tmp_size_tracker);
overflow_sub_mem_3 = new int_borrow_prop_memory<Torus>(
streams.subset(2), params, num_blocks, compute_overflow,
streams.get_ith(2), params, num_blocks, compute_overflow,
allocate_gpu_memory, tmp_size_tracker);
uint32_t group_size = overflow_sub_mem_1->group_size;
bool use_seq = overflow_sub_mem_1->prop_simu_group_carries_mem
@@ -4633,7 +4635,7 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
cuda_set_device(0);
cudaEventCreateWithFlags(&create_indexes_done, cudaEventDisableTiming);
create_indexes_for_overflow_sub(streams.subset(0), num_blocks, group_size,
create_indexes_for_overflow_sub(streams.get_ith(0), num_blocks, group_size,
use_seq, allocate_gpu_memory, size_tracker);
cudaEventRecord(create_indexes_done, streams.stream(0));
cuda_set_device(1);
@@ -4653,22 +4655,22 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
num_blocks, allocate_gpu_memory, tmp_size_tracker);
comparison_buffer_1 = new int_comparison_buffer<Torus>(
streams.subset(0), COMPARISON_TYPE::EQ, params, num_blocks, false,
streams.get_ith(0), COMPARISON_TYPE::EQ, params, num_blocks, false,
allocate_gpu_memory, size_tracker);
comparison_buffer_2 = new int_comparison_buffer<Torus>(
streams.subset(1), COMPARISON_TYPE::EQ, params, num_blocks, false,
streams.get_ith(1), COMPARISON_TYPE::EQ, params, num_blocks, false,
allocate_gpu_memory, tmp_size_tracker);
comparison_buffer_3 = new int_comparison_buffer<Torus>(
streams.subset(2), COMPARISON_TYPE::EQ, params, num_blocks, false,
streams.get_ith(2), COMPARISON_TYPE::EQ, params, num_blocks, false,
allocate_gpu_memory, tmp_size_tracker);
bitor_mem_1 = new int_bitop_buffer<Torus>(
streams.subset(0), BITOP_TYPE::BITOR, params, num_blocks,
streams.get_ith(0), BITOP_TYPE::BITOR, params, num_blocks,
allocate_gpu_memory, size_tracker);
bitor_mem_2 = new int_bitop_buffer<Torus>(
streams.subset(1), BITOP_TYPE::BITOR, params, num_blocks,
streams.get_ith(1), BITOP_TYPE::BITOR, params, num_blocks,
allocate_gpu_memory, tmp_size_tracker);
bitor_mem_3 = new int_bitop_buffer<Torus>(
streams.subset(2), BITOP_TYPE::BITOR, params, num_blocks,
streams.get_ith(2), BITOP_TYPE::BITOR, params, num_blocks,
allocate_gpu_memory, tmp_size_tracker);
init_lookup_tables(streams, num_blocks, allocate_gpu_memory, size_tracker);
@@ -4823,17 +4825,17 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
}
// release and delete integer ops memory objects
sub_and_propagate_mem->release(streams.subset(0));
shift_mem->release(streams.subset(1));
overflow_sub_mem_1->release(streams.subset(0));
overflow_sub_mem_2->release(streams.subset(1));
overflow_sub_mem_3->release(streams.subset(2));
comparison_buffer_1->release(streams.subset(0));
comparison_buffer_2->release(streams.subset(1));
comparison_buffer_3->release(streams.subset(2));
bitor_mem_1->release(streams.subset(0));
bitor_mem_2->release(streams.subset(1));
bitor_mem_3->release(streams.subset(2));
sub_and_propagate_mem->release(streams.get_ith(0));
shift_mem->release(streams.get_ith(1));
overflow_sub_mem_1->release(streams.get_ith(0));
overflow_sub_mem_2->release(streams.get_ith(1));
overflow_sub_mem_3->release(streams.get_ith(2));
comparison_buffer_1->release(streams.get_ith(0));
comparison_buffer_2->release(streams.get_ith(1));
comparison_buffer_3->release(streams.get_ith(2));
bitor_mem_1->release(streams.get_ith(0));
bitor_mem_2->release(streams.get_ith(1));
bitor_mem_3->release(streams.get_ith(2));
delete sub_and_propagate_mem;
sub_and_propagate_mem = nullptr;
@@ -4861,13 +4863,13 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
// release and delete lut objects
message_extract_lut_1->release(streams);
message_extract_lut_2->release(streams);
zero_out_if_not_1_lut_1->release(streams.subset(0));
zero_out_if_not_1_lut_2->release(streams.subset(3));
zero_out_if_not_2_lut_1->release(streams.subset(1));
zero_out_if_not_2_lut_2->release(streams.subset(2));
quotient_lut_1->release(streams.subset(2));
quotient_lut_2->release(streams.subset(1));
quotient_lut_3->release(streams.subset(0));
zero_out_if_not_1_lut_1->release(streams.get_ith(0));
zero_out_if_not_1_lut_2->release(streams.get_ith(3));
zero_out_if_not_2_lut_1->release(streams.get_ith(1));
zero_out_if_not_2_lut_2->release(streams.get_ith(2));
quotient_lut_1->release(streams.get_ith(2));
quotient_lut_2->release(streams.get_ith(1));
quotient_lut_3->release(streams.get_ith(0));
delete message_extract_lut_1;
message_extract_lut_1 = nullptr;

View File

@@ -118,21 +118,21 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
// Computes 2*d by extending and shifting on gpu[1]
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(
mem_ptr->d2, divisor_gpu_1, streams.subset(1));
mem_ptr->d2, divisor_gpu_1, streams.get_ith(1));
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
streams.subset(1), mem_ptr->d2, 1, mem_ptr->shift_mem, &bsks[1], &ksks[1],
ms_noise_reduction_keys[1], mem_ptr->d2->num_radix_blocks);
streams.get_ith(1), mem_ptr->d2, 1, mem_ptr->shift_mem, &bsks[1],
&ksks[1], ms_noise_reduction_keys[1], mem_ptr->d2->num_radix_blocks);
// Computes 3*d = 4*d - d using block shift and subtraction on gpu[0]
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(
mem_ptr->tmp_gpu_0, divisor_gpu_0, streams.subset(0));
host_radix_blocks_rotate_right<Torus>(streams.subset(0), mem_ptr->d3,
mem_ptr->tmp_gpu_0, divisor_gpu_0, streams.get_ith(0));
host_radix_blocks_rotate_right<Torus>(streams.get_ith(0), mem_ptr->d3,
mem_ptr->tmp_gpu_0, 1,
mem_ptr->tmp_gpu_0->num_radix_blocks);
set_zero_radix_ciphertext_slice_async<Torus>(
streams.stream(0), streams.gpu_index(0), mem_ptr->d3, 0, 1);
host_sub_and_propagate_single_carry(
streams.subset(0), mem_ptr->d3, mem_ptr->tmp_gpu_0, nullptr, nullptr,
streams.get_ith(0), mem_ptr->d3, mem_ptr->tmp_gpu_0, nullptr, nullptr,
mem_ptr->sub_and_propagate_mem, &bsks[0], &ksks[0],
ms_noise_reduction_keys[0], outputFlag::FLAG_NONE, 0);
@@ -189,10 +189,10 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
uint32_t uses_input_borrow = 0;
sub_result->num_radix_blocks = low->num_radix_blocks;
overflow_sub_mem->update_lut_indexes(
streams.subset(gpu_index), first_indexes, second_indexes,
streams.get_ith(gpu_index), first_indexes, second_indexes,
scalar_indexes, rem->num_radix_blocks);
host_integer_overflowing_sub<uint64_t>(
streams.subset(gpu_index), sub_result, rem, low, sub_overflowed,
streams.get_ith(gpu_index), sub_result, rem, low, sub_overflowed,
(const CudaRadixCiphertextFFI *)nullptr, overflow_sub_mem,
&bsks[gpu_index], &ksks[gpu_index],
ms_noise_reduction_keys[gpu_index], compute_overflow,
@@ -216,12 +216,12 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
streams.stream(gpu_index), streams.gpu_index(gpu_index));
} else {
host_compare_blocks_with_zero<Torus>(
streams.subset(gpu_index), comparison_blocks, d_msb,
streams.get_ith(gpu_index), comparison_blocks, d_msb,
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
ms_noise_reduction_keys[gpu_index], d_msb->num_radix_blocks,
comparison_buffer->is_zero_lut);
are_all_comparisons_block_true(
streams.subset(gpu_index), out_boolean_block, comparison_blocks,
streams.get_ith(gpu_index), out_boolean_block, comparison_blocks,
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
ms_noise_reduction_keys[gpu_index],
comparison_blocks->num_radix_blocks);
@@ -287,15 +287,15 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
auto o3 = mem_ptr->sub_1_overflowed;
// used as a bitor
host_integer_radix_bitop_kb(streams.subset(0), o3, o3, mem_ptr->cmp_1,
host_integer_radix_bitop_kb(streams.get_ith(0), o3, o3, mem_ptr->cmp_1,
mem_ptr->bitor_mem_1, &bsks[0], &ksks[0],
ms_noise_reduction_keys[0]);
// used as a bitor
host_integer_radix_bitop_kb(streams.subset(1), o2, o2, mem_ptr->cmp_2,
host_integer_radix_bitop_kb(streams.get_ith(1), o2, o2, mem_ptr->cmp_2,
mem_ptr->bitor_mem_2, &bsks[1], &ksks[1],
ms_noise_reduction_keys[1]);
// used as a bitor
host_integer_radix_bitop_kb(streams.subset(2), o1, o1, mem_ptr->cmp_3,
host_integer_radix_bitop_kb(streams.get_ith(2), o1, o1, mem_ptr->cmp_3,
mem_ptr->bitor_mem_3, &bsks[2], &ksks[2],
ms_noise_reduction_keys[2]);
@@ -378,8 +378,9 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
streams.gpu_index(gpu_index),
rx, rx, cx, 4, 4);
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams.subset(gpu_index), rx, rx, &bsks[gpu_index], &ksks[gpu_index],
ms_noise_reduction_keys[gpu_index], lut, rx->num_radix_blocks);
streams.get_ith(gpu_index), rx, rx, &bsks[gpu_index],
&ksks[gpu_index], ms_noise_reduction_keys[gpu_index], lut,
rx->num_radix_blocks);
};
for (uint j = 0; j < 4; j++) {
@@ -396,15 +397,15 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
// calculate quotient bits GPU[2]
integer_radix_apply_univariate_lookup_table_kb<Torus>(
mem_ptr->sub_streams_1.subset(2), mem_ptr->q1, c1, &bsks[2], &ksks[2],
mem_ptr->sub_streams_1.get_ith(2), mem_ptr->q1, c1, &bsks[2], &ksks[2],
ms_noise_reduction_keys[2], mem_ptr->quotient_lut_1, 1);
// calculate quotient bits GPU[1]
integer_radix_apply_univariate_lookup_table_kb<Torus>(
mem_ptr->sub_streams_1.subset(1), mem_ptr->q2, c2, &bsks[1], &ksks[1],
mem_ptr->sub_streams_1.get_ith(1), mem_ptr->q2, c2, &bsks[1], &ksks[1],
ms_noise_reduction_keys[1], mem_ptr->quotient_lut_2, 1);
// calculate quotient bits GPU[0]
integer_radix_apply_univariate_lookup_table_kb<Torus>(
mem_ptr->sub_streams_1.subset(0), mem_ptr->q3, c3, &bsks[0], &ksks[0],
mem_ptr->sub_streams_1.get_ith(0), mem_ptr->q3, c3, &bsks[0], &ksks[0],
ms_noise_reduction_keys[0], mem_ptr->quotient_lut_3, 1);
for (uint j = 0; j < 4; j++) {

View File

@@ -547,20 +547,20 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr,
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)lwe_array_out->ptr,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
ms_noise_reduction_key, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
lut_stride);
} else {
/// Make sure all data that should be on GPU 0 is indeed there
cuda_event_record(lut->event_scatter_in, streams.stream(0),
@@ -664,20 +664,20 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr,
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)lwe_array_out->ptr,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
ms_noise_reduction_key, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
lut_stride);
} else {
/// Make sure all data that should be on GPU 0 is indeed there
cuda_event_record(lut->event_scatter_in, streams.stream(0),
@@ -796,20 +796,20 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
auto active_streams = streams.active_gpu_subset(num_radix_blocks);
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], (Torus *)lwe_array_pbs_in->ptr,
lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension,
ks_base_log, ks_level, num_radix_blocks);
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
num_radix_blocks);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)(lwe_array_out->ptr),
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
ms_noise_reduction_key, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
streams.get_ith(0), (Torus *)(lwe_array_out->ptr), lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, ms_noise_reduction_key, lut->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
pbs_level, grouping_factor, num_radix_blocks, pbs_type, num_many_lut,
lut_stride);
} else {
cuda_event_record(lut->event_scatter_in, streams.stream(0),
streams.gpu_index(0));
@@ -1575,8 +1575,7 @@ void host_full_propagate_inplace(
/// Since the keyswitch is done on one input only, use only 1 GPU
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(),
(Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
streams.get_ith(0), (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr,
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1);
@@ -1586,7 +1585,7 @@ void host_full_propagate_inplace(
1, 2, mem_ptr->tmp_small_lwe_vector, 0, 1);
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
streams.get_ith(0), (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
mem_ptr->lut->lut_indexes_vec,
(Torus *)mem_ptr->tmp_small_lwe_vector->ptr,
@@ -2419,11 +2418,10 @@ __host__ void integer_radix_apply_noise_squashing_kb(
streams.active_gpu_subset(lwe_array_out->num_radix_blocks);
if (active_streams.count() == 1) {
execute_keyswitch_async<InputTorus>(
streams.subset_first_gpu(), lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], (InputTorus *)lwe_array_pbs_in->ptr,
lut->lwe_indexes_in, ksks, lut->input_big_lwe_dimension,
small_lwe_dimension, ks_base_log, ks_level,
lwe_array_out->num_radix_blocks);
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
ks_level, lwe_array_out->num_radix_blocks);
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
@@ -2431,7 +2429,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
/// int_noise_squashing_lut doesn't support a different output or lut
/// indexing than the trivial
execute_pbs_async<uint64_t, __uint128_t>(
streams.subset_first_gpu(), (__uint128_t *)lwe_array_out->ptr,
streams.get_ith(0), (__uint128_t *)lwe_array_out->ptr,
lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
ms_noise_reduction_key, lut->pbs_buffer, glwe_dimension,

View File

@@ -398,19 +398,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), (Torus *)small_lwe_vector->ptr,
d_pbs_indexes_in, (Torus *)current_blocks->ptr, d_pbs_indexes_in,
ksks, big_lwe_dimension, small_lwe_dimension,
mem_ptr->params.ks_base_log, mem_ptr->params.ks_level,
total_messages);
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
(Torus *)current_blocks->ptr, d_pbs_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, total_messages);
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)current_blocks->ptr,
d_pbs_indexes_out, luts_message_carry->lut_vec,
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
d_pbs_indexes_in, bsks, ms_noise_reduction_key,
luts_message_carry->buffer, glwe_dimension, small_lwe_dimension,
polynomial_size, mem_ptr->params.pbs_base_log,
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec,
(Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks,
ms_noise_reduction_key, luts_message_carry->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log,
mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor,
total_ciphertexts, mem_ptr->params.pbs_type, num_many_lut,
lut_stride);
@@ -451,18 +449,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
if (active_streams.count() == 1) {
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), (Torus *)small_lwe_vector->ptr,
d_pbs_indexes_in, (Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in,
(Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks,
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, num_radix_blocks);
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)current_blocks->ptr,
d_pbs_indexes_out, luts_message_carry->lut_vec,
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
d_pbs_indexes_in, bsks, ms_noise_reduction_key,
luts_message_carry->buffer, glwe_dimension, small_lwe_dimension,
polynomial_size, mem_ptr->params.pbs_base_log,
streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out,
luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec,
(Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks,
ms_noise_reduction_key, luts_message_carry->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log,
mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor,
2 * num_radix_blocks, mem_ptr->params.pbs_type, num_many_lut,
lut_stride);

View File

@@ -31,8 +31,8 @@ void host_integer_grouped_oprf(
if (active_streams.count() == 1) {
execute_pbs_async<Torus, Torus>(
streams.subset_first_gpu(), (Torus *)(radix_lwe_out->ptr),
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
streams.get_ith(0), (Torus *)(radix_lwe_out->ptr), lut->lwe_indexes_out,
lut->lut_vec, lut->lut_indexes_vec,
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
ms_noise_reduction_key, lut->buffer, mem_ptr->params.glwe_dimension,
mem_ptr->params.small_lwe_dimension, mem_ptr->params.polynomial_size,

View File

@@ -77,7 +77,7 @@ __host__ void host_expand_without_verification(
// apply keyswitch to BIG
execute_keyswitch_async<Torus>(
streams.subset_first_gpu(), ksed_small_to_big_expanded_lwes,
streams.get_ith(0), ksed_small_to_big_expanded_lwes,
lwe_trivial_indexes_vec[0], expanded_lwes, lwe_trivial_indexes_vec[0],
casting_keys, casting_input_dimension, casting_output_dimension,
casting_ks_base_log, casting_ks_level, num_lwes);