refactor(cuda): Implements support to k>1 on cbs+vp.

This commit is contained in:
Pedro Alves
2023-02-22 16:55:14 -03:00
committed by Agnès Leroy
parent 7896d96a49
commit b0a362af6d
2 changed files with 52 additions and 40 deletions

View File

@@ -45,15 +45,15 @@ __global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value,
*/
template <typename Torus, class params>
__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
uint32_t base_log_cbs) {
uint32_t base_log_cbs,
uint32_t glwe_dimension) {
Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree];
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
Torus *cur_body = &lut[(blockIdx.x * (glwe_dimension + 1) + glwe_dimension) *
params::degree];
size_t tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_mask[tid] = 0;
cur_poly[tid] =
cur_body[tid] =
0ll -
(1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1)));
tid += params::degree / params::opt;
@@ -76,22 +76,27 @@ __global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
template <typename Torus, class params>
__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src,
uint32_t ciphertext_n_bits,
uint32_t base_log_cbs, uint32_t level_cbs) {
uint32_t base_log_cbs, uint32_t level_cbs,
uint32_t glwe_dimension) {
size_t tid = threadIdx.x;
size_t src_lwe_id = blockIdx.x / (glwe_dimension + 1);
size_t dst_lwe_id = blockIdx.x;
size_t src_lwe_id = dst_lwe_id / 2;
size_t cur_cbs_level = src_lwe_id % level_cbs + 1;
auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)];
auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)];
auto cur_src = &lwe_src[src_lwe_id * (glwe_dimension * params::degree + 1)];
auto cur_dst = &lwe_dst[dst_lwe_id * (glwe_dimension * params::degree + 1)];
auto cur_src_slice = cur_src + blockIdx.y * params::degree;
auto cur_dst_slice = cur_dst + blockIdx.y * params::degree;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_dst[tid] = cur_src[tid];
cur_dst_slice[tid] = cur_src_slice[tid];
tid += params::degree / params::opt;
}
Torus val = 1ll << (ciphertext_n_bits - 1 - base_log_cbs * cur_cbs_level);
if (threadIdx.x == 0) {
cur_dst[params::degree] = cur_src[params::degree] + val;
if (threadIdx.x == 0 && blockIdx.y == 0) {
cur_dst[glwe_dimension * params::degree] =
cur_src[glwe_dimension * params::degree] + val;
}
}
@@ -102,9 +107,10 @@ get_buffer_size_cbs(uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t number_of_inputs) {
return number_of_inputs * level_count_cbs * (glwe_dimension + 1) *
(polynomial_size + 1) *
(glwe_dimension * polynomial_size + 1) *
sizeof(Torus) + // lwe_array_in_fp_ks_buffer
number_of_inputs * level_count_cbs * (polynomial_size + 1) *
number_of_inputs * level_count_cbs *
(glwe_dimension * polynomial_size + 1) *
sizeof(Torus) + // lwe_array_out_pbs_buffer
number_of_inputs * level_count_cbs * (lwe_dimension + 1) *
sizeof(Torus) + // lwe_array_in_shifted_buffer
@@ -174,7 +180,8 @@ __host__ void host_circuit_bootstrap(
sizeof(Torus));
Torus *lwe_array_in_shifted_buffer =
lwe_array_out_pbs_buffer +
(ptrdiff_t)(number_of_inputs * level_cbs * (polynomial_size + 1));
(ptrdiff_t)(number_of_inputs * level_cbs *
(glwe_dimension * polynomial_size + 1));
Torus *lut_vector =
lwe_array_in_shifted_buffer +
(ptrdiff_t)(number_of_inputs * level_cbs * (lwe_dimension + 1));
@@ -195,9 +202,13 @@ __host__ void host_circuit_bootstrap(
// Fill lut (equivalent to trivial encryption as mask is 0s)
// The LUT is filled with -alpha in each coefficient where
// alpha = 2^{log(q) - 1 - base_log * level}
check_cuda_error(cudaMemsetAsync(lut_vector, 0,
level_cbs * (glwe_dimension + 1) *
polynomial_size * sizeof(Torus),
*stream));
fill_lut_body_for_cbs<Torus, params>
<<<level_cbs, params::degree / params::opt, 0, *stream>>>(
lut_vector, ciphertext_n_bits, base_log_cbs);
lut_vector, ciphertext_n_bits, base_log_cbs, glwe_dimension);
// Applying a negacyclic LUT on a ciphertext with one bit of message in the
// MSB and no bit of padding
@@ -207,18 +218,19 @@ __host__ void host_circuit_bootstrap(
glwe_dimension, lwe_dimension, polynomial_size, base_log_bsk, level_bsk,
pbs_count, level_cbs, 0, max_shared_memory);
dim3 copy_grid(pbs_count * (glwe_dimension + 1), 1, 1);
dim3 copy_grid(pbs_count * (glwe_dimension + 1), glwe_dimension, 1);
dim3 copy_block(params::degree / params::opt, 1, 1);
// Add q/4 to center the error while computing a negacyclic LUT
// copy pbs result (glwe_dimension + 1) times to be an input of fp-ks
copy_add_lwe_cbs<Torus, params><<<copy_grid, copy_block, 0, *stream>>>(
lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer, ciphertext_n_bits,
base_log_cbs, level_cbs);
base_log_cbs, level_cbs, glwe_dimension);
cuda_fp_keyswitch_lwe_to_glwe(
v_stream, gpu_index, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array,
polynomial_size, glwe_dimension, polynomial_size, base_log_pksk,
level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1);
glwe_dimension * polynomial_size, glwe_dimension, polynomial_size,
base_log_pksk, level_pksk, pbs_count * (glwe_dimension + 1),
glwe_dimension + 1);
}
#endif // CBS_CUH

