mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(gpu): remove sub streams from overflowing subtraction
This commit is contained in:
@@ -2846,12 +2846,6 @@ template <typename Torus> struct int_borrow_prop_memory {
|
||||
int_radix_params params;
|
||||
|
||||
CudaStreams active_streams;
|
||||
CudaStreams sub_streams_1;
|
||||
CudaStreams sub_streams_2;
|
||||
|
||||
cudaEvent_t *incoming_events;
|
||||
cudaEvent_t *outgoing_events1;
|
||||
cudaEvent_t *outgoing_events2;
|
||||
|
||||
uint32_t compute_overflow;
|
||||
bool gpu_memory_allocated;
|
||||
@@ -2927,20 +2921,6 @@ template <typename Torus> struct int_borrow_prop_memory {
|
||||
}
|
||||
|
||||
active_streams = streams.active_gpu_subset(num_radix_blocks);
|
||||
sub_streams_1.create_on_same_gpus(active_streams);
|
||||
sub_streams_2.create_on_same_gpus(active_streams);
|
||||
|
||||
incoming_events =
|
||||
(cudaEvent_t *)malloc(active_streams.count() * sizeof(cudaEvent_t));
|
||||
outgoing_events1 =
|
||||
(cudaEvent_t *)malloc(active_streams.count() * sizeof(cudaEvent_t));
|
||||
outgoing_events2 =
|
||||
(cudaEvent_t *)malloc(active_streams.count() * sizeof(cudaEvent_t));
|
||||
for (uint j = 0; j < active_streams.count(); j++) {
|
||||
incoming_events[j] = cuda_create_event(active_streams.gpu_index(j));
|
||||
outgoing_events1[j] = cuda_create_event(active_streams.gpu_index(j));
|
||||
outgoing_events2[j] = cuda_create_event(active_streams.gpu_index(j));
|
||||
}
|
||||
};
|
||||
|
||||
// needed for the division to update the lut indexes
|
||||
@@ -2966,22 +2946,6 @@ template <typename Torus> struct int_borrow_prop_memory {
|
||||
lut_borrow_flag->release(streams);
|
||||
delete lut_borrow_flag;
|
||||
}
|
||||
|
||||
// The substreams have to be synchronized before destroying events
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
|
||||
// release events
|
||||
for (uint j = 0; j < active_streams.count(); j++) {
|
||||
cuda_event_destroy(incoming_events[j], active_streams.gpu_index(j));
|
||||
cuda_event_destroy(outgoing_events1[j], active_streams.gpu_index(j));
|
||||
cuda_event_destroy(outgoing_events2[j], active_streams.gpu_index(j));
|
||||
}
|
||||
free(incoming_events);
|
||||
free(outgoing_events1);
|
||||
free(outgoing_events2);
|
||||
|
||||
sub_streams_1.release();
|
||||
sub_streams_2.release();
|
||||
};
|
||||
};
|
||||
|
||||
@@ -2991,8 +2955,6 @@ template <typename Torus> struct int_zero_out_if_buffer {
|
||||
|
||||
CudaRadixCiphertextFFI *tmp;
|
||||
|
||||
CudaStreams true_streams;
|
||||
CudaStreams false_streams;
|
||||
bool gpu_memory_allocated;
|
||||
|
||||
int_zero_out_if_buffer(CudaStreams streams, int_radix_params params,
|
||||
@@ -3006,17 +2968,12 @@ template <typename Torus> struct int_zero_out_if_buffer {
|
||||
create_zero_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), tmp, num_radix_blocks,
|
||||
params.big_lwe_dimension, size_tracker, allocate_gpu_memory);
|
||||
// We may use a different stream to allow concurrent operation
|
||||
true_streams.create_on_same_gpus(active_streams);
|
||||
false_streams.create_on_same_gpus(active_streams);
|
||||
}
|
||||
void release(CudaStreams streams) {
|
||||
release_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), tmp,
|
||||
gpu_memory_allocated);
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
delete tmp;
|
||||
true_streams.release();
|
||||
false_streams.release();
|
||||
tmp = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -2322,53 +2322,27 @@ void host_single_borrow_propagate(
|
||||
params.carry_modulus);
|
||||
}
|
||||
|
||||
cuda_event_record(mem->incoming_events[0], streams.stream(0),
|
||||
streams.gpu_index(0));
|
||||
for (int j = 0; j < mem->active_streams.count(); j++) {
|
||||
cuda_stream_wait_event(mem->sub_streams_1.stream(j),
|
||||
mem->incoming_events[0],
|
||||
mem->sub_streams_1.gpu_index(j));
|
||||
cuda_stream_wait_event(mem->sub_streams_2.stream(j),
|
||||
mem->incoming_events[0],
|
||||
mem->sub_streams_1.gpu_index(j));
|
||||
}
|
||||
|
||||
if (compute_overflow == outputFlag::FLAG_OVERFLOW) {
|
||||
auto borrow_flag = mem->lut_borrow_flag;
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
mem->sub_streams_1, overflow_block, mem->overflow_block, bsks, ksks,
|
||||
streams, overflow_block, mem->overflow_block, bsks, ksks,
|
||||
ms_noise_reduction_key, borrow_flag, 1);
|
||||
}
|
||||
for (int j = 0; j < mem->active_streams.count(); j++) {
|
||||
cuda_event_record(mem->outgoing_events1[j], mem->sub_streams_1.stream(j),
|
||||
mem->sub_streams_1.gpu_index(j));
|
||||
}
|
||||
|
||||
// subtract borrow and cleanup prepared blocks
|
||||
auto resolved_carries = mem->prop_simu_group_carries_mem->resolved_carries;
|
||||
host_negation<Torus>(
|
||||
mem->sub_streams_2.stream(0), mem->sub_streams_2.gpu_index(0),
|
||||
(Torus *)resolved_carries->ptr, (Torus *)resolved_carries->ptr,
|
||||
big_lwe_dimension, num_groups);
|
||||
streams.stream(0), streams.gpu_index(0), (Torus *)resolved_carries->ptr,
|
||||
(Torus *)resolved_carries->ptr, big_lwe_dimension, num_groups);
|
||||
|
||||
host_radix_sum_in_groups<Torus>(
|
||||
mem->sub_streams_2.stream(0), mem->sub_streams_2.gpu_index(0),
|
||||
prepared_blocks, prepared_blocks, resolved_carries, num_radix_blocks,
|
||||
mem->group_size);
|
||||
streams.stream(0), streams.gpu_index(0), prepared_blocks, prepared_blocks,
|
||||
resolved_carries, num_radix_blocks, mem->group_size);
|
||||
|
||||
auto message_extract = mem->lut_message_extract;
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
mem->sub_streams_2, lwe_array, prepared_blocks, bsks, ksks,
|
||||
ms_noise_reduction_key, message_extract, num_radix_blocks);
|
||||
|
||||
for (int j = 0; j < mem->active_streams.count(); j++) {
|
||||
cuda_event_record(mem->outgoing_events2[j], mem->sub_streams_2.stream(j),
|
||||
mem->sub_streams_2.gpu_index(j));
|
||||
cuda_stream_wait_event(streams.stream(0), mem->outgoing_events1[j],
|
||||
streams.gpu_index(0));
|
||||
cuda_stream_wait_event(streams.stream(0), mem->outgoing_events2[j],
|
||||
streams.gpu_index(0));
|
||||
}
|
||||
streams, lwe_array, prepared_blocks, bsks, ksks, ms_noise_reduction_key,
|
||||
message_extract, num_radix_blocks);
|
||||
}
|
||||
|
||||
/// num_radix_blocks corresponds to the number of blocks on which to apply the
|
||||
|
||||
Reference in New Issue
Block a user