#ifndef CNCRT_KS_CUH #define CNCRT_KS_CUH #include #include "device.h" #include "gadget.cuh" #include "helper_multi_gpu.h" #include "polynomial/functions.cuh" #include "polynomial/polynomial_math.cuh" #include "torus.cuh" #include "utils/helper.cuh" #include "utils/kernel_dimensions.cuh" #include #include const int BLOCK_SIZE_DECOMP = 8; const int BLOCK_SIZE_GEMM_KS = 36; const int THREADS_GEMM_KS = 6; inline uint64_t get_threshold_ks_gemm() { return 128; } template struct ks_mem { Torus *d_buffer; uint64_t num_lwes; uint32_t lwe_dimension; }; template uint64_t scratch_cuda_keyswitch_size(uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t num_lwes) { GPU_ASSERT(lwe_dimension_in >= lwe_dimension_out, "Trying to allocate KS temp buffer for invalid LWE dimensions"); return (uint64_t)num_lwes * lwe_dimension_in * sizeof(Torus) * 2; } template __device__ Torus *get_ith_block(Torus *ksk, int i, int level, uint32_t lwe_dimension_out, uint32_t level_count) { int pos = i * level_count * (lwe_dimension_out + 1) + level * (lwe_dimension_out + 1); Torus *ptr = &ksk[pos]; return ptr; } template __device__ T closest_repr(T input, uint32_t base_log, uint32_t level_count) { T minus_2 = static_cast(-2); const T rep_bit_count = level_count * base_log; // 32 const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count; // 32 auto shift = (non_rep_bit_count - 1); // 31 T res = input >> shift; res++; res &= minus_2; res <<= shift; return res; } template __global__ void closest_representable(const T *input, T *output, uint32_t base_log, uint32_t level_count) { output[0] = closest_repr(input[0], base_log, level_count); } template __host__ void host_cuda_closest_representable(cudaStream_t stream, uint32_t gpu_index, const T *input, T *output, uint32_t base_log, uint32_t level_count) { dim3 grid(1, 1, 1); dim3 threads(1, 1, 1); cuda_set_device(gpu_index); closest_representable<<>>(input, output, base_log, level_count); } // Initialize decomposition by performing rounding // and decomposing one level of an array of Torus LWEs. Only // decomposes the mask elements of the incoming LWEs. template __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log, uint32_t level_count) { // index of this LWE ct in the buffer auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x; // index of the LWE sample in the LWE ct auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y; if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension) return; // Input LWE array is [mask_0, .., mask_lwe_dim, message] and // we only decompose the mask. Thus the stride for reading // is lwe_dimension + 1, while for writing it is lwe_dimension auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx; auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; auto write_state_idx = num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx; Torus a_i = lwe_in[read_val_idx]; Torus state = init_decomposer_state(a_i, base_log, level_count); Torus mod_b_mask = (1ll << base_log) - 1ll; KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out; kst_ptr_lwe_out[write_val_idx] = decompose_one(state, mod_b_mask, base_log); __syncthreads(); lwe_out[write_state_idx] = state; } // Decompose an array of LWEs with indirection through lwe_input_indices // The LWE array can contain total_lwe LWEs where total_lwe can be different // from num_lwe. The maximum index should be <= total_lwe. num_lwe is the number // of LWEs to decompose The output buffer should have space for num_lwe LWEs. // These will be sorted according to the input indices. template __global__ void decompose_vectorize_init_with_indices( Torus const *lwe_in, const Torus *__restrict__ lwe_input_indices, Torus *lwe_out, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log, uint32_t level_count) { // index of this LWE ct in the buffer auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x; // index of the LWE sample in the LWE ct auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y; if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension) return; // Input LWE array is [mask_0, .., mask_lwe_dim, message] and // we only decompose the mask. Thus the stride for reading // is lwe_dimension + 1, while for writing it is lwe_dimension auto read_val_idx = lwe_input_indices[lwe_idx] * (lwe_dimension + 1) + lwe_sample_idx; auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; auto write_state_idx = num_lwe * lwe_dimension + lwe_idx * lwe_dimension + lwe_sample_idx; Torus a_i = lwe_in[read_val_idx]; Torus state = init_decomposer_state(a_i, base_log, level_count); Torus mod_b_mask = (1ll << base_log) - 1ll; KSTorus *kst_ptr_lwe_out = (KSTorus *)lwe_out; kst_ptr_lwe_out[write_val_idx] = decompose_one(state, mod_b_mask, base_log); __syncthreads(); lwe_out[write_state_idx] = state; } // Continue decomposition of an array of Torus elements in place. Supposes // that the array contains already decomposed elements and // computes the new decomposed level in place. template __global__ void decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log, uint32_t level_count) { // index of this LWE ct in the buffer auto lwe_idx = blockIdx.x * blockDim.x + threadIdx.x; // index of the LWE sample in the LWE ct auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y; if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension) return; auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; auto state_idx = num_lwe * lwe_dimension + val_idx; Torus state = buffer_in[state_idx]; __syncthreads(); Torus mod_b_mask = (1ll << base_log) - 1ll; KSTorus *kst_ptr_lwe_out = (KSTorus *)buffer_in; kst_ptr_lwe_out[val_idx] = decompose_one(state, mod_b_mask, base_log); __syncthreads(); buffer_in[state_idx] = state; } /* LWEs inputs to the keyswitch function are stored as a_0,...,a_{lwe_dim},b, * where a_i are mask elements and b is the message. We initialize * the output keyswitched LWEs to 0, ..., 0, -b. The GEMM keyswitch is computed * as: * -(-b + sum(a_i A_KSK)) */ template __global__ void keyswitch_gemm_copy_negated_message_with_indices( const Torus *__restrict__ lwe_in, const Torus *__restrict__ lwe_input_indices, KSTorus *__restrict__ lwe_out, const Torus *__restrict__ lwe_output_indices, uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) { uint32_t lwe_id = blockIdx.x * blockDim.x + threadIdx.x; if (lwe_id >= num_lwes) return; uint32_t lwe_in_idx = lwe_input_indices[lwe_id]; uint32_t lwe_out_idx = lwe_output_indices[lwe_id]; Torus body_in = lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in]; Torus body_out; if constexpr (std::is_same_v) { body_out = -body_in; } else { body_out = closest_repr( lwe_in[lwe_in_idx * (lwe_dimension_in + 1) + lwe_dimension_in], sizeof(KSTorus) * 8, 1); // Power of two are encoded in the MSBs of the types so we need to scale // the type to the other one without having to worry about the moduli static_assert(sizeof(Torus) >= sizeof(KSTorus), "Cannot compile keyswitch with given input/output dtypes"); Torus input_to_output_scaling_factor = (sizeof(Torus) - sizeof(KSTorus)) * 8; auto rounded_downscaled_body = (KSTorus)(body_out >> input_to_output_scaling_factor); body_out = -rounded_downscaled_body; } lwe_out[lwe_out_idx * (lwe_dimension_out + 1) + lwe_dimension_out] = (KSTorus)body_out; } // The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK)). // This function finishes the KS computation by negating all elements in the // array using output indices. The array contains -b + SUM(a_i x LWE_i) and this // final step computes b - SUM(a_i x LWE_i). template __global__ void keyswitch_negate_with_output_indices( KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices, uint32_t lwe_size, uint32_t num_lwe) { // index of this LWE ct in the buffer auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x; // index of the LWE sample in the LWE ct auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y; if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size) return; auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx; Torus val = buffer_in[val_idx]; buffer_in[val_idx] = -val; } template __global__ void keyswitch_zero_output_with_output_indices( KSTorus *buffer_in, const Torus *__restrict__ lwe_output_indices, uint32_t lwe_size, uint32_t num_lwe) { // index of this LWE ct in the buffer auto lwe_sample_idx = blockIdx.x * blockDim.x + threadIdx.x; // index of the LWE sample in the LWE ct auto lwe_idx = blockIdx.y * blockDim.y + threadIdx.y; if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_size) return; auto val_idx = lwe_output_indices[lwe_idx] * lwe_size + lwe_sample_idx; buffer_in[val_idx] = 0; } /* * keyswitch kernel * Each thread handles a piece of the following equation: * $$GLWE_s2(\Delta.m+e) = (0,0,..,0,b) - \sum_{i=0,k-1} $$ where k is the dimension of * the GLWE ciphertext. If the polynomial dimension in GLWE is > 1, this * equation is solved for each polynomial coefficient. where Dec denotes the * decomposition with base beta and l levels and the inner product is done * between the decomposition of a_i and l GLWE encryptions of s1_i q/\beta^j, * with j in [1,l] We obtain a GLWE encryption of Delta.m (with Delta the * scaling factor) under key s2 instead of s1, with an increased noise * */ // Each thread in x are used to calculate one output. // threads in y are used to parallelize the lwe_dimension_in loop. // shared memory is used to store intermediate results of the reduction. // Note: To reduce register pressure we have slightly changed the algorithm, // the idea consists in calculating the negate value of the output. So, instead // of accumulating subtractions using -=, we accumulate additions using += in // the local_lwe_out. This seems to work better cause profits madd ops and save // some regs. For this to work, we need to negate the input // lwe_array_in[lwe_dimension_in], and negate back the output at the end to get // the correct results. Additionally, we split the calculation of the ksk offset // in two parts, a constant part is calculated before the loop, and a variable // part is calculated inside the loop. This seems to help with the register // pressure as well. template __global__ void keyswitch(KSTorus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, const Torus *__restrict__ lwe_array_in, const Torus *__restrict__ lwe_input_indexes, const KSTorus *__restrict__ ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) { const int tid = threadIdx.x + blockIdx.y * blockDim.x; const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x; extern __shared__ int8_t sharedmem[]; Torus *lwe_acc_out = (Torus *)sharedmem; auto block_lwe_array_out = get_chunk( lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1); if (tid <= lwe_dimension_out) { KSTorus local_lwe_out = 0; auto block_lwe_array_in = get_chunk( lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1); if (tid == lwe_dimension_out && threadIdx.y == 0) { if constexpr (std::is_same_v) { local_lwe_out = -block_lwe_array_in[lwe_dimension_in]; } else { auto new_body = closest_repr(block_lwe_array_in[lwe_dimension_in], sizeof(KSTorus) * 8, 1); // Power of two are encoded in the MSBs of the types so we need to scale // the type to the other one without having to worry about the moduli Torus input_to_output_scaling_factor = (sizeof(Torus) - sizeof(KSTorus)) * 8; auto rounded_downscaled_body = (KSTorus)(new_body >> input_to_output_scaling_factor); local_lwe_out = -rounded_downscaled_body; } } const Torus mask_mod_b = (1ll << base_log) - 1ll; const int pack_size = (lwe_dimension_in + blockDim.y - 1) / blockDim.y; const int start_i = pack_size * threadIdx.y; const int end_i = SEL(lwe_dimension_in, pack_size * (threadIdx.y + 1), pack_size * (threadIdx.y + 1) <= lwe_dimension_in); // This loop distribution seems to benefit the global mem reads for (int i = start_i; i < end_i; i++) { Torus state = init_decomposer_state(block_lwe_array_in[i], base_log, level_count); uint32_t offset = i * level_count * (lwe_dimension_out + 1); for (int j = 0; j < level_count; j++) { KSTorus decomposed = decompose_one(state, mask_mod_b, base_log); local_lwe_out += (KSTorus)ksk[tid + j * (lwe_dimension_out + 1) + offset] * decomposed; } } lwe_acc_out[shmem_index] = local_lwe_out; } if (tid <= lwe_dimension_out) { for (int offset = blockDim.y / 2; offset > 0 && threadIdx.y < offset; offset /= 2) { __syncthreads(); lwe_acc_out[shmem_index] += lwe_acc_out[shmem_index + offset * blockDim.x]; } if (threadIdx.y == 0) block_lwe_array_out[tid] = -lwe_acc_out[shmem_index]; } } template __host__ void host_keyswitch_lwe_ciphertext_vector( cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out, Torus const *lwe_output_indexes, Torus const *lwe_array_in, Torus const *lwe_input_indexes, KSTorus const *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count, uint32_t num_samples) { cuda_set_device(gpu_index); constexpr int num_threads_y = 32; int num_blocks_per_sample, num_threads_x; getNumBlocksAndThreads2D(lwe_dimension_out + 1, 512, num_threads_y, num_blocks_per_sample, num_threads_x); int shared_mem = sizeof(Torus) * num_threads_y * num_threads_x; PANIC_IF_FALSE( num_blocks_per_sample <= 65536, "Cuda error (Keyswitch): number of blocks per sample (%d) is too large", num_blocks_per_sample); // In multiplication of large integers (512, 1024, 2048), the number of // samples can be larger than 65536, so we need to set it in the first // dimension of the grid dim3 grid(num_samples, num_blocks_per_sample, 1); dim3 threads(num_threads_x, num_threads_y, 1); keyswitch<<>>( lwe_array_out, lwe_output_indexes, lwe_array_in, lwe_input_indexes, ksk, lwe_dimension_in, lwe_dimension_out, base_log, level_count); check_cuda_error(cudaGetLastError()); } // The GEMM keyswitch is computed as: -(-b + sum(a_i A_KSK)) template __host__ void host_gemm_keyswitch_lwe_ciphertext_vector( cudaStream_t stream, uint32_t gpu_index, KSTorus *lwe_array_out, Torus const *lwe_output_indices, Torus const *lwe_array_in, Torus const *lwe_input_indices, KSTorus const *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count, uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices) { cuda_set_device(gpu_index); check_cuda_error(cudaGetLastError()); // fp_tmp_buffer contains 2x the space to store the input LWE masks without // the body the first half can be interpreted with a smaller dtype when // performing 64->32 KS the second half, storing decomposition state, must be // interpreted as Torus* (usually 64b) KSTorus *d_mem_0 = (KSTorus *)fp_tmp_buffer; // keeps decomposed value (in KSTorus type) // Set the scratch buffer to 0 as it is used to accumulate // decomposition temporary results if (uses_trivial_indices) { cuda_memset_async(lwe_array_out, 0, num_samples * (lwe_dimension_out + 1) * sizeof(KSTorus), stream, gpu_index); } else { // gemm to ks the individual LWEs to GLWEs dim3 grid_zero(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP), CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP)); dim3 threads_zero(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP); keyswitch_zero_output_with_output_indices <<>>( lwe_array_out, lwe_output_indices, lwe_dimension_out + 1, num_samples); } check_cuda_error(cudaGetLastError()); dim3 grid_copy(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP)); dim3 threads_copy(BLOCK_SIZE_DECOMP); // lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies // lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out] // and negates them keyswitch_gemm_copy_negated_message_with_indices <<>>( lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices, lwe_dimension_in, num_samples, lwe_dimension_out); check_cuda_error(cudaGetLastError()); // decompose LWEs // don't decompose LWE body - the LWE has lwe_size + 1 elements. The last // element, the body is ignored by rounding down the number of blocks assuming // here that the LWE dimension is a multiple of the block size dim3 grid_decomp(CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP), CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP)); dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP); uint32_t shared_mem_size = get_shared_mem_size_tgemm(); // Shared memory requirement is 4096, 8192, and 16384 bytes respectively for // 32, 64, and 128-bit Torus elements // Sanity check: the shared memory size is a constant defined by the algorithm GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus), "GEMM kernel error: shared memory required might be too large"); auto stride_KSK_buffer = (lwe_dimension_out + 1) * level_count; // gemm to ks the individual LWEs to GLWEs dim3 grid_gemm(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_GEMM_KS), CEIL_DIV(num_samples, BLOCK_SIZE_GEMM_KS)); dim3 threads_gemm(BLOCK_SIZE_GEMM_KS * THREADS_GEMM_KS); // decompose first level (skips the body in the input buffer) decompose_vectorize_init_with_indices <<>>( lwe_array_in, lwe_input_indices, fp_tmp_buffer, lwe_dimension_in, num_samples, base_log, level_count); check_cuda_error(cudaGetLastError()); if (uses_trivial_indices) { tgemm <<>>( num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1); check_cuda_error(cudaGetLastError()); } else { tgemm_with_indices <<>>( num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1, lwe_output_indices); check_cuda_error(cudaGetLastError()); } auto ksk_block_size = (lwe_dimension_out + 1); for (int li = 1; li < level_count; ++li) { decompose_vectorize_step_inplace <<>>( fp_tmp_buffer, lwe_dimension_in, num_samples, base_log, level_count); check_cuda_error(cudaGetLastError()); if (uses_trivial_indices) { tgemm <<>>( num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1); check_cuda_error(cudaGetLastError()); } else { tgemm_with_indices <<>>( num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1, lwe_output_indices); check_cuda_error(cudaGetLastError()); } } // gemm to ks the individual LWEs to GLWEs dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP), CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP)); dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP); // Negate all outputs in the output LWEs. This is the final step in the GEMM // keyswitch computed as: -(-b + sum(a_i A_KSK)) keyswitch_negate_with_output_indices <<>>( lwe_array_out, lwe_output_indices, lwe_dimension_out + 1, num_samples); check_cuda_error(cudaGetLastError()); } template void execute_keyswitch_async( CudaStreams streams, const LweArrayVariant &lwe_array_out, const LweArrayVariant &lwe_output_indexes, const LweArrayVariant &lwe_array_in, const LweArrayVariant &lwe_input_indexes, KSTorus *const *ksks, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count, uint32_t num_samples, bool uses_trivial_indices, const std::vector *> &fp_tmp_buffer) { /// If the number of radix blocks is lower than the number of GPUs, not all /// GPUs will be active and there will be 1 input per GPU for (uint i = 0; i < streams.count(); i++) { int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, streams.count()); Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i); Torus *current_lwe_output_indexes = get_variant_element(lwe_output_indexes, i); Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i); Torus *current_lwe_input_indexes = get_variant_element(lwe_input_indexes, i); if (!fp_tmp_buffer.empty() && num_samples_on_gpu >= get_threshold_ks_gemm()) { GPU_ASSERT(fp_tmp_buffer.size() >= streams.count(), "GEMM KS Buffers %ld were not initialized for this amount of " "streams, %d", fp_tmp_buffer.size(), streams.count()); GPU_ASSERT(fp_tmp_buffer[i]->num_lwes >= num_samples_on_gpu, "KS temp buffer not big enough"); GPU_ASSERT(fp_tmp_buffer[i]->lwe_dimension == std::max(lwe_dimension_in, lwe_dimension_out), "KS temp buffer was created for a different input LWE size: " "%d vs (in:%d, out:%d)", fp_tmp_buffer[i]->lwe_dimension, lwe_dimension_in, lwe_dimension_out); // Compute Keyswitch host_gemm_keyswitch_lwe_ciphertext_vector( streams.stream(i), streams.gpu_index(i), current_lwe_array_out, current_lwe_output_indexes, current_lwe_array_in, current_lwe_input_indexes, ksks[i], lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples_on_gpu, fp_tmp_buffer[i]->d_buffer, uses_trivial_indices); } else { // Compute Keyswitch host_keyswitch_lwe_ciphertext_vector( streams.stream(i), streams.gpu_index(i), current_lwe_array_out, current_lwe_output_indexes, current_lwe_array_in, current_lwe_input_indexes, ksks[i], lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples_on_gpu); } } } template __host__ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe( cudaStream_t stream, uint32_t gpu_index, int8_t **fp_ks_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t num_lwes, bool allocate_gpu_memory) { cuda_set_device(gpu_index); int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size; // allocate at least LWE-mask times two: to keep both decomposition state and // decomposed intermediate value uint64_t memory_unit = glwe_accumulator_size > lwe_dimension * 2 ? glwe_accumulator_size : lwe_dimension * 2; uint64_t size_tracker = 0; uint64_t buffer_size = 2 * num_lwes * memory_unit * sizeof(Torus); *fp_ks_buffer = (int8_t *)cuda_malloc_with_size_tracking_async( buffer_size, stream, gpu_index, size_tracker, allocate_gpu_memory); return size_tracker; } // public functional packing keyswitch for a single LWE ciphertext // // Assumes there are (glwe_dimension+1) * polynomial_size threads split through // different thread blocks at the x-axis to work on that input. template __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext( Torus *glwe_out, Torus const *lwe_in, Torus const *fp_ksk, uint32_t lwe_dimension_in, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count) { const int tid = threadIdx.x + blockIdx.x * blockDim.x; size_t glwe_size = (glwe_dimension + 1); if (tid < glwe_size * polynomial_size) { const int local_index = threadIdx.x; // the output_glwe is split in polynomials and each x-block takes one of // them size_t poly_id = blockIdx.x; size_t coef_per_block = blockDim.x; // number of coefficients inside fp-ksk block for each lwe_input coefficient size_t ksk_block_size = glwe_size * polynomial_size * level_count; // initialize accumulator to 0 glwe_out[tid] = SEL(0, lwe_in[lwe_dimension_in], tid == glwe_dimension * polynomial_size); // Iterate through all lwe elements for (int i = 0; i < lwe_dimension_in; i++) { // Round and prepare decomposition Torus state = init_decomposer_state(lwe_in[i], base_log, level_count); Torus mod_b_mask = (1ll << base_log) - 1ll; // block of key for current lwe coefficient (cur_input_lwe[i]) auto ksk_block = &fp_ksk[i * ksk_block_size]; for (int j = 0; j < level_count; j++) { auto ksk_glwe = &ksk_block[j * glwe_size * polynomial_size]; // Iterate through each level and multiply by the ksk piece auto ksk_glwe_chunk = &ksk_glwe[poly_id * coef_per_block]; Torus decomposed = decompose_one(state, mod_b_mask, base_log); glwe_out[tid] -= decomposed * ksk_glwe_chunk[local_index]; } } } } /// To-do: Rewrite this kernel for efficiency template __global__ void accumulate_glwes(Torus *glwe_out, Torus *glwe_array_in, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t num_lwes) { const int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid < (glwe_dimension + 1) * polynomial_size) { glwe_out[tid] = glwe_array_in[tid]; // Accumulate for (int i = 1; i < num_lwes; i++) { auto glwe_in = glwe_array_in + i * (glwe_dimension + 1) * polynomial_size; glwe_out[tid] += glwe_in[tid]; } } } template uint64_t scratch_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index, ks_mem **ks_tmp_memory, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t num_lwes, bool allocate_gpu_memory) { uint64_t sub_size_tracker = 0; uint64_t buffer_size = scratch_cuda_keyswitch_size( lwe_dimension_in, lwe_dimension_out, num_lwes); *ks_tmp_memory = new ks_mem; (*ks_tmp_memory)->d_buffer = (uint64_t *)cuda_malloc_with_size_tracking_async( buffer_size, stream, gpu_index, sub_size_tracker, allocate_gpu_memory); (*ks_tmp_memory)->lwe_dimension = std::max(lwe_dimension_in, lwe_dimension_out); (*ks_tmp_memory)->num_lwes = num_lwes; return sub_size_tracker; } template void cleanup_cuda_keyswitch(cudaStream_t stream, uint32_t gpu_index, ks_mem *ks_tmp_memory, bool allocate_gpu_memory) { cuda_drop_with_size_tracking_async(ks_tmp_memory->d_buffer, stream, gpu_index, allocate_gpu_memory); delete ks_tmp_memory; } #endif