feat(cuda): new decomposition algorithm for pbs.

- removes 16 bit limitation on base_log
- optimizes shared memory use: buffers for decomposition are not used anymore, rotated buffers are reused as state buffer for decomposition for the amortized PBS.
- Add a private test for cuda PBS, as we have for fft backend.
This commit is contained in:
Beka Barbakadze
2022-11-17 05:27:39 +04:00
committed by Agnès Leroy
parent d59b2f6dda
commit 56b986da8b
5 changed files with 97 additions and 138 deletions

View File

@@ -72,14 +72,8 @@ __global__ void device_bootstrap_amortized(
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
// For GPU bootstrapping the GLWE dimension is hard-set to 1: there is only
// one mask polynomial and 1 body to handle Also, since the decomposed
// polynomials take coefficients between -B/2 and B/2 they can be represented
// with only 16 bits, assuming the base log does not exceed 2^16
int16_t *accumulator_mask_decomposed = (int16_t *)selected_memory;
int16_t *accumulator_body_decomposed =
(int16_t *)accumulator_mask_decomposed + polynomial_size;
Torus *accumulator_mask = (Torus *)accumulator_body_decomposed +
polynomial_size / (sizeof(Torus) / sizeof(int16_t));
// one mask polynomial and 1 body to handle.
Torus *accumulator_mask = (Torus *)selected_memory;
Torus *accumulator_body =
(Torus *)accumulator_mask + (ptrdiff_t)polynomial_size;
Torus *accumulator_mask_rotated =
@@ -100,8 +94,6 @@ __global__ void device_bootstrap_amortized(
&lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree *
2];
GadgetMatrix<Torus, params> gadget(base_log, level_count);
// Put "b", the body, in [0, 2N[
Torus b_hat = 0;
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
@@ -159,27 +151,17 @@ __global__ void device_bootstrap_amortized(
pos += params::degree / params::opt;
}
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count,
accumulator_mask_rotated);
GadgetMatrix<Torus, params> gadget_body(base_log, level_count,
accumulator_body_rotated);
// Now that the rotation is done, decompose the resulting polynomial
// coefficients so as to multiply each decomposed level with the
// corresponding part of the bootstrapping key
for (int level = 0; level < level_count; level++) {
for (int level = level_count - 1; level >= 0; level--) {
gadget.decompose_one_level(accumulator_mask_decomposed,
accumulator_mask_rotated, level);
gadget.decompose_one_level(accumulator_body_decomposed,
accumulator_body_rotated, level);
synchronize_threads_in_block();
// First, perform the polynomial multiplication for the mask
// Reduce the size of the FFT to be performed by storing
// the real-valued polynomial into a complex polynomial
real_to_complex_compressed<params>(accumulator_mask_decomposed,
accumulator_fft);
synchronize_threads_in_block();
gadget_mask.decompose_and_compress_next(accumulator_fft);
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
@@ -205,12 +187,10 @@ __global__ void device_bootstrap_amortized(
body_res_fft, accumulator_fft, bsk_body_slice);
synchronize_threads_in_block();
gadget_body.decompose_and_compress_next(accumulator_fft);
// Now handle the polynomial multiplication for the body
// in the same way
real_to_complex_compressed<params>(accumulator_body_decomposed,
accumulator_fft);
synchronize_threads_in_block();
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
@@ -306,12 +286,10 @@ __host__ void host_bootstrap_amortized(
uint32_t gpu_index = 0;
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
sizeof(Torus) * polynomial_size + // accumulator body
sizeof(Torus) * polynomial_size + // accumulator mask rotated
sizeof(Torus) * polynomial_size + // accumulator body rotated
sizeof(int16_t) * polynomial_size + // accumulator_dec mask
sizeof(int16_t) * polynomial_size + // accumulator_dec_body
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
sizeof(Torus) * polynomial_size + // accumulator body
sizeof(Torus) * polynomial_size + // accumulator mask rotated
sizeof(Torus) * polynomial_size + // accumulator body rotated
sizeof(double2) * polynomial_size / 2 + // accumulator fft mask
sizeof(double2) * polynomial_size / 2 + // accumulator fft body
sizeof(double2) * polynomial_size / 2; // calculate buffer fft

View File

@@ -30,14 +30,10 @@ namespace cg = cooperative_groups;
template <typename Torus, class params>
__device__ void
mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed,
double2 *mask_join_buffer, double2 *body_join_buffer,
double2 *bootstrapping_key, int polynomial_size, int level_count,
int iteration, grid_group &grid) {
// Put the decomposed GLWE sample in the Fourier domain
real_to_complex_compressed<params>(glwe_decomposed, fft);
synchronize_threads_in_block();
mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
double2 *body_join_buffer, double2 *bootstrapping_key,
int polynomial_size, int level_count, int iteration,
grid_group &grid) {
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(fft);
@@ -157,16 +153,13 @@ __global__ void device_bootstrap_low_latency(
char *selected_memory = sharedmem;
int16_t *accumulator_decomposed = (int16_t *)selected_memory;
Torus *accumulator = (Torus *)accumulator_decomposed +
polynomial_size / (sizeof(Torus) / sizeof(int16_t));
Torus *accumulator = (Torus *)selected_memory;
Torus *accumulator_rotated =
(Torus *)accumulator + (ptrdiff_t)polynomial_size;
double2 *accumulator_fft =
(double2 *)accumulator +
(double2 *)accumulator_rotated +
polynomial_size / (sizeof(double2) / sizeof(Torus));
// Reuse memory from accumulator_fft for accumulator_rotated
Torus *accumulator_rotated = (Torus *)accumulator_fft;
// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
auto block_lwe_array_in = &lwe_array_in[blockIdx.z * (lwe_dimension + 1)];
@@ -181,7 +174,6 @@ __global__ void device_bootstrap_low_latency(
// Since the space is L1 cache is small, we use the same memory location for
// the rotated accumulator and the fft accumulator, since we know that the
// rotated array is not in use anymore by the time we perform the fft
GadgetMatrix<Torus, params> gadget(base_log, level_count);
// Put "b" in [0, 2N[
Torus b_hat = 0;
@@ -217,11 +209,14 @@ __global__ void device_bootstrap_low_latency(
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);
synchronize_threads_in_block();
// Decompose the accumulator. Each block gets one level of the
// decomposition, for the mask and the body (so block 0 will have the
// accumulator decomposed at level 0, 1 at 1, etc.)
gadget.decompose_one_level(accumulator_decomposed, accumulator_rotated,
blockIdx.x);
GadgetMatrix<Torus, params> gadget_acc(base_log, level_count,
accumulator_rotated);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x);
// We are using the same memory space for accumulator_fft and
// accumulator_rotated, so we need to synchronize here to make sure they
@@ -229,9 +224,9 @@ __global__ void device_bootstrap_low_latency(
synchronize_threads_in_block();
// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe<Torus, params>(accumulator, accumulator_fft,
accumulator_decomposed, block_mask_join_buffer,
block_body_join_buffer, bootstrapping_key,
polynomial_size, level_count, i, grid);
block_mask_join_buffer, block_body_join_buffer,
bootstrapping_key, polynomial_size,
level_count, i, grid);
}
auto block_lwe_array_out = &lwe_array_out[blockIdx.z * (polynomial_size + 1)];
@@ -270,8 +265,8 @@ host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out,
double2 *body_buffer_fft =
(double2 *)cuda_malloc_async(buffer_size_per_gpu, *stream, gpu_index);
int bytes_needed = sizeof(int16_t) * polynomial_size + // accumulator_decomp
sizeof(Torus) * polynomial_size + // accumulator
int bytes_needed = sizeof(Torus) * polynomial_size + // accumulator_rotated
sizeof(Torus) * polynomial_size + // accumulator
sizeof(double2) * polynomial_size / 2; // accumulator fft
int thds = polynomial_size / params::opt;

View File

@@ -21,14 +21,7 @@
#include "utils/memory.cuh"
#include "utils/timer.cuh"
template <class params> __device__ void fft(double2 *output, int16_t *input) {
synchronize_threads_in_block();
// Reduce the size of the FFT to be performed by storing
// the real-valued polynomial into a complex polynomial
real_to_complex_compressed<params>(input, output);
synchronize_threads_in_block();
template <class params> __device__ void fft(double2 *output) {
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(output);
synchronize_threads_in_block();
@@ -81,19 +74,14 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
Torus *glwe_sub_mask = (Torus *)selected_memory;
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
int16_t *glwe_mask_decomposed = (int16_t *)(glwe_sub_body + polynomial_size);
int16_t *glwe_body_decomposed =
(int16_t *)glwe_mask_decomposed + (ptrdiff_t)polynomial_size;
double2 *mask_res_fft = (double2 *)(glwe_body_decomposed + polynomial_size);
double2 *mask_res_fft = (double2 *)glwe_sub_body +
polynomial_size / (sizeof(double2) / sizeof(Torus));
double2 *body_res_fft =
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
double2 *glwe_fft =
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
GadgetMatrix<Torus, params> gadget(base_log, level_count);
/////////////////////////////////////
// glwe2-glwe1
@@ -125,18 +113,18 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
pos += params::degree / params::opt;
}
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count, glwe_sub_mask);
GadgetMatrix<Torus, params> gadget_body(base_log, level_count, glwe_sub_body);
// Subtract each glwe operand, decompose the resulting
// polynomial coefficients to multiply each decomposed level
// with the corresponding part of the LUT
for (int level = 0; level < level_count; level++) {
for (int level = level_count - 1; level >= 0; level--) {
// Decomposition
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level);
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level);
gadget_mask.decompose_and_compress_next(glwe_fft);
// First, perform the polynomial multiplication for the mask
synchronize_threads_in_block();
fft<params>(glwe_fft, glwe_mask_decomposed);
fft<params>(glwe_fft);
// External product and accumulate
// Get the piece necessary for the multiplication
@@ -157,7 +145,9 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
// Now handle the polynomial multiplication for the body
// in the same way
synchronize_threads_in_block();
fft<params>(glwe_fft, glwe_body_decomposed);
gadget_body.decompose_and_compress_next(glwe_fft);
fft<params>(glwe_fft);
// External product and accumulate
// Get the piece necessary for the multiplication
@@ -272,8 +262,6 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
int memory_needed_per_block =
sizeof(Torus) * polynomial_size + // glwe_sub_mask
sizeof(Torus) * polynomial_size + // glwe_sub_body
sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed
sizeof(int16_t) * polynomial_size + // glwe_body_decomposed
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
sizeof(double2) * polynomial_size / 2 + // body_res_fft
sizeof(double2) * polynomial_size / 2; // glwe_fft
@@ -331,11 +319,11 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
// walks horizontally through the leafs
if (max_shared_memory < memory_needed_per_block)
device_batch_cmux<Torus, STorus, params, NOSM>
<<<grid, thds, memory_needed_per_block, *stream>>>(
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
glwe_dimension, // k
polynomial_size, base_log, level_count,
layer_idx // r
<<<grid, thds, 0, *stream>>>(output, input, d_ggsw_fft_in, d_mem,
memory_needed_per_block,
glwe_dimension, // k
polynomial_size, base_log, level_count,
layer_idx // r
);
else
device_batch_cmux<Torus, STorus, params, FULLSM>
@@ -640,8 +628,6 @@ void host_blind_rotate_and_sample_extraction(
sizeof(Torus) * polynomial_size + // accumulator_c1 body
sizeof(Torus) * polynomial_size + // glwe_sub_mask
sizeof(Torus) * polynomial_size + // glwe_sub_body
sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed
sizeof(int16_t) * polynomial_size + // glwe_body_decomposed
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
sizeof(double2) * polynomial_size / 2 + // body_res_fft
sizeof(double2) * polynomial_size / 2; // glwe_fft

View File

@@ -12,51 +12,51 @@ private:
uint32_t mask;
uint32_t halfbg;
T offset;
int current_level;
T mask_mod_b;
T *state;
public:
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count)
: base_log(base_log), level_count(level_count) {
uint32_t bg = 1 << base_log;
this->halfbg = bg / 2;
this->mask = bg - 1;
T temp = 0;
for (int i = 0; i < this->level_count; i++) {
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
}
this->offset = temp * this->halfbg;
}
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count, T *state)
: base_log(base_log), level_count(level_count), state(state) {
template <typename V, typename U>
__device__ void decompose_one_level(Polynomial<V, params> &result,
Polynomial<U, params> &polynomial,
uint32_t level) {
mask_mod_b = (1ll << base_log) - 1ll;
current_level = level_count;
int tid = threadIdx.x;
for (int i = 0; i < params::opt; i++) {
T s = polynomial.coefficients[tid] + this->offset;
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
T temp1 = (s >> decal) & this->mask;
result.coefficients[tid] = (V)(temp1 - this->halfbg);
tid += params::degree / params::opt;
}
}
template <typename V, typename U>
__device__ void decompose_one_level(V *result, U *polynomial,
uint32_t level) {
int tid = threadIdx.x;
for (int i = 0; i < params::opt; i++) {
T s = polynomial[tid] + this->offset;
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
T temp1 = (s >> decal) & this->mask;
result[tid] = (V)(temp1 - this->halfbg);
state[tid] >>= (sizeof(T) * 8 - base_log * level_count);
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}
__device__ T decompose_one_level_single(T element, uint32_t level) {
T s = element + this->offset;
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
T temp1 = (s >> decal) & this->mask;
return (T)(temp1 - this->halfbg);
__device__ void decompose_and_compress_next(double2 *result) {
int tid = threadIdx.x;
current_level -= 1;
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state[tid * 2] & mask_mod_b;
T res_im = state[tid * 2 + 1] & mask_mod_b;
state[tid * 2] >>= base_log;
state[tid * 2 + 1] >>= base_log;
T carry_re = ((res_re - 1ll) | state[tid * 2]) & res_re;
T carry_im = ((res_im - 1ll) | state[tid * 2 + 1]) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state[tid * 2] += carry_re;
state[tid * 2 + 1] += carry_im;
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;
result[tid].x = (int32_t)res_re;
result[tid].y = (int32_t)res_im;
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}
__device__ void decompose_and_compress_level(double2 *result, int level) {
for (int i = 0; i < level_count - level; i++)
decompose_and_compress_next(result);
}
};
@@ -89,4 +89,15 @@ public:
}
};
template <typename Torus>
__device__ Torus decompose_one(Torus &state, Torus mask_mod_b, int base_log) {
Torus res = state & mask_mod_b;
state >>= base_log;
Torus carry = ((res - 1ll) | state) & res;
carry >>= base_log - 1;
state += carry;
res -= carry << base_log;
return res;
}
#endif // CNCRT_CRPYTO_H

View File

@@ -17,17 +17,6 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
return ptr;
}
template <typename Torus>
__device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) {
Torus res = state & mod_b_mask;
state >>= base_log;
Torus carry = ((res - 1ll) | state) & res;
carry >>= base_log - 1;
state += carry;
res -= carry << base_log;
return res;
}
/*
* keyswitch kernel
* Each thread handles a piece of the following equation:
@@ -85,12 +74,12 @@ __global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
Torus mask_mod_b = (1ll << base_log) - 1ll;
for (int j = 0; j < level_count; j++) {
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;