mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): rework select to avoid using local streams
This commit is contained in:
1
.github/workflows/benchmark_gpu_erc20.yml
vendored
1
.github/workflows/benchmark_gpu_erc20.yml
vendored
@@ -12,6 +12,7 @@ on:
|
||||
- "l40 (n3-L40x1)"
|
||||
- "single-h100 (n3-H100x1)"
|
||||
- "2-h100 (n3-H100x2)"
|
||||
- "4-h100 (n3-H100x4)"
|
||||
- "multi-h100 (n3-H100x8)"
|
||||
- "multi-h100-nvlink (n3-H100x8-NVLink)"
|
||||
- "multi-h100-sxm5 (n3-H100x8-SXM5)"
|
||||
|
||||
@@ -2954,14 +2954,11 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
|
||||
|
||||
template <typename Torus> struct int_cmux_buffer {
|
||||
int_radix_lut<Torus> *predicate_lut;
|
||||
int_radix_lut<Torus> *inverted_predicate_lut;
|
||||
int_radix_lut<Torus> *message_extract_lut;
|
||||
|
||||
Torus *tmp_true_ct;
|
||||
Torus *tmp_false_ct;
|
||||
|
||||
int_zero_out_if_buffer<Torus> *zero_if_true_buffer;
|
||||
int_zero_out_if_buffer<Torus> *zero_if_false_buffer;
|
||||
Torus *buffer_in;
|
||||
Torus *buffer_out;
|
||||
Torus *condition_array;
|
||||
|
||||
int_radix_params params;
|
||||
|
||||
@@ -2977,17 +2974,12 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
Torus big_size =
|
||||
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus);
|
||||
|
||||
tmp_true_ct =
|
||||
(Torus *)cuda_malloc_async(big_size, streams[0], gpu_indexes[0]);
|
||||
tmp_false_ct =
|
||||
(Torus *)cuda_malloc_async(big_size, streams[0], gpu_indexes[0]);
|
||||
|
||||
zero_if_true_buffer = new int_zero_out_if_buffer<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
allocate_gpu_memory);
|
||||
zero_if_false_buffer = new int_zero_out_if_buffer<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
|
||||
allocate_gpu_memory);
|
||||
buffer_in =
|
||||
(Torus *)cuda_malloc_async(2 * big_size, streams[0], gpu_indexes[0]);
|
||||
buffer_out =
|
||||
(Torus *)cuda_malloc_async(2 * big_size, streams[0], gpu_indexes[0]);
|
||||
condition_array =
|
||||
(Torus *)cuda_malloc_async(2 * big_size, streams[0], gpu_indexes[0]);
|
||||
|
||||
auto lut_f = [predicate_lut_f](Torus block, Torus condition) -> Torus {
|
||||
return predicate_lut_f(condition) ? 0 : block;
|
||||
@@ -3001,12 +2993,8 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
};
|
||||
|
||||
predicate_lut =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 1,
|
||||
num_radix_blocks, allocate_gpu_memory);
|
||||
|
||||
inverted_predicate_lut =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 1,
|
||||
num_radix_blocks, allocate_gpu_memory);
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
|
||||
2 * num_radix_blocks, allocate_gpu_memory);
|
||||
|
||||
message_extract_lut =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 1,
|
||||
@@ -3015,21 +3003,33 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams[0], gpu_indexes[0], predicate_lut->get_lut(0, 0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, lut_f);
|
||||
params.carry_modulus, inverted_lut_f);
|
||||
|
||||
generate_device_accumulator_bivariate<Torus>(
|
||||
streams[0], gpu_indexes[0], inverted_predicate_lut->get_lut(0, 0),
|
||||
streams[0], gpu_indexes[0], predicate_lut->get_lut(0, 1),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, inverted_lut_f);
|
||||
params.carry_modulus, lut_f);
|
||||
|
||||
generate_device_accumulator<Torus>(
|
||||
streams[0], gpu_indexes[0], message_extract_lut->get_lut(0, 0),
|
||||
params.glwe_dimension, params.polynomial_size, params.message_modulus,
|
||||
params.carry_modulus, message_extract_lut_f);
|
||||
Torus *h_lut_indexes =
|
||||
(Torus *)malloc(2 * num_radix_blocks * sizeof(Torus));
|
||||
for (int index = 0; index < 2 * num_radix_blocks; index++) {
|
||||
if (index < num_radix_blocks) {
|
||||
h_lut_indexes[index] = 0;
|
||||
} else {
|
||||
h_lut_indexes[index] = 1;
|
||||
}
|
||||
}
|
||||
cuda_memcpy_async_to_gpu(
|
||||
predicate_lut->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
2 * num_radix_blocks * sizeof(Torus), streams[0], gpu_indexes[0]);
|
||||
|
||||
predicate_lut->broadcast_lut(streams, gpu_indexes, 0);
|
||||
inverted_predicate_lut->broadcast_lut(streams, gpu_indexes, 0);
|
||||
message_extract_lut->broadcast_lut(streams, gpu_indexes, 0);
|
||||
free(h_lut_indexes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3037,18 +3037,12 @@ template <typename Torus> struct int_cmux_buffer {
|
||||
uint32_t gpu_count) {
|
||||
predicate_lut->release(streams, gpu_indexes, gpu_count);
|
||||
delete predicate_lut;
|
||||
inverted_predicate_lut->release(streams, gpu_indexes, gpu_count);
|
||||
delete inverted_predicate_lut;
|
||||
message_extract_lut->release(streams, gpu_indexes, gpu_count);
|
||||
delete message_extract_lut;
|
||||
|
||||
zero_if_true_buffer->release(streams, gpu_indexes, gpu_count);
|
||||
delete zero_if_true_buffer;
|
||||
zero_if_false_buffer->release(streams, gpu_indexes, gpu_count);
|
||||
delete zero_if_false_buffer;
|
||||
|
||||
cuda_drop_async(tmp_true_ct, streams[0], gpu_indexes[0]);
|
||||
cuda_drop_async(tmp_false_ct, streams[0], gpu_indexes[0]);
|
||||
cuda_drop_async(buffer_in, streams[0], gpu_indexes[0]);
|
||||
cuda_drop_async(buffer_out, streams[0], gpu_indexes[0]);
|
||||
cuda_drop_async(condition_array, streams[0], gpu_indexes[0]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -37,39 +37,32 @@ __host__ void host_integer_radix_cmux_kb(
|
||||
uint32_t num_radix_blocks) {
|
||||
|
||||
auto params = mem_ptr->params;
|
||||
|
||||
// Since our CPU threads will be working on different streams we shall assert
|
||||
// the work in the main stream is completed
|
||||
auto true_streams = mem_ptr->zero_if_true_buffer->true_streams;
|
||||
auto false_streams = mem_ptr->zero_if_false_buffer->false_streams;
|
||||
for (uint j = 0; j < gpu_count; j++) {
|
||||
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
|
||||
}
|
||||
|
||||
auto mem_true = mem_ptr->zero_if_true_buffer;
|
||||
zero_out_if<Torus>(true_streams, gpu_indexes, gpu_count, mem_ptr->tmp_true_ct,
|
||||
lwe_array_true, lwe_condition, mem_true,
|
||||
mem_ptr->inverted_predicate_lut, bsks, ksks,
|
||||
num_radix_blocks);
|
||||
auto mem_false = mem_ptr->zero_if_false_buffer;
|
||||
zero_out_if<Torus>(false_streams, gpu_indexes, gpu_count,
|
||||
mem_ptr->tmp_false_ct, lwe_array_false, lwe_condition,
|
||||
mem_false, mem_ptr->predicate_lut, bsks, ksks,
|
||||
num_radix_blocks);
|
||||
for (uint j = 0; j < mem_ptr->zero_if_true_buffer->active_gpu_count; j++) {
|
||||
cuda_synchronize_stream(true_streams[j], gpu_indexes[j]);
|
||||
}
|
||||
for (uint j = 0; j < mem_ptr->zero_if_false_buffer->active_gpu_count; j++) {
|
||||
cuda_synchronize_stream(false_streams[j], gpu_indexes[j]);
|
||||
Torus lwe_size = params.big_lwe_dimension + 1;
|
||||
Torus radix_lwe_size = lwe_size * num_radix_blocks;
|
||||
cuda_memcpy_async_gpu_to_gpu(mem_ptr->buffer_in, lwe_array_true,
|
||||
radix_lwe_size * sizeof(Torus), streams[0],
|
||||
gpu_indexes[0]);
|
||||
cuda_memcpy_async_gpu_to_gpu(mem_ptr->buffer_in + radix_lwe_size,
|
||||
lwe_array_false, radix_lwe_size * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
for (uint i = 0; i < 2 * num_radix_blocks; i++) {
|
||||
cuda_memcpy_async_gpu_to_gpu(mem_ptr->condition_array + i * lwe_size,
|
||||
lwe_condition, lwe_size * sizeof(Torus),
|
||||
streams[0], gpu_indexes[0]);
|
||||
}
|
||||
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, mem_ptr->buffer_out, mem_ptr->buffer_in,
|
||||
mem_ptr->condition_array, bsks, ksks, 2 * num_radix_blocks,
|
||||
mem_ptr->predicate_lut, params.message_modulus);
|
||||
|
||||
// If the condition was true, true_ct will have kept its value and false_ct
|
||||
// will be 0 If the condition was false, true_ct will be 0 and false_ct will
|
||||
// have kept its value
|
||||
auto added_cts = mem_ptr->tmp_true_ct;
|
||||
host_addition<Torus>(streams[0], gpu_indexes[0], added_cts,
|
||||
mem_ptr->tmp_true_ct, mem_ptr->tmp_false_ct,
|
||||
params.big_lwe_dimension, num_radix_blocks);
|
||||
auto mem_true = mem_ptr->buffer_out;
|
||||
auto mem_false = &mem_ptr->buffer_out[radix_lwe_size];
|
||||
auto added_cts = mem_true;
|
||||
host_addition<Torus>(streams[0], gpu_indexes[0], added_cts, mem_true,
|
||||
mem_false, params.big_lwe_dimension, num_radix_blocks);
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, lwe_array_out, added_cts, bsks, ksks,
|
||||
|
||||
Reference in New Issue
Block a user