mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(cuda): Add circuit bootstrap in the cuda backend
- Add FP-Keyswitch. - Add entry points for cuda fk ksk in the public API. - Add test for fp_ksk in cuda backend. - Add fixture for bit extract Co-authored-by: agnesLeroy <agnes.leroy@zama.ai>
This commit is contained in:
committed by
Agnès Leroy
parent
e10c2936d1
commit
0aedb1a4f4
@@ -86,6 +86,25 @@ void cuda_extract_bits_64(
|
||||
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t number_of_samples,
|
||||
uint32_t max_shared_memory);
|
||||
void cuda_circuit_bootstrap_32(
|
||||
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
|
||||
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
|
||||
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
|
||||
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
||||
uint32_t number_of_samples, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_circuit_bootstrap_64(
|
||||
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
|
||||
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
|
||||
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
|
||||
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
||||
uint32_t number_of_samples, uint32_t max_shared_memory);
|
||||
}
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
@@ -14,6 +14,24 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
|
||||
|
||||
void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out,
|
||||
void *lwe_array_in, void *fp_ksk_array,
|
||||
uint32_t input_lwe_dimension,
|
||||
uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe,
|
||||
uint32_t number_of_keys);
|
||||
|
||||
void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out,
|
||||
void *lwe_array_in, void *fp_ksk_array,
|
||||
uint32_t input_lwe_dimension,
|
||||
uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe,
|
||||
uint32_t number_of_keys);
|
||||
}
|
||||
|
||||
#endif // CNCRT_KS_H_
|
||||
|
||||
@@ -343,3 +343,180 @@ void cuda_blind_rotate_and_sample_extraction_64(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_circuit_bootstrap_32(
|
||||
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
|
||||
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
|
||||
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
|
||||
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
||||
uint32_t number_of_samples, uint32_t max_shared_memory) {
|
||||
assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1",
|
||||
glwe_dimension == 1));
|
||||
assert(("Error (GPU circuit bootstrap): polynomial_size should be one of "
|
||||
"512, 1024, 2048, 4096, 8192",
|
||||
polynomial_size == 512 || polynomial_size == 1024 ||
|
||||
polynomial_size == 2048 || polynomial_size == 4096 ||
|
||||
polynomial_size == 8192));
|
||||
// The number of samples should be lower than the number of streaming
|
||||
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
|
||||
// to the occupancy of 50%). The only supported value for k is 1, so
|
||||
// k + 1 = 2 for now.
|
||||
int number_of_sm = 0;
|
||||
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
|
||||
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
|
||||
"equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_bsk));
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_circuit_bootstrap<uint32_t, Degree<512>>(
|
||||
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
|
||||
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_circuit_bootstrap<uint32_t, Degree<1024>>(
|
||||
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
|
||||
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_circuit_bootstrap<uint32_t, Degree<2048>>(
|
||||
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
|
||||
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_circuit_bootstrap<uint32_t, Degree<4096>>(
|
||||
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
|
||||
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_circuit_bootstrap<uint32_t, Degree<8192>>(
|
||||
v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint32_t *)fp_ksk_array,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer,
|
||||
(uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_circuit_bootstrap_64(
|
||||
void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in,
|
||||
void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer,
|
||||
void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer,
|
||||
void *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
||||
uint32_t number_of_samples, uint32_t max_shared_memory) {
|
||||
assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1",
|
||||
glwe_dimension == 1));
|
||||
assert(("Error (GPU circuit bootstrap): polynomial_size should be one of "
|
||||
"512, 1024, 2048, 4096, 8192",
|
||||
polynomial_size == 512 || polynomial_size == 1024 ||
|
||||
polynomial_size == 2048 || polynomial_size == 4096 ||
|
||||
polynomial_size == 8192));
|
||||
// The number of samples should be lower than the number of streaming
|
||||
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
|
||||
// to the occupancy of 50%). The only supported value for k is 1, so
|
||||
// k + 1 = 2 for now.
|
||||
int number_of_sm = 0;
|
||||
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
|
||||
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
|
||||
"equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_bsk));
|
||||
// The number of samples should be lower than the number of streaming
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_circuit_bootstrap<uint64_t, Degree<512>>(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_circuit_bootstrap<uint64_t, Degree<1024>>(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_circuit_bootstrap<uint64_t, Degree<2048>>(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_circuit_bootstrap<uint64_t, Degree<4096>>(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_circuit_bootstrap<uint64_t, Degree<8192>>(
|
||||
v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in,
|
||||
(double2 *)fourier_bsk, (uint64_t *)fp_ksk_array,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer,
|
||||
(uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size,
|
||||
glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk,
|
||||
base_log_pksk, level_cbs, base_log_cbs, number_of_samples,
|
||||
max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "../include/helper_cuda.h"
|
||||
#include "bootstrap.h"
|
||||
#include "bootstrap_amortized.cuh"
|
||||
#include "bootstrap_low_latency.cuh"
|
||||
#include "complex/operations.cuh"
|
||||
#include "crypto/ggsw.cuh"
|
||||
@@ -377,6 +378,31 @@ __global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
|
||||
}
|
||||
}
|
||||
|
||||
// works for lwe with generic sizes
|
||||
// shifted_lwe_buffer is scalar multiplication of lwe input
|
||||
// blockIdx.x refers to input ciphertext id
|
||||
template <typename Torus, class params>
|
||||
__global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value,
|
||||
size_t lwe_size) {
|
||||
|
||||
size_t blockId = blockIdx.y * gridDim.x + blockIdx.x;
|
||||
size_t threads_per_block = blockDim.x;
|
||||
size_t opt = lwe_size / threads_per_block;
|
||||
size_t rem = lwe_size & (threads_per_block - 1);
|
||||
|
||||
auto cur_dst = &dst_shift[blockId * lwe_size];
|
||||
auto cur_src = &src[blockIdx.y * lwe_size];
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
for (size_t i = 0; i < opt; i++) {
|
||||
cur_dst[tid] = cur_src[tid] * value;
|
||||
tid += threads_per_block;
|
||||
}
|
||||
|
||||
if (threadIdx.x < rem)
|
||||
cur_dst[tid] = cur_src[tid] * value;
|
||||
}
|
||||
|
||||
// only works for small lwe in ks+bs case
|
||||
// function copies lwe when length is not a power of two
|
||||
template <typename Torus>
|
||||
@@ -413,11 +439,11 @@ __global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) {
|
||||
|
||||
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
|
||||
// mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where alpha =
|
||||
// delta*2^{bit_idx-1}
|
||||
// The LUT is filled with value
|
||||
template <typename Torus, class params>
|
||||
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
||||
Torus *cur_poly = &lut[params::degree];
|
||||
|
||||
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
||||
size_t tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
@@ -426,6 +452,46 @@ __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
||||
}
|
||||
}
|
||||
|
||||
// Fill lut (equivalent to trivial encryption as mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where
|
||||
// alpha = 2^{log(q) - 1 - base_log * level}
|
||||
template <typename Torus, class params>
|
||||
__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
|
||||
uint32_t base_log_cbs) {
|
||||
|
||||
Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree];
|
||||
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
||||
size_t tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_mask[tid] = 0;
|
||||
cur_poly[tid] =
|
||||
0ll - 1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1));
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, Torus value) {
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
size_t src_lwe_id = blockIdx.y;
|
||||
size_t dst_lwe_id = blockIdx.x;
|
||||
|
||||
auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)];
|
||||
auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_dst[tid] = cur_src[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cur_dst[params::degree] = cur_src[params::degree] + value;
|
||||
}
|
||||
}
|
||||
|
||||
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0
|
||||
// if the extracted bit was 0 and 1 in the other case
|
||||
//
|
||||
@@ -683,4 +749,64 @@ void host_blind_rotate_and_sample_extraction(
|
||||
cuda_drop_async(d_mem, *stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_circuit_bootstrap(
|
||||
void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in,
|
||||
double2 *fourier_bsk, Torus *fp_ksk_array,
|
||||
Torus *lwe_array_in_shifted_buffer, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer,
|
||||
Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
||||
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
||||
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
||||
uint32_t number_of_samples, uint32_t max_shared_memory) {
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
||||
uint32_t lwe_size = lwe_dimension + 1;
|
||||
int pbs_count = number_of_samples * level_cbs;
|
||||
|
||||
dim3 blocks(level_cbs, number_of_samples, 1);
|
||||
int threads = 256;
|
||||
|
||||
// Shift message LSB on padding bit, at this point we expect to have messages
|
||||
// with only 1 bit of information
|
||||
shift_lwe_cbs<Torus, params><<<blocks, threads, 0, *stream>>>(
|
||||
lwe_array_in_shifted_buffer, lwe_array_in,
|
||||
1LL << (ciphertext_n_bits - delta_log - 1), lwe_size);
|
||||
|
||||
// Add q/4 to center the error while computing a negacyclic LUT
|
||||
add_to_body<Torus>
|
||||
<<<pbs_count, 1, 0, *stream>>>(lwe_array_in_shifted_buffer, lwe_dimension,
|
||||
1ll << (ciphertext_n_bits - 2));
|
||||
// Fill lut (equivalent to trivial encryption as mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where
|
||||
// alpha = 2^{log(q) - 1 - base_log * level}
|
||||
fill_lut_body_for_cbs<Torus, params>
|
||||
<<<level_cbs, params::degree / params::opt, 0, *stream>>>(
|
||||
lut_vector, ciphertext_n_bits, base_log_cbs);
|
||||
|
||||
// Applying a negacyclic LUT on a ciphertext with one bit of message in the
|
||||
// MSB and no bit of padding
|
||||
host_bootstrap_amortized<Torus, params>(
|
||||
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_vector,
|
||||
lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk,
|
||||
lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count,
|
||||
level_cbs, 0, max_shared_memory);
|
||||
|
||||
dim3 copy_grid(pbs_count * (glwe_dimension + 1), pbs_count, 1);
|
||||
dim3 copy_block(params::degree / params::opt, 1, 1);
|
||||
|
||||
// Add q/4 to center the error while computing a negacyclic LUT
|
||||
// copy pbs result (glwe_dimension + 1) times to be an input of fp-ks
|
||||
copy_add_lwe_cbs<Torus, params><<<copy_grid, copy_block>>>(
|
||||
lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer,
|
||||
1ll << (ciphertext_n_bits - 1 - base_log_cbs * level_cbs));
|
||||
|
||||
cuda_fp_keyswitch_lwe_to_glwe(
|
||||
v_stream, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array,
|
||||
polynomial_size, glwe_dimension, polynomial_size, base_log_pksk,
|
||||
level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1);
|
||||
}
|
||||
|
||||
#endif // WO_PBS_H
|
||||
|
||||
@@ -43,3 +43,36 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
static_cast<uint64_t *>(lwe_array_in), static_cast<uint64_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
}
|
||||
|
||||
void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out,
|
||||
void *lwe_array_in, void *fp_ksk_array,
|
||||
uint32_t input_lwe_dimension,
|
||||
uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe,
|
||||
uint32_t number_of_keys) {
|
||||
|
||||
cuda_fp_keyswitch_lwe_to_glwe(
|
||||
v_stream, static_cast<uint32_t *>(glwe_array_out),
|
||||
static_cast<uint32_t *>(lwe_array_in),
|
||||
static_cast<uint32_t *>(fp_ksk_array), input_lwe_dimension,
|
||||
output_glwe_dimension, output_polynomial_size, base_log, level_count,
|
||||
number_of_input_lwe, number_of_keys);
|
||||
}
|
||||
void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out,
|
||||
void *lwe_array_in, void *fp_ksk_array,
|
||||
uint32_t input_lwe_dimension,
|
||||
uint32_t output_glwe_dimension,
|
||||
uint32_t output_polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe,
|
||||
uint32_t number_of_keys) {
|
||||
|
||||
cuda_fp_keyswitch_lwe_to_glwe(
|
||||
v_stream, static_cast<uint64_t *>(glwe_array_out),
|
||||
static_cast<uint64_t *>(lwe_array_in),
|
||||
static_cast<uint64_t *>(fp_ksk_array), input_lwe_dimension,
|
||||
output_glwe_dimension, output_polynomial_size, base_log, level_count,
|
||||
number_of_input_lwe, number_of_keys);
|
||||
}
|
||||
|
||||
@@ -17,6 +17,75 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// blockIdx.y represents single lwe ciphertext
|
||||
// blockIdx.x represents chunk of lwe ciphertext,
|
||||
// chunk_count = glwe_size * polynomial_size / threads.
|
||||
// each threads will responsible to process only lwe_size times multiplication
|
||||
template <typename Torus>
|
||||
__global__ void
|
||||
fp_keyswitch(Torus *glwe_array_out, Torus *lwe_array_in, Torus *fp_ksk_array,
|
||||
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe, uint32_t number_of_keys) {
|
||||
size_t tid = threadIdx.x;
|
||||
|
||||
size_t glwe_size = (glwe_dimension + 1);
|
||||
size_t lwe_size = (lwe_dimension_in + 1);
|
||||
|
||||
// number of coefficients in a single fp-ksk
|
||||
size_t ksk_size = lwe_size * level_count * glwe_size * polynomial_size;
|
||||
|
||||
// number of coefficients inside fp-ksk block for each lwe_input coefficient
|
||||
size_t ksk_block_size = glwe_size * polynomial_size * level_count;
|
||||
|
||||
size_t ciphertext_id = blockIdx.y;
|
||||
// number of coefficients processed inside single block
|
||||
size_t coef_per_block = blockDim.x;
|
||||
size_t chunk_id = blockIdx.x;
|
||||
size_t ksk_id = ciphertext_id % number_of_keys;
|
||||
|
||||
extern __shared__ char sharedmem[];
|
||||
|
||||
// result accumulator, shared memory is used because of frequent access
|
||||
Torus *local_glwe_chunk = (Torus *)sharedmem;
|
||||
|
||||
// current input lwe ciphertext
|
||||
auto cur_input_lwe = &lwe_array_in[ciphertext_id * lwe_size];
|
||||
// current output glwe ciphertext
|
||||
auto cur_output_glwe =
|
||||
&glwe_array_out[ciphertext_id * glwe_size * polynomial_size];
|
||||
// current out glwe chunk, will be processed inside single block
|
||||
auto cur_glwe_chunk = &cur_output_glwe[chunk_id * coef_per_block];
|
||||
|
||||
// fp key used for current ciphertext
|
||||
auto cur_ksk = &fp_ksk_array[ksk_id * ksk_size];
|
||||
|
||||
// set shared mem accumulator to 0
|
||||
local_glwe_chunk[tid] = 0;
|
||||
|
||||
// iterate through each coefficient of input lwe
|
||||
for (size_t i = 0; i <= lwe_dimension_in; i++) {
|
||||
Torus a_i =
|
||||
round_to_closest_multiple(cur_input_lwe[i], base_log, level_count);
|
||||
|
||||
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
|
||||
// block of key for current lwe coefficient (cur_input_lwe[i])
|
||||
auto ksk_block = &cur_ksk[i * ksk_block_size];
|
||||
|
||||
// iterate through levels, calculating decomposition in reverse order
|
||||
for (size_t j = 0; j < level_count; j++) {
|
||||
auto ksk_glwe =
|
||||
&ksk_block[(level_count - j - 1) * glwe_size * polynomial_size];
|
||||
auto ksk_glwe_chunk = &ksk_glwe[chunk_id * coef_per_block];
|
||||
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
local_glwe_chunk[tid] -= decomposed * ksk_glwe_chunk[tid];
|
||||
}
|
||||
}
|
||||
cur_glwe_chunk[tid] = local_glwe_chunk[tid];
|
||||
}
|
||||
|
||||
/*
|
||||
* keyswitch kernel
|
||||
* Each thread handles a piece of the following equation:
|
||||
@@ -138,4 +207,24 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
cudaStreamSynchronize(*stream);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void cuda_fp_keyswitch_lwe_to_glwe(
|
||||
void *v_stream, Torus *glwe_array_out, Torus *lwe_array_in,
|
||||
Torus *fp_ksk_array, uint32_t lwe_dimension_in, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t number_of_input_lwe, uint32_t number_of_keys) {
|
||||
int threads = 256;
|
||||
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
|
||||
dim3 blocks(glwe_accumulator_size / threads, number_of_input_lwe, 1);
|
||||
|
||||
int shared_mem = sizeof(Torus) * threads;
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
fp_keyswitch<<<blocks, threads, shared_mem, *stream>>>(
|
||||
glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
|
||||
glwe_dimension, polynomial_size, base_log, level_count,
|
||||
number_of_input_lwe, number_of_keys);
|
||||
|
||||
cudaStreamSynchronize(*stream);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user