fix(gpu): clean unused variables in specialized classical pbs

This commit is contained in:
Guillermo Oyarzun
2026-02-05 14:59:15 +01:00
committed by Agnès Leroy
parent 2355cf4d89
commit b218c98194

View File

@@ -212,8 +212,8 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ bootstrapping_key, double2 *join_buffer,
uint32_t lwe_dimension, uint32_t num_many_lut, uint32_t lut_stride,
const double2 *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
uint32_t num_many_lut, uint32_t lut_stride,
PBS_MS_REDUCTION_T noise_reduction_type) {
constexpr uint32_t level_count = 1;
@@ -259,9 +259,6 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
(glwe_dimension + 1)];
double2 *block_join_buffer =
&join_buffer[blockIdx.x * level_count * (glwe_dimension + 1) *
params::degree / 2];
// Since the space is L1 cache is small, we use the same memory location for
// the rotated accumulator and the fft accumulator, since we know that the
// rotated array is not in use anymore by the time we perform the fft
@@ -284,13 +281,14 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
Torus temp_a_hat = 0;
for (int i = 0; i < lwe_dimension; i++) {
constexpr int WARP_SIZE = 32;
// We calculate the modulus switch of a warp size of elements
if (i % 32 == 0 && (i + threadIdx.x % 32) < lwe_dimension) {
modulus_switch(block_lwe_array_in[i + threadIdx.x % 32], temp_a_hat,
log_modulus);
if (i % WARP_SIZE == 0 && (i + threadIdx.x % WARP_SIZE) < lwe_dimension) {
modulus_switch(block_lwe_array_in[i + threadIdx.x % WARP_SIZE],
temp_a_hat, log_modulus);
}
// each iteration we broadcast the corresponding ms previously calculated
Torus a_hat = __shfl_sync(0xFFFFFFFF, temp_a_hat, i % 32);
Torus a_hat = __shfl_sync(0xFFFFFFFF, temp_a_hat, i % WARP_SIZE);
__syncthreads();
Torus reg_acc_rotated[params::opt];
@@ -556,8 +554,8 @@ __host__ void host_programmable_bootstrap_tbc(
&config,
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,
lwe_dimension, num_many_lut, lut_stride, noise_reduction_type));
lwe_array_in, lwe_input_indexes, bootstrapping_key, lwe_dimension,
num_many_lut, lut_stride, noise_reduction_type));
} else {
config.dynamicSmemBytes = full_sm + minimum_sm_tbc;