feat(cuda): Add circuit bootstrap in the cuda backend

- Add FP-Keyswitch.
- Add entry points for cuda fk ksk in the public API.
- Add test for fp_ksk in cuda backend.
- Add fixture for bit extract

Co-authored-by: agnesLeroy <agnes.leroy@zama.ai>
This commit is contained in:
Beka Barbakadze
2022-11-03 14:10:57 +04:00
committed by Agnès Leroy
parent e10c2936d1
commit 0aedb1a4f4
6 changed files with 465 additions and 3 deletions

View File

@@ -86,6 +86,25 @@ void cuda_extract_bits_64(
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t number_of_samples,
uint32_t max_shared_memory);
void cuda_circuit_bootstrap_32(
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
uint32_t number_of_samples, uint32_t max_shared_memory);
void cuda_circuit_bootstrap_64(
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
uint32_t number_of_samples, uint32_t max_shared_memory);
}
#ifdef __CUDACC__

View File

@@ -14,6 +14,24 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out,
void *lwe_array_in, void *fp_ksk_array,
uint32_t input_lwe_dimension,
uint32_t output_glwe_dimension,
uint32_t output_polynomial_size,
uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe,
uint32_t number_of_keys);
void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out,
void *lwe_array_in, void *fp_ksk_array,
uint32_t input_lwe_dimension,
uint32_t output_glwe_dimension,
uint32_t output_polynomial_size,
uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe,
uint32_t number_of_keys);
}
#endif // CNCRT_KS_H_

View File

