mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(cuda): introduce scratch for blind rotation and sample extraction
This commit is contained in:
@@ -38,11 +38,26 @@ void cuda_cmux_tree_64(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void cleanup_cuda_cmux_tree(void *v_stream, uint32_t gpu_index,
|
||||
int8_t **cmux_tree_buffer);
|
||||
|
||||
void scratch_cuda_blind_rotation_sample_extraction_32(
|
||||
void *v_stream, uint32_t gpu_index, int8_t **br_se_buffer,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory,
|
||||
bool allocate_gpu_memory);
|
||||
|
||||
void scratch_cuda_blind_rotation_sample_extraction_64(
|
||||
void *v_stream, uint32_t gpu_index, int8_t **br_se_buffer,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory,
|
||||
bool allocate_gpu_memory);
|
||||
|
||||
void cuda_blind_rotate_and_sample_extraction_64(
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
||||
uint32_t max_shared_memory);
|
||||
}
|
||||
void *lut_vector, int8_t *br_se_buffer, uint32_t mbr_size, uint32_t tau,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t max_shared_memory);
|
||||
|
||||
void cleanup_cuda_blind_rotation_sample_extraction(void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
int8_t **br_se_buffer);
|
||||
}
|
||||
#endif // VERTICAL_PACKING_H
|
||||
|
||||
@@ -239,6 +239,90 @@ void cleanup_cuda_cmux_tree(void *v_stream, uint32_t gpu_index,
|
||||
cuda_drop_async(*cmux_tree_buffer, stream, gpu_index);
|
||||
}
|
||||
|
||||
/*
|
||||
* This scratch function allocates the necessary amount of data on the GPU for
|
||||
* the Cmux tree on 32 bits inputs, into `br_se_buffer`. It also configures
|
||||
* SM options on the GPU in case FULLSM mode is going to be used.
|
||||
*/
|
||||
void scratch_cuda_blind_rotation_sample_extraction_32(
|
||||
void *v_stream, uint32_t gpu_index, int8_t **br_se_buffer,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory,
|
||||
bool allocate_gpu_memory) {
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
scratch_blind_rotation_sample_extraction<uint32_t, int32_t, Degree<512>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 1024:
|
||||
scratch_blind_rotation_sample_extraction<uint32_t, int32_t, Degree<1024>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 2048:
|
||||
scratch_blind_rotation_sample_extraction<uint32_t, int32_t, Degree<2048>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 4096:
|
||||
scratch_blind_rotation_sample_extraction<uint32_t, int32_t, Degree<4096>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 8192:
|
||||
scratch_blind_rotation_sample_extraction<uint32_t, int32_t, Degree<8192>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* This scratch function allocates the necessary amount of data on the GPU for
|
||||
* the Cmux tree on 64 bits inputs, into `br_se_buffer`. It also configures
|
||||
* SM options on the GPU in case FULLSM mode is going to be used.
|
||||
*/
|
||||
void scratch_cuda_blind_rotation_sample_extraction_64(
|
||||
void *v_stream, uint32_t gpu_index, int8_t **br_se_buffer,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory,
|
||||
bool allocate_gpu_memory) {
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
scratch_blind_rotation_sample_extraction<uint64_t, int64_t, Degree<512>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 1024:
|
||||
scratch_blind_rotation_sample_extraction<uint64_t, int64_t, Degree<1024>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 2048:
|
||||
scratch_blind_rotation_sample_extraction<uint64_t, int64_t, Degree<2048>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 4096:
|
||||
scratch_blind_rotation_sample_extraction<uint64_t, int64_t, Degree<4096>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
case 8192:
|
||||
scratch_blind_rotation_sample_extraction<uint64_t, int64_t, Degree<8192>>(
|
||||
v_stream, gpu_index, br_se_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, mbr_size, tau, max_shared_memory, allocate_gpu_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Performs blind rotation on batch of 64-bit input ggsw ciphertexts
|
||||
* - `v_stream` is a void pointer to the Cuda stream to be used in the kernel
|
||||
@@ -264,40 +348,52 @@ void cleanup_cuda_cmux_tree(void *v_stream, uint32_t gpu_index,
|
||||
*/
|
||||
void cuda_blind_rotate_and_sample_extraction_64(
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
||||
uint32_t max_shared_memory) {
|
||||
void *lut_vector, int8_t *br_se_buffer, uint32_t mbr_size, uint32_t tau,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t max_shared_memory) {
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<512>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
(uint64_t *)lut_vector, br_se_buffer, mbr_size, tau, glwe_dimension,
|
||||
polynomial_size, base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<1024>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
(uint64_t *)lut_vector, br_se_buffer, mbr_size, tau, glwe_dimension,
|
||||
polynomial_size, base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<2048>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
(uint64_t *)lut_vector, br_se_buffer, mbr_size, tau, glwe_dimension,
|
||||
polynomial_size, base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<4096>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
(uint64_t *)lut_vector, br_se_buffer, mbr_size, tau, glwe_dimension,
|
||||
polynomial_size, base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<8192>>(
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
(uint64_t *)lut_vector, br_se_buffer, mbr_size, tau, glwe_dimension,
|
||||
polynomial_size, base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* This cleanup function frees the data for the blind rotation and sample
|
||||
* extraction on GPU in br_se_buffer for 32 or 64 bits inputs.
|
||||
*/
|
||||
void cleanup_cuda_blind_rotation_sample_extraction(void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
int8_t **br_se_buffer) {
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
// Free memory
|
||||
cuda_drop_async(*br_se_buffer, stream, gpu_index);
|
||||
}
|
||||
|
||||
@@ -502,34 +502,55 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
sample_extract_body<Torus, params>(block_lwe_out, accumulator_c0, 1);
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, class params>
|
||||
__host__ void host_blind_rotate_and_sample_extraction(
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_out, Torus *ggsw_in,
|
||||
Torus *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t max_shared_memory) {
|
||||
template <typename Torus>
|
||||
__host__ __device__ int
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction(
|
||||
uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size + // accumulator_c0 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c0 body
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 body
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __device__ int get_buffer_size_blind_rotation_sample_extraction(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory) {
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
int device_mem = 0;
|
||||
if (max_shared_memory < memory_needed_per_block) {
|
||||
device_mem = memory_needed_per_block * tau;
|
||||
}
|
||||
if (max_shared_memory < polynomial_size * sizeof(double)) {
|
||||
device_mem += polynomial_size * sizeof(double);
|
||||
}
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
return mbr_size * ggsw_size * sizeof(double) // d_ggsw_fft_in
|
||||
+ device_mem;
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, typename params>
|
||||
__host__ void scratch_blind_rotation_sample_extraction(
|
||||
void *v_stream, uint32_t gpu_index, int8_t **br_se_buffer,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t max_shared_memory,
|
||||
bool allocate_gpu_memory) {
|
||||
cudaSetDevice(gpu_index);
|
||||
assert(glwe_dimension ==
|
||||
1); // For larger k we will need to adjust the mask size
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int memory_needed_per_block =
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c0 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c0 body
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 body
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
|
||||
int8_t *d_mem;
|
||||
if (max_shared_memory < memory_needed_per_block)
|
||||
d_mem = (int8_t *)cuda_malloc_async(memory_needed_per_block * tau, stream,
|
||||
gpu_index);
|
||||
else {
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
if (max_shared_memory >= memory_needed_per_block) {
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_blind_rotation_and_sample_extraction<Torus, STorus, params,
|
||||
FULLSM>,
|
||||
@@ -540,21 +561,47 @@ __host__ void host_blind_rotate_and_sample_extraction(
|
||||
cudaFuncCachePreferShared));
|
||||
}
|
||||
|
||||
// Applying the FFT on m^br
|
||||
if (allocate_gpu_memory) {
|
||||
int buffer_size = get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count, mbr_size, tau,
|
||||
max_shared_memory);
|
||||
*br_se_buffer = (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus, typename STorus, class params>
|
||||
__host__ void host_blind_rotate_and_sample_extraction(
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_out, Torus *ggsw_in,
|
||||
Torus *lut_vector, int8_t *br_se_buffer, uint32_t mbr_size, uint32_t tau,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t max_shared_memory) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
assert(glwe_dimension ==
|
||||
1); // For larger k we will need to adjust the mask size
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int memory_needed_per_block =
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction<Torus>(
|
||||
polynomial_size);
|
||||
|
||||
// Prepare the buffers
|
||||
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * level_count;
|
||||
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
|
||||
mbr_size * ggsw_size * sizeof(double), stream, gpu_index);
|
||||
|
||||
int8_t *d_mem_fft = (int8_t *)cuda_malloc_async(
|
||||
polynomial_size * sizeof(double), stream, gpu_index);
|
||||
double2 *d_ggsw_fft_in = (double2 *)br_se_buffer;
|
||||
int8_t *d_mem_fft = (int8_t *)d_ggsw_fft_in +
|
||||
(ptrdiff_t)(mbr_size * ggsw_size * sizeof(double));
|
||||
int8_t *d_mem = d_mem_fft;
|
||||
if (max_shared_memory < polynomial_size * sizeof(double)) {
|
||||
d_mem = d_mem_fft + (ptrdiff_t)(polynomial_size * sizeof(double));
|
||||
}
|
||||
// Apply the FFT on m^br
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
stream, d_ggsw_fft_in, ggsw_in, d_mem_fft, mbr_size, glwe_dimension,
|
||||
polynomial_size, level_count, gpu_index, max_shared_memory);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_drop_async(d_mem_fft, stream, gpu_index);
|
||||
|
||||
//
|
||||
dim3 thds(polynomial_size / params::opt, 1, 1);
|
||||
dim3 grid(tau, 1, 1);
|
||||
|
||||
@@ -573,10 +620,5 @@ __host__ void host_blind_rotate_and_sample_extraction(
|
||||
polynomial_size, base_log, level_count, memory_needed_per_block,
|
||||
d_mem);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
//
|
||||
cuda_drop_async(d_ggsw_fft_in, stream, gpu_index);
|
||||
if (max_shared_memory < memory_needed_per_block)
|
||||
cuda_drop_async(d_mem, stream, gpu_index);
|
||||
}
|
||||
#endif // VERTICAL_PACKING_CUH
|
||||
|
||||
@@ -65,6 +65,7 @@ __host__ void scratch_circuit_bootstrap_vertical_packing(
|
||||
Torus *h_lut_vector_indexes =
|
||||
(Torus *)malloc(number_of_inputs * level_count_cbs * sizeof(Torus));
|
||||
uint32_t r = number_of_inputs - params::log2_degree;
|
||||
uint32_t mbr_size = number_of_inputs - r;
|
||||
// allocate and initialize device pointers for circuit bootstrap and vertical
|
||||
// packing
|
||||
if (allocate_gpu_memory) {
|
||||
@@ -74,7 +75,10 @@ __host__ void scratch_circuit_bootstrap_vertical_packing(
|
||||
number_of_inputs, tau) +
|
||||
get_buffer_size_cmux_tree<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, r, tau,
|
||||
max_shared_memory);
|
||||
max_shared_memory) +
|
||||
get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs, mbr_size, tau,
|
||||
max_shared_memory);
|
||||
*cbs_vp_buffer =
|
||||
(int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index);
|
||||
}
|
||||
@@ -95,6 +99,9 @@ __host__ void scratch_circuit_bootstrap_vertical_packing(
|
||||
scratch_cmux_tree<Torus, STorus, params>(
|
||||
v_stream, gpu_index, cbs_vp_buffer, glwe_dimension, polynomial_size,
|
||||
level_count_cbs, r, tau, max_shared_memory, false);
|
||||
scratch_blind_rotation_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, cbs_vp_buffer, glwe_dimension, polynomial_size,
|
||||
level_count_cbs, mbr_size, tau, max_shared_memory, false);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -166,7 +173,7 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
uint32_t r = number_of_inputs - params::log2_degree;
|
||||
int8_t *cmux_tree_buffer =
|
||||
(int8_t *)glwe_array_out +
|
||||
tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus);
|
||||
(ptrdiff_t)(tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus));
|
||||
// CMUX Tree
|
||||
// r = tau * p - log2(N)
|
||||
host_cmux_tree<Torus, STorus, params>(
|
||||
@@ -180,8 +187,12 @@ __host__ void host_circuit_bootstrap_vertical_packing(
|
||||
Torus *br_ggsw = (Torus *)ggsw_out +
|
||||
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * polynomial_size);
|
||||
int8_t *br_se_buffer =
|
||||
cmux_tree_buffer + (ptrdiff_t)(get_buffer_size_cmux_tree<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs,
|
||||
r, tau, max_shared_memory));
|
||||
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
||||
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
|
||||
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out, br_se_buffer,
|
||||
number_of_inputs - r, tau, glwe_dimension, polynomial_size, base_log_cbs,
|
||||
level_count_cbs, max_shared_memory);
|
||||
}
|
||||
@@ -225,6 +236,7 @@ scratch_wop_pbs(void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer,
|
||||
number_of_inputs * number_of_bits_to_extract;
|
||||
uint32_t tau = number_of_inputs;
|
||||
uint32_t r = cbs_vp_number_of_inputs - params::log2_degree;
|
||||
uint32_t mbr_size = cbs_vp_number_of_inputs - r;
|
||||
int buffer_size =
|
||||
get_buffer_size_cbs_vp<Torus>(glwe_dimension, lwe_dimension,
|
||||
polynomial_size, level_count_cbs,
|
||||
@@ -232,6 +244,9 @@ scratch_wop_pbs(void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer,
|
||||
get_buffer_size_cmux_tree<Torus>(glwe_dimension, polynomial_size,
|
||||
level_count_cbs, r, tau,
|
||||
max_shared_memory) +
|
||||
get_buffer_size_blind_rotation_sample_extraction<Torus>(
|
||||
glwe_dimension, polynomial_size, level_count_cbs, mbr_size, tau,
|
||||
max_shared_memory) +
|
||||
wop_pbs_buffer_size;
|
||||
|
||||
*wop_pbs_buffer = (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index);
|
||||
|
||||
Reference in New Issue
Block a user