mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore(cuda): Reduces shared memory consumption in the amortized PBS and improves loop unrolling.
This commit is contained in:
@@ -124,11 +124,12 @@ __global__ void device_bootstrap_amortized(
|
||||
// The polynomial multiplications happens at the block level
|
||||
// and each thread handles two or more coefficients
|
||||
int pos = threadIdx.x;
|
||||
for (int j = 0; j < (glwe_dimension + 1) * params::opt / 2; j++) {
|
||||
res_fft[pos].x = 0;
|
||||
res_fft[pos].y = 0;
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
for (int i = 0; i < (glwe_dimension + 1); i++)
|
||||
for (int j = 0; j < params::opt / 2; j++) {
|
||||
res_fft[pos].x = 0;
|
||||
res_fft[pos].y = 0;
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count,
|
||||
accumulator_rotated, glwe_dimension + 1);
|
||||
@@ -136,20 +137,17 @@ __global__ void device_bootstrap_amortized(
|
||||
// coefficients so as to multiply each decomposed level with the
|
||||
// corresponding part of the bootstrapping key
|
||||
for (int level = level_count - 1; level >= 0; level--) {
|
||||
gadget.decompose_and_compress_next(accumulator_fft);
|
||||
for (int i = 0; i < (glwe_dimension + 1); i++) {
|
||||
auto accumulator_fft_slice = accumulator_fft + i * params::degree / 2;
|
||||
gadget.decompose_and_compress_next_polynomial(accumulator_fft, i);
|
||||
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft_slice);
|
||||
synchronize_threads_in_block();
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
|
||||
// Get the bootstrapping key piece necessary for the multiplication
|
||||
// It is already in the Fourier domain
|
||||
auto bsk_slice = get_ith_mask_kth_block(bootstrapping_key, iteration, i,
|
||||
level, polynomial_size,
|
||||
glwe_dimension, level_count);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Perform the coefficient-wise product with the two pieces of
|
||||
// bootstrapping key
|
||||
@@ -157,7 +155,7 @@ __global__ void device_bootstrap_amortized(
|
||||
auto bsk_poly = bsk_slice + j * params::degree / 2;
|
||||
auto res_fft_poly = res_fft + j * params::degree / 2;
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
res_fft_poly, accumulator_fft_slice, bsk_poly);
|
||||
res_fft_poly, accumulator_fft, bsk_poly);
|
||||
}
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
@@ -219,9 +217,8 @@ get_buffer_size_full_sm_bootstrap_amortized(uint32_t polynomial_size,
|
||||
uint32_t glwe_dimension) {
|
||||
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // accumulator
|
||||
sizeof(Torus) * polynomial_size *
|
||||
(glwe_dimension + 1) + // accumulator rotated
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1) + // accumulator fft
|
||||
(glwe_dimension + 1) + // accumulator rotated
|
||||
sizeof(double2) * polynomial_size / 2 + // accumulator fft
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1); // calculate buffer fft
|
||||
}
|
||||
@@ -230,8 +227,7 @@ template <typename Torus>
|
||||
__host__ __device__ int
|
||||
get_buffer_size_partial_sm_bootstrap_amortized(uint32_t polynomial_size,
|
||||
uint32_t glwe_dimension) {
|
||||
return sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1); // calculate buffer fft
|
||||
return sizeof(double2) * polynomial_size / 2; // calculate buffer fft
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -75,11 +75,12 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
// The polynomial multiplications happens at the block level
|
||||
// and each thread handles two or more coefficients
|
||||
int pos = threadIdx.x;
|
||||
for (int j = 0; j < (glwe_dim + 1) * params::opt / 2; j++) {
|
||||
res_fft[pos].x = 0;
|
||||
res_fft[pos].y = 0;
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
for (int i = 0; i < (glwe_dim + 1); i++)
|
||||
for (int j = 0; j < params::opt / 2; j++) {
|
||||
res_fft[pos].x = 0;
|
||||
res_fft[pos].y = 0;
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
|
||||
synchronize_threads_in_block();
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
|
||||
@@ -124,10 +125,11 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
Torus *mb = &glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
|
||||
mb[tid] = m0[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
for (int i = 0; i < (glwe_dim + 1); i++)
|
||||
for (int j = 0; j < params::opt; j++) {
|
||||
mb[tid] = m0[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto res_fft_slice = res_fft + i * params::degree / 2;
|
||||
@@ -425,10 +427,11 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
// Input LUT
|
||||
auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size];
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
|
||||
accumulator_c0[tid] = mi[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
for (int i = 0; i < (glwe_dim + 1); i++)
|
||||
for (int j = 0; j < params::opt; j++) {
|
||||
accumulator_c0[tid] = mi[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
int monomial_degree = 0;
|
||||
for (int i = mbr_size - 1; i >= 0; i--) {
|
||||
|
||||
Reference in New Issue
Block a user