@@ -343,3 +343,180 @@ void cuda_blind_rotate_and_sample_extraction_64(
break;
}
}
void cuda_circuit_bootstrap_32(
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
uint32_t number_of_samples, uint32_t max_shared_memory) {
assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU circuit bootstrap): polynomial_size should be one of "
"512, 1024, 2048, 4096, 8192",
polynomial_size == 512 || polynomial_size == 1024 ||
polynomial_size == 2048 || polynomial_size == 4096 ||
polynomial_size == 8192));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
// k + 1 = 2 for now.
int number_of_sm = 0;
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_bsk));
switch (polynomial_size) {
case 512:
host_circuit_bootstrap<uint32_t, Degree<512>>(
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 1024:
host_circuit_bootstrap<uint32_t, Degree<1024>>(
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 2048:
host_circuit_bootstrap<uint32_t, Degree<2048>>(
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 4096:
host_circuit_bootstrap<uint32_t, Degree<4096>>(
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 8192:
host_circuit_bootstrap<uint32_t, Degree<8192>>(
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
default:
break;
}
}
void cuda_circuit_bootstrap_64(
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
uint32_t number_of_samples, uint32_t max_shared_memory) {
assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU circuit bootstrap): polynomial_size should be one of "
"512, 1024, 2048, 4096, 8192",
polynomial_size == 512 || polynomial_size == 1024 ||
polynomial_size == 2048 || polynomial_size == 4096 ||
polynomial_size == 8192));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
// k + 1 = 2 for now.
int number_of_sm = 0;
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_bsk));
// The number of samples should be lower than the number of streaming
switch (polynomial_size) {
case 512:
host_circuit_bootstrap<uint64_t, Degree<512>>(
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 1024:
host_circuit_bootstrap<uint64_t, Degree<1024>>(
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 2048:
host_circuit_bootstrap<uint64_t, Degree<2048>>(
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 4096:
host_circuit_bootstrap<uint64_t, Degree<4096>>(
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
case 8192:
host_circuit_bootstrap<uint64_t, Degree<8192>>(
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
max_shared_memory);
break;
default:
break;
}
}

View File

@@ -5,6 +5,7 @@
#include "../include/helper_cuda.h"
#include "bootstrap.h"
#include "bootstrap_amortized.cuh"
#include "bootstrap_low_latency.cuh"
#include "complex/operations.cuh"
#include "crypto/ggsw.cuh"
@@ -377,6 +378,31 @@ __global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
}
}
// works for lwe with generic sizes
// shifted_lwe_buffer is scalar multiplication of lwe input
// blockIdx.x refers to input ciphertext id
template <typename Torus, class params>
__global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value,
size_t lwe_size) {
size_t blockId = blockIdx.y * gridDim.x + blockIdx.x;
size_t threads_per_block = blockDim.x;
size_t opt = lwe_size / threads_per_block;
size_t rem = lwe_size & (threads_per_block - 1);
auto cur_dst = &dst_shift[blockId * lwe_size];
auto cur_src = &src[blockIdx.y * lwe_size];
size_t tid = threadIdx.x;
for (size_t i = 0; i < opt; i++) {
cur_dst[tid] = cur_src[tid] * value;
tid += threads_per_block;
}
if (threadIdx.x < rem)
cur_dst[tid] = cur_src[tid] * value;
}
// only works for small lwe in ks+bs case
// function copies lwe when length is not a power of two
template <typename Torus>
@@ -413,11 +439,11 @@ __global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) {
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
// mask is 0s)
// The LUT is filled with -alpha in each coefficient where alpha =
// delta*2^{bit_idx-1}
// The LUT is filled with value
template <typename Torus, class params>
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
Torus *cur_poly = &lut[params::degree];
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
size_t tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
@@ -426,6 +452,46 @@ __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
}
}
// Fill lut (equivalent to trivial encryption as mask is 0s)
// The LUT is filled with -alpha in each coefficient where
// alpha = 2^{log(q) - 1 - base_log * level}
template <typename Torus, class params>
__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
uint32_t base_log_cbs) {
Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree];
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
size_t tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_mask[tid] = 0;
cur_poly[tid] =
0ll - 1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1));
tid += params::degree / params::opt;
}
}
template <typename Torus, class params>
__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, Torus value) {
size_t tid = threadIdx.x;
size_t src_lwe_id = blockIdx.y;
size_t dst_lwe_id = blockIdx.x;
auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)];
auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_dst[tid] = cur_src[tid];
tid += params::degree / params::opt;
}
if (threadIdx.x == 0) {
cur_dst[params::degree] = cur_src[params::degree] + value;
}
}
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0
// if the extracted bit was 0 and 1 in the other case
//
@@ -683,4 +749,64 @@ void host_blind_rotate_and_sample_extraction(
cuda_drop_async(d_mem, *stream, gpu_index);
}
template <typename Torus, class params>
__host__ void host_circuit_bootstrap(
void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in,
double2 *fourier_bsk, Torus *fp_ksk_array,
Torus *lwe_array_in_shifted_buffer, Torus *lut_vector,
uint32_t *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer,
Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
uint32_t number_of_samples, uint32_t max_shared_memory) {
auto stream = static_cast<cudaStream_t *>(v_stream);
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
uint32_t lwe_size = lwe_dimension + 1;
int pbs_count = number_of_samples * level_cbs;
dim3 blocks(level_cbs, number_of_samples, 1);
int threads = 256;
// Shift message LSB on padding bit, at this point we expect to have messages
// with only 1 bit of information
shift_lwe_cbs<Torus, params><<<blocks, threads, 0, *stream>>>(
lwe_array_in_shifted_buffer, lwe_array_in,
1LL << (ciphertext_n_bits - delta_log - 1), lwe_size);
// Add q/4 to center the error while computing a negacyclic LUT
add_to_body<Torus>
<<<pbs_count, 1, 0, *stream>>>(lwe_array_in_shifted_buffer, lwe_dimension,
1ll << (ciphertext_n_bits - 2));
// Fill lut (equivalent to trivial encryption as mask is 0s)
// The LUT is filled with -alpha in each coefficient where
// alpha = 2^{log(q) - 1 - base_log * level}
fill_lut_body_for_cbs<Torus, params>
<<<level_cbs, params::degree / params::opt, 0, *stream>>>(
lut_vector, ciphertext_n_bits, base_log_cbs);
// Applying a negacyclic LUT on a ciphertext with one bit of message in the
// MSB and no bit of padding
host_bootstrap_amortized<Torus, params>(
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_vector,
lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk,
lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count,
level_cbs, 0, max_shared_memory);
dim3 copy_grid(pbs_count * (glwe_dimension + 1), pbs_count, 1);
dim3 copy_block(params::degree / params::opt, 1, 1);
// Add q/4 to center the error while computing a negacyclic LUT
// copy pbs result (glwe_dimension + 1) times to be an input of fp-ks
copy_add_lwe_cbs<Torus, params><<<copy_grid, copy_block>>>(
lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer,
1ll << (ciphertext_n_bits - 1 - base_log_cbs * level_cbs));
cuda_fp_keyswitch_lwe_to_glwe(
v_stream, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array,
polynomial_size, glwe_dimension, polynomial_size, base_log_pksk,
level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1);
}
#endif // WO_PBS_H

View File

