fix(gpu): fix pbs128 selection for small num samples

This commit is contained in:
Guillermo Oyarzun
2025-11-21 08:54:59 +01:00
committed by Agnès Leroy
parent e5742e63e9
commit 02312e23ea
2 changed files with 16 additions and 10 deletions

View File

@@ -378,9 +378,11 @@ struct int_radix_lut_custom_input_output {
for (uint i = 0; i < active_streams.count(); i++) {
cuda_set_device(active_streams.gpu_index(i));
int8_t *gpu_pbs_buffer;
auto num_blocks_on_gpu = std::max(
THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(num_radix_blocks, i, active_streams.count()));
auto num_blocks_on_gpu =
std::min((int)num_radix_blocks,
std::max(THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(num_radix_blocks, i,
active_streams.count())));
uint64_t size = 0;
execute_scratch_pbs<OutputTorus>(
@@ -768,9 +770,11 @@ struct int_radix_lut_custom_input_output {
lwe_aligned_vec.resize(active_streams.count());
for (uint i = 0; i < active_streams.count(); i++) {
uint64_t size_tracker_on_array_i = 0;
auto inputs_on_gpu = std::max(
THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu(max_num_radix_blocks, i,
active_streams.count()));
auto inputs_on_gpu =
std::min((int)max_num_radix_blocks,
std::max(THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(max_num_radix_blocks, i,
active_streams.count())));
InputTorus *d_array =
(InputTorus *)cuda_malloc_with_size_tracking_async(
inputs_on_gpu * (params.big_lwe_dimension + 1) *

View File

@@ -65,9 +65,10 @@ void multi_gpu_alloc_lwe_async(CudaStreams streams, std::vector<Torus *> &dest,
dest.resize(streams.count());
for (uint i = 0; i < streams.count(); i++) {
uint64_t size_tracker_on_gpu_i = 0;
auto inputs_on_gpu =
auto inputs_on_gpu = std::min(
(int)num_inputs,
std::max(THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(num_inputs, i, streams.count()));
get_num_inputs_on_gpu(num_inputs, i, streams.count())));
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
inputs_on_gpu * lwe_size * sizeof(Torus), streams.stream(i),
streams.gpu_index(i), size_tracker_on_gpu_i, allocate_gpu_memory);
@@ -97,9 +98,10 @@ void multi_gpu_alloc_lwe_many_lut_output_async(
dest.resize(streams.count());
for (uint i = 0; i < streams.count(); i++) {
uint64_t size_tracker = 0;
auto inputs_on_gpu =
auto inputs_on_gpu = std::min(
(int)num_inputs,
std::max(THRESHOLD_MULTI_GPU,
get_num_inputs_on_gpu(num_inputs, i, streams.count()));
get_num_inputs_on_gpu(num_inputs, i, streams.count())));
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
num_many_lut * inputs_on_gpu * lwe_size * sizeof(Torus),
streams.stream(i), streams.gpu_index(i), size_tracker,