mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
fix(gpu): clean unused variables in specialized classical pbs
This commit is contained in:
committed by
Agnès Leroy
parent
2355cf4d89
commit
b218c98194
@@ -212,8 +212,8 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
|
||||
const Torus *__restrict__ lut_vector_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const double2 *__restrict__ bootstrapping_key, double2 *join_buffer,
|
||||
uint32_t lwe_dimension, uint32_t num_many_lut, uint32_t lut_stride,
|
||||
const double2 *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t num_many_lut, uint32_t lut_stride,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type) {
|
||||
|
||||
constexpr uint32_t level_count = 1;
|
||||
@@ -259,9 +259,6 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
|
||||
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
|
||||
(glwe_dimension + 1)];
|
||||
|
||||
double2 *block_join_buffer =
|
||||
&join_buffer[blockIdx.x * level_count * (glwe_dimension + 1) *
|
||||
params::degree / 2];
|
||||
// Since the space is L1 cache is small, we use the same memory location for
|
||||
// the rotated accumulator and the fft accumulator, since we know that the
|
||||
// rotated array is not in use anymore by the time we perform the fft
|
||||
@@ -284,13 +281,14 @@ __global__ void device_programmable_bootstrap_tbc_2_2_params(
|
||||
Torus temp_a_hat = 0;
|
||||
for (int i = 0; i < lwe_dimension; i++) {
|
||||
|
||||
constexpr int WARP_SIZE = 32;
|
||||
// We calculate the modulus switch of a warp size of elements
|
||||
if (i % 32 == 0 && (i + threadIdx.x % 32) < lwe_dimension) {
|
||||
modulus_switch(block_lwe_array_in[i + threadIdx.x % 32], temp_a_hat,
|
||||
log_modulus);
|
||||
if (i % WARP_SIZE == 0 && (i + threadIdx.x % WARP_SIZE) < lwe_dimension) {
|
||||
modulus_switch(block_lwe_array_in[i + threadIdx.x % WARP_SIZE],
|
||||
temp_a_hat, log_modulus);
|
||||
}
|
||||
// each iteration we broadcast the corresponding ms previously calculated
|
||||
Torus a_hat = __shfl_sync(0xFFFFFFFF, temp_a_hat, i % 32);
|
||||
Torus a_hat = __shfl_sync(0xFFFFFFFF, temp_a_hat, i % WARP_SIZE);
|
||||
|
||||
__syncthreads();
|
||||
Torus reg_acc_rotated[params::opt];
|
||||
@@ -556,8 +554,8 @@ __host__ void host_programmable_bootstrap_tbc(
|
||||
&config,
|
||||
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
|
||||
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
|
||||
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,
|
||||
lwe_dimension, num_many_lut, lut_stride, noise_reduction_type));
|
||||
lwe_array_in, lwe_input_indexes, bootstrapping_key, lwe_dimension,
|
||||
num_many_lut, lut_stride, noise_reduction_type));
|
||||
} else {
|
||||
config.dynamicSmemBytes = full_sm + minimum_sm_tbc;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user