mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
230 lines
8.2 KiB
Plaintext
230 lines
8.2 KiB
Plaintext
#ifndef CNCRT_KS_H
|
|
#define CNCRT_KS_H
|
|
|
|
#include "crypto/gadget.cuh"
|
|
#include "crypto/torus.cuh"
|
|
#include "polynomial/polynomial.cuh"
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
template <typename Torus>
|
|
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
|
uint32_t lwe_dimension_out,
|
|
uint32_t level_count) {
|
|
int pos = i * level_count * (lwe_dimension_out + 1) +
|
|
level * (lwe_dimension_out + 1);
|
|
Torus *ptr = &ksk[pos];
|
|
return ptr;
|
|
}
|
|
|
|
// blockIdx.y represents single lwe ciphertext
|
|
// blockIdx.x represents chunk of lwe ciphertext,
|
|
// chunk_count = glwe_size * polynomial_size / threads.
|
|
// each threads will responsible to process only lwe_size times multiplication
|
|
template <typename Torus>
|
|
__global__ void
|
|
fp_keyswitch(Torus *glwe_array_out, Torus *lwe_array_in, Torus *fp_ksk_array,
|
|
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
|
uint32_t number_of_input_lwe, uint32_t number_of_keys) {
|
|
size_t tid = threadIdx.x;
|
|
|
|
size_t glwe_size = (glwe_dimension + 1);
|
|
size_t lwe_size = (lwe_dimension_in + 1);
|
|
|
|
// number of coefficients in a single fp-ksk
|
|
size_t ksk_size = lwe_size * level_count * glwe_size * polynomial_size;
|
|
|
|
// number of coefficients inside fp-ksk block for each lwe_input coefficient
|
|
size_t ksk_block_size = glwe_size * polynomial_size * level_count;
|
|
|
|
size_t ciphertext_id = blockIdx.y;
|
|
// number of coefficients processed inside single block
|
|
size_t coef_per_block = blockDim.x;
|
|
size_t chunk_id = blockIdx.x;
|
|
size_t ksk_id = ciphertext_id % number_of_keys;
|
|
|
|
extern __shared__ int8_t sharedmem[];
|
|
|
|
// result accumulator, shared memory is used because of frequent access
|
|
Torus *local_glwe_chunk = (Torus *)sharedmem;
|
|
|
|
// current input lwe ciphertext
|
|
auto cur_input_lwe = &lwe_array_in[ciphertext_id * lwe_size];
|
|
// current output glwe ciphertext
|
|
auto cur_output_glwe =
|
|
&glwe_array_out[ciphertext_id * glwe_size * polynomial_size];
|
|
// current out glwe chunk, will be processed inside single block
|
|
auto cur_glwe_chunk = &cur_output_glwe[chunk_id * coef_per_block];
|
|
|
|
// fp key used for current ciphertext
|
|
auto cur_ksk = &fp_ksk_array[ksk_id * ksk_size];
|
|
|
|
// set shared mem accumulator to 0
|
|
local_glwe_chunk[tid] = 0;
|
|
|
|
// iterate through each coefficient of input lwe
|
|
for (size_t i = 0; i <= lwe_dimension_in; i++) {
|
|
Torus a_i =
|
|
round_to_closest_multiple(cur_input_lwe[i], base_log, level_count);
|
|
|
|
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
|
|
// block of key for current lwe coefficient (cur_input_lwe[i])
|
|
auto ksk_block = &cur_ksk[i * ksk_block_size];
|
|
|
|
// iterate through levels, calculating decomposition in reverse order
|
|
for (size_t j = 0; j < level_count; j++) {
|
|
auto ksk_glwe =
|
|
&ksk_block[(level_count - j - 1) * glwe_size * polynomial_size];
|
|
auto ksk_glwe_chunk = &ksk_glwe[chunk_id * coef_per_block];
|
|
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
local_glwe_chunk[tid] -= decomposed * ksk_glwe_chunk[tid];
|
|
}
|
|
}
|
|
cur_glwe_chunk[tid] = local_glwe_chunk[tid];
|
|
}
|
|
|
|
/*
|
|
* keyswitch kernel
|
|
* Each thread handles a piece of the following equation:
|
|
* $$GLWE_s2(\Delta.m+e) = (0,0,..,0,b) - \sum_{i=0,k-1} <Dec(a_i),
|
|
* (GLWE_s2(s1_i q/beta),..,GLWE(s1_i q/beta^l)>$$ where k is the dimension of
|
|
* the GLWE ciphertext. If the polynomial dimension in GLWE is > 1, this
|
|
* equation is solved for each polynomial coefficient. where Dec denotes the
|
|
* decomposition with base beta and l levels and the inner product is done
|
|
* between the decomposition of a_i and l GLWE encryptions of s1_i q/\beta^j,
|
|
* with j in [1,l] We obtain a GLWE encryption of Delta.m (with Delta the
|
|
* scaling factor) under key s2 instead of s1, with an increased noise
|
|
*
|
|
*/
|
|
template <typename Torus>
|
|
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
|
uint32_t base_log, uint32_t level_count,
|
|
int lwe_lower, int lwe_upper, int cutoff) {
|
|
int tid = threadIdx.x;
|
|
|
|
extern __shared__ int8_t sharedmem[];
|
|
|
|
Torus *local_lwe_array_out = (Torus *)sharedmem;
|
|
|
|
auto block_lwe_array_in =
|
|
get_chunk(lwe_array_in, blockIdx.x, lwe_dimension_in + 1);
|
|
auto block_lwe_array_out =
|
|
get_chunk(lwe_array_out, blockIdx.x, lwe_dimension_out + 1);
|
|
|
|
auto gadget = GadgetMatrixSingle<Torus>(base_log, level_count);
|
|
|
|
int lwe_part_per_thd;
|
|
if (tid < cutoff) {
|
|
lwe_part_per_thd = lwe_upper;
|
|
} else {
|
|
lwe_part_per_thd = lwe_lower;
|
|
}
|
|
__syncthreads();
|
|
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
local_lwe_array_out[idx] = 0;
|
|
}
|
|
|
|
if (tid == 0) {
|
|
local_lwe_array_out[lwe_dimension_out] =
|
|
block_lwe_array_in[lwe_dimension_in];
|
|
}
|
|
|
|
for (int i = 0; i < lwe_dimension_in; i++) {
|
|
|
|
__syncthreads();
|
|
|
|
Torus a_i =
|
|
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
|
|
|
|
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
|
Torus mask_mod_b = (1ll << base_log) - 1ll;
|
|
|
|
for (int j = 0; j < level_count; j++) {
|
|
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
|
|
lwe_dimension_out, level_count);
|
|
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
block_lwe_array_out[idx] = local_lwe_array_out[idx];
|
|
}
|
|
}
|
|
|
|
/// assume lwe_array_in in the gpu
|
|
template <typename Torus>
|
|
__host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
|
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out,
|
|
Torus *lwe_array_in, Torus *ksk, uint32_t lwe_dimension_in,
|
|
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
|
uint32_t num_samples) {
|
|
|
|
cudaSetDevice(gpu_index);
|
|
constexpr int ideal_threads = 128;
|
|
|
|
int lwe_dim = lwe_dimension_out + 1;
|
|
int lwe_lower, lwe_upper, cutoff;
|
|
if (lwe_dim % ideal_threads == 0) {
|
|
lwe_lower = lwe_dim / ideal_threads;
|
|
lwe_upper = lwe_dim / ideal_threads;
|
|
cutoff = 0;
|
|
} else {
|
|
int y =
|
|
ceil((double)lwe_dim / (double)ideal_threads) * ideal_threads - lwe_dim;
|
|
cutoff = ideal_threads - y;
|
|
lwe_lower = lwe_dim / ideal_threads;
|
|
lwe_upper = (int)ceil((double)lwe_dim / (double)ideal_threads);
|
|
}
|
|
|
|
int lwe_size_after = (lwe_dimension_out + 1) * num_samples;
|
|
|
|
int shared_mem = sizeof(Torus) * (lwe_dimension_out + 1);
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
cudaMemsetAsync(lwe_array_out, 0, sizeof(Torus) * lwe_size_after, *stream);
|
|
|
|
dim3 grid(num_samples, 1, 1);
|
|
dim3 threads(ideal_threads, 1, 1);
|
|
|
|
cudaFuncSetAttribute(keyswitch<Torus>,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem);
|
|
|
|
keyswitch<<<grid, threads, shared_mem, *stream>>>(
|
|
lwe_array_out, lwe_array_in, ksk, lwe_dimension_in, lwe_dimension_out,
|
|
base_log, level_count, lwe_lower, lwe_upper, cutoff);
|
|
check_cuda_error(cudaGetLastError());
|
|
}
|
|
|
|
template <typename Torus>
|
|
__host__ void cuda_fp_keyswitch_lwe_to_glwe(
|
|
void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
|
Torus *lwe_array_in, Torus *fp_ksk_array, uint32_t lwe_dimension_in,
|
|
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
|
uint32_t level_count, uint32_t number_of_input_lwe,
|
|
uint32_t number_of_keys) {
|
|
cudaSetDevice(gpu_index);
|
|
int threads = 256;
|
|
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
|
|
dim3 blocks(glwe_accumulator_size / threads, number_of_input_lwe, 1);
|
|
|
|
int shared_mem = sizeof(Torus) * threads;
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
fp_keyswitch<<<blocks, threads, shared_mem, *stream>>>(
|
|
glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in,
|
|
glwe_dimension, polynomial_size, base_log, level_count,
|
|
number_of_input_lwe, number_of_keys);
|
|
}
|
|
|
|
#endif
|