mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(cuda): fix scratch functions to avoid misaligned pointers
This commit is contained in:
@@ -247,7 +247,7 @@ __host__ __device__ int get_buffer_size_bootstrap_amortized(
|
||||
} else if (max_shared_memory < full_sm) {
|
||||
device_mem = partial_dm * input_lwe_ciphertext_count;
|
||||
}
|
||||
return device_mem;
|
||||
return device_mem + device_mem % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
|
||||
@@ -277,9 +277,10 @@ __host__ __device__ int get_buffer_size_bootstrap_low_latency(
|
||||
device_mem = partial_dm * input_lwe_ciphertext_count * level_count *
|
||||
(glwe_dimension + 1);
|
||||
}
|
||||
return device_mem + (glwe_dimension + 1) * level_count *
|
||||
input_lwe_ciphertext_count * polynomial_size / 2 *
|
||||
sizeof(double2);
|
||||
int buffer_size = device_mem + (glwe_dimension + 1) * level_count *
|
||||
input_lwe_ciphertext_count *
|
||||
polynomial_size / 2 * sizeof(double2);
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
|
||||
@@ -106,16 +106,17 @@ get_buffer_size_cbs(uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t level_count_cbs,
|
||||
uint32_t number_of_inputs) {
|
||||
|
||||
return number_of_inputs * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_fp_ks_buffer
|
||||
number_of_inputs * level_count_cbs *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_out_pbs_buffer
|
||||
number_of_inputs * level_count_cbs * (lwe_dimension + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_shifted_buffer
|
||||
level_count_cbs * (glwe_dimension + 1) * polynomial_size *
|
||||
sizeof(Torus); // lut_vector_cbs
|
||||
int buffer_size = number_of_inputs * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_fp_ks_buffer
|
||||
number_of_inputs * level_count_cbs *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_out_pbs_buffer
|
||||
number_of_inputs * level_count_cbs * (lwe_dimension + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_shifted_buffer
|
||||
level_count_cbs * (glwe_dimension + 1) * polynomial_size *
|
||||
sizeof(Torus); // lut_vector_cbs
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
|
||||
@@ -238,9 +238,10 @@ get_buffer_size_cmux_tree(uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
if (max_shared_memory < polynomial_size * sizeof(double)) {
|
||||
device_mem += polynomial_size * sizeof(double);
|
||||
}
|
||||
return r * ggsw_size * sizeof(double) +
|
||||
num_lut * tau * glwe_size * sizeof(Torus) +
|
||||
num_lut * tau * glwe_size * sizeof(Torus) + device_mem;
|
||||
int buffer_size = r * ggsw_size * sizeof(double) +
|
||||
num_lut * tau * glwe_size * sizeof(Torus) +
|
||||
num_lut * tau * glwe_size * sizeof(Torus) + device_mem;
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
@@ -315,29 +316,32 @@ host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
||||
//////////////////////
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
int glwe_size = (glwe_dimension + 1) * polynomial_size;
|
||||
|
||||
// Define the buffers
|
||||
// Always define the buffers with strongest memory alignment constraints first
|
||||
// d_buffer1 and d_buffer2 are aligned with Torus, so they're defined last
|
||||
double2 *d_ggsw_fft_in = (double2 *)cmux_tree_buffer;
|
||||
|
||||
int8_t *d_mem_fft =
|
||||
int8_t *d_mem =
|
||||
(int8_t *)d_ggsw_fft_in + (ptrdiff_t)(r * ggsw_size * sizeof(double));
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
stream, d_ggsw_fft_in, ggsw_in, d_mem_fft, r, glwe_dimension,
|
||||
polynomial_size, level_count, gpu_index, max_shared_memory);
|
||||
|
||||
//////////////////////
|
||||
|
||||
// Allocate global memory in case parameters are too large
|
||||
int8_t *d_mem_fft = d_mem;
|
||||
if (max_shared_memory < memory_needed_per_block) {
|
||||
d_mem_fft =
|
||||
d_mem + (ptrdiff_t)(memory_needed_per_block * (1 << (r - 1)) * tau);
|
||||
}
|
||||
int8_t *d_buffer1 = d_mem_fft;
|
||||
if (max_shared_memory < polynomial_size * sizeof(double)) {
|
||||
d_buffer1 = d_mem_fft + (ptrdiff_t)(polynomial_size * sizeof(double));
|
||||
}
|
||||
|
||||
// Allocate buffers
|
||||
int glwe_size = (glwe_dimension + 1) * polynomial_size;
|
||||
|
||||
int8_t *d_buffer2 =
|
||||
d_buffer1 + (ptrdiff_t)(num_lut * tau * glwe_size * sizeof(Torus));
|
||||
|
||||
//////////////////////
|
||||
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
stream, d_ggsw_fft_in, ggsw_in, d_mem_fft, r, glwe_dimension,
|
||||
polynomial_size, level_count, gpu_index, max_shared_memory);
|
||||
|
||||
add_padding_to_lut_async<Torus, params>(
|
||||
(Torus *)d_buffer1, lut_vector, glwe_dimension, num_lut * tau, stream);
|
||||
|
||||
@@ -350,9 +354,6 @@ host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
||||
int num_cmuxes = (1 << (r - 1 - layer_idx));
|
||||
dim3 grid(num_cmuxes, tau, 1);
|
||||
|
||||
int8_t *d_mem =
|
||||
d_buffer2 + (ptrdiff_t)(num_lut * tau * glwe_size * sizeof(Torus));
|
||||
|
||||
// walks horizontally through the leaves
|
||||
if (max_shared_memory < memory_needed_per_block) {
|
||||
device_batch_cmux<Torus, STorus, params, NOSM>
|
||||
@@ -494,8 +495,9 @@ __host__ __device__ int get_buffer_size_blind_rotation_sample_extraction(
|
||||
}
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
return mbr_size * ggsw_size * sizeof(double) // d_ggsw_fft_in
|
||||
+ device_mem;
|
||||
int buffer_size = mbr_size * ggsw_size * sizeof(double) // d_ggsw_fft_in
|
||||
+ device_mem;
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
@@ -545,6 +547,7 @@ __host__ void host_blind_rotate_and_sample_extraction(
|
||||
glwe_dimension, polynomial_size);
|
||||
|
||||
// Prepare the buffers
|
||||
// Here all the buffers have double2 alignment
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
double2 *d_ggsw_fft_in = (double2 *)br_se_buffer;
|
||||
|
||||
@@ -34,11 +34,12 @@ get_buffer_size_cbs_vp(uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
|
||||
int ggsw_size = level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size;
|
||||
return number_of_inputs * level_count_cbs *
|
||||
sizeof(Torus) + // lut_vector_indexes
|
||||
number_of_inputs * ggsw_size * sizeof(Torus) + // ggsw_out_cbs
|
||||
tau * (glwe_dimension + 1) * polynomial_size *
|
||||
sizeof(Torus); // glwe_array_out_cmux_tree
|
||||
int buffer_size =
|
||||
number_of_inputs * level_count_cbs * sizeof(Torus) + // lut_vector_indexes
|
||||
number_of_inputs * ggsw_size * sizeof(Torus) + // ggsw_out_cbs
|
||||
tau * (glwe_dimension + 1) * polynomial_size *
|
||||
sizeof(Torus); // glwe_array_out_cmux_tree
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
@@ -57,24 +58,23 @@ __host__ void scratch_circuit_bootstrap_vertical_packing(
|
||||
(Torus *)malloc(number_of_inputs * level_count_cbs * sizeof(Torus));
|
||||
uint32_t r = number_of_inputs - params::log2_degree;
|
||||
uint32_t mbr_size = number_of_inputs - r;
|
||||
// allocate and initialize device pointers for circuit bootstrap and vertical
|
||||
int buffer_size =
|
||||
get_buffer_size_cbs_vp<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, tau, number_of_inputs) +
|
||||
get_buffer_size_cbs<Torus>(glwe_dimension, lwe_dimension, polynomial_size,
|
||||
level_count_cbs, number_of_inputs) +
|
||||
get_buffer_size_bootstrap_amortized<Torus>(
|
||||
glwe_dimension, polynomial_size, number_of_inputs * level_count_cbs,
|
||||
max_shared_memory) +
|
||||
get_buffer_size_cmux_tree<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, r, tau,
|
||||
max_shared_memory) +
|
||||
get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs, mbr_size, tau,
|
||||
max_shared_memory);
|
||||
// allocate device pointer for circuit bootstrap and vertical
|
||||
// packing
|
||||
if (allocate_gpu_memory) {
|
||||
int buffer_size =
|
||||
get_buffer_size_cbs_vp<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, tau, number_of_inputs) +
|
||||
get_buffer_size_cbs<Torus>(glwe_dimension, lwe_dimension,
|
||||
polynomial_size, level_count_cbs,
|
||||
number_of_inputs) +
|
||||
get_buffer_size_bootstrap_amortized<Torus>(
|
||||
glwe_dimension, polynomial_size, number_of_inputs * level_count_cbs,
|
||||
max_shared_memory) +
|
||||
get_buffer_size_cmux_tree<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, r, tau,
|
||||
max_shared_memory) +
|
||||
get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs, mbr_size, tau,
|
||||
max_shared_memory);
|
||||
*cbs_vp_buffer =
|
||||
(int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index);
|
||||
}
|
||||
@@ -82,10 +82,14 @@ __host__ void scratch_circuit_bootstrap_vertical_packing(
|
||||
for (uint index = 0; index < level_count_cbs * number_of_inputs; index++) {
|
||||
h_lut_vector_indexes[index] = index % level_count_cbs;
|
||||
}
|
||||
// lut_vector_indexes is the first buffer in the cbs_vp_buffer
|
||||
cuda_memcpy_async_to_gpu((Torus *)*cbs_vp_buffer, h_lut_vector_indexes,
|
||||
number_of_inputs * level_count_cbs * sizeof(Torus),
|
||||
stream, gpu_index);
|
||||
// lut_vector_indexes is the last buffer in the cbs_vp_buffer
|
||||
int lut_vector_indexes_size =
|
||||
number_of_inputs * level_count_cbs * sizeof(Torus);
|
||||
int8_t *d_lut_vector_indexes =
|
||||
(int8_t *)*cbs_vp_buffer +
|
||||
(ptrdiff_t)(buffer_size - lut_vector_indexes_size);
|
||||
cuda_memcpy_async_to_gpu((Torus *)d_lut_vector_indexes, h_lut_vector_indexes,
|
||||
lut_vector_indexes_size, stream, gpu_index);
|
||||
check_cuda_error(cudaStreamSynchronize(*stream));
|
||||
free(h_lut_vector_indexes);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
@@ -117,13 +121,16 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
uint32_t level_count_pksk, uint32_t base_log_cbs, uint32_t level_count_cbs,
|
||||
uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory) {
|
||||
|
||||
// Define the buffers
|
||||
// Always define the buffers with strongest memory alignment requirement first
|
||||
// Here the only requirement is that lut_vector_indexes should be defined
|
||||
// last, since all the other buffers are aligned with double2 (all buffers
|
||||
// with a size that's a multiple of polynomial_size * sizeof(Torus) are
|
||||
// aligned with double2)
|
||||
int ggsw_size = level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size;
|
||||
|
||||
Torus *lut_vector_indexes = (Torus *)cbs_vp_buffer;
|
||||
int8_t *cbs_buffer =
|
||||
(int8_t *)lut_vector_indexes +
|
||||
(ptrdiff_t)(number_of_inputs * level_count_cbs * sizeof(Torus));
|
||||
int8_t *cbs_buffer = (int8_t *)cbs_vp_buffer;
|
||||
int8_t *ggsw_out_cbs =
|
||||
cbs_buffer +
|
||||
(ptrdiff_t)(get_buffer_size_cbs<Torus>(glwe_dimension, lwe_dimension,
|
||||
@@ -132,14 +139,6 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
get_buffer_size_bootstrap_amortized<Torus>(
|
||||
glwe_dimension, polynomial_size,
|
||||
number_of_inputs * level_count_cbs, max_shared_memory));
|
||||
host_circuit_bootstrap<Torus, params>(
|
||||
v_stream, gpu_index, (Torus *)ggsw_out_cbs, lwe_array_in, fourier_bsk,
|
||||
cbs_fpksk, lut_vector_indexes, cbs_buffer, cbs_delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_count_bsk, base_log_bsk,
|
||||
level_count_pksk, base_log_pksk, level_count_cbs, base_log_cbs,
|
||||
number_of_inputs, max_shared_memory);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// number_of_inputs = tau * p is the total number of GGSWs
|
||||
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
|
||||
// lsb GGSW is for the last blind rotation.
|
||||
@@ -150,6 +149,25 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
cmux_tree_buffer + (ptrdiff_t)(get_buffer_size_cmux_tree<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs,
|
||||
r, tau, max_shared_memory));
|
||||
int8_t *br_se_buffer =
|
||||
glwe_array_out_cmux_tree +
|
||||
(ptrdiff_t)(tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
Torus *lut_vector_indexes =
|
||||
(Torus *)br_se_buffer +
|
||||
(ptrdiff_t)(get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs,
|
||||
number_of_inputs - r, tau, max_shared_memory) /
|
||||
sizeof(Torus));
|
||||
|
||||
// Circuit bootstrap
|
||||
host_circuit_bootstrap<Torus, params>(
|
||||
v_stream, gpu_index, (Torus *)ggsw_out_cbs, lwe_array_in, fourier_bsk,
|
||||
cbs_fpksk, lut_vector_indexes, cbs_buffer, cbs_delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_count_bsk, base_log_bsk,
|
||||
level_count_pksk, base_log_pksk, level_count_cbs, base_log_cbs,
|
||||
number_of_inputs, max_shared_memory);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
// CMUX Tree
|
||||
// r = tau * p - log2(N)
|
||||
host_cmux_tree<Torus, STorus, params>(
|
||||
@@ -161,12 +179,11 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
|
||||
// Blind rotation + sample extraction
|
||||
// mbr = tau * p - r = log2(N)
|
||||
// br_ggsw is a pointer to a sub-part of the ggsw_out_cbs buffer, for the
|
||||
// blind rotation
|
||||
Torus *br_ggsw = (Torus *)ggsw_out_cbs +
|
||||
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size);
|
||||
int8_t *br_se_buffer =
|
||||
glwe_array_out_cmux_tree +
|
||||
(ptrdiff_t)(tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, lwe_array_out, br_ggsw,
|
||||
(Torus *)glwe_array_out_cmux_tree, br_se_buffer, number_of_inputs - r,
|
||||
@@ -179,8 +196,10 @@ __host__ __device__ int
|
||||
get_buffer_size_wop_pbs(uint32_t lwe_dimension,
|
||||
uint32_t number_of_bits_of_message_including_padding) {
|
||||
|
||||
return (lwe_dimension + 1) * (number_of_bits_of_message_including_padding) *
|
||||
sizeof(Torus); // lwe_array_out_bit_extract
|
||||
int buffer_size = (lwe_dimension + 1) *
|
||||
(number_of_bits_of_message_including_padding) *
|
||||
sizeof(Torus); // lwe_array_out_bit_extract
|
||||
return buffer_size + buffer_size % sizeof(double2);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
|
||||
Reference in New Issue
Block a user