feat(gpu): expose ILP in standard keyswitch

This commit is contained in:
Guillermo Oyarzun
2026-04-15 13:19:33 +02:00
parent 96d230cf6f
commit 4e9dc1caee
2 changed files with 82 additions and 16 deletions

View File

@@ -301,7 +301,11 @@ __global__ void keyswitch_zero_output_with_output_indices(
// in two parts, a constant part is calculated before the loop, and a variable
// part is calculated inside the loop. This seems to help with the register
// pressure as well.
template <typename Torus, typename KSTorus>
// LevelCount template fully unrolls the level loop.
// Accumulation step is divided in even and odd to expose more instruction level
// parallelism. When LevelCount == 0 (default), the original runtime path is
// used unchanged.
template <typename Torus, typename KSTorus, int LevelCount = 0>
__global__ void
keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lwe_array_in,
@@ -318,13 +322,15 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
if (tid <= lwe_dimension_out) {
KSTorus local_lwe_out = 0;
// Accumulator is splited into even and odd iterations.
KSTorus local_lwe_out_even = 0;
KSTorus local_lwe_out_odd = 0;
auto block_lwe_array_in = get_chunk(
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
if (tid == lwe_dimension_out && threadIdx.y == 0) {
if constexpr (std::is_same_v<KSTorus, Torus>) {
local_lwe_out = -block_lwe_array_in[lwe_dimension_in];
local_lwe_out_even = -block_lwe_array_in[lwe_dimension_in];
} else {
auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in],
sizeof(KSTorus) * 8, 1);
@@ -337,7 +343,7 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
auto rounded_downscaled_body =
(KSTorus)(new_body >> input_to_output_scaling_factor);
local_lwe_out = -rounded_downscaled_body;
local_lwe_out_even = -rounded_downscaled_body;
}
}
const Torus mask_mod_b = (1ll << base_log) - 1ll;
@@ -351,18 +357,41 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
for (int i = start_i; i < end_i; i++) {
Torus state =
init_decomposer_state(block_lwe_array_in[i], base_log, level_count);
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
#pragma unroll 1
for (int j = 0; j < level_count; j++) {
KSTorus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
local_lwe_out +=
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
decomposed;
// Fully unrolled level loop in case LevelCount is precompiled.
if constexpr (LevelCount > 0) {
uint32_t offset = i * LevelCount * (lwe_dimension_out + 1);
#pragma unroll
for (int j = 0; j < LevelCount; j++) {
KSTorus decomposed =
decompose_one<Torus>(state, mask_mod_b, base_log);
if (j & 1)
local_lwe_out_odd +=
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
decomposed;
else
local_lwe_out_even +=
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
decomposed;
}
} else {
// Runtime fallback: original behaviour preserved exactly.
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
#pragma unroll 1
for (int j = 0; j < level_count; j++) {
KSTorus decomposed =
decompose_one<Torus>(state, mask_mod_b, base_log);
local_lwe_out_even +=
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
decomposed;
}
}
}
lwe_acc_out[shmem_index] = local_lwe_out;
if constexpr (LevelCount > 0)
lwe_acc_out[shmem_index] = local_lwe_out_even + local_lwe_out_odd;
else
lwe_acc_out[shmem_index] = local_lwe_out_even;
}
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
@@ -405,9 +434,46 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
dim3 grid(num_samples, num_blocks_per_sample, 1);
dim3 threads(num_threads_x, num_threads_y, 1);
keyswitch<Torus, KSTorus><<<grid, threads, shared_mem, stream>>>(
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
// Dispatch to a statically-specialised kernel for common level_count values
// so the level loop is fully unrolled.
#define KS_LAUNCH(N) \
keyswitch<Torus, KSTorus, N><<<grid, threads, shared_mem, stream>>>( \
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk, \
lwe_dimension_in, lwe_dimension_out, base_log, level_count)
switch (level_count) {
case 1:
KS_LAUNCH(1);
break;
case 2:
KS_LAUNCH(2);
break;
case 3:
KS_LAUNCH(3);
break;
case 4:
KS_LAUNCH(4);
break;
case 5:
KS_LAUNCH(5);
break;
case 6:
KS_LAUNCH(6);
break;
case 7:
KS_LAUNCH(7);
break;
case 8:
KS_LAUNCH(8);
break;
case 9:
KS_LAUNCH(9);
break;
default:
KS_LAUNCH(0);
break;
}
#undef KS_LAUNCH
check_cuda_error(cudaGetLastError());
}

View File

@@ -264,7 +264,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC)
scratch_cuda_programmable_bootstrap_tbc<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, CLASSICAL> **)&buffer,
lwe_dimension, glwe_dimension, polynomial_size, pbs_level,
input_lwe_ciphertext_count, true, false);
input_lwe_ciphertext_count, true, PBS_MS_REDUCTION_T::NO_REDUCTION);
uint32_t num_many_lut = 1;
uint32_t lut_stride = 0;
for (auto _ : st) {