mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
2 Commits
main
...
as/radix_l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
705f383ad2 | ||
|
|
3be7aae8f3 |
@@ -5,6 +5,7 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
|
||||
extern "C" {
|
||||
|
||||
@@ -140,4 +141,34 @@ template <typename Torus>
|
||||
void cuda_set_value_async(cudaStream_t stream, uint32_t gpu_index,
|
||||
Torus *d_array, Torus value, Torus n);
|
||||
|
||||
template <class T> struct malloc_with_size_tracking_async_deleter {
|
||||
private:
|
||||
cudaStream_t _stream;
|
||||
uint32_t _gpu_index;
|
||||
uint64_t &_size_tracker;
|
||||
bool _allocate_gpu_memory;
|
||||
|
||||
public:
|
||||
malloc_with_size_tracking_async_deleter(cudaStream_t stream,
|
||||
uint32_t gpu_index,
|
||||
uint64_t &size_tracker,
|
||||
bool allocate_gpu_memory)
|
||||
: _stream(stream), _gpu_index(gpu_index), _size_tracker(size_tracker),
|
||||
_allocate_gpu_memory(allocate_gpu_memory)
|
||||
|
||||
{}
|
||||
void operator()(T *ptr) { cuda_drop_with_size_tracking_async(ptr, _stream, _gpu_index, _allocate_gpu_memory) ; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
std::shared_ptr<T> cuda_make_shared_with_size_tracking_async(
|
||||
uint64_t size, cudaStream_t stream, uint32_t gpu_index,
|
||||
uint64_t &size_tracker, bool allocate_gpu_memory) {
|
||||
return std::shared_ptr<T>(
|
||||
(T*)cuda_malloc_with_size_tracking_async(size, stream, gpu_index,
|
||||
size_tracker, allocate_gpu_memory),
|
||||
malloc_with_size_tracking_async_deleter<T>(
|
||||
stream, gpu_index, size_tracker, allocate_gpu_memory));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -183,4 +183,93 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct CudaStreamsBarrier {
|
||||
private:
|
||||
std::vector<cudaEvent_t> _events;
|
||||
CudaStreams _streams;
|
||||
|
||||
CudaStreamsBarrier(const CudaStreamsBarrier &) {} // Prevent copy-construction
|
||||
CudaStreamsBarrier &operator=(const CudaStreamsBarrier &) {
|
||||
return *this;
|
||||
} // Prevent assignment
|
||||
public:
|
||||
void create_on(const CudaStreams &streams) {
|
||||
_streams = streams;
|
||||
|
||||
GPU_ASSERT(streams.count() > 1, "CudaStreamsFirstWaitsWorkersBarrier: "
|
||||
"Attempted to create on single GPU");
|
||||
_events.resize(streams.count());
|
||||
for (int i = 0; i < streams.count(); i++) {
|
||||
_events[i] = cuda_create_event(streams.gpu_index(i));
|
||||
}
|
||||
}
|
||||
|
||||
CudaStreamsBarrier(){};
|
||||
|
||||
void local_streams_wait_for_stream_0(const CudaStreams &user_streams) {
|
||||
GPU_ASSERT(!_events.empty(),
|
||||
"CudaStreamsBarrier: must call create_on before use");
|
||||
GPU_ASSERT(user_streams.gpu_index(0) == _streams.gpu_index(0),
|
||||
"CudaStreamsBarrier: synchronization can only be performed on "
|
||||
"the GPUs the barrier was initially created on.");
|
||||
|
||||
cuda_event_record(_events[0], user_streams.stream(0),
|
||||
user_streams.gpu_index(0));
|
||||
for (int j = 1; j < user_streams.count(); j++) {
|
||||
GPU_ASSERT(user_streams.gpu_index(j) == _streams.gpu_index(j),
|
||||
"CudaStreamsBarrier: synchronization can only be performed on "
|
||||
"the GPUs the barrier was initially created on.");
|
||||
cuda_stream_wait_event(user_streams.stream(j), _events[0],
|
||||
user_streams.gpu_index(j));
|
||||
}
|
||||
}
|
||||
|
||||
void stream_0_wait_for_local_streams(const CudaStreams &user_streams) {
|
||||
GPU_ASSERT(
|
||||
!_events.empty(),
|
||||
"CudaStreamsFirstWaitsWorkersBarrier: must call create_on before use");
|
||||
GPU_ASSERT(
|
||||
user_streams.count() <= _events.size(),
|
||||
"CudaStreamsFirstWaitsWorkersBarrier: trying to synchronize too many "
|
||||
"streams. "
|
||||
"The barrier was created on a LUT that had %lu active streams, while "
|
||||
"the user stream set has %u streams",
|
||||
_events.size(), user_streams.count());
|
||||
|
||||
if (user_streams.count() > 1) {
|
||||
// Worker GPUs record their events
|
||||
for (int j = 1; j < user_streams.count(); j++) {
|
||||
GPU_ASSERT(_streams.gpu_index(j) == user_streams.gpu_index(j),
|
||||
"CudaStreamsBarrier: The user stream "
|
||||
"set GPU[%d]=%u while the LUT stream set GPU[%d]=%u",
|
||||
j, user_streams.gpu_index(j), j, _streams.gpu_index(j));
|
||||
|
||||
cuda_event_record(_events[j], user_streams.stream(j),
|
||||
user_streams.gpu_index(j));
|
||||
}
|
||||
|
||||
// GPU 0 waits for all workers
|
||||
for (int j = 1; j < user_streams.count(); j++) {
|
||||
cuda_stream_wait_event(user_streams.stream(0), _events[j],
|
||||
user_streams.gpu_index(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void release() {
|
||||
for (int j = 0; j < _streams.count(); j++) {
|
||||
cuda_event_destroy(_events[j], _streams.gpu_index(j));
|
||||
}
|
||||
|
||||
_events.clear();
|
||||
}
|
||||
|
||||
~CudaStreamsBarrier() {
|
||||
GPU_ASSERT(_events.empty(),
|
||||
"CudaStreamsBarrier: must "
|
||||
"call release before destruction: events size = %lu",
|
||||
_events.size());
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -266,6 +266,11 @@ void cuda_memcpy_with_size_tracking_async_gpu_to_gpu(
|
||||
uint32_t gpu_index, bool gpu_memory_allocated) {
|
||||
if (size == 0 || !gpu_memory_allocated)
|
||||
return;
|
||||
GPU_ASSERT(dest != nullptr,
|
||||
"Cuda error: trying to copy gpu->gpu to null ptr");
|
||||
GPU_ASSERT(src != nullptr,
|
||||
"Cuda error: trying to copy gpu->gpu from null ptr");
|
||||
|
||||
cudaPointerAttributes attr_dest;
|
||||
check_cuda_error(cudaPointerGetAttributes(&attr_dest, dest));
|
||||
PANIC_IF_FALSE(
|
||||
|
||||
@@ -31,8 +31,8 @@ __host__ void zero_out_if(CudaStreams streams,
|
||||
// second operand is not an array
|
||||
auto tmp_lwe_array_input = mem_ptr->tmp;
|
||||
host_pack_bivariate_blocks_with_single_block<Torus>(
|
||||
streams, tmp_lwe_array_input, predicate->lwe_indexes_in, lwe_array_input,
|
||||
lwe_condition, predicate->lwe_indexes_in, params.message_modulus,
|
||||
streams, tmp_lwe_array_input, predicate->lwe_indexes_in.get(), lwe_array_input,
|
||||
lwe_condition, predicate->lwe_indexes_in.get(), params.message_modulus,
|
||||
num_radix_blocks);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
|
||||
@@ -344,7 +344,7 @@ host_integer_decompress(CudaStreams streams,
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
active_streams, (Torus *)d_lwe_array_out->ptr, lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec, extracted_lwe,
|
||||
lut->lwe_indexes_in, d_bsks, lut->buffer,
|
||||
lut->lwe_indexes_in.get(), d_bsks, lut->buffer,
|
||||
encryption_params.glwe_dimension,
|
||||
compression_params.small_lwe_dimension,
|
||||
encryption_params.polynomial_size, encryption_params.pbs_base_log,
|
||||
@@ -359,17 +359,13 @@ host_integer_decompress(CudaStreams streams,
|
||||
std::vector<Torus *> lwe_trivial_indexes_vec =
|
||||
lut->lwe_trivial_indexes_vec;
|
||||
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(j), lut->event_scatter_in,
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
lut->multi_gpu_scatter_barrier.local_streams_wait_for_stream_0(
|
||||
active_streams);
|
||||
|
||||
/// With multiple GPUs we push to the vectors on each GPU then when we
|
||||
/// gather data to GPU 0 we can copy back to the original indexing
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, extracted_lwe, lut->lwe_indexes_in,
|
||||
active_streams, lwe_array_in_vec, extracted_lwe, lut->lwe_indexes_in.get(),
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
|
||||
lut->active_streams.count(), num_blocks_to_decompress,
|
||||
compression_params.small_lwe_dimension + 1);
|
||||
@@ -395,15 +391,8 @@ host_integer_decompress(CudaStreams streams,
|
||||
|
||||
/// Synchronize all GPUs
|
||||
// other gpus record their events
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_event_record(lut->event_scatter_out[j], active_streams.stream(j),
|
||||
active_streams.gpu_index(j));
|
||||
}
|
||||
// GPU 0 waits for all
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(0), lut->event_scatter_out[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<Torus, __uint128_t>,
|
||||
|
||||
@@ -546,7 +546,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in.get(), ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
@@ -560,19 +560,15 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
} else {
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(j), lut->event_scatter_in,
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
lut->multi_gpu_scatter_barrier.local_streams_wait_for_stream_0(
|
||||
active_streams);
|
||||
|
||||
/// With multiple GPUs we push to the vectors on each GPU then when we
|
||||
/// gather data to GPU 0 we can copy back to the original indexing
|
||||
PUSH_RANGE("scatter")
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
@@ -598,16 +594,8 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_radix_blocks, big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
// other gpus record their events
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_event_record(lut->event_scatter_out[j], streams.stream(j),
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
// GPU 0 waits for all
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(0), lut->event_scatter_out[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
}
|
||||
for (uint i = 0; i < num_radix_blocks; i++) {
|
||||
auto degrees_index = lut->h_lut_indexes[i];
|
||||
@@ -660,7 +648,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
(Torus *)lwe_array_in->ptr, lut->lwe_indexes_in.get(), ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
@@ -674,18 +662,15 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
} else {
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(j), lut->event_scatter_in,
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
lut->multi_gpu_scatter_barrier.local_streams_wait_for_stream_0(
|
||||
active_streams);
|
||||
|
||||
/// With multiple GPUs we push to the vectors on each GPU then when we
|
||||
/// gather data to GPU 0 we can copy back to the original indexing
|
||||
PUSH_RANGE("scatter")
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
@@ -712,16 +697,8 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
num_radix_blocks, big_lwe_dimension + 1, num_many_lut);
|
||||
POP_RANGE()
|
||||
|
||||
// other gpus record their events
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_event_record(lut->event_scatter_out[j], streams.stream(j),
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
// GPU 0 waits for all
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(0), lut->event_scatter_out[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
}
|
||||
for (uint i = 0; i < lwe_array_out->num_radix_blocks; i++) {
|
||||
auto degrees_index = lut->h_lut_indexes[i % lut->num_blocks];
|
||||
@@ -771,10 +748,10 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
uint32_t lut_stride = 0;
|
||||
|
||||
// Left message is shifted
|
||||
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks;
|
||||
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks.get();
|
||||
host_pack_bivariate_blocks<Torus>(
|
||||
streams, lwe_array_pbs_in, lut->lwe_trivial_indexes, lwe_array_1,
|
||||
lwe_array_2, lut->lwe_indexes_in, shift, num_radix_blocks,
|
||||
lwe_array_2, lut->lwe_indexes_in.get(), shift, num_radix_blocks,
|
||||
params.message_modulus, params.carry_modulus);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
@@ -789,7 +766,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<Torus>(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
(Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in.get(), ksks,
|
||||
big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level,
|
||||
num_radix_blocks);
|
||||
|
||||
@@ -802,16 +779,13 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, num_radix_blocks, pbs_type, num_many_lut, lut_stride);
|
||||
} else {
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(j), lut->event_scatter_in,
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
lut->multi_gpu_scatter_barrier.local_streams_wait_for_stream_0(
|
||||
active_streams);
|
||||
|
||||
PUSH_RANGE("scatter")
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
@@ -837,16 +811,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_radix_blocks, big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
// other gpus record their events
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_event_record(lut->event_scatter_out[j], streams.stream(j),
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
// GPU 0 waits for all
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(0), lut->event_scatter_out[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
}
|
||||
for (uint i = 0; i < num_radix_blocks; i++) {
|
||||
auto degrees_index = lut->h_lut_indexes[i];
|
||||
@@ -2368,7 +2334,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
|
||||
/// For multi GPU execution we create vectors of pointers for inputs and
|
||||
/// outputs
|
||||
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks;
|
||||
auto lwe_array_pbs_in = lut->tmp_lwe_before_ks.get();
|
||||
std::vector<InputTorus *> lwe_array_in_vec = lut->lwe_array_in_vec;
|
||||
std::vector<InputTorus *> lwe_after_ks_vec = lut->lwe_after_ks_vec;
|
||||
std::vector<__uint128_t *> lwe_after_pbs_vec = lut->lwe_after_pbs_vec;
|
||||
@@ -2387,7 +2353,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
if (active_streams.count() == 1) {
|
||||
execute_keyswitch_async<InputTorus>(
|
||||
streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0],
|
||||
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks,
|
||||
(InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in.get(), ksks,
|
||||
lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
|
||||
ks_level, lwe_array_out->num_radix_blocks);
|
||||
|
||||
@@ -2399,7 +2365,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
execute_pbs_async<uint64_t, __uint128_t>(
|
||||
streams.get_ith(0), (__uint128_t *)lwe_array_out->ptr,
|
||||
lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->pbs_buffer,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer,
|
||||
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
|
||||
pbs_level, grouping_factor, lwe_array_out->num_radix_blocks,
|
||||
params.pbs_type, 0, 0);
|
||||
@@ -2411,7 +2377,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
/// gather data to GPU 0 we can copy back to the original indexing
|
||||
multi_gpu_scatter_lwe_async<InputTorus>(
|
||||
active_streams, lwe_array_in_vec, (InputTorus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_indexes_in.get(), lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_scatter_vec, lut->active_streams.count(),
|
||||
lwe_array_out->num_radix_blocks, lut->input_big_lwe_dimension + 1);
|
||||
|
||||
@@ -2426,7 +2392,7 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
execute_pbs_async<uint64_t, __uint128_t>(
|
||||
active_streams, lwe_after_pbs_vec, lwe_trivial_indexes_vec,
|
||||
lut->lut_vec, lwe_trivial_indexes_vec, lwe_after_ks_vec,
|
||||
lwe_trivial_indexes_vec, bsks, lut->pbs_buffer, glwe_dimension,
|
||||
lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, lwe_array_out->num_radix_blocks, params.pbs_type, 0,
|
||||
0);
|
||||
|
||||
@@ -375,7 +375,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
while (needs_processing) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in.get();
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
calculate_chunks<Torus>
|
||||
<<<number_of_blocks_2d, number_of_threads, 0, streams.stream(0)>>>(
|
||||
@@ -433,7 +433,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
|
||||
if (mem_ptr->reduce_degrees_for_single_carry_propagation) {
|
||||
auto luts_message_carry = mem_ptr->luts_message_carry;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in;
|
||||
auto d_pbs_indexes_in = mem_ptr->luts_message_carry->lwe_indexes_in.get();
|
||||
auto d_pbs_indexes_out = mem_ptr->luts_message_carry->lwe_indexes_out;
|
||||
prepare_final_pbs_indexes<Torus>
|
||||
<<<1, 2 * num_radix_blocks, 0, streams.stream(0)>>>(
|
||||
|
||||
@@ -34,7 +34,7 @@ void host_integer_grouped_oprf(CudaStreams streams,
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams.get_ith(0), (Torus *)(radix_lwe_out->ptr), lut->lwe_indexes_out,
|
||||
lut->lut_vec, lut->lut_indexes_vec,
|
||||
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
|
||||
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in.get(), bsks,
|
||||
lut->buffer, mem_ptr->params.glwe_dimension,
|
||||
mem_ptr->params.small_lwe_dimension, mem_ptr->params.polynomial_size,
|
||||
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
|
||||
@@ -45,16 +45,12 @@ void host_integer_grouped_oprf(CudaStreams streams,
|
||||
std::vector<Torus *> lwe_after_pbs_vec = lut->lwe_after_pbs_vec;
|
||||
std::vector<Torus *> lwe_trivial_indexes_vec = lut->lwe_trivial_indexes_vec;
|
||||
|
||||
cuda_event_record(lut->event_scatter_in, streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(j), lut->event_scatter_in,
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
lut->multi_gpu_scatter_barrier.local_streams_wait_for_stream_0(
|
||||
active_streams);
|
||||
|
||||
PUSH_RANGE("scatter")
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, seeded_lwe_input, lut->lwe_indexes_in,
|
||||
active_streams, lwe_array_in_vec, seeded_lwe_input, lut->lwe_indexes_in.get(),
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
|
||||
active_streams.count(), num_blocks_to_process,
|
||||
mem_ptr->params.small_lwe_dimension + 1);
|
||||
@@ -76,16 +72,8 @@ void host_integer_grouped_oprf(CudaStreams streams,
|
||||
lut->lwe_aligned_vec, num_blocks_to_process,
|
||||
mem_ptr->params.big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
// other gpus record their events
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_event_record(lut->event_scatter_out[j], streams.stream(j),
|
||||
streams.gpu_index(j));
|
||||
}
|
||||
// GPU 0 waits for all
|
||||
for (int j = 1; j < active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(streams.stream(0), lut->event_scatter_out[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_blocks_to_process; i++) {
|
||||
|
||||
@@ -150,7 +150,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
|
||||
// control_bit|b|a
|
||||
host_pack_bivariate_blocks<Torus>(
|
||||
streams, mux_inputs, mux_lut->lwe_indexes_out, rotated_input,
|
||||
input_bits_a, mux_lut->lwe_indexes_in, 2, total_nb_bits,
|
||||
input_bits_a, mux_lut->lwe_indexes_in.get(), 2, total_nb_bits,
|
||||
mem->params.message_modulus, mem->params.carry_modulus);
|
||||
|
||||
// The shift bit is already properly aligned/positioned
|
||||
|
||||
Reference in New Issue
Block a user