mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
721 lines
29 KiB
Plaintext
721 lines
29 KiB
Plaintext
#ifndef CNCRT_KS_CUH
|
|
#define CNCRT_KS_CUH
|
|
|
|
#include <algorithm>
|
|
|
|
#include "device.h"
|
|
#include "gadget.cuh"
|
|
#include "helper_multi_gpu.h"
|
|
#include "polynomial/functions.cuh"
|
|
#include "polynomial/polynomial_math.cuh"
|
|
#include "torus.cuh"
|
|
#include "utils/helper.cuh"
|
|
#include "utils/kernel_dimensions.cuh"
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
const int BLOCK_SIZE_DECOMP = 8;
|
|
const int BLOCK_SIZE_GEMM_KS = 36;
|
|
const int THREADS_GEMM_KS = 6;
|
|
|
|
inline uint64_t get_threshold_ks_gemm() { return 128; }
|
|
|
|
template <typename Torus> struct ks_mem {
|
|
Torus *d_buffer;
|
|
uint64_t num_lwes;
|
|
uint32_t lwe_dimension;
|
|
};
|
|
|
|
template <typename Torus>
|
|
uint64_t scratch_cuda_keyswitch_size(uint32_t lwe_dimension_in,
|
|
uint32_t lwe_dimension_out,
|
|
uint32_t num_lwes) {
|
|
GPU_ASSERT(lwe_dimension_in >= lwe_dimension_out,
|
|
"Trying to allocate KS temp buffer for invalid LWE dimensions");
|
|
return (uint64_t)num_lwes * lwe_dimension_in * sizeof(Torus) * 2;
|
|
}
|
|
|
|
template <typename Torus>
|
|
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
|
uint32_t lwe_dimension_out,
|
|
uint32_t level_count) {
|
|
int pos = i * level_count * (lwe_dimension_out + 1) +
|
|
level * (lwe_dimension_out + 1);
|
|
Torus *ptr = &ksk[pos];
|
|
return ptr;
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ T closest_repr(T input, uint32_t base_log, uint32_t level_count) {
|
|
T minus_2 = static_cast<T>(-2);
|
|
const T rep_bit_count = level_count * base_log; // 32
|
|
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count; // 32
|
|
auto shift = (non_rep_bit_count - 1); // 31
|
|
T res = input >> shift;
|
|
res++;
|
|
res &= minus_2;
|
|
res <<= shift;
|
|
return res;
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void closest_representable(const T *input, T *output,
|
|
uint32_t base_log, uint32_t level_count) {
|
|
output[0] = closest_repr(input[0], base_log, level_count);
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ void
|
|
host_cuda_closest_representable(cudaStream_t stream, uint32_t gpu_index,
|
|
const T *input, T *output, uint32_t base_log,
|
|
uint32_t level_count) {
|
|
dim3 grid(1, 1, 1);
|
|
dim3 threads(1, 1, 1);
|
|
|
|
cuda_set_device(gpu_index);
|
|
closest_representable<<<grid, threads, 0, stream>>>(input, output, base_log,
|
|
level_count);
|
|
}
|
|
|
|
// Initialize decomposition by performing rounding
|
|
// and decomposing one level of an array of Torus LWEs. Only
|
|
// decomposes the mask elements of the incoming LWEs.
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
|
|
uint32_t lwe_dimension,
|
|
uint32_t num_lwe, uint32_t base_log,
|
|
uint32_t level_count) {
|
|
|
|
// index of this LWE ct in the buffer
|
|
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// index of the LWE sample in the LWE ct
|
|
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
|
return;
|
|
|
|
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
|
|
// we only decompose the mask. Thus the stride for reading
|
|
// is lwe_dimension + 1, while for writing it is lwe_dimension
|
|
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
|
|
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
|
auto write_state_idx =
|
|
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
|
|
|
|
Torus a_i = lwe_in[read_val_idx];
|
|
|
|
Torus state = init_decomposer_state(a_i, base_log, level_count);
|
|
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out;
|
|
kst_ptr_lwe_out[write_val_idx] =
|
|
decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
__syncthreads();
|
|
lwe_out[write_state_idx] = state;
|
|
}
|
|
|
|
// Decompose an array of LWEs with indirection through lwe_input_indices
|
|
// The LWE array can contain total_lwe LWEs where total_lwe can be different
|
|
// from num_lwe. The maximum index should be <= total_lwe. num_lwe is the number
|
|
// of LWEs to decompose The output buffer should have space for num_lwe LWEs.
|
|
// These will be sorted according to the input indices.
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void decompose_vectorize_init_with_indices(
|
|
Torus const *lwe_in, const Torus *__restrict__ lwe_input_indices,
|
|
Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log,
|
|
uint32_t level_count) {
|
|
|
|
// index of this LWE ct in the buffer
|
|
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// index of the LWE sample in the LWE ct
|
|
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
|
return;
|
|
|
|
// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
|
|
// we only decompose the mask. Thus the stride for reading
|
|
// is lwe_dimension + 1, while for writing it is lwe_dimension
|
|
auto read_val_idx =
|
|
lwe_input_indices[lwe_idx] * (lwe_dimension + 1) + lwe_sample_idx;
|
|
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
|
auto write_state_idx =
|
|
num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx;
|
|
|
|
Torus a_i = lwe_in[read_val_idx];
|
|
|
|
Torus state = init_decomposer_state(a_i, base_log, level_count);
|
|
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out;
|
|
kst_ptr_lwe_out[write_val_idx] =
|
|
decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
__syncthreads();
|
|
lwe_out[write_state_idx] = state;
|
|
}
|
|
|
|
// Continue decomposition of an array of Torus elements in place. Supposes
|
|
// that the array contains already decomposed elements and
|
|
// computes the new decomposed level in place.
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void
|
|
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
|
|
uint32_t num_lwe, uint32_t base_log,
|
|
uint32_t level_count) {
|
|
|
|
// index of this LWE ct in the buffer
|
|
auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// index of the LWE sample in the LWE ct
|
|
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
|
|
return;
|
|
|
|
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;
|
|
auto state_idx = num_lwe * lwe_dimension + val_idx;
|
|
|
|
Torus state = buffer_in[state_idx];
|
|
__syncthreads();
|
|
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
|
|
KSTorus *kst_ptr_lwe_out = (KSTorus *)buffer_in;
|
|
kst_ptr_lwe_out[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
__syncthreads();
|
|
buffer_in[state_idx] = state;
|
|
}
|
|
|
|
/* LWEs inputs to the keyswitch function are stored as a_0,...,a_{lwe_dim},b,
|
|
* where a_i are mask elements and b is the message. We initialize
|
|
* the output keyswitched LWEs to 0, ..., 0, -b. The GEMM keyswitch is computed
|
|
* as:
|
|
* -(-b + sum(a_i A_KSK))
|
|
*/
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void keyswitch_gemm_copy_negated_message_with_indices(
|
|
const Torus *__restrict__ lwe_in,
|
|
const Torus *__restrict__ lwe_input_indices, KSTorus *__restrict__ lwe_out,
|
|
const Torus *__restrict__ lwe_output_indices,
|
|
|
|
uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
|
|
|
|
uint32_t lwe_id = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
if (lwe_id >= num_lwes)
|
|
return;
|
|
|
|
uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
|
|
uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
|
|
|
|
Torus body_in =
|
|
lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
|
|
Torus body_out;
|
|
if constexpr (std::is_same_v<KSTorus, Torus>) {
|
|
body_out = -body_in;
|
|
} else {
|
|
body_out = closest_repr(
|
|
lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in],
|
|
sizeof(KSTorus) * 8, 1);
|
|
|
|
// Power of two are encoded in the MSBs of the types so we need to scale
|
|
// the type to the other one without having to worry about the moduli
|
|
static_assert(sizeof(Torus) >= sizeof(KSTorus),
|
|
"Cannot compile keyswitch with given input/output dtypes");
|
|
Torus input_to_output_scaling_factor =
|
|
(sizeof(Torus) - sizeof(KSTorus)) * 8;
|
|
|
|
auto rounded_downscaled_body =
|
|
(KSTorus)(body_out >> input_to_output_scaling_factor);
|
|
|
|
body_out = -rounded_downscaled_body;
|
|
}
|
|
lwe_out[lwe_out_idx * (lwe_dimension_out + 1) + lwe_dimension_out] =
|
|
(KSTorus)body_out;
|
|
}
|
|
|
|
// The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK)).
|
|
// This function finishes the KS computation by negating all elements in the
|
|
// array using output indices. The array contains -b + SUM(a_i x LWE_i) and this
|
|
// final step computes b - SUM(a_i x LWE_i).
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void keyswitch_negate_with_output_indices(
|
|
KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices,
|
|
uint32_t lwe_size, uint32_t num_lwe) {
|
|
|
|
// index of this LWE ct in the buffer
|
|
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// index of the LWE sample in the LWE ct
|
|
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
|
|
return;
|
|
|
|
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
|
|
|
|
Torus val = buffer_in[val_idx];
|
|
buffer_in[val_idx] = -val;
|
|
}
|
|
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void keyswitch_zero_output_with_output_indices(
|
|
KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices,
|
|
uint32_t lwe_size, uint32_t num_lwe) {
|
|
|
|
// index of this LWE ct in the buffer
|
|
auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// index of the LWE sample in the LWE ct
|
|
auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size)
|
|
return;
|
|
|
|
auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx;
|
|
|
|
buffer_in[val_idx] = 0;
|
|
}
|
|
|
|
/*
|
|
* keyswitch kernel
|
|
* Each thread handles a piece of the following equation:
|
|
* $$GLWE_s2(\Delta.m+e) = (0,0,..,0,b) - \sum_{i=0,k-1} <Dec(a_i),
|
|
* (GLWE_s2(s1_i q/beta),..,GLWE(s1_i q/beta^l)>$$ where k is the dimension of
|
|
* the GLWE ciphertext. If the polynomial dimension in GLWE is > 1, this
|
|
* equation is solved for each polynomial coefficient. where Dec denotes the
|
|
* decomposition with base beta and l levels and the inner product is done
|
|
* between the decomposition of a_i and l GLWE encryptions of s1_i q/\beta^j,
|
|
* with j in [1,l] We obtain a GLWE encryption of Delta.m (with Delta the
|
|
* scaling factor) under key s2 instead of s1, with an increased noise
|
|
*
|
|
*/
|
|
// Each thread in x are used to calculate one output.
|
|
// threads in y are used to parallelize the lwe_dimension_in loop.
|
|
// shared memory is used to store intermediate results of the reduction.
|
|
// Note: To reduce register pressure we have slightly changed the algorithm,
|
|
// the idea consists in calculating the negate value of the output. So, instead
|
|
// of accumulating subtractions using -=, we accumulate additions using += in
|
|
// the local_lwe_out. This seems to work better cause profits madd ops and save
|
|
// some regs. For this to work, we need to negate the input
|
|
// lwe_array_in[lwe_dimension_in], and negate back the output at the end to get
|
|
// the correct results. Additionally, we split the calculation of the ksk offset
|
|
// in two parts, a constant part is calculated before the loop, and a variable
|
|
// part is calculated inside the loop. This seems to help with the register
|
|
// pressure as well.
|
|
template <typename Torus, typename KSTorus>
|
|
__global__ void
|
|
keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
|
const Torus *__restrict__ lwe_array_in,
|
|
const Torus *__restrict__ lwe_input_indexes,
|
|
const KSTorus *__restrict__ ksk, uint32_t lwe_dimension_in,
|
|
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
|
|
const int tid = threadIdx.x + blockIdx.y * blockDim.x;
|
|
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;
|
|
|
|
extern __shared__ int8_t sharedmem[];
|
|
Torus *lwe_acc_out = (Torus *)sharedmem;
|
|
auto block_lwe_array_out = get_chunk(
|
|
lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1);
|
|
|
|
if (tid <= lwe_dimension_out) {
|
|
|
|
KSTorus local_lwe_out = 0;
|
|
auto block_lwe_array_in = get_chunk(
|
|
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
|
|
|
|
if (tid == lwe_dimension_out && threadIdx.y == 0) {
|
|
if constexpr (std::is_same_v<KSTorus, Torus>) {
|
|
local_lwe_out = -block_lwe_array_in[lwe_dimension_in];
|
|
} else {
|
|
auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in],
|
|
sizeof(KSTorus) * 8, 1);
|
|
|
|
// Power of two are encoded in the MSBs of the types so we need to scale
|
|
// the type to the other one without having to worry about the moduli
|
|
Torus input_to_output_scaling_factor =
|
|
(sizeof(Torus) - sizeof(KSTorus)) * 8;
|
|
|
|
auto rounded_downscaled_body =
|
|
(KSTorus)(new_body >> input_to_output_scaling_factor);
|
|
|
|
local_lwe_out = -rounded_downscaled_body;
|
|
}
|
|
}
|
|
const Torus mask_mod_b = (1ll << base_log) - 1ll;
|
|
|
|
const int pack_size = (lwe_dimension_in + blockDim.y - 1) / blockDim.y;
|
|
const int start_i = pack_size * threadIdx.y;
|
|
const int end_i = SEL(lwe_dimension_in, pack_size * (threadIdx.y + 1),
|
|
pack_size * (threadIdx.y + 1) <= lwe_dimension_in);
|
|
|
|
// This loop distribution seems to benefit the global mem reads
|
|
for (int i = start_i; i < end_i; i++) {
|
|
Torus state =
|
|
init_decomposer_state(block_lwe_array_in[i], base_log, level_count);
|
|
uint32_t offset = i * level_count * (lwe_dimension_out + 1);
|
|
for (int j = 0; j < level_count; j++) {
|
|
|
|
KSTorus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
|
|
local_lwe_out +=
|
|
(KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] *
|
|
decomposed;
|
|
}
|
|
}
|
|
|
|
lwe_acc_out[shmem_index] = local_lwe_out;
|
|
}
|
|
|
|
if (tid <= lwe_dimension_out) {
|
|
for (int offset = blockDim.y / 2; offset > 0 && threadIdx.y < offset;
|
|
offset /= 2) {
|
|
__syncthreads();
|
|
lwe_acc_out[shmem_index] +=
|
|
lwe_acc_out[shmem_index + offset * blockDim.x];
|
|
}
|
|
if (threadIdx.y == 0)
|
|
block_lwe_array_out[tid] = -lwe_acc_out[shmem_index];
|
|
}
|
|
}
|
|
|
|
template <typename Torus, typename KSTorus>
|
|
__host__ void host_keyswitch_lwe_ciphertext_vector(
|
|
cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out,
|
|
Torus const *lwe_output_indexes, Torus const *lwe_array_in,
|
|
Torus const *lwe_input_indexes, KSTorus const *ksk,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
|
uint32_t level_count, uint32_t num_samples) {
|
|
|
|
cuda_set_device(gpu_index);
|
|
|
|
constexpr int num_threads_y = 32;
|
|
int num_blocks_per_sample, num_threads_x;
|
|
|
|
getNumBlocksAndThreads2D(lwe_dimension_out + 1, 512, num_threads_y,
|
|
num_blocks_per_sample, num_threads_x);
|
|
|
|
int shared_mem = sizeof(Torus) * num_threads_y * num_threads_x;
|
|
PANIC_IF_FALSE(
|
|
num_blocks_per_sample <= 65536,
|
|
"Cuda error (Keyswitch): number of blocks per sample (%d) is too large",
|
|
num_blocks_per_sample);
|
|
|
|
// In multiplication of large integers (512, 1024, 2048), the number of
|
|
// samples can be larger than 65536, so we need to set it in the first
|
|
// dimension of the grid
|
|
dim3 grid(num_samples, num_blocks_per_sample, 1);
|
|
dim3 threads(num_threads_x, num_threads_y, 1);
|
|
|
|
keyswitch<Torus, KSTorus><<<grid, threads, shared_mem, stream>>>(
|
|
lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk,
|
|
lwe_dimension_in, lwe_dimension_out, base_log, level_count);
|
|
check_cuda_error(cudaGetLastError());
|
|
}
|
|
|
|
// The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK))
|
|
template <typename Torus, typename KSTorus>
|
|
__host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
|
|
cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out,
|
|
Torus const *lwe_output_indices, Torus const *lwe_array_in,
|
|
Torus const *lwe_input_indices, KSTorus const *ksk,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
|
uint32_t level_count, uint32_t num_samples, Torus *fp_tmp_buffer,
|
|
bool uses_trivial_indices) {
|
|
cuda_set_device(gpu_index);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
// fp_tmp_buffer contains 2x the space to store the input LWE masks without
|
|
// the body the first half can be interpreted with a smaller dtype when
|
|
// performing 64->32 KS the second half, storing decomposition state, must be
|
|
// interpreted as Torus* (usually 64b)
|
|
KSTorus *d_mem_0 =
|
|
(KSTorus *)fp_tmp_buffer; // keeps decomposed value (in KSTorus type)
|
|
|
|
// Set the scratch buffer to 0 as it is used to accumulate
|
|
// decomposition temporary results
|
|
if (uses_trivial_indices) {
|
|
cuda_memset_async(lwe_array_out, 0,
|
|
num_samples * (lwe_dimension_out + 1) * sizeof(KSTorus),
|
|
stream, gpu_index);
|
|
} else {
|
|
// gemm to ks the individual LWEs to GLWEs
|
|
dim3 grid_zero(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
|
|
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
|
dim3 threads_zero(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
|
|
|
keyswitch_zero_output_with_output_indices<Torus, KSTorus>
|
|
<<<grid_zero, threads_zero, 0, stream>>>(
|
|
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
|
|
num_samples);
|
|
}
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
dim3 grid_copy(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
|
dim3 threads_copy(BLOCK_SIZE_DECOMP);
|
|
|
|
// lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
|
|
// lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
|
|
// and negates them
|
|
keyswitch_gemm_copy_negated_message_with_indices<Torus, KSTorus>
|
|
<<<grid_copy, threads_copy, 0, stream>>>(
|
|
lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
|
|
lwe_dimension_in, num_samples, lwe_dimension_out);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
// decompose LWEs
|
|
// don't decompose LWE body - the LWE has lwe_size + 1 elements. The last
|
|
// element, the body is ignored by rounding down the number of blocks assuming
|
|
// here that the LWE dimension is a multiple of the block size
|
|
dim3 grid_decomp(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP),
|
|
CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP));
|
|
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
|
|
|
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
|
|
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
|
|
// 32, 64, and 128-bit Torus elements
|
|
// Sanity check: the shared memory size is a constant defined by the algorithm
|
|
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
|
|
"GEMM kernel error: shared memory required might be too large");
|
|
|
|
auto stride_KSK_buffer = (lwe_dimension_out + 1) * level_count;
|
|
|
|
// gemm to ks the individual LWEs to GLWEs
|
|
dim3 grid_gemm(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_GEMM_KS),
|
|
CEIL_DIV(num_samples, BLOCK_SIZE_GEMM_KS));
|
|
dim3 threads_gemm(BLOCK_SIZE_GEMM_KS * THREADS_GEMM_KS);
|
|
|
|
// decompose first level (skips the body in the input buffer)
|
|
decompose_vectorize_init_with_indices<Torus, KSTorus>
|
|
<<<grid_decomp, threads_decomp, 0, stream>>>(
|
|
lwe_array_in, lwe_input_indices, fp_tmp_buffer, lwe_dimension_in,
|
|
num_samples, base_log, level_count);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
if (uses_trivial_indices) {
|
|
tgemm<KSTorus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
|
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
|
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
|
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
} else {
|
|
tgemm_with_indices<KSTorus, Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
|
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
|
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
|
ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1,
|
|
lwe_output_indices);
|
|
check_cuda_error(cudaGetLastError());
|
|
}
|
|
|
|
auto ksk_block_size = (lwe_dimension_out + 1);
|
|
|
|
for (int li = 1; li < level_count; ++li) {
|
|
decompose_vectorize_step_inplace<Torus, KSTorus>
|
|
<<<grid_decomp, threads_decomp, 0, stream>>>(
|
|
fp_tmp_buffer, lwe_dimension_in, num_samples, base_log,
|
|
level_count);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
if (uses_trivial_indices) {
|
|
tgemm<KSTorus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
|
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
|
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
|
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
|
|
lwe_dimension_out + 1);
|
|
check_cuda_error(cudaGetLastError());
|
|
|
|
} else {
|
|
tgemm_with_indices<KSTorus, Torus, BLOCK_SIZE_GEMM_KS, THREADS_GEMM_KS>
|
|
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
|
|
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
|
|
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
|
|
lwe_dimension_out + 1, lwe_output_indices);
|
|
check_cuda_error(cudaGetLastError());
|
|
}
|
|
}
|
|
|
|
// gemm to ks the individual LWEs to GLWEs
|
|
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
|
|
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
|
|
dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
|
|
// Negate all outputs in the output LWEs. This is the final step in the GEMM
|
|
// keyswitch computed as: -(-b + sum(a_i A_KSK))
|
|
keyswitch_negate_with_output_indices<Torus, KSTorus>
|
|
<<<grid_negate, threads_negate, 0, stream>>>(
|
|
lwe_array_out, lwe_output_indices, lwe_dimension_out + 1,
|
|
num_samples);
|
|
check_cuda_error(cudaGetLastError());
|
|
}
|
|
|
|
template <typename Torus, typename KSTorus>
|
|
void execute_keyswitch_async(
|
|
CudaStreams streams, const LweArrayVariant<Torus> &lwe_array_out,
|
|
const LweArrayVariant<Torus> &lwe_output_indexes,
|
|
const LweArrayVariant<Torus> &lwe_array_in,
|
|
const LweArrayVariant<Torus> &lwe_input_indexes, KSTorus *const *ksks,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
|
uint32_t level_count, uint32_t num_samples, bool uses_trivial_indices,
|
|
const std::vector<ks_mem<Torus> *> &fp_tmp_buffer) {
|
|
|
|
/// If the number of radix blocks is lower than the number of GPUs, not all
|
|
/// GPUs will be active and there will be 1 input per GPU
|
|
for (uint i = 0; i < streams.count(); i++) {
|
|
int num_samples_on_gpu =
|
|
get_num_inputs_on_gpu(num_samples, i, streams.count());
|
|
|
|
Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
|
Torus *current_lwe_output_indexes =
|
|
get_variant_element(lwe_output_indexes, i);
|
|
Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
|
Torus *current_lwe_input_indexes =
|
|
get_variant_element(lwe_input_indexes, i);
|
|
|
|
if (!fp_tmp_buffer.empty() &&
|
|
num_samples_on_gpu >= get_threshold_ks_gemm()) {
|
|
GPU_ASSERT(fp_tmp_buffer.size() >= streams.count(),
|
|
"GEMM KS Buffers %ld were not initialized for this amount of "
|
|
"streams, %d",
|
|
fp_tmp_buffer.size(), streams.count());
|
|
|
|
GPU_ASSERT(fp_tmp_buffer[i]->num_lwes >= num_samples_on_gpu,
|
|
"KS temp buffer not big enough");
|
|
|
|
GPU_ASSERT(fp_tmp_buffer[i]->lwe_dimension ==
|
|
std::max(lwe_dimension_in, lwe_dimension_out),
|
|
"KS temp buffer was created for a different input LWE size: "
|
|
"%d vs (in:%d, out:%d)",
|
|
fp_tmp_buffer[i]->lwe_dimension, lwe_dimension_in,
|
|
lwe_dimension_out);
|
|
|
|
// Compute Keyswitch
|
|
host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
|
|
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
|
|
current_lwe_output_indexes, current_lwe_array_in,
|
|
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
|
|
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
|
|
fp_tmp_buffer[i]->d_buffer, uses_trivial_indices);
|
|
|
|
} else {
|
|
// Compute Keyswitch
|
|
host_keyswitch_lwe_ciphertext_vector<Torus>(
|
|
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
|
|
current_lwe_output_indexes, current_lwe_array_in,
|
|
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
|
|
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Torus>
|
|
__host__ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe(
|
|
cudaStream_t stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
|
|
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
|
uint32_t num_lwes, bool allocate_gpu_memory) {
|
|
cuda_set_device(gpu_index);
|
|
|
|
int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size;
|
|
|
|
// allocate at least LWE-mask times two: to keep both decomposition state and
|
|
// decomposed intermediate value
|
|
uint64_t memory_unit = glwe_accumulator_size > lwe_dimension * 2
|
|
? glwe_accumulator_size
|
|
: lwe_dimension * 2;
|
|
|
|
uint64_t size_tracker = 0;
|
|
uint64_t buffer_size = 2 * num_lwes * memory_unit * sizeof(Torus);
|
|
*fp_ks_buffer = (int8_t *)cuda_malloc_with_size_tracking_async(
|
|
buffer_size, stream, gpu_index, size_tracker, allocate_gpu_memory);
|
|
return size_tracker;
|
|
}
|
|
|
|
// public functional packing keyswitch for a single LWE ciphertext
|
|
//
|
|
// Assumes there are (glwe_dimension+1) * polynomial_size threads split through
|
|
// different thread blocks at the x-axis to work on that input.
|
|
template <typename Torus>
|
|
__device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
|
|
Torus *glwe_out, Torus const *lwe_in, Torus const *fp_ksk,
|
|
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count) {
|
|
|
|
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
size_t glwe_size = (glwe_dimension + 1);
|
|
|
|
if (tid < glwe_size * polynomial_size) {
|
|
const int local_index = threadIdx.x;
|
|
// the output_glwe is split in polynomials and each x-block takes one of
|
|
// them
|
|
size_t poly_id = blockIdx.x;
|
|
size_t coef_per_block = blockDim.x;
|
|
|
|
// number of coefficients inside fp-ksk block for each lwe_input coefficient
|
|
size_t ksk_block_size = glwe_size * polynomial_size * level_count;
|
|
|
|
// initialize accumulator to 0
|
|
glwe_out[tid] = SEL(0, lwe_in[lwe_dimension_in],
|
|
tid == glwe_dimension * polynomial_size);
|
|
|
|
// Iterate through all lwe elements
|
|
for (int i = 0; i < lwe_dimension_in; i++) {
|
|
// Round and prepare decomposition
|
|
Torus state = init_decomposer_state(lwe_in[i], base_log, level_count);
|
|
|
|
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
|
|
|
// block of key for current lwe coefficient (cur_input_lwe[i])
|
|
auto ksk_block = &fp_ksk[i * ksk_block_size];
|
|
for (int j = 0; j < level_count; j++) {
|
|
auto ksk_glwe = &ksk_block[j * glwe_size * polynomial_size];
|
|
// Iterate through each level and multiply by the ksk piece
|
|
auto ksk_glwe_chunk = &ksk_glwe[poly_id * coef_per_block];
|
|
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
|
glwe_out[tid] -= decomposed * ksk_glwe_chunk[local_index];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// To-do: Rewrite this kernel for efficiency
|
|
template <typename Torus>
|
|
__global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in,
|
|
uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t num_lwes) {
|
|
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
if (tid < (glwe_dimension + 1) * polynomial_size) {
|
|
glwe_out[tid] = glwe_array_in[tid];
|
|
|
|
// Accumulate
|
|
for (int i = 1; i < num_lwes; i++) {
|
|
auto glwe_in = glwe_array_in + i * (glwe_dimension + 1) * polynomial_size;
|
|
glwe_out[tid] += glwe_in[tid];
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Torus>
|
|
uint64_t scratch_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
|
|
ks_mem<Torus> **ks_tmp_memory,
|
|
uint32_t lwe_dimension_in,
|
|
uint32_t lwe_dimension_out, uint32_t num_lwes,
|
|
bool allocate_gpu_memory) {
|
|
uint64_t sub_size_tracker = 0;
|
|
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
|
|
lwe_dimension_in, lwe_dimension_out, num_lwes);
|
|
|
|
*ks_tmp_memory = new ks_mem<Torus>;
|
|
(*ks_tmp_memory)->d_buffer = (uint64_t *)cuda_malloc_with_size_tracking_async(
|
|
buffer_size, stream, gpu_index, sub_size_tracker, allocate_gpu_memory);
|
|
(*ks_tmp_memory)->lwe_dimension =
|
|
std::max(lwe_dimension_in, lwe_dimension_out);
|
|
(*ks_tmp_memory)->num_lwes = num_lwes;
|
|
return sub_size_tracker;
|
|
}
|
|
|
|
template <typename Torus>
|
|
void cleanup_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index,
|
|
ks_mem<Torus> *ks_tmp_memory,
|
|
bool allocate_gpu_memory) {
|
|
cuda_drop_with_size_tracking_async(ks_tmp_memory->d_buffer, stream, gpu_index,
|
|
allocate_gpu_memory);
|
|
delete ks_tmp_memory;
|
|
}
|
|
|
|
#endif
|