mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(cuda): Implements support to k>1 on bit extraction.
This commit is contained in:
@@ -121,14 +121,11 @@ void cuda_extract_bits_32(void *v_stream, uint32_t gpu_index,
|
||||
uint32_t max_shared_memory) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 32",
|
||||
base_log_bsk <= 32));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
|
||||
assert(("Error (GPU extract bits): polynomial_size should be one of "
|
||||
"256, 512, 1024, 2048, 4096, 8192",
|
||||
lwe_dimension_in == 256 || lwe_dimension_in == 512 ||
|
||||
lwe_dimension_in == 1024 || lwe_dimension_in == 2048 ||
|
||||
lwe_dimension_in == 4096 || lwe_dimension_in == 8192));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be equal to "
|
||||
"polynomial_size",
|
||||
lwe_dimension_in == polynomial_size));
|
||||
polynomial_size == 256 || polynomial_size == 512 ||
|
||||
polynomial_size == 1024 || polynomial_size == 2048 ||
|
||||
polynomial_size == 4096 || polynomial_size == 8192));
|
||||
// The number of samples should be lower than four time the number of
|
||||
// streaming multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being
|
||||
// related to the occupancy of 50%). The only supported value for k is 1, so
|
||||
@@ -141,7 +138,7 @@ void cuda_extract_bits_32(void *v_stream, uint32_t gpu_index,
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
|
||||
|
||||
switch (lwe_dimension_in) {
|
||||
switch (polynomial_size) {
|
||||
case 256:
|
||||
host_extract_bits<uint32_t, Degree<256>>(
|
||||
v_stream, gpu_index, (uint32_t *)list_lwe_array_out,
|
||||
@@ -258,14 +255,11 @@ void cuda_extract_bits_64(void *v_stream, uint32_t gpu_index,
|
||||
uint32_t max_shared_memory) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 64",
|
||||
base_log_bsk <= 64));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
|
||||
assert(("Error (GPU extract bits): polynomial_size should be one of "
|
||||
"256, 512, 1024, 2048, 4096, 8192",
|
||||
lwe_dimension_in == 256 || lwe_dimension_in == 512 ||
|
||||
lwe_dimension_in == 1024 || lwe_dimension_in == 2048 ||
|
||||
lwe_dimension_in == 4096 || lwe_dimension_in == 8192));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be equal to "
|
||||
"polynomial_size",
|
||||
lwe_dimension_in == polynomial_size));
|
||||
polynomial_size == 256 || polynomial_size == 512 ||
|
||||
polynomial_size == 1024 || polynomial_size == 2048 ||
|
||||
polynomial_size == 4096 || polynomial_size == 8192));
|
||||
// The number of samples should be lower than four time the number of
|
||||
// streaming multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being
|
||||
// related to the occupancy of 50%). The only supported value for k is 1, so
|
||||
@@ -278,7 +272,7 @@ void cuda_extract_bits_64(void *v_stream, uint32_t gpu_index,
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
|
||||
|
||||
switch (lwe_dimension_in) {
|
||||
switch (polynomial_size) {
|
||||
case 256:
|
||||
host_extract_bits<uint64_t, Degree<256>>(
|
||||
v_stream, gpu_index, (uint64_t *)list_lwe_array_out,
|
||||
|
||||
@@ -11,30 +11,27 @@
|
||||
#include "utils/timer.cuh"
|
||||
|
||||
/*
|
||||
* Function copies batch lwe input to two different buffers,
|
||||
* one is shifted by value
|
||||
* one is copied without any modification
|
||||
* Function copies batch lwe input to one that is shifted by value
|
||||
* works for ciphertexts with sizes supported by params::degree
|
||||
*
|
||||
* Each x-block handles a params::degree-chunk of src
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
|
||||
Torus *src, Torus value) {
|
||||
int blockId = blockIdx.x;
|
||||
__global__ void copy_and_shift_lwe(Torus *dst_shift, Torus *src, Torus value,
|
||||
uint32_t glwe_dimension) {
|
||||
int tid = threadIdx.x;
|
||||
auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)];
|
||||
auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)];
|
||||
auto cur_src = &src[blockId * (params::degree + 1)];
|
||||
auto cur_dst_shift = &dst_shift[blockIdx.x * params::degree];
|
||||
auto cur_src = &src[blockIdx.x * params::degree];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_dst_copy[tid] = cur_src[tid];
|
||||
cur_dst_shift[tid] = cur_src[tid] * value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == params::degree / params::opt - 1) {
|
||||
cur_dst_copy[params::degree] = cur_src[params::degree];
|
||||
cur_dst_shift[params::degree] = cur_src[params::degree] * value;
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
cur_dst_shift[glwe_dimension * params::degree] =
|
||||
cur_src[glwe_dimension * params::degree] * value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,24 +86,27 @@ __global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) {
|
||||
template <typename Torus, class params>
|
||||
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
||||
Torus *pbs_lwe_array_out, Torus add_value,
|
||||
Torus mul_value) {
|
||||
Torus mul_value, uint32_t glwe_dimension) {
|
||||
size_t tid = threadIdx.x;
|
||||
size_t blockId = blockIdx.x;
|
||||
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_shifted_lwe =
|
||||
&shifted_lwe[blockId * (glwe_dimension * params::degree + 1)];
|
||||
auto cur_state_lwe =
|
||||
&state_lwe[blockId * (glwe_dimension * params::degree + 1)];
|
||||
auto cur_pbs_lwe_array_out =
|
||||
&pbs_lwe_array_out[blockId * (params::degree + 1)];
|
||||
&pbs_lwe_array_out[blockId * (glwe_dimension * params::degree + 1)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
for (int i = 0; i < glwe_dimension * params::opt; i++) {
|
||||
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_array_out[tid];
|
||||
cur_shifted_lwe[tid] *= mul_value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == params::degree / params::opt - 1) {
|
||||
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
|
||||
(cur_pbs_lwe_array_out[params::degree] + add_value);
|
||||
cur_shifted_lwe[params::degree] *= mul_value;
|
||||
if (threadIdx.x == 0) {
|
||||
cur_shifted_lwe[glwe_dimension * params::degree] =
|
||||
cur_state_lwe[glwe_dimension * params::degree] -=
|
||||
(cur_pbs_lwe_array_out[glwe_dimension * params::degree] + add_value);
|
||||
cur_shifted_lwe[glwe_dimension * params::degree] *= mul_value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,9 +116,11 @@ __global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
||||
* blockIdx.x refers id of lut vector
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
||||
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value,
|
||||
uint32_t glwe_dimension) {
|
||||
|
||||
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
||||
Torus *cur_poly = &lut[(blockIdx.x * (glwe_dimension + 1) + glwe_dimension) *
|
||||
params::degree];
|
||||
size_t tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
@@ -196,7 +198,6 @@ __host__ void host_extract_bits(
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
||||
|
||||
int blocks = 1;
|
||||
int threads = params::degree / params::opt;
|
||||
|
||||
// Always define the PBS buffer first, because it has the strongest memory
|
||||
@@ -224,9 +225,13 @@ __host__ void host_extract_bits(
|
||||
(ptrdiff_t)((glwe_dimension * polynomial_size + 1) * sizeof(Torus));
|
||||
|
||||
// shift lwe on padding bit and copy in new buffer
|
||||
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>(
|
||||
lwe_array_in_buffer, lwe_array_in_shifted_buffer, lwe_array_in,
|
||||
(Torus)(1ll << (ciphertext_n_bits - delta_log - 1)));
|
||||
check_cuda_error(
|
||||
cudaMemcpyAsync(lwe_array_in_buffer, lwe_array_in,
|
||||
(glwe_dimension * polynomial_size + 1) * sizeof(Torus),
|
||||
cudaMemcpyDeviceToDevice, *stream));
|
||||
copy_and_shift_lwe<Torus, params><<<glwe_dimension, threads, 0, *stream>>>(
|
||||
lwe_array_in_shifted_buffer, lwe_array_in,
|
||||
(Torus)(1ll << (ciphertext_n_bits - delta_log - 1)), glwe_dimension);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
||||
@@ -234,7 +239,6 @@ __host__ void host_extract_bits(
|
||||
v_stream, gpu_index, lwe_array_out_ks_buffer,
|
||||
lwe_array_in_shifted_buffer, ksk, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_ksk, level_count_ksk, 1);
|
||||
|
||||
copy_small_lwe<<<1, 256, 0, *stream>>>(
|
||||
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
|
||||
number_of_bits, number_of_bits - bit_idx - 1);
|
||||
@@ -253,15 +257,15 @@ __host__ void host_extract_bits(
|
||||
// Fill lut for the current bit (equivalent to trivial encryption as mask is
|
||||
// 0s) The LUT is filled with -alpha in each coefficient where alpha =
|
||||
// delta*2^{bit_idx-1}
|
||||
fill_lut_body_for_current_bit<Torus, params>
|
||||
<<<blocks, threads, 0, *stream>>>(
|
||||
lut_pbs, (Torus)(0ll - 1ll << (delta_log - 1 + bit_idx)));
|
||||
fill_lut_body_for_current_bit<Torus, params><<<1, threads, 0, *stream>>>(
|
||||
lut_pbs, (Torus)(0ll - 1ll << (delta_log - 1 + bit_idx)),
|
||||
glwe_dimension);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
host_bootstrap_low_latency<Torus, params>(
|
||||
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_pbs,
|
||||
lut_vector_indexes, lwe_array_out_ks_buffer, fourier_bsk, pbs_buffer,
|
||||
glwe_dimension, lwe_dimension_out, lwe_dimension_in, base_log_bsk,
|
||||
glwe_dimension, lwe_dimension_out, polynomial_size, base_log_bsk,
|
||||
level_count_bsk, number_of_samples, 1, max_shared_memory);
|
||||
|
||||
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption
|
||||
@@ -269,7 +273,8 @@ __host__ void host_extract_bits(
|
||||
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
||||
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
|
||||
lwe_array_out_pbs_buffer, (Torus)(1ll << (delta_log - 1 + bit_idx)),
|
||||
(Torus)(1ll << (ciphertext_n_bits - delta_log - bit_idx - 2)));
|
||||
(Torus)(1ll << (ciphertext_n_bits - delta_log - bit_idx - 2)),
|
||||
glwe_dimension);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user