View File

@@ -420,14 +420,16 @@ __global__ void device_blind_rotation_and_sample_extraction(
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
Torus *accumulator_c0 = (Torus *)selected_memory;
Torus *accumulator_c1 = (Torus *)accumulator_c0 + 2 * polynomial_size;
Torus *accumulator_c1 =
(Torus *)accumulator_c0 + (glwe_dim + 1) * polynomial_size;
int8_t *cmux_memory =
(int8_t *)(accumulator_c1 + (glwe_dim + 1) * polynomial_size);
// Input LUT
auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size];
int tid = threadIdx.x;
for (int i = 0; i < params::opt; i++) {
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
accumulator_c0[tid] = mi[tid];
accumulator_c0[tid + params::degree] = mi[tid + params::degree];
tid += params::degree / params::opt;
}
@@ -436,45 +438,43 @@ __global__ void device_blind_rotation_and_sample_extraction(
synchronize_threads_in_block();
// Compute x^ai * ACC
// Body
// Mask and Body
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_c1, accumulator_c0, (1 << monomial_degree), false, 1);
// Mask
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_c1 + polynomial_size, accumulator_c0 + polynomial_size,
(1 << monomial_degree), false, 1);
accumulator_c1, accumulator_c0, (1 << monomial_degree), false,
(glwe_dim + 1));
monomial_degree += 1;
// ACC = CMUX ( Ci, x^ai * ACC, ACC )
synchronize_threads_in_block();
cmux<Torus, STorus, params>(
accumulator_c0, accumulator_c0, ggsw_in,
(int8_t *)(accumulator_c0 + 4 * polynomial_size), 0, 0, 1, glwe_dim,
polynomial_size, base_log, level_count, i);
cmux<Torus, STorus, params>(accumulator_c0, accumulator_c0, ggsw_in,
cmux_memory, 0, 0, 1, glwe_dim, polynomial_size,
base_log, level_count, i);
}
synchronize_threads_in_block();
// Write the output
auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)];
auto block_lwe_out = &lwe_out[blockIdx.x * (glwe_dim * polynomial_size + 1)];
// The blind rotation for this block is over
// Now we can perform the sample extraction: for the body it's just
// the resulting constant coefficient of the accumulator
// For the mask it's more complicated
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_c0, 1);
sample_extract_body<Torus, params>(block_lwe_out, accumulator_c0, 1);
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_c0, glwe_dim);
sample_extract_body<Torus, params>(block_lwe_out, accumulator_c0, glwe_dim);
}
template <typename Torus>
__host__ __device__ int
get_memory_needed_per_block_blind_rotation_sample_extraction(
uint32_t glwe_dimension, uint32_t polynomial_size) {
return sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c0
sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c1
+ get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension, polynomial_size);
return sizeof(Torus) * polynomial_size *
(glwe_dimension + 1) + // accumulator_c0
sizeof(Torus) * polynomial_size *
(glwe_dimension + 1) + // accumulator_c1
+get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension,
polynomial_size);
}
template <typename Torus>