feat(gpu): extend specialized version to classical pbs

This commit is contained in:
Guillermo Oyarzun
2025-10-19 11:36:44 +02:00
committed by Agnès Leroy
parent 79f1d22573
commit e12638dabe
7 changed files with 366 additions and 19 deletions

View File

@@ -51,6 +51,14 @@ uint64_t get_buffer_size_sm_dsm_plus_tbc_classic_programmable_bootstrap(
return sizeof(double2) * polynomial_size / 2; // tbc
}
template <typename Torus>
uint64_t get_buffer_size_full_sm_programmable_bootstrap_tbc_2_2_params(
uint32_t polynomial_size) {
// In the first implementation with 2-2 params, we need up to 5 polynomials in
// shared memory we can optimize this later
return sizeof(Torus) * polynomial_size * 5;
}
template <typename Torus>
uint64_t
get_buffer_size_full_sm_programmable_bootstrap_cg(uint32_t polynomial_size) {

View File

@@ -203,6 +203,49 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
// the buffer in registers to avoid synchronizations and shared memory usage
}
// We need a different version for classical accumulation because the
// bootstrapping key is not stored in the same way than the keybundles. This is
// a suboptimal version cause global reads are not coalesced, but the bsk key is
// small and hopefully it will be stored in cache. We can optimize this later.
template <typename G, class params, uint32_t polynomial_size,
uint32_t glwe_dimension, uint32_t level_count>
__device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params_classical(
double2 *fft, double2 *fft_regs, double2 *buffer_regs,
const double2 *__restrict__ bootstrapping_key, int iteration, G &group,
int this_block_rank) {
// Continues multiplying fft by every polynomial in that particular bsk level
// Each y-block accumulates in a different polynomial at each iteration
// We accumulate in registers to free shared memory
// In 2_2 params we only have one level
constexpr uint32_t level_id = 0;
// The first product doesn't need using dsm
auto bsk_slice =
get_ith_mask_kth_block_2_2_params<double2, polynomial_size,
glwe_dimension, level_count, level_id>(
bootstrapping_key, iteration, this_block_rank);
auto bsk_poly = bsk_slice + blockIdx.y * polynomial_size / 2;
polynomial_product_accumulate_in_fourier_domain_2_2_params_classical<
params, double2, true>(buffer_regs, fft_regs, bsk_poly);
// Synchronize to ensure all blocks have written its fft result
group.sync();
constexpr uint32_t glwe_id = 1;
int idx = (glwe_id + this_block_rank) % (glwe_dimension + 1);
bsk_slice =
get_ith_mask_kth_block_2_2_params<double2, polynomial_size,
glwe_dimension, level_count, level_id>(
bootstrapping_key, iteration, idx);
bsk_poly = bsk_slice + blockIdx.y * polynomial_size / 2;
auto fft_slice =
get_join_buffer_element_tbc<G, level_id, glwe_dimension>(idx, group, fft);
polynomial_product_accumulate_in_fourier_domain_2_2_params_classical<
params, double2, false>(buffer_regs, fft_slice, bsk_poly);
// We don't need to synchronize here, cause we are going to use a buffer
// different than the input In 2_2 params, level_count=1 so we can just return
// the buffer in registers to avoid synchronizations and shared memory usage
}
template <typename InputTorus, typename OutputTorus>
void execute_pbs_async(CudaStreams streams,
const LweArrayVariant<OutputTorus> &lwe_array_out,

View File

@@ -43,9 +43,8 @@ bool has_support_to_cuda_programmable_bootstrap_tbc(
max_shared_memory);
case 2048:
return supports_thread_block_clusters_on_classic_programmable_bootstrap<
Torus, AmortizedDegree<2048>>(num_samples, glwe_dimension,
polynomial_size, level_count,
max_shared_memory);
Torus, Degree<2048>>(num_samples, glwe_dimension, polynomial_size,
level_count, max_shared_memory);
case 4096:
return supports_thread_block_clusters_on_classic_programmable_bootstrap<
Torus, AmortizedDegree<4096>>(num_samples, glwe_dimension,
@@ -96,7 +95,7 @@ uint64_t scratch_cuda_programmable_bootstrap_tbc(
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory, noise_reduction_type);
case 2048:
return scratch_programmable_bootstrap_tbc<Torus, AmortizedDegree<2048>>(
return scratch_programmable_bootstrap_tbc<Torus, Degree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, pbs_buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory, noise_reduction_type);
@@ -159,7 +158,7 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
num_many_lut, lut_stride);
break;
case 2048:
host_programmable_bootstrap_tbc<Torus, AmortizedDegree<2048>>(
host_programmable_bootstrap_tbc<Torus, Degree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension,

View File

@@ -200,6 +200,188 @@ __global__ void device_programmable_bootstrap_tbc(
}
}
template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_programmable_bootstrap_tbc_2_2_params(
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lut_vector,
const Torus *__restrict__ lut_vector_indexes,
const Torus *__restrict__ lwe_array_in,
const Torus *__restrict__ lwe_input_indexes,
const double2 *__restrict__ bootstrapping_key, double2 *join_buffer,
uint32_t lwe_dimension, uint32_t num_many_lut, uint32_t lut_stride,
PBS_MS_REDUCTION_T noise_reduction_type) {
constexpr uint32_t level_count = 1;
constexpr uint32_t polynomial_size = 2048;
constexpr uint32_t glwe_dimension = 1;
constexpr uint32_t base_log = 23;
constexpr bool support_dsm = true;
cluster_group cluster = this_cluster();
auto this_block_rank = cluster.block_index().y;
// We use shared memory for the polynomials that are used often during the
// bootstrap, since shared memory is kept in L1 cache and accessing it is
// much faster than global memory
extern __shared__ int8_t sharedmem[];
int8_t *selected_memory;
// When using 2_2 params and tbc we know everything fits in shared memory
// The first (polynomial_size/2) * sizeof(double2) bytes are reserved for
// external product using distributed shared memory
selected_memory = sharedmem;
// We know that dsm is supported and we have enough memory
constexpr uint32_t num_buffers_ping_pong = 4;
selected_memory += sizeof(Torus) * polynomial_size * num_buffers_ping_pong;
double2 *accumulator_ping = (double2 *)sharedmem;
double2 *accumulator_pong = accumulator_ping + (polynomial_size / 2);
double2 *shared_twiddles = accumulator_pong + (polynomial_size / 2);
double2 *shared_fft = shared_twiddles + (polynomial_size / 2);
Torus *accumulator = (Torus *)selected_memory;
// Copying the twiddles from global to shared for extra performance
for (int k = 0; k < params::opt / 2; k++) {
shared_twiddles[threadIdx.x + k * (params::degree / params::opt)] =
negtwiddles[threadIdx.x + k * (params::degree / params::opt)];
}
// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.x] * (lwe_dimension + 1)];
const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
(glwe_dimension + 1)];
double2 *block_join_buffer =
&join_buffer[blockIdx.x * level_count * (glwe_dimension + 1) *
params::degree / 2];
// Since the space is L1 cache is small, we use the same memory location for
// the rotated accumulator and the fft accumulator, since we know that the
// rotated array is not in use anymore by the time we perform the fft
// Put "b" in [0, 2N[
constexpr auto log_modulus = params::log2_degree + 1;
Torus b_hat = 0;
Torus correction = 0;
if (noise_reduction_type == PBS_MS_REDUCTION_T::CENTERED) {
correction = centered_binary_modulus_switch_body_correction_to_add(
block_lwe_array_in, lwe_dimension, log_modulus);
}
modulus_switch(block_lwe_array_in[lwe_dimension] + correction, b_hat,
log_modulus);
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, &block_lut_vector[blockIdx.y * params::degree], b_hat,
false);
Torus temp_a_hat = 0;
for (int i = 0; i < lwe_dimension; i++) {
// We calculate the modulus switch of a warp size of elements
if (i % 32 == 0) {
modulus_switch(block_lwe_array_in[i + threadIdx.x % 32], temp_a_hat,
log_modulus);
}
// each iteration we broadcast the corresponding ms previously calculated
Torus a_hat = __shfl_sync(0xFFFFFFFF, temp_a_hat, i % 32);
__syncthreads();
Torus reg_acc_rotated[params::opt];
// Perform ACC * (X^ä - 1)
multiply_by_monomial_negacyclic_and_sub_polynomial_in_regs<
Torus, params::opt, params::degree / params::opt>(
accumulator, reg_acc_rotated, a_hat);
init_decomposer_state_inplace_2_2_params<Torus, params::opt,
params::degree / params::opt,
base_log, level_count>(
reg_acc_rotated);
auto accumulator_fft = i % 2 ? accumulator_ping : accumulator_pong;
double2 fft_out_regs[params::opt / 2];
// Decompose the accumulator. Each block gets one level of the
// decomposition, for the mask and the body (so block 0 will have the
// accumulator decomposed at level 0, 1 at 1, etc.)
decompose_and_compress_level_2_2_params<Torus, params, base_log>(
fft_out_regs, reg_acc_rotated);
NSMFFT_direct_2_2_params<HalfDegree<params>>(shared_fft, fft_out_regs,
shared_twiddles);
// we move registers into shared memory to use dsm
int tid = threadIdx.x;
for (Index k = 0; k < params::opt / 4; k++) {
accumulator_fft[tid] = fft_out_regs[k];
accumulator_fft[tid + params::degree / 4] =
fft_out_regs[k + params::opt / 4];
tid = tid + params::degree / params::opt;
}
double2 buffer_regs[params::opt / 2];
// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe_in_fourier_domain_2_2_params_classical<
cluster_group, params, polynomial_size, glwe_dimension, level_count>(
accumulator_fft, fft_out_regs, buffer_regs, bootstrapping_key, i,
cluster, this_block_rank);
NSMFFT_inverse_2_2_params<HalfDegree<params>>(shared_fft, buffer_regs,
shared_twiddles);
// We need a new version of that to torus that writes in shred memory
add_to_torus_2_2_params_using_shared<Torus, params>(buffer_regs,
accumulator);
}
__syncthreads();
auto block_lwe_array_out =
&lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];
if (blockIdx.z == 0) {
if (blockIdx.y < glwe_dimension) {
// Perform a sample extract. At this point, all blocks have the result,
// but we do the computation at block 0 to avoid waiting for extra blocks,
// in case they're not synchronized
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
if (num_many_lut > 1) {
for (int i = 1; i < num_many_lut; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];
sample_extract_mask<Torus, params>(next_block_lwe_array_out,
accumulator, 1, i * lut_stride);
}
}
} else if (blockIdx.y == glwe_dimension) {
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator, 0);
if (num_many_lut > 1) {
for (int i = 1; i < num_many_lut; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];
sample_extract_body<Torus, params>(next_block_lwe_array_out,
accumulator, 0, i * lut_stride);
}
}
}
}
cluster.sync();
}
template <typename Torus, typename params>
__host__ uint64_t scratch_programmable_bootstrap_tbc(
cudaStream_t stream, uint32_t gpu_index,
@@ -340,14 +522,40 @@ __host__ void host_programmable_bootstrap_tbc(
partial_dm, supports_dsm, num_many_lut, lut_stride,
noise_reduction_type));
} else {
config.dynamicSmemBytes = full_sm + minimum_sm_tbc;
if (polynomial_size == 2048 && level_count == 1 && glwe_dimension == 1 &&
base_log == 23) {
uint64_t full_sm_2_2 =
get_buffer_size_full_sm_programmable_bootstrap_tbc_2_2_params<Torus>(
polynomial_size);
config.dynamicSmemBytes = full_sm_2_2;
check_cuda_error(cudaLaunchKernelEx(
&config, device_programmable_bootstrap_tbc<Torus, params, FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,
lwe_dimension, polynomial_size, base_log, level_count, d_mem, 0,
supports_dsm, num_many_lut, lut_stride, buffer->noise_reduction_type));
check_cuda_error(cudaFuncSetAttribute(
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_2_2));
check_cuda_error(cudaFuncSetAttribute(
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared));
check_cuda_error(cudaFuncSetCacheConfig(
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
cudaFuncCachePreferShared));
check_cuda_error(cudaLaunchKernelEx(
&config,
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,
lwe_dimension, num_many_lut, lut_stride, noise_reduction_type));
} else {
config.dynamicSmemBytes = full_sm + minimum_sm_tbc;
check_cuda_error(cudaLaunchKernelEx(
&config, device_programmable_bootstrap_tbc<Torus, params, FULLSM>,
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft,
lwe_dimension, polynomial_size, base_log, level_count, d_mem, 0,
supports_dsm, num_many_lut, lut_stride,
buffer->noise_reduction_type));
}
}
}
@@ -466,12 +674,22 @@ __host__ bool supports_thread_block_clusters_on_classic_programmable_bootstrap(
&cluster_size,
device_programmable_bootstrap_tbc<Torus, params, PARTIALSM>, &config));
} else {
check_cuda_error(cudaFuncSetAttribute(
device_programmable_bootstrap_tbc<Torus, params, FULLSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size, device_programmable_bootstrap_tbc<Torus, params, FULLSM>,
&config));
if (polynomial_size == 2048 && level_count == 1 && glwe_dimension == 1) {
check_cuda_error(cudaFuncSetAttribute(
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
device_programmable_bootstrap_tbc_2_2_params<Torus, params, FULLSM>,
&config));
} else {
check_cuda_error(cudaFuncSetAttribute(
device_programmable_bootstrap_tbc<Torus, params, FULLSM>,
cudaFuncAttributeNonPortableClusterSizeAllowed, false));
check_cuda_error(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
device_programmable_bootstrap_tbc<Torus, params, FULLSM>, &config));
}
}
return cluster_size >= level_count * (glwe_dimension + 1);

