From 92df46f8f2a7c17bd6a4923f0c30f008faed6758 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Mon, 22 Dec 2025 18:16:15 +0100 Subject: [PATCH] fix(gpu): return to 64 regs in multi-bit pbs --- .../programmable_bootstrap_cg_multibit.cuh | 7 ++++--- .../pbs/programmable_bootstrap_multibit.cuh | 4 ++-- .../programmable_bootstrap_tbc_multibit.cuh | 21 ++++++++++--------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index db0c27959..7ea811a2f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt) Torus *global_accumulator, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset, - uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input, + uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input, int8_t *device_mem, uint64_t device_memory_size_per_block, uint32_t num_many_lut, uint32_t lut_stride) { @@ -321,8 +321,9 @@ __host__ void execute_cg_external_product_loop( lwe_chunk_size * level_count * (glwe_dimension + 1) * (glwe_dimension + 1) * (polynomial_size / 2); - uint64_t chunk_size = std::min( - lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset); + uint32_t chunk_size = (uint32_t)(std::min( + lwe_chunk_size, + (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset)); auto d_mem = buffer->d_mem_acc_cg; auto keybundle_fft = buffer->keybundle_fft; diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index 64c04050c..8f79df734 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -373,7 +373,7 @@ __global__ void __launch_bounds__(params::degree / params::opt) Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, const double2 *__restrict__ keybundle_array, Torus *global_accumulator, double2 *join_buffer, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t iteration, uint64_t lwe_chunk_size, + uint32_t level_count, uint32_t iteration, uint32_t lwe_chunk_size, int8_t *device_mem, uint64_t device_memory_size_per_block, uint32_t num_many_lut, uint32_t lut_stride) { // We use shared memory for the polynomials that are used often during the @@ -790,7 +790,7 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out, uint32_t lut_stride) { cuda_set_device(gpu_index); - auto lwe_chunk_size = buffer->lwe_chunk_size; + uint32_t lwe_chunk_size = (uint32_t)(buffer->lwe_chunk_size); uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( polynomial_size); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh index 9891dfff8..a2f0f3001 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh @@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt) Torus *global_accumulator, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset, - uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input, + uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input, int8_t *device_mem, uint64_t device_memory_size_per_block, bool support_dsm, uint32_t num_many_lut, uint32_t lut_stride) { @@ -205,10 +205,10 @@ __global__ void __launch_bounds__(params::degree / params::opt) const Torus *__restrict__ lut_vector_indexes, const Torus *__restrict__ lwe_array_in, const Torus *__restrict__ lwe_input_indexes, - const double2 *__restrict__ keybundle_array, double2 *join_buffer, - Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset, - uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input, - uint32_t num_many_lut, uint32_t lut_stride) { + const double2 *__restrict__ keybundle_array, Torus *global_accumulator, + uint32_t lwe_dimension, uint32_t lwe_offset, uint32_t lwe_chunk_size, + uint64_t keybundle_size_per_input, uint32_t num_many_lut, + uint32_t lut_stride) { constexpr uint32_t level_count = 1; constexpr uint32_t grouping_factor = 4; @@ -548,8 +548,9 @@ __host__ void execute_tbc_external_product_loop( lwe_chunk_size * level_count * (glwe_dimension + 1) * (glwe_dimension + 1) * (polynomial_size / 2); - uint64_t chunk_size = std::min( - lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset); + uint32_t chunk_size = (uint32_t)(std::min( + lwe_chunk_size, + (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset)); auto d_mem = buffer->d_mem_acc_tbc; auto keybundle_fft = buffer->keybundle_fft; @@ -624,9 +625,9 @@ __host__ void execute_tbc_external_product_loop( device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params< Torus, params, FULLSM>, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, - lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft, - global_accumulator, lwe_dimension, lwe_offset, chunk_size, - keybundle_size_per_input, num_many_lut, lut_stride)); + lwe_array_in, lwe_input_indexes, keybundle_fft, global_accumulator, + lwe_dimension, lwe_offset, chunk_size, keybundle_size_per_input, + num_many_lut, lut_stride)); } else { check_cuda_error(cudaLaunchKernelEx( &config,