mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(cuda): Implements support to k>1 on cbs+vp.
This commit is contained in:
@@ -45,15 +45,15 @@ __global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value,
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
|
||||
uint32_t base_log_cbs) {
|
||||
uint32_t base_log_cbs,
|
||||
uint32_t glwe_dimension) {
|
||||
|
||||
Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree];
|
||||
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
||||
Torus *cur_body = &lut[(blockIdx.x * (glwe_dimension + 1) + glwe_dimension) *
|
||||
params::degree];
|
||||
size_t tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_mask[tid] = 0;
|
||||
cur_poly[tid] =
|
||||
cur_body[tid] =
|
||||
0ll -
|
||||
(1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1)));
|
||||
tid += params::degree / params::opt;
|
||||
@@ -76,22 +76,27 @@ __global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
|
||||
template <typename Torus, class params>
|
||||
__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src,
|
||||
uint32_t ciphertext_n_bits,
|
||||
uint32_t base_log_cbs, uint32_t level_cbs) {
|
||||
uint32_t base_log_cbs, uint32_t level_cbs,
|
||||
uint32_t glwe_dimension) {
|
||||
size_t tid = threadIdx.x;
|
||||
size_t src_lwe_id = blockIdx.x / (glwe_dimension + 1);
|
||||
size_t dst_lwe_id = blockIdx.x;
|
||||
size_t src_lwe_id = dst_lwe_id / 2;
|
||||
size_t cur_cbs_level = src_lwe_id % level_cbs + 1;
|
||||
|
||||
auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)];
|
||||
auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)];
|
||||
auto cur_src = &lwe_src[src_lwe_id * (glwe_dimension * params::degree + 1)];
|
||||
auto cur_dst = &lwe_dst[dst_lwe_id * (glwe_dimension * params::degree + 1)];
|
||||
|
||||
auto cur_src_slice = cur_src + blockIdx.y * params::degree;
|
||||
auto cur_dst_slice = cur_dst + blockIdx.y * params::degree;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_dst[tid] = cur_src[tid];
|
||||
cur_dst_slice[tid] = cur_src_slice[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
Torus val = 1ll << (ciphertext_n_bits - 1 - base_log_cbs * cur_cbs_level);
|
||||
if (threadIdx.x == 0) {
|
||||
cur_dst[params::degree] = cur_src[params::degree] + val;
|
||||
if (threadIdx.x == 0 && blockIdx.y == 0) {
|
||||
cur_dst[glwe_dimension * params::degree] =
|
||||
cur_src[glwe_dimension * params::degree] + val;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,9 +107,10 @@ get_buffer_size_cbs(uint32_t glwe_dimension, uint32_t lwe_dimension,
|
||||
uint32_t number_of_inputs) {
|
||||
|
||||
return number_of_inputs * level_count_cbs * (glwe_dimension + 1) *
|
||||
(polynomial_size + 1) *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_fp_ks_buffer
|
||||
number_of_inputs * level_count_cbs * (polynomial_size + 1) *
|
||||
number_of_inputs * level_count_cbs *
|
||||
(glwe_dimension * polynomial_size + 1) *
|
||||
sizeof(Torus) + // lwe_array_out_pbs_buffer
|
||||
number_of_inputs * level_count_cbs * (lwe_dimension + 1) *
|
||||
sizeof(Torus) + // lwe_array_in_shifted_buffer
|
||||
@@ -174,7 +180,8 @@ __host__ void host_circuit_bootstrap(
|
||||
sizeof(Torus));
|
||||
Torus *lwe_array_in_shifted_buffer =
|
||||
lwe_array_out_pbs_buffer +
|
||||
(ptrdiff_t)(number_of_inputs * level_cbs * (polynomial_size + 1));
|
||||
(ptrdiff_t)(number_of_inputs * level_cbs *
|
||||
(glwe_dimension * polynomial_size + 1));
|
||||
Torus *lut_vector =
|
||||
lwe_array_in_shifted_buffer +
|
||||
(ptrdiff_t)(number_of_inputs * level_cbs * (lwe_dimension + 1));
|
||||
@@ -195,9 +202,13 @@ __host__ void host_circuit_bootstrap(
|
||||
// Fill lut (equivalent to trivial encryption as mask is 0s)
|
||||
// The LUT is filled with -alpha in each coefficient where
|
||||
// alpha = 2^{log(q) - 1 - base_log * level}
|
||||
check_cuda_error(cudaMemsetAsync(lut_vector, 0,
|
||||
level_cbs * (glwe_dimension + 1) *
|
||||
polynomial_size * sizeof(Torus),
|
||||
*stream));
|
||||
fill_lut_body_for_cbs<Torus, params>
|
||||
<<<level_cbs, params::degree / params::opt, 0, *stream>>>(
|
||||
lut_vector, ciphertext_n_bits, base_log_cbs);
|
||||
lut_vector, ciphertext_n_bits, base_log_cbs, glwe_dimension);
|
||||
|
||||
// Applying a negacyclic LUT on a ciphertext with one bit of message in the
|
||||
// MSB and no bit of padding
|
||||
@@ -207,18 +218,19 @@ __host__ void host_circuit_bootstrap(
|
||||
glwe_dimension, lwe_dimension, polynomial_size, base_log_bsk, level_bsk,
|
||||
pbs_count, level_cbs, 0, max_shared_memory);
|
||||
|
||||
dim3 copy_grid(pbs_count * (glwe_dimension + 1), 1, 1);
|
||||
dim3 copy_grid(pbs_count * (glwe_dimension + 1), glwe_dimension, 1);
|
||||
dim3 copy_block(params::degree / params::opt, 1, 1);
|
||||
// Add q/4 to center the error while computing a negacyclic LUT
|
||||
// copy pbs result (glwe_dimension + 1) times to be an input of fp-ks
|
||||
copy_add_lwe_cbs<Torus, params><<<copy_grid, copy_block, 0, *stream>>>(
|
||||
lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer, ciphertext_n_bits,
|
||||
base_log_cbs, level_cbs);
|
||||
base_log_cbs, level_cbs, glwe_dimension);
|
||||
|
||||
cuda_fp_keyswitch_lwe_to_glwe(
|
||||
v_stream, gpu_index, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array,
|
||||
polynomial_size, glwe_dimension, polynomial_size, base_log_pksk,
|
||||
level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1);
|
||||
glwe_dimension * polynomial_size, glwe_dimension, polynomial_size,
|
||||
base_log_pksk, level_pksk, pbs_count * (glwe_dimension + 1),
|
||||
glwe_dimension + 1);
|
||||
}
|
||||
|
||||
#endif // CBS_CUH
|
||||
|
||||
@@ -420,14 +420,16 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
|
||||
|
||||
Torus *accumulator_c0 = (Torus *)selected_memory;
|
||||
Torus *accumulator_c1 = (Torus *)accumulator_c0 + 2 * polynomial_size;
|
||||
Torus *accumulator_c1 =
|
||||
(Torus *)accumulator_c0 + (glwe_dim + 1) * polynomial_size;
|
||||
int8_t *cmux_memory =
|
||||
(int8_t *)(accumulator_c1 + (glwe_dim + 1) * polynomial_size);
|
||||
|
||||
// Input LUT
|
||||
auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size];
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
|
||||
accumulator_c0[tid] = mi[tid];
|
||||
accumulator_c0[tid + params::degree] = mi[tid + params::degree];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
@@ -436,45 +438,43 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Compute x^ai * ACC
|
||||
// Body
|
||||
// Mask and Body
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
accumulator_c1, accumulator_c0, (1 << monomial_degree), false, 1);
|
||||
// Mask
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
accumulator_c1 + polynomial_size, accumulator_c0 + polynomial_size,
|
||||
(1 << monomial_degree), false, 1);
|
||||
accumulator_c1, accumulator_c0, (1 << monomial_degree), false,
|
||||
(glwe_dim + 1));
|
||||
|
||||
monomial_degree += 1;
|
||||
|
||||
// ACC = CMUX ( Ci, x^ai * ACC, ACC )
|
||||
synchronize_threads_in_block();
|
||||
cmux<Torus, STorus, params>(
|
||||
accumulator_c0, accumulator_c0, ggsw_in,
|
||||
(int8_t *)(accumulator_c0 + 4 * polynomial_size), 0, 0, 1, glwe_dim,
|
||||
polynomial_size, base_log, level_count, i);
|
||||
cmux<Torus, STorus, params>(accumulator_c0, accumulator_c0, ggsw_in,
|
||||
cmux_memory, 0, 0, 1, glwe_dim, polynomial_size,
|
||||
base_log, level_count, i);
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Write the output
|
||||
auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)];
|
||||
auto block_lwe_out = &lwe_out[blockIdx.x * (glwe_dim * polynomial_size + 1)];
|
||||
|
||||
// The blind rotation for this block is over
|
||||
// Now we can perform the sample extraction: for the body it's just
|
||||
// the resulting constant coefficient of the accumulator
|
||||
// For the mask it's more complicated
|
||||
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_c0, 1);
|
||||
sample_extract_body<Torus, params>(block_lwe_out, accumulator_c0, 1);
|
||||
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_c0, glwe_dim);
|
||||
sample_extract_body<Torus, params>(block_lwe_out, accumulator_c0, glwe_dim);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__host__ __device__ int
|
||||
get_memory_needed_per_block_blind_rotation_sample_extraction(
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size) {
|
||||
return sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c0
|
||||
sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c1
|
||||
+ get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension, polynomial_size);
|
||||
return sizeof(Torus) * polynomial_size *
|
||||
(glwe_dimension + 1) + // accumulator_c0
|
||||
sizeof(Torus) * polynomial_size *
|
||||
(glwe_dimension + 1) + // accumulator_c1
|
||||
+get_memory_needed_per_block_cmux_tree<Torus>(glwe_dimension,
|
||||
polynomial_size);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
Reference in New Issue
Block a user