refactor(cuda): Implements support to k>1 on cmux tree.

This commit is contained in:
Pedro Alves
2023-02-22 13:03:36 -03:00
committed by Agnès Leroy
parent 184d453387
commit eb8aeb5a01

View File

@@ -15,19 +15,6 @@
#include "polynomial/polynomial_math.cuh"
#include "utils/timer.cuh"
template <class params> __device__ void fft(double2 *output) {
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(output);
synchronize_threads_in_block();
}
template <class params> __device__ void ifft_inplace(double2 *data) {
synchronize_threads_in_block();
NSMFFT_inverse<HalfDegree<params>>(data);
synchronize_threads_in_block();
}
/*
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
@@ -59,121 +46,96 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
// Define glwe_sub
Torus *glwe_sub_mask = (Torus *)selected_memory;
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
Torus *glwe_sub = (Torus *)selected_memory;
double2 *mask_res_fft = (double2 *)glwe_sub_body +
polynomial_size / (sizeof(double2) / sizeof(Torus));
double2 *body_res_fft =
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
double2 *res_fft =
(double2 *)glwe_sub +
(glwe_dim + 1) * polynomial_size / (sizeof(double2) / sizeof(Torus));
double2 *glwe_fft =
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
(double2 *)res_fft + (ptrdiff_t)((glwe_dim + 1) * polynomial_size / 2);
/////////////////////////////////////
// glwe2-glwe1
// Copy m0 to shared memory to preserve data
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
auto m0_body = m0_mask + polynomial_size;
// Gets the pointers for the global memory
auto m0 = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
auto m1 = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
// Just gets the pointer for m1 on global memory
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
auto m1_body = m1_mask + polynomial_size;
// Mask
sub_polynomial<Torus, params>(glwe_sub_mask, m1_mask, m0_mask);
// Body
sub_polynomial<Torus, params>(glwe_sub_body, m1_body, m0_body);
synchronize_threads_in_block();
// Subtraction: m1-m0
for (int i = 0; i < (glwe_dim + 1); i++) {
auto glwe_sub_slice = glwe_sub + i * params::degree;
auto m0_slice = m0 + i * params::degree;
auto m1_slice = m1 + i * params::degree;
sub_polynomial<Torus, params>(glwe_sub_slice, m1_slice, m0_slice);
}
// Initialize the polynomial multiplication via FFT arrays
// The polynomial multiplications happens at the block level
// and each thread handles two or more coefficients
int pos = threadIdx.x;
for (int j = 0; j < params::opt / 2; j++) {
mask_res_fft[pos].x = 0;
mask_res_fft[pos].y = 0;
body_res_fft[pos].x = 0;
body_res_fft[pos].y = 0;
for (int j = 0; j < (glwe_dim + 1) * params::opt / 2; j++) {
res_fft[pos].x = 0;
res_fft[pos].y = 0;
pos += params::degree / params::opt;
}
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count, glwe_sub_mask,
1);
GadgetMatrix<Torus, params> gadget_body(base_log, level_count, glwe_sub_body,
1);
synchronize_threads_in_block();
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
glwe_dim + 1);
// Subtract each glwe operand, decompose the resulting
// polynomial coefficients to multiply each decomposed level
// with the corresponding part of the LUT
for (int level = level_count - 1; level >= 0; level--) {
// Decomposition
gadget_mask.decompose_and_compress_next(glwe_fft);
// First, perform the polynomial multiplication for the mask
fft<params>(glwe_fft);
// External product and accumulate
// Get the piece necessary for the multiplication
auto mask_fourier = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
auto body_fourier = get_ith_body_kth_block(
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
gadget.decompose_and_compress_next(glwe_fft);
synchronize_threads_in_block();
for (int i = 0; i < (glwe_dim + 1); i++) {
auto glwe_fft_slice = glwe_fft + i * params::degree / 2;
// Perform the coefficient-wise product
// First, perform the polynomial multiplication
NSMFFT_direct<HalfDegree<params>>(glwe_fft_slice);
// External product and accumulate
// Get the piece necessary for the multiplication
auto bsk_slice = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, i, level, polynomial_size, glwe_dim, level_count);
synchronize_threads_in_block();
// Perform the coefficient-wise product
for (int j = 0; j < (glwe_dim + 1); j++) {
auto bsk_poly = bsk_slice + j * params::degree / 2;
auto res_fft_poly = res_fft + j * params::degree / 2;
polynomial_product_accumulate_in_fourier_domain<params, double2>(
res_fft_poly, glwe_fft_slice, bsk_poly);
}
}
synchronize_threads_in_block();
polynomial_product_accumulate_in_fourier_domain<params, double2>(
mask_res_fft, glwe_fft, mask_fourier);
polynomial_product_accumulate_in_fourier_domain<params, double2>(
body_res_fft, glwe_fft, body_fourier);
// Now handle the polynomial multiplication for the body
// in the same way
synchronize_threads_in_block();
gadget_body.decompose_and_compress_next(glwe_fft);
fft<params>(glwe_fft);
// External product and accumulate
// Get the piece necessary for the multiplication
mask_fourier = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
body_fourier = get_ith_body_kth_block(
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
synchronize_threads_in_block();
polynomial_product_accumulate_in_fourier_domain<params, double2>(
mask_res_fft, glwe_fft, mask_fourier);
polynomial_product_accumulate_in_fourier_domain<params, double2>(
body_res_fft, glwe_fft, body_fourier);
}
// IFFT
synchronize_threads_in_block();
ifft_inplace<params>(mask_res_fft);
ifft_inplace<params>(body_res_fft);
for (int i = 0; i < (glwe_dim + 1); i++) {
auto res_fft_slice = res_fft + i * params::degree / 2;
NSMFFT_inverse<HalfDegree<params>>(res_fft_slice);
}
synchronize_threads_in_block();
// Write the output
Torus *mb_mask =
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
Torus *mb_body = mb_mask + polynomial_size;
Torus *mb = &glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
int tid = threadIdx.x;
for (int i = 0; i < params::opt; i++) {
mb_mask[tid] = m0_mask[tid];
mb_body[tid] = m0_body[tid];
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
mb[tid] = m0[tid];
tid += params::degree / params::opt;
}
add_to_torus<Torus, params>(mask_res_fft, mb_mask);
add_to_torus<Torus, params>(body_res_fft, mb_body);
for (int i = 0; i < (glwe_dim + 1); i++) {
auto res_fft_slice = res_fft + i * params::degree / 2;
auto mb_slice = mb + i * params::degree;
add_to_torus<Torus, params>(res_fft_slice, mb_slice);
}
}
// Appends zeroed paddings between each LUT
@@ -187,8 +149,9 @@ __host__ void add_padding_to_lut_async(Torus *lut_out, Torus *lut_in,
*stream));
for (int i = 0; i < num_lut; i++)
check_cuda_error(cudaMemcpyAsync(
lut_out + (2 * i + 1) * params::degree, lut_in + i * params::degree,
params::degree * sizeof(Torus), cudaMemcpyDeviceToDevice, *stream));
lut_out + ((glwe_dimension + 1) * i + glwe_dimension) * params::degree,
lut_in + i * params::degree, params::degree * sizeof(Torus),
cudaMemcpyDeviceToDevice, *stream));
}
/**
@@ -221,6 +184,9 @@ __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in,
int tree_idx = blockIdx.y;
int tree_offset = tree_idx * num_lut * (glwe_dim + 1) * polynomial_size;
auto block_glwe_array_out = glwe_array_out + tree_offset;
auto block_glwe_array_in = glwe_array_in + tree_offset;
// The x-axis handles a single cmux tree. Each block computes one cmux.
int cmux_idx = blockIdx.x;
int output_idx = cmux_idx;
@@ -237,20 +203,21 @@ __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in,
selected_memory = &device_mem[(blockIdx.x + blockIdx.y * gridDim.x) *
device_memory_size_per_block];
cmux<Torus, STorus, params>(
glwe_array_out + tree_offset, glwe_array_in + tree_offset, ggsw_in,
selected_memory, output_idx, input_idx1, input_idx2, glwe_dim,
polynomial_size, base_log, level_count, ggsw_idx);
cmux<Torus, STorus, params>(block_glwe_array_out, block_glwe_array_in,
ggsw_in, selected_memory, output_idx, input_idx1,
input_idx2, glwe_dim, polynomial_size, base_log,
level_count, ggsw_idx);
}
template <typename Torus>
__host__ __device__ int
get_memory_needed_per_block_cmux_tree(uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size + // glwe_sub_mask
sizeof(Torus) * polynomial_size + // glwe_sub_body
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
sizeof(double2) * polynomial_size / 2 + // body_res_fft
sizeof(double2) * polynomial_size / 2; // glwe_fft
get_memory_needed_per_block_cmux_tree(uint32_t glwe_dimension,
uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // glwe_sub
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1) + // res_fft
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1); // glwe_fft
}
template <typename Torus>
@@ -259,8 +226,8 @@ get_buffer_size_cmux_tree(uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t r, uint32_t tau,
uint32_t max_shared_memory) {
int memory_needed_per_block =
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
glwe_dimension, polynomial_size);
int num_lut = (1 << r);
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
(glwe_dimension + 1) * level_count;
@@ -286,8 +253,8 @@ scratch_cmux_tree(void *v_stream, uint32_t gpu_index, int8_t **cmux_tree_buffer,
cudaSetDevice(gpu_index);
auto stream = static_cast<cudaStream_t *>(v_stream);
int memory_needed_per_block =
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
glwe_dimension, polynomial_size);
if (max_shared_memory >= memory_needed_per_block) {
check_cuda_error(cudaFuncSetAttribute(
device_batch_cmux<Torus, STorus, params, FULLSM>,
@@ -341,8 +308,8 @@ host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
return;
}
int memory_needed_per_block =
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
glwe_dimension, polynomial_size);
dim3 thds(polynomial_size / params::opt, 1, 1);
@@ -504,16 +471,10 @@ __global__ void device_blind_rotation_and_sample_extraction(
template <typename Torus>
__host__ __device__ int
get_memory_needed_per_block_blind_rotation_sample_extraction(
uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size + // accumulator_c0 mask
sizeof(Torus) * polynomial_size + // accumulator_c0 body
sizeof(Torus) * polynomial_size + // accumulator_c1 mask
sizeof(Torus) * polynomial_size + // accumulator_c1 body
sizeof(Torus) * polynomial_size + // glwe_sub_mask
sizeof(Torus) * polynomial_size + // glwe_sub_body
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
sizeof(double2) * polynomial_size / 2 + // body_res_fft
sizeof(double2) * polynomial_size / 2; // glwe_fft
uint32_t glwe_dimension, uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c0
sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c1
+ get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension, polynomial_size);
}
template <typename Torus>
@@ -523,7 +484,7 @@ __host__ __device__ int get_buffer_size_blind_rotation_sample_extraction(
int memory_needed_per_block =
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
polynomial_size);
glwe_dimension, polynomial_size);
int device_mem = 0;
if (max_shared_memory < memory_needed_per_block) {
device_mem = memory_needed_per_block * tau;
@@ -548,7 +509,7 @@ __host__ void scratch_blind_rotation_sample_extraction(
int memory_needed_per_block =
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
polynomial_size);
glwe_dimension, polynomial_size);
if (max_shared_memory >= memory_needed_per_block) {
check_cuda_error(cudaFuncSetAttribute(
device_blind_rotation_and_sample_extraction<Torus, STorus, params,
@@ -583,7 +544,7 @@ __host__ void host_blind_rotate_and_sample_extraction(
int memory_needed_per_block =
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
polynomial_size);
glwe_dimension, polynomial_size);
// Prepare the buffers
int ggsw_size = polynomial_size * (glwe_dimension + 1) *