@@ -43,3 +43,36 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
static_cast<uint64_t *>(lwe_array_in), static_cast<uint64_t *>(ksk),
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
}
void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out,
void *lwe_array_in, void *fp_ksk_array,
uint32_t input_lwe_dimension,
uint32_t output_glwe_dimension,
uint32_t output_polynomial_size,
uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe,
uint32_t number_of_keys) {
cuda_fp_keyswitch_lwe_to_glwe(
v_stream, static_cast<uint32_t *>(glwe_array_out),
static_cast<uint32_t *>(lwe_array_in),
static_cast<uint32_t *>(fp_ksk_array), input_lwe_dimension,
output_glwe_dimension, output_polynomial_size, base_log, level_count,
number_of_input_lwe, number_of_keys);
}
void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out,
void *lwe_array_in, void *fp_ksk_array,
uint32_t input_lwe_dimension,
uint32_t output_glwe_dimension,
uint32_t output_polynomial_size,
uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe,
uint32_t number_of_keys) {
cuda_fp_keyswitch_lwe_to_glwe(
v_stream, static_cast<uint64_t *>(glwe_array_out),
static_cast<uint64_t *>(lwe_array_in),
static_cast<uint64_t *>(fp_ksk_array), input_lwe_dimension,
output_glwe_dimension, output_polynomial_size, base_log, level_count,
number_of_input_lwe, number_of_keys);
}

View File

@@ -17,6 +17,75 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
return ptr;
}
// blockIdx.y represents single lwe ciphertext
// blockIdx.x represents chunk of lwe ciphertext,
// chunk_count = glwe_size * polynomial_size / threads.
// each threads will responsible to process only lwe_size times multiplication
template <typename Torus>
__global__ void
fp_keyswitch(Torus *glwe_array_out, Torus *lwe_array_in, Torus *fp_ksk_array,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe, uint32_t number_of_keys) {
size_t tid = threadIdx.x;
size_t glwe_size = (glwe_dimension + 1);
size_t lwe_size = (lwe_dimension_in + 1);
// number of coefficients in a single fp-ksk
size_t ksk_size = lwe_size * level_count * glwe_size * polynomial_size;
// number of coefficients inside fp-ksk block for each lwe_input coefficient
size_t ksk_block_size = glwe_size * polynomial_size * level_count;
size_t ciphertext_id = blockIdx.y;
// number of coefficients processed inside single block
size_t coef_per_block = blockDim.x;
size_t chunk_id = blockIdx.x;
size_t ksk_id = ciphertext_id % number_of_keys;
extern __shared__ char sharedmem[];
// result accumulator, shared memory is used because of frequent access
Torus *local_glwe_chunk = (Torus *)sharedmem;
// current input lwe ciphertext
auto cur_input_lwe = &lwe_array_in[ciphertext_id * lwe_size];
// current output glwe ciphertext
auto cur_output_glwe =
&glwe_array_out[ciphertext_id * glwe_size * polynomial_size];
// current out glwe chunk, will be processed inside single block
auto cur_glwe_chunk = &cur_output_glwe[chunk_id * coef_per_block];
// fp key used for current ciphertext
auto cur_ksk = &fp_ksk_array[ksk_id * ksk_size];
// set shared mem accumulator to 0
local_glwe_chunk[tid] = 0;
// iterate through each coefficient of input lwe
for (size_t i = 0; i <= lwe_dimension_in; i++) {
Torus a_i =
round_to_closest_multiple(cur_input_lwe[i], base_log, level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
// block of key for current lwe coefficient (cur_input_lwe[i])
auto ksk_block = &cur_ksk[i * ksk_block_size];
// iterate through levels, calculating decomposition in reverse order
for (size_t j = 0; j < level_count; j++) {
auto ksk_glwe =
&ksk_block[(level_count - j - 1) * glwe_size * polynomial_size];
auto ksk_glwe_chunk = &ksk_glwe[chunk_id * coef_per_block];
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
local_glwe_chunk[tid] -= decomposed * ksk_glwe_chunk[tid];
}
}
cur_glwe_chunk[tid] = local_glwe_chunk[tid];
}
/*
* keyswitch kernel
* Each thread handles a piece of the following equation:
@@ -138,4 +207,24 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
cudaStreamSynchronize(*stream);
}
template <typename Torus>
__host__ void cuda_fp_keyswitch_lwe_to_glwe(
void *v_stream, Torus *glwe_array_out, Torus *lwe_array_in,
Torus *fp_ksk_array, uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t number_of_input_lwe, uint32_t number_of_keys) {
int threads = 256;
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
dim3 blocks(glwe_accumulator_size / threads, number_of_input_lwe, 1);
int shared_mem = sizeof(Torus) * threads;
auto stream = static_cast<cudaStream_t *>(v_stream);
fp_keyswitch<<<blocks, threads, shared_mem, *stream>>>(
glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
glwe_dimension, polynomial_size, base_log, level_count,
number_of_input_lwe, number_of_keys);
cudaStreamSynchronize(*stream);
}
#endif