mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(cuda): new decomposition algorithm for pbs.
- removes 16 bit limitation on base_log - optimizes shared memory use: buffers for decomposition are not used anymore, rotated buffers are reused as state buffer for decomposition for the amortized PBS. - Add a private test for cuda PBS, as we have for fft backend.
This commit is contained in:
committed by
Agnès Leroy
parent
d59b2f6dda
commit
56b986da8b
@@ -72,14 +72,8 @@ __global__ void device_bootstrap_amortized(
|
||||
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
|
||||
|
||||
// For GPU bootstrapping the GLWE dimension is hard-set to 1: there is only
|
||||
// one mask polynomial and 1 body to handle Also, since the decomposed
|
||||
// polynomials take coefficients between -B/2 and B/2 they can be represented
|
||||
// with only 16 bits, assuming the base log does not exceed 2^16
|
||||
int16_t *accumulator_mask_decomposed = (int16_t *)selected_memory;
|
||||
int16_t *accumulator_body_decomposed =
|
||||
(int16_t *)accumulator_mask_decomposed + polynomial_size;
|
||||
Torus *accumulator_mask = (Torus *)accumulator_body_decomposed +
|
||||
polynomial_size / (sizeof(Torus) / sizeof(int16_t));
|
||||
// one mask polynomial and 1 body to handle.
|
||||
Torus *accumulator_mask = (Torus *)selected_memory;
|
||||
Torus *accumulator_body =
|
||||
(Torus *)accumulator_mask + (ptrdiff_t)polynomial_size;
|
||||
Torus *accumulator_mask_rotated =
|
||||
@@ -100,8 +94,6 @@ __global__ void device_bootstrap_amortized(
|
||||
&lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree *
|
||||
2];
|
||||
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b", the body, in [0, 2N[
|
||||
Torus b_hat = 0;
|
||||
rescale_torus_element(block_lwe_array_in[lwe_dimension], b_hat,
|
||||
@@ -159,27 +151,17 @@ __global__ void device_bootstrap_amortized(
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
|
||||
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count,
|
||||
accumulator_mask_rotated);
|
||||
GadgetMatrix<Torus, params> gadget_body(base_log, level_count,
|
||||
accumulator_body_rotated);
|
||||
|
||||
// Now that the rotation is done, decompose the resulting polynomial
|
||||
// coefficients so as to multiply each decomposed level with the
|
||||
// corresponding part of the bootstrapping key
|
||||
for (int level = 0; level < level_count; level++) {
|
||||
for (int level = level_count - 1; level >= 0; level--) {
|
||||
|
||||
gadget.decompose_one_level(accumulator_mask_decomposed,
|
||||
accumulator_mask_rotated, level);
|
||||
|
||||
gadget.decompose_one_level(accumulator_body_decomposed,
|
||||
accumulator_body_rotated, level);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// First, perform the polynomial multiplication for the mask
|
||||
|
||||
// Reduce the size of the FFT to be performed by storing
|
||||
// the real-valued polynomial into a complex polynomial
|
||||
real_to_complex_compressed<params>(accumulator_mask_decomposed,
|
||||
accumulator_fft);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
gadget_mask.decompose_and_compress_next(accumulator_fft);
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
@@ -205,12 +187,10 @@ __global__ void device_bootstrap_amortized(
|
||||
body_res_fft, accumulator_fft, bsk_body_slice);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
gadget_body.decompose_and_compress_next(accumulator_fft);
|
||||
|
||||
// Now handle the polynomial multiplication for the body
|
||||
// in the same way
|
||||
real_to_complex_compressed<params>(accumulator_body_decomposed,
|
||||
accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
@@ -306,12 +286,10 @@ __host__ void host_bootstrap_amortized(
|
||||
|
||||
uint32_t gpu_index = 0;
|
||||
|
||||
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator body
|
||||
sizeof(Torus) * polynomial_size + // accumulator mask rotated
|
||||
sizeof(Torus) * polynomial_size + // accumulator body rotated
|
||||
sizeof(int16_t) * polynomial_size + // accumulator_dec mask
|
||||
sizeof(int16_t) * polynomial_size + // accumulator_dec_body
|
||||
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator body
|
||||
sizeof(Torus) * polynomial_size + // accumulator mask rotated
|
||||
sizeof(Torus) * polynomial_size + // accumulator body rotated
|
||||
sizeof(double2) * polynomial_size / 2 + // accumulator fft mask
|
||||
sizeof(double2) * polynomial_size / 2 + // accumulator fft body
|
||||
sizeof(double2) * polynomial_size / 2; // calculate buffer fft
|
||||
|
||||
@@ -30,14 +30,10 @@ namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Torus, class params>
|
||||
__device__ void
|
||||
mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed,
|
||||
double2 *mask_join_buffer, double2 *body_join_buffer,
|
||||
double2 *bootstrapping_key, int polynomial_size, int level_count,
|
||||
int iteration, grid_group &grid) {
|
||||
|
||||
// Put the decomposed GLWE sample in the Fourier domain
|
||||
real_to_complex_compressed<params>(glwe_decomposed, fft);
|
||||
synchronize_threads_in_block();
|
||||
mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
|
||||
double2 *body_join_buffer, double2 *bootstrapping_key,
|
||||
int polynomial_size, int level_count, int iteration,
|
||||
grid_group &grid) {
|
||||
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(fft);
|
||||
@@ -157,16 +153,13 @@ __global__ void device_bootstrap_low_latency(
|
||||
|
||||
char *selected_memory = sharedmem;
|
||||
|
||||
int16_t *accumulator_decomposed = (int16_t *)selected_memory;
|
||||
Torus *accumulator = (Torus *)accumulator_decomposed +
|
||||
polynomial_size / (sizeof(Torus) / sizeof(int16_t));
|
||||
Torus *accumulator = (Torus *)selected_memory;
|
||||
Torus *accumulator_rotated =
|
||||
(Torus *)accumulator + (ptrdiff_t)polynomial_size;
|
||||
double2 *accumulator_fft =
|
||||
(double2 *)accumulator +
|
||||
(double2 *)accumulator_rotated +
|
||||
polynomial_size / (sizeof(double2) / sizeof(Torus));
|
||||
|
||||
// Reuse memory from accumulator_fft for accumulator_rotated
|
||||
Torus *accumulator_rotated = (Torus *)accumulator_fft;
|
||||
|
||||
// The third dimension of the block is used to determine on which ciphertext
|
||||
// this block is operating, in the case of batch bootstraps
|
||||
auto block_lwe_array_in = &lwe_array_in[blockIdx.z * (lwe_dimension + 1)];
|
||||
@@ -181,7 +174,6 @@ __global__ void device_bootstrap_low_latency(
|
||||
// Since the space is L1 cache is small, we use the same memory location for
|
||||
// the rotated accumulator and the fft accumulator, since we know that the
|
||||
// rotated array is not in use anymore by the time we perform the fft
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b" in [0, 2N[
|
||||
Torus b_hat = 0;
|
||||
@@ -217,11 +209,14 @@ __global__ void device_bootstrap_low_latency(
|
||||
params::degree / params::opt>(
|
||||
accumulator_rotated, base_log, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Decompose the accumulator. Each block gets one level of the
|
||||
// decomposition, for the mask and the body (so block 0 will have the
|
||||
// accumulator decomposed at level 0, 1 at 1, etc.)
|
||||
gadget.decompose_one_level(accumulator_decomposed, accumulator_rotated,
|
||||
blockIdx.x);
|
||||
GadgetMatrix<Torus, params> gadget_acc(base_log, level_count,
|
||||
accumulator_rotated);
|
||||
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x);
|
||||
|
||||
// We are using the same memory space for accumulator_fft and
|
||||
// accumulator_rotated, so we need to synchronize here to make sure they
|
||||
@@ -229,9 +224,9 @@ __global__ void device_bootstrap_low_latency(
|
||||
synchronize_threads_in_block();
|
||||
// Perform G^-1(ACC) * GGSW -> GLWE
|
||||
mul_ggsw_glwe<Torus, params>(accumulator, accumulator_fft,
|
||||
accumulator_decomposed, block_mask_join_buffer,
|
||||
block_body_join_buffer, bootstrapping_key,
|
||||
polynomial_size, level_count, i, grid);
|
||||
block_mask_join_buffer, block_body_join_buffer,
|
||||
bootstrapping_key, polynomial_size,
|
||||
level_count, i, grid);
|
||||
}
|
||||
|
||||
auto block_lwe_array_out = &lwe_array_out[blockIdx.z * (polynomial_size + 1)];
|
||||
@@ -270,8 +265,8 @@ host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out,
|
||||
double2 *body_buffer_fft =
|
||||
(double2 *)cuda_malloc_async(buffer_size_per_gpu, *stream, gpu_index);
|
||||
|
||||
int bytes_needed = sizeof(int16_t) * polynomial_size + // accumulator_decomp
|
||||
sizeof(Torus) * polynomial_size + // accumulator
|
||||
int bytes_needed = sizeof(Torus) * polynomial_size + // accumulator_rotated
|
||||
sizeof(Torus) * polynomial_size + // accumulator
|
||||
sizeof(double2) * polynomial_size / 2; // accumulator fft
|
||||
|
||||
int thds = polynomial_size / params::opt;
|
||||
|
||||
@@ -21,14 +21,7 @@
|
||||
#include "utils/memory.cuh"
|
||||
#include "utils/timer.cuh"
|
||||
|
||||
template <class params> __device__ void fft(double2 *output, int16_t *input) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Reduce the size of the FFT to be performed by storing
|
||||
// the real-valued polynomial into a complex polynomial
|
||||
real_to_complex_compressed<params>(input, output);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
template <class params> __device__ void fft(double2 *output) {
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(output);
|
||||
synchronize_threads_in_block();
|
||||
@@ -81,19 +74,14 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
Torus *glwe_sub_mask = (Torus *)selected_memory;
|
||||
Torus *glwe_sub_body = (Torus *)glwe_sub_mask + (ptrdiff_t)polynomial_size;
|
||||
|
||||
int16_t *glwe_mask_decomposed = (int16_t *)(glwe_sub_body + polynomial_size);
|
||||
int16_t *glwe_body_decomposed =
|
||||
(int16_t *)glwe_mask_decomposed + (ptrdiff_t)polynomial_size;
|
||||
|
||||
double2 *mask_res_fft = (double2 *)(glwe_body_decomposed + polynomial_size);
|
||||
double2 *mask_res_fft = (double2 *)glwe_sub_body +
|
||||
polynomial_size / (sizeof(double2) / sizeof(Torus));
|
||||
double2 *body_res_fft =
|
||||
(double2 *)mask_res_fft + (ptrdiff_t)polynomial_size / 2;
|
||||
|
||||
double2 *glwe_fft =
|
||||
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
||||
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
/////////////////////////////////////
|
||||
|
||||
// glwe2-glwe1
|
||||
@@ -125,18 +113,18 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
pos += params::degree / params::opt;
|
||||
}
|
||||
|
||||
GadgetMatrix<Torus, params> gadget_mask(base_log, level_count, glwe_sub_mask);
|
||||
GadgetMatrix<Torus, params> gadget_body(base_log, level_count, glwe_sub_body);
|
||||
// Subtract each glwe operand, decompose the resulting
|
||||
// polynomial coefficients to multiply each decomposed level
|
||||
// with the corresponding part of the LUT
|
||||
for (int level = 0; level < level_count; level++) {
|
||||
for (int level = level_count - 1; level >= 0; level--) {
|
||||
|
||||
// Decomposition
|
||||
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level);
|
||||
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level);
|
||||
gadget_mask.decompose_and_compress_next(glwe_fft);
|
||||
|
||||
// First, perform the polynomial multiplication for the mask
|
||||
synchronize_threads_in_block();
|
||||
fft<params>(glwe_fft, glwe_mask_decomposed);
|
||||
fft<params>(glwe_fft);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
@@ -157,7 +145,9 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
// Now handle the polynomial multiplication for the body
|
||||
// in the same way
|
||||
synchronize_threads_in_block();
|
||||
fft<params>(glwe_fft, glwe_body_decomposed);
|
||||
|
||||
gadget_body.decompose_and_compress_next(glwe_fft);
|
||||
fft<params>(glwe_fft);
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
@@ -272,8 +262,6 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
||||
int memory_needed_per_block =
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed
|
||||
sizeof(int16_t) * polynomial_size + // glwe_body_decomposed
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
@@ -331,11 +319,11 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
||||
// walks horizontally through the leafs
|
||||
if (max_shared_memory < memory_needed_per_block)
|
||||
device_batch_cmux<Torus, STorus, params, NOSM>
|
||||
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
||||
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
||||
glwe_dimension, // k
|
||||
polynomial_size, base_log, level_count,
|
||||
layer_idx // r
|
||||
<<<grid, thds, 0, *stream>>>(output, input, d_ggsw_fft_in, d_mem,
|
||||
memory_needed_per_block,
|
||||
glwe_dimension, // k
|
||||
polynomial_size, base_log, level_count,
|
||||
layer_idx // r
|
||||
);
|
||||
else
|
||||
device_batch_cmux<Torus, STorus, params, FULLSM>
|
||||
@@ -640,8 +628,6 @@ void host_blind_rotate_and_sample_extraction(
|
||||
sizeof(Torus) * polynomial_size + // accumulator_c1 body
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_mask
|
||||
sizeof(Torus) * polynomial_size + // glwe_sub_body
|
||||
sizeof(int16_t) * polynomial_size + // glwe_mask_decomposed
|
||||
sizeof(int16_t) * polynomial_size + // glwe_body_decomposed
|
||||
sizeof(double2) * polynomial_size / 2 + // mask_res_fft
|
||||
sizeof(double2) * polynomial_size / 2 + // body_res_fft
|
||||
sizeof(double2) * polynomial_size / 2; // glwe_fft
|
||||
|
||||
@@ -12,51 +12,51 @@ private:
|
||||
uint32_t mask;
|
||||
uint32_t halfbg;
|
||||
T offset;
|
||||
int current_level;
|
||||
T mask_mod_b;
|
||||
T *state;
|
||||
|
||||
public:
|
||||
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count)
|
||||
: base_log(base_log), level_count(level_count) {
|
||||
uint32_t bg = 1 << base_log;
|
||||
this->halfbg = bg / 2;
|
||||
this->mask = bg - 1;
|
||||
T temp = 0;
|
||||
for (int i = 0; i < this->level_count; i++) {
|
||||
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
|
||||
}
|
||||
this->offset = temp * this->halfbg;
|
||||
}
|
||||
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count, T *state)
|
||||
: base_log(base_log), level_count(level_count), state(state) {
|
||||
|
||||
template <typename V, typename U>
|
||||
__device__ void decompose_one_level(Polynomial<V, params> &result,
|
||||
Polynomial<U, params> &polynomial,
|
||||
uint32_t level) {
|
||||
mask_mod_b = (1ll << base_log) - 1ll;
|
||||
current_level = level_count;
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
T s = polynomial.coefficients[tid] + this->offset;
|
||||
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
|
||||
T temp1 = (s >> decal) & this->mask;
|
||||
result.coefficients[tid] = (V)(temp1 - this->halfbg);
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
template <typename V, typename U>
|
||||
__device__ void decompose_one_level(V *result, U *polynomial,
|
||||
uint32_t level) {
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
T s = polynomial[tid] + this->offset;
|
||||
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
|
||||
T temp1 = (s >> decal) & this->mask;
|
||||
result[tid] = (V)(temp1 - this->halfbg);
|
||||
state[tid] >>= (sizeof(T) * 8 - base_log * level_count);
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
__device__ T decompose_one_level_single(T element, uint32_t level) {
|
||||
T s = element + this->offset;
|
||||
uint32_t decal = (sizeof(T) * 8 - (level + 1) * this->base_log);
|
||||
T temp1 = (s >> decal) & this->mask;
|
||||
return (T)(temp1 - this->halfbg);
|
||||
__device__ void decompose_and_compress_next(double2 *result) {
|
||||
int tid = threadIdx.x;
|
||||
current_level -= 1;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
T res_re = state[tid * 2] & mask_mod_b;
|
||||
T res_im = state[tid * 2 + 1] & mask_mod_b;
|
||||
state[tid * 2] >>= base_log;
|
||||
state[tid * 2 + 1] >>= base_log;
|
||||
T carry_re = ((res_re - 1ll) | state[tid * 2]) & res_re;
|
||||
T carry_im = ((res_im - 1ll) | state[tid * 2 + 1]) & res_im;
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
state[tid * 2] += carry_re;
|
||||
state[tid * 2 + 1] += carry_im;
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
|
||||
result[tid].x = (int32_t)res_re;
|
||||
result[tid].y = (int32_t)res_im;
|
||||
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
__device__ void decompose_and_compress_level(double2 *result, int level) {
|
||||
for (int i = 0; i < level_count - level; i++)
|
||||
decompose_and_compress_next(result);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -89,4 +89,15 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus decompose_one(Torus &state, Torus mask_mod_b, int base_log) {
|
||||
Torus res = state & mask_mod_b;
|
||||
state >>= base_log;
|
||||
Torus carry = ((res - 1ll) | state) & res;
|
||||
carry >>= base_log - 1;
|
||||
state += carry;
|
||||
res -= carry << base_log;
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif // CNCRT_CRPYTO_H
|
||||
|
||||
@@ -17,17 +17,6 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
return ptr;
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) {
|
||||
Torus res = state & mod_b_mask;
|
||||
state >>= base_log;
|
||||
Torus carry = ((res - 1ll) | state) & res;
|
||||
carry >>= base_log - 1;
|
||||
state += carry;
|
||||
res -= carry << base_log;
|
||||
return res;
|
||||
}
|
||||
|
||||
/*
|
||||
* keyswitch kernel
|
||||
* Each thread handles a piece of the following equation:
|
||||
@@ -85,12 +74,12 @@ __global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
||||
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
|
||||
|
||||
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
Torus mask_mod_b = (1ll << base_log) - 1ll;
|
||||
|
||||
for (int j = 0; j < level_count; j++) {
|
||||
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
|
||||
lwe_dimension_out, level_count);
|
||||
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
|
||||
for (int k = 0; k < lwe_part_per_thd; k++) {
|
||||
int idx = tid + k * blockDim.x;
|
||||
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
||||
|
||||
Reference in New Issue
Block a user