mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
- rename l_gadget and stop calling low lat PBS with N too large - rename trlwe and trgsw - rename lwe_mask_size into lwe_dimension - rename lwe_in into lwe_array_in - rename lwe_out into lwe_array_out - rename decomp_level into level - rename lwe_dimension_before/after into lwe_dimension_in/out
572 lines
22 KiB
Plaintext
572 lines
22 KiB
Plaintext
#ifndef WOP_PBS_H
|
|
#define WOP_PBS_H
|
|
|
|
#include "cooperative_groups.h"
|
|
|
|
#include "../include/helper_cuda.h"
|
|
#include "bootstrap.h"
|
|
#include "bootstrap_low_latency.cuh"
|
|
#include "complex/operations.cuh"
|
|
#include "crypto/ggsw.cuh"
|
|
#include "crypto/torus.cuh"
|
|
#include "fft/bnsmfft.cuh"
|
|
#include "fft/smfft.cuh"
|
|
#include "fft/twiddles.cuh"
|
|
#include "keyswitch.cuh"
|
|
#include "polynomial/functions.cuh"
|
|
#include "polynomial/parameters.cuh"
|
|
#include "polynomial/polynomial.cuh"
|
|
#include "polynomial/polynomial_math.cuh"
|
|
#include "utils/memory.cuh"
|
|
#include "utils/timer.cuh"
|
|
|
|
template <typename T, class params>
|
|
__device__ void fft(double2 *output, T *input) {
|
|
synchronize_threads_in_block();
|
|
|
|
// Reduce the size of the FFT to be performed by storing
|
|
// the real-valued polynomial into a complex polynomial
|
|
real_to_complex_compressed<T, params>(input, output);
|
|
synchronize_threads_in_block();
|
|
|
|
// Switch to the FFT space
|
|
NSMFFT_direct<HalfDegree<params>>(output);
|
|
synchronize_threads_in_block();
|
|
|
|
correction_direct_fft_inplace<params>(output);
|
|
synchronize_threads_in_block();
|
|
}
|
|
|
|
template <typename T, typename ST, class params>
|
|
__device__ void fft(double2 *output, T *input) {
|
|
synchronize_threads_in_block();
|
|
|
|
// Reduce the size of the FFT to be performed by storing
|
|
// the real-valued polynomial into a complex polynomial
|
|
real_to_complex_compressed<T, ST, params>(input, output);
|
|
synchronize_threads_in_block();
|
|
|
|
// Switch to the FFT space
|
|
NSMFFT_direct<HalfDegree<params>>(output);
|
|
synchronize_threads_in_block();
|
|
|
|
correction_direct_fft_inplace<params>(output);
|
|
synchronize_threads_in_block();
|
|
}
|
|
|
|
template <class params> __device__ void ifft_inplace(double2 *data) {
|
|
synchronize_threads_in_block();
|
|
|
|
correction_inverse_fft_inplace<params>(data);
|
|
synchronize_threads_in_block();
|
|
|
|
NSMFFT_inverse<HalfDegree<params>>(data);
|
|
synchronize_threads_in_block();
|
|
}
|
|
|
|
/*
|
|
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
|
|
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
|
|
* Compute a CMUX with these operands and writes the output to a particular
|
|
* index of glwe_array_out.
|
|
*
|
|
* This function needs polynomial_size threads per block.
|
|
*
|
|
* - glwe_array_out: An array where the result should be written to.
|
|
* - glwe_array_in: An array where the GLWE inputs are stored.
|
|
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
|
* - selected_memory: An array to be used for the accumulators. Can be in the
|
|
* shared memory or global memory.
|
|
* - output_idx: The index of the output where the glwe ciphertext should be
|
|
* written.
|
|
* - input_idx1: The index of the first glwe ciphertext we will use.
|
|
* - input_idx2: The index of the second glwe ciphertext we will use.
|
|
* - glwe_dim: This is k.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - ggsw_idx: The index of the GGSW we will use.
|
|
*/
|
|
template <typename Torus, typename STorus, class params>
|
|
__device__ void
|
|
cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
|
char *selected_memory, uint32_t output_idx, uint32_t input_idx1,
|
|
uint32_t input_idx2, uint32_t glwe_dim, uint32_t polynomial_size,
|
|
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
|
|
|
// Define glwe_sub
|
|
Torus *glwe_sub_mask = (Torus *)selected_memory;
|
|
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
|
|
|
|
int16_t *glwe_mask_decomposed = (int16_t *)(glwe_sub_body + polynomial_size);
|
|
int16_t *glwe_body_decomposed =
|
|
(int16_t *)glwe_mask_decomposed + (ptrdiff_t)polynomial_size;
|
|
|
|
double2 *mask_res_fft = (double2 *)(glwe_body_decomposed + polynomial_size);
|
|
double2 *body_res_fft =
|
|
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
|
|
|
|
double2 *glwe_fft =
|
|
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
|
|
|
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
|
|
|
/////////////////////////////////////
|
|
|
|
// glwe2-glwe1
|
|
|
|
// Copy m0 to shared memory to preserve data
|
|
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
|
auto m0_body = m0_mask + polynomial_size;
|
|
|
|
// Just gets the pointer for m1 on global memory
|
|
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
|
auto m1_body = m1_mask + polynomial_size;
|
|
|
|
// Mask
|
|
sub_polynomial<Torus, params>(glwe_sub_mask, m1_mask, m0_mask);
|
|
// Body
|
|
sub_polynomial<Torus, params>(glwe_sub_body, m1_body, m0_body);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
// Initialize the polynomial multiplication via FFT arrays
|
|
// The polynomial multiplications happens at the block level
|
|
// and each thread handles two or more coefficients
|
|
int pos = threadIdx.x;
|
|
for (int j = 0; j < params::opt / 2; j++) {
|
|
mask_res_fft[pos].x = 0;
|
|
mask_res_fft[pos].y = 0;
|
|
body_res_fft[pos].x = 0;
|
|
body_res_fft[pos].y = 0;
|
|
pos += params::degree / params::opt;
|
|
}
|
|
|
|
// Subtract each glwe operand, decompose the resulting
|
|
// polynomial coefficients to multiply each decomposed level
|
|
// with the corresponding part of the LUT
|
|
for (int level = 0; level < level_count; level++) {
|
|
|
|
// Decomposition
|
|
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level);
|
|
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level);
|
|
|
|
// First, perform the polynomial multiplication for the mask
|
|
synchronize_threads_in_block();
|
|
fft<int16_t, params>(glwe_fft, glwe_mask_decomposed);
|
|
|
|
// External product and accumulate
|
|
// Get the piece necessary for the multiplication
|
|
auto mask_fourier = get_ith_mask_kth_block(
|
|
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
|
auto body_fourier = get_ith_body_kth_block(
|
|
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
// Perform the coefficient-wise product
|
|
synchronize_threads_in_block();
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
mask_res_fft, glwe_fft, mask_fourier);
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
body_res_fft, glwe_fft, body_fourier);
|
|
|
|
// Now handle the polynomial multiplication for the body
|
|
// in the same way
|
|
synchronize_threads_in_block();
|
|
fft<int16_t, params>(glwe_fft, glwe_body_decomposed);
|
|
|
|
// External product and accumulate
|
|
// Get the piece necessary for the multiplication
|
|
mask_fourier = get_ith_mask_kth_block(
|
|
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
|
body_fourier = get_ith_body_kth_block(
|
|
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
mask_res_fft, glwe_fft, mask_fourier);
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
body_res_fft, glwe_fft, body_fourier);
|
|
}
|
|
|
|
// IFFT
|
|
synchronize_threads_in_block();
|
|
ifft_inplace<params>(mask_res_fft);
|
|
ifft_inplace<params>(body_res_fft);
|
|
synchronize_threads_in_block();
|
|
|
|
// Write the output
|
|
Torus *mb_mask =
|
|
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
|
Torus *mb_body = mb_mask + polynomial_size;
|
|
|
|
int tid = threadIdx.x;
|
|
for (int i = 0; i < params::opt; i++) {
|
|
mb_mask[tid] = m0_mask[tid];
|
|
mb_body[tid] = m0_body[tid];
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
add_to_torus<Torus, params>(mask_res_fft, mb_mask);
|
|
add_to_torus<Torus, params>(body_res_fft, mb_body);
|
|
}
|
|
|
|
/**
|
|
* Computes several CMUXes using an array of GLWE ciphertexts and a single GGSW
|
|
* ciphertext. The GLWE ciphertexts are picked two-by-two in sequence. Each
|
|
* thread block computes a single CMUX.
|
|
*
|
|
* - glwe_array_out: An array where the result should be written to.
|
|
* - glwe_array_in: An array where the GLWE inputs are stored.
|
|
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
|
* - device_mem: An pointer for the global memory in case the shared memory is
|
|
* not big enough to store the accumulators.
|
|
* - device_memory_size_per_block: Memory size needed to store all accumulators
|
|
* for a single block.
|
|
* - glwe_dim: This is k.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - ggsw_idx: The index of the GGSW we will use.
|
|
*/
|
|
template <typename Torus, typename STorus, class params, sharedMemDegree SMD>
|
|
__global__ void
|
|
device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
|
char *device_mem, size_t device_memory_size_per_block,
|
|
uint32_t glwe_dim, uint32_t polynomial_size,
|
|
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
|
|
|
int cmux_idx = blockIdx.x;
|
|
int output_idx = cmux_idx;
|
|
int input_idx1 = (cmux_idx << 1);
|
|
int input_idx2 = (cmux_idx << 1) + 1;
|
|
|
|
// We use shared memory for intermediate result
|
|
extern __shared__ char sharedmem[];
|
|
char *selected_memory;
|
|
|
|
if constexpr (SMD == FULLSM)
|
|
selected_memory = sharedmem;
|
|
else
|
|
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_block];
|
|
|
|
cmux<Torus, STorus, params>(glwe_array_out, glwe_array_in, ggsw_in,
|
|
selected_memory, output_idx, input_idx1,
|
|
input_idx2, glwe_dim, polynomial_size, base_log,
|
|
level_count, ggsw_idx);
|
|
}
|
|
/*
|
|
* This kernel executes the CMUX tree used by the hybrid packing of the WoPBS.
|
|
*
|
|
* Uses shared memory for intermediate results
|
|
*
|
|
* - v_stream: The CUDA stream that should be used.
|
|
* - glwe_array_out: A device array for the output GLWE ciphertext.
|
|
* - ggsw_in: A device array for the GGSW ciphertexts used in each layer.
|
|
* - lut_vector: A device array for the GLWE ciphertexts used in the first
|
|
* layer.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - r: Number of layers in the tree.
|
|
*/
|
|
template <typename Torus, typename STorus, class params>
|
|
void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
|
Torus *lut_vector, uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log,
|
|
uint32_t level_count, uint32_t r,
|
|
uint32_t max_shared_memory) {
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
int num_lut = (1 << r);
|
|
|
|
cuda_initialize_twiddles(polynomial_size, 0);
|
|
|
|
int memory_needed_per_block =
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
|
sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed
|
|
sizeof(int16_t) * polynomial_size + // glwe_body_decomposed
|
|
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
|
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
|
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
|
|
|
dim3 thds(polynomial_size / params::opt, 1, 1);
|
|
|
|
//////////////////////
|
|
double2 *d_ggsw_fft_in;
|
|
int ggsw_size = r * polynomial_size * (glwe_dimension + 1) *
|
|
(glwe_dimension + 1) * level_count;
|
|
|
|
#if (CUDART_VERSION < 11020)
|
|
checkCudaErrors(
|
|
cudaMalloc((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double)));
|
|
#else
|
|
checkCudaErrors(cudaMallocAsync((void **)&d_ggsw_fft_in,
|
|
ggsw_size * sizeof(double), *stream));
|
|
#endif
|
|
|
|
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
|
|
r, glwe_dimension,
|
|
polynomial_size, level_count);
|
|
|
|
//////////////////////
|
|
|
|
// Allocate global memory in case parameters are too large
|
|
char *d_mem;
|
|
if (max_shared_memory < memory_needed_per_block) {
|
|
#if (CUDART_VERSION < 11020)
|
|
checkCudaErrors(
|
|
cudaMalloc((void **)&d_mem, memory_needed_per_block * (1 << (r - 1))));
|
|
#else
|
|
checkCudaErrors(cudaMallocAsync(
|
|
(void **)&d_mem, memory_needed_per_block * (1 << (r - 1)), *stream));
|
|
#endif
|
|
} else {
|
|
checkCudaErrors(cudaFuncSetAttribute(
|
|
device_batch_cmux<Torus, STorus, params, FULLSM>,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, memory_needed_per_block));
|
|
checkCudaErrors(
|
|
cudaFuncSetCacheConfig(device_batch_cmux<Torus, STorus, params, FULLSM>,
|
|
cudaFuncCachePreferShared));
|
|
}
|
|
|
|
// Allocate buffers
|
|
int glwe_size = (glwe_dimension + 1) * polynomial_size;
|
|
Torus *d_buffer1, *d_buffer2;
|
|
|
|
#if (CUDART_VERSION < 11020)
|
|
checkCudaErrors(
|
|
cudaMalloc((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus)));
|
|
checkCudaErrors(
|
|
cudaMalloc((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus)));
|
|
#else
|
|
checkCudaErrors(cudaMallocAsync(
|
|
(void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus), *stream));
|
|
checkCudaErrors(cudaMallocAsync(
|
|
(void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus), *stream));
|
|
#endif
|
|
checkCudaErrors(cudaMemcpyAsync(d_buffer1, lut_vector,
|
|
num_lut * glwe_size * sizeof(Torus),
|
|
cudaMemcpyDeviceToDevice, *stream));
|
|
|
|
Torus *output;
|
|
// Run the cmux tree
|
|
for (int layer_idx = 0; layer_idx < r; layer_idx++) {
|
|
output = (layer_idx % 2 ? d_buffer1 : d_buffer2);
|
|
Torus *input = (layer_idx % 2 ? d_buffer2 : d_buffer1);
|
|
|
|
int num_cmuxes = (1 << (r - 1 - layer_idx));
|
|
dim3 grid(num_cmuxes, 1, 1);
|
|
|
|
// walks horizontally through the leafs
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
device_batch_cmux<Torus, STorus, params, NOSM>
|
|
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
|
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, level_count,
|
|
layer_idx // r
|
|
);
|
|
else
|
|
device_batch_cmux<Torus, STorus, params, FULLSM>
|
|
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
|
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, level_count,
|
|
layer_idx // r
|
|
);
|
|
}
|
|
|
|
checkCudaErrors(
|
|
cudaMemcpyAsync(glwe_array_out, output,
|
|
(glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
|
cudaMemcpyDeviceToDevice, *stream));
|
|
|
|
// We only need synchronization to assert that data is in glwe_array_out
|
|
// before returning. Memory release can be added to the stream and processed
|
|
// later.
|
|
checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
|
|
// Free memory
|
|
#if (CUDART_VERSION < 11020)
|
|
checkCudaErrors(cudaFree(d_ggsw_fft_in));
|
|
checkCudaErrors(cudaFree(d_buffer1));
|
|
checkCudaErrors(cudaFree(d_buffer2));
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
checkCudaErrors(cudaFree(d_mem));
|
|
#else
|
|
checkCudaErrors(cudaFreeAsync(d_ggsw_fft_in, *stream));
|
|
checkCudaErrors(cudaFreeAsync(d_buffer1, *stream));
|
|
checkCudaErrors(cudaFreeAsync(d_buffer2, *stream));
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
checkCudaErrors(cudaFreeAsync(d_mem, *stream));
|
|
#endif
|
|
}
|
|
|
|
// only works for big lwe for ks+bs case
|
|
// state_lwe_buffer is copied from big lwe input
|
|
// shifted_lwe_buffer is scalar multiplication of lwe input
|
|
// blockIdx.x refers to input ciphertext id
|
|
template <typename Torus, class params>
|
|
__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
|
|
Torus *src, Torus value) {
|
|
int blockId = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)];
|
|
auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)];
|
|
auto cur_src = &src[blockId * (params::degree + 1)];
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_dst_copy[tid] = cur_src[tid];
|
|
cur_dst_shift[tid] = cur_src[tid] * value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
if (threadIdx.x == params::degree / params::opt - 1) {
|
|
cur_dst_copy[params::degree] = cur_src[params::degree];
|
|
cur_dst_shift[params::degree] = cur_src[params::degree] * value;
|
|
}
|
|
}
|
|
|
|
// only works for small lwe in ks+bs case
|
|
// function copies lwe when length is not a power of two
|
|
template <typename Torus>
|
|
__global__ void copy_small_lwe(Torus *dst, Torus *src, uint32_t small_lwe_size,
|
|
uint32_t number_of_bits, uint32_t lwe_id) {
|
|
|
|
size_t blockId = blockIdx.x;
|
|
size_t threads_per_block = blockDim.x;
|
|
size_t opt = small_lwe_size / threads_per_block;
|
|
size_t rem = small_lwe_size & (threads_per_block - 1);
|
|
|
|
auto cur_lwe_list = &dst[blockId * small_lwe_size * number_of_bits];
|
|
auto cur_dst = &cur_lwe_list[lwe_id * small_lwe_size];
|
|
auto cur_src = &src[blockId * small_lwe_size];
|
|
|
|
size_t tid = threadIdx.x;
|
|
for (int i = 0; i < opt; i++) {
|
|
cur_dst[tid] = cur_src[tid];
|
|
tid += threads_per_block;
|
|
}
|
|
|
|
if (threadIdx.x < rem)
|
|
cur_dst[tid] = cur_src[tid];
|
|
}
|
|
|
|
// only used in extract bits for one ciphertext
|
|
// should be called with one block and one thread
|
|
// NOTE: check if putting this functionality in copy_small_lwe or
|
|
// fill_pbs_lut vector is faster
|
|
template <typename Torus>
|
|
__global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) {
|
|
lwe[blockIdx.x * (lwe_dimension + 1) + lwe_dimension] += value;
|
|
}
|
|
|
|
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
|
|
// mask is 0s)
|
|
// The LUT is filled with -alpha in each coefficient where alpha =
|
|
// delta*2^{bit_idx-1}
|
|
template <typename Torus, class params>
|
|
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
|
Torus *cur_poly = &lut[params::degree];
|
|
size_t tid = threadIdx.x;
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_poly[tid] = value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
}
|
|
|
|
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0
|
|
// if the extracted bit was 0 and 1 in the other case
|
|
//
|
|
// Remove the extracted bit from the state LWE to get a 0 at the extracted bit
|
|
// location.
|
|
//
|
|
// Shift on padding bit for next iteration, that's why
|
|
// alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) is used
|
|
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
|
|
template <typename Torus, class params>
|
|
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
|
Torus *pbs_lwe_array_out, Torus add_value,
|
|
Torus mul_value) {
|
|
size_t tid = threadIdx.x;
|
|
size_t blockId = blockIdx.x;
|
|
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
|
|
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
|
|
auto cur_pbs_lwe_array_out =
|
|
&pbs_lwe_array_out[blockId * (params::degree + 1)];
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_array_out[tid];
|
|
cur_shifted_lwe[tid] *= mul_value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
if (threadIdx.x == params::degree / params::opt - 1) {
|
|
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
|
|
(cur_pbs_lwe_array_out[params::degree] + add_value);
|
|
cur_shifted_lwe[params::degree] *= mul_value;
|
|
}
|
|
}
|
|
|
|
template <typename Torus, class params>
|
|
__host__ void host_extract_bits(
|
|
void *v_stream, Torus *list_lwe_array_out, Torus *lwe_array_in,
|
|
Torus *lwe_array_in_buffer, Torus *lwe_array_in_shifted_buffer,
|
|
Torus *lwe_array_out_ks_buffer, Torus *lwe_array_out_pbs_buffer,
|
|
Torus *lut_pbs, uint32_t *lut_vector_indexes, Torus *ksk,
|
|
double2 *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
|
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
|
uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk,
|
|
uint32_t level_count_ksk, uint32_t number_of_samples) {
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
|
|
|
int blocks = 1;
|
|
int threads = params::degree / params::opt;
|
|
|
|
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>(
|
|
lwe_array_in_buffer, lwe_array_in_shifted_buffer, lwe_array_in,
|
|
1ll << (ciphertext_n_bits - delta_log - 1));
|
|
|
|
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
|
cuda_keyswitch_lwe_ciphertext_vector(
|
|
v_stream, lwe_array_out_ks_buffer, lwe_array_in_shifted_buffer, ksk,
|
|
lwe_dimension_in, lwe_dimension_out, base_log_ksk, level_count_ksk, 1);
|
|
|
|
copy_small_lwe<<<1, 256, 0, *stream>>>(
|
|
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
|
|
number_of_bits, number_of_bits - bit_idx - 1);
|
|
|
|
if (bit_idx == number_of_bits - 1) {
|
|
break;
|
|
}
|
|
|
|
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_array_out_ks_buffer,
|
|
lwe_dimension_out,
|
|
1ll << (ciphertext_n_bits - 2));
|
|
|
|
fill_lut_body_for_current_bit<Torus, params>
|
|
<<<blocks, threads, 0, *stream>>>(
|
|
lut_pbs, 0ll - 1ll << (delta_log - 1 + bit_idx));
|
|
|
|
host_bootstrap_low_latency<Torus, params>(
|
|
v_stream, lwe_array_out_pbs_buffer, lut_pbs, lut_vector_indexes,
|
|
lwe_array_out_ks_buffer, fourier_bsk, lwe_dimension_out,
|
|
lwe_dimension_in, base_log_bsk, level_count_bsk, number_of_samples, 1);
|
|
|
|
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
|
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
|
|
lwe_array_out_pbs_buffer, 1ll << (delta_log - 1 + bit_idx),
|
|
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2));
|
|
}
|
|
}
|
|
|
|
#endif // WO_PBS_H
|