mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(cuda): Implements support for k>1 on the Wop-PBS.
This commit is contained in:
@@ -135,15 +135,17 @@ get_buffer_size_extract_bits(uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size,
|
||||
uint32_t number_of_inputs) {
|
||||
|
||||
return sizeof(Torus) * number_of_inputs // lut_vector_indexes
|
||||
+ ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) // lwe_array_in_buffer
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) // lwe_array_in_shifted_buffer
|
||||
+ (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus); // lwe_array_out_pbs_buffer
|
||||
int buffer_size =
|
||||
sizeof(Torus) * number_of_inputs // lut_vector_indexes
|
||||
+ ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) // lwe_array_in_buffer
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) // lwe_array_in_shifted_buffer
|
||||
+ (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer
|
||||
+ (glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus); // lwe_array_out_pbs_buffer
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
|
||||
@@ -43,35 +43,45 @@ public:
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
// Decomposes all polynomials at once
|
||||
__device__ void decompose_and_compress_next(double2 *result) {
|
||||
current_level -= 1;
|
||||
for (int j = 0; j < num_poly; j++) {
|
||||
int tid = threadIdx.x;
|
||||
auto state_slice = state + j * params::degree;
|
||||
auto result_slice = result + j * params::degree / 2;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
T res_re = state_slice[tid] & mask_mod_b;
|
||||
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
|
||||
state_slice[tid] >>= base_log;
|
||||
state_slice[tid + params::degree / 2] >>= base_log;
|
||||
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
|
||||
T carry_im =
|
||||
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
state_slice[tid] += carry_re;
|
||||
state_slice[tid + params::degree / 2] += carry_im;
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
decompose_and_compress_next_polynomial(result_slice, j);
|
||||
}
|
||||
}
|
||||
|
||||
result_slice[tid].x = (int32_t)res_re;
|
||||
result_slice[tid].y = (int32_t)res_im;
|
||||
// Decomposes a single polynomial
|
||||
__device__ void decompose_and_compress_next_polynomial(double2 *result,
|
||||
int j) {
|
||||
if (j == 0)
|
||||
current_level -= 1;
|
||||
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
auto state_slice = state + j * params::degree;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
T res_re = state_slice[tid] & mask_mod_b;
|
||||
T res_im = state_slice[tid + params::degree / 2] & mask_mod_b;
|
||||
state_slice[tid] >>= base_log;
|
||||
state_slice[tid + params::degree / 2] >>= base_log;
|
||||
T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re;
|
||||
T carry_im =
|
||||
((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im;
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
state_slice[tid] += carry_re;
|
||||
state_slice[tid + params::degree / 2] += carry_im;
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
|
||||
result[tid].x = (int32_t)res_re;
|
||||
result[tid].y = (int32_t)res_im;
|
||||
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
__device__ void decompose_and_compress_level(double2 *result, int level) {
|
||||
for (int i = 0; i < level_count - level; i++)
|
||||
decompose_and_compress_next(result);
|
||||
|
||||
@@ -84,31 +84,29 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
synchronize_threads_in_block();
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
|
||||
glwe_dim + 1);
|
||||
|
||||
// Subtract each glwe operand, decompose the resulting
|
||||
// polynomial coefficients to multiply each decomposed level
|
||||
// with the corresponding part of the LUT
|
||||
for (int level = level_count - 1; level >= 0; level--) {
|
||||
// Decomposition
|
||||
gadget.decompose_and_compress_next(glwe_fft);
|
||||
synchronize_threads_in_block();
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto glwe_fft_slice = glwe_fft + i * params::degree / 2;
|
||||
gadget.decompose_and_compress_next_polynomial(glwe_fft, i);
|
||||
|
||||
// First, perform the polynomial multiplication
|
||||
NSMFFT_direct<HalfDegree<params>>(glwe_fft_slice);
|
||||
NSMFFT_direct<HalfDegree<params>>(glwe_fft);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
auto bsk_slice = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, i, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
// Perform the coefficient-wise product
|
||||
for (int j = 0; j < (glwe_dim + 1); j++) {
|
||||
auto bsk_poly = bsk_slice + j * params::degree / 2;
|
||||
auto res_fft_poly = res_fft + j * params::degree / 2;
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
res_fft_poly, glwe_fft_slice, bsk_poly);
|
||||
res_fft_poly, glwe_fft, bsk_poly);
|
||||
}
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
@@ -215,9 +213,8 @@ get_memory_needed_per_block_cmux_tree(uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // glwe_sub
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1) + // res_fft
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1); // glwe_fft
|
||||
(glwe_dimension + 1) + // res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
@@ -538,8 +535,6 @@ __host__ void host_blind_rotate_and_sample_extraction(
|
||||
uint32_t level_count, uint32_t max_shared_memory) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
assert(glwe_dimension ==
|
||||
1); // For larger k we will need to adjust the mask size
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int memory_needed_per_block =
|
||||
|
||||
@@ -305,9 +305,9 @@ __host__ void host_wop_pbs(
|
||||
host_extract_bits<Torus, params>(
|
||||
v_stream, gpu_index, (Torus *)lwe_array_out_bit_extract, lwe_array_in,
|
||||
bit_extract_buffer, ksk, fourier_bsk, number_of_bits_to_extract,
|
||||
delta_log, polynomial_size, lwe_dimension, glwe_dimension,
|
||||
polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk,
|
||||
level_count_ksk, number_of_inputs, max_shared_memory);
|
||||
delta_log, glwe_dimension * polynomial_size, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log_bsk, level_count_bsk,
|
||||
base_log_ksk, level_count_ksk, number_of_inputs, max_shared_memory);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
int8_t *cbs_vp_buffer =
|
||||
|
||||
Reference in New Issue
Block a user