mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix pbs128 selection for small num samples
This commit is contained in:
committed by
Agnès Leroy
parent
e5742e63e9
commit
02312e23ea
@@ -378,9 +378,11 @@ struct int_radix_lut_custom_input_output {
|
||||
for (uint i = 0; i < active_streams.count(); i++) {
|
||||
cuda_set_device(active_streams.gpu_index(i));
|
||||
int8_t *gpu_pbs_buffer;
|
||||
auto num_blocks_on_gpu = std::max(
|
||||
THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(num_radix_blocks, i, active_streams.count()));
|
||||
auto num_blocks_on_gpu =
|
||||
std::min((int)num_radix_blocks,
|
||||
std::max(THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(num_radix_blocks, i,
|
||||
active_streams.count())));
|
||||
|
||||
uint64_t size = 0;
|
||||
execute_scratch_pbs<OutputTorus>(
|
||||
@@ -768,9 +770,11 @@ struct int_radix_lut_custom_input_output {
|
||||
lwe_aligned_vec.resize(active_streams.count());
|
||||
for (uint i = 0; i < active_streams.count(); i++) {
|
||||
uint64_t size_tracker_on_array_i = 0;
|
||||
auto inputs_on_gpu = std::max(
|
||||
THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu(max_num_radix_blocks, i,
|
||||
active_streams.count()));
|
||||
auto inputs_on_gpu =
|
||||
std::min((int)max_num_radix_blocks,
|
||||
std::max(THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(max_num_radix_blocks, i,
|
||||
active_streams.count())));
|
||||
InputTorus *d_array =
|
||||
(InputTorus *)cuda_malloc_with_size_tracking_async(
|
||||
inputs_on_gpu * (params.big_lwe_dimension + 1) *
|
||||
|
||||
@@ -65,9 +65,10 @@ void multi_gpu_alloc_lwe_async(CudaStreams streams, std::vector<Torus *> &dest,
|
||||
dest.resize(streams.count());
|
||||
for (uint i = 0; i < streams.count(); i++) {
|
||||
uint64_t size_tracker_on_gpu_i = 0;
|
||||
auto inputs_on_gpu =
|
||||
auto inputs_on_gpu = std::min(
|
||||
(int)num_inputs,
|
||||
std::max(THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(num_inputs, i, streams.count()));
|
||||
get_num_inputs_on_gpu(num_inputs, i, streams.count())));
|
||||
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
inputs_on_gpu * lwe_size * sizeof(Torus), streams.stream(i),
|
||||
streams.gpu_index(i), size_tracker_on_gpu_i, allocate_gpu_memory);
|
||||
@@ -97,9 +98,10 @@ void multi_gpu_alloc_lwe_many_lut_output_async(
|
||||
dest.resize(streams.count());
|
||||
for (uint i = 0; i < streams.count(); i++) {
|
||||
uint64_t size_tracker = 0;
|
||||
auto inputs_on_gpu =
|
||||
auto inputs_on_gpu = std::min(
|
||||
(int)num_inputs,
|
||||
std::max(THRESHOLD_MULTI_GPU,
|
||||
get_num_inputs_on_gpu(num_inputs, i, streams.count()));
|
||||
get_num_inputs_on_gpu(num_inputs, i, streams.count())));
|
||||
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
|
||||
num_many_lut * inputs_on_gpu * lwe_size * sizeof(Torus),
|
||||
streams.stream(i), streams.gpu_index(i), size_tracker,
|
||||
|
||||
Reference in New Issue
Block a user