mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
152 lines
4.9 KiB
Plaintext
152 lines
4.9 KiB
Plaintext
#ifndef CNCRT_KS_H
|
|
#define CNCRT_KS_H
|
|
|
|
#include "crypto/gadget.cuh"
|
|
#include "crypto/torus.cuh"
|
|
#include "polynomial/polynomial.cuh"
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
template <typename Torus>
|
|
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
|
uint32_t lwe_dimension_out,
|
|
uint32_t level_count) {
|
|
int pos = i * level_count * (lwe_dimension_out + 1) +
|
|
level * (lwe_dimension_out + 1);
|
|
Torus *ptr = &ksk[pos];
|
|
return ptr;
|
|
}
|
|
|
|
template <typename Torus>
|
|
__device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) {
|
|
Torus res = state & mod_b_mask;
|
|
state >>= base_log;
|
|
Torus carry = ((res - 1ll) | state) & res;
|
|
carry >>= base_log - 1;
|
|
state += carry;
|
|
res -= carry << base_log;
|
|
return res;
|
|
}
|
|
|
|
/*
|
|
* keyswitch kernel
|
|
* Each thread handles a piece of the following equation:
|
|
* $$GLWE_s2(\Delta.m+e) = (0,0,..,0,b) - \sum_{i=0,k-1} <Dec(a_i),
|
|
* (GLWE_s2(s1_i q/beta),..,GLWE(s1_i q/beta^l)>$$ where k is the dimension of
|
|
* the GLWE ciphertext. If the polynomial dimension in GLWE is > 1, this
|
|
* equation is solved for each polynomial coefficient. where Dec denotes the
|
|
* decomposition with base beta and l levels and the inner product is done
|
|
* between the decomposition of a_i and l GLWE encryptions of s1_i q/\beta^j,
|
|
* with j in [1,l] We obtain a GLWE encryption of Delta.m (with Delta the
|
|
* scaling factor) under key s2 instead of s1, with an increased noise
|
|
*
|
|
*/
|
|
template <typename Torus>
|
|
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
|
uint32_t base_log, uint32_t level_count,
|
|
int lwe_lower, int lwe_upper, int cutoff) {
|
|
int tid = threadIdx.x;
|
|
|
|
extern __shared__ char sharedmem[];
|
|
|
|
Torus *local_lwe_array_out = (Torus *)sharedmem;
|
|
|
|
auto block_lwe_array_in =
|
|
get_chunk(lwe_array_in, blockIdx.x, lwe_dimension_in + 1);
|
|
auto block_lwe_array_out =
|
|
get_chunk(lwe_array_out, blockIdx.x, lwe_dimension_out + 1);
|
|
|
|
auto gadget = GadgetMatrixSingle<Torus>(base_log, level_count);
|
|
|
|
int lwe_part_per_thd;
|
|
if (tid < cutoff) {
|
|
lwe_part_per_thd = lwe_upper;
|
|
} else {
|
|
lwe_part_per_thd = lwe_lower;
|
|
}
|
|
__syncthreads();
|
|
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
local_lwe_array_out[idx] = 0;
|
|
}
|
|
|
|
if (tid == 0) {
|
|
local_lwe_array_out[lwe_dimension_out] =
|
|
block_lwe_array_in[lwe_dimension_in];
|
|
}
|
|
|
|
for (int i = 0; i < lwe_dimension_in; i++) {
|
|
|
|
__syncthreads();
|
|
|
|
Torus a_i =
|
|
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
|
|
|
|
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
|
|
for (int j = 0; j < level_count; j++) {
|
|
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
|
|
lwe_dimension_out, level_count);
|
|
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int k = 0; k < lwe_part_per_thd; k++) {
|
|
int idx = tid + k * blockDim.x;
|
|
block_lwe_array_out[idx] = local_lwe_array_out[idx];
|
|
}
|
|
}
|
|
|
|
/// assume lwe_array_in in the gpu
|
|
template <typename Torus>
|
|
__host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
|
void *v_stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
|
uint32_t level_count, uint32_t num_samples) {
|
|
|
|
constexpr int ideal_threads = 128;
|
|
|
|
int lwe_dim = lwe_dimension_out + 1;
|
|
int lwe_lower, lwe_upper, cutoff;
|
|
if (lwe_dim % ideal_threads == 0) {
|
|
lwe_lower = lwe_dim / ideal_threads;
|
|
lwe_upper = lwe_dim / ideal_threads;
|
|
cutoff = 0;
|
|
} else {
|
|
int y =
|
|
ceil((double)lwe_dim / (double)ideal_threads) * ideal_threads - lwe_dim;
|
|
cutoff = ideal_threads - y;
|
|
lwe_lower = lwe_dim / ideal_threads;
|
|
lwe_upper = (int)ceil((double)lwe_dim / (double)ideal_threads);
|
|
}
|
|
|
|
int lwe_size_after = (lwe_dimension_out + 1) * num_samples;
|
|
|
|
int shared_mem = sizeof(Torus) * (lwe_dimension_out + 1);
|
|
|
|
cudaMemset(lwe_array_out, 0, sizeof(Torus) * lwe_size_after);
|
|
|
|
dim3 grid(num_samples, 1, 1);
|
|
dim3 threads(ideal_threads, 1, 1);
|
|
|
|
cudaFuncSetAttribute(keyswitch<Torus>,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem);
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
keyswitch<<<grid, threads, shared_mem, *stream>>>(
|
|
lwe_array_out, lwe_array_in, ksk, lwe_dimension_in, lwe_dimension_out,
|
|
base_log, level_count, lwe_lower, lwe_upper, cutoff);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
cudaStreamSynchronize(*stream);
|
|
}
|
|
|
|
#endif
|