fix(gpu): return to 64 regs in multi-bit pbs

This commit is contained in:
Guillermo Oyarzun
2025-12-22 18:16:15 +01:00
parent effb7ada6d
commit 92df46f8f2
3 changed files with 17 additions and 15 deletions

View File

@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
Torus *global_accumulator, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input,
int8_t *device_mem, uint64_t device_memory_size_per_block,
uint32_t num_many_lut, uint32_t lut_stride) {
@@ -321,8 +321,9 @@ __host__ void execute_cg_external_product_loop(
lwe_chunk_size * level_count * (glwe_dimension + 1) *
(glwe_dimension + 1) * (polynomial_size / 2);
uint64_t chunk_size = std::min(
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
uint32_t chunk_size = (uint32_t)(std::min(
lwe_chunk_size,
(uint64_t)(lwe_dimension / grouping_factor) - lwe_offset));
auto d_mem = buffer->d_mem_acc_cg;
auto keybundle_fft = buffer->keybundle_fft;

View File

@@ -373,7 +373,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const double2 *__restrict__ keybundle_array, Torus *global_accumulator,
double2 *join_buffer, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t iteration, uint64_t lwe_chunk_size,
uint32_t level_count, uint32_t iteration, uint32_t lwe_chunk_size,
int8_t *device_mem, uint64_t device_memory_size_per_block,
uint32_t num_many_lut, uint32_t lut_stride) {
// We use shared memory for the polynomials that are used often during the
@@ -790,7 +790,7 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
uint32_t lut_stride) {
cuda_set_device(gpu_index);
auto lwe_chunk_size = buffer->lwe_chunk_size;
uint32_t lwe_chunk_size = (uint32_t)(buffer->lwe_chunk_size);
uint64_t full_sm_accumulate_step_two =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
polynomial_size);

View File

@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
Torus *global_accumulator, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
uint32_t lwe_chunk_size, uint64_t keybundle_size_per_input,
int8_t *device_mem, uint64_t device_memory_size_per_block,
bool support_dsm, uint32_t num_many_lut, uint32_t lut_stride) {
@@ -205,10 +205,10 @@ __global__ void __launch_bounds__(params::degree / params::opt)
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
uint32_t num_many_lut, uint32_t lut_stride) {
const double2 *__restrict__ keybundle_array, Torus *global_accumulator,
uint32_t lwe_dimension, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint64_t keybundle_size_per_input, uint32_t num_many_lut,
uint32_t lut_stride) {
constexpr uint32_t level_count = 1;
constexpr uint32_t grouping_factor = 4;
@@ -548,8 +548,9 @@ __host__ void execute_tbc_external_product_loop(
lwe_chunk_size * level_count * (glwe_dimension + 1) *
(glwe_dimension + 1) * (polynomial_size / 2);
uint64_t chunk_size = std::min(
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
uint32_t chunk_size = (uint32_t)(std::min(
lwe_chunk_size,
(uint64_t)(lwe_dimension / grouping_factor) - lwe_offset));
auto d_mem = buffer->d_mem_acc_tbc;
auto keybundle_fft = buffer->keybundle_fft;
@@ -624,9 +625,9 @@ __host__ void execute_tbc_external_product_loop(
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
Torus, params, FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft,
global_accumulator, lwe_dimension, lwe_offset, chunk_size,
keybundle_size_per_input, num_many_lut, lut_stride));
lwe_array_in, lwe_input_indexes, keybundle_fft, global_accumulator,
lwe_dimension, lwe_offset, chunk_size, keybundle_size_per_input,
num_many_lut, lut_stride));
} else {
check_cuda_error(cudaLaunchKernelEx(
&config,