mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(cuda): Implements support to k>1 on cmux tree.
This commit is contained in:
@@ -15,19 +15,6 @@
|
||||
#include "polynomial/polynomial_math.cuh"
|
||||
#include "utils/timer.cuh"
|
||||
|
||||
template <class params> __device__ void fft(double2 *output) {
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(output);
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
template <class params> __device__ void ifft_inplace(double2 *data) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_inverse<HalfDegree<params>>(data);
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
/*
|
||||
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
|
||||
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
|
||||
@@ -59,121 +46,96 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
||||
|
||||
// Define glwe_sub
|
||||
Torus *glwe_sub_mask = (Torus *)selected_memory;
|
||||
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
|
||||
Torus *glwe_sub = (Torus *)selected_memory;
|
||||
|
||||
double2 *mask_res_fft = (double2 *)glwe_sub_body +
|
||||
polynomial_size / (sizeof(double2) / sizeof(Torus));
|
||||
double2 *body_res_fft =
|
||||
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
|
||||
double2 *res_fft =
|
||||
(double2 *)glwe_sub +
|
||||
(glwe_dim + 1) * polynomial_size / (sizeof(double2) / sizeof(Torus));
|
||||
|
||||
double2 *glwe_fft =
|
||||
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
||||
(double2 *)res_fft + (ptrdiff_t)((glwe_dim + 1) * polynomial_size / 2);
|
||||
|
||||
/////////////////////////////////////
|
||||
|
||||
// glwe2-glwe1
|
||||
|
||||
// Copy m0 to shared memory to preserve data
|
||||
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m0_body = m0_mask + polynomial_size;
|
||||
// Gets the pointers for the global memory
|
||||
auto m0 = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m1 = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
||||
|
||||
// Just gets the pointer for m1 on global memory
|
||||
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m1_body = m1_mask + polynomial_size;
|
||||
|
||||
// Mask
|
||||
sub_polynomial<Torus, params>(glwe_sub_mask, m1_mask, m0_mask);
|
||||
// Body
|
||||
sub_polynomial<Torus, params>(glwe_sub_body, m1_body, m0_body);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
// Subtraction: m1-m0
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto glwe_sub_slice = glwe_sub + i * params::degree;
|
||||
auto m0_slice = m0 + i * params::degree;
|
||||
auto m1_slice = m1 + i * params::degree;
|
||||
sub_polynomial<Torus, params>(glwe_sub_slice, m1_slice, m0_slice);
|
||||
}
|
||||
|
||||
// Initialize the polynomial multiplication via FFT arrays
|
||||
// The polynomial multiplications happens at the block level
|
||||
// and each thread handles two or more coefficients
|
||||
int pos = threadIdx.x;
|
||||
for (int j = 0; j < params::opt / 2; j++) {
|
||||
mask_res_fft[pos].x = 0;
|
||||
mask_res_fft[pos].y = 0;
|
||||
body_res_fft[pos].x = 0;
|
||||
body_res_fft[pos].y = 0;
|
||||
for (int j = 0; j < (glwe_dim + 1) * params::opt / 2; j++) {
|
||||
res_fft[pos].x = 0;
|
||||
res_fft[pos].y = 0;
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
|
||||
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count, glwe_sub_mask,
|
||||
1);
|
||||
GadgetMatrix<Torus, params> gadget_body(base_log, level_count, glwe_sub_body,
|
||||
1);
|
||||
synchronize_threads_in_block();
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
|
||||
glwe_dim + 1);
|
||||
// Subtract each glwe operand, decompose the resulting
|
||||
// polynomial coefficients to multiply each decomposed level
|
||||
// with the corresponding part of the LUT
|
||||
for (int level = level_count - 1; level >= 0; level--) {
|
||||
|
||||
// Decomposition
|
||||
gadget_mask.decompose_and_compress_next(glwe_fft);
|
||||
|
||||
// First, perform the polynomial multiplication for the mask
|
||||
fft<params>(glwe_fft);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
auto mask_fourier = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
||||
auto body_fourier = get_ith_body_kth_block(
|
||||
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
gadget.decompose_and_compress_next(glwe_fft);
|
||||
synchronize_threads_in_block();
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto glwe_fft_slice = glwe_fft + i * params::degree / 2;
|
||||
|
||||
// Perform the coefficient-wise product
|
||||
// First, perform the polynomial multiplication
|
||||
NSMFFT_direct<HalfDegree<params>>(glwe_fft_slice);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
auto bsk_slice = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, i, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
// Perform the coefficient-wise product
|
||||
for (int j = 0; j < (glwe_dim + 1); j++) {
|
||||
auto bsk_poly = bsk_slice + j * params::degree / 2;
|
||||
auto res_fft_poly = res_fft + j * params::degree / 2;
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
res_fft_poly, glwe_fft_slice, bsk_poly);
|
||||
}
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
mask_res_fft, glwe_fft, mask_fourier);
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
body_res_fft, glwe_fft, body_fourier);
|
||||
|
||||
// Now handle the polynomial multiplication for the body
|
||||
// in the same way
|
||||
synchronize_threads_in_block();
|
||||
|
||||
gadget_body.decompose_and_compress_next(glwe_fft);
|
||||
fft<params>(glwe_fft);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
mask_fourier = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
||||
body_fourier = get_ith_body_kth_block(
|
||||
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
mask_res_fft, glwe_fft, mask_fourier);
|
||||
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
||||
body_res_fft, glwe_fft, body_fourier);
|
||||
}
|
||||
|
||||
// IFFT
|
||||
synchronize_threads_in_block();
|
||||
ifft_inplace<params>(mask_res_fft);
|
||||
ifft_inplace<params>(body_res_fft);
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto res_fft_slice = res_fft + i * params::degree / 2;
|
||||
NSMFFT_inverse<HalfDegree<params>>(res_fft_slice);
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Write the output
|
||||
Torus *mb_mask =
|
||||
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
||||
Torus *mb_body = mb_mask + polynomial_size;
|
||||
Torus *mb = &glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
mb_mask[tid] = m0_mask[tid];
|
||||
mb_body[tid] = m0_body[tid];
|
||||
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
|
||||
mb[tid] = m0[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
add_to_torus<Torus, params>(mask_res_fft, mb_mask);
|
||||
add_to_torus<Torus, params>(body_res_fft, mb_body);
|
||||
for (int i = 0; i < (glwe_dim + 1); i++) {
|
||||
auto res_fft_slice = res_fft + i * params::degree / 2;
|
||||
auto mb_slice = mb + i * params::degree;
|
||||
add_to_torus<Torus, params>(res_fft_slice, mb_slice);
|
||||
}
|
||||
}
|
||||
|
||||
// Appends zeroed paddings between each LUT
|
||||
@@ -187,8 +149,9 @@ __host__ void add_padding_to_lut_async(Torus *lut_out, Torus *lut_in,
|
||||
*stream));
|
||||
for (int i = 0; i < num_lut; i++)
|
||||
check_cuda_error(cudaMemcpyAsync(
|
||||
lut_out + (2 * i + 1) * params::degree, lut_in + i * params::degree,
|
||||
params::degree * sizeof(Torus), cudaMemcpyDeviceToDevice, *stream));
|
||||
lut_out + ((glwe_dimension + 1) * i + glwe_dimension) * params::degree,
|
||||
lut_in + i * params::degree, params::degree * sizeof(Torus),
|
||||
cudaMemcpyDeviceToDevice, *stream));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -221,6 +184,9 @@ __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in,
|
||||
int tree_idx = blockIdx.y;
|
||||
int tree_offset = tree_idx * num_lut * (glwe_dim + 1) * polynomial_size;
|
||||
|
||||
auto block_glwe_array_out = glwe_array_out + tree_offset;
|
||||
auto block_glwe_array_in = glwe_array_in + tree_offset;
|
||||
|
||||
// The x-axis handles a single cmux tree. Each block computes one cmux.
|
||||
int cmux_idx = blockIdx.x;
|
||||
int output_idx = cmux_idx;
|
||||
@@ -237,20 +203,21 @@ __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in,
|
||||
selected_memory = &device_mem[(blockIdx.x + blockIdx.y * gridDim.x) *
|
||||
device_memory_size_per_block];
|
||||
|
||||
cmux<Torus, STorus, params>(
|
||||
glwe_array_out + tree_offset, glwe_array_in + tree_offset, ggsw_in,
|
||||
selected_memory, output_idx, input_idx1, input_idx2, glwe_dim,
|
||||
polynomial_size, base_log, level_count, ggsw_idx);
|
||||
cmux<Torus, STorus, params>(block_glwe_array_out, block_glwe_array_in,
|
||||
ggsw_in, selected_memory, output_idx, input_idx1,
|
||||
input_idx2, glwe_dim, polynomial_size, base_log,
|
||||
level_count, ggsw_idx);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __device__ int
|
||||
get_memory_needed_per_block_cmux_tree(uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
get_memory_needed_per_block_cmux_tree(uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // glwe_sub
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1) + // res_fft
|
||||
sizeof(double2) * polynomial_size / 2 *
|
||||
(glwe_dimension + 1); // glwe_fft
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
@@ -259,8 +226,8 @@ get_buffer_size_cmux_tree(uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t level_count, uint32_t r, uint32_t tau,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
|
||||
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
|
||||
glwe_dimension, polynomial_size);
|
||||
int num_lut = (1 << r);
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
@@ -286,8 +253,8 @@ scratch_cmux_tree(void *v_stream, uint32_t gpu_index, int8_t **cmux_tree_buffer,
|
||||
cudaSetDevice(gpu_index);
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
|
||||
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
|
||||
glwe_dimension, polynomial_size);
|
||||
if (max_shared_memory >= memory_needed_per_block) {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_batch_cmux<Torus, STorus, params, FULLSM>,
|
||||
@@ -341,8 +308,8 @@ host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
||||
return;
|
||||
}
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_cmux_tree<Torus>(polynomial_size);
|
||||
int memory_needed_per_block = get_memory_needed_per_block_cmux_tree<Torus>(
|
||||
glwe_dimension, polynomial_size);
|
||||
|
||||
dim3 thds(polynomial_size / params::opt, 1, 1);
|
||||
|
||||
@@ -504,16 +471,10 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
template <typename Torus>
|
||||
__host__ __device__ int
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction(
|
||||
uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size + // accumulator_c0 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c0 body
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 body
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c0
|
||||
sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c1
|
||||
+ get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension, polynomial_size);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
@@ -523,7 +484,7 @@ __host__ __device__ int get_buffer_size_blind_rotation_sample_extraction(
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
glwe_dimension, polynomial_size);
|
||||
int device_mem = 0;
|
||||
if (max_shared_memory < memory_needed_per_block) {
|
||||
device_mem = memory_needed_per_block * tau;
|
||||
@@ -548,7 +509,7 @@ __host__ void scratch_blind_rotation_sample_extraction(
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
glwe_dimension, polynomial_size);
|
||||
if (max_shared_memory >= memory_needed_per_block) {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_blind_rotation_and_sample_extraction<Torus, STorus, params,
|
||||
@@ -583,7 +544,7 @@ __host__ void host_blind_rotate_and_sample_extraction(
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
glwe_dimension, polynomial_size);
|
||||
|
||||
// Prepare the buffers
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
|
||||
Reference in New Issue
Block a user