refactor(cuda): Implements support for k>1 on the Wop-PBS.

This commit is contained in:
Pedro Alves
2023-02-24 15:48:27 -03:00
committed by Agnès Leroy
parent b0a362af6d
commit 400786f3f9
4 changed files with 51 additions and 44 deletions

View File

@@ -135,15 +135,17 @@ get_buffer_size_extract_bits(uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t polynomial_size,
uint32_t number_of_inputs) {
return sizeof(Torus) * number_of_inputs // lut_vector_indexes
+ ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus) // lwe_array_in_buffer
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus) // lwe_array_in_shifted_buffer
+ (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus); // lwe_array_out_pbs_buffer
int buffer_size =
sizeof(Torus) * number_of_inputs // lut_vector_indexes
+ ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus) // lwe_array_in_buffer
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus) // lwe_array_in_shifted_buffer
+ (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer
+ (glwe_dimension * polynomial_size + 1) *
sizeof(Torus); // lwe_array_out_pbs_buffer
return buffer_size + buffer_size % sizeof(double2);
}
template <typename Torus, typename STorus, typename params>

View File

@@ -43,35 +43,45 @@ public:
synchronize_threads_in_block();
}
// Decomposes all polynomials at once
__device__ void decompose_and_compress_next(double2 *result) {
current_level -= 1;
for (int j = 0; j < num_poly; j++) {
int tid = threadIdx.x;
auto state_slice = state + j * params::degree;
auto result_slice = result + j * params::degree / 2;
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state_slice[tid] & mask_mod_b;
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
state_slice[tid] >>= base_log;
state_slice[tid + params::degree / 2] >>= base_log;
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
T carry_im =
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state_slice[tid] += carry_re;
state_slice[tid + params::degree / 2] += carry_im;
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;
decompose_and_compress_next_polynomial(result_slice, j);
}
}
result_slice[tid].x = (int32_t)res_re;
result_slice[tid].y = (int32_t)res_im;
// Decomposes a single polynomial
__device__ void decompose_and_compress_next_polynomial(double2 *result,
int j) {
if (j == 0)
current_level -= 1;
tid += params::degree / params::opt;
}
int tid = threadIdx.x;
auto state_slice = state + j * params::degree;
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state_slice[tid] & mask_mod_b;
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
state_slice[tid] >>= base_log;
state_slice[tid + params::degree / 2] >>= base_log;
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
T carry_im =
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state_slice[tid] += carry_re;
state_slice[tid + params::degree / 2] += carry_im;
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;
result[tid].x = (int32_t)res_re;
result[tid].y = (int32_t)res_im;
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}
__device__ void decompose_and_compress_level(double2 *result, int level) {
for (int i = 0; i < level_count - level; i++)
decompose_and_compress_next(result);

View File

@@ -84,31 +84,29 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
synchronize_threads_in_block();
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
glwe_dim + 1);
// Subtract each glwe operand, decompose the resulting
// polynomial coefficients to multiply each decomposed level
// with the corresponding part of the LUT
for (int level = level_count - 1; level >= 0; level--) {
// Decomposition
gadget.decompose_and_compress_next(glwe_fft);
synchronize_threads_in_block();
for (int i = 0; i < (glwe_dim + 1); i++) {
auto glwe_fft_slice = glwe_fft + i * params::degree / 2;
gadget.decompose_and_compress_next_polynomial(glwe_fft, i);
// First, perform the polynomial multiplication
NSMFFT_direct<HalfDegree<params>>(glwe_fft_slice);
NSMFFT_direct<HalfDegree<params>>(glwe_fft);
// External product and accumulate
// Get the piece necessary for the multiplication
auto bsk_slice = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, i, level, polynomial_size, glwe_dim, level_count);
synchronize_threads_in_block();
// Perform the coefficient-wise product
for (int j = 0; j < (glwe_dim + 1); j++) {
auto bsk_poly = bsk_slice + j * params::degree / 2;
auto res_fft_poly = res_fft + j * params::degree / 2;
polynomial_product_accumulate_in_fourier_domain<params, double2>(
res_fft_poly, glwe_fft_slice, bsk_poly);
res_fft_poly, glwe_fft, bsk_poly);
}
}
synchronize_threads_in_block();
@@ -215,9 +213,8 @@ get_memory_needed_per_block_cmux_tree(uint32_t glwe_dimension,
uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // glwe_sub
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1) + // res_fft
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1); // glwe_fft
(glwe_dimension + 1) + // res_fft
sizeof(double2) * polynomial_size / 2; // glwe_fft
}
template <typename Torus>
@@ -538,8 +535,6 @@ __host__ void host_blind_rotate_and_sample_extraction(
uint32_t level_count, uint32_t max_shared_memory) {
cudaSetDevice(gpu_index);
assert(glwe_dimension ==
1); // For larger k we will need to adjust the mask size
auto stream = static_cast<cudaStream_t *>(v_stream);
int memory_needed_per_block =

View File

@@ -305,9 +305,9 @@ __host__ void host_wop_pbs(
host_extract_bits<Torus, params>(
v_stream, gpu_index, (Torus *)lwe_array_out_bit_extract, lwe_array_in,
bit_extract_buffer, ksk, fourier_bsk, number_of_bits_to_extract,
delta_log, polynomial_size, lwe_dimension, glwe_dimension,
polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk,
level_count_ksk, number_of_inputs, max_shared_memory);
delta_log, glwe_dimension * polynomial_size, lwe_dimension,
glwe_dimension, polynomial_size, base_log_bsk, level_count_bsk,
base_log_ksk, level_count_ksk, number_of_inputs, max_shared_memory);
check_cuda_error(cudaGetLastError());
int8_t *cbs_vp_buffer =