mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
1 Commits
as/ensure_
...
go/fix/avo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4a489e9b0 |
@@ -486,4 +486,38 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Event pool for managing temporary CUDA events in scatter/gather operations
|
||||
struct CudaEventPool {
|
||||
private:
|
||||
std::vector<cudaEvent_t> _events;
|
||||
std::vector<uint32_t> _gpu_indices;
|
||||
|
||||
public:
|
||||
CudaEventPool() {}
|
||||
|
||||
// Requests a new event from the pool (creates and stores it)
|
||||
cudaEvent_t request_event(uint32_t gpu_index) {
|
||||
cudaEvent_t event = cuda_create_event(gpu_index);
|
||||
_events.push_back(event);
|
||||
_gpu_indices.push_back(gpu_index);
|
||||
return event;
|
||||
}
|
||||
|
||||
// Releases all pooled events
|
||||
// This should always be called in the release of the LUT, so streams
|
||||
// are already synchronized
|
||||
void release() {
|
||||
for (size_t i = 0; i < _events.size(); i++) {
|
||||
cuda_event_destroy(_events[i], _gpu_indices[i]);
|
||||
}
|
||||
_events.clear();
|
||||
_gpu_indices.clear();
|
||||
}
|
||||
|
||||
~CudaEventPool() {
|
||||
GPU_ASSERT(_events.empty(),
|
||||
"CudaEventPool: must call release before destruction");
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -349,6 +349,7 @@ struct int_radix_lut_custom_input_output {
|
||||
|
||||
CudaStreamsBarrier multi_gpu_scatter_barrier, multi_gpu_broadcast_barrier;
|
||||
CudaStreamsBarrier multi_gpu_gather_barrier;
|
||||
CudaEventPool event_pool;
|
||||
|
||||
// Setup the LUT configuration:
|
||||
// input_big_lwe_dimension: BIG LWE dimension of the KS output / PBS input
|
||||
@@ -865,6 +866,7 @@ struct int_radix_lut_custom_input_output {
|
||||
|
||||
if (active_streams.count() > 1) {
|
||||
active_streams.synchronize();
|
||||
event_pool.release();
|
||||
multi_gpu_gather_barrier.release();
|
||||
multi_gpu_broadcast_barrier.release();
|
||||
multi_gpu_scatter_barrier.release();
|
||||
|
||||
@@ -368,7 +368,7 @@ host_integer_decompress(CudaStreams streams,
|
||||
/// gather data to GPU 0 we can copy back to the original indexing
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, extracted_lwe, lut->lwe_indexes_in,
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec, lut->event_pool,
|
||||
lut->active_streams.count(), num_blocks_to_decompress,
|
||||
compression_params.small_lwe_dimension + 1);
|
||||
|
||||
@@ -388,7 +388,7 @@ host_integer_decompress(CudaStreams streams,
|
||||
multi_gpu_gather_lwe_async<Torus>(
|
||||
active_streams, (Torus *)d_lwe_array_out->ptr, lwe_after_pbs_vec,
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_blocks_to_decompress,
|
||||
lut->lwe_aligned_vec, lut->event_pool, num_blocks_to_decompress,
|
||||
encryption_params.big_lwe_dimension + 1);
|
||||
|
||||
/// Synchronize all GPUs
|
||||
|
||||
@@ -570,8 +570,8 @@ __host__ void integer_radix_apply_univariate_lookup_table(
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_vec, lut->event_pool, lut->active_streams.count(),
|
||||
num_radix_blocks, big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
/// Apply KS to go from a big LWE dimension to a small LWE dimension
|
||||
execute_keyswitch_async<Torus>(
|
||||
@@ -594,7 +594,8 @@ __host__ void integer_radix_apply_univariate_lookup_table(
|
||||
multi_gpu_gather_lwe_async<Torus>(
|
||||
active_streams, (Torus *)lwe_array_out->ptr, lwe_after_pbs_vec,
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_radix_blocks, big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_vec, lut->event_pool, num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
@@ -674,8 +675,8 @@ __host__ void integer_radix_apply_many_univariate_lookup_table(
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_vec, lut->event_pool, lut->active_streams.count(),
|
||||
num_radix_blocks, big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
/// Apply KS to go from a big LWE dimension to a small LWE dimension
|
||||
execute_keyswitch_async<Torus>(
|
||||
@@ -791,8 +792,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, (Torus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, lut->active_streams.count(), num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_vec, lut->event_pool, lut->active_streams.count(),
|
||||
num_radix_blocks, big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
/// Apply KS to go from a big LWE dimension to a small LWE dimension
|
||||
execute_keyswitch_async<Torus>(
|
||||
@@ -815,7 +816,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table(
|
||||
multi_gpu_gather_lwe_async<Torus>(
|
||||
active_streams, (Torus *)(lwe_array_out->ptr), lwe_after_pbs_vec,
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_radix_blocks, big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_vec, lut->event_pool, num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
active_streams);
|
||||
@@ -2373,8 +2375,9 @@ integer_radix_apply_noise_squashing(CudaStreams streams,
|
||||
multi_gpu_scatter_lwe_async<InputTorus>(
|
||||
active_streams, lwe_array_in_vec, (InputTorus *)lwe_array_pbs_in->ptr,
|
||||
lut->lwe_indexes_in, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_scatter_vec, lut->active_streams.count(),
|
||||
lwe_array_out->num_radix_blocks, lut->input_big_lwe_dimension + 1);
|
||||
lut->lwe_aligned_scatter_vec, lut->event_pool,
|
||||
lut->active_streams.count(), lwe_array_out->num_radix_blocks,
|
||||
lut->input_big_lwe_dimension + 1);
|
||||
|
||||
execute_keyswitch_async<InputTorus>(
|
||||
active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec,
|
||||
@@ -2397,7 +2400,8 @@ integer_radix_apply_noise_squashing(CudaStreams streams,
|
||||
multi_gpu_gather_lwe_async<__uint128_t>(
|
||||
active_streams, (__uint128_t *)lwe_array_out->ptr, lwe_after_pbs_vec,
|
||||
nullptr, lut->using_trivial_lwe_indexes, lut->lwe_aligned_gather_vec,
|
||||
lwe_array_out->num_radix_blocks, big_lwe_dimension + 1);
|
||||
lut->event_pool, lwe_array_out->num_radix_blocks,
|
||||
big_lwe_dimension + 1);
|
||||
|
||||
/// Synchronize all GPUs
|
||||
streams.synchronize();
|
||||
|
||||
@@ -54,7 +54,7 @@ void host_integer_grouped_oprf(CudaStreams streams,
|
||||
PUSH_RANGE("scatter")
|
||||
multi_gpu_scatter_lwe_async<Torus>(
|
||||
active_streams, lwe_array_in_vec, seeded_lwe_input, lut->lwe_indexes_in,
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec,
|
||||
lut->using_trivial_lwe_indexes, lut->lwe_aligned_vec, lut->event_pool,
|
||||
active_streams.count(), num_blocks_to_process,
|
||||
mem_ptr->params.small_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
@@ -72,7 +72,7 @@ void host_integer_grouped_oprf(CudaStreams streams,
|
||||
multi_gpu_gather_lwe_async<Torus>(
|
||||
active_streams, (Torus *)radix_lwe_out->ptr, lwe_after_pbs_vec,
|
||||
lut->lwe_indexes_out, lut->using_trivial_lwe_indexes,
|
||||
lut->lwe_aligned_vec, num_blocks_to_process,
|
||||
lut->lwe_aligned_vec, lut->event_pool, num_blocks_to_process,
|
||||
mem_ptr->params.big_lwe_dimension + 1);
|
||||
POP_RANGE()
|
||||
lut->multi_gpu_gather_barrier.stream_0_wait_for_local_streams(
|
||||
|
||||
@@ -158,13 +158,11 @@ __global__ void realign_with_indexes(Torus *d_vector,
|
||||
/// The output indexing is always the trivial one
|
||||
/// num_inputs: total num of lwe in src
|
||||
template <typename Torus>
|
||||
void multi_gpu_scatter_lwe_async(CudaStreams streams,
|
||||
const std::vector<Torus *> &dest,
|
||||
Torus const *src, Torus const *d_src_indexes,
|
||||
bool is_trivial_index,
|
||||
std::vector<Torus *> &aligned_vec,
|
||||
uint32_t max_active_gpu_count,
|
||||
uint32_t num_inputs, uint32_t lwe_size) {
|
||||
void multi_gpu_scatter_lwe_async(
|
||||
CudaStreams streams, const std::vector<Torus *> &dest, Torus const *src,
|
||||
Torus const *d_src_indexes, bool is_trivial_index,
|
||||
std::vector<Torus *> &aligned_vec, CudaEventPool &event_pool,
|
||||
uint32_t max_active_gpu_count, uint32_t num_inputs, uint32_t lwe_size) {
|
||||
|
||||
PANIC_IF_FALSE(
|
||||
max_active_gpu_count >= streams.count(),
|
||||
@@ -193,7 +191,7 @@ void multi_gpu_scatter_lwe_async(CudaStreams streams,
|
||||
if (d_src_indexes == nullptr)
|
||||
PANIC("Cuda error: source indexes should be initialized!");
|
||||
|
||||
cudaEvent_t temp_event2 = cuda_create_event(streams.gpu_index(0));
|
||||
cudaEvent_t temp_event2 = event_pool.request_event(streams.gpu_index(0));
|
||||
cuda_set_device(streams.gpu_index(0));
|
||||
align_with_indexes<Torus><<<inputs_on_gpu, 1024, 0, streams.stream(0)>>>(
|
||||
aligned_vec[i], (Torus *)src, (Torus *)d_src_indexes + gpu_offset,
|
||||
@@ -207,7 +205,7 @@ void multi_gpu_scatter_lwe_async(CudaStreams streams,
|
||||
dest[i], aligned_vec[i], inputs_on_gpu * lwe_size * sizeof(Torus),
|
||||
streams.stream(i), streams.gpu_index(i), true);
|
||||
|
||||
cudaEvent_t temp_event = cuda_create_event(streams.gpu_index(i));
|
||||
cudaEvent_t temp_event = event_pool.request_event(streams.gpu_index(i));
|
||||
cuda_event_record(temp_event, streams.stream(i), streams.gpu_index(i));
|
||||
cuda_stream_wait_event(streams.stream(0), temp_event,
|
||||
streams.gpu_index(0));
|
||||
@@ -223,7 +221,8 @@ void multi_gpu_gather_lwe_async(CudaStreams streams, Torus *dest,
|
||||
const std::vector<Torus *> &src,
|
||||
Torus *d_dest_indexes, bool is_trivial_index,
|
||||
std::vector<Torus *> &aligned_vec,
|
||||
uint32_t num_inputs, uint32_t lwe_size) {
|
||||
CudaEventPool &event_pool, uint32_t num_inputs,
|
||||
uint32_t lwe_size) {
|
||||
|
||||
PANIC_IF_FALSE(src.size() >= streams.count(),
|
||||
"Cuda error: src vector was not allocated for enough GPUs");
|
||||
@@ -247,7 +246,7 @@ void multi_gpu_gather_lwe_async(CudaStreams streams, Torus *dest,
|
||||
if (d_dest_indexes == nullptr)
|
||||
PANIC("Cuda error: destination indexes should be initialized!");
|
||||
|
||||
cudaEvent_t temp_event2 = cuda_create_event(streams.gpu_index(0));
|
||||
cudaEvent_t temp_event2 = event_pool.request_event(streams.gpu_index(0));
|
||||
|
||||
cuda_event_record(temp_event2, streams.stream(0), streams.gpu_index(0));
|
||||
cuda_stream_wait_event(streams.stream(i), temp_event2,
|
||||
@@ -257,7 +256,7 @@ void multi_gpu_gather_lwe_async(CudaStreams streams, Torus *dest,
|
||||
aligned_vec[i], src[i], inputs_on_gpu * lwe_size * sizeof(Torus),
|
||||
streams.stream(i), streams.gpu_index(i), true);
|
||||
|
||||
cudaEvent_t temp_event3 = cuda_create_event(streams.gpu_index(i));
|
||||
cudaEvent_t temp_event3 = event_pool.request_event(streams.gpu_index(i));
|
||||
cuda_event_record(temp_event3, streams.stream(i), streams.gpu_index(i));
|
||||
cuda_stream_wait_event(streams.stream(0), temp_event3,
|
||||
streams.gpu_index(0));
|
||||
|
||||
Reference in New Issue
Block a user