mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
960 lines
39 KiB
Plaintext
960 lines
39 KiB
Plaintext
#ifndef WOP_PBS_H
|
|
#define WOP_PBS_H
|
|
|
|
#include "cooperative_groups.h"
|
|
|
|
#include "../include/helper_cuda.h"
|
|
#include "bootstrap.h"
|
|
#include "bootstrap_amortized.cuh"
|
|
#include "bootstrap_low_latency.cuh"
|
|
#include "complex/operations.cuh"
|
|
#include "crypto/ggsw.cuh"
|
|
#include "crypto/torus.cuh"
|
|
#include "device.h"
|
|
#include "fft/bnsmfft.cuh"
|
|
#include "fft/smfft.cuh"
|
|
#include "fft/twiddles.cuh"
|
|
#include "keyswitch.cuh"
|
|
#include "polynomial/functions.cuh"
|
|
#include "polynomial/parameters.cuh"
|
|
#include "polynomial/polynomial.cuh"
|
|
#include "polynomial/polynomial_math.cuh"
|
|
#include "utils/kernel_dimensions.cuh"
|
|
#include "utils/memory.cuh"
|
|
#include "utils/timer.cuh"
|
|
|
|
template <class params> __device__ void fft(double2 *output) {
|
|
// Switch to the FFT space
|
|
NSMFFT_direct<HalfDegree<params>>(output);
|
|
synchronize_threads_in_block();
|
|
|
|
correction_direct_fft_inplace<params>(output);
|
|
synchronize_threads_in_block();
|
|
}
|
|
|
|
template <class params> __device__ void ifft_inplace(double2 *data) {
|
|
synchronize_threads_in_block();
|
|
|
|
correction_inverse_fft_inplace<params>(data);
|
|
synchronize_threads_in_block();
|
|
|
|
NSMFFT_inverse<HalfDegree<params>>(data);
|
|
synchronize_threads_in_block();
|
|
}
|
|
|
|
/*
|
|
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
|
|
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
|
|
* Compute a CMUX with these operands and writes the output to a particular
|
|
* index of glwe_array_out.
|
|
*
|
|
* This function needs polynomial_size threads per block.
|
|
*
|
|
* - glwe_array_out: An array where the result should be written to.
|
|
* - glwe_array_in: An array where the GLWE inputs are stored.
|
|
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
|
* - selected_memory: An array to be used for the accumulators. Can be in the
|
|
* shared memory or global memory.
|
|
* - output_idx: The index of the output where the glwe ciphertext should be
|
|
* written.
|
|
* - input_idx1: The index of the first glwe ciphertext we will use.
|
|
* - input_idx2: The index of the second glwe ciphertext we will use.
|
|
* - glwe_dim: This is k.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - ggsw_idx: The index of the GGSW we will use.
|
|
*/
|
|
template <typename Torus, typename STorus, class params>
|
|
__device__ void
|
|
cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
|
char *selected_memory, uint32_t output_idx, uint32_t input_idx1,
|
|
uint32_t input_idx2, uint32_t glwe_dim, uint32_t polynomial_size,
|
|
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
|
|
|
// Define glwe_sub
|
|
Torus *glwe_sub_mask = (Torus *)selected_memory;
|
|
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
|
|
|
|
double2 *mask_res_fft = (double2 *)glwe_sub_body +
|
|
polynomial_size / (sizeof(double2) / sizeof(Torus));
|
|
double2 *body_res_fft =
|
|
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
|
|
|
|
double2 *glwe_fft =
|
|
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
|
|
|
/////////////////////////////////////
|
|
|
|
// glwe2-glwe1
|
|
|
|
// Copy m0 to shared memory to preserve data
|
|
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
|
auto m0_body = m0_mask + polynomial_size;
|
|
|
|
// Just gets the pointer for m1 on global memory
|
|
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
|
auto m1_body = m1_mask + polynomial_size;
|
|
|
|
// Mask
|
|
sub_polynomial<Torus, params>(glwe_sub_mask, m1_mask, m0_mask);
|
|
// Body
|
|
sub_polynomial<Torus, params>(glwe_sub_body, m1_body, m0_body);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
// Initialize the polynomial multiplication via FFT arrays
|
|
// The polynomial multiplications happens at the block level
|
|
// and each thread handles two or more coefficients
|
|
int pos = threadIdx.x;
|
|
for (int j = 0; j < params::opt / 2; j++) {
|
|
mask_res_fft[pos].x = 0;
|
|
mask_res_fft[pos].y = 0;
|
|
body_res_fft[pos].x = 0;
|
|
body_res_fft[pos].y = 0;
|
|
pos += params::degree / params::opt;
|
|
}
|
|
|
|
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count, glwe_sub_mask);
|
|
GadgetMatrix<Torus, params> gadget_body(base_log, level_count, glwe_sub_body);
|
|
// Subtract each glwe operand, decompose the resulting
|
|
// polynomial coefficients to multiply each decomposed level
|
|
// with the corresponding part of the LUT
|
|
for (int level = level_count - 1; level >= 0; level--) {
|
|
|
|
// Decomposition
|
|
gadget_mask.decompose_and_compress_next(glwe_fft);
|
|
|
|
// First, perform the polynomial multiplication for the mask
|
|
fft<params>(glwe_fft);
|
|
|
|
// External product and accumulate
|
|
// Get the piece necessary for the multiplication
|
|
auto mask_fourier = get_ith_mask_kth_block(
|
|
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
|
auto body_fourier = get_ith_body_kth_block(
|
|
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
// Perform the coefficient-wise product
|
|
synchronize_threads_in_block();
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
mask_res_fft, glwe_fft, mask_fourier);
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
body_res_fft, glwe_fft, body_fourier);
|
|
|
|
// Now handle the polynomial multiplication for the body
|
|
// in the same way
|
|
synchronize_threads_in_block();
|
|
|
|
gadget_body.decompose_and_compress_next(glwe_fft);
|
|
fft<params>(glwe_fft);
|
|
|
|
// External product and accumulate
|
|
// Get the piece necessary for the multiplication
|
|
mask_fourier = get_ith_mask_kth_block(
|
|
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
|
body_fourier = get_ith_body_kth_block(
|
|
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
|
|
|
synchronize_threads_in_block();
|
|
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
mask_res_fft, glwe_fft, mask_fourier);
|
|
polynomial_product_accumulate_in_fourier_domain<params, double2>(
|
|
body_res_fft, glwe_fft, body_fourier);
|
|
}
|
|
|
|
// IFFT
|
|
synchronize_threads_in_block();
|
|
ifft_inplace<params>(mask_res_fft);
|
|
ifft_inplace<params>(body_res_fft);
|
|
synchronize_threads_in_block();
|
|
|
|
// Write the output
|
|
Torus *mb_mask =
|
|
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
|
Torus *mb_body = mb_mask + polynomial_size;
|
|
|
|
int tid = threadIdx.x;
|
|
for (int i = 0; i < params::opt; i++) {
|
|
mb_mask[tid] = m0_mask[tid];
|
|
mb_body[tid] = m0_body[tid];
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
add_to_torus<Torus, params>(mask_res_fft, mb_mask);
|
|
add_to_torus<Torus, params>(body_res_fft, mb_body);
|
|
}
|
|
|
|
// Appends zeroed paddings between each LUT
|
|
template <typename Torus, class params>
|
|
__host__ void add_padding_to_lut_async(Torus *lut_out, Torus *lut_in,
|
|
uint32_t glwe_dimension,
|
|
uint32_t num_lut, cudaStream_t *stream) {
|
|
checkCudaErrors(cudaMemsetAsync(lut_out, 0,
|
|
num_lut * (glwe_dimension + 1) *
|
|
params::degree * sizeof(Torus),
|
|
*stream));
|
|
for (int i = 0; i < num_lut; i++)
|
|
checkCudaErrors(cudaMemcpyAsync(
|
|
lut_out + (2 * i + 1) * params::degree, lut_in + i * params::degree,
|
|
params::degree * sizeof(Torus), cudaMemcpyDeviceToDevice, *stream));
|
|
}
|
|
|
|
/**
|
|
* Computes several CMUXes using an array of GLWE ciphertexts and a single GGSW
|
|
* ciphertext. The GLWE ciphertexts are picked two-by-two in sequence. Each
|
|
* thread block computes a single CMUX.
|
|
*
|
|
* - glwe_array_out: An array where the result should be written to.
|
|
* - glwe_array_in: An array where the GLWE inputs are stored.
|
|
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
|
* - device_mem: An pointer for the global memory in case the shared memory is
|
|
* not big enough to store the accumulators.
|
|
* - device_memory_size_per_block: Memory size needed to store all accumulators
|
|
* for a single block.
|
|
* - glwe_dim: This is k.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - ggsw_idx: The index of the GGSW we will use.
|
|
*/
|
|
template <typename Torus, typename STorus, class params, sharedMemDegree SMD>
|
|
__global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in,
|
|
double2 *ggsw_in, char *device_mem,
|
|
size_t device_memory_size_per_block,
|
|
uint32_t glwe_dim, uint32_t polynomial_size,
|
|
uint32_t base_log, uint32_t level_count,
|
|
uint32_t ggsw_idx, uint32_t num_lut) {
|
|
|
|
// We are running gridDim.y cmux trees in parallel
|
|
int tree_idx = blockIdx.y;
|
|
int tree_offset = tree_idx * num_lut * (glwe_dim + 1) * polynomial_size;
|
|
|
|
// The x-axis handles a single cmux tree. Each block computes one cmux.
|
|
int cmux_idx = blockIdx.x;
|
|
int output_idx = cmux_idx;
|
|
int input_idx1 = (cmux_idx << 1);
|
|
int input_idx2 = (cmux_idx << 1) + 1;
|
|
|
|
// We use shared memory for intermediate result
|
|
extern __shared__ char sharedmem[];
|
|
char *selected_memory;
|
|
|
|
if constexpr (SMD == FULLSM)
|
|
selected_memory = sharedmem;
|
|
else
|
|
selected_memory = &device_mem[(blockIdx.x + blockIdx.y * gridDim.x) *
|
|
device_memory_size_per_block];
|
|
|
|
cmux<Torus, STorus, params>(
|
|
glwe_array_out + tree_offset, glwe_array_in + tree_offset, ggsw_in,
|
|
selected_memory, output_idx, input_idx1, input_idx2, glwe_dim,
|
|
polynomial_size, base_log, level_count, ggsw_idx);
|
|
}
|
|
/*
|
|
* This kernel executes the CMUX tree used by the hybrid packing of the WoPBS.
|
|
*
|
|
* Uses shared memory for intermediate results
|
|
*
|
|
* - v_stream: The CUDA stream that should be used.
|
|
* - glwe_array_out: A device array for the output GLWE ciphertext.
|
|
* - ggsw_in: A device array for the GGSW ciphertexts used in each layer.
|
|
* - lut_vector: A device array for the GLWE ciphertexts used in the first
|
|
* layer.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
|
* - r: Number of layers in the tree.
|
|
* - tau: The quantity of CMUX trees that should be executed
|
|
*/
|
|
template <typename Torus, typename STorus, class params>
|
|
void host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
|
Torus *ggsw_in, Torus *lut_vector, uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log,
|
|
uint32_t level_count, uint32_t r, uint32_t tau,
|
|
uint32_t max_shared_memory) {
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
|
|
int num_lut = (1 << r);
|
|
if (r == 0) {
|
|
// Simply copy the LUTs
|
|
add_padding_to_lut_async<Torus, params>(
|
|
glwe_array_out, lut_vector, glwe_dimension, num_lut * tau, stream);
|
|
checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
return;
|
|
}
|
|
|
|
cuda_initialize_twiddles(polynomial_size, 0);
|
|
|
|
int memory_needed_per_block =
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
|
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
|
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
|
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
|
|
|
dim3 thds(polynomial_size / params::opt, 1, 1);
|
|
|
|
//////////////////////
|
|
int ggsw_size = polynomial_size * (glwe_dimension + 1) *
|
|
(glwe_dimension + 1) * level_count;
|
|
|
|
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
|
|
r * ggsw_size * sizeof(double), stream, gpu_index);
|
|
|
|
batch_fft_ggsw_vector<Torus, STorus, params>(
|
|
stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size,
|
|
level_count, gpu_index, max_shared_memory);
|
|
|
|
//////////////////////
|
|
|
|
// Allocate global memory in case parameters are too large
|
|
char *d_mem;
|
|
if (max_shared_memory < memory_needed_per_block) {
|
|
d_mem = (char *)cuda_malloc_async(
|
|
memory_needed_per_block * (1 << (r - 1)) * tau, stream, gpu_index);
|
|
} else {
|
|
checkCudaErrors(cudaFuncSetAttribute(
|
|
device_batch_cmux<Torus, STorus, params, FULLSM>,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, memory_needed_per_block));
|
|
checkCudaErrors(
|
|
cudaFuncSetCacheConfig(device_batch_cmux<Torus, STorus, params, FULLSM>,
|
|
cudaFuncCachePreferShared));
|
|
}
|
|
|
|
// Allocate buffers
|
|
int glwe_size = (glwe_dimension + 1) * polynomial_size;
|
|
|
|
Torus *d_buffer1 = (Torus *)cuda_malloc_async(
|
|
num_lut * tau * glwe_size * sizeof(Torus), stream, gpu_index);
|
|
Torus *d_buffer2 = (Torus *)cuda_malloc_async(
|
|
num_lut * tau * glwe_size * sizeof(Torus), stream, gpu_index);
|
|
|
|
add_padding_to_lut_async<Torus, params>(d_buffer1, lut_vector, glwe_dimension,
|
|
num_lut * tau, stream);
|
|
|
|
Torus *output;
|
|
// Run the cmux tree
|
|
for (int layer_idx = 0; layer_idx < r; layer_idx++) {
|
|
output = (layer_idx % 2 ? d_buffer1 : d_buffer2);
|
|
Torus *input = (layer_idx % 2 ? d_buffer2 : d_buffer1);
|
|
|
|
int num_cmuxes = (1 << (r - 1 - layer_idx));
|
|
dim3 grid(num_cmuxes, tau, 1);
|
|
|
|
// walks horizontally through the leaves
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
device_batch_cmux<Torus, STorus, params, NOSM>
|
|
<<<grid, thds, 0, *stream>>>(output, input, d_ggsw_fft_in, d_mem,
|
|
memory_needed_per_block,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, level_count,
|
|
layer_idx, // r
|
|
num_lut);
|
|
else
|
|
device_batch_cmux<Torus, STorus, params, FULLSM>
|
|
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
|
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, level_count,
|
|
layer_idx, // r
|
|
num_lut);
|
|
checkCudaErrors(cudaGetLastError());
|
|
}
|
|
|
|
for (int i = 0; i < tau; i++)
|
|
checkCudaErrors(cudaMemcpyAsync(
|
|
glwe_array_out + i * glwe_size, output + i * num_lut * glwe_size,
|
|
glwe_size * sizeof(Torus), cudaMemcpyDeviceToDevice, *stream));
|
|
|
|
// We only need synchronization to assert that data is in glwe_array_out
|
|
// before returning. Memory release can be added to the stream and processed
|
|
// later.
|
|
checkCudaErrors(cudaStreamSynchronize(*stream));
|
|
|
|
// Free memory
|
|
cuda_drop_async(d_ggsw_fft_in, stream, gpu_index);
|
|
cuda_drop_async(d_buffer1, stream, gpu_index);
|
|
cuda_drop_async(d_buffer2, stream, gpu_index);
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
cuda_drop_async(d_mem, stream, gpu_index);
|
|
}
|
|
|
|
// only works for big lwe for ks+bs case
|
|
// state_lwe_buffer is copied from big lwe input
|
|
// shifted_lwe_buffer is scalar multiplication of lwe input
|
|
// blockIdx.x refers to input ciphertext id
|
|
template <typename Torus, class params>
|
|
__global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift,
|
|
Torus *src, Torus value) {
|
|
int blockId = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
auto cur_dst_copy = &dst_copy[blockId * (params::degree + 1)];
|
|
auto cur_dst_shift = &dst_shift[blockId * (params::degree + 1)];
|
|
auto cur_src = &src[blockId * (params::degree + 1)];
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_dst_copy[tid] = cur_src[tid];
|
|
cur_dst_shift[tid] = cur_src[tid] * value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
if (threadIdx.x == params::degree / params::opt - 1) {
|
|
cur_dst_copy[params::degree] = cur_src[params::degree];
|
|
cur_dst_shift[params::degree] = cur_src[params::degree] * value;
|
|
}
|
|
}
|
|
|
|
// works for lwe with generic sizes
|
|
// shifted_lwe_buffer is scalar multiplication of lwe input
|
|
// blockIdx.x refers to input ciphertext id
|
|
template <typename Torus, class params>
|
|
__global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value,
|
|
size_t lwe_size) {
|
|
|
|
size_t blockId = blockIdx.y * gridDim.x + blockIdx.x;
|
|
size_t threads_per_block = blockDim.x;
|
|
size_t opt = lwe_size / threads_per_block;
|
|
size_t rem = lwe_size & (threads_per_block - 1);
|
|
|
|
auto cur_dst = &dst_shift[blockId * lwe_size];
|
|
auto cur_src = &src[blockIdx.y * lwe_size];
|
|
|
|
size_t tid = threadIdx.x;
|
|
for (size_t i = 0; i < opt; i++) {
|
|
cur_dst[tid] = cur_src[tid] * value;
|
|
tid += threads_per_block;
|
|
}
|
|
|
|
if (threadIdx.x < rem)
|
|
cur_dst[tid] = cur_src[tid] * value;
|
|
}
|
|
|
|
// only works for small lwe in ks+bs case
|
|
// function copies lwe when length is not a power of two
|
|
template <typename Torus>
|
|
__global__ void copy_small_lwe(Torus *dst, Torus *src, uint32_t small_lwe_size,
|
|
uint32_t number_of_bits, uint32_t lwe_id) {
|
|
|
|
size_t blockId = blockIdx.x;
|
|
size_t threads_per_block = blockDim.x;
|
|
size_t opt = small_lwe_size / threads_per_block;
|
|
size_t rem = small_lwe_size & (threads_per_block - 1);
|
|
|
|
auto cur_lwe_list = &dst[blockId * small_lwe_size * number_of_bits];
|
|
auto cur_dst = &cur_lwe_list[lwe_id * small_lwe_size];
|
|
auto cur_src = &src[blockId * small_lwe_size];
|
|
|
|
size_t tid = threadIdx.x;
|
|
for (int i = 0; i < opt; i++) {
|
|
cur_dst[tid] = cur_src[tid];
|
|
tid += threads_per_block;
|
|
}
|
|
|
|
if (threadIdx.x < rem)
|
|
cur_dst[tid] = cur_src[tid];
|
|
}
|
|
|
|
// only used in extract bits for one ciphertext
|
|
// should be called with one block and one thread
|
|
// NOTE: check if putting this functionality in copy_small_lwe or
|
|
// fill_pbs_lut vector is faster
|
|
template <typename Torus>
|
|
__global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) {
|
|
lwe[blockIdx.x * (lwe_dimension + 1) + lwe_dimension] += value;
|
|
}
|
|
|
|
// Fill lut(only body) for the current bit (equivalent to trivial encryption as
|
|
// mask is 0s)
|
|
// The LUT is filled with value
|
|
template <typename Torus, class params>
|
|
__global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
|
|
|
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
|
size_t tid = threadIdx.x;
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_poly[tid] = value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
}
|
|
|
|
// Fill lut (equivalent to trivial encryption as mask is 0s)
|
|
// The LUT is filled with -alpha in each coefficient where
|
|
// alpha = 2^{log(q) - 1 - base_log * level}
|
|
template <typename Torus, class params>
|
|
__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits,
|
|
uint32_t base_log_cbs) {
|
|
|
|
Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree];
|
|
Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree];
|
|
size_t tid = threadIdx.x;
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_mask[tid] = 0;
|
|
cur_poly[tid] =
|
|
0ll -
|
|
(1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1)));
|
|
tid += params::degree / params::opt;
|
|
}
|
|
}
|
|
|
|
template <typename Torus, class params>
|
|
__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src,
|
|
uint32_t ciphertext_n_bits,
|
|
uint32_t base_log_cbs, uint32_t level_cbs) {
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_lwe_id = blockIdx.x;
|
|
size_t src_lwe_id = dst_lwe_id / 2;
|
|
size_t cur_cbs_level = src_lwe_id % level_cbs + 1;
|
|
|
|
auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)];
|
|
auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)];
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_dst[tid] = cur_src[tid];
|
|
tid += params::degree / params::opt;
|
|
}
|
|
Torus val = 1ll << (ciphertext_n_bits - 1 - base_log_cbs * cur_cbs_level);
|
|
if (threadIdx.x == 0) {
|
|
cur_dst[params::degree] = cur_src[params::degree] + val;
|
|
}
|
|
}
|
|
|
|
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0
|
|
// if the extracted bit was 0 and 1 in the other case
|
|
//
|
|
// Remove the extracted bit from the state LWE to get a 0 at the extracted bit
|
|
// location.
|
|
//
|
|
// Shift on padding bit for next iteration, that's why
|
|
// alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 2) is used
|
|
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
|
|
template <typename Torus, class params>
|
|
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
|
Torus *pbs_lwe_array_out, Torus add_value,
|
|
Torus mul_value) {
|
|
size_t tid = threadIdx.x;
|
|
size_t blockId = blockIdx.x;
|
|
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
|
|
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
|
|
auto cur_pbs_lwe_array_out =
|
|
&pbs_lwe_array_out[blockId * (params::degree + 1)];
|
|
#pragma unroll
|
|
for (int i = 0; i < params::opt; i++) {
|
|
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_array_out[tid];
|
|
cur_shifted_lwe[tid] *= mul_value;
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
if (threadIdx.x == params::degree / params::opt - 1) {
|
|
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
|
|
(cur_pbs_lwe_array_out[params::degree] + add_value);
|
|
cur_shifted_lwe[params::degree] *= mul_value;
|
|
}
|
|
}
|
|
|
|
template <typename Torus, class params>
|
|
__host__ void host_extract_bits(
|
|
void *v_stream, uint32_t gpu_index, Torus *list_lwe_array_out,
|
|
Torus *lwe_array_in, Torus *lwe_array_in_buffer,
|
|
Torus *lwe_array_in_shifted_buffer, Torus *lwe_array_out_ks_buffer,
|
|
Torus *lwe_array_out_pbs_buffer, Torus *lut_pbs,
|
|
uint32_t *lut_vector_indexes, Torus *ksk, double2 *fourier_bsk,
|
|
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
|
uint32_t lwe_dimension_out, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
|
uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t number_of_samples,
|
|
uint32_t max_shared_memory) {
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
|
|
|
int blocks = 1;
|
|
int threads = params::degree / params::opt;
|
|
|
|
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>(
|
|
lwe_array_in_buffer, lwe_array_in_shifted_buffer, lwe_array_in,
|
|
1ll << (ciphertext_n_bits - delta_log - 1));
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
|
cuda_keyswitch_lwe_ciphertext_vector(
|
|
v_stream, gpu_index, lwe_array_out_ks_buffer,
|
|
lwe_array_in_shifted_buffer, ksk, lwe_dimension_in, lwe_dimension_out,
|
|
base_log_ksk, level_count_ksk, 1);
|
|
|
|
copy_small_lwe<<<1, 256, 0, *stream>>>(
|
|
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
|
|
number_of_bits, number_of_bits - bit_idx - 1);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
if (bit_idx == number_of_bits - 1) {
|
|
break;
|
|
}
|
|
|
|
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_array_out_ks_buffer,
|
|
lwe_dimension_out,
|
|
1ll << (ciphertext_n_bits - 2));
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
fill_lut_body_for_current_bit<Torus, params>
|
|
<<<blocks, threads, 0, *stream>>>(
|
|
lut_pbs, 0ll - 1ll << (delta_log - 1 + bit_idx));
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
host_bootstrap_low_latency<Torus, params>(
|
|
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_pbs,
|
|
lut_vector_indexes, lwe_array_out_ks_buffer, fourier_bsk,
|
|
lwe_dimension_out, lwe_dimension_in, base_log_bsk, level_count_bsk,
|
|
number_of_samples, 1, max_shared_memory);
|
|
|
|
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
|
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
|
|
lwe_array_out_pbs_buffer, 1ll << (delta_log - 1 + bit_idx),
|
|
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2));
|
|
checkCudaErrors(cudaGetLastError());
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Receives "tau" GLWE ciphertexts as LUTs and "mbr_size" GGSWs. Each block
|
|
* computes the blind rotation loop + sample extraction for a single LUT.
|
|
* Writes the lwe output to lwe_out.
|
|
*
|
|
* This function needs polynomial_size/params::opt threads per block and tau
|
|
* blocks
|
|
*
|
|
* - lwe_out: An array of lwe ciphertexts. The outcome is written here.
|
|
* - glwe_in: An array of "tau" GLWE ciphertexts. These are the LUTs.
|
|
* - ggsw_in: An array of "mbr_size" GGSWs in the fourier domain.
|
|
* - mbr_size: The number of GGSWs.
|
|
* - glwe_dim: This is k.
|
|
* - polynomial_size: size of the polynomials. This is N.
|
|
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
|
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
|
* - device_memory_size_per_sample: Amount of (shared/global) memory used for
|
|
* the accumulators.
|
|
* - device_mem: An array to be used for the accumulators. Can be in the shared
|
|
* memory or global memory.
|
|
*/
|
|
template <typename Torus, typename STorus, class params, sharedMemDegree SMD>
|
|
__global__ void device_blind_rotation_and_sample_extraction(
|
|
Torus *lwe_out, Torus *glwe_in, double2 *ggsw_in, // m^BR
|
|
uint32_t mbr_size, uint32_t glwe_dim, uint32_t polynomial_size,
|
|
uint32_t base_log, uint32_t l_gadget, size_t device_memory_size_per_sample,
|
|
char *device_mem) {
|
|
|
|
// We use shared memory for intermediate result
|
|
extern __shared__ char sharedmem[];
|
|
char *selected_memory;
|
|
|
|
if constexpr (SMD == FULLSM)
|
|
selected_memory = sharedmem;
|
|
else
|
|
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
|
|
|
|
Torus *accumulator_c0 = (Torus *)selected_memory;
|
|
Torus *accumulator_c1 = (Torus *)accumulator_c0 + 2 * polynomial_size;
|
|
|
|
// Input LUT
|
|
auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size];
|
|
int tid = threadIdx.x;
|
|
for (int i = 0; i < params::opt; i++) {
|
|
accumulator_c0[tid] = mi[tid];
|
|
accumulator_c0[tid + params::degree] = mi[tid + params::degree];
|
|
tid += params::degree / params::opt;
|
|
}
|
|
|
|
int monomial_degree = 0;
|
|
for (int i = mbr_size - 1; i >= 0; i--) {
|
|
synchronize_threads_in_block();
|
|
|
|
// Compute x^ai * ACC
|
|
// Body
|
|
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
|
params::degree / params::opt>(
|
|
accumulator_c1, accumulator_c0, (1 << monomial_degree), false);
|
|
// Mask
|
|
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
|
params::degree / params::opt>(
|
|
accumulator_c1 + polynomial_size, accumulator_c0 + polynomial_size,
|
|
(1 << monomial_degree), false);
|
|
|
|
monomial_degree += 1;
|
|
|
|
// ACC = CMUX ( Ci, x^ai * ACC, ACC )
|
|
synchronize_threads_in_block();
|
|
cmux<Torus, STorus, params>(accumulator_c0, accumulator_c0, ggsw_in,
|
|
(char *)(accumulator_c0 + 4 * polynomial_size),
|
|
0, 0, 1, glwe_dim, polynomial_size, base_log,
|
|
l_gadget, i);
|
|
}
|
|
synchronize_threads_in_block();
|
|
|
|
// Write the output
|
|
auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)];
|
|
|
|
// The blind rotation for this block is over
|
|
// Now we can perform the sample extraction: for the body it's just
|
|
// the resulting constant coefficient of the accumulator
|
|
// For the mask it's more complicated
|
|
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_c0);
|
|
sample_extract_body<Torus, params>(block_lwe_out,
|
|
accumulator_c0 + polynomial_size);
|
|
}
|
|
|
|
template <typename Torus, typename STorus, class params>
|
|
void host_blind_rotate_and_sample_extraction(
|
|
void *v_stream, uint32_t gpu_index, Torus *lwe_out, Torus *ggsw_in,
|
|
Torus *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
|
uint32_t max_shared_memory) {
|
|
|
|
assert(glwe_dimension ==
|
|
1); // For larger k we will need to adjust the mask size
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
|
|
int memory_needed_per_block =
|
|
sizeof(Torus) * polynomial_size + // accumulator_c0 mask
|
|
sizeof(Torus) * polynomial_size + // accumulator_c0 body
|
|
sizeof(Torus) * polynomial_size + // accumulator_c1 mask
|
|
sizeof(Torus) * polynomial_size + // accumulator_c1 body
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
|
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
|
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
|
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
|
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
|
|
|
char *d_mem;
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
d_mem = (char *)cuda_malloc_async(memory_needed_per_block * tau, stream,
|
|
gpu_index);
|
|
else {
|
|
checkCudaErrors(cudaFuncSetAttribute(
|
|
device_blind_rotation_and_sample_extraction<Torus, STorus, params,
|
|
FULLSM>,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, memory_needed_per_block));
|
|
checkCudaErrors(cudaFuncSetCacheConfig(
|
|
device_blind_rotation_and_sample_extraction<Torus, STorus, params,
|
|
FULLSM>,
|
|
cudaFuncCachePreferShared));
|
|
}
|
|
|
|
// Applying the FFT on m^br
|
|
int ggsw_size =
|
|
polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) * l_gadget;
|
|
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
|
|
mbr_size * ggsw_size * sizeof(double), stream, gpu_index);
|
|
|
|
batch_fft_ggsw_vector<Torus, STorus, params>(
|
|
stream, d_ggsw_fft_in, ggsw_in, mbr_size, glwe_dimension, polynomial_size,
|
|
l_gadget, gpu_index, max_shared_memory);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
//
|
|
dim3 thds(polynomial_size / params::opt, 1, 1);
|
|
dim3 grid(tau, 1, 1);
|
|
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
device_blind_rotation_and_sample_extraction<Torus, STorus, params, NOSM>
|
|
<<<grid, thds, 0, *stream>>>(lwe_out, lut_vector, d_ggsw_fft_in,
|
|
mbr_size,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, l_gadget,
|
|
memory_needed_per_block, d_mem);
|
|
else
|
|
device_blind_rotation_and_sample_extraction<Torus, STorus, params, FULLSM>
|
|
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
|
lwe_out, lut_vector, d_ggsw_fft_in, mbr_size,
|
|
glwe_dimension, // k
|
|
polynomial_size, base_log, l_gadget, memory_needed_per_block,
|
|
d_mem);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
//
|
|
cuda_drop_async(d_ggsw_fft_in, stream, gpu_index);
|
|
if (max_shared_memory < memory_needed_per_block)
|
|
cuda_drop_async(d_mem, stream, gpu_index);
|
|
}
|
|
|
|
template <typename Torus, class params>
|
|
__host__ void host_circuit_bootstrap(
|
|
void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in,
|
|
double2 *fourier_bsk, Torus *fp_ksk_array,
|
|
Torus *lwe_array_in_shifted_buffer, Torus *lut_vector,
|
|
uint32_t *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer,
|
|
Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log,
|
|
uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
|
uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk,
|
|
uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs,
|
|
uint32_t number_of_samples, uint32_t max_shared_memory) {
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
|
|
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
|
uint32_t lwe_size = lwe_dimension + 1;
|
|
int pbs_count = number_of_samples * level_cbs;
|
|
|
|
dim3 blocks(level_cbs, number_of_samples, 1);
|
|
int threads = 256;
|
|
|
|
// Shift message LSB on padding bit, at this point we expect to have messages
|
|
// with only 1 bit of information
|
|
shift_lwe_cbs<Torus, params><<<blocks, threads, 0, *stream>>>(
|
|
lwe_array_in_shifted_buffer, lwe_array_in,
|
|
1LL << (ciphertext_n_bits - delta_log - 1), lwe_size);
|
|
|
|
// Add q/4 to center the error while computing a negacyclic LUT
|
|
add_to_body<Torus>
|
|
<<<pbs_count, 1, 0, *stream>>>(lwe_array_in_shifted_buffer, lwe_dimension,
|
|
1ll << (ciphertext_n_bits - 2));
|
|
// Fill lut (equivalent to trivial encryption as mask is 0s)
|
|
// The LUT is filled with -alpha in each coefficient where
|
|
// alpha = 2^{log(q) - 1 - base_log * level}
|
|
fill_lut_body_for_cbs<Torus, params>
|
|
<<<level_cbs, params::degree / params::opt, 0, *stream>>>(
|
|
lut_vector, ciphertext_n_bits, base_log_cbs);
|
|
|
|
// Applying a negacyclic LUT on a ciphertext with one bit of message in the
|
|
// MSB and no bit of padding
|
|
host_bootstrap_amortized<Torus, params>(
|
|
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_vector,
|
|
lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk,
|
|
lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count,
|
|
level_cbs, 0, max_shared_memory);
|
|
|
|
dim3 copy_grid(pbs_count * (glwe_dimension + 1), 1, 1);
|
|
dim3 copy_block(params::degree / params::opt, 1, 1);
|
|
// Add q/4 to center the error while computing a negacyclic LUT
|
|
// copy pbs result (glwe_dimension + 1) times to be an input of fp-ks
|
|
copy_add_lwe_cbs<Torus, params><<<copy_grid, copy_block>>>(
|
|
lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer, ciphertext_n_bits,
|
|
base_log_cbs, level_cbs);
|
|
|
|
cuda_fp_keyswitch_lwe_to_glwe(
|
|
v_stream, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array,
|
|
polynomial_size, glwe_dimension, polynomial_size, base_log_pksk,
|
|
level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1);
|
|
}
|
|
|
|
// number_of_inputs is the total number of LWE ciphertexts passed to CBS + VP,
|
|
// i.e. tau * p where tau is the number of LUTs (the original number of LWEs
|
|
// before bit extraction) and p is the number of extracted bits
|
|
template <typename Torus, typename STorus, class params>
|
|
__host__ void host_circuit_bootstrap_vertical_packing(
|
|
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out,
|
|
Torus *lwe_array_in, Torus *lut_vector, double2 *fourier_bsk,
|
|
Torus *cbs_fpksk, uint32_t glwe_dimension, uint32_t lwe_dimension,
|
|
uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
|
uint32_t base_log_pksk, uint32_t level_count_pksk, uint32_t base_log_cbs,
|
|
uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t lut_number,
|
|
uint32_t max_shared_memory) {
|
|
|
|
auto stream = static_cast<cudaStream_t *>(v_stream);
|
|
|
|
// allocate and initialize device pointers for circuit bootstrap
|
|
// output ggsw array for cbs
|
|
int ggsw_size = level_count_cbs * (glwe_dimension + 1) *
|
|
(glwe_dimension + 1) * polynomial_size;
|
|
Torus *ggsw_out = (Torus *)cuda_malloc_async(
|
|
number_of_inputs * ggsw_size * sizeof(Torus), stream, gpu_index);
|
|
// input lwe array for fp-ks
|
|
Torus *lwe_array_in_fp_ks_buffer = (Torus *)cuda_malloc_async(
|
|
number_of_inputs * level_count_cbs * (glwe_dimension + 1) *
|
|
(polynomial_size + 1) * sizeof(Torus),
|
|
stream, gpu_index);
|
|
// buffer for pbs output
|
|
Torus *lwe_array_out_pbs_buffer =
|
|
(Torus *)cuda_malloc_async(number_of_inputs * level_count_cbs *
|
|
(polynomial_size + 1) * sizeof(Torus),
|
|
stream, gpu_index);
|
|
// vector for shifted lwe input
|
|
Torus *lwe_array_in_shifted_buffer = (Torus *)cuda_malloc_async(
|
|
number_of_inputs * level_count_cbs * (lwe_dimension + 1) * sizeof(Torus),
|
|
stream, gpu_index);
|
|
// lut vector buffer for cbs
|
|
Torus *lut_vector_cbs = (Torus *)cuda_malloc_async(
|
|
level_count_cbs * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
|
stream, gpu_index);
|
|
// indexes of lut vectors for cbs
|
|
uint32_t *h_lut_vector_indexes =
|
|
(uint32_t *)malloc(number_of_inputs * level_count_cbs * sizeof(uint32_t));
|
|
for (uint index = 0; index < level_count_cbs * number_of_inputs; index++) {
|
|
h_lut_vector_indexes[index] = index % level_count_cbs;
|
|
}
|
|
uint32_t *lut_vector_indexes = (uint32_t *)cuda_malloc_async(
|
|
number_of_inputs * level_count_cbs * sizeof(uint32_t), stream, gpu_index);
|
|
cuda_memcpy_async_to_gpu(
|
|
lut_vector_indexes, h_lut_vector_indexes,
|
|
number_of_inputs * level_count_cbs * sizeof(uint32_t), stream, gpu_index);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
uint32_t bits = sizeof(Torus) * 8;
|
|
uint32_t delta_log = (bits - 1);
|
|
|
|
host_circuit_bootstrap<Torus, params>(
|
|
v_stream, gpu_index, ggsw_out, lwe_array_in, fourier_bsk, cbs_fpksk,
|
|
lwe_array_in_shifted_buffer, lut_vector_cbs, lut_vector_indexes,
|
|
lwe_array_out_pbs_buffer, lwe_array_in_fp_ks_buffer, delta_log,
|
|
polynomial_size, glwe_dimension, lwe_dimension, level_count_bsk,
|
|
base_log_bsk, level_count_pksk, base_log_pksk, level_count_cbs,
|
|
base_log_cbs, number_of_inputs, max_shared_memory);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
// Free memory
|
|
cuda_drop_async(lwe_array_in_fp_ks_buffer, stream, gpu_index);
|
|
cuda_drop_async(lwe_array_in_shifted_buffer, stream, gpu_index);
|
|
cuda_drop_async(lwe_array_out_pbs_buffer, stream, gpu_index);
|
|
cuda_drop_async(lut_vector_cbs, stream, gpu_index);
|
|
cuda_drop_async(lut_vector_indexes, stream, gpu_index);
|
|
free(h_lut_vector_indexes);
|
|
|
|
// number_of_inputs = tau * p is the total number of GGSWs
|
|
if (number_of_inputs > params::log2_degree) {
|
|
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
|
|
// lsb GGSW is for the last blind rotation.
|
|
uint32_t r = number_of_inputs - params::log2_degree;
|
|
Torus *br_ggsw = (Torus *)ggsw_out +
|
|
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
|
|
(glwe_dimension + 1) * polynomial_size);
|
|
Torus *glwe_array_out = (Torus *)cuda_malloc_async(
|
|
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
|
stream, gpu_index);
|
|
// CMUX Tree
|
|
// r = tau * p - log2(N)
|
|
host_cmux_tree<Torus, STorus, params>(
|
|
v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector,
|
|
glwe_dimension, polynomial_size, base_log_cbs, level_count_cbs, r,
|
|
lut_number, max_shared_memory);
|
|
checkCudaErrors(cudaGetLastError());
|
|
cuda_drop_async(glwe_array_out, stream, gpu_index);
|
|
|
|
// Blind rotation + sample extraction
|
|
// mbr = tau * p - r = log2(N)
|
|
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
|
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
|
|
number_of_inputs - r, lut_number, glwe_dimension, polynomial_size,
|
|
base_log_cbs, level_count_cbs, max_shared_memory);
|
|
} else {
|
|
// we need to expand the lut to fill the masks with zeros
|
|
Torus *lut_vector_glwe = (Torus *)cuda_malloc_async(
|
|
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
|
stream, gpu_index);
|
|
add_padding_to_lut_async<Torus, params>(lut_vector_glwe, lut_vector,
|
|
glwe_dimension, lut_number, stream);
|
|
checkCudaErrors(cudaGetLastError());
|
|
|
|
// Blind rotation + sample extraction
|
|
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
|
|
v_stream, gpu_index, lwe_array_out, ggsw_out, lut_vector_glwe,
|
|
number_of_inputs, lut_number, glwe_dimension, polynomial_size,
|
|
base_log_cbs, level_count_cbs, max_shared_memory);
|
|
}
|
|
cuda_drop_async(ggsw_out, stream, gpu_index);
|
|
}
|
|
|
|
#endif // WOP_PBS_H
|