mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
feat(gpu): expose ILP in standard keyswitch
This commit is contained in:
@@ -301,7 +301,11 @@ __global__ void keyswitch_zero_output_with_output_indices(
|
||||
// in two parts, a constant part is calculated before the loop, and a variable
|
||||
// part is calculated inside the loop. This seems to help with the register
|
||||
// pressure as well.
|
||||
template <typename Torus, typename KSTorus>
|
||||
// LevelCount template fully unrolls the level loop.
|
||||
// Accumulation step is divided in even and odd to expose more instruction level
|
||||
// parallelism. When LevelCount == 0 (default), the original runtime path is
|
||||
// used unchanged.
|
||||
template <typename Torus, typename KSTorus, int LevelCount = 0>
|
||||
__global__ void
|
||||
keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
@@ -318,13 +322,15 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
|
||||
if (tid <= lwe_dimension_out) {
|
||||
|
||||
KSTorus local_lwe_out = 0;
|
||||
// Accumulator is splited into even and odd iterations.
|
||||
KSTorus local_lwe_out_even = 0;
|
||||
KSTorus local_lwe_out_odd = 0;
|
||||
auto block_lwe_array_in = get_chunk(
|
||||
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
|
||||
|
||||
if (tid == lwe_dimension_out && threadIdx.y == 0) {
|
||||
if constexpr (std::is_same_v<KSTorus, Torus>) {
|
||||
local_lwe_out = -block_lwe_array_in[lwe_dimension_in];
|
||||
local_lwe_out_even = -block_lwe_array_in[lwe_dimension_in];
|
||||
} else {
|
||||
auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in],
|
||||
sizeof(KSTorus) * 8, 1);
|
||||
@@ -337,7 +343,7 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
auto rounded_downscaled_body =
|
||||
(KSTorus)(new_body >> input_to_output_scaling_factor);
|
||||
|
||||
local_lwe_out = -rounded_downscaled_body;
|
||||
local_lwe_out_even = -rounded_downscaled_body;
|
||||
}
|
||||
}
|
||||
const Torus mask_mod_b = (1ll << base_log) - 1ll;
|
||||
@@ -351,18 +357,41 @@ keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
for (int i = start_i; i < end_i; i++) {
|
||||
Torus state =
|
||||
init_decomposer_state(block_lwe_array_in[i], base_log, level_count);
|
||||
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
|
||||
#pragma unroll 1
|
||||
for (int j = 0; j < level_count; j++) {
|
||||
|
||||
KSTorus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
|
||||
local_lwe_out +=
|
||||
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
|
||||
decomposed;
|
||||
// Fully unrolled level loop in case LevelCount is precompiled.
|
||||
if constexpr (LevelCount > 0) {
|
||||
uint32_t offset = i * LevelCount * (lwe_dimension_out + 1);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < LevelCount; j++) {
|
||||
KSTorus decomposed =
|
||||
decompose_one<Torus>(state, mask_mod_b, base_log);
|
||||
if (j & 1)
|
||||
local_lwe_out_odd +=
|
||||
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
|
||||
decomposed;
|
||||
else
|
||||
local_lwe_out_even +=
|
||||
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
|
||||
decomposed;
|
||||
}
|
||||
} else {
|
||||
// Runtime fallback: original behaviour preserved exactly.
|
||||
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
|
||||
#pragma unroll 1
|
||||
for (int j = 0; j < level_count; j++) {
|
||||
KSTorus decomposed =
|
||||
decompose_one<Torus>(state, mask_mod_b, base_log);
|
||||
local_lwe_out_even +=
|
||||
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
|
||||
decomposed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lwe_acc_out[shmem_index] = local_lwe_out;
|
||||
if constexpr (LevelCount > 0)
|
||||
lwe_acc_out[shmem_index] = local_lwe_out_even + local_lwe_out_odd;
|
||||
else
|
||||
lwe_acc_out[shmem_index] = local_lwe_out_even;
|
||||
}
|
||||
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
@@ -405,9 +434,46 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
|
||||
dim3 grid(num_samples, num_blocks_per_sample, 1);
|
||||
dim3 threads(num_threads_x, num_threads_y, 1);
|
||||
|
||||
keyswitch<Torus, KSTorus><<<grid, threads, shared_mem, stream>>>(
|
||||
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
|
||||
// Dispatch to a statically-specialised kernel for common level_count values
|
||||
// so the level loop is fully unrolled.
|
||||
#define KS_LAUNCH(N) \
|
||||
keyswitch<Torus, KSTorus, N><<<grid, threads, shared_mem, stream>>>( \
|
||||
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk, \
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count)
|
||||
|
||||
switch (level_count) {
|
||||
case 1:
|
||||
KS_LAUNCH(1);
|
||||
break;
|
||||
case 2:
|
||||
KS_LAUNCH(2);
|
||||
break;
|
||||
case 3:
|
||||
KS_LAUNCH(3);
|
||||
break;
|
||||
case 4:
|
||||
KS_LAUNCH(4);
|
||||
break;
|
||||
case 5:
|
||||
KS_LAUNCH(5);
|
||||
break;
|
||||
case 6:
|
||||
KS_LAUNCH(6);
|
||||
break;
|
||||
case 7:
|
||||
KS_LAUNCH(7);
|
||||
break;
|
||||
case 8:
|
||||
KS_LAUNCH(8);
|
||||
break;
|
||||
case 9:
|
||||
KS_LAUNCH(9);
|
||||
break;
|
||||
default:
|
||||
KS_LAUNCH(0);
|
||||
break;
|
||||
}
|
||||
#undef KS_LAUNCH
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC)
|
||||
scratch_cuda_programmable_bootstrap_tbc<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, CLASSICAL> **)&buffer,
|
||||
lwe_dimension, glwe_dimension, polynomial_size, pbs_level,
|
||||
input_lwe_ciphertext_count, true, false);
|
||||
input_lwe_ciphertext_count, true, PBS_MS_REDUCTION_T::NO_REDUCTION);
|
||||
uint32_t num_many_lut = 1;
|
||||
uint32_t lut_stride = 0;
|
||||
for (auto _ : st) {
|
||||
|
||||
Reference in New Issue
Block a user