diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index dfd8f379a..502e29251 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -301,7 +301,11 @@ __global__ void keyswitch_zero_output_with_output_indices( // in two parts, a constant part is calculated before the loop, and a variable // part is calculated inside the loop. This seems to help with the register // pressure as well. -template +// LevelCount template fully unrolls the level loop. +// Accumulation step is divided in even and odd to expose more instruction level +// parallelism. When LevelCount == 0 (default), the original runtime path is +// used unchanged. +template __global__ void keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, const Torus *__restrict__ lwe_array_in, @@ -318,13 +322,15 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, if (tid <= lwe_dimension_out) { - KSTorus local_lwe_out = 0; + // Accumulator is splited into even and odd iterations. + KSTorus local_lwe_out_even = 0; + KSTorus local_lwe_out_odd = 0; auto block_lwe_array_in = get_chunk( lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1); if (tid == lwe_dimension_out && threadIdx.y == 0) { if constexpr (std::is_same_v) { - local_lwe_out = -block_lwe_array_in[lwe_dimension_in]; + local_lwe_out_even = -block_lwe_array_in[lwe_dimension_in]; } else { auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in], sizeof(KSTorus) * 8, 1); @@ -337,7 +343,7 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, auto rounded_downscaled_body = (KSTorus)(new_body >> input_to_output_scaling_factor); - local_lwe_out = -rounded_downscaled_body; + local_lwe_out_even = -rounded_downscaled_body; } } const Torus mask_mod_b = (1ll << base_log) - 1ll; @@ -351,18 +357,41 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, for (int i = start_i; i < end_i; i++) { Torus state = init_decomposer_state(block_lwe_array_in[i], base_log, level_count); - uint32_t offset = i * level_count * (lwe_dimension_out + 1); -#pragma unroll 1 - for (int j = 0; j < level_count; j++) { - KSTorus decomposed = decompose_one(state, mask_mod_b, base_log); - local_lwe_out += - (KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * - decomposed; + // Fully unrolled level loop in case LevelCount is precompiled. + if constexpr (LevelCount > 0) { + uint32_t offset = i * LevelCount * (lwe_dimension_out + 1); +#pragma unroll + for (int j = 0; j < LevelCount; j++) { + KSTorus decomposed = + decompose_one(state, mask_mod_b, base_log); + if (j & 1) + local_lwe_out_odd += + (KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * + decomposed; + else + local_lwe_out_even += + (KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * + decomposed; + } + } else { + // Runtime fallback: original behaviour preserved exactly. + uint32_t offset = i * level_count * (lwe_dimension_out + 1); +#pragma unroll 1 + for (int j = 0; j < level_count; j++) { + KSTorus decomposed = + decompose_one(state, mask_mod_b, base_log); + local_lwe_out_even += + (KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * + decomposed; + } } } - lwe_acc_out[shmem_index] = local_lwe_out; + if constexpr (LevelCount > 0) + lwe_acc_out[shmem_index] = local_lwe_out_even + local_lwe_out_odd; + else + lwe_acc_out[shmem_index] = local_lwe_out_even; } for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { @@ -405,9 +434,46 @@ __host__ void host_keyswitch_lwe_ciphertext_vector( dim3 grid(num_samples, num_blocks_per_sample, 1); dim3 threads(num_threads_x, num_threads_y, 1); - keyswitch<<>>( - lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk, - lwe_dimension_in, lwe_dimension_out, base_log, level_count); + // Dispatch to a statically-specialised kernel for common level_count values + // so the level loop is fully unrolled. +#define KS_LAUNCH(N) \ + keyswitch<<>>( \ + lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk, \ + lwe_dimension_in, lwe_dimension_out, base_log, level_count) + + switch (level_count) { + case 1: + KS_LAUNCH(1); + break; + case 2: + KS_LAUNCH(2); + break; + case 3: + KS_LAUNCH(3); + break; + case 4: + KS_LAUNCH(4); + break; + case 5: + KS_LAUNCH(5); + break; + case 6: + KS_LAUNCH(6); + break; + case 7: + KS_LAUNCH(7); + break; + case 8: + KS_LAUNCH(8); + break; + case 9: + KS_LAUNCH(9); + break; + default: + KS_LAUNCH(0); + break; + } +#undef KS_LAUNCH check_cuda_error(cudaGetLastError()); } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp index 30ff5c3ac..ce8f789fc 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp @@ -264,7 +264,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC) scratch_cuda_programmable_bootstrap_tbc( stream, gpu_index, (pbs_buffer **)&buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_level, - input_lwe_ciphertext_count, true, false); + input_lwe_ciphertext_count, true, PBS_MS_REDUCTION_T::NO_REDUCTION); uint32_t num_many_lut = 1; uint32_t lut_stride = 0; for (auto _ : st) {