View File

@@ -159,6 +159,33 @@ __device__ void multiply_by_monomial_negacyclic_and_sub_polynomial(
}
}
/*
* Receives num_poly concatenated polynomials of type T. For each:
*
* Performs result_acc = acc * (X^ä - 1) - acc
* takes single buffer as input and returns a single rotated buffer
* result_acc must be in registers
* acc must be in shared memory
* By default, it works on a single polynomial.
*/
template <typename T, int elems_per_thread, int block_size>
__device__ void multiply_by_monomial_negacyclic_and_sub_polynomial_in_regs(
T *acc, T *result_acc, uint32_t j) {
constexpr int degree = block_size * elems_per_thread;
int tid = threadIdx.x;
for (int i = 0; i < elems_per_thread; i++) {
if (j < degree) {
int x = tid - j + SEL(0, degree, tid < j);
result_acc[i] = SEL(1, -1, tid < j) * acc[x] - acc[tid];
} else {
int32_t jj = j - degree;
int x = tid - jj + SEL(0, degree, tid < jj);
result_acc[i] = SEL(-1, 1, tid < jj) * acc[x] - acc[tid];
}
tid += block_size;
}
}
/*
* Receives num_poly concatenated polynomials of type T. For each performs a
* rounding to increase accuracy of the PBS. Calculates inplace.
@@ -252,6 +279,31 @@ __device__ void add_to_torus_2_2_params(double2 *m_values, Torus *result) {
tid = tid + params::degree / params::opt;
}
}
/**
* In case of 2_2_classical PBS, this method should accumulate the result.
* and the result is stored in shared memory
*/
template <typename Torus, class params>
__device__ void add_to_torus_2_2_params_using_shared(double2 *m_values,
Torus *result) {
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
double double_real = m_values[i].x;
double double_imag = m_values[i].y;
Torus torus_real = 0;
typecast_double_round_to_torus<Torus>(double_real, torus_real);
Torus torus_imag = 0;
typecast_double_round_to_torus<Torus>(double_imag, torus_imag);
result[tid] += torus_real;
result[tid + params::degree / 2] += torus_imag;
tid = tid + params::degree / params::opt;
}
}
/**
* In case of classical PBS, this method should accumulate the result.

View File

@@ -84,6 +84,32 @@ __device__ void polynomial_product_accumulate_in_fourier_domain_2_2_params(
}
}
// Computes the same than above but adapted for the bootstrapping key storage
// layout, there is room for optimization here but we can do it later
// if the classical version is needed in production
template <class params, typename T, bool init_accumulator>
__device__ void
polynomial_product_accumulate_in_fourier_domain_2_2_params_classical(
T *__restrict__ result, T *__restrict__ first,
const T *__restrict__ second) {
int tid = threadIdx.x;
if constexpr (init_accumulator) {
for (int i = 0; i < params::opt / 4; i++) {
result[i] = first[i] * second[2 * tid];
result[i + params::opt / 4] =
first[i + params::opt / 4] * second[2 * tid + 1];
tid += (params::degree / params::opt);
}
} else {
for (int i = 0; i < params::opt / 4; i++) {
result[i] += first[tid] * second[2 * tid];
result[i + params::opt / 4] +=
first[tid + params::degree / 4] * second[2 * tid + 1];
tid += params::degree / params::opt;
}
}
}
// Computes result += first * second
// If init_accumulator is set, assumes that result was not initialized and does
// that with the outcome of first * second

View File

@@ -851,7 +851,8 @@ fn main() {
#[cfg(feature = "gpu")]
fn main() {
let params = benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let params =
benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let config = tfhe::ConfigBuilder::with_custom_parameters(params).build();
let cks = ClientKey::generate(config);