mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
fix(cuda): Fix the CUDA test for CBS+VP when tau > 1 and tau * p == log2(N).
This commit is contained in:
@@ -276,8 +276,8 @@ void host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
||||
int num_lut = (1 << r);
|
||||
if (r == 0) {
|
||||
// Simply copy the LUTs
|
||||
add_padding_to_lut_async<Torus, params>(
|
||||
glwe_array_out, lut_vector, glwe_dimension, num_lut * tau, stream);
|
||||
add_padding_to_lut_async<Torus, params>(glwe_array_out, lut_vector,
|
||||
glwe_dimension, tau, stream);
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
Torus *cbs_fpksk, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_pksk, uint32_t level_count_pksk, uint32_t base_log_cbs,
|
||||
uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t lut_number,
|
||||
uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t tau,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
@@ -100,46 +100,31 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
free(h_lut_vector_indexes);
|
||||
|
||||
// number_of_inputs = tau * p is the total number of GGSWs
|
||||
if (number_of_inputs > params::log2_degree) {
|
||||
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
|
||||
// lsb GGSW is for the last blind rotation.
|
||||
uint32_t r = number_of_inputs - params::log2_degree;
|
||||
Torus *br_ggsw = (Torus *)ggsw_out +
|
||||
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size);
|
||||
Torus *glwe_array_out = (Torus *)cuda_malloc_async(
|
||||
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
stream, gpu_index);
|
||||
// CMUX Tree
|
||||
// r = tau * p - log2(N)
|
||||
host_cmux_tree<Torus, STorus, params>(
|
||||
v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector,
|
||||
glwe_dimension, polynomial_size, base_log_cbs, level_count_cbs, r,
|
||||
lut_number, max_shared_memory);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
cuda_drop_async(glwe_array_out, stream, gpu_index);
|
||||
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
|
||||
// lsb GGSW is for the last blind rotation.
|
||||
uint32_t r = number_of_inputs - params::log2_degree;
|
||||
Torus *glwe_array_out = (Torus *)cuda_malloc_async(
|
||||
tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus), stream,
|
||||
gpu_index);
|
||||
// CMUX Tree
|
||||
// r = tau * p - log2(N)
|
||||
host_cmux_tree<Torus, STorus, params>(
|
||||
v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector, glwe_dimension,
|
||||
polynomial_size, base_log_cbs, level_count_cbs, r, tau,
|
||||
max_shared_memory);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
|
||||
// Blind rotation + sample extraction
|
||||
// mbr = tau * p - r = log2(N)
|
||||
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
|
||||
number_of_inputs - r, lut_number, glwe_dimension, polynomial_size,
|
||||
base_log_cbs, level_count_cbs, max_shared_memory);
|
||||
} else {
|
||||
// we need to expand the lut to fill the masks with zeros
|
||||
Torus *lut_vector_glwe = (Torus *)cuda_malloc_async(
|
||||
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
stream, gpu_index);
|
||||
add_padding_to_lut_async<Torus, params>(lut_vector_glwe, lut_vector,
|
||||
glwe_dimension, lut_number, stream);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
// Blind rotation + sample extraction
|
||||
// mbr = tau * p - r = log2(N)
|
||||
Torus *br_ggsw = (Torus *)ggsw_out +
|
||||
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size);
|
||||
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
|
||||
number_of_inputs - r, tau, glwe_dimension, polynomial_size, base_log_cbs,
|
||||
level_count_cbs, max_shared_memory);
|
||||
|
||||
// Blind rotation + sample extraction
|
||||
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, lwe_array_out, ggsw_out, lut_vector_glwe,
|
||||
number_of_inputs, lut_number, glwe_dimension, polynomial_size,
|
||||
base_log_cbs, level_count_cbs, max_shared_memory);
|
||||
}
|
||||
cuda_drop_async(glwe_array_out, stream, gpu_index);
|
||||
cuda_drop_async(ggsw_out, stream, gpu_index);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user